diff --git a/changelog.d/derive-us-program-entities.changed.md b/changelog.d/derive-us-program-entities.changed.md new file mode 100644 index 00000000..53ba0269 --- /dev/null +++ b/changelog.d/derive-us-program-entities.changed.md @@ -0,0 +1 @@ +Derive US program-statistics entity from variable metadata instead of duplicating it in the program list. diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 7bc1cd52..a285020d 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -28,18 +28,23 @@ ) from policyengine.utils.errors import format_conditional_error_detail -US_PROGRAMS = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "employee_payroll_tax": {"entity": "tax_unit", "is_tax": True}, - "state_income_tax": {"entity": "tax_unit", "is_tax": True}, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "person", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - "medicare_cost": {"entity": "person", "is_tax": False}, - "medicaid": {"entity": "person", "is_tax": False}, - "eitc": {"entity": "tax_unit", "is_tax": False}, - "ctc": {"entity": "tax_unit", "is_tax": False}, +# Map of US program-statistics variable name -> program metadata. The +# entity for each program is derived from the variable's own metadata +# at runtime (see ``_validate_program_statistics_config`` and +# ``economic_impact_analysis``), so this list cannot silently drift +# when policyengine-us moves a variable between entities. +US_PROGRAMS: dict[str, dict] = { + "income_tax": {"is_tax": True}, + "employee_payroll_tax": {"is_tax": True}, + "state_income_tax": {"is_tax": True}, + "snap": {"is_tax": False}, + "tanf": {"is_tax": False}, + "ssi": {"is_tax": False}, + "social_security": {"is_tax": False}, + "medicare_cost": {"is_tax": False}, + "medicaid": {"is_tax": False}, + "eitc": {"is_tax": False}, + "ctc": {"is_tax": False}, } @@ -95,7 +100,7 @@ def _validate_program_statistics_config( missing_outputs: set[tuple[str, str]] = set() simulations = (baseline_simulation, reform_simulation) - for program_name, program_info in US_PROGRAMS.items(): + for program_name in US_PROGRAMS: for simulation in simulations: model_version = simulation.tax_benefit_model_version try: @@ -153,13 +158,14 @@ def economic_impact_analysis( income_variable="household_net_income", ) + model_version = baseline_simulation.tax_benefit_model_version program_statistics = [] for program_name, program_info in US_PROGRAMS.items(): stats = ProgramStatistics( baseline_simulation=baseline_simulation, reform_simulation=reform_simulation, program_name=program_name, - entity=program_info["entity"], + entity=model_version.get_variable(program_name).entity, is_tax=program_info["is_tax"], ) stats.run() diff --git a/tests/test_us_program_statistics.py b/tests/test_us_program_statistics.py index 2c5044f8..cef71cdf 100644 --- a/tests/test_us_program_statistics.py +++ b/tests/test_us_program_statistics.py @@ -103,13 +103,14 @@ def test_us_program_statistics_config_runs_against_mocked_outputs(tmp_path): _validate_program_statistics_config(baseline, reform) + model_version = baseline.tax_benefit_model_version results = {} for program_name, program_info in US_PROGRAMS.items(): stats = ProgramStatistics( baseline_simulation=baseline, reform_simulation=reform, program_name=program_name, - entity=program_info["entity"], + entity=model_version.get_variable(program_name).entity, is_tax=program_info["is_tax"], ) stats.run() @@ -144,3 +145,10 @@ def test_us_program_statistics_config_fails_before_simulation_run( _validate_program_statistics_config(baseline, reform) assert "medicare_cost" in str(exc_info.value) + + +def test_us_programs_entities_match_model_metadata(): + for program_name in US_PROGRAMS: + assert program_name in us_latest.variables_by_name, ( + f"{program_name} is not defined in the US model" + )