Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
84742cb
feat(configurator): env_params as first-class trial-identity field
rutayan-nv Jun 20, 2026
7d68edf
refactor(configurator): env_params as cmd_args annotation, excluded f…
rutayan-nv Jun 23, 2026
bf2b690
fix(configurator): tighten env_params validation, align env.csv, fix …
rutayan-nv Jun 24, 2026
c341433
fix(workload): reject empty env_params candidate lists at build time
rutayan-nv Jun 24, 2026
7ac83fd
refactor(configurator): model env_params as cohesive EnvParam/EnvPara…
rutayan-nv Jun 24, 2026
7fdae40
refactor(configurator): collapse env_params sink into one concrete class
rutayan-nv Jun 25, 2026
f4c93e3
refactor(configurator): gate env_params sampling on an agent capabili…
rutayan-nv Jun 25, 2026
a7c6587
style(configurator): trim verbose comments toward self-documenting code
rutayan-nv Jun 25, 2026
822e8a3
fix(test_scenario): skip env_params on post-overlay revalidation
rutayan-nv Jun 25, 2026
1ecadee
fix(workload): reject scalar-only env_params annotations at parse time
rutayan-nv Jun 26, 2026
6dbe05e
refactor(env_params): thread per-trial sample as a local; split DR pr…
rutayan-nv Jun 29, 2026
ee55ba4
refactor(test_scenario): hoist Registry import to module level
rutayan-nv Jun 29, 2026
d02769e
fix(workload): reject singleton env_params candidate lists at parse time
rutayan-nv Jun 29, 2026
2437aa7
refactor(configurator): rename agent flag samples_env_params -> suppo…
rutayan-nv Jun 30, 2026
988bbaf
feat(core): add ContinuousSpace action-space primitive
rutayan-nv Jun 16, 2026
ec1a67d
test(action-space): suppress pyright on intentional ContinuousSpace r…
rutayan-nv Jun 16, 2026
b1cf51c
fix(core): require integer bounds for ContinuousSpace(dtype="int")
rutayan-nv Jun 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/cloudai/_core/action_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Action-space primitives for CloudAI DSE.

CloudAI describes a tunable action space as a mapping from parameter name to
its candidate domain. Discrete parameters use a ``list`` of candidate values;
continuous parameters use :class:`ContinuousSpace`, a closed real interval.
Agents and adapters (e.g. ``GymnasiumAdapter``) read these to build their own
action representation and to decode sampled actions back to native values.
"""

from __future__ import annotations

from typing import Literal

from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self


class ContinuousSpace(BaseModel):
"""
A continuous (closed-interval) action-space dimension.

Represents a single tunable parameter drawn from ``[low, high]``. ``dtype``
declares whether decoded samples should be quantized to integers
(``"int"``) or kept as floats (``"float"``); quantization is applied by
consumers when decoding an action, not stored here.
"""

model_config = ConfigDict(extra="forbid")

low: float
high: float
dtype: Literal["int", "float"] = "float"
Comment thread
coderabbitai[bot] marked this conversation as resolved.

@model_validator(mode="after")
def _validate_bounds(self) -> Self:
if self.low >= self.high:
raise ValueError(f"ContinuousSpace requires low < high; got low={self.low}, high={self.high}")
if self.dtype == "int" and (not self.low.is_integer() or not self.high.is_integer()):
Comment on lines +45 to +55

@coderabbitai coderabbitai Bot Jun 23, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
from pydantic import BaseModel, ConfigDict

class Model(BaseModel):
    model_config = ConfigDict(extra="forbid")
    x: float

for value in (float("nan"), float("inf"), float("-inf")):
    try:
        print("accepted", value, "->", Model(x=value).x)
    except Exception as exc:
        print("rejected", value, type(exc).__name__)
PY

Repository: NVIDIA/cloudai

Length of output: 214


🏁 Script executed:

# First, locate and examine the ContinuousSpace class
cat -n src/cloudai/_core/action_space.py | head -70

Repository: NVIDIA/cloudai

Length of output: 2879


🏁 Script executed:

# Search for usage of ContinuousSpace in the codebase
rg "ContinuousSpace" --type py -A 2 -B 2

Repository: NVIDIA/cloudai

Length of output: 152


🏁 Script executed:

# Check for any existing finite validation or tests related to bounds
rg "isfinite|NaN|infinity" --type py src/cloudai/

Repository: NVIDIA/cloudai

Length of output: 396


🏁 Script executed:

# Retry searching for ContinuousSpace usage with broader search
rg "ContinuousSpace" --type py

Repository: NVIDIA/cloudai

Length of output: 152


🏁 Script executed:

# Find test files related to action_space
find . -name "*test*action*" -o -name "*action*test*" | grep -E "\.py$"

Repository: NVIDIA/cloudai

Length of output: 181


🏁 Script executed:

# Search for imports from action_space to find consumers
rg "from.*action_space import|import.*action_space" --type py

Repository: NVIDIA/cloudai

Length of output: 152


🏁 Script executed:

# Check test file for ContinuousSpace
cat tests/test_action_space.py

Repository: NVIDIA/cloudai

Length of output: 2514


🏁 Script executed:

# Search broadly for ContinuousSpace in all text files
rg "ContinuousSpace" .

Repository: NVIDIA/cloudai

Length of output: 1501


🏁 Script executed:

# Look for files that use low/high bounds or action space
rg "\.low|\.high" --type py | grep -i "space\|adapter" | head -20

Repository: NVIDIA/cloudai

Length of output: 152


Reject non-finite bounds explicitly.

low >= high does not reject NaN (comparison with NaN is always False), and unbounded infinities can create invalid intervals for downstream adapters. Add finite-bound validation.

Proposed fix
+import math
 from typing import Literal
@@
     `@model_validator`(mode="after")
     def _validate_bounds(self) -> Self:
+        if not math.isfinite(self.low) or not math.isfinite(self.high):
+            raise ValueError(f"ContinuousSpace requires finite bounds; got low={self.low}, high={self.high}")
         if self.low >= self.high:
             raise ValueError(f"ContinuousSpace requires low < high; got low={self.low}, high={self.high}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model_config = ConfigDict(extra="forbid")
low: float
high: float
dtype: Literal["int", "float"] = "float"
@model_validator(mode="after")
def _validate_bounds(self) -> Self:
if self.low >= self.high:
raise ValueError(f"ContinuousSpace requires low < high; got low={self.low}, high={self.high}")
if self.dtype == "int" and (not self.low.is_integer() or not self.high.is_integer()):
model_config = ConfigDict(extra="forbid")
low: float
high: float
dtype: Literal["int", "float"] = "float"
`@model_validator`(mode="after")
def _validate_bounds(self) -> Self:
if not math.isfinite(self.low) or not math.isfinite(self.high):
raise ValueError(f"ContinuousSpace requires finite bounds; got low={self.low}, high={self.high}")
if self.low >= self.high:
raise ValueError(f"ContinuousSpace requires low < high; got low={self.low}, high={self.high}")
if self.dtype == "int" and (not self.low.is_integer() or not self.high.is_integer()):
🧰 Tools
🪛 Ruff (0.15.18)

[warning] 54-54: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/cloudai/_core/action_space.py` around lines 45 - 55, The _validate_bounds
method in the ContinuousSpace class needs explicit validation for non-finite
bounds. The current comparison check (self.low >= self.high) does not reject NaN
values since NaN comparisons always return False, and it allows infinities which
can create invalid intervals. Add checks to ensure both self.low and self.high
are finite values (using a method like math.isfinite) before or alongside the
existing bounds validation, raising a ValueError with a descriptive message if
either bound is NaN or infinite.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid; deferred with the PR. The finite-bound (NaN/inf) guard will land when ContinuousSpace is revived and wired into param_space.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rutayan-nv, understood — thanks for the context. Would you like me to open a GitHub issue to track adding the finite-bound (math.isfinite) guard to ContinuousSpace._validate_bounds so it doesn't get lost before the wiring-in PR lands?


🧠 Learnings used
Learnt from: amaslenn
Repo: NVIDIA/cloudai PR: 754
File: src/cloudai/_core/registry.py:226-234
Timestamp: 2025-12-16T19:47:41.994Z
Learning: In this repository, prefer expressing behavioral documentation through tests rather than docstrings. Tests act as living, verified documentation. Reserve docstrings for interfaces or high-level descriptions, and avoid duplicating behavior that is already covered by tests.

Learnt from: amaslenn
Repo: NVIDIA/cloudai PR: 799
File: src/cloudai/workloads/vllm/report_generation_strategy.py:47-48
Timestamp: 2026-02-10T13:29:25.671Z
Learning: In src/cloudai/workloads/vllm/report_generation_strategy.py, it is acceptable to annotate parse_vllm_bench_output with a caching decorator (e.g., functools.cache) because benchmark result files are immutable after creation in typical usage. This improves repeated calls by avoiding re-parsing unchanged data. Ensure the function is side-effect free and that cache invalidation is not needed for these inputs. If inputs include non-file-state variability, consider cache keys on file path or content hash.

Learnt from: podkidyshev
Repo: NVIDIA/cloudai PR: 821
File: src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py:87-97
Timestamp: 2026-03-10T11:01:25.158Z
Learning: In NVIDIA/cloudai repository, there is no secrets management infrastructure; environment variable values (including tokens/secrets) are serialized as literal strings into generated commands and run artifacts (e.g., cloudai_generated_command.sh). This is an architectural limitation acknowledged by the maintainer (podkidyshev). During code reviews for this repo, do not flag secret-serialization issues unless a secrets management solution is available. If a future change introduces proper secrets handling, update this guideline accordingly.

Learnt from: amaslenn
Repo: NVIDIA/cloudai PR: 836
File: tests/workloads/test_llm_serving.py:35-83
Timestamp: 2026-03-16T14:10:11.280Z
Learning: In the NVIDIA/cloudai repository, require meaningful docstrings for non-test Python classes (production/library code) to improve readability and maintainability. Test doubles, test helper classes, and other classes defined inside test files (e.g., Fake*, Plain*, stub classes) are exempt from the docstring requirement. Enforce via automated checks (e.g., a docstring style/quality gate using CI) that validate presence and quality of class docstrings for production code while skipping test-only classes within test files.

Learnt from: amaslenn
Repo: NVIDIA/cloudai PR: 839
File: src/cloudai/workloads/sglang/slurm_command_gen_strategy.py:41-41
Timestamp: 2026-03-17T20:01:53.137Z
Learning: In src/cloudai/workloads/sglang/slurm_command_gen_strategy.py, 0.0.0.0 bindings (default prefill_host/decode_host in get_serve_commands and --host 0.0.0.0 in get_helper_command) are intentional for in-cluster serving. Ruff S104 warnings on these lines should not be flagged or require suppression comments, as the maintainer (amaslenn) has explicitly confirmed this is the intended behavior.

Learnt from: podkidyshev
Repo: NVIDIA/cloudai PR: 846
File: src/cloudai/reporter.py:168-168
Timestamp: 2026-03-24T23:24:17.668Z
Learning: In the NVIDIA/cloudai repository, follow the project-wide logging convention: use root/global logging calls like `logging.info(...)` and `logging.warning(...)` rather than creating module-level loggers via `logging.getLogger(__name__)`. When reviewing Python code, do not flag or reference LOG015 for this reason and do not suggest switching to module-level loggers. Separately, Ruff’s formatting rule (e.g., G004 f-string vs %-style) is independent and may still be raised if applicable.

Learnt from: podkidyshev
Repo: NVIDIA/cloudai PR: 869
File: src/cloudai/_core/test_scenario.py:23-23
Timestamp: 2026-04-14T11:05:59.444Z
Learning: In the NVIDIA/cloudai repository, do not surface or fail code-review findings for Ruff UP035 warnings about deprecated typing aliases (e.g., typing.List, typing.Set, typing.Type) on individual files. This deprecation cleanup is planned for a dedicated whole-repo modernization bump, so UP035 should be treated as an expected/known issue rather than something reviewers should block at the per-file level.

Learnt from: orbalayla-nvidia
Repo: NVIDIA/cloudai PR: 899
File: src/cloudai/_core/grading_strategy.py:2-2
Timestamp: 2026-06-07T12:01:43.285Z
Learning: In the NVIDIA/cloudai repo, copyright headers’ year formatting must match the test logic enforced by tests/test_check_copyright_headers.py (via collect_years_same_file and _format_years_to_ranges): use comma-separated year lists for non-consecutive “touched” years (e.g., "2024, 2026") and use hyphen ranges only for consecutive years (e.g., "2024-2026"). Do not “normalize” "YYYY, YYYY" into "YYYY-YYYY" unless the missing years are actually present in git history; otherwise the copyright header check will fail.

Learnt from: rutayan-nv
Repo: NVIDIA/cloudai PR: 901
File: src/cloudai/cli/handlers.py:147-162
Timestamp: 2026-06-24T01:05:39.500Z
Learning: For NVIDIA/cloudai: Ruff rule G004 (logging-f-string) is not enabled in this repo’s Ruff configuration, and CI Ruff checks pass even with f-string usage in logging calls. During code reviews, do not flag or request changes specifically for Ruff G004 violations; leave logging f-strings as-is in .py files unless CI or other enabled rules fail.

raise ValueError(
f"ContinuousSpace(dtype='int') requires integer bounds; got low={self.low}, high={self.high}"
)
return self
36 changes: 32 additions & 4 deletions src/cloudai/_core/test_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import TYPE_CHECKING, Any, List, Optional, Set, Type, TypeAlias, Union

from ..util import flatten_dict
from .registry import Registry
from .system import System

if TYPE_CHECKING:
Expand Down Expand Up @@ -140,6 +141,22 @@ def get_metric_value(self, system: System, metric: str) -> MetricValue:
def is_dse_job(self) -> bool:
return self.test.is_dse_job or isinstance(self.num_nodes, list)

@property
def is_domain_randomization_active(self) -> bool:
"""
Whether this run will actually env-sample (domain-randomize) per trial.

