diff --git a/src/config.py b/src/config.py index ffc4fb89..0d411161 100644 --- a/src/config.py +++ b/src/config.py @@ -1,11 +1,31 @@ -import functools +"""Configuration logic and schema definitions. + +The `get_config` function provides access to the most recently loaded configuration. +A default configuration is loaded if none is explicitly set. + +Use `set_config` to use a different configuration. +To parse a configuration from a file and environment variables, use `parse_configuration`. +Example of loading a configuration with a custom TOML and .env file: + +``` +config = parse_config( + dotenv_file=Path("path/to/.env"), + configuration_file=Path("path/to/config.toml") +) +set_config(config) +``` +and then consequent calls to `get_config` will return that configuration. +""" + import os import tomllib import typing from pathlib import Path +from typing import Literal, cast from dotenv import load_dotenv from loguru import logger +from pydantic import AnyUrl, BaseModel, Field TomlTable = dict[str, typing.Any] @@ -13,72 +33,149 @@ CONFIG_FILE_ENV = "OPENML_REST_API_CONFIG_FILE" DOTENV_FILE_ENV = "OPENML_REST_API_DOTENV_FILE" -OPENML_DB_USERNAME_ENV = "OPENML_DATABASES_OPENML_USERNAME" -OPENML_DB_PASSWORD_ENV = "OPENML_DATABASES_OPENML_PASSWORD" # noqa: S105 # not a password -EXPDB_DB_USERNAME_ENV = "OPENML_DATABASES_EXPDB_USERNAME" -EXPDB_DB_PASSWORD_ENV = "OPENML_DATABASES_EXPDB_PASSWORD" # noqa: S105 # not a password -_config_directory = Path(os.getenv(CONFIG_DIRECTORY_ENV, Path(__file__).parent)) -_config_directory = _config_directory.expanduser().absolute() -_config_file = Path(os.getenv(CONFIG_FILE_ENV, _config_directory / "config.toml")) -_config_file = _config_file.expanduser().absolute() -_dotenv_file = Path(os.getenv(DOTENV_FILE_ENV, _config_directory / ".env")) -_dotenv_file = _dotenv_file.expanduser().absolute() +_config: Configuration | None = None -logger.info( - "Determined configuration sources.", - configuration_directory=_config_directory, - configuration_file=_config_file, - dotenv_file=_dotenv_file, -) +# The reason we use a module variable instead of functools.cache +# is that this method allows a custom configuration to be set +# through `set_config` and subsequently loaded through `get_config`. +def get_config() -> Configuration: + if _config is None: + config = parse_config() + set_config(config) + return cast("Configuration", _config) -load_dotenv(dotenv_path=_dotenv_file) +def set_config(configuration: Configuration) -> None: + global _config # noqa: PLW0603 + _config = configuration -def _apply_defaults_to_siblings(configuration: TomlTable) -> TomlTable: - defaults = configuration["defaults"] - return { - subtable: (defaults | overrides) if isinstance(overrides, dict) else overrides - for subtable, overrides in configuration.items() - if subtable != "defaults" - } +class Configuration(BaseModel, frozen=True): + openml_database: DatabaseConfiguration + expdb_database: DatabaseConfiguration + development: DevelopmentConfiguration + routing: RoutingConfiguration + logging: list[LoggingConfiguration] -@functools.cache -def _load_configuration(file: Path) -> TomlTable: - return tomllib.loads(file.read_text()) +class DatabaseConfiguration(BaseModel, frozen=True): + """Settings for one database connection.""" -def load_routing_configuration(file: Path = _config_file) -> TomlTable: - return typing.cast("TomlTable", _load_configuration(file)["routing"]) + host: str = Field(default="database", description="Database server host name") + port: int = Field(default=3306, gt=0) + database: str = Field(description="Database name") + username: str = Field(default="root") + password: str = Field(default="ok") + echo: bool = Field( + default=False, + description="https://docs.sqlalchemy.org/en/20/core/engines.html#sqlalchemy.create_engine.params.echo", + ) + drivername: str = Field( + default="mysql+aiomysql", + description="SQLAlchemy `dialect` and `driver`: https://docs.sqlalchemy.org/en/20/dialects/index.html", + ) + + +class DevelopmentConfiguration(BaseModel, frozen=True): + """Settings for development or test specific features.""" + allow_test_api_keys: bool = Field(default=False) -@functools.cache -def load_database_configuration(file: Path = _config_file) -> TomlTable: - configuration = _load_configuration(file) - database_configuration = _apply_defaults_to_siblings( - configuration["databases"], + +class RoutingConfiguration(BaseModel, frozen=True): + root_path: str = Field(default="", description="Path prefix under which the service is hosted.") + minio_url: AnyUrl = Field(description="URL to the MinIO server or service") + server_url: AnyUrl = Field( + description="URL to this server (excluding the path prefix of `fastapi.root_path`).", ) - database_configuration["openml"]["username"] = os.environ.get( - OPENML_DB_USERNAME_ENV, - "root", + + +class LoggingConfiguration(BaseModel, frozen=True): + """Configuration for a single log sink. + + You can add any arguments that `loguru.logger.add` allows, + the `sink` will be used as first positional argument. + See also: https://loguru.readthedocs.io/en/stable/api/logger.html + """ + + sink: str + level: Literal["TRACE", "DEBUG", "INFO", "SUCCESS", "WARNING", "ERROR"] + rotation: str | None = Field( + default=None, + description="Set rotation policy by date or file size.", ) - database_configuration["openml"]["password"] = os.environ.get( - OPENML_DB_PASSWORD_ENV, - "ok", + retention: str | None = Field( + default=None, + description="Timespan after which automatic cleanup occurs.", ) - database_configuration["expdb"]["username"] = os.environ.get( - EXPDB_DB_USERNAME_ENV, - "root", + compression: str | None = Field(default="gz") + # Logs provided variables as JSON + serialize: bool = Field(default=True) + # Decouples log calls from I/O and makes it multiprocessing safe. + enqueue: bool = Field(default=True) + + +def _db_env_credentials(alias: str) -> dict[str, str]: + return { + "username": os.environ.get( + f"OPENML_DATABASES_{alias.upper()}_USERNAME", + "root", + ), + "password": os.environ.get( + f"OPENML_DATABASES_{alias.upper()}_PASSWORD", + "ok", + ), + } + + +def parse_config( + dotenv_file: Path | None = None, + configuration_file: Path | None = None, +) -> Configuration: + """Load configuration from file and environment variables. + + The parsed configuration is returned but not used by default for other calls in this module. + """ + _config_directory = Path(os.getenv(CONFIG_DIRECTORY_ENV, Path(__file__).parent)) + _config_directory = _config_directory.expanduser().absolute() + logger.info( + "Determined configuration directory to be {configuration_directory}.", + configuration_directory=_config_directory, ) - database_configuration["expdb"]["password"] = os.environ.get( - EXPDB_DB_PASSWORD_ENV, - "ok", + + if not dotenv_file: + dotenv_filepath = os.getenv(DOTENV_FILE_ENV, _config_directory / ".env") + dotenv_file = Path(dotenv_filepath).expanduser().absolute() + + logger.info( + "Determined dotenv file path to be {dotenv_file}.", + dotenv_file=dotenv_file, ) - return database_configuration + load_dotenv(dotenv_file) + if not configuration_file: + config_filepath = os.getenv(CONFIG_FILE_ENV, _config_directory / "config.toml") + configuration_file = Path(config_filepath).expanduser().absolute() -def load_configuration(file: Path | None = None) -> TomlTable: - file = file or _config_file - return tomllib.loads(file.read_text()) + logger.info( + "Determined config file path to be {config_file}.", + config_file=configuration_file, + ) + + config = tomllib.loads(configuration_file.read_text()) + db_section = config["databases"] + openml_db = DatabaseConfiguration(**db_section["openml"], **_db_env_credentials("openml")) + expdb_db = DatabaseConfiguration(**db_section["expdb"], **_db_env_credentials("expdb")) + + return Configuration( + routing=RoutingConfiguration(**config["routing"]), + logging=[ + LoggingConfiguration(**sink_configuration) + for sink_configuration in config["logging"].values() + ], + openml_database=openml_db, + expdb_database=expdb_db, + development=DevelopmentConfiguration(**config["development"]), + ) diff --git a/src/config.toml b/src/config.toml index 384067d7..8a23dd87 100644 --- a/src/config.toml +++ b/src/config.toml @@ -1,33 +1,17 @@ -arff_base_url="https://test.openml.org" -minio_base_url="https://openml1.win.tue.nl" - [development] allow_test_api_keys=true # Any number of logging.NAME configurations can be added. # NAME is for reference only, it has no meaning otherwise. -# You can add any arguments to `loguru.logger.add`, -# the `sink` variable will be used as first positional argument. -# https://loguru.readthedocs.io/en/stable/api/logger.html [logging.develop] sink="develop.log" # One of loguru levels: TRACE, DEBUG, INFO, SUCCESS, WARNING, ERROR level="DEBUG" -# Automatically create a new file by date or file size rotation="50 MB" # Retention specifies the timespan after which automatic cleanup occurs. retention="1 day" compression="gz" -[fastapi] -root_path="" - -[databases.defaults] -host="database" -port="3306" -# SQLAlchemy `dialect` and `driver`: https://docs.sqlalchemy.org/en/20/dialects/index.html -drivername="mysql+aiomysql" - [databases.expdb] database="openml_expdb" @@ -35,5 +19,6 @@ database="openml_expdb" database="openml" [routing] +root_path="" minio_url="http://minio:9000/" server_url="http://php-api:80/" diff --git a/src/core/formatting.py b/src/core/formatting.py index 406659fa..faf6d423 100644 --- a/src/core/formatting.py +++ b/src/core/formatting.py @@ -1,7 +1,7 @@ import html from typing import TYPE_CHECKING -from config import load_routing_configuration +from config import get_config from schemas.datasets.openml import DatasetFileFormat if TYPE_CHECKING: @@ -21,14 +21,14 @@ def _format_parquet_url(dataset: Row) -> str | None: if dataset.format.lower() != DatasetFileFormat.ARFF: return None - minio_base_url = load_routing_configuration()["minio_url"] + minio_base_url = get_config().routing.minio_url ten_thousands_prefix = f"{dataset.did // 10_000:04d}" padded_id = f"{dataset.did:04d}" return f"{minio_base_url}datasets/{ten_thousands_prefix}/{padded_id}/dataset_{dataset.did}.pq" def _format_dataset_url(dataset: Row) -> str: - base_url = load_routing_configuration()["server_url"] + base_url = get_config().routing.server_url filename = f"{html.escape(dataset.name)}.{dataset.format.lower()}" return f"{base_url}data/v1/download/{dataset.file_id}/{filename}" diff --git a/src/core/logging.py b/src/core/logging.py index 6546f714..36979733 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -4,31 +4,31 @@ import time import uuid from collections.abc import Awaitable, Callable -from pathlib import Path from typing import TYPE_CHECKING from loguru import logger -from config import load_configuration +from config import LoggingConfiguration if TYPE_CHECKING: from starlette.requests import Request from starlette.responses import Response -def setup_log_sinks(configuration_file: Path | None = None) -> None: +def setup_log_sinks(*configurations: LoggingConfiguration) -> None: """Configure loguru based on app configuration.""" - configuration = load_configuration(configuration_file) - for nickname, sink_configuration in configuration.get("logging", {}).items(): - logger.info("Configuring sink", nickname=nickname, **sink_configuration) - sink = sink_configuration.pop("sink") + for sink_configuration in configurations: + conf = sink_configuration.model_dump() + logger.info("Configuring sink", **conf) + sink = conf.pop("sink") if sink == "sys.stderr": sink = sys.stderr - # Logs the additionally provided data as JSON. - sink_configuration.setdefault("serialize", True) - # Decouples log calls from I/O and makes it multiprocessing safe. - sink_configuration.setdefault("enqueue", True) - logger.add(sink, **sink_configuration) + # defaults may be provided for rotation and retention, + # but they are not valid options for stderr logging. + conf.pop("rotation", None) + conf.pop("retention", None) + conf.pop("compression", None) + logger.add(sink, **conf) async def add_request_context_to_log( diff --git a/src/database/setup.py b/src/database/setup.py index ca877138..f02f379d 100644 --- a/src/database/setup.py +++ b/src/database/setup.py @@ -1,21 +1,24 @@ from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from config import load_database_configuration +from config import DatabaseConfiguration, get_config _user_engine = None _expdb_engine = None -def _create_engine(database_name: str) -> AsyncEngine: - database_configuration = load_database_configuration() - db_config = dict(database_configuration[database_name]) - echo = db_config.pop("echo", False) - - db_url = URL.create(**db_config) +def _create_engine(db_config: DatabaseConfiguration) -> AsyncEngine: + db_url = URL.create( + drivername=db_config.drivername, + username=db_config.username, + password=db_config.password, + host=db_config.host, + port=db_config.port, + database=db_config.database, + ) return create_async_engine( db_url, - echo=echo, + echo=db_config.echo, pool_recycle=3600, ) @@ -23,14 +26,14 @@ def _create_engine(database_name: str) -> AsyncEngine: def user_database() -> AsyncEngine: global _user_engine # noqa: PLW0603 if _user_engine is None: - _user_engine = _create_engine("openml") + _user_engine = _create_engine(get_config().openml_database) return _user_engine def expdb_database() -> AsyncEngine: global _expdb_engine # noqa: PLW0603 if _expdb_engine is None: - _expdb_engine = _create_engine("expdb") + _expdb_engine = _create_engine(get_config().expdb_database) return _expdb_engine diff --git a/src/database/users.py b/src/database/users.py index 8a812d69..bc2d645b 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -1,24 +1,38 @@ import dataclasses +import functools +import re from enum import IntEnum from typing import TYPE_CHECKING, Annotated, Self -from pydantic import StringConstraints +from pydantic import AfterValidator from sqlalchemy import text -from config import load_configuration +from config import get_config if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -# If `allow_test_api_keys` is set, the key may also be one of `normaluser`, -# `normaluser2`, or `abc` (admin). -api_key_pattern = r"^[0-9a-fA-F]{32}$" -if load_configuration().get("development", {}).get("allow_test_api_keys"): - api_key_pattern = r"^([0-9a-fA-F]{32}|normaluser|normaluser2|abc)$" + +api_key_pattern = re.compile(r"^[0-9a-fA-F]{32}$") +# The test database currently contains some non-standard API keys +api_key_pattern_with_test = re.compile(r"^([0-9a-fA-F]{32}|normaluser|normaluser2|abc)$") + + +@functools.cache +def is_valid_api_key(key: str) -> str: + """Raise ValueError if key is not valid, return key otherwise.""" + pattern = api_key_pattern + if get_config().development.allow_test_api_keys: + pattern = api_key_pattern_with_test + if not pattern.match(key): + msg = f"API key {key!r} format is not valid." + raise ValueError(msg) + return key + APIKey = Annotated[ str, - StringConstraints(pattern=api_key_pattern), + AfterValidator(is_valid_api_key), ] diff --git a/src/main.py b/src/main.py index 46cd79cf..fa5ea54f 100644 --- a/src/main.py +++ b/src/main.py @@ -3,14 +3,18 @@ import sys from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from pathlib import Path import uvicorn from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from loguru import logger -from config import load_configuration +from config import ( + Configuration, + get_config, + parse_config, + set_config, +) from core.errors import ( ProblemDetailError, problem_detail_exception_handler, @@ -74,15 +78,17 @@ def _parse_args() -> argparse.Namespace: return parser.parse_args() -def create_api(configuration_file: Path | None = None) -> FastAPI: +def create_api(configuration: Configuration | None = None) -> FastAPI: # Default logging configuration so we have logs during setup logger.remove() setup_sink = logger.add(sys.stderr, serialize=True) - setup_log_sinks(configuration_file) + config = configuration or parse_config() + set_config(config) + setup_log_sinks(*get_config().logging) - fastapi_kwargs = load_configuration(configuration_file)["fastapi"] - logger.info("Creating FastAPI App", lifespan=lifespan, **fastapi_kwargs) - app = FastAPI(**fastapi_kwargs, lifespan=lifespan) + root_path = get_config().routing.root_path + logger.info("Creating FastAPI App", lifespan=lifespan, root_path=root_path) + app = FastAPI(lifespan=lifespan, root_path=root_path) logger.info("Setting up middleware and exception handlers.") # Order matters! Each added middleware wraps the previous, creating a stack. diff --git a/src/routers/mldcat_ap/dataset.py b/src/routers/mldcat_ap/dataset.py index 2749664f..277fe07e 100644 --- a/src/routers/mldcat_ap/dataset.py +++ b/src/routers/mldcat_ap/dataset.py @@ -5,6 +5,7 @@ """ import asyncio +import functools from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends, HTTPException @@ -31,10 +32,12 @@ from sqlalchemy.ext.asyncio import AsyncConnection router = APIRouter(prefix="/mldcat_ap", tags=["MLDCAT-AP"]) -_configuration = config.load_configuration() -_server_url = ( - f"{_configuration['arff_base_url']}{_configuration['fastapi']['root_path']}{router.prefix}" -) + + +@functools.cache +def server_url() -> str: + _routing_configuration = config.get_config().routing + return f"{_routing_configuration.server_url}{_routing_configuration.root_path}{router.prefix}" @router.get( @@ -59,6 +62,7 @@ async def get_mldcat_ap_distribution( get_dataset_features(distribution_id, user, expdb), get_qualities(distribution_id, user, expdb), ) + _server_url = server_url() features = [ Feature( id_=f"{_server_url}/feature/{distribution_id}/{feature.index}", @@ -131,6 +135,7 @@ def get_dataservice(service_id: int) -> JsonLDGraph: if service_id != 1: msg = f"Service with id {service_id} not found." raise ServiceNotFoundError(msg) + _server_url = server_url() return JsonLDGraph( context="https://semiceu.github.io/MLDCAT-AP/releases/1.0.0/context.jsonld", graph=[ @@ -161,6 +166,7 @@ async def get_distribution_quality( status_code=404, detail=f"Quality '{quality_name}' not found for distribution {distribution_id}.", ) + _server_url = server_url() example_quality = Quality( id_=f"{_server_url}/quality/{quality_name}/{distribution_id}", quality_type=f"{_server_url}/quality/{quality_name}", @@ -192,6 +198,7 @@ async def get_distribution_feature( expdb=expdb, ) feature = features[feature_no] + _server_url = server_url() mldcat_feature = Feature( id_=f"{_server_url}/feature/{distribution_id}/{feature.index}", name=feature.name, diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 6627d79f..411ea5b5 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -8,9 +8,9 @@ from fastapi import APIRouter, Body, Depends from sqlalchemy import bindparam, text -import config import database.datasets import database.tasks +from config import get_config from core.errors import InternalError, NoResultsError, TaskNotFoundError from routers.dependencies import Pagination, expdb_connection from routers.types import ( @@ -165,7 +165,8 @@ async def _fill_json_template( # noqa: C901 # I believe that the operations below are always part of string output, so # we don't need to be careful to avoid losing typedness template = template.replace("[TASK:id]", str(task.task_id)) - server_url = config.load_routing_configuration()["server_url"] + url = get_config().routing.server_url + server_url = f"{url.scheme}://{url.host}:{url.port}/" return template.replace("[CONSTANT:base_url]", server_url) diff --git a/tests/config.test.toml b/tests/config.test.toml deleted file mode 100644 index 6942c904..00000000 --- a/tests/config.test.toml +++ /dev/null @@ -1,28 +0,0 @@ -arff_base_url="https://test.openml.org" -minio_base_url="https://openml1.win.tue.nl" - -[development] -allow_test_api_keys=true - -[logging.develop] -sink="sys.stderr" -level="DEBUG" - -[fastapi] -root_path="" - -[databases.defaults] -host="database" -port="3306" -# SQLAlchemy `dialect` and `driver`: https://docs.sqlalchemy.org/en/20/dialects/index.html -drivername="mysql+aiomysql" - -[databases.expdb] -database="openml_expdb" - -[databases.openml] -database="openml" - -[routing] -minio_url="http://minio:9000/" -server_url="http://php-api:80/" diff --git a/tests/config_test.py b/tests/config_test.py index 3218f802..6935a3c3 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -1,36 +1,19 @@ import os -from pathlib import Path from unittest import mock -from config import _apply_defaults_to_siblings, load_database_configuration +from config import _db_env_credentials -def test_apply_defaults_to_siblings_applies_defaults() -> None: - input_ = {"defaults": {1: 1}, "other": {}} - expected = {"other": {1: 1}} - output = _apply_defaults_to_siblings(input_) - assert output == expected +def test__db_env_credentials() -> None: + db_alias = "openml" + credentials = _db_env_credentials(db_alias) + assert credentials["username"] == "root" + assert credentials["password"] == "ok" # noqa: S105 + env_var_name = f"OPENML_DATABASES_{db_alias.upper()}_USERNAME" + env_var_pass = f"OPENML_DATABASES_{db_alias.upper()}_PASSWORD" + with mock.patch.dict(os.environ, {env_var_name: "foo", env_var_pass: "bar"}): + credentials = _db_env_credentials(db_alias) -def test_apply_defaults_to_siblings_does_not_override() -> None: - input_ = {"defaults": {1: 1}, "other": {1: 2}} - expected = {"other": {1: 2}} - output = _apply_defaults_to_siblings(input_) - assert output == expected - - -def test_apply_defaults_to_siblings_ignores_nontables() -> None: - input_ = {"defaults": {1: 1}, "other": {1: 2}, "not-a-table": 3} - expected = {"other": {1: 2}, "not-a-table": 3} - output = _apply_defaults_to_siblings(input_) - assert output == expected - - -def test_load_configuration_adds_environment_variables(default_configuration_file: Path) -> None: - database_configuration = load_database_configuration(default_configuration_file) - assert database_configuration["openml"]["username"] == "root" - - load_database_configuration.cache_clear() - with mock.patch.dict(os.environ, {"OPENML_DATABASES_OPENML_USERNAME": "foo"}): - database_configuration = load_database_configuration(default_configuration_file) - assert database_configuration["openml"]["username"] == "foo" + assert credentials["username"] == "foo" + assert credentials["password"] == "bar" # noqa: S105 diff --git a/tests/conftest.py b/tests/conftest.py index 01ac8c82..23a52955 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,13 @@ from asgi_lifespan import LifespanManager from sqlalchemy import text +from config import ( + Configuration, + DatabaseConfiguration, + DevelopmentConfiguration, + LoggingConfiguration, + RoutingConfiguration, +) from database.setup import expdb_database, user_database from main import create_api from routers.dependencies import expdb_connection, userdb_connection @@ -78,7 +85,16 @@ async def php_api() -> AsyncIterator[httpx.AsyncClient]: @pytest.fixture(scope="session") async def app() -> AsyncIterator[FastAPI]: - _app = create_api(Path(__file__).parent / "config.test.toml") + config = Configuration( + openml_database=DatabaseConfiguration(database="openml"), + expdb_database=DatabaseConfiguration(database="openml_expdb"), + development=DevelopmentConfiguration(allow_test_api_keys=True), + routing=RoutingConfiguration( + minio_url="http://minio:9000", server_url="http://php-api:80/" + ), + logging=[LoggingConfiguration(sink="sys.stderr", level="DEBUG")], + ) + _app = create_api(config) async with LifespanManager(_app): yield _app @@ -123,11 +139,6 @@ def dataset_130() -> Iterator[dict[str, Any]]: yield json.load(dataset_file) -@pytest.fixture -def default_configuration_file() -> Path: - return Path().parent.parent / "src" / "config.toml" - - class Flow(NamedTuple): """To be replaced by an actual ORM class."""