Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/derivations-narrative.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``policyengine.derivations`` for per-variable computation explanations: ``derive(simulation, variable, period)`` returns a structured ``Derivation`` (with pruned trace and top-level contributions); ``narrate(derivation)`` optionally hands it to an LLM for a plain-prose walkthrough.
35 changes: 35 additions & 0 deletions src/policyengine/derivations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Derivations: structured + narrated explanations of one variable's value.

A ``Derivation`` is the pruned, deterministic computation tree for a single
``(simulation, variable)`` pair. The tree is the same information OpenFisca
already records when ``simulation.trace`` is on, but presented as a stable
data class (independent of OpenFisca internals) so callers can:

- print or serialize the structured tree (deterministic, free),
- pull out top-level contributions for charts or tables, and
- optionally hand the derivation to an LLM via :func:`narrate` for a plain-prose
walkthrough (the only step that requires a network call).

This module deliberately separates the *deterministic* part of the explanation
(everything in ``Derivation``) from the *narration* (an external LLM call). A
caller can use one without the other.
"""

from .narrate import narrate, narrate_async
from .trace import (
Derivation,
TraceNode,
derive,
is_zero_value,
top_level_contributions,
)

__all__ = [
"Derivation",
"TraceNode",
"derive",
"is_zero_value",
"narrate",
"narrate_async",
"top_level_contributions",
]
124 changes: 124 additions & 0 deletions src/policyengine/derivations/narrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""LLM narration of a structured :class:`Derivation`.

This is the only piece of the derivations API that makes a network call. It
is kept in its own module so that callers who only want the deterministic
structured tree don't drag a network/LLM dependency into the import graph.

LiteLLM is imported lazily inside the call so that ``import
policyengine.derivations`` doesn't require any LLM credentials to succeed.
"""

from __future__ import annotations

from typing import Any

from .trace import Derivation

DEFAULT_MODEL = "claude-sonnet-4-6"
DEFAULT_MAX_TOKENS = 500
DEFAULT_TEMPERATURE = 0.0


def _build_prompt(
derivation: Derivation,
*,
country: str | None,
household_summary: str | None,
extra_context: str | None,
trace_max_depth: int,
) -> str:
header_lines = [
"You are summarizing how PolicyEngine derived a single variable's value "
"for one household.",
"",
f"VARIABLE: {derivation.variable}",
f"PERIOD: {derivation.period}",
f"REFERENCE VALUE: {derivation.value}",
]
if country:
header_lines.insert(3, f"COUNTRY: {country.upper()}")
if household_summary:
header_lines.append(f"HOUSEHOLD: {household_summary}")
if extra_context:
header_lines.append("")
header_lines.append(extra_context)

trace_text = derivation.trace_text(max_depth=trace_max_depth)
instructions = (
"Write a 3-5 sentence narrative explaining how PolicyEngine arrived at "
"this value. Reference the most important intermediate quantities by "
"name and amount. Be concrete and quantitative. Plain prose, no "
"headers, no bullet lists."
)
return (
"\n".join(header_lines) + "\n\nPolicyEngine computation trace "
"(indented dependency tree, non-zero nodes only):\n```\n"
+ trace_text
+ "\n```\n\n"
+ instructions
+ "\n"
)


def narrate(
derivation: Derivation,
*,
country: str | None = None,
household_summary: str | None = None,
extra_context: str | None = None,
model: str = DEFAULT_MODEL,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
trace_max_depth: int = 8,
) -> str:
"""Synchronously ask an LLM to narrate this derivation.

Imports LiteLLM lazily so that ``import policyengine.derivations`` has no
LLM dependency. Returns the model's response text.
"""
import litellm # noqa: PLC0415 — lazy import keeps the base module light