True only when domain randomization is declared (``env_params`` present), the run is a DSE
job (so a per-trial loop exists - including a ``num_nodes`` sweep), and the agent opts into
sampling. An unknown agent is treated as opted-in so the dedicated agent-resolution error
surfaces instead of this one.
"""
if not self.test.is_domain_randomization_enabled:
return False

agent = Registry().agents_map.get(self.test.agent)
return self.is_dse_job and (agent is None or agent.supports_variable_environment)

@property
def nnodes(self) -> int:
"""Type safe getter for num_nodes, should only be used on an unrolled DSE job."""
Expand All @@ -156,7 +173,9 @@ def param_space(self) -> dict[str, Any]:
**{
key: value
for key, value in cmd_args_dict.items()
if isinstance(value, list) and not self.test.is_dse_excluded_arg(key)
if isinstance(value, list)
and not self.test.is_dse_excluded_arg(key)
and not self.test.is_env_sampled(key)
},
**{f"extra_env_vars.{key}": value for key, value in extra_env_vars_dict.items() if isinstance(value, list)},
}
Expand Down Expand Up @@ -184,9 +203,13 @@ def all_combinations(self) -> list[dict[str, Any]]:

return all_combinations

def apply_params_set(self, action: dict[str, Any]) -> "TestRun":
def apply_params_set(self, action: dict[str, Any], env_params: dict[str, Any] | None = None) -> "TestRun":
tdef = self.test.model_copy(deep=True)
for key, value in action.items():

# RNG runs in the env before this call; applying only concrete values keeps this deterministic.
# action and env_params target disjoint keys, so a plain merge applies both in one pass.
full_action = action | (env_params or {})
for key, value in full_action.items():
if key.startswith("extra_env_vars."):
tdef.extra_env_vars[key[len("extra_env_vars.") :]] = value
else:
Expand All @@ -199,7 +222,12 @@ def apply_params_set(self, action: dict[str, Any]) -> "TestRun":
else:
setattr(obj, attrs[-1], value)

type(tdef)(**tdef.model_dump()) # trigger validation
# env_params is validated at parse time; after the overlay its target cmd_args fields hold
# concrete scalar draws, so re-validating it here would reject weighted specs. Drop it for
# this validation-only pass, which exists to validate the applied action values.
validation_args = tdef.model_dump()
validation_args.pop("env_params", None)
type(tdef)(**validation_args) # trigger validation

new_tr = copy.deepcopy(self)
new_tr.test = tdef
Expand Down
14 changes: 11 additions & 3 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import toml
import yaml

from cloudai.configurator.env_params import validate_domain_randomization_active
from cloudai.core import (
BaseInstaller,
CloudAIGymEnv,
Expand All @@ -39,6 +40,7 @@
System,
TestParser,
TestScenario,
TestScenarioParsingError,
)
from cloudai.models.scenario import ReportConfig
from cloudai.models.workload import TestDefinition
Expand Down Expand Up @@ -133,8 +135,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
return 1

err = 0
# Recoverable failures return a non-zero rc and are accumulated here; an unexpected exception
# (a bug) is a hard-fail. We capture it so reports still generate, then re-raise below.
# Capture an unexpected error so reports still generate, then re-raise below.
run_error: Exception | None = None
try:
for tr in runner.runner.test_scenario.test_runs:
Expand Down Expand Up @@ -303,6 +304,12 @@ def handle_dry_run_and_run(args: argparse.Namespace) -> int:
return 1
system, test_scenario, tests = setup_result

try:
validate_domain_randomization_active(test_scenario)
except TestScenarioParsingError as e:
logging.error(str(e))
return 1
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if not _handle_single_sbatch(args, system):
return 1

Expand Down Expand Up @@ -491,7 +498,8 @@ def verify_test_scenarios(
tests = Parser.parse_tests(test_tomls, system)
hook_tests = Parser.parse_tests(hook_test_tomls, system)
hooks = Parser.parse_hooks(hook_tomls, system, {t.name: t for t in hook_tests})
Parser.parse_test_scenario(scenario_file, system, {t.name: t for t in tests}, hooks)
scenario = Parser.parse_test_scenario(scenario_file, system, {t.name: t for t in tests}, hooks)
validate_domain_randomization_active(scenario)
except Exception:
nfailed += 1

Expand Down
17 changes: 11 additions & 6 deletions src/cloudai/configurator/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ class BaseAgent(ABC):
Provides a unified interface and parameter management for action spaces.
"""

