diff --git a/src/ghstack/github.py b/src/ghstack/github.py index 31343ca..f03dd27 100644 --- a/src/ghstack/github.py +++ b/src/ghstack/github.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import asyncio from abc import ABCMeta, abstractmethod from typing import Any, Sequence @@ -64,7 +65,7 @@ def get(self, path: str, **kwargs: Any) -> Any: Returns: parsed JSON response """ - return self.rest("get", path, **kwargs) + return self._run_async(self.aget(path, **kwargs)) def post(self, path: str, **kwargs: Any) -> Any: """ @@ -72,7 +73,7 @@ def post(self, path: str, **kwargs: Any) -> Any: Returns: parsed JSON response """ - return self.rest("post", path, **kwargs) + return self._run_async(self.apost(path, **kwargs)) def patch(self, path: str, **kwargs: Any) -> Any: """ @@ -80,13 +81,43 @@ def patch(self, path: str, **kwargs: Any) -> Any: Returns: parsed JSON response """ - return self.rest("patch", path, **kwargs) + return self._run_async(self.apatch(path, **kwargs)) - @abstractmethod def rest(self, method: str, path: str, **kwargs: Any) -> Any: """ Send a 'method' request to endpoint 'path'. + Args: + method: 'GET', 'POST', etc. + path: relative URL path to access on endpoint + **kwargs: dictionary of JSON payload to send + + Returns: parsed JSON response + """ + return self._run_async(self.arest(method, path, **kwargs)) + + @staticmethod + def _run_async(coro: Any) -> Any: + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + async def aget(self, path: str, **kwargs: Any) -> Any: + return await self.arest("get", path, **kwargs) + + async def apost(self, path: str, **kwargs: Any) -> Any: + return await self.arest("post", path, **kwargs) + + async def apatch(self, path: str, **kwargs: Any) -> Any: + return await self.arest("patch", path, **kwargs) + + @abstractmethod + async def arest(self, method: str, path: str, **kwargs: Any) -> Any: + """ + Send an async 'method' request to endpoint 'path'. + Args: method: 'GET', 'POST', etc. path: relative URL path to access on endpoint diff --git a/src/ghstack/github_fake.py b/src/ghstack/github_fake.py index a54a3bb..3072c47 100644 --- a/src/ghstack/github_fake.py +++ b/src/ghstack/github_fake.py @@ -219,6 +219,11 @@ def pullRequests(self, info: GraphQLResolveInfo) -> "PullRequestConnection": # TODO: This should take which repository the ref is in # This only works if you have upstream_sh def _make_ref(self, state: GitHubState, refName: str) -> "Ref": + return ghstack.github.GitHubEndpoint._run_async( + self._make_ref_async(state, refName) + ) + + async def _make_ref_async(self, state: GitHubState, refName: str) -> "Ref": # TODO: Probably should preserve object identity here when # you call this with refName/oid that are the same assert state.upstream_sh @@ -226,7 +231,7 @@ def _make_ref(self, state: GitHubState, refName: str) -> "Ref": id=state.next_id(), # TODO: this upstream_sh hardcode wrong, but ok for now # because we only have one repo - oid=GitObjectID(state.upstream_sh.git("rev-parse", refName)), + oid=GitObjectID(await state.upstream_sh.agit("rev-parse", refName)), _repository=self.id, ) ref = Ref( @@ -366,7 +371,7 @@ def push_hook(self, refNames: Sequence[str]) -> None: def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None: self.state.notify_merged(pr_resolved) - def _create_pull( + async def _create_pull_async( self, owner: str, name: str, input: CreatePullRequestInput ) -> CreatePullRequestPayload: state = self.state @@ -378,8 +383,8 @@ def _create_pull( # TODO: When we support forks, this needs rewriting to stop # hard coded the repo we opened the pull request on if state.upstream_sh: - baseRef = repo._make_ref(state, input["base"]) - headRef = repo._make_ref(state, input["head"]) + baseRef = await repo._make_ref_async(state, input["base"]) + headRef = await repo._make_ref_async(state, input["head"]) pr = PullRequest( id=id, _repository=repo.id, @@ -403,7 +408,7 @@ def _create_pull( # NB: This technically does have a payload, but we don't # use it so I didn't bother constructing it. - def _update_pull( + async def _update_pull_async( self, owner: str, name: str, number: GitHubNumber, input: UpdatePullRequestInput ) -> None: state = self.state @@ -415,11 +420,11 @@ def _update_pull( pr.title = input["title"] if "base" in input and input["base"] is not None: pr.baseRefName = input["base"] - pr.baseRef = repo._make_ref(state, pr.baseRefName) + pr.baseRef = await repo._make_ref_async(state, pr.baseRefName) if "body" in input and input["body"] is not None: pr.body = input["body"] - def _create_issue_comment( + async def _create_issue_comment_async( self, owner: str, name: str, comment_id: int, input: CreateIssueCommentInput ) -> CreateIssueCommentPayload: state = self.state @@ -439,7 +444,7 @@ def _create_issue_comment( "id": comment_id, } - def _update_issue_comment( + async def _update_issue_comment_async( self, owner: str, name: str, comment_id: int, input: UpdateIssueCommentInput ) -> None: state = self.state @@ -450,14 +455,19 @@ def _update_issue_comment( # NB: This may have a payload, but we don't # use it so I didn't bother constructing it. - def _set_default_branch( + async def _set_default_branch_async( self, owner: str, name: str, input: SetDefaultBranchInput ) -> None: state = self.state repo = state.repository(owner, name) - repo.defaultBranchRef = repo._make_ref(state, input["default_branch"]) + repo.defaultBranchRef = await repo._make_ref_async( + state, input["default_branch"] + ) - def rest(self, method: str, path: str, **kwargs: Any) -> Any: + async def arest(self, method: str, path: str, **kwargs: Any) -> Any: + return await self._arest_impl(method, path, **kwargs) + + async def _arest_impl(self, method: str, path: str, **kwargs: Any) -> Any: if method == "get": m = re.match(r"^repos/([^/]+)/([^/]+)/branches/([^/]+)/protection", path) if m: @@ -472,6 +482,12 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: "state": "closed" if pr.closed else "open", "title": pr.title, "body": pr.body, + "head": { + "ref": pr.headRefName, + }, + "base": { + "ref": pr.baseRefName, + }, } if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$", path): state = self.state @@ -484,11 +500,11 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: elif method == "post": if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls$", path): - return self._create_pull( + return await self._create_pull_async( m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs) ) if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/([^/]+)/comments", path): - return self._create_issue_comment( + return await self._create_issue_comment_async( m.group(1), m.group(2), GitHubNumber(int(m.group(3))), @@ -516,23 +532,24 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: if m := re.match(r"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$", path): owner, name, number = m.groups() if number is not None: - return self._update_pull( + return await self._update_pull_async( owner, name, GitHubNumber(int(number)), cast(UpdatePullRequestInput, kwargs), ) elif "default_branch" in kwargs: - return self._set_default_branch( + return await self._set_default_branch_async( owner, name, cast(SetDefaultBranchInput, kwargs) ) if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$", path): - return self._update_issue_comment( + return await self._update_issue_comment_async( m.group(1), m.group(2), int(m.group(3)), cast(UpdateIssueCommentInput, kwargs), ) + raise NotImplementedError( "FakeGitHubEndpoint REST {} {} not implemented".format(method.upper(), path) ) diff --git a/src/ghstack/github_real.py b/src/ghstack/github_real.py index 600fe64..ec0c9f5 100644 --- a/src/ghstack/github_real.py +++ b/src/ghstack/github_real.py @@ -1,11 +1,15 @@ #!/usr/bin/env python3 +import asyncio +import itertools import json import logging import re +import ssl import time from typing import Any, Dict, Optional, Sequence, Tuple, Union +import aiohttp import requests import ghstack.github @@ -70,6 +74,7 @@ def __init__( self.github_url = github_url self.verify = verify self.cert = cert + self._rest_request_ids = itertools.count(1) def push_hook(self, refName: Sequence[str]) -> None: pass @@ -120,12 +125,6 @@ def graphql(self, query: str, **kwargs: Any) -> Any: return r - def _proxies(self) -> Dict[str, str]: - if self.proxy: - return {"http": self.proxy, "https": self.proxy} - else: - return {} - def get_head_ref(self, **params: Any) -> str: if self.oauth_token: @@ -149,73 +148,106 @@ def get_head_ref(self, **params: Any) -> str: # couldn't find, fall back to regular query return super().get_head_ref(**params) - def rest(self, method: str, path: str, **kwargs: Any) -> Any: + def _proxies(self) -> Dict[str, str]: + if self.proxy: + return {"http": self.proxy, "https": self.proxy} + else: + return {} + + def _rest_headers(self) -> Dict[str, str]: assert self.oauth_token - headers = { + return { "Authorization": "token " + self.oauth_token, "Content-Type": "application/json", "User-Agent": "ghstack", "Accept": "application/vnd.github.v3+json", } + def _aiohttp_ssl(self) -> Any: + if self.verify is False: + return False + if self.verify is None and self.cert is None: + return None + + context = ssl.create_default_context( + cafile=self.verify if isinstance(self.verify, str) else None + ) + if isinstance(self.cert, tuple): + context.load_cert_chain(self.cert[0], self.cert[1]) + elif self.cert is not None: + context.load_cert_chain(self.cert) + return context + + async def arest(self, method: str, path: str, **kwargs: Any) -> Any: + assert self.oauth_token + headers = self._rest_headers() url = self.rest_endpoint.format(github_url=self.github_url) + "/" + path - backoff_seconds = INITIAL_BACKOFF_SECONDS - for attempt in range(0, MAX_RETRIES): - logging.debug("# {} {}".format(method, url)) - logging.debug("Request body:\n{}".format(json.dumps(kwargs, indent=1))) - - resp: requests.Response = getattr(requests, method)( - url, - json=kwargs, - headers=headers, - proxies=self._proxies(), - verify=self.verify, - cert=self.cert, - ) + request_kwargs: Dict[str, Any] = { + "json": kwargs, + "headers": headers, + } + if self.proxy: + request_kwargs["proxy"] = self.proxy + aiohttp_ssl = self._aiohttp_ssl() + if aiohttp_ssl is not None: + request_kwargs["ssl"] = aiohttp_ssl - logging.debug("Response status: {}".format(resp.status_code)) + backoff_seconds = INITIAL_BACKOFF_SECONDS + request_id = next(self._rest_request_ids) + log_prefix = f"rest[{request_id}]" + async with aiohttp.ClientSession() as session: + for attempt in range(0, MAX_RETRIES): + logging.debug("# %s %s %s", log_prefix, method, url) + logging.debug( + "%s request body:\n%s", log_prefix, json.dumps(kwargs, indent=1) + ) - try: - r = resp.json() - except ValueError: - logging.debug("Response body:\n{}".format(r.text)) - raise - else: - pretty_json = json.dumps(r, indent=1) - logging.debug("Response JSON:\n{}".format(pretty_json)) - - # Per Github rate limiting: https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api?apiVersion=2022-11-28#exceeding-the-rate-limit - if resp.status_code in (403, 429): - remaining_count = resp.headers.get("x-ratelimit-remaining") - reset_time = resp.headers.get("x-ratelimit-reset") - - if remaining_count == "0" and reset_time: - sleep_time = int(reset_time) - int(time.time()) - logging.warning( - f"Rate limit exceeded. Sleeping until reset in {sleep_time} seconds." - ) - time.sleep(sleep_time) - continue - else: - retry_after_seconds = resp.headers.get("retry-after") - if retry_after_seconds: - sleep_time = int(retry_after_seconds) - logging.warning( - f"Secondary rate limit hit. Sleeping for {sleep_time} seconds." + async with getattr(session, method)(url, **request_kwargs) as resp: + logging.debug("%s response status: %s", log_prefix, resp.status) + try: + r = await resp.json() + except (aiohttp.ContentTypeError, ValueError): + logging.debug( + "%s response body:\n%s", log_prefix, await resp.text() ) + raise else: - sleep_time = backoff_seconds - logging.warning( - f"Secondary rate limit hit. Sleeping for {sleep_time} seconds (exponential backoff)." - ) - backoff_seconds *= 2 - time.sleep(sleep_time) - continue - - if resp.status_code == 404: - raise ghstack.github.NotFoundError( - """\ + pretty_json = json.dumps(r, indent=1) + logging.debug("%s response JSON:\n%s", log_prefix, pretty_json) + + # Per Github rate limiting: + # https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api?apiVersion=2022-11-28#exceeding-the-rate-limit + if resp.status in (403, 429): + remaining_count = resp.headers.get("x-ratelimit-remaining") + reset_time = resp.headers.get("x-ratelimit-reset") + + if remaining_count == "0" and reset_time: + sleep_time = int(reset_time) - int(time.time()) + logging.warning( + f"Rate limit exceeded. Sleeping until reset in {sleep_time} seconds." + ) + await asyncio.sleep(sleep_time) + continue + else: + retry_after_seconds = resp.headers.get("retry-after") + if retry_after_seconds: + sleep_time = int(retry_after_seconds) + logging.warning( + f"Secondary rate limit hit. Sleeping for {sleep_time} seconds." + ) + else: + sleep_time = backoff_seconds + logging.warning( + f"Secondary rate limit hit. Sleeping for {sleep_time} seconds (exponential backoff)." + ) + backoff_seconds *= 2 + await asyncio.sleep(sleep_time) + continue + + if resp.status == 404: + raise ghstack.github.NotFoundError( + """\ GitHub raised a 404 error on the request for {url}. Usually, this doesn't actually mean the page doesn't exist; instead, it @@ -229,15 +261,13 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: to a new location or been renamed. Check that the repository URL is still correct. """.format( - url=url, github_url=self.github_url - ) - ) + url=url, github_url=self.github_url + ) + ) - try: - resp.raise_for_status() - except requests.HTTPError: - raise RuntimeError(pretty_json) + if resp.status >= 400: + raise RuntimeError(pretty_json) - return r + return r raise RuntimeError("Exceeded maximum retries due to GitHub rate limiting") diff --git a/src/ghstack/shell.py b/src/ghstack/shell.py index 91d1f6a..5cf6933 100644 --- a/src/ghstack/shell.py +++ b/src/ghstack/shell.py @@ -93,6 +93,39 @@ def sh( stdout: _HANDLE = subprocess.PIPE, exitcode: bool = False, tick: bool = False, + ) -> _SHELL_RET: + return self._run_async( + self.ash( + *args, + env=env, + stderr=stderr, + input=input, + stdin=stdin, + stdout=stdout, + exitcode=exitcode, + tick=tick, + ) + ) + + @staticmethod + def _run_async(coro: Any) -> Any: + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + async def ash( + self, + *args: str, # noqa: C901 + env: Optional[Dict[str, str]] = None, + stderr: _HANDLE = None, + # TODO: Arguably bytes should be accepted here too + input: Optional[str] = None, + stdin: _HANDLE = None, + stdout: _HANDLE = subprocess.PIPE, + exitcode: bool = False, + tick: bool = False, ) -> _SHELL_RET: """ Run a command specified by args, and return string representing @@ -197,8 +230,7 @@ async def run() -> Tuple[int, bytes, bytes]: assert proc.returncode is not None return (proc.returncode, out, err) - loop = asyncio.get_event_loop() - returncode, out, err = loop.run_until_complete(run()) + returncode, out, err = await run() def decode(b: bytes) -> str: return ( @@ -252,6 +284,17 @@ def git(self, *args: str, **kwargs: Any) -> _SHELL_RET: # noqa: F811 *args: Arguments to git **kwargs: Any valid kwargs for sh() """ + return self._run_async(self.agit(*args, **kwargs)) + + async def agit(self, *args: str, **kwargs: Any) -> _SHELL_RET: + """ + Run a git command asynchronously. The returned stdout has trailing + newlines stripped. + + Args: + *args: Arguments to git + **kwargs: Any valid kwargs for ash() + """ env = kwargs.setdefault("env", {}) # For git hooks to detect execution inside ghstack env.setdefault("GHSTACK", "1") @@ -278,7 +321,7 @@ def git(self, *args: str, **kwargs: Any) -> _SHELL_RET: # noqa: F811 if "stderr" not in kwargs: kwargs["stderr"] = subprocess.PIPE - return self._maybe_rstrip(self.sh(*(("git",) + args), **kwargs)) + return self._maybe_rstrip(await self.ash(*(("git",) + args), **kwargs)) @overload # noqa: F811 def hg(self, *args: str) -> str: ... @@ -298,7 +341,19 @@ def hg(self, *args: str, **kwargs: Any) -> _SHELL_RET: # noqa: F811 **kwargs: Any valid kwargs for sh() """ - return self._maybe_rstrip(self.sh(*(("hg",) + args), **kwargs)) + return self._run_async(self.ahg(*args, **kwargs)) + + async def ahg(self, *args: str, **kwargs: Any) -> _SHELL_RET: + """ + Run a hg command asynchronously. The returned stdout has trailing + newlines stripped. + + Args: + *args: Arguments to hg + **kwargs: Any valid kwargs for ash() + """ + + return self._maybe_rstrip(await self.ash(*(("hg",) + args), **kwargs)) def jf(self, *args: str, **kwargs: Any) -> _SHELL_RET: """ @@ -309,9 +364,21 @@ def jf(self, *args: str, **kwargs: Any) -> _SHELL_RET: **kwargs: Any valid kwargs for sh() """ + return self._run_async(self.ajf(*args, **kwargs)) + + async def ajf(self, *args: str, **kwargs: Any) -> _SHELL_RET: + """ + Run a jf command asynchronously. The returned stdout has trailing + newlines stripped. + + Args: + *args: Arguments to jf + **kwargs: Any valid kwargs for ash() + """ + kwargs.setdefault("stdout", sys.stderr) - return self._maybe_rstrip(self.sh(*(("jf",) + args), **kwargs)) + return self._maybe_rstrip(await self.ash(*(("jf",) + args), **kwargs)) def test_tick(self) -> None: """ @@ -327,7 +394,15 @@ def open(self, fn: str, mode: str) -> IO[Any]: fn: filename to open mode: mode to open the file as """ - return open(os.path.join(self.cwd, fn), mode) + return open(self.abspath(fn), mode) + + def abspath(self, fn: str) -> str: + """ + Resolve a path against this shell's current working directory. + """ + if os.path.isabs(fn): + return fn + return os.path.join(self.cwd, fn) def cd(self, d: str) -> None: """