diff --git a/src/ghstack/checkout.py b/src/ghstack/checkout.py index 13230f7..72fd152 100644 --- a/src/ghstack/checkout.py +++ b/src/ghstack/checkout.py @@ -1,13 +1,25 @@ #!/usr/bin/env python3 +import asyncio import logging import re +from typing import Iterable import ghstack.github import ghstack.github_utils import ghstack.shell +async def _fetch_refs( + sh: ghstack.shell.Shell, *, remote_name: str, refs: Iterable[str] +) -> None: + refspecs = [ + f"+refs/heads/{ref}:refs/remotes/{remote_name}/{ref}" + for ref in sorted(set(refs)) + ] + await sh.agit("fetch", "--prune", remote_name, *refspecs) + + async def main( pull_request: str, github: ghstack.github.GitHubEndpoint, @@ -19,7 +31,24 @@ async def main( params = await ghstack.github_utils.parse_pull_request( pull_request, sh=sh, remote_name=remote_name ) - head_ref = await github.get_head_ref(**params) + head_ref_task = asyncio.ensure_future(github.get_head_ref(**params)) + + if same_base: + repo_info_task = asyncio.ensure_future( + ghstack.github_utils.get_github_repo_info( + github=github, + sh=sh, + repo_owner=params["owner"], + repo_name=params["name"], + github_url=params["github_url"], + remote_name=remote_name, + ) + ) + head_ref, repo_info = await asyncio.gather(head_ref_task, repo_info_task) + else: + head_ref = await head_ref_task + repo_info = None + orig_ref = re.sub(r"/head$", "/orig", head_ref) if orig_ref == head_ref: logging.warning( @@ -30,15 +59,7 @@ async def main( # If --same-base is specified, check if checkout would change the merge-base if same_base: - # Get the default branch name from the repo - repo_info = await ghstack.github_utils.get_github_repo_info( - github=github, - sh=sh, - repo_owner=params["owner"], - repo_name=params["name"], - github_url=params["github_url"], - remote_name=remote_name, - ) + assert repo_info is not None default_branch = repo_info["default_branch"] default_branch_ref = f"{remote_name}/{default_branch}" @@ -48,7 +69,11 @@ async def main( current_base = None default_branch_ref = None - await sh.agit("fetch", "--prune", remote_name) + refs_to_fetch = [orig_ref] + if same_base: + assert repo_info is not None + refs_to_fetch.append(repo_info["default_branch"]) + await _fetch_refs(sh, remote_name=remote_name, refs=refs_to_fetch) # If --same-base is specified, check what the new merge-base would be if same_base: