diff --git a/src/ghstack/cli.py b/src/ghstack/cli.py index 24f77ca..ea65e67 100644 --- a/src/ghstack/cli.py +++ b/src/ghstack/cli.py @@ -33,6 +33,16 @@ def run_async(coro: Coroutine[Any, Any, object]) -> object: return asyncio.run(coro) +async def run_with_github( + github: ghstack.github_real.RealGitHubEndpoint, + coro: Coroutine[Any, Any, object], +) -> object: + try: + return await coro + finally: + await github.aclose() + + @contextlib.contextmanager def cli_context( *, @@ -176,11 +186,14 @@ def action(close: bool, pull_request: str) -> None: """ with cli_context() as (shell, _, github): run_async( - ghstack.action.main( - pull_request=pull_request, - github=github, - sh=shell, - close=close, + run_with_github( + github, + ghstack.action.main( + pull_request=pull_request, + github=github, + sh=shell, + close=close, + ), ) ) @@ -198,12 +211,15 @@ def checkout(same_base: bool, pull_request: str) -> None: """ with cli_context(request_github_token=False) as (shell, config, github): run_async( - ghstack.checkout.main( - pull_request=pull_request, - github=github, - sh=shell, - remote_name=config.remote_name, - same_base=same_base, + run_with_github( + github, + ghstack.checkout.main( + pull_request=pull_request, + github=github, + sh=shell, + remote_name=config.remote_name, + same_base=same_base, + ), ) ) @@ -227,13 +243,16 @@ def cherry_pick(stack: bool, no_fetch: bool, pull_request: str) -> None: """ with cli_context(request_github_token=False) as (shell, config, github): run_async( - ghstack.cherry_pick.main( - pull_request=pull_request, - github=github, - sh=shell, - remote_name=config.remote_name, - stack=stack, - no_fetch=no_fetch, + run_with_github( + github, + ghstack.cherry_pick.main( + pull_request=pull_request, + github=github, + sh=shell, + remote_name=config.remote_name, + stack=stack, + no_fetch=no_fetch, + ), ) ) @@ -247,13 +266,16 @@ def land(force: bool, pull_request: str) -> None: """ with cli_context() as (shell, config, github): run_async( - ghstack.land.main( - pull_request=pull_request, - github=github, - sh=shell, - github_url=config.github_url, - remote_name=config.remote_name, - force=force, + run_with_github( + github, + ghstack.land.main( + pull_request=pull_request, + github=github, + sh=shell, + github_url=config.github_url, + remote_name=config.remote_name, + force=force, + ), ) ) @@ -278,13 +300,16 @@ def log(pull_request: Optional[str], git_log_args: Tuple[str, ...]) -> None: """ with cli_context(request_github_token=False) as (shell, config, github): run_async( - ghstack.log.main( - github=github, - sh=shell, - remote_name=config.remote_name, - github_url=config.github_url, - args=list(git_log_args), - pull_request=pull_request, + run_with_github( + github, + ghstack.log.main( + github=github, + sh=shell, + remote_name=config.remote_name, + github_url=config.github_url, + args=list(git_log_args), + pull_request=pull_request, + ), ) ) @@ -312,10 +337,13 @@ def status(pull_request: str) -> None: ) run_async( - ghstack.status.main( - pull_request=pull_request, - github=github, - circleci=circleci, + run_with_github( + github, + ghstack.status.main( + pull_request=pull_request, + github=github, + circleci=circleci, + ), ) ) @@ -420,25 +448,28 @@ def submit( """ with cli_context() as (shell, config, github): run_async( - ghstack.submit.main( - msg=message, - username=config.github_username, - sh=shell, - github=github, - update_fields=update_fields, - short=short, - force=force, - no_skip=no_skip, - draft=draft, - github_url=config.github_url, - remote_name=config.remote_name, - base_opt=base, - revs=revs, - stack=stack, - direct_opt=direct_opt, - reviewer=reviewer if reviewer is not None else config.reviewer, - label=label if label is not None else config.label, - no_fetch=no_fetch, + run_with_github( + github, + ghstack.submit.main( + msg=message, + username=config.github_username, + sh=shell, + github=github, + update_fields=update_fields, + short=short, + force=force, + no_skip=no_skip, + draft=draft, + github_url=config.github_url, + remote_name=config.remote_name, + base_opt=base, + revs=revs, + stack=stack, + direct_opt=direct_opt, + reviewer=reviewer if reviewer is not None else config.reviewer, + label=label if label is not None else config.label, + no_fetch=no_fetch, + ), ) ) @@ -450,11 +481,14 @@ def sync() -> None: """ with cli_context() as (shell, config, github): run_async( - ghstack.sync.main( - github=github, - sh=shell, - github_url=config.github_url, - remote_name=config.remote_name, + run_with_github( + github, + ghstack.sync.main( + github=github, + sh=shell, + github_url=config.github_url, + remote_name=config.remote_name, + ), ) ) @@ -467,11 +501,14 @@ def unlink(commits: List[str]) -> None: """ with cli_context() as (shell, config, github): run_async( - ghstack.unlink.main( - commits=commits, - github=github, - sh=shell, - github_url=config.github_url, - remote_name=config.remote_name, + run_with_github( + github, + ghstack.unlink.main( + commits=commits, + github=github, + sh=shell, + github_url=config.github_url, + remote_name=config.remote_name, + ), ) ) diff --git a/src/ghstack/github_real.py b/src/ghstack/github_real.py index b0b10f7..d6308ac 100644 --- a/src/ghstack/github_real.py +++ b/src/ghstack/github_real.py @@ -59,6 +59,7 @@ def rest_endpoint(self) -> str: # Client side certificate to use when connecitng. # Passed to requests as 'cert'. cert: Optional[Union[str, Tuple[str, str]]] + _session: Optional[aiohttp.ClientSession] def __init__( self, @@ -74,6 +75,17 @@ def __init__( self.verify = verify self.cert = cert self._rest_request_ids = itertools.count(1) + self._session = None + + def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + return self._session + + async def aclose(self) -> None: + if self._session is not None and not self._session.closed: + await self._session.close() + self._session = None def push_hook(self, refName: Sequence[str]) -> None: pass @@ -101,27 +113,27 @@ async def graphql(self, query: str, **kwargs: Any) -> Any: if aiohttp_ssl is not None: request_kwargs["ssl"] = aiohttp_ssl - async with aiohttp.ClientSession() as session: - async with session.post( - self.graphql_endpoint.format(github_url=self.github_url), - **request_kwargs, - ) as resp: - logging.debug("Response status: {}".format(resp.status)) - - try: - r = await resp.json() - except (aiohttp.ContentTypeError, ValueError): - logging.debug("Response body:\n{}".format(await resp.text())) - raise - else: - pretty_json = json.dumps(r, indent=1) - logging.debug("Response JSON:\n{}".format(pretty_json)) - - # Actually, this code is dead on the GitHub GraphQL API, because - # they seem to always return 200, even in error case (as of - # 11/5/2018) - if resp.status >= 400: - raise RuntimeError(pretty_json) + session = self._get_session() + async with session.post( + self.graphql_endpoint.format(github_url=self.github_url), + **request_kwargs, + ) as resp: + logging.debug("Response status: {}".format(resp.status)) + + try: + r = await resp.json() + except (aiohttp.ContentTypeError, ValueError): + logging.debug("Response body:\n{}".format(await resp.text())) + raise + else: + pretty_json = json.dumps(r, indent=1) + logging.debug("Response JSON:\n{}".format(pretty_json)) + + # Actually, this code is dead on the GitHub GraphQL API, because + # they seem to always return 200, even in error case (as of + # 11/5/2018) + if resp.status >= 400: + raise RuntimeError(pretty_json) if "errors" in r: raise RuntimeError(pretty_json) @@ -142,13 +154,13 @@ async def get_head_ref(self, **params: Any) -> str: aiohttp_ssl = self._aiohttp_ssl() if aiohttp_ssl is not None: request_kwargs["ssl"] = aiohttp_ssl - async with aiohttp.ClientSession() as session: - async with session.get( - f"{self.www_endpoint.format(github_url=self.github_url)}/{owner}/{name}/pull/{number}", - **request_kwargs, - ) as resp: - logging.debug("Response status: {}".format(resp.status)) - r = await resp.text() + session = self._get_session() + async with session.get( + f"{self.www_endpoint.format(github_url=self.github_url)}/{owner}/{name}/pull/{number}", + **request_kwargs, + ) as resp: + logging.debug("Response status: {}".format(resp.status)) + r = await resp.text() if m := re.search(r' Any: backoff_seconds = INITIAL_BACKOFF_SECONDS request_id = next(self._rest_request_ids) log_prefix = f"rest[{request_id}]" - async with aiohttp.ClientSession() as session: - for attempt in range(0, MAX_RETRIES): - logging.debug("# %s %s %s", log_prefix, method, url) - logging.debug( - "%s request body:\n%s", log_prefix, json.dumps(kwargs, indent=1) - ) - - async with getattr(session, method)(url, **request_kwargs) as resp: - logging.debug("%s response status: %s", log_prefix, resp.status) - try: - r = await resp.json() - except (aiohttp.ContentTypeError, ValueError): - logging.debug( - "%s response body:\n%s", log_prefix, await resp.text() + session = self._get_session() + for attempt in range(0, MAX_RETRIES): + logging.debug("# %s %s %s", log_prefix, method, url) + logging.debug( + "%s request body:\n%s", log_prefix, json.dumps(kwargs, indent=1) + ) + + async with getattr(session, method)(url, **request_kwargs) as resp: + logging.debug("%s response status: %s", log_prefix, resp.status) + try: + r = await resp.json() + except (aiohttp.ContentTypeError, ValueError): + logging.debug( + "%s response body:\n%s", log_prefix, await resp.text() + ) + raise + else: + pretty_json = json.dumps(r, indent=1) + logging.debug("%s response JSON:\n%s", log_prefix, pretty_json) + + # Per Github rate limiting: + # https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api?apiVersion=2022-11-28#exceeding-the-rate-limit + if resp.status in (403, 429): + remaining_count = resp.headers.get("x-ratelimit-remaining") + reset_time = resp.headers.get("x-ratelimit-reset") + + if remaining_count == "0" and reset_time: + sleep_time = int(reset_time) - int(time.time()) + logging.warning( + f"Rate limit exceeded. Sleeping until reset in {sleep_time} seconds." ) - raise + await asyncio.sleep(sleep_time) + continue else: - pretty_json = json.dumps(r, indent=1) - logging.debug("%s response JSON:\n%s", log_prefix, pretty_json) - - # Per Github rate limiting: - # https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api?apiVersion=2022-11-28#exceeding-the-rate-limit - if resp.status in (403, 429): - remaining_count = resp.headers.get("x-ratelimit-remaining") - reset_time = resp.headers.get("x-ratelimit-reset") - - if remaining_count == "0" and reset_time: - sleep_time = int(reset_time) - int(time.time()) + retry_after_seconds = resp.headers.get("retry-after") + if retry_after_seconds: + sleep_time = int(retry_after_seconds) logging.warning( - f"Rate limit exceeded. Sleeping until reset in {sleep_time} seconds." + f"Secondary rate limit hit. Sleeping for {sleep_time} seconds." ) - await asyncio.sleep(sleep_time) - continue else: - retry_after_seconds = resp.headers.get("retry-after") - if retry_after_seconds: - sleep_time = int(retry_after_seconds) - logging.warning( - f"Secondary rate limit hit. Sleeping for {sleep_time} seconds." - ) - else: - sleep_time = backoff_seconds - logging.warning( - f"Secondary rate limit hit. Sleeping for {sleep_time} seconds (exponential backoff)." - ) - backoff_seconds *= 2 - await asyncio.sleep(sleep_time) - continue - - if resp.status == 404: - raise ghstack.github.NotFoundError( - """\ + sleep_time = backoff_seconds + logging.warning( + f"Secondary rate limit hit. Sleeping for {sleep_time} seconds (exponential backoff)." + ) + backoff_seconds *= 2 + await asyncio.sleep(sleep_time) + continue + + if resp.status == 404: + raise ghstack.github.NotFoundError( + """\ GitHub raised a 404 error on the request for {url}. Usually, this doesn't actually mean the page doesn't exist; instead, it @@ -268,13 +280,13 @@ async def arest(self, method: str, path: str, **kwargs: Any) -> Any: to a new location or been renamed. Check that the repository URL is still correct. """.format( - url=url, github_url=self.github_url - ) + url=url, github_url=self.github_url ) + ) - if resp.status >= 400: - raise RuntimeError(pretty_json) + if resp.status >= 400: + raise RuntimeError(pretty_json) - return r + return r raise RuntimeError("Exceeded maximum retries due to GitHub rate limiting")