# Opt-in: agents that operate over a variable environment - one that changes per trial, whether
# by env_params sampling (domain randomization) or a curriculum schedule - set this True. Default
# False so env_params declared for an agent that cannot handle a varying env are rejected rather
# than silently ignored.
supports_variable_environment: bool = False

def __init__(self, env: BaseGym, config: BaseAgentConfig):
"""
Initialize the agent with the environment.
Expand Down Expand Up @@ -94,9 +100,8 @@ def select_action(self, observation: list[float] | None = None) -> tuple[int, di

Args:
observation: Latest observation produced by the environment (``env.reset()`` on the
first call, then the result of the prior ``env.step()``). Stateless agents such
as grid search or Bayesian optimization may ignore this; observation-conditioned
agents (RL, contextual bandits) should use it.
first call, then the result of the prior ``env.step()``). Stateless agents may
ignore this; observation-conditioned agents should use it.

Returns:
Tuple[int, Dict[str, Any]] | None: The current step index and a dictionary mapping action keys
Expand All @@ -120,8 +125,7 @@ def run(self) -> int:

Default: a step loop driven by the dispatcher (``select_action`` →
``env.step`` → ``update_policy`` per trial). Agents that drive their
own training loop (e.g. RLlib-based agents calling ``algo.train()``)
override this method.
own training loop override this method.

Failure contract (``handle_dse_job`` consumes the result via
``err |= agent.run()``):
Expand All @@ -131,7 +135,8 @@ def run(self) -> int:
accumulated and the next ``TestRun`` still executes. Workload-level
failures are already surfaced this way: ``CloudAIGymEnv.step`` maps a
failed metric to ``rewards.metric_failure`` rather than raising, and
``rllib_run`` catches training errors and returns ``rc=1``.
agents with their own training loop should likewise catch training
errors and return a non-zero code.
- Raise for *unexpected* failures (framework/agent bugs). Exceptions
propagate out of ``handle_dse_job`` and hard-fail the job so the bug
is surfaced instead of masked as a penalizing reward.
Expand Down
50 changes: 44 additions & 6 deletions src/cloudai/configurator/cloudai_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from .base_agent import RewardOverrides
from .base_gym import BaseGym
from .env_params import EnvParams, write_env_params


@dataclasses.dataclass(frozen=True)
Expand All @@ -36,6 +37,7 @@ class TrajectoryEntry:
action: dict[str, Any]
reward: float
observation: list
env_params: dict[str, Any] = dataclasses.field(default_factory=dict)


class CloudAIGymEnv(BaseGym):
Expand All @@ -61,8 +63,14 @@ def __init__(self, test_run: TestRun, runner: BaseRunner, rewards: RewardOverrid
self.max_steps = test_run.test.agent_steps
self.reward_function = Registry().get_reward_function(test_run.test.agent_reward_function)
self.trajectory: dict[int, list[TrajectoryEntry]] = {}
self.params: EnvParams | None = EnvParams.from_test(test_run.test)
super().__init__()

@property
def env_params_record_path(self) -> Path:
"""``env.csv`` lives alongside ``trajectory.csv`` so a plain ``merge`` joins them."""
return self.iteration_dir / "env.csv"

def define_action_space(self) -> Dict[str, list[Any]]:
return self.test_run.param_space

Expand Down Expand Up @@ -119,9 +127,11 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
- info (dict): Additional info for debugging.
"""
self.test_run.increment_step()
self.test_run = self.test_run.apply_params_set(action)
# RNG lives in the env: sample here, then apply action + sample so the run and cache key see them.
sampled_env_params = self.params.sample(self.test_run.step) if self.params else {}
self.test_run = self.test_run.apply_params_set(action, env_params=sampled_env_params)

