Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,28 @@ def _handle_challenge_phase(login_server,
logger.debug(add_timestamp("Sending a HTTP Get request to {}".format(request_url)))
challenge = requests.get(request_url, verify=not should_disable_connection_verify())

if challenge.status_code != 401 or 'WWW-Authenticate' not in challenge.headers:
if challenge.status_code not in [401, 403]:
from ._errors import CONNECTIVITY_CHALLENGE_ERROR
if is_diagnostics_context:
return CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server)
raise CLIError(CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server).get_error_message())

authenticate = challenge.headers['WWW-Authenticate']
authenticate = challenge.headers.get('WWW-Authenticate')
if not authenticate:
if is_aad_token and challenge.status_code == 403:
logger.warning(
"Received 403 challenge response without WWW-Authenticate from '%s'. "
"Falling back to default ACR token endpoints.",
login_server,
)
return {
'realm': 'https://{}/oauth2/token'.format(login_server),
'service': login_server
}
from ._errors import CONNECTIVITY_CHALLENGE_ERROR
if is_diagnostics_context:
return CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server)
raise CLIError(CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server).get_error_message())

tokens = authenticate.split(' ', 2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from unittest import mock
import sys

from knack.util import CLIError

from azure.cli.command_modules.acr.repository import (
acr_repository_list,
acr_repository_show_tags,
Expand Down Expand Up @@ -44,6 +46,7 @@
get_access_credentials,
get_authorization_header,
get_manifest_authorization_header,
_handle_challenge_phase,
_resolve_acr_scope,
RepoAccessTokenPermission,
HelmAccessTokenPermission,
Expand Down Expand Up @@ -1219,6 +1222,71 @@ def _core_token_scenarios(self, mock_get_raw_token, mock_requests_get, mock_requ
get_access_credentials(cmd, registry_name, tenant_suffix=tenant_suffix, artifact_repository=TEST_REPOSITORY, permission=HelmAccessTokenPermission.PULL.value)
self._validate_access_token_request(mock_requests_get, mock_requests_post, login_server, 'artifact-repository:{}:{}'.format(TEST_REPOSITORY, HelmAccessTokenPermission.PULL.value))

@mock.patch('azure.cli.core._profile.Profile.get_subscription_id', autospec=True)
@mock.patch('azure.cli.command_modules.acr._docker_utils.get_registry_by_name')
@mock.patch('requests.post', autospec=True)
@mock.patch('requests.get', autospec=True)
@mock.patch('azure.cli.core._profile.Profile.get_raw_token')
def test_get_access_credentials_fallback_on_403_without_www_authenticate(
self, mock_get_raw_token, mock_requests_get, mock_requests_post, mock_get_registry_by_name,
mock_get_subscription):
from azure.mgmt.containerregistry.models import Registry, Sku

registry = Registry(location='westus', sku=Sku(name='Standard'))
login_server = 'testregistry.azurecr.io'
registry.login_server = login_server
mock_get_registry_by_name.return_value = registry, None

cmd = self._setup_cmd()
mock_get_subscription.return_value = TEST_SUBSCRIPTION
mock_get_raw_token.return_value = ('Bearer', TEST_AAD_ACCESS_TOKEN, {}), TEST_SUBSCRIPTION, TEST_TENANT

initial_connectivity_response = mock.MagicMock()
initial_connectivity_response.status_code = 200
initial_connectivity_response.headers = {}
challenge_response = mock.MagicMock()
challenge_response.status_code = 403
challenge_response.headers = {}
mock_requests_get.side_effect = [initial_connectivity_response, challenge_response]

token_response = mock.MagicMock()
token_response.status_code = 200
token_response.headers = {}
token_response.content = json.dumps({
'refresh_token': TEST_ACR_REFRESH_TOKEN,
'access_token': TEST_ACR_ACCESS_TOKEN
}).encode()
mock_requests_post.return_value = token_response

login_server, username, password = get_access_credentials(
cmd,
'testregistry',
artifact_repository=TEST_REPOSITORY,
permission=HelmAccessTokenPermission.PULL.value
)
self.assertEqual((login_server, username, password), ('testregistry.azurecr.io', EMPTY_GUID, TEST_ACR_ACCESS_TOKEN))

mock_requests_post.assert_any_call(
'https://{}/oauth2/exchange'.format(login_server),
urlencode({
'grant_type': 'access_token',
'service': login_server,
'tenant': TEST_TENANT,
'access_token': TEST_AAD_ACCESS_TOKEN
}),
headers={'Content-Type': 'application/x-www-form-urlencoded'},
verify=mock.ANY)
mock_requests_post.assert_any_call(
'https://{}/oauth2/token'.format(login_server),
urlencode({
'grant_type': 'refresh_token',
'service': login_server,
'scope': 'artifact-repository:{}:{}'.format(TEST_REPOSITORY, HelmAccessTokenPermission.PULL.value),
'refresh_token': TEST_ACR_REFRESH_TOKEN
}),
headers={'Content-Type': 'application/x-www-form-urlencoded'},
verify=mock.ANY)

def _setup_mock_token_requests(self, mock_get_aad_token, mock_requests_get, mock_requests_post, login_server):
# Set up AAD token with only access token
mock_get_aad_token.return_value = ('Bearer', TEST_AAD_ACCESS_TOKEN, {}), TEST_SUBSCRIPTION, TEST_TENANT
Expand All @@ -1240,6 +1308,42 @@ def _setup_mock_token_requests(self, mock_get_aad_token, mock_requests_get, mock
'access_token': TEST_ACR_ACCESS_TOKEN}).encode()
mock_requests_post.return_value = token_response

@mock.patch('requests.get', autospec=True)
def test_handle_challenge_phase_allows_403_with_www_authenticate(self, mock_requests_get):
challenge_response = mock.MagicMock()
challenge_response.status_code = 403
challenge_response.headers = {
'WWW-Authenticate': 'Bearer realm="https://testregistry.azurecr.io/oauth2/token",service="testregistry.azurecr.io"'
}
mock_requests_get.return_value = challenge_response

token_params = _handle_challenge_phase(
login_server='testregistry.azurecr.io',
repository=TEST_REPOSITORY,
artifact_repository=None,
permission=RepoAccessTokenPermission.METADATA_READ.value
)
self.assertEqual(
token_params,
{'realm': 'https://testregistry.azurecr.io/oauth2/token', 'service': 'testregistry.azurecr.io'}
)

@mock.patch('requests.get', autospec=True)
def test_handle_challenge_phase_rejects_403_without_www_authenticate_for_non_aad_auth(self, mock_requests_get):
challenge_response = mock.MagicMock()
challenge_response.status_code = 403
challenge_response.headers = {}
mock_requests_get.return_value = challenge_response

with self.assertRaises(CLIError):
_handle_challenge_phase(
login_server='testregistry.azurecr.io',
repository=TEST_REPOSITORY,
artifact_repository=None,
permission=RepoAccessTokenPermission.METADATA_READ.value,
is_aad_token=False
)

def _validate_raw_token_request(self, mock_get_raw_token):
mock_get_raw_token.assert_called_with(mock.ANY, resource="https://containerregistry.azure.net", subscription=mock.ANY)

Expand Down