diff --git a/src/github_app.py b/src/github_app.py index c18b6a9..97e9785 100644 --- a/src/github_app.py +++ b/src/github_app.py @@ -4,12 +4,21 @@ from __future__ import annotations import contextlib +import logging import time from typing import Generator import jwt import requests +GITHUB_API_TIMEOUT = 10 +GITHUB_API_MAX_ATTEMPTS = 3 +GITHUB_API_RETRYABLE_ERRORS = ( + requests.exceptions.ConnectionError, + requests.exceptions.SSLError, + requests.exceptions.Timeout, +) + class GithubAppToken: def __init__(self, private_key, app_id) -> None: @@ -19,8 +28,9 @@ def __init__(self, private_key, app_id) -> None: # configured by the GitHub App and expire after one hour. @contextlib.contextmanager def get_token(self, installation_id: int) -> Generator[str, None, None]: - req = requests.post( - url=f"https://api.github.com/app/installations/{installation_id}/access_tokens", + req = _request_github_api( + "POST", + f"https://api.github.com/app/installations/{installation_id}/access_tokens", headers=self.headers, ) req.raise_for_status() @@ -29,7 +39,8 @@ def get_token(self, installation_id: int) -> Generator[str, None, None]: # This token expires in an hour yield resp["token"] finally: - requests.delete( + _request_github_api( + "DELETE", "https://api.github.com/installation/token", headers={"Authorization": f"token {resp['token']}"}, ) @@ -51,3 +62,25 @@ def get_authentication_header(self, private_key, app_id): "Accept": "application/vnd.github.v3+json", "Authorization": f"Bearer {jwt_token}", } + + +def _request_github_api(method: str, url: str, **kwargs) -> requests.Response: + for attempt in range(1, GITHUB_API_MAX_ATTEMPTS + 1): + try: + return requests.request( + method, + url, + timeout=GITHUB_API_TIMEOUT, + **kwargs, + ) + except GITHUB_API_RETRYABLE_ERRORS: + if attempt == GITHUB_API_MAX_ATTEMPTS: + raise + + logging.warning( + "Transient GitHub App API request failed; retrying", + exc_info=True, + ) + time.sleep(2 ** (attempt - 1)) + + raise RuntimeError("Unreachable GitHub App API retry state") diff --git a/tests/test_github_app.py b/tests/test_github_app.py new file mode 100644 index 0000000..612f68d --- /dev/null +++ b/tests/test_github_app.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from unittest.mock import Mock + +import pytest +import requests + +from src.github_app import GITHUB_API_TIMEOUT, GithubAppToken + + +class DummyResponse: + def __init__(self, payload=None, error=None) -> None: + self.payload = payload or {} + self.error = error + + def raise_for_status(self): + if self.error: + raise self.error + + def json(self): + return self.payload + + +@pytest.fixture +def github_app_token(): + token = GithubAppToken.__new__(GithubAppToken) + token.headers = {"Authorization": "Bearer jwt"} + return token + + +def test_get_token_retries_transient_token_creation_errors( + github_app_token, + monkeypatch, +): + calls = [] + sleep = Mock() + monkeypatch.setattr("src.github_app.time.sleep", sleep) + + def request(method, url, **kwargs): + calls.append((method, url, kwargs)) + post_attempts = len([call for call in calls if call[0] == "POST"]) + if method == "POST" and post_attempts == 1: + raise requests.exceptions.ConnectTimeout("connect timeout") + return DummyResponse({"token": "github-token"}) + + monkeypatch.setattr("src.github_app.requests.request", request) + + with github_app_token.get_token(42) as token: + assert token == "github-token" + + post_calls = [call for call in calls if call[0] == "POST"] + assert len(post_calls) == 2 + assert calls[-1][0] == "DELETE" + assert all(call[2]["timeout"] == GITHUB_API_TIMEOUT for call in calls) + sleep.assert_called_once_with(1) + + +def test_get_token_does_not_retry_http_status_errors(github_app_token, monkeypatch): + calls = [] + + def request(method, url, **kwargs): + calls.append((method, url, kwargs)) + return DummyResponse(error=requests.HTTPError("server error")) + + monkeypatch.setattr("src.github_app.requests.request", request) + + with pytest.raises(requests.HTTPError): + with github_app_token.get_token(42): + pass + + assert len(calls) == 1