cached_result = self.get_cached_trajectory_result(action)
cached_result = self.get_cached_trajectory_result(action, sampled_env_params)
if cached_result is not None:
logging.info(
"Retrieved cached result from trajectory with reward %s (from step %s). Skipping execution.",
Expand All @@ -134,6 +144,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
action=action,
reward=cached_result.reward,
observation=cached_result.observation,
env_params=sampled_env_params,
)
)
return cached_result.observation, cached_result.reward, False, {}
Expand Down Expand Up @@ -171,6 +182,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
action=action,
reward=reward,
observation=observation,
env_params=sampled_env_params,
)
)

Expand Down Expand Up @@ -230,7 +242,14 @@ def get_observation(self, action: Any) -> list:
return observation

def write_trajectory(self, entry: TrajectoryEntry):
"""Append the trajectory to the CSV file and to the local attribute."""
"""
Append the entry to the in-memory cache and trajectory.csv (plus env.csv when declared).

``trajectory.csv`` and the ``env.csv`` projection are sunk from the same
``TrajectoryEntry`` here, so a trial that never produces an entry (e.g. a
constraint failure returns before this call) lands in neither file and the
two stay 1:1 step-aligned.
"""
self.current_trajectory.append(entry)

file_exists = self.trajectory_file_path.exists()
Expand All @@ -243,17 +262,36 @@ def write_trajectory(self, entry: TrajectoryEntry):
writer.writerow(["step", "action", "reward", "observation"])
writer.writerow([entry.step, entry.action, entry.reward, entry.observation])

