diff --git a/changelog.d/us-long-term-datasets.added.md b/changelog.d/us-long-term-datasets.added.md new file mode 100644 index 00000000..15c38fd0 --- /dev/null +++ b/changelog.d/us-long-term-datasets.added.md @@ -0,0 +1 @@ +Add metadata-aware loading for pre-built long-term US projected datasets. diff --git a/pyproject.toml b/pyproject.toml index d76cc0f6..3d4a43e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ dependencies = [ "pydantic>=2.0.0", "pandas>=2.0.0", + "h5py>=3.0.0", "microdf_python>=1.2.1", "jsonschema>=4.0.0", "requests>=2.31.0", diff --git a/src/policyengine/tax_benefit_models/us/__init__.py b/src/policyengine/tax_benefit_models/us/__init__.py index d49d46d4..3bc605bd 100644 --- a/src/policyengine/tax_benefit_models/us/__init__.py +++ b/src/policyengine/tax_benefit_models/us/__init__.py @@ -37,6 +37,8 @@ create_datasets, ensure_datasets, load_datasets, + load_long_term_datasets, + validate_long_term_dataset_metadata, ) from .household import calculate_household from .model import ( @@ -60,7 +62,9 @@ "PolicyEngineUSDataset", "create_datasets", "load_datasets", + "load_long_term_datasets", "ensure_datasets", + "validate_long_term_dataset_metadata", "PolicyEngineUS", "PolicyEngineUSLatest", "managed_microsimulation", diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index 014309db..6ba01b6f 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -1,10 +1,13 @@ +import json import warnings +from importlib import metadata as importlib_metadata from pathlib import Path -from typing import Optional +from typing import Any, Optional +import h5py import pandas as pd from microdf import MicroDataFrame -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from policyengine.core import Dataset, YearData from policyengine.provenance.manifest import ( @@ -42,6 +45,8 @@ class PolicyEngineUSDataset(Dataset): """US dataset with multi-year entity-level data.""" data: Optional[USYearData] = None + metadata: dict[str, Any] = Field(default_factory=dict) + metadata_filepath: Optional[str] = None def model_post_init(self, __context) -> None: """Called after Pydantic initialization.""" @@ -73,6 +78,10 @@ def save(self) -> None: def load(self) -> None: """Load dataset from HDF5 file into this instance.""" filepath = self.filepath + if _is_policyengine_core_h5(Path(filepath)): + self.data = _load_policyengine_core_h5(Path(filepath), self.year) + return + with pd.HDFStore(filepath, mode="r") as store: self.data = USYearData( person=MicroDataFrame(store["person"], weights="person_weight"), @@ -100,6 +109,154 @@ def __repr__(self) -> str: return f"" +US_ENTITY_KEYS = ( + "person", + "household", + "tax_unit", + "spm_unit", + "family", + "marital_unit", +) + +US_ENTITY_ID_COLUMNS = {entity: f"{entity}_id" for entity in US_ENTITY_KEYS} +US_ENTITY_WEIGHT_COLUMNS = {entity: f"{entity}_weight" for entity in US_ENTITY_KEYS} +US_PERSON_ENTITY_ID_COLUMNS = { + "household": "person_household_id", + "tax_unit": "person_tax_unit_id", + "spm_unit": "person_spm_unit_id", + "family": "person_family_id", + "marital_unit": "person_marital_unit_id", +} + + +def _is_policyengine_core_h5(path: Path) -> bool: + """Return whether ``path`` uses PolicyEngine core's variable/period H5 layout.""" + + try: + with h5py.File(path, "r") as h5_file: + node = h5_file.get("person_id") + return isinstance(node, h5py.Group) + except OSError: + return False + + +def _read_core_h5_period_values( + h5_file: h5py.File, + variable_name: str, + year: int, +) -> Any: + group = h5_file[variable_name] + period = str(year) + if period not in group: + periods = [key for key in group.keys() if key != "ETERNITY"] + period = sorted(periods)[0] if periods else sorted(group.keys())[0] + values = group[period][:] + if getattr(values, "dtype", None) is not None and values.dtype.kind in {"O", "S"}: + return ( + pd.Series(values) + .map( + lambda value: ( + value.decode("utf-8") if isinstance(value, bytes) else value + ) + ) + .to_numpy() + ) + return values + + +def _core_h5_entity_lengths(h5_file: h5py.File, year: int) -> dict[str, int]: + lengths: dict[str, int] = {} + for entity, id_column in US_ENTITY_ID_COLUMNS.items(): + if id_column in h5_file: + lengths[entity] = len(_read_core_h5_period_values(h5_file, id_column, year)) + return lengths + + +def _core_h5_variable_entities() -> dict[str, str]: + from policyengine_us.system import system + + return {name: variable.entity.key for name, variable in system.variables.items()} + + +def _assign_missing_entity_weights(data: dict[str, pd.DataFrame]) -> None: + household = data["household"] + if "household_id" not in household or "household_weight" not in household: + return + + household_weights = household[["household_id", "household_weight"]] + person = data["person"] + if ( + "person_weight" not in person + and "person_household_id" in person + and len(person) > 0 + ): + person.loc[:, "person_weight"] = person["person_household_id"].map( + household_weights.set_index("household_id")["household_weight"] + ) + + for entity, person_entity_id in US_PERSON_ENTITY_ID_COLUMNS.items(): + if entity == "household": + continue + entity_id = US_ENTITY_ID_COLUMNS[entity] + entity_weight = US_ENTITY_WEIGHT_COLUMNS[entity] + if ( + entity_weight in data[entity] + or entity_id not in data[entity] + or person_entity_id not in person + or "person_household_id" not in person + ): + continue + + entity_households = person[ + [person_entity_id, "person_household_id"] + ].drop_duplicates(subset=[person_entity_id]) + weight_lookup = entity_households.merge( + household_weights, + left_on="person_household_id", + right_on="household_id", + how="left", + ).set_index(person_entity_id)["household_weight"] + data[entity].loc[:, entity_weight] = data[entity][entity_id].map(weight_lookup) + + +def _load_policyengine_core_h5(path: Path, year: int) -> USYearData: + """Load a PolicyEngine core variable/period H5 into .py entity DataFrames.""" + + data = {entity: pd.DataFrame() for entity in US_ENTITY_KEYS} + variable_entities = _core_h5_variable_entities() + + with h5py.File(path, "r") as h5_file: + entity_lengths = _core_h5_entity_lengths(h5_file, year) + for variable_name in h5_file.keys(): + values = _read_core_h5_period_values(h5_file, variable_name, year) + entity = variable_entities.get(variable_name) + if entity is None: + matching_entities = [ + key + for key, length in entity_lengths.items() + if length == len(values) + ] + if len(matching_entities) != 1: + continue + entity = matching_entities[0] + if entity not in data: + continue + data[entity][variable_name] = values + + _assign_missing_entity_weights(data) + + return USYearData( + person=MicroDataFrame(data["person"], weights="person_weight"), + household=MicroDataFrame(data["household"], weights="household_weight"), + tax_unit=MicroDataFrame(data["tax_unit"], weights="tax_unit_weight"), + spm_unit=MicroDataFrame(data["spm_unit"], weights="spm_unit_weight"), + family=MicroDataFrame(data["family"], weights="family_weight"), + marital_unit=MicroDataFrame( + data["marital_unit"], weights="marital_unit_weight" + ), + ) + + def create_datasets( datasets: list[str] = [ "enhanced_cps_2024", @@ -323,6 +480,357 @@ def load_datasets( return result +CALIBRATION_QUALITY_RANK = { + "aggregate": 0, + "approximate": 1, + "exact": 2, +} + + +def _metadata_path_for_h5(path: Path) -> Path: + return Path(f"{path}.metadata.json") + + +def _load_dataset_metadata( + path: Path, require_metadata: bool +) -> tuple[dict, Optional[Path]]: + metadata_path = _metadata_path_for_h5(path) + if not metadata_path.exists(): + if require_metadata: + raise FileNotFoundError( + f"Long-term dataset metadata missing for {path}. Expected " + f"{metadata_path}." + ) + return {}, None + return json.loads(metadata_path.read_text(encoding="utf-8")), metadata_path + + +def _quality_rank(quality: str) -> int: + try: + return CALIBRATION_QUALITY_RANK[quality] + except KeyError as error: + valid = ", ".join(sorted(CALIBRATION_QUALITY_RANK)) + raise ValueError( + f"Unknown calibration quality {quality!r}. Valid qualities: {valid}." + ) from error + + +def _require_metadata_value( + metadata: dict, + path: Path, + label: str, + actual: Any, + expected: Any, +) -> None: + if expected is None: + return + if actual != expected: + raise ValueError( + f"Long-term dataset {path} has {label}={actual!r}, expected {expected!r}." + ) + + +def _policyengine_us_git_sha(policyengine_us_metadata: dict) -> Optional[str]: + direct_url = policyengine_us_metadata.get("direct_url") or {} + vcs_info = direct_url.get("vcs_info") or {} + for key in ("commit_id", "git_commit_id", "vcs_commit_id"): + value = vcs_info.get(key) or policyengine_us_metadata.get(key) + if value: + return str(value) + return None + + +def _runtime_policyengine_us_metadata() -> dict[str, Any]: + try: + distribution = importlib_metadata.distribution("policyengine-us") + except importlib_metadata.PackageNotFoundError: + return {} + + result: dict[str, Any] = {"version": distribution.version} + direct_url_text = distribution.read_text("direct_url.json") + if direct_url_text: + try: + result["direct_url"] = json.loads(direct_url_text) + except json.JSONDecodeError: + result["direct_url"] = {} + return result + + +def _validate_runtime_policyengine_us_match( + metadata: dict, + *, + path: Path, +) -> None: + policyengine_us = metadata.get("policyengine_us") or {} + runtime_policyengine_us = _runtime_policyengine_us_metadata() + metadata_version = policyengine_us.get("version") + runtime_version = runtime_policyengine_us.get("version") + + if metadata_version and runtime_version != metadata_version: + raise ValueError( + f"Long-term dataset {path} was built with policyengine-us " + f"version {metadata_version!r}, but the installed runtime is " + f"{runtime_version!r}." + ) + + metadata_git_sha = _policyengine_us_git_sha(policyengine_us) + runtime_git_sha = _policyengine_us_git_sha(runtime_policyengine_us) + if metadata_git_sha and runtime_git_sha != metadata_git_sha: + raise ValueError( + f"Long-term dataset {path} was built with policyengine-us git SHA " + f"{metadata_git_sha!r}, but the installed runtime has " + f"{runtime_git_sha!r}." + ) + + +def validate_long_term_dataset_metadata( + metadata: dict, + *, + path: Path, + year: int, + required_profile: Optional[str] = None, + required_target_source: Optional[str] = None, + required_tax_assumption: Optional[str] = None, + required_support_augmentation_profile: Optional[str] = None, + required_support_augmentation_target_year: Optional[int] = None, + required_support_augmentation_target_year_strategy: Optional[str] = None, + required_support_augmentation_blueprint_base_weight_scale: Optional[float] = None, + require_support_augmentation_sanitize_clone_non_target_income: Optional[ + bool + ] = None, + require_support_augmentation_sanitize_worker_non_target_income: Optional[ + bool + ] = None, + minimum_calibration_quality: Optional[str] = None, + require_validation_passed: bool = False, + required_policyengine_us_version: Optional[str] = None, + required_policyengine_us_git_sha: Optional[str] = None, + require_policyengine_us_clean_build: bool = False, + require_runtime_policyengine_us_match: bool = False, +) -> None: + """Validate sidecar metadata for a long-term projected US dataset.""" + + metadata_year = metadata.get("year") + if metadata_year is not None and int(metadata_year) != int(year): + raise ValueError( + f"Long-term dataset {path} metadata year={metadata_year!r}, " + f"expected {year}." + ) + + profile = metadata.get("profile") or {} + target_source = metadata.get("target_source") or {} + tax_assumption = metadata.get("tax_assumption") or {} + support_augmentation = metadata.get("support_augmentation") or {} + calibration_audit = metadata.get("calibration_audit") or {} + policyengine_us = metadata.get("policyengine_us") or {} + + _require_metadata_value( + metadata, + path, + "profile.name", + profile.get("name"), + required_profile, + ) + _require_metadata_value( + metadata, + path, + "target_source.name", + target_source.get("name"), + required_target_source, + ) + _require_metadata_value( + metadata, + path, + "tax_assumption.name", + tax_assumption.get("name"), + required_tax_assumption, + ) + _require_metadata_value( + metadata, + path, + "support_augmentation.name", + support_augmentation.get("name"), + required_support_augmentation_profile, + ) + _require_metadata_value( + metadata, + path, + "support_augmentation.target_year", + support_augmentation.get("target_year"), + required_support_augmentation_target_year, + ) + _require_metadata_value( + metadata, + path, + "support_augmentation.target_year_strategy", + support_augmentation.get("target_year_strategy"), + required_support_augmentation_target_year_strategy, + ) + _require_metadata_value( + metadata, + path, + "support_augmentation.blueprint_base_weight_scale", + support_augmentation.get("blueprint_base_weight_scale"), + required_support_augmentation_blueprint_base_weight_scale, + ) + _require_metadata_value( + metadata, + path, + "support_augmentation.sanitize_clone_non_target_income", + support_augmentation.get("sanitize_clone_non_target_income"), + require_support_augmentation_sanitize_clone_non_target_income, + ) + _require_metadata_value( + metadata, + path, + "support_augmentation.sanitize_worker_non_target_income", + support_augmentation.get("sanitize_worker_non_target_income"), + require_support_augmentation_sanitize_worker_non_target_income, + ) + + if minimum_calibration_quality is not None: + quality = calibration_audit.get("calibration_quality") + if quality is None: + raise ValueError( + f"Long-term dataset {path} is missing " + "calibration_audit.calibration_quality." + ) + if _quality_rank(quality) < _quality_rank(minimum_calibration_quality): + raise ValueError( + f"Long-term dataset {path} has calibration quality {quality!r}, " + f"below required minimum {minimum_calibration_quality!r}." + ) + + if ( + require_validation_passed + and calibration_audit.get("validation_passed") is not True + ): + raise ValueError( + f"Long-term dataset {path} has " + f"calibration_audit.validation_passed=" + f"{calibration_audit.get('validation_passed')!r}; expected true." + ) + + _require_metadata_value( + metadata, + path, + "policyengine_us.version", + policyengine_us.get("version"), + required_policyengine_us_version, + ) + _require_metadata_value( + metadata, + path, + "policyengine_us.git_sha", + _policyengine_us_git_sha(policyengine_us), + required_policyengine_us_git_sha, + ) + if ( + require_policyengine_us_clean_build + and policyengine_us.get("git_dirty") is not False + ): + raise ValueError( + f"Long-term dataset {path} has policyengine_us.git_dirty=" + f"{policyengine_us.get('git_dirty')!r}; expected false." + ) + if require_runtime_policyengine_us_match: + _validate_runtime_policyengine_us_match(metadata, path=path) + + +def load_long_term_datasets( + years: list[int], + data_folder: str = "./projected_datasets", + dataset_template: str = "{year}.h5", + dataset_name: str = "long_term_cps", + require_metadata: bool = True, + required_profile: Optional[str] = None, + required_target_source: Optional[str] = None, + required_tax_assumption: Optional[str] = None, + required_support_augmentation_profile: Optional[str] = None, + required_support_augmentation_target_year: Optional[int] = None, + required_support_augmentation_target_year_strategy: Optional[str] = None, + required_support_augmentation_blueprint_base_weight_scale: Optional[float] = None, + require_support_augmentation_sanitize_clone_non_target_income: Optional[ + bool + ] = None, + require_support_augmentation_sanitize_worker_non_target_income: Optional[ + bool + ] = None, + minimum_calibration_quality: Optional[str] = None, + require_validation_passed: bool = False, + required_policyengine_us_version: Optional[str] = None, + required_policyengine_us_git_sha: Optional[str] = None, + require_policyengine_us_clean_build: bool = False, + require_runtime_policyengine_us_match: bool = False, +) -> dict[str, PolicyEngineUSDataset]: + """Load pre-built long-term US projected datasets. + + The country data repo still owns the expensive projection and calibration + build. This helper lets policyengine.py consume those year-specific H5 + artifacts with sidecar metadata validation, including optional checks that + the installed ``policyengine-us`` runtime matches the H5 build metadata. + """ + + result = {} + root = Path(data_folder).expanduser() + for year in years: + path = root / dataset_template.format(year=year) + if not path.exists(): + raise FileNotFoundError(f"Long-term dataset not found: {path}") + + metadata, metadata_path = _load_dataset_metadata(path, require_metadata) + if metadata_path is not None: + validate_long_term_dataset_metadata( + metadata, + path=path, + year=year, + required_profile=required_profile, + required_target_source=required_target_source, + required_tax_assumption=required_tax_assumption, + required_support_augmentation_profile=( + required_support_augmentation_profile + ), + required_support_augmentation_target_year=( + required_support_augmentation_target_year + ), + required_support_augmentation_target_year_strategy=( + required_support_augmentation_target_year_strategy + ), + required_support_augmentation_blueprint_base_weight_scale=( + required_support_augmentation_blueprint_base_weight_scale + ), + require_support_augmentation_sanitize_clone_non_target_income=( + require_support_augmentation_sanitize_clone_non_target_income + ), + require_support_augmentation_sanitize_worker_non_target_income=( + require_support_augmentation_sanitize_worker_non_target_income + ), + minimum_calibration_quality=minimum_calibration_quality, + require_validation_passed=require_validation_passed, + required_policyengine_us_version=required_policyengine_us_version, + required_policyengine_us_git_sha=required_policyengine_us_git_sha, + require_policyengine_us_clean_build=( + require_policyengine_us_clean_build + ), + require_runtime_policyengine_us_match=( + require_runtime_policyengine_us_match + ), + ) + + dataset = PolicyEngineUSDataset( + id=f"{dataset_name}_{year}", + name=f"{dataset_name}-{year}", + description=f"US long-term projected dataset for {year}", + filepath=str(path), + year=int(year), + metadata=metadata, + metadata_filepath=str(metadata_path) if metadata_path else None, + ) + result[f"{dataset_name}_{year}"] = dataset + + return result + + def ensure_datasets( datasets: list[str] = [ "enhanced_cps_2024", diff --git a/tests/test_models.py b/tests/test_models.py index 577b9886..0d10a7f4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -183,8 +183,8 @@ def test__given_breakdown_label__then_includes_enum_value_in_parentheses( found = False for p in us_latest.parameters: if ".SINGLE" in p.name and p.label and "(" in p.label: - assert "(Single)" in p.label, ( - f"Label '{p.label}' should contain '(Single)'" + assert re.search(r"\([^)]*\bSingle\b[^)]*\)", p.label), ( + f"Label '{p.label}' should contain 'Single' in parentheses" ) found = True break diff --git a/tests/test_us_long_term_datasets.py b/tests/test_us_long_term_datasets.py new file mode 100644 index 00000000..0990287f --- /dev/null +++ b/tests/test_us_long_term_datasets.py @@ -0,0 +1,311 @@ +import json +from pathlib import Path + +import h5py +import pandas as pd +import pytest +from microdf import MicroDataFrame + +import policyengine.tax_benefit_models.us.datasets as us_datasets_module +from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, + load_long_term_datasets, +) + + +def _simple_entity(entity: str, count: int = 1) -> MicroDataFrame: + return MicroDataFrame( + pd.DataFrame( + { + f"{entity}_id": list(range(1, count + 1)), + f"{entity}_weight": [1_000.0] * count, + } + ), + weights=f"{entity}_weight", + ) + + +def _write_us_h5(path: Path, year: int) -> None: + person = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1], + "household_id": [1], + "tax_unit_id": [1], + "spm_unit_id": [1], + "family_id": [1], + "marital_unit_id": [1], + "person_weight": [1_000.0], + "age": [70], + } + ), + weights="person_weight", + ) + household = MicroDataFrame( + pd.DataFrame({"household_id": [1], "household_weight": [1_000.0]}), + weights="household_weight", + ) + PolicyEngineUSDataset( + id=f"fixture_{year}", + name=f"fixture-{year}", + description="Long-term fixture", + filepath=str(path), + year=year, + data=USYearData( + person=person, + household=household, + tax_unit=_simple_entity("tax_unit"), + spm_unit=_simple_entity("spm_unit"), + family=_simple_entity("family"), + marital_unit=_simple_entity("marital_unit"), + ), + ) + + +def _write_core_h5(path: Path, year: int) -> None: + values = { + "person_id": [1, 2], + "person_household_id": [10, 10], + "person_tax_unit_id": [20, 20], + "person_spm_unit_id": [30, 30], + "person_family_id": [40, 40], + "person_marital_unit_id": [50, 50], + "age": [70, 68], + "household_id": [10], + "household_weight": [1_000.0], + "state_code": ["CA"], + "tax_unit_id": [20], + "spm_unit_id": [30], + "family_id": [40], + "marital_unit_id": [50], + } + with h5py.File(path, "w") as h5_file: + for variable, data in values.items(): + group = h5_file.create_group(variable) + group.create_dataset(str(year), data=data) + + +def _write_metadata(path: Path, year: int, **overrides) -> None: + metadata = { + "year": year, + "profile": {"name": "ss-payroll-tob"}, + "target_source": {"name": "trustees_2025_current_law"}, + "tax_assumption": {"name": "trustees-core-thresholds-v1"}, + "support_augmentation": { + "name": "donor-backed-composite-v1", + "target_year": 2100, + "target_year_strategy": "fixed", + "blueprint_base_weight_scale": 5.0, + "sanitize_clone_non_target_income": True, + "sanitize_worker_non_target_income": False, + }, + "calibration_audit": { + "calibration_quality": "exact", + "validation_passed": True, + }, + "policyengine_us": { + "version": "1.691.10", + "direct_url": { + "vcs_info": { + "commit_id": "4fd79e6608bc2dac3a7fde0be37191cb4870bd85", + "vcs": "git", + }, + }, + "git_dirty": False, + }, + } + for key, value in overrides.items(): + metadata[key] = value + Path(f"{path}.metadata.json").write_text( + json.dumps(metadata, indent=2), + encoding="utf-8", + ) + + +def test__load_long_term_datasets__loads_h5_and_sidecar_metadata(tmp_path): + h5_path = tmp_path / "2075.h5" + _write_us_h5(h5_path, 2075) + _write_metadata(h5_path, 2075) + + datasets = load_long_term_datasets( + [2075], + data_folder=str(tmp_path), + required_profile="ss-payroll-tob", + required_target_source="trustees_2025_current_law", + required_tax_assumption="trustees-core-thresholds-v1", + required_support_augmentation_profile="donor-backed-composite-v1", + required_support_augmentation_target_year=2100, + required_support_augmentation_target_year_strategy="fixed", + required_support_augmentation_blueprint_base_weight_scale=5.0, + require_support_augmentation_sanitize_clone_non_target_income=True, + require_support_augmentation_sanitize_worker_non_target_income=False, + minimum_calibration_quality="exact", + require_validation_passed=True, + ) + + dataset = datasets["long_term_cps_2075"] + assert dataset.year == 2075 + assert dataset.filepath == str(h5_path) + assert dataset.metadata["profile"]["name"] == "ss-payroll-tob" + assert dataset.metadata_filepath == f"{h5_path}.metadata.json" + assert len(dataset.data.household) == 1 + + +def test__load_long_term_datasets__loads_policyengine_core_h5(tmp_path): + h5_path = tmp_path / "2100.h5" + _write_core_h5(h5_path, 2100) + _write_metadata(h5_path, 2100) + + datasets = load_long_term_datasets( + [2100], + data_folder=str(tmp_path), + required_profile="ss-payroll-tob", + required_target_source="trustees_2025_current_law", + required_tax_assumption="trustees-core-thresholds-v1", + minimum_calibration_quality="exact", + require_validation_passed=True, + ) + + dataset = datasets["long_term_cps_2100"] + assert len(dataset.data.person) == 2 + assert len(dataset.data.household) == 1 + assert dataset.data.person["person_weight"].tolist() == [1_000.0, 1_000.0] + assert dataset.data.tax_unit["tax_unit_weight"].tolist() == [1_000.0] + assert dataset.data.household["state_code"].tolist() == ["CA"] + + +def test__load_long_term_datasets__rejects_metadata_contract_mismatch(tmp_path): + h5_path = tmp_path / "2075.h5" + _write_us_h5(h5_path, 2075) + _write_metadata(h5_path, 2075, profile={"name": "age-only"}) + + with pytest.raises(ValueError, match="profile.name"): + load_long_term_datasets( + [2075], + data_folder=str(tmp_path), + required_profile="ss-payroll-tob", + ) + + +def test__load_long_term_datasets__rejects_empty_metadata_when_contract_required( + tmp_path, +): + h5_path = tmp_path / "2075.h5" + _write_us_h5(h5_path, 2075) + Path(f"{h5_path}.metadata.json").write_text("{}", encoding="utf-8") + + with pytest.raises(ValueError, match="profile.name"): + load_long_term_datasets( + [2075], + data_folder=str(tmp_path), + required_profile="ss-payroll-tob", + ) + + +def test__load_long_term_datasets__rejects_support_contract_mismatch(tmp_path): + h5_path = tmp_path / "2075.h5" + _write_us_h5(h5_path, 2075) + _write_metadata( + h5_path, + 2075, + support_augmentation={ + "name": "donor-backed-composite-v1", + "target_year": 2100, + "target_year_strategy": "fixed", + "blueprint_base_weight_scale": 5.0, + "sanitize_clone_non_target_income": False, + "sanitize_worker_non_target_income": False, + }, + ) + + with pytest.raises( + ValueError, + match="support_augmentation.sanitize_clone_non_target_income", + ): + load_long_term_datasets( + [2075], + data_folder=str(tmp_path), + require_support_augmentation_sanitize_clone_non_target_income=True, + ) + + +def test__load_long_term_datasets__rejects_policyengine_us_version_mismatch( + tmp_path, +): + h5_path = tmp_path / "2075.h5" + _write_us_h5(h5_path, 2075) + _write_metadata(h5_path, 2075) + + with pytest.raises(ValueError, match="policyengine_us.version"): + load_long_term_datasets( + [2075], + data_folder=str(tmp_path), + required_policyengine_us_version="1.691.3", + ) + + +def test__load_long_term_datasets__rejects_policyengine_us_git_sha_mismatch( + tmp_path, +): + h5_path = tmp_path / "2075.h5" + _write_us_h5(h5_path, 2075) + _write_metadata(h5_path, 2075) + + with pytest.raises(ValueError, match="policyengine_us.git_sha"): + load_long_term_datasets( + [2075], + data_folder=str(tmp_path), + required_policyengine_us_git_sha="a" * 40, + ) + + +def test__load_long_term_datasets__rejects_dirty_policyengine_us_build( + tmp_path, +): + h5_path = tmp_path / "2075.h5" + _write_us_h5(h5_path, 2075) + _write_metadata( + h5_path, + 2075, + policyengine_us={"version": "1.691.10", "git_dirty": True}, + ) + + with pytest.raises(ValueError, match="policyengine_us.git_dirty"): + load_long_term_datasets( + [2075], + data_folder=str(tmp_path), + require_policyengine_us_clean_build=True, + ) + + +def test__load_long_term_datasets__can_require_runtime_policyengine_us_match( + monkeypatch, + tmp_path, +): + h5_path = tmp_path / "2075.h5" + _write_us_h5(h5_path, 2075) + _write_metadata(h5_path, 2075) + monkeypatch.setattr( + us_datasets_module, + "_runtime_policyengine_us_metadata", + lambda: { + "version": "1.691.10", + "direct_url": { + "vcs_info": { + "commit_id": "4fd79e6608bc2dac3a7fde0be37191cb4870bd85", + "vcs": "git", + }, + }, + }, + ) + + datasets = load_long_term_datasets( + [2075], + data_folder=str(tmp_path), + require_runtime_policyengine_us_match=True, + ) + + assert datasets["long_term_cps_2075"].metadata["policyengine_us"]["version"] == ( + "1.691.10" + ) diff --git a/uv.lock b/uv.lock index 469bc909..3a0cfe59 100644 --- a/uv.lock +++ b/uv.lock @@ -2411,9 +2411,11 @@ wheels = [ [[package]] name = "policyengine" -version = "4.4.3" +version = "4.4.4" source = { editable = "." } dependencies = [ + { name = "h5py", version = "3.14.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "h5py", version = "3.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "jsonschema" }, { name = "microdf-python" }, { name = "packaging" }, @@ -2465,6 +2467,7 @@ requires-dist = [ { name = "autodoc-pydantic", marker = "extra == 'dev'" }, { name = "build", marker = "extra == 'dev'" }, { name = "furo", marker = "extra == 'dev'" }, + { name = "h5py", specifier = ">=3.0.0" }, { name = "itables", marker = "extra == 'dev'" }, { name = "jsonschema", specifier = ">=4.0.0" }, { name = "jupyter-book", marker = "extra == 'dev'" },