prompt = _build_prompt(
derivation,
country=country,
household_summary=household_summary,
extra_context=extra_context,
trace_max_depth=trace_max_depth,
)
response = litellm.completion(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content.strip()


async def narrate_async(
derivation: Derivation,
*,
country: str | None = None,
household_summary: str | None = None,
extra_context: str | None = None,
model: str = DEFAULT_MODEL,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
trace_max_depth: int = 8,
) -> str:
"""Async variant of :func:`narrate` — same interface, awaitable result."""
import litellm # noqa: PLC0415 — lazy import keeps the base module light

prompt = _build_prompt(
derivation,
country=country,
household_summary=household_summary,
extra_context=extra_context,
trace_max_depth=trace_max_depth,
)
response: Any = await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content.strip()
219 changes: 219 additions & 0 deletions src/policyengine/derivations/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""Deterministic computation-trace extraction.

Turns the live OpenFisca tracer output for a single variable into a stable
:class:`Derivation` data class. Everything here is pure and side-effect-free
once the tracer has captured the calculation.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Iterable


@dataclass(frozen=True)
class TraceNode:
"""One node in the pruned computation tree.

Mirrors OpenFisca's per-variable trace entry: the variable name, the
scalar value it took for the requested household, and its immediate
dependencies (each themselves a :class:`TraceNode`). Booleans surface as
Python ``bool``; numeric values surface as ``float``.
"""

name: str
value: Any
children: tuple[TraceNode, ...] = field(default_factory=tuple)

def to_text(self, *, max_depth: int = 8, prune_zero: bool = True) -> str:
"""Render the tree as an indented text block.

Parameters
----------
max_depth:
Stop descending below this depth (root has depth 0).
prune_zero:
When True (the default), zero-valued subtrees below depth 1 are
omitted from the rendering. This is the format we feed to LLMs:
keep the top-level zero categories so the model knows they were
considered, but drop the cascading zero leaves under them.
"""

lines: list[str] = []
_render(self, depth=0, lines=lines, max_depth=max_depth, prune_zero=prune_zero)
return "\n".join(lines)


@dataclass(frozen=True)
class Derivation:
"""A single variable's computation, captured as a structured tree.

``Derivation`` is the deterministic core of an explanation. It can be
rendered as text, walked programmatically for charts, or passed to
:func:`policyengine.derivations.narrate` for a prose summary.
"""

variable: str
value: Any
trace: TraceNode
period: Any

def trace_text(self, *, max_depth: int = 8, prune_zero: bool = True) -> str:
"""Convenience wrapper around :meth:`TraceNode.to_text`."""
return self.trace.to_text(max_depth=max_depth, prune_zero=prune_zero)

def top_level_contributions(self) -> list[tuple[str, Any]]:
"""``[(child_variable_name, value), ...]`` for the root's children.

Useful when you want a deterministic structured breakdown next to the
prose narrative — e.g. "the answer is the sum of these named pieces".
"""
return top_level_contributions(self)


def _to_python(value: Any) -> Any:
"""Convert a numpy scalar to a native Python scalar; pass through tuples."""
if hasattr(value, "item"):
return value.item()
return value


def _capture(value: Any) -> Any:
"""Capture an OpenFisca trace value as a Python scalar or tuple.

Per-person variables come through as numpy arrays of length N (one per
person in the household); tax-unit / household variables come through as
length-1 arrays. Length-1 arrays collapse to a scalar so most renderings
stay terse; multi-entity arrays are preserved as tuples so that
summarising "$45,000 SE income + $40,000 wages" doesn't silently drop
the spouse's row.
"""
if hasattr(value, "__len__"):
if len(value) == 0:
return None
if len(value) == 1:
return _to_python(value[0])
return tuple(_to_python(item) for item in value)
return _to_python(value)


def is_zero_value(value: Any) -> bool:
"""True iff ``value`` is the zero of its type across every entity.

For multi-entity (tuple) values, every entry must be falsy/zero. Exported
because callers sometimes want to filter their own copies of the tree
without redefining what "zero" means.
"""
if isinstance(value, tuple):
return all(is_zero_value(item) for item in value)
if isinstance(value, bool):
return value is False
if isinstance(value, (int, float)):
return value == 0
return False


def _convert(node: Any) -> TraceNode:
"""Convert an OpenFisca tracer node into our stable ``TraceNode`` shape."""
return TraceNode(
name=node.name,
value=_capture(node.value),
children=tuple(_convert(child) for child in node.children),
)


def _render(
node: TraceNode,
*,
depth: int,
lines: list[str],
max_depth: int,
prune_zero: bool,
) -> None:
if depth > max_depth:
return
if prune_zero and depth > 1 and is_zero_value(node.value):
return
lines.append(" " * depth + node.name + " = " + _format_value(node.value))
for child in node.children:
_render(
child,
depth=depth + 1,
lines=lines,
max_depth=max_depth,
prune_zero=prune_zero,
)


def _format_value(value: Any) -> str:
if isinstance(value, tuple):
formatted = [_format_value(item) for item in value]
if all(
isinstance(item, (int, float)) and not isinstance(item, bool)
for item in value
):
total = sum(value)
return f"{_format_scalar(total)} (per entity: {', '.join(formatted)})"
return "[" + ", ".join(formatted) + "]"
return _format_scalar(value)


def _format_scalar(value: Any) -> str:
if isinstance(value, bool):
return "True" if value else "False"
if isinstance(value, float):
return f"{value:.2f}".rstrip("0").rstrip(".") or "0"
return str(value)


def _find_root(roots: Iterable[Any], target: str) -> Any | None:
"""Depth-first search the tracer roots for a node named ``target``."""
for root in roots:
if root.name == target:
return root
match = _find_root(root.children, target)
if match is not None:
return match
return None


def derive(simulation: Any, variable: str, period: Any) -> Derivation:
"""Compute ``variable`` on ``simulation`` and return a structured derivation.

The caller is responsible for owning the ``Simulation`` and any reform on
it. ``derive`` turns the tracer on (if not already on), clears any prior
trees so the captured tree is exactly the one we asked for, runs the
calculation, and converts the resulting tree to a stable ``TraceNode``.
"""

simulation.trace = True
if hasattr(simulation, "tracer") and hasattr(simulation.tracer, "trees"):
simulation.tracer.trees.clear()
simulation.calculate(variable, period)

if not hasattr(simulation, "tracer") or not simulation.tracer.trees:
raise RuntimeError(
f"No trace recorded after calculating {variable!r}. "
"Ensure the simulation backend supports tracing."
)
root = _find_root(simulation.tracer.trees, variable)
if root is None:
raise RuntimeError(
f"Tracer did not produce a root for {variable!r}. "
"This usually means the variable was already cached."
)
return Derivation(
variable=variable,
value=_capture(root.value),
trace=_convert(root),
period=period,
)


def top_level_contributions(derivation: Derivation) -> list[tuple[str, Any]]:
"""Return ``[(name, value), ...]`` for the immediate dependencies of the root.

Children appear in the order OpenFisca recorded them. Use this when you
want a deterministic structured breakdown alongside the prose narrative.
"""
return [(child.name, child.value) for child in derivation.trace.children]
Loading
Loading