write_env_params(self.env_params_record_path, entry.step, entry.env_params)

@property
def iteration_dir(self) -> Path:
"""Per-iteration output dir; trajectory.csv and env.csv both live here, step-aligned."""
return self.runner.scenario_root / self.test_run.name / f"{self.test_run.current_iteration}"

@property
def trajectory_file_path(self) -> Path:
return self.runner.scenario_root / self.test_run.name / f"{self.test_run.current_iteration}" / "trajectory.csv"
return self.iteration_dir / "trajectory.csv"

@property
def current_trajectory(self) -> list[TrajectoryEntry]:
return self.trajectory.setdefault(self.test_run.current_iteration, [])

def get_cached_trajectory_result(self, action: Any) -> TrajectoryEntry | None:
def get_cached_trajectory_result(self, action: Any, env_params: dict[str, Any]) -> TrajectoryEntry | None:
"""
Return a cached entry only when the full trial identity matches.

Trial identity is ``(action, env_params)``: env-randomized parameters
change the workload's behaviour, so a trial repeating the same action
under a different ``env_params`` sample must miss and re-run. Empty
env_params on both sides is the back-compat path for workloads that
do not declare any ``[env_params.*]`` block. The sample is passed in (a
per-trial local owned by ``step``), exactly like ``action``.
"""
for entry in self.current_trajectory:
if self._values_match_exact(entry.action, action):
action_match = self._values_match_exact(entry.action, action)
env_params_match = self._values_match_exact(entry.env_params, env_params)
if action_match and env_params_match:
return entry

return None
Expand Down
Loading
Loading