diff --git a/src/ghstack/cli.py b/src/ghstack/cli.py index 25f8e3a..1d95ec2 100644 --- a/src/ghstack/cli.py +++ b/src/ghstack/cli.py @@ -384,6 +384,11 @@ def status(pull_request: str) -> None: is_flag=True, help="Create stack that directly merges into main", ) +@click.option( + "--no-fetch", + is_flag=True, + help="Skip fetching remote refs (faster when you know local refs are up-to-date)", +) @click.argument( "revs", nargs=-1, @@ -402,6 +407,7 @@ def submit( stack: bool, reviewer: Optional[str], label: Optional[str], + no_fetch: bool, ) -> None: """ Submit or update a PR stack @@ -425,6 +431,7 @@ def submit( direct_opt=direct_opt, reviewer=reviewer if reviewer is not None else config.reviewer, label=label if label is not None else config.label, + no_fetch=no_fetch, ) diff --git a/src/ghstack/submit.py b/src/ghstack/submit.py index 808dc11..ed51241 100644 --- a/src/ghstack/submit.py +++ b/src/ghstack/submit.py @@ -1,12 +1,26 @@ #!/usr/bin/env python3 +import asyncio import dataclasses import itertools import logging import os import re +import time from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple +from typing import ( + Any, + Awaitable, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + TypeVar, +) import ghstack import ghstack.git @@ -256,6 +270,60 @@ def next(self) -> GitCommitHash: return self.push_branches.next.commit.commit_id +_TIMING_ENABLED = True +_T = TypeVar("_T") + + +class _Timer: + def __init__(self) -> None: + self.start = time.monotonic() + self.last = self.start + self.entries: List[Tuple[str, float]] = [] + + def mark(self, label: str) -> None: + now = time.monotonic() + self.entries.append((label, now - self.last)) + self.last = now + + def report(self) -> None: + total = time.monotonic() - self.start + for label, elapsed in self.entries: + logging.info("[ghstack timing] %s: %.0fms", label, elapsed * 1000) + logging.info("[ghstack timing] total: %.0fms", total * 1000) + + +def _run_async_ordered(awaitables: Iterable[Awaitable[_T]]) -> List[_T]: + aws = list(awaitables) + if not aws: + return [] + + async def gather() -> List[Any]: + return await asyncio.gather(*aws, return_exceptions=True) + + loop = asyncio.new_event_loop() + try: + results = loop.run_until_complete(gather()) + finally: + loop.close() + + checked_results: List[_T] = [] + for result in results: + if isinstance(result, BaseException): + raise result + checked_results.append(result) + return checked_results + + +@dataclass +class _PendingNewPR: + commit_id: GitCommitHash + diff: ghstack.diff.Diff + base_diff_meta: Optional["DiffMeta"] + ghnum: GhNumber + push_specs: List[str] + diff_meta: "DiffMeta" + + def main(**kwargs: Any) -> List[DiffMeta]: submitter = Submitter(**kwargs) return submitter.run() @@ -269,8 +337,9 @@ def all_branches(username: str, ghnum: GhNumber) -> Tuple[str, str, str]: ) -def push_spec(commit: GitCommitHash, branch: str) -> str: - return "{}:refs/heads/{}".format(commit, branch) +def push_spec(commit: GitCommitHash, branch: str, force: bool = False) -> str: + spec = "{}:refs/heads/{}".format(commit, branch) + return "+" + spec if force else spec @dataclass(frozen=True) @@ -350,6 +419,9 @@ class Submitter: # Default labels to add to new pull requests (comma-separated labels) label: Optional[str] = None + # Skip fetching remote refs before submitting + no_fetch: bool = False + # ~~~~~~~~~~~~~~~~~~~~~~~~ # Computed in post init @@ -429,7 +501,21 @@ def __post_init__(self) -> None: # The main algorithm def run(self) -> List[DiffMeta]: - self.fetch() + timer = _Timer() if _TIMING_ENABLED else None + + if not self.no_fetch: + # Submit only needs fresh ghstack refs here. We intentionally do + # not fetch the base branch: in the normal workflow, a stack based + # on newer upstream commits got those commits by updating the local + # base ref before rebasing. The later rev-list boundary is against + # that local base ref, so narrowing this fetch avoids unrelated + # remote IO while preserving the usual submit semantics. + self.fetch( + f"+refs/heads/gh/{self.username}/*" + f":refs/remotes/{self.remote_name}/gh/{self.username}/*" + ) + if timer: + timer.mark("fetch") commits_to_submit_and_boundary = self.parse_revs() @@ -474,12 +560,19 @@ def run(self) -> List[DiffMeta]: "There appears to be no commits to process, based on the revs you passed me." ) + if timer: + timer.mark("parse_revs") + # This is not really accurate if you're doing a fancy pattern; # if this is a problem file us a bug. run_pre_ghstack_hook( self.sh, f"{self.remote_name}/{self.base}", commits_to_submit[0].commit_id ) + pr_info_cache = self._prefetch_pr_info(commits_to_rebase) + if not self.no_fetch: + self._fetch_foreign_pr_refs(pr_info_cache.values()) + # NB: This is duplicative with prepare_submit to keep the # check_invariants code small, as it counts as TCB pre_branch_state_index: Dict[GitCommitHash, PreBranchState] = {} @@ -487,7 +580,10 @@ def run(self) -> List[DiffMeta]: for h in commits_to_submit: d = ghstack.git.convert_header(h, self.github_url) if d.pull_request_resolved is not None: - ed = self.elaborate_diff(d) + ed = self.elaborate_diff( + d, + _pr_info=pr_info_cache.get(d.pull_request_resolved.number), + ) # Skip closed PRs (e.g., after landing) where branches have been deleted if not ed.closed: pre_branch_state_index[h.commit_id] = PreBranchState( @@ -511,8 +607,13 @@ def run(self) -> List[DiffMeta]: ) } diff_meta_index, rebase_index = self.prepare_updates( - commit_index, commits_to_submit, commits_to_rebase + commit_index, + commits_to_submit, + commits_to_rebase, + pr_info_cache=pr_info_cache, ) + if timer: + timer.mark("prepare_updates") logging.debug("rebase_index = %s", rebase_index) diffs_to_submit = [ diff_meta_index[h.commit_id] @@ -525,6 +626,8 @@ def run(self) -> List[DiffMeta]: if h.commit_id in diff_meta_index ] self.push_updates(diffs_to_submit, all_diffs=all_diffs_in_topo_order) + if timer: + timer.mark("push_updates") if new_head := rebase_index.get( old_head := GitCommitHash(self.sh.git("rev-parse", "HEAD")) ): @@ -564,22 +667,31 @@ def run(self) -> List[DiffMeta]: exitcode=True, ) + if timer: + timer.mark("finalize") + timer.report() + # NB: earliest first, which is the intuitive order for unit testing return list(reversed(diffs_to_submit)) # ~~~~~~~~~~~~~~~~~~~~~~~~ # The main pieces - def fetch(self) -> None: - # TODO: Potentially we could narrow this refspec down to only OUR gh - # branches. However, this will interact poorly with cross-author - # so it needs to be thought more carefully - self.sh.git( - "fetch", - "--prune", - self.remote_name, - f"+refs/heads/*:refs/remotes/{self.remote_name}/*", - ) + def fetch(self, refspec: Optional[str] = None) -> None: + if refspec is not None: + self.sh.git( + "fetch", + "--prune", + self.remote_name, + refspec, + ) + else: + self.sh.git( + "fetch", + "--prune", + self.remote_name, + f"+refs/heads/*:refs/remotes/{self.remote_name}/*", + ) def parse_revs(self) -> List[ghstack.git.CommitHeader]: # There are two distinct usage patterns: @@ -679,20 +791,73 @@ def parse_revs(self) -> List[ghstack.git.CommitHeader]: return commits_to_submit_and_boundary + def _prefetch_pr_info( + self, + commits: List[ghstack.git.CommitHeader], + ) -> Dict[GitHubNumber, Any]: + """Batch-fetch PR info for all commits that have existing PRs. + Uses async REST calls to overlap GitHub IO.""" + + pr_numbers: List[GitHubNumber] = [] + for commit in commits: + diff = ghstack.git.convert_header(commit, self.github_url) + if diff.pull_request_resolved is not None: + pr_numbers.append(diff.pull_request_resolved.number) + + if not pr_numbers: + return {} + + unique_numbers = sorted(set(pr_numbers), key=int) + + async def fetch_pr(number: GitHubNumber) -> Tuple[GitHubNumber, Any]: + r = await self.github.arest( + "get", f"repos/{self.repo_owner}/{self.repo_name}/pulls/{number}" + ) + return number, r + + results = _run_async_ordered(fetch_pr(number) for number in unique_numbers) + pr_info: Dict[GitHubNumber, Any] = dict(results) + + return pr_info + + def _fetch_foreign_pr_refs(self, pr_infos: Iterable[Any]) -> None: + usernames: Set[str] = set() + for pr_info in pr_infos: + head_ref_name = self._pr_ref_name(pr_info, "head") + if head_ref_name is None: + continue + m = re.match(r"gh/([^/]+)/([0-9]+)/head$", head_ref_name) + if m is not None and m.group(1) != self.username: + usernames.add(m.group(1)) + + for username in sorted(usernames): + # TODO: Potentially we could narrow this refspec down to only the + # referenced PRs. However, this will interact poorly with + # cross-author stacks, so it needs to be thought more carefully. + self.fetch( + f"+refs/heads/gh/{username}/*" + f":refs/remotes/{self.remote_name}/gh/{username}/*" + ) + def prepare_updates( self, commit_index: Dict[GitCommitHash, ghstack.git.CommitHeader], commits_to_submit: List[ghstack.git.CommitHeader], commits_to_rebase: List[ghstack.git.CommitHeader], + *, + pr_info_cache: Optional[Dict[GitHubNumber, Any]] = None, ) -> Tuple[Dict[GitCommitHash, DiffMeta], Dict[GitCommitHash, GitCommitHash]]: - # Prepare diffs in reverse topological order. - # (Reverse here is important because we must have processed parents - # first.) - # NB: some parts of the algo (namely commit creation) could - # be done in parallel + # Prefetch PR info for all commits with existing PRs (parallel REST GETs) + if pr_info_cache is None: + pr_info_cache = self._prefetch_pr_info(commits_to_rebase) + + # Phase 1: Process all commits (oldest first) to determine what + # needs updating, create head/base commits, and identify new PRs. + # New PRs are NOT pushed/created yet — deferred to batch operation. submit_set = set(h.commit_id for h in commits_to_submit) diff_meta_index: Dict[GitCommitHash, DiffMeta] = {} - rebase_index: Dict[GitCommitHash, GitCommitHash] = {} + pending_new_prs: List[_PendingNewPR] = [] + for commit in reversed(commits_to_rebase): submit = commit.commit_id in submit_set parents = commit.parents @@ -704,7 +869,6 @@ def prepare_updates( ) ) parent = parents[0] - diff_meta = None parent_commit = commit_index[parent] parent_diff_meta = diff_meta_index.get(parent) diff = ghstack.git.convert_header(commit, self.github_url) @@ -713,28 +877,77 @@ def prepare_updates( parent_diff_meta, diff, ( - self.elaborate_diff(diff) + self.elaborate_diff( + diff, + _pr_info=pr_info_cache.get(diff.pull_request_resolved.number), + ) if diff.pull_request_resolved is not None else None ), submit, + pending_new_prs=pending_new_prs, ) if diff_meta is not None: diff_meta_index[commit.commit_id] = diff_meta - # Check if we actually need to rebase it, or can use it as is + # Phase 2: Batch-push branches and create all new PRs. + if pending_new_prs: + # Collect all push specs and push in one call. + all_new_push_specs: List[str] = [] + for pending in pending_new_prs: + all_new_push_specs.extend(pending.push_specs) + if all_new_push_specs: + self._git_push(all_new_push_specs) + + # Create PRs in stack order. GitHub PR numbers are globally allocated, + # so parallel creation makes numbering nondeterministic. + results = [ + ( + pending, + self._create_pull_request( + pending.diff, + pending.base_diff_meta, + pending.ghnum, + ), + ) + for pending in pending_new_prs + ] + + # Update DiffMeta entries with real PR info + for pending, elab_diff in results: + dm = pending.diff_meta + dm.elab_diff = elab_diff + trailers_to_add = [f"ghstack-source-id: {pending.diff.source_id}"] + if self.direct: + trailers_to_add.append( + f"ghstack-comment-id: {elab_diff.comment_id}" + ) + trailers_to_add.append( + f"Pull-Request: {elab_diff.pull_request_resolved.url()}" + ) + dm.commit_msg = ghstack.trailers.interpret_trailers( + strip_mentions(pending.diff.summary.rstrip()), + trailers_to_add, + ) + + # Phase 3: Create orig commits and build rebase index. + # Must happen after Phase 2 so new PRs have correct commit messages. + rebase_index: Dict[GitCommitHash, GitCommitHash] = {} + for commit in reversed(commits_to_rebase): + parent = commit.parents[0] + diff_meta = diff_meta_index.get(commit.commit_id) + + # Check if we actually need to rebase it, or can use it as is. # NB: This is not in process_commit, because we may need - # to rebase a commit even if we didn't submit it + # to rebase a commit even if we didn't submit it. if parent in rebase_index or diff_meta is not None: - # Yes, we need to rebase it - if diff_meta is not None: # use the updated commit message, if it exists commit_msg = diff_meta.commit_msg else: commit_msg = commit.commit_msg - if rebase_id := rebase_index.get(commit.parents[0]): + if rebase_id := rebase_index.get(parent): # use the updated base, if it exists base_commit_id = rebase_id else: @@ -762,9 +975,8 @@ def prepare_updates( ) if diff_meta is not None: - # Add the new_orig to push - # This may not exist. If so, that means this diff only exists - # to update HEAD. + # Add the new_orig to push. This may not exist. If so, + # that means this diff only exists to update HEAD. diff_meta.push_branches.orig.update(GhCommit(new_orig, commit.tree)) rebase_index[commit.commit_id] = new_orig @@ -772,11 +984,15 @@ def prepare_updates( return diff_meta_index, rebase_index def elaborate_diff( - self, diff: ghstack.diff.Diff, *, is_ghexport: bool = False + self, + diff: ghstack.diff.Diff, + *, + is_ghexport: bool = False, + _pr_info: Any = None, ) -> DiffWithGitHubMetadata: """ - Query GitHub API for the current title, body and closed? status - of the pull request corresponding to a ghstack.diff.Diff. + Query GitHub API for the current title, body, branch, and closed? + status of the pull request corresponding to a ghstack.diff.Diff. """ assert diff.pull_request_resolved is not None @@ -784,94 +1000,80 @@ def elaborate_diff( assert diff.pull_request_resolved.repo == self.repo_name number = diff.pull_request_resolved.number - # TODO: There is no reason to do a node query here; we can - # just look up the repo the old fashioned way - r = self.github.graphql( - """ - query ($repo_id: ID!, $number: Int!) { - node(id: $repo_id) { - ... on Repository { - pullRequest(number: $number) { - body - title - closed - headRefName - baseRefName - } - } - } - } - """, - repo_id=self.repo_id, - number=number, - )["data"]["node"]["pullRequest"] + + # Use pre-fetched PR info if available, otherwise fetch now. + if _pr_info is None: + pr_info = self.github.get( + f"repos/{self.repo_owner}/{self.repo_name}/pulls/{number}" + ) + else: + pr_info = _pr_info + + head_ref_name = self._pr_ref_name(pr_info, "head") + if head_ref_name is None: + head_ref_name = self.github.get_head_ref( + owner=self.repo_owner, name=self.repo_name, number=number + ) # Sorry, this is a big hack to support the ghexport case - m = re.match(r"(refs/heads/)?export-D([0-9]+)$", r["headRefName"]) - if m is not None and is_ghexport: + m_export = re.match(r"(refs/heads/)?export-D([0-9]+)$", head_ref_name) + if m_export is not None and is_ghexport: raise RuntimeError( - """\ -This commit appears to already be associated with a pull request, -but the pull request was previously submitted with an old version of -ghexport. You can continue exporting using the old style using: - - ghexport --legacy - -For future diffs, we recommend using the non-legacy version of ghexport -as it supports bidirectional syncing. However, there is no way to -convert a pre-existing PR in the old style to the new format which -supports bidirectional syncing. If you would like to blow away the old -PR and start anew, edit the Summary in the Phabricator diff to delete -the line 'Pull-Request' and then run ghexport again. -""" + "This commit appears to already be associated with a pull request,\n" + "but the pull request was previously submitted with an old version of\n" + "ghexport. You can continue exporting using the old style using:\n\n" + " ghexport --legacy\n\n" + "For future diffs, we recommend using the non-legacy version of ghexport\n" + "as it supports bidirectional syncing. However, there is no way to\n" + "convert a pre-existing PR in the old style to the new format which\n" + "supports bidirectional syncing. If you would like to blow away the old\n" + "PR and start anew, edit the Summary in the Phabricator diff to delete\n" + "the line 'Pull-Request' and then run ghexport again.\n" ) - # TODO: Hmm, I'm not sure why this matches - m = re.match(r"gh/([^/]+)/([0-9]+)/head$", r["headRefName"]) + m = re.match(r"gh/([^/]+)/([0-9]+)/head$", head_ref_name) if m is None: if is_ghexport: raise RuntimeError( - """\ -This commit appears to already be associated with a pull request, -but the pull request doesn't look like it was submitted by ghexport -Maybe you exported it using the "Export to Open Source" button on -the Phabricator diff page? If so, please continue to use that button -to export your diff. - -If you think this is in error, edit the Summary in the Phabricator diff -to delete the line 'Pull-Request' and then run ghexport again. -""" - ) - else: - raise RuntimeError( - """\ -This commit appears to already be associated with a pull request, -but the pull request doesn't look like it was submitted by ghstack. -If you think this is in error, run: - - ghstack unlink {} - -to disassociate the commit with the pull request, and then try again. -(This will create a new pull request!) -""".format( - diff.oid - ) + "This commit appears to already be associated with a pull request,\n" + "but the pull request doesn't look like it was submitted by ghexport\n" + 'Maybe you exported it using the "Export to Open Source" button on\n' + "the Phabricator diff page? If so, please continue to use that button\n" + "to export your diff.\n\n" + "If you think this is in error, edit the Summary in the Phabricator diff\n" + "to delete the line 'Pull-Request' and then run ghexport again.\n" ) + raise RuntimeError( + "This commit appears to already be associated with a pull request,\n" + "but the pull request doesn't look like it was submitted by ghstack.\n" + "If you think this is in error, run:\n\n" + " ghstack unlink {}\n\n" + "to disassociate the commit with the pull request, and then try again.\n" + "(This will create a new pull request!)\n".format(diff.oid) + ) username = m.group(1) gh_number = GhNumber(m.group(2)) - # NB: Technically, we don't need to pull this information at - # all, but it's more convenient to unconditionally edit - # title/body when we update the pull request info - title = r["title"] - pr_body = r["body"] - if self.update_fields: - title, pr_body = self._default_title_and_body(diff, pr_body) + base_ref_name = self._pr_ref_name(pr_info, "base") + if base_ref_name is None: + if self.direct: + base_ref_name = str(branch_next(username, gh_number)) + else: + base_ref_name = str(branch_base(username, gh_number)) + closed = pr_info.get("state") != "open" - # TODO: remote summary should be done earlier so we can use - # it to test if updates are necessary + if self.update_fields: + # NB: Technically, we don't need to pull this information at + # all, but it's more convenient to unconditionally edit + # title/body when we update the pull request info + title, pr_body = self._default_title_and_body(diff, pr_info.get("body")) + else: + title = pr_info["title"] + pr_body = pr_info["body"] or "" try: + # TODO: remote summary should be done earlier so we can use + # it to test if updates are necessary rev_list = self.sh.git( "rev-list", "--max-count=1", @@ -879,11 +1081,7 @@ def elaborate_diff( self.remote_name + "/" + branch_orig(username, gh_number), ) except RuntimeError: - if r["closed"]: - # If the PR is closed and the branch is deleted (e.g., after landing), - # we can't get the remote source ID. Return None for it, which will - # signal to process_commit that this commit has been landed and should - # be skipped (not updated). + if closed: remote_source_id = None comment_id = None else: @@ -901,17 +1099,28 @@ def elaborate_diff( diff=diff, title=title, body=pr_body, - closed=r["closed"], + closed=closed, number=number, username=username, ghnum=gh_number, remote_source_id=remote_source_id, comment_id=comment_id, pull_request_resolved=diff.pull_request_resolved, - head_ref=r["headRefName"], - base_ref=r["baseRefName"], + head_ref=head_ref_name, + base_ref=base_ref_name, ) + def _pr_ref_name(self, pr_info: Any, kind: str) -> Optional[str]: + ref = pr_info.get(kind) + if isinstance(ref, dict): + name = ref.get("ref") + if isinstance(name, str): + return name + name = pr_info.get(f"{kind}RefName") + if isinstance(name, str): + return name + return None + def process_commit( self, base: ghstack.git.CommitHeader, @@ -919,6 +1128,8 @@ def process_commit( diff: ghstack.diff.Diff, elab_diff: Optional[DiffWithGitHubMetadata], submit: bool, + *, + pending_new_prs: List[_PendingNewPR], ) -> Optional[DiffMeta]: # Do not process poisoned commits if "[ghstack-poisoned]" in diff.summary: @@ -990,14 +1201,54 @@ def process_commit( # Create pull request, if needed if elab_diff is None: - # Need to push branches now rather than later, so we can create PR - self._git_push( - [push_spec(p[0], branch(username, ghnum, p[1])) for p in push_branches] - ) + # Defer push and PR creation to batch phase. + # Record push specs for later, use placeholder commit_msg. + new_pr_push_specs = [ + push_spec(p[0], branch(username, ghnum, p[1])) for p in push_branches + ] push_branches.clear() - elab_diff = self._create_pull_request(diff, base_diff_meta, ghnum) - what = "Created" - new_pr = True + + # Placeholder elab_diff — will be replaced after PR creation + placeholder_elab = DiffWithGitHubMetadata( + diff=diff, + number=GitHubNumber(0), + username=username, + remote_source_id=diff.source_id, + comment_id=None, + title=diff.title, + body="", + closed=False, + ghnum=ghnum, + pull_request_resolved=ghstack.diff.PullRequestResolved( + owner=self.repo_owner, + repo=self.repo_name, + number=0, + github_url=self.github_url, + ), + head_ref=str(branch_head(username, ghnum)), + base_ref=base_branch, + ) + # Placeholder commit_msg — will be updated after PR creation + commit_msg = strip_mentions(diff.summary.rstrip()) + + dm = DiffMeta( + elab_diff=placeholder_elab, + commit_msg=commit_msg, + push_branches=push_branches, + what="Created", + base=base_branch, + ) + pending_new_prs.append( + _PendingNewPR( + commit_id=diff.oid, + diff=diff, + base_diff_meta=base_diff_meta, + ghnum=ghnum, + push_specs=new_pr_push_specs, + diff_meta=dm, + ) + ) + return dm else: if not push_branches: what = "Skipped" @@ -1005,36 +1256,17 @@ def process_commit( what = "Skipped (next updated)" else: what = "Updated" - new_pr = False - pull_request_resolved = elab_diff.pull_request_resolved - - if not new_pr: - # Underlying diff can be assumed to have the correct metadata, we - # only need to update it commit_msg = self._update_source_id(diff.summary, elab_diff) - else: - # Need to insert metadata for the first time - # Using our Python implementation of interpret-trailers - trailers_to_add = [f"ghstack-source-id: {diff.source_id}"] - - if self.direct: - trailers_to_add.append(f"ghstack-comment-id: {elab_diff.comment_id}") - trailers_to_add.append(f"Pull-Request: {pull_request_resolved.url()}") - - commit_msg = ghstack.trailers.interpret_trailers( - strip_mentions(diff.summary.rstrip()), trailers_to_add + return DiffMeta( + elab_diff=elab_diff, + commit_msg=commit_msg, + push_branches=push_branches, + what=what, + base=base_branch, ) - return DiffMeta( - elab_diff=elab_diff, - commit_msg=commit_msg, - push_branches=push_branches, - what=what, - base=base_branch, - ) - def _raise_poisoned(self) -> None: raise RuntimeError( """\ @@ -1068,30 +1300,30 @@ def _warn_empty( ) def _allocate_ghnum(self) -> GhNumber: - # Determine the next available GhNumber. We do this by - # iterating through known branches and keeping track - # of the max. The next available GhNumber is the next number. - # This is technically subject to a race, but we assume - # end user is not running this script concurrently on - # multiple machines (you bad bad) + # Check both seen_ghnums (from commits in the current stack) and + # remote refs (which may include ghnums from landed/closed PRs + # whose branches still exist) + # This is technically subject to a race, but we assume the end user is + # not running this script concurrently on multiple machines. + max_seen = max( + ( + int(str(ghnum)) + for user, ghnum in self.seen_ghnums + if user == self.username + ), + default=0, + ) refs = self.sh.git( "for-each-ref", - # Use OUR username here, since there's none attached to the - # diff "refs/remotes/{}/gh/{}".format(self.remote_name, self.username), "--format=%(refname)", ).split() - - def _is_valid_ref(ref: str) -> bool: - splits = ref.split("/") - if len(splits) < 3: - return False - else: - return splits[-2].isnumeric() - - refs = list(filter(_is_valid_ref, refs)) - max_ref_num = max(int(ref.split("/")[-2]) for ref in refs) if refs else 0 - return GhNumber(str(max_ref_num + 1)) + max_ref = 0 + for ref in refs: + parts = ref.split("/") + if len(parts) >= 3 and parts[-2].isnumeric(): + max_ref = max(max_ref, int(parts[-2])) + return GhNumber(str(max(max_seen, max_ref) + 1)) def _sanity_check_ghnum(self, username: str, ghnum: GhNumber) -> None: if (username, ghnum) in self.seen_ghnums: @@ -1550,36 +1782,37 @@ def push_updates( all_diffs: Optional[List[DiffMeta]] = None, import_help: bool = True, ) -> None: - # update pull request information, update bases as necessary - # preferably do this in one network call - # push your commits (be sure to do this AFTER you update bases) - base_push_branches: List[str] = [] - push_branches: List[str] = [] - force_push_branches: List[str] = [] + # Collect all refspecs into a single batched push. This is being + # tested in production because GitHub may observe base/head ref updates + # out of order when refreshing PR diffs. If that happens, revert this + # block to three grouped pushes in this order: + # 1. base branches + # 2. head/next branches + # 3. orig branches + # Per-refspec force is encoded with the + prefix: + # orig branches: always force-pushed + # head/next branches: force-pushed only with --force flag + # base branches: never force-pushed + # It is VERY important that we preserve base-before-head ordering, + # otherwise GitHub can spuriously think that the user pushed a number + # of patches as part of the PR, when actually they were just from the + # new upstream branch. + all_push_specs: List[str] = [] for s in reversed(diffs_to_submit): - # It is VERY important that we do base updates BEFORE real - # head updates, otherwise GitHub will spuriously think that - # the user pushed a number of patches as part of the PR, - # when actually they were just from the (new) upstream - # branch - for diff, b in s.push_branches: + # Careful! Don't push main. if b == "orig": - q = force_push_branches + force = True elif b == "base": - q = base_push_branches + force = False else: - q = push_branches - q.append(push_spec(diff, branch(s.username, s.ghnum, b))) - # Careful! Don't push main. - # TODO: These pushes need to be atomic (somehow) - if base_push_branches: - self._git_push(base_push_branches) - if push_branches: - self._git_push(push_branches, force=self.force) - if force_push_branches: - self._git_push(force_push_branches, force=True) + force = self.force + all_push_specs.append( + push_spec(diff, branch(s.username, s.ghnum, b), force=force) + ) + if all_push_specs: + self._git_push(all_push_specs) # Discover orphan PR numbers from the old stack listing. # We search the full local stack for old stack text, then @@ -1621,8 +1854,7 @@ def push_updates( orphan_above.append(num) break - for s in reversed(diffs_to_submit): - # NB: GraphQL API does not support modifying PRs + def _update_pr_args(s: DiffMeta) -> Tuple[str, Dict[str, Any], Optional[str]]: assert not s.closed logging.info( "# Updating https://{github_url}/{owner}/{repo}/pull/{number}".format( @@ -1632,7 +1864,6 @@ def push_updates( number=s.number, ) ) - # TODO: don't update this if it doesn't need updating base_kwargs = {} if self.direct: base_kwargs["base"] = s.base @@ -1641,24 +1872,40 @@ def push_updates( stack_desc = self._format_stack( diffs_to_submit, s.number, orphan_above, orphan_below ) - self.github.patch( - "repos/{owner}/{repo}/pulls/{number}".format( - owner=self.repo_owner, repo=self.repo_name, number=s.number - ), - # NB: this substitution does nothing on direct PRs - body=RE_STACK.sub( + path = "repos/{owner}/{repo}/pulls/{number}".format( + owner=self.repo_owner, repo=self.repo_name, number=s.number + ) + kwargs = { + "body": RE_STACK.sub( stack_desc, s.body, ), - title=s.title, + "title": s.title, **base_kwargs, - ) - + } + comment_path = None if s.elab_diff.comment_id is not None: - self.github.patch( - f"repos/{self.repo_owner}/{self.repo_name}/issues/comments/{s.elab_diff.comment_id}", - body=stack_desc, + comment_path = ( + f"repos/{self.repo_owner}/{self.repo_name}/issues/comments/" + f"{s.elab_diff.comment_id}" ) + return path, kwargs, comment_path + + async def _update_pr_async(s: DiffMeta) -> None: + # NB: GraphQL API does not support modifying PRs + path, kwargs, comment_path = _update_pr_args(s) + await self.github.arest("patch", path, **kwargs) + + if comment_path is not None: + await self.github.arest( + "patch", + comment_path, + body=self._format_stack( + diffs_to_submit, s.number, orphan_above, orphan_below + ), + ) + + _run_async_ordered(_update_pr_async(s) for s in reversed(diffs_to_submit)) # Report what happened def format_url(s: DiffMeta) -> str: