diff --git a/app/auth.py b/app/auth.py index a36b1d6..4dd88d6 100644 --- a/app/auth.py +++ b/app/auth.py @@ -49,8 +49,11 @@ def _decode_token(token: str): def get_current_user_id(token: str = Depends(oauth2_scheme)): - user: dict = _decode_token(token) - return user["sub"] + return get_current_user_claims(token)["sub"] + + +def get_current_user_claims(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]: + return _decode_token(token) async def websocket_authenticate(websocket: WebSocket) -> str | None: diff --git a/app/platforms/implementations/ogc_api_process.py b/app/platforms/implementations/ogc_api_process.py index dd2cffa..546f74f 100644 --- a/app/platforms/implementations/ogc_api_process.py +++ b/app/platforms/implementations/ogc_api_process.py @@ -1,6 +1,6 @@ import re from typing import List -from app.auth import exchange_token +from app.auth import exchange_token, get_current_user_claims from fastapi import Response from loguru import logger @@ -34,11 +34,14 @@ @register_platform(ProcessTypeEnum.OGC_API_PROCESS) class OGCAPIProcessPlatform(BaseProcessingPlatform): input_type_map = { + "date-time": ParamTypeEnum.DATETIME, "date-interval": ParamTypeEnum.DATE_INTERVAL, "bounding-box": ParamTypeEnum.BOUNDING_BOX, "boolean": ParamTypeEnum.BOOLEAN, "integer": ParamTypeEnum.INTEGER, "double": ParamTypeEnum.DOUBLE, + "number": ParamTypeEnum.DOUBLE, + "string": ParamTypeEnum.STRING, } status_mapping = { @@ -53,6 +56,12 @@ class OGCAPIProcessPlatform(BaseProcessingPlatform): r"(?P.+)/processes/(?P[^/]+)$" ) + geojson_schema_references = { + GEOJSON_FEATURECOLLECTION_SCHEMA, + "https://geojson.org/schema/FeatureCollection.json", + "https://geojson.org/schema/Feature.json", + } + """ OGC API Process processing platform implementation. This class handles the execution of processing jobs on the OGC API Process platform. @@ -64,6 +73,76 @@ def _split_job_id(self, job_id) -> tuple[str, ...]: return ("", job_id) return tuple(parts) + def _get_type_from_schema( + self, schema: dict | str | None, input_id: str = "" + ) -> ParamTypeEnum: + if isinstance(schema, str): + if schema in self.__class__.geojson_schema_references: + return ParamTypeEnum.POLYGON + return self.__class__.input_type_map.get(schema, ParamTypeEnum.STRING) + + if not isinstance(schema, dict): + return ParamTypeEnum.STRING + + schema_type = str(schema.get("type")) + schema_format = schema.get("format") + schema_subtype = schema.get("subtype") + + if schema_type == "array" and schema_subtype == "date-interval": + return ParamTypeEnum.DATE_INTERVAL + if schema_type == "array" and schema.get("items", {}).get("type") == "string": + return ParamTypeEnum.ARRAY_STRING + if schema_subtype == "geojson": + return ParamTypeEnum.POLYGON + if schema_subtype == "bounding-box": + return ParamTypeEnum.BOUNDING_BOX + if schema_format == "geojson": + return ParamTypeEnum.POLYGON + if schema_format == "date-time": + return ParamTypeEnum.DATETIME + if schema_type == "object": + required = schema.get("required") or [] + if "type" in required and "coordinates" in required: + type_properties = schema.get("properties", {}).get("type", {}) + type_instance = type_properties + while "actual_instance" in type_instance: + type_instance = type_instance["actual_instance"] + if "Polygon" in type_instance.get("enum", []): + return ParamTypeEnum.POLYGON + if "Point" in type_instance.get("enum", []): + return ParamTypeEnum.POINT + elif "bbox" in required: + return ParamTypeEnum.BOUNDING_BOX + + if isinstance(schema.get("$ref"), str): + return self._get_type_from_schema(schema.get("$ref"), input_id) + + for variant_key in ("oneOf", "anyOf", "allOf"): + variants = schema.get(variant_key) or [] + if not isinstance(variants, list): + continue + for variant in variants: + detected_type = self._get_type_from_schema(variant, input_id) + if detected_type != ParamTypeEnum.STRING: + return detected_type + + properties = schema.get("properties") or {} + if ( + schema.get("title") == "GeoJSON" + or "geometry" in properties + or "features" in properties + or input_id.lower() in {"aoi", "geometry", "geom", "geojson"} + ): + return ParamTypeEnum.POLYGON + + return self.__class__.input_type_map.get(schema_type, ParamTypeEnum.STRING) + + def _get_options_from_schema(self, schema: dict | str | None) -> list: + if not isinstance(schema, dict): + return [] + options = schema.get("enum") + return options if isinstance(options, list) else [] + async def _create_api_client_instance( self, endpoint: str, @@ -91,6 +170,8 @@ async def execute_job( ) -> str: logger.info(f"Executing OGC API job with title={title}") + parameters = await self._transform_parameters(user_token, details, parameters) + # Exchanging token logger.debug("Exchanging user token for OGC API Process execution...") exchanged_token = await exchange_token( @@ -112,7 +193,22 @@ async def execute_job( if exchanged_token: headers["Authorization"] = f"Bearer {exchanged_token}" - data = {"inputs": {key: value for key, value in parameters.items()}} + user_claims = get_current_user_claims(user_token) + properties = { + "title": title, + "application": details.application, + } + if user_claims.get("sub"): + properties["user_id"] = user_claims["sub"] + if user_claims.get("preferred_username"): + properties["username"] = user_claims["preferred_username"] + if user_claims.get("email"): + properties["email"] = user_claims["email"] + + data = { + "inputs": parameters, + "properties": properties, + } content = api_client.execute_simple( process_id=details.application, execute=data, _headers=headers @@ -125,6 +221,49 @@ async def execute_job( return f"{details.namespace}:{job_id}" return job_id + def _transform_bbox_parameter(self, param_name: str, value) -> list[float]: + if isinstance(value, (list, tuple)) and len(value) == 4: + return [float(coord) for coord in value] + + if isinstance(value, dict): + if ["east", "north", "south", "west"] == sorted(value.keys()): + return [ + float(value["west"]), + float(value["south"]), + float(value["east"]), + float(value["north"]), + ] + + raise ValueError( + f"Unsupported bounding box value for parameter {param_name}: {value}" + ) + + async def _transform_parameters( + self, user_token: str, details: ServiceDetails, parameters: dict + ) -> dict: + service_params = await self.get_service_parameters(user_token, details) + + modifiers = { + ParamTypeEnum.BOUNDING_BOX: self._transform_bbox_parameter, + } + + transformed_parameters = parameters.copy() + for param in service_params: + if param.name not in parameters: + continue + + modifier = modifiers.get(param.type) + + if modifier: + transformed_parameters[param.name] = modifier( + param.name, parameters[param.name] + ) + + logger.debug( + f"Transformed parameters for OGC API Process: {transformed_parameters}" + ) + return transformed_parameters + async def execute_synchronous_job( self, user_token: str, @@ -303,53 +442,10 @@ async def get_service_parameters( if process_description.inputs: for input_id, input_details in process_description.inputs.items(): - input_type = ( - input_id, - input_details.model_dump() - .get("var_schema", {}) - .get("actual_instance", {}) - .get("type", ""), - ) - if isinstance(input_type, tuple): - input_type_str = next( - ( - t - for t in input_type - if t - in [ - "date-interval", - "bounding-box", - "boolean", - "integer", - "double", - ] - ), - None, - ) - else: - input_type_str = None - - if input_type_str: - input_type_str = self.__class__.input_type_map.get(input_type_str) - - if not input_type_str: - input_type_str = ParamTypeEnum.STRING - input_types = ( - input_details.model_dump() - .get("var_schema", {}) - .get("actual_instance", {}) - .get("required") - or [] - ) - if "bbox" in input_types: - input_type_str = ParamTypeEnum.BOUNDING_BOX - - input_options = ( + schema = ( input_details.model_dump() .get("var_schema", {}) - .get("actual_instance", {}) - .get("enum") - or [] + .get("actual_instance") ) parameters.append( Parameter( @@ -359,8 +455,8 @@ async def get_service_parameters( else f"Parameter: {input_id}", default=None, optional=(input_details.min_occurs == 0), - type=input_type_str, - options=input_options, + type=self._get_type_from_schema(schema, input_id), + options=self._get_options_from_schema(schema), ) ) diff --git a/app/schemas/parameters.py b/app/schemas/parameters.py index 25ce9f4..90e8b2d 100644 --- a/app/schemas/parameters.py +++ b/app/schemas/parameters.py @@ -11,6 +11,7 @@ class ParamTypeEnum(str, Enum): DATE_INTERVAL = "date-interval" BOUNDING_BOX = "bounding-box" POLYGON = "polygon" + POINT = "point" BOOLEAN = "boolean" INTEGER = "integer" DOUBLE = "double" diff --git a/tests/platforms/test_ogc_api_process_platform.py b/tests/platforms/test_ogc_api_process_platform.py new file mode 100644 index 0000000..75ba9fc --- /dev/null +++ b/tests/platforms/test_ogc_api_process_platform.py @@ -0,0 +1,894 @@ +import sys +import types +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +def _install_ogc_api_processes_client_stub(): + root_module = types.ModuleType("ogc_api_processes_client") + api_client_wrapper_module = types.ModuleType( + "ogc_api_processes_client.api_client_wrapper" + ) + configuration_module = types.ModuleType( + "ogc_api_processes_client.configuration" + ) + models_module = types.ModuleType("ogc_api_processes_client.models") + inline_or_ref_data_module = types.ModuleType( + "ogc_api_processes_client.models.inline_or_ref_data" + ) + input_value_no_object_module = types.ModuleType( + "ogc_api_processes_client.models.input_value_no_object" + ) + link_module = types.ModuleType("ogc_api_processes_client.models.link") + qualified_input_value_module = types.ModuleType( + "ogc_api_processes_client.models.qualified_input_value" + ) + status_code_module = types.ModuleType("ogc_api_processes_client.models.status_code") + status_info_module = types.ModuleType("ogc_api_processes_client.models.status_info") + + class ApiClientWrapper: + pass + + class Configuration: + def __init__(self, host): + self.host = host + + class InlineOrRefData: + pass + + class InputValueNoObject: + pass + + class Link: + pass + + class QualifiedInputValue: + pass + + class StatusCode: + ACCEPTED = "accepted" + RUNNING = "running" + DISMISSED = "dismissed" + SUCCESSFUL = "successful" + FAILED = "failed" + + class StatusInfo: + pass + + api_client_wrapper_module.ApiClientWrapper = ApiClientWrapper + configuration_module.Configuration = Configuration + inline_or_ref_data_module.InlineOrRefData = InlineOrRefData + input_value_no_object_module.InputValueNoObject = InputValueNoObject + link_module.Link = Link + qualified_input_value_module.QualifiedInputValue = QualifiedInputValue + status_code_module.StatusCode = StatusCode + status_info_module.StatusInfo = StatusInfo + + sys.modules["ogc_api_processes_client"] = root_module + sys.modules["ogc_api_processes_client.api_client_wrapper"] = ( + api_client_wrapper_module + ) + sys.modules["ogc_api_processes_client.configuration"] = configuration_module + sys.modules["ogc_api_processes_client.models"] = models_module + sys.modules["ogc_api_processes_client.models.inline_or_ref_data"] = ( + inline_or_ref_data_module + ) + sys.modules["ogc_api_processes_client.models.input_value_no_object"] = ( + input_value_no_object_module + ) + sys.modules["ogc_api_processes_client.models.link"] = link_module + sys.modules["ogc_api_processes_client.models.qualified_input_value"] = ( + qualified_input_value_module + ) + sys.modules["ogc_api_processes_client.models.status_code"] = status_code_module + sys.modules["ogc_api_processes_client.models.status_info"] = status_info_module + + +try: + import ogc_api_processes_client # noqa: F401 +except ModuleNotFoundError: + _install_ogc_api_processes_client_stub() + +from app.platforms.implementations.ogc_api_process import ( # noqa: E402 + GEOJSON_FEATURECOLLECTION_SCHEMA, + OGCAPIProcessPlatform, + STAC_COLLECTION_SCHEMA, +) +from app.schemas.enum import ( # noqa: E402 + OutputFormatEnum, + ProcessingStatusEnum, +) +from app.schemas.parameters import ParamTypeEnum, Parameter # noqa: E402 +from app.schemas.unit_job import ServiceDetails # noqa: E402 + + +@pytest.fixture +def platform(): + return OGCAPIProcessPlatform() + + +def build_input(description, min_occurs, schema): + return SimpleNamespace( + description=description, + min_occurs=min_occurs, + model_dump=lambda: { + "var_schema": { + "actual_instance": schema, + } + }, + ) + + +def build_collection_payload(collection_id="collection-1", title="Test Collection"): + return { + "id": collection_id, + "stac_version": "1.0.0", + "title": title, + "description": "Test STAC collection", + "type": "Collection", + "license": "proprietary", + "links": [], + "extent": { + "spatial": {"bbox": [[-180.0, -90.0, 180.0, 90.0]]}, + "temporal": {"interval": [[None, None]]}, + }, + } + + +@pytest.mark.parametrize( + ("schema", "input_id", "expected_type"), + [ + ({"format": "date-time"}, "datetime", ParamTypeEnum.DATETIME), + ( + {"type": "array", "subtype": "date-interval"}, + "temporal_extent", + ParamTypeEnum.DATE_INTERVAL, + ), + ( + {"type": "object", "subtype": "bounding-box"}, + "spatial_extent", + ParamTypeEnum.BOUNDING_BOX, + ), + (GEOJSON_FEATURECOLLECTION_SCHEMA, "area", ParamTypeEnum.POLYGON), + ({"format": "geojson"}, "area", ParamTypeEnum.POLYGON), + ({"type": "object", "subtype": "geojson"}, "area", ParamTypeEnum.POLYGON), + ( + {"type": "object", "properties": {"features": {"type": "array"}}}, + "area", + ParamTypeEnum.POLYGON, + ), + ({"type": "boolean"}, "enabled", ParamTypeEnum.BOOLEAN), + ({"type": "integer"}, "limit", ParamTypeEnum.INTEGER), + ({"type": "double"}, "scale", ParamTypeEnum.DOUBLE), + ( + {"type": "object", "required": ["coordinates", "type", "bbox"], + "properties": { + "type": { + "actual_instance": { + "actual_instance": { + "enum": ["Polygon"] + } + } + } + } + }, + "aoi", + ParamTypeEnum.POLYGON, + ), + ( + {"type": "object", "required": ["coordinates", "type", "bbox"], + "properties": { + "type": { + "actual_instance": { + "actual_instance": { + "enum": ["Point"] + } + } + } + } + }, + "poi", + ParamTypeEnum.POINT, + ), + ( + {"type": "object", "required": ["bbox", "crs"]}, + "bbox", + ParamTypeEnum.BOUNDING_BOX, + ), + ( + {"type": "array", "items": {"type": "string"}}, + "bands", + ParamTypeEnum.ARRAY_STRING, + ), + ({"type": "number"}, "threshold", ParamTypeEnum.DOUBLE), + ({"type": "string"}, "mode", ParamTypeEnum.STRING), + ( + {"oneOf": [{"type": "string"}, {"format": "geojson"}]}, + "geometry", + ParamTypeEnum.POLYGON, + ), + ], +) +def test_get_type_from_schema(platform, schema, input_id, expected_type): + assert platform._get_type_from_schema(schema, input_id) == expected_type + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + ( + {"west": 4, "south": 50, "east": 5, "north": 51}, + [4.0, 50.0, 5.0, 51.0], + ), + ( + [4, 50, 5, 51], + [4.0, 50.0, 5.0, 51.0], + ), + ], +) +def test_transform_bbox_parameter(platform, value, expected): + assert platform._transform_bbox_parameter("bbox", value) == expected + + +@pytest.mark.parametrize( + "value", + [ + "invalid", + { + "type": "Polygon", + "coordinates": [ + [[4, 50], [5, 50], [5, 51], [4, 51], [4, 50]] + ], + }, + { + "type": "Feature", + "geometry": { + "type": "Polygon", + "coordinates": [ + [[4, 50], [5, 50], [5, 51], [4, 51], [4, 50]] + ], + }, + }, + ], +) +def test_transform_bbox_parameter_invalid_value(platform, value): + with pytest.raises(ValueError, match="Unsupported bounding box value"): + platform._transform_bbox_parameter("bbox", value) + + +@pytest.mark.parametrize( + ("job_id", "expected"), + [ + ("namespace:job-123", ("namespace", "job-123")), + ("job-123", ("", "job-123")), + ], +) +def test_split_job_id(platform, job_id, expected): + assert platform._split_job_id(job_id) == expected + + +@pytest.mark.parametrize( + ("ogc_status, expected_status"), + [ + ("accepted", ProcessingStatusEnum.CREATED), + ("running", ProcessingStatusEnum.RUNNING), + ("dismissed", ProcessingStatusEnum.CANCELED), + ("successful", ProcessingStatusEnum.FINISHED), + ("failed", ProcessingStatusEnum.FAILED), + ("unknown", ProcessingStatusEnum.UNKNOWN), + (None, ProcessingStatusEnum.UNKNOWN), + ], +) +def test_map_ogcapi_status(platform, ogc_status, expected_status): + assert platform._map_ogcapi_status(ogc_status) == expected_status + + +@pytest.mark.asyncio +@patch.object(OGCAPIProcessPlatform, "get_service_parameters", new_callable=AsyncMock) +async def test_transform_parameters_applies_bbox_modifier( + mock_get_service_parameters, platform +): + mock_get_service_parameters.return_value = [ + Parameter( + name="bbox", + description="Spatial extent", + type=ParamTypeEnum.BOUNDING_BOX, + optional=False, + ), + Parameter( + name="mode", + description="Execution mode", + type=ParamTypeEnum.STRING, + optional=False, + ), + ] + + result = await platform._transform_parameters( + "token", + ServiceDetails(endpoint="https://example.com", application="buffer"), + { + "bbox": {"west": 4, "south": 50, "east": 5, "north": 51}, + "mode": "fast", + }, + ) + + assert result == { + "bbox": [4.0, 50.0, 5.0, 51.0], + "mode": "fast", + } + + +@pytest.mark.asyncio +@patch("app.platforms.implementations.ogc_api_process.ApiClientWrapper") +async def test_create_api_client_instance_without_namespace(mock_api_client, platform): + await platform._create_api_client_instance("https://example.com", "", None) + + configuration = mock_api_client.call_args.args[0] + assert configuration.host == "https://example.com" + assert mock_api_client.call_args.kwargs == {} + + +@pytest.mark.asyncio +@patch("app.platforms.implementations.ogc_api_process.ApiClientWrapper") +async def test_create_api_client_instance_with_token_and_namespace( + mock_api_client, platform +): + await platform._create_api_client_instance( + "https://example.com", "ns", "exchanged-token" + ) + + configuration = mock_api_client.call_args.args[0] + assert configuration.host == "https://example.com/ns" + assert mock_api_client.call_args.kwargs == { + "header_name": "Authorization", + "header_value": "Bearer exchanged-token", + } + + +@pytest.mark.asyncio +@patch( + "app.platforms.implementations.ogc_api_process.exchange_token", + new_callable=AsyncMock, +) +@patch("app.platforms.implementations.ogc_api_process.get_current_user_claims") +@patch.object(OGCAPIProcessPlatform, "_transform_parameters", new_callable=AsyncMock) +@patch.object(OGCAPIProcessPlatform, "_create_api_client_instance", new_callable=AsyncMock) +async def test_execute_job_returns_namespaced_job_id( + mock_create_api_client, + mock_transform_parameters, + mock_get_current_user_claims, + mock_exchange_token, + platform, +): + mock_transform_parameters.return_value = {"bbox": [4.0, 50.0, 5.0, 51.0]} + mock_exchange_token.return_value = "exchanged-token" + mock_get_current_user_claims.return_value = { + "sub": "user-123", + "preferred_username": "alice", + "email": "alice@example.com", + } + api_client = MagicMock() + api_client.execute_simple.return_value = SimpleNamespace(job_id="job-123") + mock_create_api_client.return_value = api_client + + result = await platform.execute_job( + user_token="token", + title="My job", + details=ServiceDetails( + endpoint="https://example.com", + namespace="ns", + application="buffer", + ), + parameters={"geometry": {"type": "Polygon"}}, + format=OutputFormatEnum.GEOTIFF, + ) + + assert result == "ns:job-123" + api_client.execute_simple.assert_called_once_with( + process_id="buffer", + execute={ + "inputs": {"bbox": [4.0, 50.0, 5.0, 51.0]}, + "properties": { + "title": "My job", + "application": "buffer", + "user_id": "user-123", + "username": "alice", + "email": "alice@example.com", + }, + }, + _headers={ + "accept": "*/*", + "Content-Type": "application/json", + "Authorization": "Bearer exchanged-token", + }, + ) + + +@pytest.mark.asyncio +@patch( + "app.platforms.implementations.ogc_api_process.exchange_token", + new_callable=AsyncMock, +) +@patch("app.platforms.implementations.ogc_api_process.get_current_user_claims") +@patch.object(OGCAPIProcessPlatform, "_transform_parameters", new_callable=AsyncMock) +@patch.object(OGCAPIProcessPlatform, "_create_api_client_instance", new_callable=AsyncMock) +async def test_execute_job_returns_plain_job_id_without_namespace( + mock_create_api_client, + mock_transform_parameters, + mock_get_current_user_claims, + mock_exchange_token, + platform, +): + mock_transform_parameters.return_value = {"limit": 10} + mock_exchange_token.return_value = None + mock_get_current_user_claims.return_value = {"sub": "user-123"} + api_client = MagicMock() + api_client.execute_simple.return_value = SimpleNamespace(job_id="job-123") + mock_create_api_client.return_value = api_client + + result = await platform.execute_job( + user_token="token", + title="My job", + details=ServiceDetails( + endpoint="https://example.com", + application="buffer", + ), + parameters={"limit": 10}, + format=OutputFormatEnum.GEOTIFF, + ) + + assert result == "job-123" + api_client.execute_simple.assert_called_once_with( + process_id="buffer", + execute={ + "inputs": {"limit": 10}, + "properties": { + "title": "My job", + "application": "buffer", + "user_id": "user-123", + }, + }, + _headers={ + "accept": "*/*", + "Content-Type": "application/json", + }, + ) + + +@pytest.mark.asyncio +@patch( + "app.platforms.implementations.ogc_api_process.exchange_token", + new_callable=AsyncMock, +) +@patch("app.platforms.implementations.ogc_api_process.get_current_user_claims") +@patch.object(OGCAPIProcessPlatform, "_transform_parameters", new_callable=AsyncMock) +@patch.object(OGCAPIProcessPlatform, "_create_api_client_instance", new_callable=AsyncMock) +async def test_execute_job_omits_missing_optional_user_fields( + mock_create_api_client, + mock_transform_parameters, + mock_get_current_user_claims, + mock_exchange_token, + platform, +): + mock_transform_parameters.return_value = {"limit": 10} + mock_exchange_token.return_value = "exchanged-token" + mock_get_current_user_claims.return_value = {} + api_client = MagicMock() + api_client.execute_simple.return_value = SimpleNamespace(job_id="job-123") + mock_create_api_client.return_value = api_client + + await platform.execute_job( + user_token="token", + title="My job", + details=ServiceDetails( + endpoint="https://example.com", + application="buffer", + ), + parameters={"limit": 10}, + format=OutputFormatEnum.GEOTIFF, + ) + + api_client.execute_simple.assert_called_once_with( + process_id="buffer", + execute={ + "inputs": {"limit": 10}, + "properties": {"title": "My job", "application": "buffer"}, + }, + _headers={ + "accept": "*/*", + "Content-Type": "application/json", + "Authorization": "Bearer exchanged-token", + }, + ) + + +@pytest.mark.asyncio +async def test_execute_synchronous_job_not_implemented(platform): + with pytest.raises(NotImplementedError, match="not implemented yet"): + await platform.execute_synchronous_job( + user_token="token", + title="My job", + details=ServiceDetails( + endpoint="https://example.com", + application="buffer", + ), + parameters={}, + format=OutputFormatEnum.GEOTIFF, + ) + + +@pytest.mark.asyncio +@patch( + "app.platforms.implementations.ogc_api_process.exchange_token", + new_callable=AsyncMock, +) +@patch.object(OGCAPIProcessPlatform, "_create_api_client_instance", new_callable=AsyncMock) +async def test_get_job_status_maps_client_status( + mock_create_api_client, mock_exchange_token, platform +): + mock_exchange_token.return_value = "exchanged-token" + api_client = MagicMock() + api_client.get_status.return_value = SimpleNamespace(status="running") + mock_create_api_client.return_value = api_client + + result = await platform.get_job_status( + user_token="token", + job_id="ns:job-123", + details=ServiceDetails( + endpoint="https://example.com", + namespace="ignored-by-job-id", + application="buffer", + ), + ) + + assert result == ProcessingStatusEnum.RUNNING + mock_create_api_client.assert_awaited_once_with( + "https://example.com", "ns", "exchanged-token" + ) + api_client.get_status.assert_called_once_with(job_id="job-123") + + +@pytest.mark.asyncio +@patch( + "app.platforms.implementations.ogc_api_process.exchange_token", + new_callable=AsyncMock, +) +@patch.object(OGCAPIProcessPlatform, "_create_api_client_instance", new_callable=AsyncMock) +async def test_get_job_results_returns_stac_collection( + mock_create_api_client, mock_exchange_token, platform +): + mock_exchange_token.return_value = "exchanged-token" + api_client = MagicMock() + api_client.get_result.return_value = { + "result": SimpleNamespace( + actual_instance=SimpleNamespace( + var_schema=SimpleNamespace(actual_instance=STAC_COLLECTION_SCHEMA), + value=SimpleNamespace( + actual_instance=build_collection_payload("collection-from-stac") + ), + ) + ) + } + mock_create_api_client.return_value = api_client + + result = await platform.get_job_results( + user_token="token", + job_id="ns:job-123", + details=ServiceDetails( + endpoint="https://example.com", + namespace="ns", + application="buffer", + ), + ) + + assert result.id == "collection-from-stac" + assert result.title == "Test Collection" + + +@pytest.mark.asyncio +@patch("app.platforms.implementations.ogc_api_process.http_get") +@patch( + "app.platforms.implementations.ogc_api_process.exchange_token", + new_callable=AsyncMock, +) +@patch.object(OGCAPIProcessPlatform, "_create_api_client_instance", new_callable=AsyncMock) +async def test_get_job_results_follows_geojson_collection_link( + mock_create_api_client, mock_exchange_token, mock_http_get, platform +): + mock_exchange_token.return_value = "exchanged-token" + api_client = MagicMock() + api_client.get_result.return_value = { + "result": SimpleNamespace( + actual_instance=SimpleNamespace( + var_schema=SimpleNamespace( + actual_instance=GEOJSON_FEATURECOLLECTION_SCHEMA + ), + value=SimpleNamespace( + oneof_schema_2_validator={ + "features": [ + { + "links": [ + { + "rel": "collection", + "href": "https://example.com/collections/1", + } + ] + } + ] + } + ), + ) + ) + } + mock_create_api_client.return_value = api_client + mock_http_get.return_value = MagicMock() + mock_http_get.return_value.json.return_value = build_collection_payload( + "collection-from-geojson" + ) + + result = await platform.get_job_results( + user_token="token", + job_id="ns:job-123", + details=ServiceDetails( + endpoint="https://example.com", + namespace="ns", + application="buffer", + ), + ) + + assert result.id == "collection-from-geojson" + mock_http_get.assert_called_once_with( + "https://example.com/collections/1", + follow_redirects=True, + headers={"Authorization": "Bearer exchanged-token"}, + ) + mock_http_get.return_value.raise_for_status.assert_called_once_with() + + +@pytest.mark.asyncio +@patch( + "app.platforms.implementations.ogc_api_process.exchange_token", + new_callable=AsyncMock, +) +@patch.object(OGCAPIProcessPlatform, "_create_api_client_instance", new_callable=AsyncMock) +async def test_get_job_results_returns_empty_collection_when_no_supported_result( + mock_create_api_client, mock_exchange_token, platform +): + mock_exchange_token.return_value = "exchanged-token" + api_client = MagicMock() + api_client.get_result.return_value = { + "result": SimpleNamespace(actual_instance=None) + } + mock_create_api_client.return_value = api_client + + result = await platform.get_job_results( + user_token="token", + job_id="ns:job-123", + details=ServiceDetails( + endpoint="https://example.com", + namespace="ns", + application="buffer", + ), + ) + + assert result.id == "ns-job-123" + assert result.title == "Results for buffer" + assert result.license == "proprietary" + + +@pytest.mark.asyncio +@patch( + "app.platforms.implementations.ogc_api_process.exchange_token", + new_callable=AsyncMock, +) +@patch.object(OGCAPIProcessPlatform, "_create_api_client_instance", new_callable=AsyncMock) +async def test_get_service_parameters_maps_geojson_and_options( + mock_create_api_client, mock_exchange_token, platform +): + mock_exchange_token.return_value = "exchanged-token" + + geojson_input = build_input( + "Area of interest", + 1, + { + "oneOf": [ + {"type": "string"}, + {"format": "geojson"}, + ] + }, + ) + enum_input = build_input( + "Output mode", + 0, + { + "type": "string", + "enum": ["fast", "accurate"], + }, + ) + + api_client = MagicMock() + api_client.get_process_description.return_value = SimpleNamespace( + inputs={ + "geometry": geojson_input, + "mode": enum_input, + } + ) + mock_create_api_client.return_value = api_client + + result = await platform.get_service_parameters( + user_token="token", + details=ServiceDetails( + endpoint="https://example.com", + namespace="ns", + application="my-process", + ), + ) + + assert result == [ + Parameter( + name="geometry", + description="Area of interest", + default=None, + optional=False, + type=ParamTypeEnum.POLYGON, + options=[], + ), + Parameter( + name="mode", + description="Output mode", + default=None, + optional=True, + type=ParamTypeEnum.STRING, + options=["fast", "accurate"], + ), + ] + + +@pytest.mark.asyncio +@patch( + "app.platforms.implementations.ogc_api_process.exchange_token", + new_callable=AsyncMock, +) +@patch.object(OGCAPIProcessPlatform, "_create_api_client_instance", new_callable=AsyncMock) +async def test_get_service_parameters_maps_all_supported_types( + mock_create_api_client, mock_exchange_token, platform +): + mock_exchange_token.return_value = "exchanged-token" + + api_client = MagicMock() + api_client.get_process_description.return_value = SimpleNamespace( + inputs={ + "acquired_at": build_input( + "Acquisition datetime", + 1, + {"format": "date-time"}, + ), + "temporal_extent": build_input( + "Temporal range", + 0, + {"type": "array", "subtype": "date-interval"}, + ), + "bbox": build_input( + "Spatial extent", + 1, + {"type": "object", "subtype": "bounding-box"}, + ), + "geometry": build_input( + "Area of interest", + 0, + {"format": "geojson"}, + ), + "enabled": build_input( + "Boolean flag", + 1, + {"type": "boolean"}, + ), + "limit": build_input( + "Maximum number of items", + 0, + {"type": "integer"}, + ), + "threshold": build_input( + "Threshold value", + 1, + {"type": "number"}, + ), + "mode": build_input( + "Execution mode", + 0, + {"type": "string", "enum": ["fast", "accurate"]}, + ), + "bands": build_input( + "Band list", + 1, + {"type": "array", "items": {"type": "string"}}, + ), + } + ) + mock_create_api_client.return_value = api_client + + result = await platform.get_service_parameters( + user_token="token", + details=ServiceDetails( + endpoint="https://example.com", + namespace="ns", + application="my-process", + ), + ) + + assert result == [ + Parameter( + name="acquired_at", + description="Acquisition datetime", + default=None, + optional=False, + type=ParamTypeEnum.DATETIME, + options=[], + ), + Parameter( + name="temporal_extent", + description="Temporal range", + default=None, + optional=True, + type=ParamTypeEnum.DATE_INTERVAL, + options=[], + ), + Parameter( + name="bbox", + description="Spatial extent", + default=None, + optional=False, + type=ParamTypeEnum.BOUNDING_BOX, + options=[], + ), + Parameter( + name="geometry", + description="Area of interest", + default=None, + optional=True, + type=ParamTypeEnum.POLYGON, + options=[], + ), + Parameter( + name="enabled", + description="Boolean flag", + default=None, + optional=False, + type=ParamTypeEnum.BOOLEAN, + options=[], + ), + Parameter( + name="limit", + description="Maximum number of items", + default=None, + optional=True, + type=ParamTypeEnum.INTEGER, + options=[], + ), + Parameter( + name="threshold", + description="Threshold value", + default=None, + optional=False, + type=ParamTypeEnum.DOUBLE, + options=[], + ), + Parameter( + name="mode", + description="Execution mode", + default=None, + optional=True, + type=ParamTypeEnum.STRING, + options=["fast", "accurate"], + ), + Parameter( + name="bands", + description="Band list", + default=None, + optional=False, + type=ParamTypeEnum.ARRAY_STRING, + options=[], + ), + ] diff --git a/tests/test_auth.py b/tests/test_auth.py index 881fe3b..9e2cbad 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -3,13 +3,47 @@ import httpx from fastapi import status -from app.auth import exchange_token, _exchange_token_for_provider +from app.auth import ( + exchange_token, + _exchange_token_for_provider, + get_current_user_claims, + get_current_user_id, +) from app.config.settings import settings from app.config.schemas import BackendAuthConfig, AuthMethod from app.error import AuthException # Tests for exchange_token function +@patch("app.auth._decode_token") +def test_get_current_user_claims(mock_decode_token): + mock_decode_token.return_value = { + "sub": "user-123", + "preferred_username": "alice", + } + + result = get_current_user_claims("token") + + assert result == { + "sub": "user-123", + "preferred_username": "alice", + } + mock_decode_token.assert_called_once_with("token") + + +@patch("app.auth._decode_token") +def test_get_current_user_id(mock_decode_token): + mock_decode_token.return_value = { + "sub": "user-123", + "preferred_username": "alice", + } + + result = get_current_user_id("token") + + assert result == "user-123" + mock_decode_token.assert_called_once_with("token") + + @pytest.mark.asyncio async def test_exchange_token_missing_provider():