From 427459fc11e1beaa86fbebaaf6d347685ff7ddab Mon Sep 17 00:00:00 2001 From: Teerth Sharma Date: Sun, 28 Jun 2026 04:45:21 +0530 Subject: [PATCH 1/6] feat: improve topology-aware ACG convergence Signed-off-by: Teerth Sharma --- Cargo.lock | 196 ++++- crates/adaptive/Cargo.toml | 13 + crates/adaptive/benches/convergence_bench.rs | 183 +++++ crates/adaptive/src/acg/ir_builder.rs | 101 ++- crates/adaptive/src/acg/stability.rs | 106 ++- crates/adaptive/src/acg_component.rs | 26 + crates/adaptive/src/acg_learner.rs | 203 +++++- crates/adaptive/src/acg_profile.rs | 63 +- .../adaptive/src/adaptive_hints_intercept.rs | 64 +- crates/adaptive/src/config.rs | 146 ++++ crates/adaptive/src/lib.rs | 12 +- crates/adaptive/src/plugin_component.rs | 45 +- crates/adaptive/src/runtime/features.rs | 35 +- crates/adaptive/src/runtime/validation.rs | 73 ++ .../adaptive/src/tool_parallelism_learner.rs | 73 +- crates/adaptive/src/topology.rs | 361 ++++++++++ .../integration/runtime_integration_tests.rs | 7 +- .../tool_parallelism_plan_tests.rs | 13 +- .../integration/topology_convergence_tests.rs | 272 +++++++ .../unit/acg/economics_internal_tests.rs | 2 + .../tests/unit/acg/economics_policy_tests.rs | 2 + .../tests/unit/acg/ir_builder_tests.rs | 86 +++ .../tests/unit/acg/multi_breakpoint_tests.rs | 2 + .../unit/acg/stability_internal_tests.rs | 57 ++ .../tests/unit/acg_component_tests.rs | 283 +++++++- .../adaptive/tests/unit/acg_learner_tests.rs | 675 +++++++++++++++++- .../adaptive/tests/unit/acg_profile_tests.rs | 60 ++ .../unit/adaptive_hints_intercept_tests.rs | 75 ++ .../tests/unit/cache_diagnostics_tests.rs | 2 + crates/adaptive/tests/unit/config_tests.rs | 43 ++ .../adaptive/tests/unit/intercepts_tests.rs | 2 + .../tests/unit/plugin_component_tests.rs | 38 +- .../tests/unit/runtime_features_tests.rs | 65 +- crates/adaptive/tests/unit/runtime_tests.rs | 75 +- .../unit/storage_memory_internal_tests.rs | 2 + crates/adaptive/tests/unit/storage_tests.rs | 3 + .../unit/tool_parallelism_learner_tests.rs | 59 ++ crates/adaptive/tests/unit/topology_tests.rs | 146 ++++ crates/adaptive/tests/unit/types_tests.rs | 2 + crates/node/adaptive.d.ts | 29 + crates/node/adaptive.js | 46 ++ crates/node/tests/adaptive_tests.mjs | 32 + .../coverage/py_storage_coverage_tests.rs | 2 + go/nemo_relay/adaptive.go | 50 +- go/nemo_relay/adaptive/adaptive.go | 24 + go/nemo_relay/adaptive_test.go | 32 +- python/nemo_relay/adaptive.py | 78 +- python/nemo_relay/adaptive.pyi | 58 +- python/tests/test_adaptive.py | 37 + python/tests/test_adaptive_config.py | 6 + 50 files changed, 3927 insertions(+), 138 deletions(-) create mode 100644 crates/adaptive/benches/convergence_bench.rs create mode 100644 crates/adaptive/src/topology.rs create mode 100644 crates/adaptive/tests/integration/topology_convergence_tests.rs create mode 100644 crates/adaptive/tests/unit/topology_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 87581d293..ab5596607 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "1.0.0" @@ -325,6 +331,33 @@ dependencies = [ "windows-link", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clap" version = "4.6.0" @@ -477,12 +510,73 @@ dependencies = [ "libc", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "crypto-common" version = "0.1.7" @@ -813,6 +907,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -834,6 +939,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "http" version = "1.4.0" @@ -1129,12 +1240,32 @@ dependencies = [ "serde", ] +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -1388,6 +1519,7 @@ name = "nemo-relay-adaptive" version = "0.5.0" dependencies = [ "chrono", + "criterion", "nemo-relay", "redis", "regex", @@ -1637,7 +1769,7 @@ dependencies = [ "http-body-util", "humantime", "hyper", - "itertools", + "itertools 0.14.0", "md-5", "parking_lot", "percent-encoding", @@ -1846,6 +1978,34 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "portable-atomic" version = "1.13.1" @@ -1906,7 +2066,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" dependencies = [ "heck", - "itertools", + "itertools 0.14.0", "log", "multimap", "petgraph", @@ -1927,7 +2087,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", - "itertools", + "itertools 0.14.0", "proc-macro2", "quote", "syn", @@ -2240,6 +2400,26 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" +[[package]] +name = "rayon" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redis" version = "1.2.0" @@ -2875,6 +3055,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.11.0" diff --git a/crates/adaptive/Cargo.toml b/crates/adaptive/Cargo.toml index c59bf27e4..56f20d559 100644 --- a/crates/adaptive/Cargo.toml +++ b/crates/adaptive/Cargo.toml @@ -34,6 +34,11 @@ redis-backend = ["redis"] [dev-dependencies] tokio = { version = "1", default-features = false, features = ["rt", "macros", "sync", "time", "test-util", "rt-multi-thread"] } tokio-stream = { version = "0.1", default-features = false } +criterion = { version = "0.5", features = ["html_reports"] } + +[[bench]] +name = "convergence_bench" +harness = false [[test]] name = "redis_integration" @@ -46,3 +51,11 @@ path = "tests/integration/runtime_integration_tests.rs" [[test]] name = "acg_module_surface" path = "tests/integration/acg_module_surface_tests.rs" + +[[test]] +name = "topology_convergence" +path = "tests/integration/topology_convergence_tests.rs" + +[[test]] +name = "tool_parallelism_plan" +path = "tests/integration/tool_parallelism_plan_tests.rs" diff --git a/crates/adaptive/benches/convergence_bench.rs b/crates/adaptive/benches/convergence_bench.rs new file mode 100644 index 000000000..845b4df71 --- /dev/null +++ b/crates/adaptive/benches/convergence_bench.rs @@ -0,0 +1,183 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Benchmark comparing observations-to-decision for the ACG learner with and +//! without topological convergence detection. +//! +//! The synthetic prompt profile consists of 50 identical observations. Without +//! convergence detection the learner processes the full observation sequence +//! before deciding, while topological convergence detection declares +//! convergence after the configured stability window. +#![allow(missing_docs)] + +use std::collections::HashMap; +use std::sync::{Arc, LazyLock, RwLock}; + +use chrono::Utc; +use criterion::{Criterion, black_box, criterion_group, criterion_main}; +use nemo_relay::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; +use nemo_relay_adaptive::acg::build_prompt_ir; +use nemo_relay_adaptive::acg::prompt_ir::PromptIR; +use nemo_relay_adaptive::acg::stability::StabilityThresholds; +use nemo_relay_adaptive::acg_learner::AcgLearner; +use nemo_relay_adaptive::config::ConvergenceConfig; +use nemo_relay_adaptive::learner::traits::Learner; +use nemo_relay_adaptive::types::cache::HotCache; +use nemo_relay_adaptive::types::records::{CallKind, CallRecord, RunRecord}; +use nemo_relay_adaptive::{InMemoryBackend, StorageBackendDyn}; +use uuid::Uuid; + +static RUNTIME: LazyLock = + LazyLock::new(|| tokio::runtime::Runtime::new().expect("tokio runtime")); + +fn identical_request() -> AnnotatedLlmRequest { + AnnotatedLlmRequest { + messages: vec![ + Message::System { + content: MessageContent::Text("You are a helpful assistant.".to_string()), + name: None, + }, + Message::User { + content: MessageContent::Text("Summarize this.".to_string()), + name: None, + }, + ], + model: Some("gpt-4o".to_string()), + params: None, + tools: None, + tool_choice: None, + store: None, + previous_response_id: None, + truncation: None, + reasoning: None, + include: None, + user: None, + metadata: None, + service_tier: None, + parallel_tool_calls: None, + max_output_tokens: None, + max_tool_calls: None, + top_logprobs: None, + stream: None, + extra: serde_json::Map::new(), + } +} + +fn build_stable_observations(count: usize) -> Vec { + let request = identical_request(); + (0..count) + .map(|_| build_prompt_ir(&request).expect("valid prompt IR")) + .collect() +} + +fn build_run(request: AnnotatedLlmRequest) -> RunRecord { + let now = Utc::now(); + RunRecord { + id: Uuid::now_v7(), + agent_id: "convergence-agent".to_string(), + calls: vec![CallRecord { + kind: CallKind::Llm, + name: "planner".to_string(), + started_at: now, + ended_at: Some(now), + metadata_snapshot: None, + output_tokens: None, + prompt_tokens: None, + total_tokens: None, + model_name: None, + tool_call_count: None, + annotated_request: Some(request.into()), + annotated_response: None, + }], + started_at: now, + ended_at: Some(now), + } +} + +fn empty_cache() -> Arc> { + Arc::new(RwLock::new(HotCache { + plan: None, + trie: None, + agent_hints_default: None, + acg_profiles: HashMap::new(), + acg_profile_observation_counts: HashMap::new(), + acg_stability: None, + acg_observation_count: 0, + })) +} + +fn observations_without_convergence(observations: &[PromptIR]) -> usize { + RUNTIME.block_on(async { + let learner = AcgLearner::new("convergence-agent", 100, StabilityThresholds::default()); + let backend = InMemoryBackend::new(); + let cache = empty_cache(); + + for _ in 0..observations.len() { + let run = build_run(identical_request()); + learner + .process_run(&run, &backend, &cache) + .await + .expect("process run"); + } + observations.len() + }) +} + +fn observations_with_convergence(observations: &[PromptIR]) -> usize { + RUNTIME.block_on(async { + let config = ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }; + let learner = AcgLearner::new_with_convergence( + "convergence-agent", + 100, + StabilityThresholds::default(), + Some(config), + ); + let backend = InMemoryBackend::new(); + let cache = empty_cache(); + + for index in 0..observations.len() { + let run = build_run(identical_request()); + learner + .process_run(&run, &backend, &cache) + .await + .expect("process run"); + + let stability = backend + .load_stability("convergence-agent") + .await + .expect("load stability") + .expect("stability exists"); + if stability.converged { + return index + 1; + } + } + observations.len() + }) +} + +fn bench_convergence(c: &mut Criterion) { + let observations = build_stable_observations(50); + + c.bench_function("without_convergence", |b| { + b.iter(|| observations_without_convergence(black_box(&observations))) + }); + + c.bench_function("with_convergence", |b| { + b.iter(|| observations_with_convergence(black_box(&observations))) + }); + + let without = observations_without_convergence(&observations); + let with = observations_with_convergence(&observations); + println!("observations-to-decision: without={without}, with={with}"); + assert!( + with <= without, + "convergence path should use fewer or equal observations: {with} <= {without}" + ); +} + +criterion_group!(benches, bench_convergence); +criterion_main!(benches); diff --git a/crates/adaptive/src/acg/ir_builder.rs b/crates/adaptive/src/acg/ir_builder.rs index a40702d63..47d937fe3 100644 --- a/crates/adaptive/src/acg/ir_builder.rs +++ b/crates/adaptive/src/acg/ir_builder.rs @@ -17,6 +17,8 @@ use crate::acg::prompt_ir::{ ToolSchemaHash, }; +const RESPONSE_FORMAT_EXTRA_KEY: &str = "response_format"; + /// Build a normalized [`PromptIR`] from an annotated LLM request. /// /// The builder preserves prompt order, inserts tool-schema blocks before the @@ -35,43 +37,70 @@ use crate::acg::prompt_ir::{ pub fn build_prompt_ir(request: &AnnotatedLlmRequest) -> Result { let mut blocks: Vec = Vec::new(); let mut sequence_index: u32 = 0; - let mut inserted_tool_blocks = false; + let mut inserted_static_contract_blocks = false; + let structured_output_contract = build_structured_output_contract(request)?; + let has_static_contract_blocks = + request.tools.is_some() || structured_output_contract.is_some(); for message in &request.messages { - if should_insert_tool_blocks_before_message(inserted_tool_blocks, request, message) { - append_tool_schema_blocks(&mut blocks, &mut sequence_index, request.tools.as_deref())?; - inserted_tool_blocks = true; + if should_insert_static_contract_blocks_before_message( + inserted_static_contract_blocks, + has_static_contract_blocks, + message, + ) { + append_static_contract_blocks( + &mut blocks, + &mut sequence_index, + request.tools.as_deref(), + structured_output_contract.as_ref(), + )?; + inserted_static_contract_blocks = true; } append_message_blocks(&mut blocks, &mut sequence_index, message)?; } - if !inserted_tool_blocks { - append_tool_schema_blocks(&mut blocks, &mut sequence_index, request.tools.as_deref())?; + if !inserted_static_contract_blocks { + append_static_contract_blocks( + &mut blocks, + &mut sequence_index, + request.tools.as_deref(), + structured_output_contract.as_ref(), + )?; } let tool_schema_hashes = match &request.tools { Some(tools) => Some(build_tool_schema_hashes(tools)?), None => None, }; + let structured_output_schema_id = structured_output_contract + .as_ref() + .map(|contract| contract.schema_id.clone()); let source_request_hash = Some(compute_request_hash(request)?); Ok(PromptIR { ir_id: Uuid::new_v4(), blocks, tool_schema_hashes, - structured_output_schema_id: None, + structured_output_schema_id, source_request_hash, created_at: Utc::now(), }) } -fn should_insert_tool_blocks_before_message( - inserted_tool_blocks: bool, - request: &AnnotatedLlmRequest, +struct StructuredOutputContract { + content: String, + schema_id: String, +} + +fn should_insert_static_contract_blocks_before_message( + inserted_static_contract_blocks: bool, + has_static_contract_blocks: bool, message: &Message, ) -> bool { - !inserted_tool_blocks && !matches!(message, Message::System { .. }) && request.tools.is_some() + !inserted_static_contract_blocks + && has_static_contract_blocks + && !matches!(message, Message::System { .. }) } fn append_message_blocks( @@ -252,6 +281,21 @@ fn append_tool_schema_blocks( Ok(()) } +fn append_static_contract_blocks( + blocks: &mut Vec, + seq: &mut u32, + tools: Option<&[ToolDefinition]>, + structured_output_contract: Option<&StructuredOutputContract>, +) -> Result<()> { + append_tool_schema_blocks(blocks, seq, tools)?; + + if let Some(contract) = structured_output_contract { + blocks.push(build_structured_output_block(seq, contract)); + } + + Ok(()) +} + fn build_tool_schema_block(seq: &mut u32, tool: &ToolDefinition) -> Result { let tool_value = serde_json::to_value(tool)?; let canonical = canonicalize_value(&tool_value)?; @@ -272,6 +316,41 @@ fn build_tool_schema_block(seq: &mut u32, tool: &ToolDefinition) -> Result Result> { + let Some(response_format) = request.extra.get(RESPONSE_FORMAT_EXTRA_KEY) else { + return Ok(None); + }; + + let canonical = canonicalize_value(response_format)?; + let content = normalize_whitespace(&canonical); + Ok(Some(StructuredOutputContract { + schema_id: sha256_hex(&canonical), + content, + })) +} + +fn build_structured_output_block( + seq: &mut u32, + contract: &StructuredOutputContract, +) -> PromptBlock { + let index = *seq; + let span_id = generate_span_id(PromptRole::System, index, Some("structured-output")); + *seq += 1; + + PromptBlock { + span_id, + sequence_index: index, + role: PromptRole::System, + content: contract.content.clone(), + content_type: BlockContentType::StructuredOutput, + provenance: ProvenanceLabel::System, + sensitivity: SensitivityLabel::default(), + token_metadata: None, + } +} + fn build_tool_schema_hashes(tools: &[ToolDefinition]) -> Result> { tools .iter() diff --git a/crates/adaptive/src/acg/stability.rs b/crates/adaptive/src/acg/stability.rs index 6141ffa54..820fc7758 100644 --- a/crates/adaptive/src/acg/stability.rs +++ b/crates/adaptive/src/acg/stability.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use crate::acg::canonicalize::sha256_hex; use crate::acg::profile::{BlockStabilityScore, StabilityClass}; -use crate::acg::prompt_ir::{PromptIR, SpanId}; +use crate::acg::prompt_ir::{BlockContentType, PromptBlock, PromptIR, PromptRole, SpanId}; /// Thresholds controlling prompt-block stability classification. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -39,8 +39,15 @@ pub struct StabilityAnalysisResult { pub scores: Vec, /// Number of leading blocks that were classified as stable. pub stable_prefix_length: usize, + /// Fingerprint of the dominant observed stable prefix content. + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub stable_prefix_fingerprint: Option, /// Total number of observations included in the analysis. pub total_observations: u32, + /// Whether topological convergence has been declared for this profile. + #[serde(default)] + pub converged: bool, } struct SpanObservations { @@ -69,7 +76,9 @@ pub fn analyze_stability( return StabilityAnalysisResult { scores: Vec::new(), stable_prefix_length: 0, + stable_prefix_fingerprint: None, total_observations: 0, + converged: false, }; } @@ -85,15 +94,108 @@ pub fn analyze_stability( .map(|(span_id, obs)| build_stability_score(span_id, obs, total_observations, thresholds)) .collect(); - indexed_scores.sort_by_key(|(idx, _)| *idx); + sort_indexed_scores(&mut indexed_scores); let scores: Vec = indexed_scores.into_iter().map(|(_, score)| score).collect(); let stable_prefix_length = find_stable_prefix_length(&scores); + let stable_prefix_fingerprint = + dominant_stable_prefix_fingerprint(observations, stable_prefix_length); StabilityAnalysisResult { scores, stable_prefix_length, + stable_prefix_fingerprint, total_observations, + converged: false, + } +} + +pub(crate) fn prompt_prefix_fingerprint( + observation: &PromptIR, + prefix_length: usize, +) -> Option { + if prefix_length == 0 || observation.blocks.len() < prefix_length { + return None; + } + + let prefix = observation + .blocks + .iter() + .take(prefix_length) + .map(block_fingerprint_part) + .collect::>() + .join("\n"); + Some(sha256_hex(&prefix)) +} + +fn dominant_stable_prefix_fingerprint( + observations: &[PromptIR], + stable_prefix_length: usize, +) -> Option { + let mut counts: HashMap = HashMap::new(); + for observation in observations { + let Some(fingerprint) = prompt_prefix_fingerprint(observation, stable_prefix_length) else { + continue; + }; + *counts.entry(fingerprint).or_insert(0) += 1; + } + + select_dominant_prefix_fingerprint(counts) +} + +fn select_dominant_prefix_fingerprint(counts: HashMap) -> Option { + counts + .into_iter() + .fold(None, |best, candidate| match best { + None => Some(candidate), + Some((best_fingerprint, best_count)) => { + let (candidate_fingerprint, candidate_count) = candidate; + if candidate_count > best_count + || (candidate_count == best_count && candidate_fingerprint < best_fingerprint) + { + Some((candidate_fingerprint, candidate_count)) + } else { + Some((best_fingerprint, best_count)) + } + } + }) + .map(|(fingerprint, _)| fingerprint) +} + +fn sort_indexed_scores(indexed_scores: &mut [(u32, BlockStabilityScore)]) { + indexed_scores.sort_by(|(left_index, left_score), (right_index, right_score)| { + left_index + .cmp(right_index) + .then_with(|| left_score.span_id.0.cmp(&right_score.span_id.0)) + }); +} + +fn block_fingerprint_part(block: &PromptBlock) -> String { + format!( + "{}|{}|{}|{}", + block.span_id.0, + prompt_role_tag(block.role), + content_type_tag(block.content_type), + sha256_hex(&block.content) + ) +} + +fn prompt_role_tag(role: PromptRole) -> &'static str { + match role { + PromptRole::System => "system", + PromptRole::User => "user", + PromptRole::Assistant => "assistant", + PromptRole::Tool => "tool", + } +} + +fn content_type_tag(content_type: BlockContentType) -> &'static str { + match content_type { + BlockContentType::Text => "text", + BlockContentType::ToolSchema => "tool_schema", + BlockContentType::ToolResult => "tool_result", + BlockContentType::StructuredOutput => "structured_output", + BlockContentType::Image => "image", } } diff --git a/crates/adaptive/src/acg_component.rs b/crates/adaptive/src/acg_component.rs index 7d5ccaddc..d83ec90ef 100644 --- a/crates/adaptive/src/acg_component.rs +++ b/crates/adaptive/src/acg_component.rs @@ -14,6 +14,7 @@ use serde_json::json; use crate::acg::economics; use crate::acg::plugin::{PluginInput, ProviderPlugin}; use crate::acg::request_surfaces::{RequestSurface, resolve_request_surface_from_request}; +use crate::acg::stability::prompt_prefix_fingerprint; use crate::acg::translation::anthropic::AnthropicHintTranslator; use crate::acg::translation::openai::OpenAIHintTranslator; use crate::acg::translation::{HintPlan, HintTranslation, HintTranslator}; @@ -144,6 +145,19 @@ fn build_intent_bundle( ); return None; } + if !stable_prefix_fingerprint_matches_prompt_ir(stability, prompt_ir) { + acg_debug::emit( + "build_intent_bundle_skipped", + json!({ + "reason": "stable_prefix_fingerprint_mismatch", + "agent_id": agent_id, + "provider": provider, + "observation_count": observation_count, + "stable_prefix_length": stability.stable_prefix_length, + }), + ); + return None; + } let toolset_hash = annotated_request .tools @@ -210,6 +224,18 @@ fn build_intent_bundle( }) } +fn stable_prefix_fingerprint_matches_prompt_ir( + stability: &StabilityAnalysisResult, + prompt_ir: &crate::acg::PromptIR, +) -> bool { + let Some(fingerprint) = stability.stable_prefix_fingerprint.as_deref() else { + return stability.stable_prefix_length == 0; + }; + + prompt_prefix_fingerprint(prompt_ir, stability.stable_prefix_length).as_deref() + == Some(fingerprint) +} + fn build_cache_stability_intent( stability: &StabilityAnalysisResult, stable_prefix_end: usize, diff --git a/crates/adaptive/src/acg_learner.rs b/crates/adaptive/src/acg_learner.rs index 5c0a31ae3..d94559e0d 100644 --- a/crates/adaptive/src/acg_learner.rs +++ b/crates/adaptive/src/acg_learner.rs @@ -8,14 +8,19 @@ use std::future::Future; use std::pin::Pin; use std::sync::{Arc, RwLock}; +use serde_json::json; + +use crate::acg::debug as acg_debug; use crate::acg::ir_builder::build_prompt_ir; use crate::acg::prompt_ir::PromptIR; -use crate::acg::stability::{StabilityThresholds, analyze_stability}; +use crate::acg::stability::{StabilityThresholds, analyze_stability, prompt_prefix_fingerprint}; +use crate::config::ConvergenceConfig; use crate::acg_profile::derive_acg_learning_key; use crate::error::{AdaptiveError, Result}; use crate::learner::traits::Learner; use crate::storage::traits::StorageBackendDyn; +use crate::topology::{BettiNumbers, ConvergenceDetector}; use crate::types::cache::HotCache; use crate::types::records::{CallKind, RunRecord}; @@ -28,6 +33,8 @@ pub struct AcgLearner { agent_id: String, observation_window: usize, thresholds: StabilityThresholds, + convergence: Option, + convergence_detectors: Arc>>, } impl AcgLearner { @@ -45,12 +52,126 @@ impl AcgLearner { agent_id: impl Into, observation_window: usize, thresholds: StabilityThresholds, + ) -> Self { + Self::new_with_convergence(agent_id, observation_window, thresholds, None) + } + + /// Create a new ACG learner with optional topological convergence + /// detection. + /// + /// # Parameters + /// - `agent_id`: Agent identifier whose observations should be updated. + /// - `observation_window`: Maximum number of observations to retain per + /// profile. + /// - `thresholds`: Stability thresholds used during analysis. + /// - `convergence`: Optional convergence configuration; takes precedence + /// over any global settings when provided. + /// + /// # Returns + /// A configured [`AcgLearner`]. + pub fn new_with_convergence( + agent_id: impl Into, + observation_window: usize, + thresholds: StabilityThresholds, + convergence: Option, ) -> Self { Self { agent_id: agent_id.into(), observation_window, thresholds, + convergence, + convergence_detectors: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Map a stability analysis result to the topological feature vector used + /// by the convergence detector. + /// + /// The mapping treats the stable reusable prefix as the cacheable topology. + /// Suffix spans are cache-opaque work-item content, so they do not create + /// topology holes for convergence. Drift measures whether a cacheable + /// prefix boundary has appeared, and error is the complement of the average + /// stable-prefix score. + fn stability_to_convergence_features( + stability: &crate::acg::stability::StabilityAnalysisResult, + ) -> (BettiNumbers, f64, f64) { + let betti_0 = stability.stable_prefix_length as u32; + let betti_1 = 0; + let drift = if stability.stable_prefix_length == 0 { + 1.0 + } else { + 0.0 + }; + let stable_prefix_len = stability.stable_prefix_length.min(stability.scores.len()); + let avg_score = if stable_prefix_len == 0 { + 0.0 + } else { + stability + .scores + .iter() + .take(stable_prefix_len) + .map(|score| score.score) + .sum::() + / stable_prefix_len as f64 + }; + let error = 1.0 - avg_score; + + (BettiNumbers::new(betti_0, betti_1), drift, error) + } + + fn prompt_topology_matches_stability( + stability: &crate::acg::stability::StabilityAnalysisResult, + observation: &PromptIR, + ) -> bool { + let stable_prefix_length = stability.stable_prefix_length; + stable_prefix_length > 0 + && stability.scores.len() >= stable_prefix_length + && observation.blocks.len() >= stable_prefix_length + && stability + .stable_prefix_fingerprint + .as_deref() + .is_some_and(|fingerprint| { + prompt_prefix_fingerprint(observation, stable_prefix_length).as_deref() + == Some(fingerprint) + }) + && stability + .scores + .iter() + .take(stable_prefix_length) + .zip(observation.blocks.iter().take(stable_prefix_length)) + .all(|(score, block)| score.span_id == block.span_id) + } + + /// Update the per-profile topological convergence detector and return + /// whether the profile has converged. + fn record_stability_epoch( + &self, + profile_key: &str, + stability: &crate::acg::stability::StabilityAnalysisResult, + ) -> Result { + let Some(ref config) = self.convergence else { + return Ok(false); + }; + if !config.enabled { + return Ok(false); } + + let mut detectors = self.convergence_detectors.write().map_err(|error| { + AdaptiveError::Internal(format!("convergence detector lock poisoned: {error}")) + })?; + let stability_window = config.stability_window.max(3); + let detector = detectors + .entry(profile_key.to_string()) + .or_insert_with(|| ConvergenceDetector::new(config.epsilon, stability_window)); + + let (betti, drift, error) = Self::stability_to_convergence_features(stability); + detector.record_epoch(betti, drift, error); + + // Require at least `stability_window` epochs before allowing + // convergence so that error-based convergence cannot fire on the very + // first observation. + let enough_epochs = detector.epoch() as usize >= stability_window; + Ok(detector.is_converged() && enough_epochs) } } @@ -87,9 +208,58 @@ impl Learner for AcgLearner { Vec, crate::acg::stability::StabilityAnalysisResult, )> = None; + let mut best_aggregate_stability: Option< + crate::acg::stability::StabilityAnalysisResult, + > = None; for (profile_key, new_observations) in grouped_observations.drain() { + let existing_stability = backend.load_stability(&profile_key).await?; + let stability_window = self + .convergence + .as_ref() + .map(|config| config.stability_window.max(3)) + .unwrap_or(3); + + // If the profile has already converged, reuse the cached + // stability result and skip loading or adding observations. + // Stale records below the stability window fall through to + // the normal repair path. Requests whose span topology changed + // under the same learning key also reopen learning. + if let Some(cached) = existing_stability.as_ref().filter(|stability| { + stability.converged + && stability.total_observations as usize >= stability_window + && new_observations.iter().all(|observation| { + Self::prompt_topology_matches_stability(stability, observation) + }) + }) { + profile_counts.insert(profile_key.clone(), cached.total_observations); + profile_stability.insert(profile_key.clone(), cached.clone()); + + let replace_best = best_aggregate_stability + .as_ref() + .map(|current| { + (cached.stable_prefix_length, cached.total_observations) + > (current.stable_prefix_length, current.total_observations) + }) + .unwrap_or(true); + if replace_best { + best_aggregate_stability = Some(cached.clone()); + } + acg_debug::emit( + "learner_profile_reused", + json!({ + "agent_id": self.agent_id, + "learning_key": profile_key, + "total_observations": cached.total_observations, + "stable_prefix_length": cached.stable_prefix_length, + "converged": cached.converged, + }), + ); + continue; + } + let existing = backend.load_observations(&profile_key).await?; + let mut window: VecDeque = existing.unwrap_or_default().into_iter().collect(); @@ -101,17 +271,41 @@ impl Learner for AcgLearner { } let observations_vec: Vec = window.into_iter().collect(); + let mut stability_result = analyze_stability(&observations_vec, &self.thresholds); + + let converged_now = self.record_stability_epoch(&profile_key, &stability_result)?; + + // Store the observations that produced this stability result. + // On the epoch that first declares convergence these + // observations are preserved; on subsequent runs the cached + // converged result is reused and this path is skipped. backend .store_observations(&profile_key, &observations_vec) .await?; - let stability_result = analyze_stability(&observations_vec, &self.thresholds); + if converged_now { + stability_result.converged = true; + } + backend .store_stability(&profile_key, &stability_result) .await?; + acg_debug::emit( + "learner_profile_updated", + json!({ + "agent_id": self.agent_id, + "learning_key": profile_key, + "total_observations": stability_result.total_observations, + "stable_prefix_length": stability_result.stable_prefix_length, + "converged": stability_result.converged, + "converged_now": converged_now, + "stability_window": stability_window, + }), + ); + profile_counts.insert(profile_key.clone(), stability_result.total_observations); - profile_stability.insert(profile_key, stability_result.clone()); + profile_stability.insert(profile_key.clone(), stability_result.clone()); let replace_best = best_profile_seed .as_ref() @@ -124,6 +318,7 @@ impl Learner for AcgLearner { .unwrap_or(true); if replace_best { best_profile_seed = Some((observations_vec.clone(), stability_result.clone())); + best_aggregate_stability = Some(stability_result.clone()); } } @@ -144,7 +339,7 @@ impl Learner for AcgLearner { })?; guard.acg_profiles.extend(profile_stability); guard.acg_profile_observation_counts.extend(profile_counts); - if let Some((_, aggregate_stability)) = best_profile_seed { + if let Some(aggregate_stability) = best_aggregate_stability { guard.acg_observation_count = aggregate_stability.total_observations; guard.acg_stability = Some(aggregate_stability); } diff --git a/crates/adaptive/src/acg_profile.rs b/crates/adaptive/src/acg_profile.rs index 01e5c3b29..7e76cfa3e 100644 --- a/crates/adaptive/src/acg_profile.rs +++ b/crates/adaptive/src/acg_profile.rs @@ -14,26 +14,37 @@ const HASH_PREFIX_LEN: usize = 16; struct AcgKeyParts<'a> { model: &'a str, system_hash: String, + anchor_hash: String, tool_hash: String, + contract_hash: String, } /// Derive the stable ACG learning key used to bucket observations and hot-cache state. /// /// The learning key intentionally excludes the full role sequence because normal -/// multi-turn conversations grow every request. Instead it uses a coarse -/// conversation class plus the stable template fingerprints that should remain -/// distinct across prompt families. +/// multi-turn conversations grow every request. When the request has a stable +/// scaffold such as system policy, tool schemas, or an output contract, the key +/// follows that cacheable scaffold and leaves volatile work-item text out of +/// the bucket. One-off prompts without any scaffold retain a seed hash so +/// unrelated direct prompts are not collapsed together. pub(crate) fn derive_acg_learning_key( agent_id: &str, annotated_request: &AnnotatedLlmRequest, ) -> String { let parts = derive_key_parts(annotated_request); - let seed_fingerprint = learning_seed_fingerprint(annotated_request); - let seed_hash = short_hash(&seed_fingerprint); - format!( - "{agent_id}::model={}::seed={seed_hash}::system={}::tools={}", - parts.model, parts.system_hash, parts.tool_hash - ) + if has_cacheable_scaffold(&parts) { + format!( + "{agent_id}::model={}::scaffold=stable::system={}::tools={}", + parts.model, parts.system_hash, parts.tool_hash + ) + } else { + let seed_fingerprint = learning_seed_fingerprint(annotated_request); + let seed_hash = short_hash(&seed_fingerprint); + format!( + "{agent_id}::model={}::seed={seed_hash}::system={}::tools={}", + parts.model, parts.system_hash, parts.tool_hash + ) + } } /// Derive the exact ACG profile key used for diagnostics and debug output. @@ -45,11 +56,6 @@ pub(crate) fn derive_acg_profile_key( annotated_request: &AnnotatedLlmRequest, ) -> String { let parts = derive_key_parts(annotated_request); - let anchor_fingerprint = layered_anchor_fingerprint(annotated_request); - let anchor_hash = anchor_fingerprint - .as_deref() - .map(short_hash) - .unwrap_or("no-anchor"); let role_signature = annotated_request .messages .iter() @@ -57,22 +63,37 @@ pub(crate) fn derive_acg_profile_key( .collect::>() .join("."); format!( - "{agent_id}::model={}::roles={role_signature}::system={}::anchor={}::tools={}", - parts.model, parts.system_hash, anchor_hash, parts.tool_hash + "{agent_id}::model={}::roles={role_signature}::system={}::anchor={}::tools={}::contract={}", + parts.model, parts.system_hash, parts.anchor_hash, parts.tool_hash, parts.contract_hash ) } fn derive_key_parts(annotated_request: &AnnotatedLlmRequest) -> AcgKeyParts<'_> { let system_fingerprint = system_prompt_fingerprint(annotated_request); + let anchor_fingerprint = layered_anchor_fingerprint(annotated_request); let tool_fingerprint = tool_schema_fingerprint(annotated_request.tools.as_deref()); + let contract_fingerprint = output_contract_fingerprint(annotated_request); AcgKeyParts { model: annotated_request.model.as_deref().unwrap_or("unknown"), system_hash: short_hash(&system_fingerprint).to_string(), + anchor_hash: anchor_fingerprint + .as_deref() + .map(short_hash) + .unwrap_or("no-anchor") + .to_string(), tool_hash: short_hash(&tool_fingerprint).to_string(), + contract_hash: short_hash(&contract_fingerprint).to_string(), } } +fn has_cacheable_scaffold(parts: &AcgKeyParts<'_>) -> bool { + parts.system_hash != "no-system" + || parts.anchor_hash != "no-anchor" + || !matches!(parts.tool_hash.as_str(), "no-tools" | "tools-unavailab") + || parts.contract_hash != "no-contract" +} + fn message_role_tag(message: &Message) -> &'static str { match message { Message::System { .. } => "system", @@ -183,6 +204,16 @@ fn tool_schema_fingerprint(tools: Option<&[ToolDefinition]>) -> String { } } +fn output_contract_fingerprint(annotated_request: &AnnotatedLlmRequest) -> String { + let Some(contract) = annotated_request.extra.get("response_format") else { + return "no-contract".to_string(); + }; + + canonicalize_value(contract) + .map(|canonical| sha256_hex(&canonical)) + .unwrap_or_else(|_| "contract-unavailable".to_string()) +} + fn extract_text(content: &MessageContent) -> String { match content { MessageContent::Text(text) => text.clone(), diff --git a/crates/adaptive/src/adaptive_hints_intercept.rs b/crates/adaptive/src/adaptive_hints_intercept.rs index 05613aadb..3afbc4c96 100644 --- a/crates/adaptive/src/adaptive_hints_intercept.rs +++ b/crates/adaptive/src/adaptive_hints_intercept.rs @@ -10,16 +10,19 @@ //! transforms the [`LlmRequest`] before it reaches the callable. use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; +use std::time::Instant; use nemo_relay::api::llm::LlmRequest; use nemo_relay::api::runtime::LlmRequestInterceptFn; use nemo_relay::codec::request::AnnotatedLlmRequest; +use crate::config::GovernorConfig; use crate::context_helpers::{ extract_scope_path, read_manual_latency_sensitivity, resolve_agent_id, }; use crate::intercepts::AGENT_HINTS_HEADER_KEY; +use crate::topology::GeometricGovernor; use crate::trie::builder::SensitivityConfig; use crate::trie::lookup::PredictionTrieLookup; use crate::types::cache::HotCache; @@ -121,15 +124,29 @@ pub struct AdaptiveHintsIntercept { hot_cache: Arc>, agent_id: String, call_counter: AtomicU32, + governor: Option>>, } impl AdaptiveHintsIntercept { /// Creates a new `AdaptiveHintsIntercept`. pub fn new(hot_cache: Arc>, agent_id: String) -> Self { + Self::with_governor(hot_cache, agent_id, None) + } + + /// Creates a new `AdaptiveHintsIntercept` with optional load shedding. + pub fn with_governor( + hot_cache: Arc>, + agent_id: String, + governor: Option, + ) -> Self { + let governor = governor + .filter(|config| config.enabled) + .map(|config| Arc::new(Mutex::new(HintGovernor::new(config.epsilon)))); Self { hot_cache, agent_id, call_counter: AtomicU32::new(1), + governor, } } @@ -163,6 +180,21 @@ impl AdaptiveHintsIntercept { } } + fn should_inject_hints(&self, hints: &AgentHints, manual_ls: Option) -> bool { + if manual_ls.is_some() { + return true; + } + + let Some(governor) = &self.governor else { + return true; + }; + + governor + .lock() + .map(|mut governor| governor.allow(hints.latency_sensitivity)) + .unwrap_or(true) + } + /// Converts this intercept into an [`LlmRequestInterceptFn`] suitable for /// registration with [`register_llm_request_intercept`]. /// @@ -188,7 +220,9 @@ impl AdaptiveHintsIntercept { scope_depth, ); - if let Some(hints) = final_hints { + if let Some(hints) = final_hints + && this.should_inject_hints(&hints, manual_ls) + { inject_agent_hints(&mut request, &hints); } @@ -198,6 +232,32 @@ impl AdaptiveHintsIntercept { } } +struct HintGovernor { + governor: GeometricGovernor, + last_seen: Option, +} + +impl HintGovernor { + fn new(epsilon: f64) -> Self { + Self { + governor: GeometricGovernor::with_epsilon(epsilon), + last_seen: None, + } + } + + fn allow(&mut self, latency_sensitivity: f64) -> bool { + let allow = self.governor.should_trigger(latency_sensitivity); + let now = Instant::now(); + if let Some(last_seen) = self.last_seen { + let dt = now.duration_since(last_seen).as_secs_f64().max(0.000_001); + let observed_rate = 1.0 / dt; + self.governor.adapt(observed_rate, dt); + } + self.last_seen = Some(now); + allow + } +} + #[cfg(test)] #[path = "../tests/unit/adaptive_hints_intercept_tests.rs"] mod tests; diff --git a/crates/adaptive/src/config.rs b/crates/adaptive/src/config.rs index c24bc4e50..799cb955b 100644 --- a/crates/adaptive/src/config.rs +++ b/crates/adaptive/src/config.rs @@ -32,6 +32,9 @@ pub struct AdaptiveConfig { /// Adaptive Cache Governor settings. #[serde(default, skip_serializing_if = "Option::is_none")] pub acg: Option, + /// Global topological convergence settings. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub convergence: Option, /// Adaptive-local unsupported-config policy. #[serde(default)] pub policy: ConfigPolicy, @@ -47,6 +50,7 @@ impl Default for AdaptiveConfig { adaptive_hints: None, tool_parallelism: None, acg: None, + convergence: None, policy: ConfigPolicy::default(), } } @@ -123,6 +127,9 @@ pub struct AdaptiveHintsComponentConfig { /// JSON path used when injecting request-body hints. #[serde(default = "default_adaptive_hints_path")] pub inject_body_path: String, + /// Optional topology-aware load-shedding governor for hint injection. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub governor: Option, } impl Default for AdaptiveHintsComponentConfig { @@ -132,6 +139,7 @@ impl Default for AdaptiveHintsComponentConfig { break_chain: false, inject_header: true, inject_body_path: default_adaptive_hints_path(), + governor: None, } } } @@ -145,6 +153,9 @@ pub struct ToolParallelismComponentConfig { /// Scheduling mode such as `observe_only`, `inject_hints`, or `schedule`. #[serde(default = "default_tool_parallelism_mode")] pub mode: String, + /// Optional topology-aware drift detector for stale plan invalidation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub drift: Option, } impl Default for ToolParallelismComponentConfig { @@ -152,6 +163,71 @@ impl Default for ToolParallelismComponentConfig { Self { priority: default_priority(), mode: default_tool_parallelism_mode(), + drift: None, + } + } +} + +/// Typed helper for topology-aware hint load shedding. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GovernorConfig { + /// Whether the governor is active. + #[serde(default)] + pub enabled: bool, + /// Initial sensitivity threshold used by the governor. + #[serde(default = "default_governor_epsilon")] + pub epsilon: f64, +} + +impl Default for GovernorConfig { + fn default() -> Self { + Self { + enabled: false, + epsilon: default_governor_epsilon(), + } + } +} + +/// Typed helper for topology-aware tool plan drift detection. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DriftConfig { + /// Whether drift detection is active. + #[serde(default)] + pub enabled: bool, + /// Drift distance above which the existing execution plan is invalidated. + #[serde(default = "default_drift_threshold")] + pub threshold: f64, +} + +impl Default for DriftConfig { + fn default() -> Self { + Self { + enabled: false, + threshold: default_drift_threshold(), + } + } +} + +/// Typed helper for topological convergence detection settings. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConvergenceConfig { + /// Whether convergence detection is active. + #[serde(default)] + pub enabled: bool, + /// Error threshold below which the detector is considered converged. + #[serde(default = "default_convergence_epsilon")] + pub epsilon: f64, + /// Minimum number of epochs required to judge Betti-number stability. + #[serde(default = "default_convergence_stability_window")] + pub stability_window: usize, +} + +impl Default for ConvergenceConfig { + fn default() -> Self { + Self { + enabled: false, + epsilon: default_convergence_epsilon(), + stability_window: default_convergence_stability_window(), } } } @@ -171,6 +247,9 @@ pub struct AcgComponentConfig { /// Stability classification thresholds used by the learner. #[serde(default)] pub stability_thresholds: crate::acg::stability::StabilityThresholds, + /// Optional component-scoped topological convergence settings. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub convergence: Option, } impl Default for AcgComponentConfig { @@ -180,6 +259,7 @@ impl Default for AcgComponentConfig { observation_window: default_acg_observation_window(), priority: default_acg_priority(), stability_thresholds: crate::acg::stability::StabilityThresholds::default(), + convergence: None, } } } @@ -216,6 +296,22 @@ fn default_acg_priority() -> i32 { 50 } +fn default_convergence_epsilon() -> f64 { + 0.001 +} + +fn default_convergence_stability_window() -> usize { + 3 +} + +fn default_governor_epsilon() -> f64 { + 1.0 +} + +fn default_drift_threshold() -> f64 { + 0.75 +} + nemo_relay::editor_config! { impl AdaptiveConfig { agent_id => { label: "fallback_agent_id", kind: String, optional: true }, @@ -254,6 +350,13 @@ nemo_relay::editor_config! { nested: AcgComponentConfig, default: AcgComponentConfig, }, + convergence => { + label: "convergence", + kind: Section, + optional: true, + nested: ConvergenceConfig, + default: ConvergenceConfig, + }, policy => { label: "policy", kind: Section, @@ -294,6 +397,13 @@ nemo_relay::editor_config! { break_chain => { label: "break_chain", kind: Boolean }, inject_header => { label: "inject_header", kind: Boolean }, inject_body_path => { label: "inject_body_path", kind: String }, + governor => { + label: "governor", + kind: Section, + optional: true, + nested: GovernorConfig, + default: GovernorConfig, + }, } } @@ -305,6 +415,27 @@ nemo_relay::editor_config! { kind: Enum, values: ["observe_only", "inject_hints", "schedule"], }, + drift => { + label: "drift", + kind: Section, + optional: true, + nested: DriftConfig, + default: DriftConfig, + }, + } +} + +nemo_relay::editor_config! { + impl GovernorConfig { + enabled => { label: "enabled", kind: Boolean }, + epsilon => { label: "epsilon", kind: Float }, + } +} + +nemo_relay::editor_config! { + impl DriftConfig { + enabled => { label: "enabled", kind: Boolean }, + threshold => { label: "threshold", kind: Float }, } } @@ -323,6 +454,21 @@ nemo_relay::editor_config! { nested: crate::acg::stability::StabilityThresholds, default: crate::acg::stability::StabilityThresholds, }, + convergence => { + label: "convergence", + kind: Section, + optional: true, + nested: ConvergenceConfig, + default: ConvergenceConfig, + }, + } +} + +nemo_relay::editor_config! { + impl ConvergenceConfig { + enabled => { label: "enabled", kind: Boolean }, + epsilon => { label: "epsilon", kind: Float }, + stability_window => { label: "stability_window", kind: Integer }, } } diff --git a/crates/adaptive/src/lib.rs b/crates/adaptive/src/lib.rs index 9843a2e56..f8e8197ba 100644 --- a/crates/adaptive/src/lib.rs +++ b/crates/adaptive/src/lib.rs @@ -31,13 +31,15 @@ pub mod storage; pub mod subscriber; /// Learner that derives tool fan-out plans from observed runs. pub mod tool_parallelism_learner; +pub(crate) mod topology; pub mod trie; /// Serializable adaptive data models shared across runtime components. pub mod types; pub use config::{ - AcgComponentConfig, AdaptiveConfig, AdaptiveHintsComponentConfig, BackendSpec, StateConfig, - TelemetryComponentConfig, ToolParallelismComponentConfig, + AcgComponentConfig, AdaptiveConfig, AdaptiveHintsComponentConfig, BackendSpec, + ConvergenceConfig, DriftConfig, GovernorConfig, StateConfig, TelemetryComponentConfig, + ToolParallelismComponentConfig, }; pub use context_helpers::{ LATENCY_SENSITIVITY_POINTER, extract_scope_path, read_manual_latency_sensitivity, @@ -50,3 +52,9 @@ pub use runtime::features::AdaptiveRuntime; pub use storage::erased::AnyBackend; pub use storage::memory::InMemoryBackend; pub use storage::traits::{StorageBackend, StorageBackendDyn}; + +#[cfg(test)] +pub(crate) mod test_support { + pub(crate) static GLOBAL_RUNTIME_TEST_MUTEX: tokio::sync::Mutex<()> = + tokio::sync::Mutex::const_new(()); +} diff --git a/crates/adaptive/src/plugin_component.rs b/crates/adaptive/src/plugin_component.rs index 62ccfd4b6..f1d1ccdfa 100644 --- a/crates/adaptive/src/plugin_component.rs +++ b/crates/adaptive/src/plugin_component.rs @@ -182,6 +182,7 @@ fn validate_adaptive_plugin_config(plugin_config: &Map) -> Vec) -> Vec) -> Vec) -> Vec, acg_config: Option, + convergence: Option, ) -> Self { let subscriber_name = config .subscriber_name .unwrap_or_else(|| format!("adaptive_{runtime_id}_subscriber")); Self { - learners: build_learners(&agent_id, &config.learners, acg_config.as_ref()), + learners: build_learners( + &agent_id, + &config.learners, + tool_parallelism_config.as_ref(), + acg_config.as_ref(), + convergence.as_ref(), + ), agent_id, subscriber_name, } @@ -597,6 +607,7 @@ struct AdaptiveHintsFeature { break_chain: bool, hot_cache: Arc>, agent_id: String, + governor: Option, } impl AdaptiveHintsFeature { @@ -612,6 +623,7 @@ impl AdaptiveHintsFeature { break_chain: config.break_chain, hot_cache, agent_id, + governor: config.governor, } } } @@ -622,8 +634,11 @@ impl AdaptiveFeature for AdaptiveHintsFeature { ctx: &'a mut RegistrationContext<'_>, ) -> Pin> + Send + 'a>> { Box::pin(async move { - let adaptive_hints = - AdaptiveHintsIntercept::new(self.hot_cache.clone(), self.agent_id.clone()); + let adaptive_hints = AdaptiveHintsIntercept::with_governor( + self.hot_cache.clone(), + self.agent_id.clone(), + self.governor.clone(), + ); ctx.register_llm_request_intercept( &self.name, self.priority, @@ -770,7 +785,9 @@ impl AdaptiveFeature for AcgFeature { fn build_learners( agent_id: &str, learners: &[String], + tool_parallelism_config: Option<&ToolParallelismComponentConfig>, acg_config: Option<&AcgComponentConfig>, + convergence: Option<&crate::config::ConvergenceConfig>, ) -> Vec> { let mut built: Vec> = vec![]; for learner in learners { @@ -779,13 +796,21 @@ fn build_learners( agent_id, crate::trie::builder::SensitivityConfig::default(), ))), - "tool_parallelism" => built.push(Box::new(ToolParallelismLearner::new(agent_id))), + "tool_parallelism" => { + let drift = tool_parallelism_config.and_then(|config| config.drift.clone()); + built.push(Box::new(ToolParallelismLearner::new_with_drift( + agent_id, drift, + ))); + } "acg" => { if let Some(config) = acg_config { - built.push(Box::new(AcgLearner::new( + let profile_convergence = + config.convergence.clone().or_else(|| convergence.cloned()); + built.push(Box::new(AcgLearner::new_with_convergence( agent_id, config.observation_window, config.stability_thresholds.clone(), + profile_convergence, ))); } } diff --git a/crates/adaptive/src/runtime/validation.rs b/crates/adaptive/src/runtime/validation.rs index 8d0393791..e3c93aa66 100644 --- a/crates/adaptive/src/runtime/validation.rs +++ b/crates/adaptive/src/runtime/validation.rs @@ -62,6 +62,29 @@ pub fn validate_config(config: &AdaptiveConfig) -> ConfigReport { ), ); } + if let Some(tool_parallelism) = &config.tool_parallelism + && let Some(drift) = &tool_parallelism.drift + { + validate_positive_finite( + &mut report, + &config.policy, + "tool_parallelism.drift", + "threshold", + drift.threshold, + ); + } + + if let Some(adaptive_hints) = &config.adaptive_hints + && let Some(governor) = &adaptive_hints.governor + { + validate_positive_finite( + &mut report, + &config.policy, + "adaptive_hints.governor", + "epsilon", + governor.epsilon, + ); + } if let Some(acg) = &config.acg && acg.provider != "anthropic" @@ -80,10 +103,60 @@ pub fn validate_config(config: &AdaptiveConfig) -> ConfigReport { ), ); } + if let Some(acg) = &config.acg + && let Some(convergence) = &acg.convergence + { + validate_convergence(&mut report, &config.policy, "acg.convergence", convergence); + } + if let Some(convergence) = &config.convergence { + validate_convergence(&mut report, &config.policy, "convergence", convergence); + } report } +fn validate_convergence( + report: &mut ConfigReport, + policy: &ConfigPolicy, + component: &str, + convergence: &crate::config::ConvergenceConfig, +) { + validate_positive_finite(report, policy, component, "epsilon", convergence.epsilon); + if convergence.stability_window < 3 { + push_policy_diag( + &mut report.diagnostics, + policy.unsupported_value, + "adaptive.unsupported_value", + Some(component.to_string()), + Some("stability_window".to_string()), + format!( + "{component} stability_window must be at least 3, got {}", + convergence.stability_window + ), + ); + } +} + +fn validate_positive_finite( + report: &mut ConfigReport, + policy: &ConfigPolicy, + component: &str, + field: &str, + value: f64, +) { + if value.is_finite() && value > 0.0 { + return; + } + push_policy_diag( + &mut report.diagnostics, + policy.unsupported_value, + "adaptive.unsupported_value", + Some(component.to_string()), + Some(field.to_string()), + format!("{component} {field} must be a positive finite number, got {value}"), + ); +} + fn validate_backend(report: &mut ConfigReport, policy: &ConfigPolicy, backend: &BackendSpec) { let kind = backend.kind.as_str(); match kind { diff --git a/crates/adaptive/src/tool_parallelism_learner.rs b/crates/adaptive/src/tool_parallelism_learner.rs index f42340736..aef5b3105 100644 --- a/crates/adaptive/src/tool_parallelism_learner.rs +++ b/crates/adaptive/src/tool_parallelism_learner.rs @@ -13,9 +13,11 @@ use chrono::{DateTime, Utc}; use serde_json::json; use uuid::Uuid; +use crate::config::DriftConfig; use crate::error::{AdaptiveError, Result}; use crate::learner::traits::Learner; use crate::storage::traits::StorageBackendDyn; +use crate::topology::DriftDetector; use crate::types::cache::HotCache; use crate::types::metadata::{MetadataEnvelope, ParallelHint}; use crate::types::plan::{ExecutionPlan, ParallelGroup}; @@ -24,6 +26,8 @@ use crate::types::records::{CallKind, RunRecord}; /// Learner that discovers tool fan-out groups from run telemetry. pub struct ToolParallelismLearner { agent_id: String, + drift: Option, + drift_detector: Arc>>, } impl ToolParallelismLearner { @@ -35,10 +39,40 @@ impl ToolParallelismLearner { /// # Returns /// A configured [`ToolParallelismLearner`]. pub fn new(agent_id: impl Into) -> Self { + Self::new_with_drift(agent_id, None) + } + + /// Create a new tool-parallelism learner with optional drift detection. + /// + /// # Parameters + /// - `agent_id`: Agent identifier whose execution plan should be updated. + /// - `drift`: Optional topology-aware drift detection settings. + /// + /// # Returns + /// A configured [`ToolParallelismLearner`]. + pub fn new_with_drift(agent_id: impl Into, drift: Option) -> Self { Self { agent_id: agent_id.into(), + drift, + drift_detector: Arc::new(RwLock::new(DriftDetector::new())), } } + + fn record_cohort_drift(&self, observed_cohorts: &[Vec]) -> Result { + let Some(config) = &self.drift else { + return Ok(false); + }; + if !config.enabled { + return Ok(false); + } + + let centroid = cohort_feature_vector(observed_cohorts); + let mut detector = self.drift_detector.write().map_err(|error| { + AdaptiveError::Internal(format!("tool drift detector lock poisoned: {error}")) + })?; + let drift = detector.update(¢roid); + Ok(drift > config.threshold) + } } impl Learner for ToolParallelismLearner { @@ -54,10 +88,15 @@ impl Learner for ToolParallelismLearner { return Ok(()); } - let mut plan = backend - .load_plan_dyn(&self.agent_id) - .await? - .unwrap_or_else(|| empty_execution_plan(&self.agent_id, run.id)); + let drifted = self.record_cohort_drift(&observed_cohorts)?; + let mut plan = if drifted { + empty_execution_plan(&self.agent_id, run.id) + } else { + backend + .load_plan_dyn(&self.agent_id) + .await? + .unwrap_or_else(|| empty_execution_plan(&self.agent_id, run.id)) + }; plan.agent_id = self.agent_id.clone(); merge_observed_cohorts(&mut plan, &observed_cohorts, run.id); @@ -113,6 +152,32 @@ fn derive_observed_cohorts(run: &RunRecord) -> Vec> { observed } +fn cohort_feature_vector(observed_cohorts: &[Vec]) -> [f64; 4] { + if observed_cohorts.is_empty() { + return [0.0; 4]; + } + + let mut unique_tools = BTreeSet::new(); + let mut total_tool_refs = 0usize; + let mut max_cohort_size = 0usize; + + for cohort in observed_cohorts { + total_tool_refs += cohort.len(); + max_cohort_size = max_cohort_size.max(cohort.len()); + for tool in cohort { + unique_tools.insert(tool); + } + } + + let duplicate_refs = total_tool_refs.saturating_sub(unique_tools.len()); + [ + observed_cohorts.len() as f64, + unique_tools.len() as f64, + duplicate_refs as f64 / total_tool_refs.max(1) as f64, + max_cohort_size as f64, + ] +} + fn merge_observed_cohorts( plan: &mut ExecutionPlan, observed_cohorts: &[Vec], diff --git a/crates/adaptive/src/topology.rs b/crates/adaptive/src/topology.rs new file mode 100644 index 000000000..932f60389 --- /dev/null +++ b/crates/adaptive/src/topology.rs @@ -0,0 +1,361 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Internal topology-aware control primitives for adaptive learners. + +/// Maximum number of epochs retained in each history ring buffer. +const MAX_HISTORY: usize = 32; + +/// Minimum stability window length. +const MIN_STABILITY_WINDOW: usize = 3; + +/// Drift values below this threshold are considered converged. +const FINAL_DRIFT_THRESHOLD: f64 = 0.01; + +/// Default convergence error threshold. +const DEFAULT_EPSILON: f64 = 0.001; + +/// Target effective tick rate the governor tries to maintain, measured in Hz. +const TARGET_TICK_RATE: f64 = 1000.0; + +/// Proportional gain applied to the instantaneous control error. +const GOVERNOR_ALPHA: f64 = 0.01; + +/// Derivative gain applied to the rate of change of the control error. +const GOVERNOR_BETA: f64 = 0.05; + +/// Minimum allowed governor threshold. +const GOVERNOR_EPSILON_MIN: f64 = 0.001; + +/// Maximum allowed governor threshold. +const GOVERNOR_EPSILON_MAX: f64 = 10.0; + +/// Default governor threshold. +const GOVERNOR_EPSILON_INITIAL: f64 = 0.1; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) struct BettiNumbers { + pub(crate) beta_0: u32, + pub(crate) beta_1: u32, +} + +impl BettiNumbers { + pub(crate) const fn new(beta_0: u32, beta_1: u32) -> Self { + Self { beta_0, beta_1 } + } +} + +impl Default for BettiNumbers { + fn default() -> Self { + Self { + beta_0: 1, + beta_1: 0, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +struct RingBuffer { + data: [T; N], + len: usize, + pos: usize, +} + +impl RingBuffer { + fn new() -> Self { + Self { + data: [T::default(); N], + len: 0, + pos: 0, + } + } + + fn push(&mut self, value: T) { + self.data[self.pos] = value; + self.pos = (self.pos + 1) % N; + if self.len < N { + self.len += 1; + } + } + + fn len(&self) -> usize { + self.len + } + + fn last(&self) -> Option { + if self.len == 0 { + return None; + } + + let index = if self.pos == 0 { N - 1 } else { self.pos - 1 }; + Some(self.data[index]) + } + + fn copy_window(&self, window_size: usize, out: &mut [T]) -> usize { + let window_size = window_size.min(self.len).min(out.len()); + if window_size == 0 { + return 0; + } + + let start = (self.pos + N - window_size) % N; + for (index, slot) in out.iter_mut().enumerate().take(window_size) { + *slot = self.data[(start + index) % N]; + } + window_size + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) struct ConvergenceDecision { + pub(crate) epoch: u64, + pub(crate) stability_window: usize, + pub(crate) latest_betti: BettiNumbers, + pub(crate) latest_drift: f64, + pub(crate) latest_error: f64, + pub(crate) betti_stable: bool, + pub(crate) drift_decreasing: bool, + pub(crate) error_converged: bool, + pub(crate) converged: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) struct ConvergenceDetector { + betti_history: RingBuffer, + drift_history: RingBuffer, + error_history: RingBuffer, + stability_window: usize, + epsilon: f64, + epoch: u64, +} + +impl ConvergenceDetector { + pub(crate) fn new(epsilon: f64, stability_window: usize) -> Self { + Self { + betti_history: RingBuffer::new(), + drift_history: RingBuffer::new(), + error_history: RingBuffer::new(), + stability_window: stability_window.clamp(MIN_STABILITY_WINDOW, MAX_HISTORY), + epsilon: sanitize_positive(epsilon, DEFAULT_EPSILON), + epoch: 0, + } + } + + pub(crate) fn record_epoch( + &mut self, + betti: BettiNumbers, + drift: f64, + error: f64, + ) -> ConvergenceDecision { + self.betti_history.push(betti); + self.drift_history.push(sanitize_non_negative(drift)); + self.error_history.push(sanitize_non_negative(error)); + self.epoch = self.epoch.saturating_add(1); + self.decision() + .expect("record_epoch should always leave a latest convergence decision") + } + + pub(crate) fn decision(&self) -> Option { + let latest_betti = self.betti_history.last()?; + let latest_drift = self.drift_history.last()?; + let latest_error = self.error_history.last()?; + let betti_stable = self.is_betti_stable(); + let drift_decreasing = self.is_drift_decreasing(); + let error_converged = self.is_error_window_converged(); + + Some(ConvergenceDecision { + epoch: self.epoch, + stability_window: self.stability_window, + latest_betti, + latest_drift, + latest_error, + betti_stable, + drift_decreasing, + error_converged, + converged: betti_stable && drift_decreasing && error_converged, + }) + } + + pub(crate) fn is_converged(&self) -> bool { + self.is_betti_stable() && self.is_drift_decreasing() && self.is_error_window_converged() + } + + pub(crate) fn epoch(&self) -> u64 { + self.epoch + } + + fn is_error_window_converged(&self) -> bool { + if self.error_history.len() < self.stability_window { + return false; + } + + let mut window = [0.0; MAX_HISTORY]; + let count = self + .error_history + .copy_window(self.stability_window, &mut window); + window[..count] + .iter() + .all(|error| error.is_finite() && *error < self.epsilon) + } + + fn is_betti_stable(&self) -> bool { + if self.betti_history.len() < self.stability_window { + return false; + } + + let mut window = [BettiNumbers::default(); MAX_HISTORY]; + let count = self + .betti_history + .copy_window(self.stability_window, &mut window); + let first = window[0]; + window[..count].iter().all(|betti| *betti == first) + } + + fn is_drift_decreasing(&self) -> bool { + if self.drift_history.len() < self.stability_window { + return false; + } + + let mut window = [0.0; MAX_HISTORY]; + let count = self + .drift_history + .copy_window(self.stability_window, &mut window); + if window[..count].iter().any(|drift| !drift.is_finite()) { + return false; + } + + for pair in window[..count].windows(2) { + if pair[1] > pair[0] { + return false; + } + } + + window[count - 1] < FINAL_DRIFT_THRESHOLD + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) struct DriftDetector { + previous: [f64; D], + has_previous: bool, + velocity: [f64; D], + expected: [f64; D], +} + +impl DriftDetector { + pub(crate) fn new() -> Self { + Self { + previous: [0.0; D], + has_previous: false, + velocity: [0.0; D], + expected: [0.0; D], + } + } + + pub(crate) fn update(&mut self, centroid: &[f64; D]) -> f64 { + if centroid.iter().any(|coord| !coord.is_finite()) { + self.reset(); + return f64::INFINITY; + } + + let drift = if self.has_previous { + l2_distance(&self.expected, centroid) + } else { + 0.0 + }; + + if self.has_previous { + for (dimension, velocity) in self.velocity.iter_mut().enumerate() { + *velocity = centroid[dimension] - self.previous[dimension]; + } + } + + for (dimension, expected) in self.expected.iter_mut().enumerate() { + *expected = centroid[dimension] + self.velocity[dimension]; + } + + self.previous = *centroid; + self.has_previous = true; + + drift + } + + fn reset(&mut self) { + *self = Self::new(); + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) struct GeometricGovernor { + epsilon: f64, + last_error: f64, + adjustment_count: u64, +} + +impl GeometricGovernor { + fn new() -> Self { + Self { + epsilon: GOVERNOR_EPSILON_INITIAL, + last_error: 0.0, + adjustment_count: 0, + } + } + + pub(crate) fn with_epsilon(epsilon: f64) -> Self { + let mut governor = Self::new(); + if epsilon.is_finite() { + governor.epsilon = epsilon.clamp(GOVERNOR_EPSILON_MIN, GOVERNOR_EPSILON_MAX); + } + governor + } + + pub(crate) fn adapt(&mut self, deviation_delta: f64, dt: f64) -> f64 { + if dt <= 0.0 || !dt.is_finite() || !deviation_delta.is_finite() || self.epsilon <= 0.0 { + return self.epsilon; + } + + let current_rate = deviation_delta / self.epsilon; + let error = TARGET_TICK_RATE - current_rate; + let d_error = (error - self.last_error) / dt; + let adjustment = GOVERNOR_ALPHA * error + GOVERNOR_BETA * d_error; + + self.epsilon = + (self.epsilon - adjustment).clamp(GOVERNOR_EPSILON_MIN, GOVERNOR_EPSILON_MAX); + self.last_error = error; + self.adjustment_count = self.adjustment_count.saturating_add(1); + + self.epsilon + } + + pub(crate) fn should_trigger(&self, deviation: f64) -> bool { + deviation.is_finite() && deviation >= self.epsilon + } +} + +fn sanitize_positive(value: f64, fallback: f64) -> f64 { + if value.is_finite() && value > 0.0 { + value + } else { + fallback + } +} + +fn sanitize_non_negative(value: f64) -> f64 { + if value.is_finite() && value >= 0.0 { + value + } else { + f64::INFINITY + } +} + +fn l2_distance(a: &[f64; D], b: &[f64; D]) -> f64 { + let mut sum = 0.0; + for dimension in 0..D { + let diff = a[dimension] - b[dimension]; + sum += diff * diff; + } + sum.sqrt() +} + +#[cfg(test)] +#[path = "../tests/unit/topology_tests.rs"] +mod tests; diff --git a/crates/adaptive/tests/integration/runtime_integration_tests.rs b/crates/adaptive/tests/integration/runtime_integration_tests.rs index cdee3c050..2d2f8137a 100644 --- a/crates/adaptive/tests/integration/runtime_integration_tests.rs +++ b/crates/adaptive/tests/integration/runtime_integration_tests.rs @@ -504,11 +504,7 @@ async fn runtime_integration_acg_learner_reuses_learning_buckets_across_growing_ let requests = sample_growing_chat_requests("claude-3-5-sonnet"); let learner = AcgLearner::new(agent_id, 8, StabilityThresholds::default()); let learning_key = format!( - "{agent_id}::model=claude-3-5-sonnet::seed={}::system={}::tools=no-tools", - short_hash(&format!( - "user:{}", - nemo_relay_adaptive::acg::sha256_hex("Summarize the latest findings") - )), + "{agent_id}::model=claude-3-5-sonnet::scaffold=stable::system={}::tools=no-tools", short_hash(&nemo_relay_adaptive::acg::sha256_hex( "You are a careful planner" )), @@ -675,6 +671,7 @@ async fn test_adaptive_plugin_rejects_unsupported_mode_with_strict_policy() { tool_parallelism: Some(ToolParallelismComponentConfig { priority: 100, mode: "broken".into(), + drift: None, }), ..AdaptiveConfig::default() }) diff --git a/crates/adaptive/tests/integration/tool_parallelism_plan_tests.rs b/crates/adaptive/tests/integration/tool_parallelism_plan_tests.rs index 7ecfa4290..968fe8009 100644 --- a/crates/adaptive/tests/integration/tool_parallelism_plan_tests.rs +++ b/crates/adaptive/tests/integration/tool_parallelism_plan_tests.rs @@ -9,14 +9,13 @@ use chrono::{Duration, Utc}; use serde_json::json; use uuid::Uuid; -use nemo_relay_adaptive::{ - InMemoryBackend, StorageBackend, StorageBackendDyn, ToolParallelismLearner, -}; use nemo_relay_adaptive::learner::traits::Learner; +use nemo_relay_adaptive::tool_parallelism_learner::ToolParallelismLearner; use nemo_relay_adaptive::types::cache::HotCache; use nemo_relay_adaptive::types::metadata::{MetadataEnvelope, ParallelHint}; use nemo_relay_adaptive::types::plan::{ExecutionPlan, ParallelGroup}; use nemo_relay_adaptive::types::records::{CallKind, CallRecord, RunRecord}; +use nemo_relay_adaptive::{InMemoryBackend, StorageBackend, StorageBackendDyn}; fn make_hot_cache() -> Arc> { Arc::new(RwLock::new(HotCache { @@ -27,10 +26,6 @@ fn make_hot_cache() -> Arc> { acg_profile_observation_counts: std::collections::HashMap::new(), acg_stability: None, acg_observation_count: 0, - acg_profiles: std::collections::HashMap::new(), - acg_profile_observation_counts: std::collections::HashMap::new(), - acg_stability: None, - acg_observation_count: 0, })) } @@ -51,10 +46,6 @@ fn make_tool_call( total_tokens: None, model_name: None, tool_call_count: None, - annotated_request: None, - annotated_response: None, - annotated_request: None, - annotated_response: None, annotated_request: None, annotated_response: None, } diff --git a/crates/adaptive/tests/integration/topology_convergence_tests.rs b/crates/adaptive/tests/integration/topology_convergence_tests.rs new file mode 100644 index 000000000..e63e7c696 --- /dev/null +++ b/crates/adaptive/tests/integration/topology_convergence_tests.rs @@ -0,0 +1,272 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Integration tests for topological convergence detection in the ACG learner. + +use std::sync::{Arc, RwLock}; + +use chrono::Utc; +use nemo_relay::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; +use nemo_relay_adaptive::acg::stability::StabilityThresholds; +use nemo_relay_adaptive::acg_learner::AcgLearner; +use nemo_relay_adaptive::learner::traits::Learner; +use nemo_relay_adaptive::types::cache::HotCache; +use nemo_relay_adaptive::types::records::{CallKind, CallRecord, RunRecord}; +use nemo_relay_adaptive::{ConvergenceConfig, InMemoryBackend, StorageBackendDyn}; +use uuid::Uuid; + +fn identical_request() -> AnnotatedLlmRequest { + AnnotatedLlmRequest { + messages: vec![ + Message::System { + content: MessageContent::Text("You are a helpful assistant.".to_string()), + name: None, + }, + Message::User { + content: MessageContent::Text("Summarize this.".to_string()), + name: None, + }, + ], + model: Some("gpt-4o".to_string()), + params: None, + tools: None, + tool_choice: None, + store: None, + previous_response_id: None, + truncation: None, + reasoning: None, + include: None, + user: None, + metadata: None, + service_tier: None, + parallel_tool_calls: None, + max_output_tokens: None, + max_tool_calls: None, + top_logprobs: None, + stream: None, + extra: serde_json::Map::new(), + } +} + +fn coding_agent_request(work_item: &str) -> AnnotatedLlmRequest { + AnnotatedLlmRequest { + messages: vec![ + Message::System { + content: MessageContent::Text("You are a repo coding agent.".to_string()), + name: None, + }, + Message::User { + content: MessageContent::Text("Apply the repository review checklist.".to_string()), + name: None, + }, + Message::Assistant { + content: Some(MessageContent::Text( + "Acknowledged. I will review with that checklist.".to_string(), + )), + tool_calls: None, + name: None, + }, + Message::User { + content: MessageContent::Text(work_item.to_string()), + name: None, + }, + ], + model: Some("gpt-4o".to_string()), + params: None, + tools: None, + tool_choice: None, + store: None, + previous_response_id: None, + truncation: None, + reasoning: None, + include: None, + user: None, + metadata: None, + service_tier: None, + parallel_tool_calls: None, + max_output_tokens: None, + max_tool_calls: None, + top_logprobs: None, + stream: None, + extra: serde_json::Map::new(), + } +} + +fn run_with_requests(requests: Vec) -> RunRecord { + let now = Utc::now(); + RunRecord { + id: Uuid::now_v7(), + agent_id: "convergence-agent".to_string(), + calls: requests + .into_iter() + .map(|request| CallRecord { + kind: CallKind::Llm, + name: "planner".to_string(), + started_at: now, + ended_at: Some(now), + metadata_snapshot: None, + output_tokens: None, + prompt_tokens: None, + total_tokens: None, + model_name: None, + tool_call_count: None, + annotated_request: Some(request.into()), + annotated_response: None, + }) + .collect(), + started_at: now, + ended_at: Some(now), + } +} + +fn empty_cache() -> Arc> { + Arc::new(RwLock::new(HotCache { + plan: None, + trie: None, + agent_hints_default: None, + acg_profiles: std::collections::HashMap::new(), + acg_profile_observation_counts: std::collections::HashMap::new(), + acg_stability: None, + acg_observation_count: 0, + })) +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_declares_convergence_before_window_exhausted() { + let observation_window = 20; + let stability_window = 3; + let learner = AcgLearner::new_with_convergence( + "convergence-agent", + observation_window, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window, + }), + ); + let backend = InMemoryBackend::new(); + let hot_cache = empty_cache(); + let request = identical_request(); + + let mut converged_at = None; + let mut agent_observations_at_convergence = 0; + for iteration in 0..observation_window { + let run = run_with_requests(vec![request.clone()]); + learner + .process_run(&run, &backend, &hot_cache) + .await + .unwrap(); + + let stability = backend + .load_stability("convergence-agent") + .await + .unwrap() + .expect("stability should be stored"); + if stability.converged { + converged_at = Some(iteration + 1); + agent_observations_at_convergence = backend + .load_observations("convergence-agent") + .await + .unwrap() + .map(|observations| observations.len()) + .unwrap_or(0); + break; + } + } + + assert!( + converged_at.is_some(), + "expected convergence to be declared before exhausting the observation window" + ); + assert!( + converged_at.unwrap() < observation_window, + "convergence should be declared before the observation window is exhausted" + ); + assert!( + converged_at.unwrap() >= stability_window, + "convergence should require at least stability_window epochs" + ); + + // Continue running after convergence to verify the cached result is reused + // and observations are no longer updated. + for _ in 0..3 { + let run = run_with_requests(vec![request.clone()]); + learner + .process_run(&run, &backend, &hot_cache) + .await + .unwrap(); + } + + let final_stability = backend + .load_stability("convergence-agent") + .await + .unwrap() + .expect("stability should still be stored"); + assert!( + final_stability.converged, + "cached stability result should remain converged" + ); + + let final_agent_observations = backend + .load_observations("convergence-agent") + .await + .unwrap() + .expect("agent aggregate observations should remain stored after convergence"); + assert!( + !final_agent_observations.is_empty(), + "agent aggregate observations should be non-empty after convergence" + ); + assert_eq!( + final_agent_observations.len(), + agent_observations_at_convergence, + "agent aggregate observation storage should be skipped after convergence" + ); +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_convergence_accepts_stable_prefix_with_variable_work_item_suffix() { + let observation_window = 12; + let stability_window = 3; + let learner = AcgLearner::new_with_convergence( + "convergence-agent", + observation_window, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window, + }), + ); + let backend = InMemoryBackend::new(); + let hot_cache = empty_cache(); + + let mut converged_at = None; + let mut stable_prefix = 0; + for iteration in 0..observation_window { + let run = run_with_requests(vec![coding_agent_request(&format!( + "Review changed bundle #{iteration}" + ))]); + learner + .process_run(&run, &backend, &hot_cache) + .await + .unwrap(); + + let stability = backend + .load_stability("convergence-agent") + .await + .unwrap() + .expect("stability should be stored"); + if stability.converged { + converged_at = Some(iteration + 1); + stable_prefix = stability.stable_prefix_length; + break; + } + } + + assert!( + converged_at.is_some(), + "stable active-agent prelude should converge despite variable work item suffix" + ); + assert_eq!(stable_prefix, 3); +} diff --git a/crates/adaptive/tests/unit/acg/economics_internal_tests.rs b/crates/adaptive/tests/unit/acg/economics_internal_tests.rs index 7a6a01364..806120fb5 100644 --- a/crates/adaptive/tests/unit/acg/economics_internal_tests.rs +++ b/crates/adaptive/tests/unit/acg/economics_internal_tests.rs @@ -121,7 +121,9 @@ fn economics_internal_build_prefix_stats_stops_at_first_non_stable_block() { score(1, StabilityClass::Variable, 0.4, 0.9), ], stable_prefix_length: 2, + stable_prefix_fingerprint: None, total_observations: 6, + converged: false, }; let stats = build_prefix_stats( diff --git a/crates/adaptive/tests/unit/acg/economics_policy_tests.rs b/crates/adaptive/tests/unit/acg/economics_policy_tests.rs index 1d4a00966..70f4a2202 100644 --- a/crates/adaptive/tests/unit/acg/economics_policy_tests.rs +++ b/crates/adaptive/tests/unit/acg/economics_policy_tests.rs @@ -84,7 +84,9 @@ fn stability_result(scores: &[(f64, f64)], observation_count: u32) -> StabilityA }) .collect(), stable_prefix_length: scores.len(), + stable_prefix_fingerprint: None, total_observations: observation_count, + converged: false, } } diff --git a/crates/adaptive/tests/unit/acg/ir_builder_tests.rs b/crates/adaptive/tests/unit/acg/ir_builder_tests.rs index 2fee6ba62..6f42d7257 100644 --- a/crates/adaptive/tests/unit/acg/ir_builder_tests.rs +++ b/crates/adaptive/tests/unit/acg/ir_builder_tests.rs @@ -154,6 +154,92 @@ fn build_prompt_ir_appends_tool_blocks_when_request_contains_only_system_message assert_eq!(prompt_ir.blocks[2].sequence_index, 2); } +#[test] +fn build_prompt_ir_inserts_tool_schema_and_output_contract_before_workflow_scaffold() { + let mut extra = serde_json::Map::new(); + extra.insert( + "response_format".to_string(), + serde_json::json!({ + "type": "json_schema", + "json_schema": { + "name": "moderation_decision", + "schema": { + "type": "object", + "properties": { + "decision": {"type": "string"}, + "reason": {"type": "string"} + }, + "required": ["decision", "reason"] + } + } + }), + ); + let request = AnnotatedLlmRequest { + messages: vec![ + Message::System { + content: MessageContent::Text("Apply policy exactly.".to_string()), + name: None, + }, + Message::User { + content: MessageContent::Text("Use the moderation workflow.".to_string()), + name: None, + }, + Message::Assistant { + content: Some(MessageContent::Text( + "I will return the required moderation decision object.".to_string(), + )), + tool_calls: None, + name: None, + }, + Message::User { + content: MessageContent::Text("Review this changing post.".to_string()), + name: None, + }, + ], + model: Some("gpt-4o".to_string()), + params: None, + tools: Some(vec![sample_tool_definition("policy_lookup")]), + tool_choice: None, + store: None, + previous_response_id: None, + truncation: None, + reasoning: None, + include: None, + user: None, + metadata: None, + service_tier: None, + parallel_tool_calls: None, + max_output_tokens: None, + max_tool_calls: None, + top_logprobs: None, + stream: None, + extra, + }; + + let prompt_ir = build_prompt_ir(&request).unwrap(); + + assert_eq!(prompt_ir.blocks.len(), 6); + assert_eq!(prompt_ir.blocks[0].content_type, BlockContentType::Text); + assert_eq!( + prompt_ir.blocks[1].content_type, + BlockContentType::ToolSchema + ); + assert_eq!( + prompt_ir.blocks[2].content_type, + BlockContentType::StructuredOutput + ); + assert_eq!(prompt_ir.blocks[3].role, PromptRole::User); + assert_eq!(prompt_ir.blocks[4].role, PromptRole::Assistant); + assert_eq!(prompt_ir.blocks[5].role, PromptRole::User); + assert!( + prompt_ir + .structured_output_schema_id + .as_deref() + .is_some_and(|schema_id| schema_id.starts_with("sha256:")) + ); + assert!(prompt_ir.blocks[2].content.contains("moderation_decision")); +} + #[test] fn build_prompt_ir_omits_tool_schema_hashes_when_no_tools_are_present() { let request = AnnotatedLlmRequest { diff --git a/crates/adaptive/tests/unit/acg/multi_breakpoint_tests.rs b/crates/adaptive/tests/unit/acg/multi_breakpoint_tests.rs index f4c228482..6176d53a1 100644 --- a/crates/adaptive/tests/unit/acg/multi_breakpoint_tests.rs +++ b/crates/adaptive/tests/unit/acg/multi_breakpoint_tests.rs @@ -81,7 +81,9 @@ fn layered_stability(observation_count: u32, layers: usize) -> StabilityAnalysis }) .collect(), stable_prefix_length: layers, + stable_prefix_fingerprint: None, total_observations: observation_count, + converged: false, } } diff --git a/crates/adaptive/tests/unit/acg/stability_internal_tests.rs b/crates/adaptive/tests/unit/acg/stability_internal_tests.rs index c39843586..908b31c86 100644 --- a/crates/adaptive/tests/unit/acg/stability_internal_tests.rs +++ b/crates/adaptive/tests/unit/acg/stability_internal_tests.rs @@ -78,3 +78,60 @@ fn stability_internal_effective_score_handles_zero_present_count() { assert_eq!(effective_stability_score(&observations, 3), 0.0); } + +#[test] +fn stability_internal_prefix_fingerprint_ties_choose_lexicographically_smallest_hash() { + let mut counts = std::collections::HashMap::new(); + counts.insert("sha256:bbbb".to_string(), 2); + counts.insert("sha256:aaaa".to_string(), 2); + counts.insert("sha256:cccc".to_string(), 1); + + assert_eq!( + select_dominant_prefix_fingerprint(counts).as_deref(), + Some("sha256:aaaa") + ); +} + +#[test] +fn stability_internal_equal_sequence_indexes_sort_by_span_id() { + let mut indexed_scores = vec![ + ( + 1, + BlockStabilityScore { + span_id: SpanId("span-b".to_string()), + classification: StabilityClass::Stable, + score: 1.0, + confidence: 1.0, + observation_count: 3, + }, + ), + ( + 1, + BlockStabilityScore { + span_id: SpanId("span-a".to_string()), + classification: StabilityClass::Stable, + score: 1.0, + confidence: 1.0, + observation_count: 3, + }, + ), + ( + 0, + BlockStabilityScore { + span_id: SpanId("span-0".to_string()), + classification: StabilityClass::Stable, + score: 1.0, + confidence: 1.0, + observation_count: 3, + }, + ), + ]; + + sort_indexed_scores(&mut indexed_scores); + + let ordered_span_ids = indexed_scores + .iter() + .map(|(_, score)| score.span_id.0.as_str()) + .collect::>(); + assert_eq!(ordered_span_ids, vec!["span-0", "span-a", "span-b"]); +} diff --git a/crates/adaptive/tests/unit/acg_component_tests.rs b/crates/adaptive/tests/unit/acg_component_tests.rs index cc4c2f7e7..338fe4441 100644 --- a/crates/adaptive/tests/unit/acg_component_tests.rs +++ b/crates/adaptive/tests/unit/acg_component_tests.rs @@ -19,7 +19,9 @@ use crate::storage::traits::StorageBackendDyn; use nemo_relay::api::llm::LlmRequest; use nemo_relay::api::runtime::LlmExecutionNextFn; use nemo_relay::api::runtime::LlmStreamExecutionNextFn; -use nemo_relay::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; +use nemo_relay::codec::request::{ + AnnotatedLlmRequest, FunctionDefinition, Message, MessageContent, ToolDefinition, +}; use serde_json::{Value, json}; use tokio_stream::StreamExt; @@ -48,7 +50,12 @@ fn sample_hot_cache() -> Arc> { }, ], stable_prefix_length: 2, + stable_prefix_fingerprint: prompt_fingerprint_for_llm_request( + &sample_openai_responses_request(), + 2, + ), total_observations: 8, + converged: false, }), acg_observation_count: 8, })) @@ -115,6 +122,102 @@ fn sample_annotated_request(model: &str) -> AnnotatedLlmRequest { } } +fn pipe_tool_definition() -> ToolDefinition { + ToolDefinition { + tool_type: "function".to_string(), + function: FunctionDefinition { + name: "policy_lookup".to_string(), + description: Some("Look up policy guidance for a moderation item.".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "policy_area": {"type": "string"} + }, + "required": ["policy_area"] + })), + }, + } +} + +fn pipe_response_format(include_severity: bool) -> Value { + let mut properties = json!({ + "decision": {"type": "string"}, + "reason": {"type": "string"} + }); + let mut required = vec![ + Value::String("decision".to_string()), + Value::String("reason".to_string()), + ]; + if include_severity { + properties["severity"] = json!({"type": "string"}); + required.push(Value::String("severity".to_string())); + } + + json!({ + "type": "json_schema", + "json_schema": { + "name": "moderation_decision", + "schema": { + "type": "object", + "properties": properties, + "required": required + } + } + }) +} + +fn sample_agent_pipe_request(work_item: &str, include_severity: bool) -> AnnotatedLlmRequest { + let mut extra = serde_json::Map::new(); + extra.insert( + "response_format".to_string(), + pipe_response_format(include_severity), + ); + + AnnotatedLlmRequest { + messages: vec![ + Message::System { + content: MessageContent::Text("Apply the moderation policy exactly.".to_string()), + name: None, + }, + Message::User { + content: MessageContent::Text( + "Use the reusable moderation workflow before judging the item.".to_string(), + ), + name: None, + }, + Message::Assistant { + content: Some(MessageContent::Text( + "I will return only the required moderation decision object.".to_string(), + )), + tool_calls: None, + name: None, + }, + Message::User { + content: MessageContent::Text(work_item.to_string()), + name: None, + }, + ], + model: Some("gpt-4o".to_string()), + params: None, + tools: Some(vec![pipe_tool_definition()]), + tool_choice: None, + store: None, + previous_response_id: None, + truncation: None, + reasoning: None, + include: None, + user: None, + metadata: None, + service_tier: None, + parallel_tool_calls: None, + max_output_tokens: None, + max_tool_calls: None, + top_logprobs: None, + stream: None, + extra, + } +} + fn sample_openai_responses_request() -> LlmRequest { LlmRequest { headers: serde_json::Map::new(), @@ -174,6 +277,51 @@ fn sample_layered_anthropic_request() -> LlmRequest { } } +fn prompt_ir_for_llm_request(request: &LlmRequest) -> PromptIR { + let semantic_request_view = + build_semantic_request_view(request).expect("request should decode to semantic view"); + crate::acg::ir_builder::build_prompt_ir(&semantic_request_view.annotated_request) + .expect("semantic request should build PromptIR") +} + +fn prompt_fingerprint_for_llm_request(request: &LlmRequest, prefix_len: usize) -> Option { + let prompt_ir = prompt_ir_for_llm_request(request); + crate::acg::stability::prompt_prefix_fingerprint(&prompt_ir, prefix_len) +} + +fn stability_with_prefix_for_prompt( + prefix_len: u32, + observations: u32, + prompt_ir: &PromptIR, +) -> StabilityAnalysisResult { + let mut stability = stability_with_prefix(prefix_len, observations); + stability.stable_prefix_fingerprint = + crate::acg::stability::prompt_prefix_fingerprint(prompt_ir, prefix_len as usize); + stability +} + +fn stability_with_prefix_for_llm_request( + prefix_len: u32, + observations: u32, + request: &LlmRequest, +) -> StabilityAnalysisResult { + let prompt_ir = prompt_ir_for_llm_request(request); + stability_with_prefix_for_prompt(prefix_len, observations, &prompt_ir) +} + +fn layered_stability_result_for_prompt( + observation_count: u32, + prompt_ir: &PromptIR, +) -> StabilityAnalysisResult { + let mut stability = layered_stability_result(observation_count); + let stable_prefix_length = stability.stable_prefix_length.min(prompt_ir.blocks.len()); + stability.scores.truncate(stable_prefix_length); + stability.stable_prefix_length = stable_prefix_length; + stability.stable_prefix_fingerprint = + crate::acg::stability::prompt_prefix_fingerprint(prompt_ir, stable_prefix_length); + stability +} + fn marker_positions(req: &LlmRequest) -> Vec<(String, usize)> { let mut positions = Vec::new(); append_system_marker_positions(req, &mut positions); @@ -232,12 +380,14 @@ fn stability_with_prefix(prefix_len: u32, observations: u32) -> StabilityAnalysi }) .collect(), stable_prefix_length: prefix_len as usize, + stable_prefix_fingerprint: None, total_observations: observations, + converged: false, } } fn layered_stability_result(observation_count: u32) -> StabilityAnalysisResult { - StabilityAnalysisResult { + let mut stability = StabilityAnalysisResult { scores: vec![ BlockStabilityScore { span_id: SpanId("block-0".to_string()), @@ -262,8 +412,13 @@ fn layered_stability_result(observation_count: u32) -> StabilityAnalysisResult { }, ], stable_prefix_length: 3, + stable_prefix_fingerprint: None, total_observations: observation_count, - } + converged: false, + }; + stability.stable_prefix_fingerprint = + prompt_fingerprint_for_llm_request(&sample_layered_anthropic_request(), 3); + stability } fn sample_prompt_ir(span: &str) -> PromptIR { @@ -407,7 +562,15 @@ impl StorageBackendDyn for FailingStabilityBackend { #[test] fn acg_component_translate_request_degrades_when_provider_semantics_do_not_match_request_surface() { let request = sample_openai_chat_request(); - let hot_cache = sample_hot_cache(); + let hot_cache = Arc::new(RwLock::new(HotCache { + plan: None, + trie: None, + agent_hints_default: None, + acg_profiles: std::collections::HashMap::new(), + acg_profile_observation_counts: std::collections::HashMap::new(), + acg_stability: Some(stability_with_prefix_for_llm_request(2, 8, &request)), + acg_observation_count: 8, + })); let plugin = build_provider_plugin("anthropic").expect("anthropic plugin should build"); let translated = translate_request( @@ -443,7 +606,15 @@ fn acg_component_translate_request_applies_openai_semantics_on_resolved_request_ #[test] fn acg_component_translate_request_passes_through_when_planner_finds_no_profitable_breakpoints() { let request = sample_anthropic_request(); - let hot_cache = sample_hot_cache(); + let hot_cache = Arc::new(RwLock::new(HotCache { + plan: None, + trie: None, + agent_hints_default: None, + acg_profiles: std::collections::HashMap::new(), + acg_profile_observation_counts: std::collections::HashMap::new(), + acg_stability: Some(stability_with_prefix_for_llm_request(2, 8, &request)), + acg_observation_count: 8, + })); let plugin = build_provider_plugin("anthropic").expect("anthropic plugin should build"); let translated = translate_request( @@ -483,7 +654,7 @@ fn acg_component_translate_request_errors_are_fail_open() { agent_hints_default: None, acg_profiles: std::collections::HashMap::new(), acg_profile_observation_counts: std::collections::HashMap::new(), - acg_stability: Some(layered_stability_result(6)), + acg_stability: Some(stability_with_prefix_for_llm_request(2, 6, &request)), acg_observation_count: 6, })); @@ -549,7 +720,7 @@ fn rewrite_request_with_hot_cache_adaptive_placement_differs_by_state() { agent_hints_default: None, acg_profiles: std::collections::HashMap::new(), acg_profile_observation_counts: std::collections::HashMap::new(), - acg_stability: Some(stability_with_prefix(1, 8)), + acg_stability: Some(stability_with_prefix_for_llm_request(1, 8, &request)), acg_observation_count: 8, })); let translated_short = @@ -562,7 +733,7 @@ fn rewrite_request_with_hot_cache_adaptive_placement_differs_by_state() { agent_hints_default: None, acg_profiles: std::collections::HashMap::new(), acg_profile_observation_counts: std::collections::HashMap::new(), - acg_stability: Some(stability_with_prefix(3, 8)), + acg_stability: Some(stability_with_prefix_for_llm_request(3, 8, &request)), acg_observation_count: 8, })); let translated_long = @@ -691,7 +862,9 @@ fn acg_component_build_intent_bundle_requires_at_least_two_observations() { }, ], stable_prefix_length: 2, + stable_prefix_fingerprint: crate::acg::stability::prompt_prefix_fingerprint(&prompt_ir, 2), total_observations: 1, + converged: false, }; let intent_bundle = build_intent_bundle( @@ -725,6 +898,87 @@ fn acg_component_build_intent_bundle_requires_at_least_two_observations() { ); } +#[test] +fn acg_component_build_intent_bundle_rejects_stale_output_contract_stability() { + let plugin = build_provider_plugin("openai").expect("openai plugin should build"); + let observations = ["#1", "#2", "#3"] + .into_iter() + .map(|suffix| { + crate::acg::ir_builder::build_prompt_ir(&sample_agent_pipe_request( + &format!("Review forum post {suffix}"), + false, + )) + .unwrap() + }) + .collect::>(); + let stability = crate::acg::stability::analyze_stability( + &observations, + &crate::acg::stability::StabilityThresholds::default(), + ); + assert_eq!(stability.stable_prefix_length, 5); + assert!(stability.stable_prefix_fingerprint.is_some()); + + let changed_contract = sample_agent_pipe_request("Review forum post #4", true); + let changed_prompt_ir = crate::acg::ir_builder::build_prompt_ir(&changed_contract).unwrap(); + + let intent_bundle = build_intent_bundle( + "agent-1", + "openai", + plugin.as_ref(), + RequestSurface::OpenAIChat, + &changed_contract, + &changed_prompt_ir, + &stability, + 3, + ); + + assert!( + intent_bundle.is_none(), + "request-time ACG hints must fail open when cached stability was learned from a different output contract" + ); +} + +#[test] +fn acg_component_build_intent_bundle_rejects_missing_prefix_fingerprint_stability() { + let plugin = build_provider_plugin("openai").expect("openai plugin should build"); + let observations = ["#1", "#2", "#3"] + .into_iter() + .map(|suffix| { + crate::acg::ir_builder::build_prompt_ir(&sample_agent_pipe_request( + &format!("Review forum post {suffix}"), + false, + )) + .unwrap() + }) + .collect::>(); + let mut stability = crate::acg::stability::analyze_stability( + &observations, + &crate::acg::stability::StabilityThresholds::default(), + ); + assert_eq!(stability.stable_prefix_length, 5); + assert!(stability.stable_prefix_fingerprint.is_some()); + stability.stable_prefix_fingerprint = None; + + let request = sample_agent_pipe_request("Review forum post #4", false); + let prompt_ir = crate::acg::ir_builder::build_prompt_ir(&request).unwrap(); + + let intent_bundle = build_intent_bundle( + "agent-1", + "openai", + plugin.as_ref(), + RequestSurface::OpenAIChat, + &request, + &prompt_ir, + &stability, + 3, + ); + + assert!( + intent_bundle.is_none(), + "request-time ACG hints must fail open when cached stability cannot prove the stable prefix fingerprint" + ); +} + #[tokio::test] async fn acg_component_load_persisted_state_prefers_stability_and_falls_back_to_observations() { let backend = InMemoryBackend::new(); @@ -865,7 +1119,7 @@ fn acg_component_build_intent_bundle_supports_openai_and_rejects_unknown_provide let request = sample_annotated_request("gpt-4o"); let prompt_ir = crate::acg::ir_builder::build_prompt_ir(&request).unwrap(); let plugin = build_provider_plugin("openai").unwrap(); - let stability = layered_stability_result(4); + let stability = layered_stability_result_for_prompt(4, &prompt_ir); let bundle = build_intent_bundle( "agent-openai", @@ -949,7 +1203,7 @@ fn acg_component_build_hint_translation_and_apply_hint_translation_cover_passthr let prompt_ir = crate::acg::ir_builder::build_prompt_ir(&semantic_view.annotated_request).unwrap(); let plugin = build_provider_plugin("openai").unwrap(); - let stability = layered_stability_result(4); + let stability = layered_stability_result_for_prompt(4, &prompt_ir); let bundle = build_intent_bundle( "agent-openai", "openai", @@ -1004,7 +1258,10 @@ fn acg_component_translate_request_uses_profile_specific_stability_and_fails_ope agent_hints_default: None, acg_profiles: std::collections::HashMap::from([( learning_key.clone(), - layered_stability_result(6), + layered_stability_result_for_prompt( + 6, + &crate::acg::ir_builder::build_prompt_ir(&semantic_view.annotated_request).unwrap(), + ), )]), acg_profile_observation_counts: std::collections::HashMap::from([(learning_key, 6)]), acg_stability: None, @@ -1148,7 +1405,7 @@ fn acg_component_build_hint_translation_succeeds_for_openai_and_anthropic() { RequestSurface::OpenAIChat, &openai_view.annotated_request, &openai_prompt_ir, - &layered_stability_result(4), + &layered_stability_result_for_prompt(4, &openai_prompt_ir), 4, ) .unwrap(); @@ -1189,7 +1446,7 @@ fn acg_component_build_hint_translation_succeeds_for_openai_and_anthropic() { RequestSurface::AnthropicMessages, &anthropic_view.annotated_request, &anthropic_prompt_ir, - &layered_stability_result(6), + &layered_stability_result_for_prompt(6, &anthropic_prompt_ir), 6, ) .unwrap(); diff --git a/crates/adaptive/tests/unit/acg_learner_tests.rs b/crates/adaptive/tests/unit/acg_learner_tests.rs index 664ff5126..8eecaba09 100644 --- a/crates/adaptive/tests/unit/acg_learner_tests.rs +++ b/crates/adaptive/tests/unit/acg_learner_tests.rs @@ -5,13 +5,18 @@ use std::future::Future; use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use chrono::Utc; -use nemo_relay::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; +use nemo_relay::codec::request::{ + AnnotatedLlmRequest, FunctionDefinition, Message, MessageContent, ToolDefinition, +}; use uuid::Uuid; use super::*; +use crate::acg::profile::{BlockStabilityScore, StabilityClass}; +use crate::acg::prompt_ir::SpanId; use crate::acg_profile::derive_acg_learning_key; use crate::trie::accumulator::AccumulatorState; use crate::trie::serialization::TrieEnvelope; @@ -51,6 +56,137 @@ fn sample_request(model: &str, system: &str, user: &str) -> AnnotatedLlmRequest } } +fn layered_agent_request(work_item: &str) -> AnnotatedLlmRequest { + AnnotatedLlmRequest { + messages: vec![ + Message::System { + content: MessageContent::Text("You are a repo coding agent.".to_string()), + name: None, + }, + Message::User { + content: MessageContent::Text("Apply the repository review checklist.".to_string()), + name: None, + }, + Message::Assistant { + content: Some(MessageContent::Text( + "Acknowledged. I will review with that checklist.".to_string(), + )), + tool_calls: None, + name: None, + }, + Message::User { + content: MessageContent::Text(work_item.to_string()), + name: None, + }, + ], + model: Some("gpt-4o".to_string()), + params: None, + tools: None, + tool_choice: None, + store: None, + previous_response_id: None, + truncation: None, + reasoning: None, + include: None, + user: None, + metadata: None, + service_tier: None, + parallel_tool_calls: None, + max_output_tokens: None, + max_tool_calls: None, + top_logprobs: None, + stream: None, + extra: serde_json::Map::new(), + } +} + +fn layered_agent_request_with_extra_suffix(work_item: &str) -> AnnotatedLlmRequest { + let mut request = layered_agent_request(work_item); + request.messages.push(Message::Assistant { + content: Some(MessageContent::Text( + "I need one more repository fact before final review.".to_string(), + )), + tool_calls: None, + name: None, + }); + request.messages.push(Message::User { + content: MessageContent::Text("Additional volatile suffix context".to_string()), + name: None, + }); + request +} + +fn pipe_tool_definition() -> ToolDefinition { + ToolDefinition { + tool_type: "function".to_string(), + function: FunctionDefinition { + name: "policy_lookup".to_string(), + description: Some("Look up policy guidance for a moderation item.".to_string()), + parameters: Some(serde_json::json!({ + "type": "object", + "properties": { + "policy_area": {"type": "string"} + }, + "required": ["policy_area"] + })), + }, + } +} + +fn pipe_response_format(include_severity: bool) -> serde_json::Value { + let mut properties = serde_json::json!({ + "decision": {"type": "string"}, + "reason": {"type": "string"} + }); + let mut required = vec![ + serde_json::Value::String("decision".to_string()), + serde_json::Value::String("reason".to_string()), + ]; + if include_severity { + properties["severity"] = serde_json::json!({"type": "string"}); + required.push(serde_json::Value::String("severity".to_string())); + } + + serde_json::json!({ + "type": "json_schema", + "json_schema": { + "name": "moderation_decision", + "schema": { + "type": "object", + "properties": properties, + "required": required + } + } + }) +} + +fn layered_agent_pipe_request(work_item: &str, include_severity: bool) -> AnnotatedLlmRequest { + let mut request = layered_agent_request(work_item); + request.messages[0] = Message::System { + content: MessageContent::Text("Apply the moderation policy exactly.".to_string()), + name: None, + }; + request.messages[1] = Message::User { + content: MessageContent::Text( + "Use the reusable moderation workflow before judging the item.".to_string(), + ), + name: None, + }; + request.messages[2] = Message::Assistant { + content: Some(MessageContent::Text( + "I will return only the required moderation decision object.".to_string(), + )), + tool_calls: None, + name: None, + }; + request.tools = Some(vec![pipe_tool_definition()]); + request.extra.insert( + "response_format".to_string(), + pipe_response_format(include_severity), + ); + request +} + fn sample_run(requests: Vec) -> RunRecord { let now = Utc::now(); RunRecord { @@ -78,6 +214,60 @@ fn sample_run(requests: Vec) -> RunRecord { } } +fn stable_score(index: usize) -> BlockStabilityScore { + BlockStabilityScore { + span_id: SpanId(format!("span-{index}")), + classification: StabilityClass::Stable, + score: 1.0, + confidence: 1.0, + observation_count: 3, + } +} + +fn variable_score(index: usize) -> BlockStabilityScore { + BlockStabilityScore { + span_id: SpanId(format!("span-{index}")), + classification: StabilityClass::Variable, + score: 0.0, + confidence: 1.0, + observation_count: 3, + } +} + +#[test] +fn acg_convergence_features_ignore_variable_suffix_shape() { + let short_suffix = crate::acg::stability::StabilityAnalysisResult { + scores: vec![stable_score(0), stable_score(1), variable_score(2)], + stable_prefix_length: 2, + stable_prefix_fingerprint: None, + total_observations: 3, + converged: false, + }; + let long_suffix = crate::acg::stability::StabilityAnalysisResult { + scores: vec![ + stable_score(0), + stable_score(1), + variable_score(2), + variable_score(3), + variable_score(4), + ], + stable_prefix_length: 2, + stable_prefix_fingerprint: None, + total_observations: 3, + converged: false, + }; + + let (short_betti, short_drift, short_error) = + AcgLearner::stability_to_convergence_features(&short_suffix); + let (long_betti, long_drift, long_error) = + AcgLearner::stability_to_convergence_features(&long_suffix); + + assert_eq!(short_betti, BettiNumbers::new(2, 0)); + assert_eq!(short_betti, long_betti); + assert!((short_drift - long_drift).abs() < f64::EPSILON); + assert!((short_error - long_error).abs() < f64::EPSILON); +} + fn empty_cache() -> Arc> { Arc::new(RwLock::new(HotCache { plan: None, @@ -93,18 +283,48 @@ fn empty_cache() -> Arc> { struct SeedObservationBackend { observations: std::sync::RwLock>>, stability: std::sync::RwLock>, + fail_observation_store: AtomicBool, + load_observation_count: AtomicUsize, } impl SeedObservationBackend { - fn new(seed_key: &str, observations: Vec) -> Self { + fn empty() -> Self { Self { - observations: std::sync::RwLock::new(HashMap::from([( - seed_key.to_string(), - observations, - )])), + observations: std::sync::RwLock::new(HashMap::new()), stability: std::sync::RwLock::new(HashMap::new()), + fail_observation_store: AtomicBool::new(false), + load_observation_count: AtomicUsize::new(0), } } + + fn new(seed_key: &str, observations: Vec) -> Self { + let backend = Self::empty(); + backend + .observations + .write() + .unwrap() + .insert(seed_key.to_string(), observations); + backend + } + + fn fail_observation_stores(&self) { + self.fail_observation_store.store(true, Ordering::SeqCst); + } + + fn seed_stability( + &self, + agent_id: &str, + stability: crate::acg::stability::StabilityAnalysisResult, + ) { + self.stability + .write() + .unwrap() + .insert(agent_id.to_string(), stability); + } + + fn load_observation_count(&self) -> usize { + self.load_observation_count.load(Ordering::SeqCst) + } } impl StorageBackendDyn for SeedObservationBackend { @@ -166,6 +386,11 @@ impl StorageBackendDyn for SeedObservationBackend { ) -> Pin> + Send + 'a>> { let observations = observations.to_vec(); Box::pin(async move { + if self.fail_observation_store.load(Ordering::SeqCst) { + return Err(AdaptiveError::Storage( + "forced observation storage failure".to_string(), + )); + } self.observations .write() .unwrap() @@ -178,6 +403,7 @@ impl StorageBackendDyn for SeedObservationBackend { &'a self, agent_id: &'a str, ) -> Pin>>> + Send + 'a>> { + self.load_observation_count.fetch_add(1, Ordering::SeqCst); Box::pin(async move { Ok(self.observations.read().unwrap().get(agent_id).cloned()) }) } @@ -306,6 +532,443 @@ async fn acg_learner_prefers_profile_with_longer_stable_prefix_and_handles_poiso ); } +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_does_not_persist_converged_stability_when_observation_store_fails() { + let stability_window = 3; + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window, + }), + ); + let request = sample_request("gpt-4o", "Stable system", "Stable prompt"); + let learning_key = derive_acg_learning_key("agent-a", &request); + let backend = SeedObservationBackend::empty(); + let hot_cache = empty_cache(); + + for _ in 0..stability_window - 1 { + learner + .process_run(&sample_run(vec![request.clone()]), &backend, &hot_cache) + .await + .unwrap(); + } + + let before_failure = backend + .load_stability(&learning_key) + .await + .unwrap() + .expect("stability should be stored before the failing epoch"); + assert!(!before_failure.converged); + + backend.fail_observation_stores(); + let error = learner + .process_run(&sample_run(vec![request]), &backend, &hot_cache) + .await + .unwrap_err(); + assert!( + matches!(error, AdaptiveError::Storage(message) if message.contains("forced observation storage failure")) + ); + + let after_failure = backend + .load_stability(&learning_key) + .await + .unwrap() + .expect("previous non-converged stability should remain stored"); + assert!( + !after_failure.converged, + "converged stability must not be persisted before observations are stored" + ); +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_repairs_converged_stability_without_observations() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let request = sample_request("gpt-4o", "Stable system", "Stable prompt"); + let learning_key = derive_acg_learning_key("agent-a", &request); + let seed_observation = build_prompt_ir(&request).unwrap(); + let mut stale_stability = analyze_stability( + std::slice::from_ref(&seed_observation), + &StabilityThresholds::default(), + ); + stale_stability.converged = true; + + let backend = SeedObservationBackend::empty(); + backend.seed_stability(&learning_key, stale_stability); + + learner + .process_run(&sample_run(vec![request]), &backend, &empty_cache()) + .await + .unwrap(); + + let repaired_observations = backend + .load_observations(&learning_key) + .await + .unwrap() + .expect("missing observations should be repaired instead of trusting converged stability"); + assert_eq!(repaired_observations.len(), 1); + + let repaired_stability = backend + .load_stability(&learning_key) + .await + .unwrap() + .expect("repaired stability should be stored"); + assert!( + !repaired_stability.converged, + "a repaired single-observation profile should re-enter normal convergence" + ); +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_repairs_converged_stability_with_empty_observations() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let request = sample_request("gpt-4o", "Stable system", "Stable prompt"); + let learning_key = derive_acg_learning_key("agent-a", &request); + let seed_observation = build_prompt_ir(&request).unwrap(); + let mut stale_stability = analyze_stability( + std::slice::from_ref(&seed_observation), + &StabilityThresholds::default(), + ); + stale_stability.converged = true; + + let backend = SeedObservationBackend::new(&learning_key, Vec::new()); + backend.seed_stability(&learning_key, stale_stability); + + learner + .process_run(&sample_run(vec![request]), &backend, &empty_cache()) + .await + .unwrap(); + + let repaired_observations = backend + .load_observations(&learning_key) + .await + .unwrap() + .expect("empty observations should be repaired instead of trusting converged stability"); + assert_eq!(repaired_observations.len(), 1); +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_reuses_converged_stability_without_loading_observations() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let request = sample_request("gpt-4o", "Stable system", "Stable prompt"); + let learning_key = derive_acg_learning_key("agent-a", &request); + let seed_observation = build_prompt_ir(&request).unwrap(); + let observations = vec![ + seed_observation.clone(), + seed_observation.clone(), + seed_observation.clone(), + ]; + let mut converged_stability = analyze_stability(&observations, &StabilityThresholds::default()); + converged_stability.converged = true; + + let backend = SeedObservationBackend::new(&learning_key, observations); + backend.seed_stability(&learning_key, converged_stability); + let hot_cache = empty_cache(); + + learner + .process_run(&sample_run(vec![request]), &backend, &hot_cache) + .await + .unwrap(); + + assert_eq!( + backend.load_observation_count(), + 0, + "converged profiles should reuse cached stability without reading the observation window" + ); + let guard = hot_cache.read().unwrap(); + assert_eq!(guard.acg_profiles.len(), 1); + assert_eq!(guard.acg_observation_count, 3); + assert!(guard.acg_stability.as_ref().unwrap().converged); +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_reuses_converged_profile_when_suffix_topology_changes() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let base = layered_agent_request("Review changed bundle #1"); + let grown = layered_agent_request_with_extra_suffix("Review changed bundle #2"); + + let learning_key = derive_acg_learning_key("agent-a", &base); + assert_eq!(learning_key, derive_acg_learning_key("agent-a", &grown)); + + let observations = ["#1", "#2", "#3"] + .into_iter() + .map(|suffix| { + build_prompt_ir(&layered_agent_request(&format!( + "Review changed bundle {suffix}" + ))) + .unwrap() + }) + .collect::>(); + let mut converged_stability = analyze_stability(&observations, &StabilityThresholds::default()); + assert_eq!(converged_stability.stable_prefix_length, 3); + converged_stability.converged = true; + + let backend = SeedObservationBackend::new(&learning_key, observations); + backend.seed_stability(&learning_key, converged_stability); + + learner + .process_run(&sample_run(vec![grown]), &backend, &empty_cache()) + .await + .unwrap(); + + assert_eq!( + backend.load_observation_count(), + 0, + "suffix-only topology changes should reuse the cacheable stable prefix" + ); +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_reopens_converged_profile_when_stable_prefix_topology_changes() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let base = layered_agent_request("Review changed bundle #1"); + let mut prefix_changed = layered_agent_request("Review changed bundle #2"); + prefix_changed.messages[2] = Message::User { + content: MessageContent::Text("Inserted user context before the work item.".to_string()), + name: None, + }; + + let learning_key = derive_acg_learning_key("agent-a", &base); + assert_eq!( + learning_key, + derive_acg_learning_key("agent-a", &prefix_changed) + ); + + let observations = ["#1", "#2", "#3"] + .into_iter() + .map(|suffix| { + build_prompt_ir(&layered_agent_request(&format!( + "Review changed bundle {suffix}" + ))) + .unwrap() + }) + .collect::>(); + let mut converged_stability = analyze_stability(&observations, &StabilityThresholds::default()); + assert_eq!(converged_stability.stable_prefix_length, 3); + converged_stability.converged = true; + + let backend = SeedObservationBackend::new(&learning_key, observations); + backend.seed_stability(&learning_key, converged_stability); + + learner + .process_run(&sample_run(vec![prefix_changed]), &backend, &empty_cache()) + .await + .unwrap(); + + assert!( + backend.load_observation_count() > 0, + "stable-prefix topology changes must inspect observations instead of reusing convergence" + ); +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_reopens_converged_profile_when_stable_prefix_content_changes() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let base = layered_agent_request("Review changed bundle #1"); + let mut prefix_changed = layered_agent_request("Review changed bundle #2"); + prefix_changed.messages[2] = Message::Assistant { + content: Some(MessageContent::Text( + "Acknowledged. I will use the updated stable review lens.".to_string(), + )), + tool_calls: None, + name: None, + }; + + let learning_key = derive_acg_learning_key("agent-a", &base); + assert_eq!( + learning_key, + derive_acg_learning_key("agent-a", &prefix_changed) + ); + + let observations = ["#1", "#2", "#3"] + .into_iter() + .map(|suffix| { + build_prompt_ir(&layered_agent_request(&format!( + "Review changed bundle {suffix}" + ))) + .unwrap() + }) + .collect::>(); + let mut converged_stability = analyze_stability(&observations, &StabilityThresholds::default()); + assert_eq!(converged_stability.stable_prefix_length, 3); + converged_stability.converged = true; + + let backend = SeedObservationBackend::new(&learning_key, observations); + backend.seed_stability(&learning_key, converged_stability); + + learner + .process_run(&sample_run(vec![prefix_changed]), &backend, &empty_cache()) + .await + .unwrap(); + + assert!( + backend.load_observation_count() > 0, + "stable-prefix content changes must inspect observations instead of reusing convergence" + ); +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_reuses_converged_agent_pipe_when_only_task_suffix_changes() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let base = layered_agent_pipe_request("Review forum post #1", false); + let next_task = layered_agent_pipe_request("Review forum post #2", false); + + let learning_key = derive_acg_learning_key("agent-a", &base); + assert_eq!(learning_key, derive_acg_learning_key("agent-a", &next_task)); + + let observations = ["#1", "#2", "#3"] + .into_iter() + .map(|suffix| { + build_prompt_ir(&layered_agent_pipe_request( + &format!("Review forum post {suffix}"), + false, + )) + .unwrap() + }) + .collect::>(); + let mut converged_stability = analyze_stability(&observations, &StabilityThresholds::default()); + assert_eq!( + converged_stability.stable_prefix_length, 5, + "system policy, tool schema, output contract, workflow scaffold, and output-contract acknowledgement should be the reusable pipe" + ); + assert!(converged_stability.stable_prefix_fingerprint.is_some()); + converged_stability.converged = true; + + let backend = SeedObservationBackend::new(&learning_key, observations); + backend.seed_stability(&learning_key, converged_stability); + + learner + .process_run(&sample_run(vec![next_task]), &backend, &empty_cache()) + .await + .unwrap(); + + assert_eq!( + backend.load_observation_count(), + 0, + "task-specific suffix content should not invalidate the reusable agent pipe" + ); +} + +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_reopens_converged_agent_pipe_when_output_contract_changes() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let base = layered_agent_pipe_request("Review forum post #1", false); + let changed_contract = layered_agent_pipe_request("Review forum post #2", true); + + let learning_key = derive_acg_learning_key("agent-a", &base); + assert_eq!( + learning_key, + derive_acg_learning_key("agent-a", &changed_contract) + ); + + let observations = ["#1", "#2", "#3"] + .into_iter() + .map(|suffix| { + build_prompt_ir(&layered_agent_pipe_request( + &format!("Review forum post {suffix}"), + false, + )) + .unwrap() + }) + .collect::>(); + let mut converged_stability = analyze_stability(&observations, &StabilityThresholds::default()); + assert_eq!(converged_stability.stable_prefix_length, 5); + assert!(converged_stability.stable_prefix_fingerprint.is_some()); + converged_stability.converged = true; + + let backend = SeedObservationBackend::new(&learning_key, observations); + backend.seed_stability(&learning_key, converged_stability); + + learner + .process_run( + &sample_run(vec![changed_contract]), + &backend, + &empty_cache(), + ) + .await + .unwrap(); + + assert!( + backend.load_observation_count() > 0, + "output contract changes must reopen learning instead of reusing stale convergence" + ); +} + #[tokio::test(flavor = "current_thread")] async fn acg_learner_seeds_agent_cache_from_profile_with_more_observations_when_prefixes_tie() { let learner = AcgLearner::new("agent-a", 4, StabilityThresholds::default()); diff --git a/crates/adaptive/tests/unit/acg_profile_tests.rs b/crates/adaptive/tests/unit/acg_profile_tests.rs index b09dfc9ad..9fe835552 100644 --- a/crates/adaptive/tests/unit/acg_profile_tests.rs +++ b/crates/adaptive/tests/unit/acg_profile_tests.rs @@ -125,6 +125,66 @@ fn acg_profile_helpers_cover_none_paths_and_short_hash() { assert_eq!(message_role_tag(&too_short.messages[0]), "user"); } +#[test] +fn acg_learning_key_groups_variable_first_user_under_stable_system_scaffold() { + let spam_item = request( + vec![ + Message::System { + content: MessageContent::Text("Apply the moderation policy exactly.".to_string()), + name: None, + }, + Message::User { + content: MessageContent::Text("Review forum post #1 about spam links".to_string()), + name: None, + }, + ], + None, + ); + let bug_item = request( + vec![ + Message::System { + content: MessageContent::Text("Apply the moderation policy exactly.".to_string()), + name: None, + }, + Message::User { + content: MessageContent::Text("Review forum post #2 about a vague bug".to_string()), + name: None, + }, + ], + None, + ); + + assert_eq!( + derive_acg_learning_key("moderator-agent", &spam_item), + derive_acg_learning_key("moderator-agent", &bug_item), + "the learning key should follow the reusable agent scaffold, not the variable work item" + ); +} + +#[test] +fn acg_learning_key_keeps_seed_fallback_when_no_stable_scaffold_exists() { + let first = request( + vec![Message::User { + content: MessageContent::Text("One-off prompt A".to_string()), + name: None, + }], + None, + ); + let second = request( + vec![Message::User { + content: MessageContent::Text("One-off prompt B".to_string()), + name: None, + }], + None, + ); + + assert_ne!( + derive_acg_learning_key("direct-agent", &first), + derive_acg_learning_key("direct-agent", &second), + "without a stable scaffold, ACG should not collapse unrelated one-off prompts" + ); +} + #[test] fn acg_profile_image_parts_contribute_stable_fingerprint_signal() { let with_image_a = request( diff --git a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs index a7efa190e..a2b4e92de 100644 --- a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs +++ b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs @@ -316,6 +316,81 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() { reset_root_metadata(); } +#[test] +fn test_adaptive_hints_governor_sheds_low_sensitivity_hints_but_keeps_manual_override() { + let _guard = test_mutex().lock().unwrap(); + reset_root_metadata(); + + let defaults = AgentHints { + osl: 9, + iat: 12, + priority: 3, + latency_sensitivity: 2.0, + prefix_id: "defaults".into(), + total_requests: 11, + }; + let hot_cache = Arc::new(RwLock::new(HotCache { + plan: None, + trie: None, + agent_hints_default: Some(defaults), + acg_profiles: std::collections::HashMap::new(), + acg_profile_observation_counts: std::collections::HashMap::new(), + acg_stability: None, + acg_observation_count: 0, + })); + let governor = GovernorConfig { + enabled: true, + epsilon: 10.0, + }; + let req_fn = AdaptiveHintsIntercept::with_governor( + hot_cache.clone(), + "fallback-agent".to_string(), + Some(governor.clone()), + ) + .into_request_fn(); + + let (request, _) = req_fn( + "model", + LlmRequest { + headers: serde_json::Map::new(), + content: serde_json::json!({}), + }, + None, + ) + .unwrap(); + assert!(request.headers.get(AGENT_HINTS_HEADER_KEY).is_none()); + assert!( + request + .content + .get("nvext") + .and_then(|nvext| nvext.get("agent_hints")) + .is_none() + ); + + crate::context_helpers::set_latency_sensitivity(11).unwrap(); + let manual_req_fn = AdaptiveHintsIntercept::with_governor( + hot_cache, + "fallback-agent".to_string(), + Some(governor), + ) + .into_request_fn(); + let (manual_request, _) = manual_req_fn( + "model", + LlmRequest { + headers: serde_json::Map::new(), + content: serde_json::json!({}), + }, + None, + ) + .unwrap(); + assert_eq!( + manual_request.content["nvext"]["agent_hints"]["latency_sensitivity"], + serde_json::json!(11.0) + ); + + reset_root_metadata(); +} + #[test] fn test_apply_manual_latency_override_and_inject_agent_hints_cover_manual_paths() { let base_hints = AgentHints { diff --git a/crates/adaptive/tests/unit/cache_diagnostics_tests.rs b/crates/adaptive/tests/unit/cache_diagnostics_tests.rs index b46c7b846..4f96137da 100644 --- a/crates/adaptive/tests/unit/cache_diagnostics_tests.rs +++ b/crates/adaptive/tests/unit/cache_diagnostics_tests.rs @@ -82,7 +82,9 @@ fn make_hot_cache(stable_prefix_length: Option) -> HotCache { }) .collect(), stable_prefix_length, + stable_prefix_fingerprint: None, total_observations: 4, + converged: false, }), acg_observation_count: 4, } diff --git a/crates/adaptive/tests/unit/config_tests.rs b/crates/adaptive/tests/unit/config_tests.rs index 22f8ed6f6..a2e488ae0 100644 --- a/crates/adaptive/tests/unit/config_tests.rs +++ b/crates/adaptive/tests/unit/config_tests.rs @@ -25,9 +25,19 @@ fn test_typed_section_helpers_default() { let adaptive_hints = AdaptiveHintsComponentConfig::default(); assert_eq!(adaptive_hints.priority, 100); assert!(adaptive_hints.inject_header); + assert!(adaptive_hints.governor.is_none()); let tool_parallelism = ToolParallelismComponentConfig::default(); assert_eq!(tool_parallelism.mode, "observe_only"); + assert!(tool_parallelism.drift.is_none()); + + let governor = GovernorConfig::default(); + assert!(!governor.enabled); + assert_eq!(governor.epsilon, 1.0); + + let drift = DriftConfig::default(); + assert!(!drift.enabled); + assert_eq!(drift.threshold, 0.75); } #[test] @@ -66,11 +76,29 @@ fn test_component_configs_deserialize_with_default_helpers() { assert!(!adaptive_hints.break_chain); assert!(adaptive_hints.inject_header); assert_eq!(adaptive_hints.inject_body_path, "nvext.agent_hints"); + assert!(adaptive_hints.governor.is_none()); let tool_parallelism: ToolParallelismComponentConfig = serde_json::from_value(json!({})).unwrap(); assert_eq!(tool_parallelism.priority, 100); assert_eq!(tool_parallelism.mode, "observe_only"); + assert!(tool_parallelism.drift.is_none()); + + let adaptive_hints: AdaptiveHintsComponentConfig = serde_json::from_value(json!({ + "governor": {"enabled": true} + })) + .unwrap(); + let governor = adaptive_hints.governor.unwrap(); + assert!(governor.enabled); + assert_eq!(governor.epsilon, 1.0); + + let tool_parallelism: ToolParallelismComponentConfig = serde_json::from_value(json!({ + "drift": {"enabled": true} + })) + .unwrap(); + let drift = tool_parallelism.drift.unwrap(); + assert!(drift.enabled); + assert_eq!(drift.threshold, 0.75); } #[test] @@ -90,6 +118,7 @@ fn test_adaptive_editor_schema_covers_canonical_options() { "adaptive_hints", "tool_parallelism", "acg", + "convergence", "policy", ] ); @@ -105,6 +134,20 @@ fn test_adaptive_editor_schema_covers_canonical_options() { EditorFieldKind::Json ); + let adaptive_hints = schema.field("adaptive_hints").unwrap().schema().unwrap(); + let governor = adaptive_hints.field("governor").unwrap().schema().unwrap(); + assert_eq!( + governor.field("epsilon").unwrap().kind, + EditorFieldKind::Float + ); + + let tool_parallelism = schema.field("tool_parallelism").unwrap().schema().unwrap(); + let drift = tool_parallelism.field("drift").unwrap().schema().unwrap(); + assert_eq!( + drift.field("threshold").unwrap().kind, + EditorFieldKind::Float + ); + let acg = schema.field("acg").unwrap().schema().unwrap(); let thresholds = acg.field("stability_thresholds").unwrap().schema().unwrap(); assert_eq!( diff --git a/crates/adaptive/tests/unit/intercepts_tests.rs b/crates/adaptive/tests/unit/intercepts_tests.rs index 98fd16410..e87f8b0b5 100644 --- a/crates/adaptive/tests/unit/intercepts_tests.rs +++ b/crates/adaptive/tests/unit/intercepts_tests.rs @@ -50,7 +50,9 @@ fn make_hot_cache( acg_stability: Some(StabilityAnalysisResult { scores: vec![], stable_prefix_length, + stable_prefix_fingerprint: None, total_observations: observation_count, + converged: false, }), acg_observation_count: observation_count, })) diff --git a/crates/adaptive/tests/unit/plugin_component_tests.rs b/crates/adaptive/tests/unit/plugin_component_tests.rs index 03fe3f009..1d840831e 100644 --- a/crates/adaptive/tests/unit/plugin_component_tests.rs +++ b/crates/adaptive/tests/unit/plugin_component_tests.rs @@ -14,10 +14,10 @@ use nemo_relay::api::runtime::global_context; use nemo_relay::plugin::{DiagnosticLevel, UnsupportedBehavior, clear_plugin_configuration}; use nemo_relay::plugin::{Plugin, PluginRegistrationContext, rollback_registrations}; use serde_json::json; -use tokio::sync::Mutex as AsyncMutex; + +use crate::test_support::GLOBAL_RUNTIME_TEST_MUTEX; static TEST_MUTEX: OnceLock> = OnceLock::new(); -static ASYNC_TEST_MUTEX: AsyncMutex<()> = AsyncMutex::const_new(()); fn test_mutex() -> &'static Mutex<()> { TEST_MUTEX.get_or_init(|| Mutex::new(())) @@ -92,6 +92,7 @@ fn validate_adaptive_plugin_config_reports_unknown_fields_and_backend_errors() { #[test] fn register_adaptive_component_is_idempotent_and_deregisters_cleanly() { let _guard = test_mutex().lock().unwrap(); + let _runtime_guard = GLOBAL_RUNTIME_TEST_MUTEX.blocking_lock(); let _ = clear_plugin_configuration(); let _ = deregister_adaptive_component(); @@ -349,9 +350,40 @@ fn validate_adaptive_plugin_config_reports_component_specific_unknown_fields() { })); } +#[test] +fn validate_adaptive_plugin_config_accepts_topology_sections() { + let config = json!({ + "version": 1, + "adaptive_hints": { + "governor": {"enabled": true, "epsilon": 2.0} + }, + "tool_parallelism": { + "mode": "observe_only", + "drift": {"enabled": true, "threshold": 0.5} + }, + "acg": { + "provider": "anthropic", + "convergence": {"enabled": true, "epsilon": 0.01, "stability_window": 3} + }, + "convergence": {"enabled": true, "epsilon": 0.01, "stability_window": 3}, + "policy": { + "unknown_field": "error", + "unsupported_value": "error" + } + }); + + let diagnostics = validate_adaptive_plugin_config(config.as_object().unwrap()); + assert!( + diagnostics + .iter() + .all(|diag| diag.code != "adaptive.unknown_field"), + "topology config fields should not be reported as unknown: {diagnostics:?}" + ); +} + #[tokio::test(flavor = "current_thread")] async fn adaptive_plugin_registers_runtime_and_rolls_back_registration() { - let _guard = ASYNC_TEST_MUTEX.lock().await; + let _guard = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let plugin = AdaptivePlugin; diff --git a/crates/adaptive/tests/unit/runtime_features_tests.rs b/crates/adaptive/tests/unit/runtime_features_tests.rs index c64818606..b91526dc3 100644 --- a/crates/adaptive/tests/unit/runtime_features_tests.rs +++ b/crates/adaptive/tests/unit/runtime_features_tests.rs @@ -7,6 +7,17 @@ use super::*; use std::sync::Arc; +use crate::acg::profile::{BlockStabilityScore, StabilityClass}; +use crate::acg::prompt_ir::SpanId; +use crate::acg::stability::StabilityAnalysisResult; +use crate::config::{BackendSpec, StateConfig}; +use crate::intercepts::AGENT_HINTS_HEADER_KEY; +use crate::test_support::GLOBAL_RUNTIME_TEST_MUTEX; +use crate::trie::accumulator::AccumulatorState; +use crate::trie::serialization::TrieEnvelope; +use crate::types::metadata::{AgentHints, MetadataEnvelope, ParallelHint}; +use crate::types::plan::{ExecutionPlan, ParallelGroup}; +use crate::types::records::RunRecord; use nemo_relay::api::llm::{ LlmCallExecuteParams, LlmRequest, LlmStreamCallExecuteParams, llm_call_execute, llm_request_intercepts, llm_stream_call_execute, @@ -24,26 +35,14 @@ use nemo_relay::api::runtime::{ }; use nemo_relay::api::subscriber::{deregister_subscriber, register_subscriber}; use nemo_relay::api::tool::tool_call_execute; +use nemo_relay::codec::anthropic::AnthropicMessagesCodec; +use nemo_relay::codec::traits::LlmCodec; use nemo_relay::error::FlowError; use nemo_relay::plugin::{ConfigPolicy, DiagnosticLevel, UnsupportedBehavior}; use nemo_relay::plugin::{clear_plugin_configuration, rollback_registrations}; use serde_json::json; -use tokio::sync::Mutex; - -use crate::acg::profile::{BlockStabilityScore, StabilityClass}; -use crate::acg::prompt_ir::SpanId; -use crate::acg::stability::StabilityAnalysisResult; -use crate::config::{BackendSpec, StateConfig}; -use crate::intercepts::AGENT_HINTS_HEADER_KEY; -use crate::trie::accumulator::AccumulatorState; -use crate::trie::serialization::TrieEnvelope; -use crate::types::metadata::{AgentHints, MetadataEnvelope, ParallelHint}; -use crate::types::plan::{ExecutionPlan, ParallelGroup}; -use crate::types::records::RunRecord; use tokio_stream::StreamExt; -static TEST_MUTEX: Mutex<()> = Mutex::const_new(()); - fn reset_global() { let _ = clear_plugin_configuration(); let ctx = global_context(); @@ -90,6 +89,12 @@ fn layered_acg_request() -> LlmRequest { } fn layered_acg_stability_result(observation_count: u32) -> StabilityAnalysisResult { + let annotated_request = AnthropicMessagesCodec + .decode(&layered_acg_request()) + .expect("layered request should decode"); + let prompt_ir = crate::acg::ir_builder::build_prompt_ir(&annotated_request) + .expect("layered request should build PromptIR"); + StabilityAnalysisResult { scores: vec![ BlockStabilityScore { @@ -115,7 +120,9 @@ fn layered_acg_stability_result(observation_count: u32) -> StabilityAnalysisResu }, ], stable_prefix_length: 3, + stable_prefix_fingerprint: crate::acg::stability::prompt_prefix_fingerprint(&prompt_ir, 3), total_observations: observation_count, + converged: false, } } @@ -288,6 +295,8 @@ fn build_learners_filters_unknown_entries() { "agent-a", &["latency_sensitivity".to_string(), "unknown".to_string()], None, + None, + None, ); assert_eq!(learners.len(), 1); } @@ -427,7 +436,7 @@ async fn adaptive_runtime_new_rejects_invalid_configs_with_joined_errors() { #[tokio::test(flavor = "current_thread")] async fn registration_context_take_event_receiver_only_allows_one_consumer() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig::default()) @@ -444,7 +453,7 @@ async fn registration_context_take_event_receiver_only_allows_one_consumer() { #[tokio::test(flavor = "current_thread")] async fn telemetry_feature_registers_subscriber_and_starts_drain_task() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig { @@ -463,6 +472,8 @@ async fn telemetry_feature_registers_subscriber_and_starts_drain_task() { "agent-telemetry".into(), Uuid::now_v7(), None, + None, + None, ); let name = feature.subscriber_name.clone(); @@ -484,7 +495,7 @@ async fn telemetry_feature_registers_subscriber_and_starts_drain_task() { #[tokio::test(flavor = "current_thread")] async fn telemetry_feature_requires_backend() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig::default()) @@ -495,6 +506,8 @@ async fn telemetry_feature_requires_backend() { "agent-telemetry".into(), Uuid::now_v7(), None, + None, + None, ); let mut ctx = RegistrationContext::new(&mut runtime); @@ -506,7 +519,7 @@ async fn telemetry_feature_requires_backend() { #[tokio::test(flavor = "current_thread")] async fn adaptive_hints_feature_registers_request_intercept() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig::default()) @@ -562,7 +575,7 @@ async fn adaptive_hints_feature_registers_request_intercept() { #[tokio::test(flavor = "current_thread")] async fn tool_parallelism_feature_registers_execution_intercept() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig::default()) @@ -611,7 +624,7 @@ async fn tool_parallelism_feature_registers_execution_intercept() { #[tokio::test(flavor = "current_thread")] async fn adaptive_runtime_register_survives_hot_cache_seed_failures() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let config = AdaptiveConfig { @@ -653,7 +666,7 @@ async fn adaptive_runtime_register_survives_hot_cache_seed_failures() { #[tokio::test(flavor = "current_thread")] async fn adaptive_runtime_register_is_idempotent_for_active_features() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig { @@ -678,7 +691,7 @@ async fn adaptive_runtime_register_is_idempotent_for_active_features() { #[tokio::test(flavor = "current_thread")] async fn adaptive_runtime_register_rolls_back_when_telemetry_receiver_is_missing() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig { @@ -702,7 +715,7 @@ async fn adaptive_runtime_register_rolls_back_when_telemetry_receiver_is_missing #[tokio::test(flavor = "current_thread")] async fn registration_context_registers_all_supported_callback_types() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig::default()) @@ -781,7 +794,9 @@ async fn adaptive_runtime_helper_methods_cover_report_wait_for_idle_and_feature_ build_learners( "agent-a", &["tool_parallelism".to_string(), "acg".to_string()], + config.tool_parallelism.as_ref(), config.acg.as_ref(), + config.convergence.as_ref(), ) .len(), 2 @@ -811,7 +826,7 @@ async fn adaptive_runtime_helper_methods_cover_report_wait_for_idle_and_feature_ #[tokio::test(flavor = "current_thread")] async fn acg_feature_registers_execution_and_stream_intercepts() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig::default()) @@ -918,7 +933,7 @@ async fn acg_feature_registers_execution_and_stream_intercepts() { #[tokio::test(flavor = "current_thread")] async fn adaptive_runtime_register_feature_rolls_back_partial_registrations_and_abort_handle() { - let _lock = TEST_MUTEX.lock().await; + let _lock = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_global(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig::default()) diff --git a/crates/adaptive/tests/unit/runtime_tests.rs b/crates/adaptive/tests/unit/runtime_tests.rs index 5802a58a6..b6e7cd404 100644 --- a/crates/adaptive/tests/unit/runtime_tests.rs +++ b/crates/adaptive/tests/unit/runtime_tests.rs @@ -11,13 +11,15 @@ use nemo_relay::api::scope::{PopScopeParams, PushScopeParams, ScopeType, pop_sco use serde_json::{Map, Value as Json}; use crate::config::{ - AcgComponentConfig, AdaptiveConfig, BackendSpec, StateConfig, TelemetryComponentConfig, + AcgComponentConfig, AdaptiveConfig, AdaptiveHintsComponentConfig, BackendSpec, + ConvergenceConfig, DriftConfig, GovernorConfig, StateConfig, TelemetryComponentConfig, ToolParallelismComponentConfig, }; use crate::error::AdaptiveError; use crate::runtime::backend::build_backend; use crate::runtime::features::AdaptiveRuntime; use crate::runtime::validation::validate_config; +use crate::test_support::GLOBAL_RUNTIME_TEST_MUTEX; use nemo_relay::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; use nemo_relay::plugin::{ConfigPolicy, UnsupportedBehavior}; @@ -317,6 +319,64 @@ fn validate_config_reports_unknown_backend_and_acg_provider_per_policy() { assert!(ignore_report.diagnostics.is_empty()); } +#[test] +fn validate_config_reports_invalid_topology_numeric_fields() { + let report = validate_config(&AdaptiveConfig { + adaptive_hints: Some(AdaptiveHintsComponentConfig { + governor: Some(GovernorConfig { + enabled: true, + epsilon: f64::NAN, + }), + ..AdaptiveHintsComponentConfig::default() + }), + tool_parallelism: Some(ToolParallelismComponentConfig { + drift: Some(DriftConfig { + enabled: true, + threshold: 0.0, + }), + ..ToolParallelismComponentConfig::default() + }), + acg: Some(AcgComponentConfig { + convergence: Some(ConvergenceConfig { + enabled: true, + epsilon: -1.0, + stability_window: 2, + }), + ..AcgComponentConfig::default() + }), + convergence: Some(ConvergenceConfig { + enabled: true, + epsilon: f64::INFINITY, + stability_window: 0, + }), + policy: ConfigPolicy { + unsupported_value: UnsupportedBehavior::Error, + ..ConfigPolicy::default() + }, + ..AdaptiveConfig::default() + }); + + assert!(report.has_errors()); + for (component, field) in [ + ("adaptive_hints.governor", "epsilon"), + ("tool_parallelism.drift", "threshold"), + ("acg.convergence", "epsilon"), + ("acg.convergence", "stability_window"), + ("convergence", "epsilon"), + ("convergence", "stability_window"), + ] { + assert!( + report + .diagnostics + .iter() + .any(|diag| diag.code == "adaptive.unsupported_value" + && diag.component.as_deref() == Some(component) + && diag.field.as_deref() == Some(field)), + "expected unsupported value diagnostic for {component}.{field}" + ); + } +} + #[tokio::test(flavor = "current_thread")] async fn adaptive_runtime_new_accepts_valid_in_memory_configuration() { let runtime = AdaptiveRuntime::new(AdaptiveConfig { @@ -410,18 +470,14 @@ fn adaptive_acg_defaults_and_profile_key_behavior_stay_stable() { ); assert_eq!( profile_key, - "agent-1::model=claude-sonnet-4::roles=system.user::system=sha256:97f793c76::anchor=no-anchor::tools=no-tools" + "agent-1::model=claude-sonnet-4::roles=system.user::system=sha256:97f793c76::anchor=no-anchor::tools=no-tools::contract=no-contract" ); let learning_key = crate::acg_profile::derive_acg_learning_key( "agent-1", &sample_annotated_request(Some("claude-sonnet-4")), ); let expected_learning_key = format!( - "agent-1::model=claude-sonnet-4::seed={}::system={}::tools=no-tools", - short_hash(&format!( - "user:{}", - crate::acg::sha256_hex("Summarize the latest findings") - )), + "agent-1::model=claude-sonnet-4::scaffold=stable::system={}::tools=no-tools", short_hash(&crate::acg::sha256_hex("You are a careful planner")), ); assert_eq!(learning_key, expected_learning_key,); @@ -498,9 +554,9 @@ fn adaptive_acg_defaults_and_profile_key_behavior_stay_stable() { "agent-1", &sample_layered_request(Some("claude-sonnet-4"), "Python review guide"), ); - assert_ne!( + assert_eq!( rust_learning_key, python_learning_key, - "layered requests should still separate learning buckets when the stable anchor differs", + "layered requests share the coarse learning bucket; topology checks reopen when the stable anchor differs", ); let rust_bundle_variant = AnnotatedLlmRequest { @@ -581,6 +637,7 @@ async fn adaptive_runtime_build_cache_request_facts_keeps_missing_stability_sema #[tokio::test(flavor = "current_thread")] async fn adaptive_runtime_bind_scope_requires_registration_and_passes_through_without_state() { + let _guard = GLOBAL_RUNTIME_TEST_MUTEX.lock().await; reset_runtime_context(); let mut runtime = AdaptiveRuntime::new(AdaptiveConfig { agent_id: Some("agent-1".to_string()), diff --git a/crates/adaptive/tests/unit/storage_memory_internal_tests.rs b/crates/adaptive/tests/unit/storage_memory_internal_tests.rs index 311f007dc..7585f250e 100644 --- a/crates/adaptive/tests/unit/storage_memory_internal_tests.rs +++ b/crates/adaptive/tests/unit/storage_memory_internal_tests.rs @@ -84,7 +84,9 @@ fn sample_stability_record() -> StabilityAnalysisResult { observation_count: 1, }], stable_prefix_length: 1, + stable_prefix_fingerprint: None, total_observations: 1, + converged: false, } } diff --git a/crates/adaptive/tests/unit/storage_tests.rs b/crates/adaptive/tests/unit/storage_tests.rs index 955dc8c76..0beb39da6 100644 --- a/crates/adaptive/tests/unit/storage_tests.rs +++ b/crates/adaptive/tests/unit/storage_tests.rs @@ -89,7 +89,9 @@ fn sample_stability(agent_id: &str) -> StabilityAnalysisResult { observation_count: 3, }], stable_prefix_length: 1, + stable_prefix_fingerprint: None, total_observations: 3, + converged: false, } } @@ -249,6 +251,7 @@ async fn in_memory_backend_round_trips_observations_and_stability() { ); assert_eq!(loaded_stability.stable_prefix_length, 1); assert_eq!(loaded_stability.total_observations, 3); + assert!(!loaded_stability.converged); } #[tokio::test(flavor = "current_thread")] diff --git a/crates/adaptive/tests/unit/tool_parallelism_learner_tests.rs b/crates/adaptive/tests/unit/tool_parallelism_learner_tests.rs index 5d28bacc7..0aac5cd36 100644 --- a/crates/adaptive/tests/unit/tool_parallelism_learner_tests.rs +++ b/crates/adaptive/tests/unit/tool_parallelism_learner_tests.rs @@ -10,6 +10,7 @@ use serde_json::json; use uuid::Uuid; use super::*; +use crate::config::DriftConfig; use crate::storage::memory::InMemoryBackend; use crate::storage::traits::StorageBackend; use crate::types::cache::HotCache; @@ -299,6 +300,64 @@ async fn process_run_merges_new_cohorts_into_existing_plan() { ); } +#[tokio::test] +async fn process_run_invalidates_existing_plan_when_tool_cohort_topology_drifts() { + let backend = InMemoryBackend::new(); + let hot_cache = make_hot_cache(); + let learner = ToolParallelismLearner::new_with_drift( + "agent-drift", + Some(DriftConfig { + enabled: true, + threshold: 0.01, + }), + ); + let base = Utc::now(); + let first_run = make_run( + "agent-drift", + vec![ + make_tool_call("search", base, 0, Some(90)), + make_tool_call("fetch", base, 10, Some(100)), + ], + ); + let drifted_run = make_run( + "agent-drift", + vec![ + make_tool_call("compile", base, 0, Some(120)), + make_tool_call("test", base, 10, Some(130)), + make_tool_call("lint", base, 20, Some(140)), + ], + ); + + backend + .store_plan(&make_existing_plan("agent-drift")) + .unwrap(); + learner + .process_run(&first_run, &backend, &hot_cache) + .await + .unwrap(); + learner + .process_run(&drifted_run, &backend, &hot_cache) + .await + .unwrap(); + + let plan = backend.load_plan("agent-drift").await.unwrap().unwrap(); + assert!( + !plan + .parallel_groups + .iter() + .any(|group| group.group_id == "fanout:existing"), + "drifted cohort topology should invalidate stale plan groups" + ); + assert!(plan.parallel_groups.iter().any(|group| { + group.tool_names + == vec![ + "compile".to_string(), + "lint".to_string(), + "test".to_string(), + ] + })); +} + #[tokio::test] async fn process_run_reports_hot_cache_lock_poisoning() { let backend = InMemoryBackend::new(); diff --git a/crates/adaptive/tests/unit/topology_tests.rs b/crates/adaptive/tests/unit/topology_tests.rs new file mode 100644 index 000000000..15a107f2f --- /dev/null +++ b/crates/adaptive/tests/unit/topology_tests.rs @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Unit tests for internal topology-aware adaptive control primitives. + +use super::*; + +#[test] +fn empty_detector_is_not_converged() { + let detector = ConvergenceDetector::new(0.001, 3); + + assert!(!detector.is_converged()); + assert_eq!(detector.epoch(), 0); +} + +#[test] +fn non_finite_metrics_do_not_converge() { + let mut detector = ConvergenceDetector::new(f64::NAN, 0); + + detector.record_epoch(BettiNumbers::new(1, 0), f64::NAN, f64::NAN); + + assert!(!detector.is_converged()); + assert_eq!(detector.epoch(), 1); +} + +#[test] +fn stability_window_is_clamped_to_history_capacity() { + let detector = ConvergenceDetector::new(0.001, MAX_HISTORY + 1); + + assert_eq!(detector.stability_window, MAX_HISTORY); +} + +#[test] +fn low_error_requires_full_window_before_converging() { + let mut detector = ConvergenceDetector::new(0.01, 3); + + detector.record_epoch(BettiNumbers::new(1, 0), 0.5, 0.005); + assert!(!detector.is_converged()); + + detector.record_epoch(BettiNumbers::new(1, 0), 0.1, 0.005); + assert!(!detector.is_converged()); + + detector.record_epoch(BettiNumbers::new(1, 0), 0.001, 0.005); + assert!(detector.is_converged()); +} + +#[test] +fn stable_betti_and_decreasing_drift_converges() { + let mut detector = ConvergenceDetector::new(0.001, 3); + + detector.record_epoch(BettiNumbers::new(1, 0), 0.1, 0.0005); + detector.record_epoch(BettiNumbers::new(1, 0), 0.05, 0.0004); + detector.record_epoch(BettiNumbers::new(1, 0), 0.001, 0.0003); + + assert!(detector.is_converged()); +} + +#[test] +fn convergence_decision_reports_latest_epoch_and_gate_status() { + let mut detector = ConvergenceDetector::new(0.001, 3); + + detector.record_epoch(BettiNumbers::new(2, 1), 0.1, 0.0005); + detector.record_epoch(BettiNumbers::new(2, 1), 0.05, 0.0004); + let decision = detector.record_epoch(BettiNumbers::new(2, 1), 0.001, 0.0003); + + assert_eq!(decision.epoch, 3); + assert_eq!(decision.stability_window, 3); + assert_eq!(decision.latest_betti, BettiNumbers::new(2, 1)); + assert_eq!(decision.latest_drift, 0.001); + assert_eq!(decision.latest_error, 0.0003); + assert!(decision.betti_stable); + assert!(decision.drift_decreasing); + assert!(decision.error_converged); + assert!(decision.converged); +} + +#[test] +fn unstable_betti_does_not_converge() { + let mut detector = ConvergenceDetector::new(0.001, 3); + + detector.record_epoch(BettiNumbers::new(1, 0), 0.05, 0.0005); + detector.record_epoch(BettiNumbers::new(2, 0), 0.03, 0.0004); + detector.record_epoch(BettiNumbers::new(1, 0), 0.001, 0.0003); + + assert!(!detector.is_converged()); +} + +#[test] +fn increasing_drift_does_not_converge() { + let mut detector = ConvergenceDetector::new(0.001, 3); + + detector.record_epoch(BettiNumbers::new(1, 0), 0.01, 0.0005); + detector.record_epoch(BettiNumbers::new(1, 0), 0.05, 0.0004); + detector.record_epoch(BettiNumbers::new(1, 0), 0.001, 0.0003); + + assert!(!detector.is_converged()); +} + +#[test] +fn drift_detector_tracks_sudden_centroid_change() { + let mut detector = DriftDetector::<3>::new(); + + assert_eq!(detector.update(&[0.0, 0.0, 0.0]), 0.0); + assert_eq!(detector.update(&[1.0, 0.0, 0.0]), 1.0); + assert_eq!(detector.update(&[2.0, 0.0, 0.0]), 0.0); + + let drift = detector.update(&[5.0, 0.0, 0.0]); + + assert!(drift > 1.0); + assert!(detector.velocity[0] > 1.0); +} + +#[test] +fn drift_detector_resets_on_non_finite_centroid() { + let mut detector = DriftDetector::<3>::new(); + + detector.update(&[0.0, 0.0, 0.0]); + detector.update(&[1.0, 0.0, 0.0]); + + assert_eq!(detector.update(&[f64::NAN, 0.0, 0.0]), f64::INFINITY); + assert!(!detector.has_previous); + assert_eq!(detector.update(&[2.0, 0.0, 0.0]), 0.0); +} + +#[test] +fn governor_sheds_below_threshold_and_allows_at_threshold() { + let governor = GeometricGovernor::with_epsilon(0.5); + + assert!(!governor.should_trigger(0.4)); + assert!(governor.should_trigger(0.5)); + assert!(governor.should_trigger(0.6)); +} + +#[test] +fn governor_clamps_non_finite_and_extreme_inputs() { + let mut governor = GeometricGovernor::with_epsilon(f64::NAN); + + assert!((governor.epsilon - GOVERNOR_EPSILON_INITIAL).abs() < f64::EPSILON); + assert_eq!(governor.adapt(f64::NAN, 1.0), GOVERNOR_EPSILON_INITIAL); + assert_eq!(governor.adapt(1.0, f64::INFINITY), GOVERNOR_EPSILON_INITIAL); + assert!(!governor.should_trigger(f64::NAN)); + + governor.adapt(1_000_000.0, 0.001); + assert!(governor.epsilon >= GOVERNOR_EPSILON_MIN); + assert!(governor.epsilon <= GOVERNOR_EPSILON_MAX); +} diff --git a/crates/adaptive/tests/unit/types_tests.rs b/crates/adaptive/tests/unit/types_tests.rs index c31b0c58d..82e4ab734 100644 --- a/crates/adaptive/tests/unit/types_tests.rs +++ b/crates/adaptive/tests/unit/types_tests.rs @@ -41,7 +41,9 @@ fn sample_stability_result() -> StabilityAnalysisResult { observation_count: 4, }], stable_prefix_length: 1, + stable_prefix_fingerprint: None, total_observations: 4, + converged: false, } } diff --git a/crates/node/adaptive.d.ts b/crates/node/adaptive.d.ts index a2ad4d30c..4817719b5 100644 --- a/crates/node/adaptive.d.ts +++ b/crates/node/adaptive.d.ts @@ -29,12 +29,33 @@ export interface AdaptiveHintsConfig { break_chain?: boolean; inject_header?: boolean; inject_body_path?: string; + governor?: GovernorConfig; } /** Built-in adaptive tool scheduling settings. */ export interface ToolParallelismConfig { priority?: number; mode?: 'observe_only' | 'inject_hints' | 'schedule' | string; + drift?: DriftConfig; +} + +/** Topology-aware hint load-shedding settings. */ +export interface GovernorConfig { + enabled?: boolean; + epsilon?: number; +} + +/** Topology-aware tool-plan drift detection settings. */ +export interface DriftConfig { + enabled?: boolean; + threshold?: number; +} + +/** Topological convergence detector settings. */ +export interface ConvergenceConfig { + enabled?: boolean; + epsilon?: number; + stability_window?: number; } /** ACG prompt-stability classification thresholds. */ @@ -50,6 +71,7 @@ export interface AcgConfig { observation_window?: number; priority?: number; stability_thresholds?: AcgStabilityThresholds; + convergence?: ConvergenceConfig; } /** Canonical config object for the top-level adaptive component. */ @@ -61,6 +83,7 @@ export interface Config { adaptive_hints?: AdaptiveHintsConfig; tool_parallelism?: ToolParallelismConfig; acg?: AcgConfig; + convergence?: ConvergenceConfig; policy?: ConfigPolicy; } @@ -208,6 +231,12 @@ export declare function redisBackend(url: string, keyPrefix?: string): BackendSp * append learner names without checking for initialization first. */ export declare function telemetryConfig(config?: TelemetryConfig): TelemetryConfig; +/** Create topology-aware hint load-shedding settings with defaults applied. */ +export declare function governorConfig(config?: GovernorConfig): GovernorConfig; +/** Create topology-aware tool-plan drift detection settings with defaults applied. */ +export declare function driftConfig(config?: DriftConfig): DriftConfig; +/** Create topological convergence detector settings with defaults applied. */ +export declare function convergenceConfig(config?: ConvergenceConfig): ConvergenceConfig; /** * Create adaptive hint-injection settings with defaults applied. * diff --git a/crates/node/adaptive.js b/crates/node/adaptive.js index 8392f46a3..5bc5e205e 100644 --- a/crates/node/adaptive.js +++ b/crates/node/adaptive.js @@ -80,6 +80,49 @@ function telemetryConfig(config = {}) { }; } +/** + * Create topology-aware hint load-shedding settings with defaults applied. + * + * @param {object} [config={}] - Partial governor settings to override. + * @returns {object} A normalized governor config object. + */ +function governorConfig(config = {}) { + return { + enabled: false, + epsilon: 1.0, + ...config, + }; +} + +/** + * Create topology-aware tool-plan drift detection settings with defaults applied. + * + * @param {object} [config={}] - Partial drift settings to override. + * @returns {object} A normalized drift config object. + */ +function driftConfig(config = {}) { + return { + enabled: false, + threshold: 0.75, + ...config, + }; +} + +/** + * Create topological convergence detector settings with defaults applied. + * + * @param {object} [config={}] - Partial convergence settings to override. + * @returns {object} A normalized convergence config object. + */ +function convergenceConfig(config = {}) { + return { + enabled: false, + epsilon: 0.001, + stability_window: 3, + ...config, + }; +} + /** * Create adaptive hint-injection settings with defaults applied. * @@ -202,6 +245,9 @@ module.exports = { inMemoryBackend, redisBackend, telemetryConfig, + governorConfig, + driftConfig, + convergenceConfig, adaptiveHintsConfig, toolParallelismConfig, acgConfig, diff --git a/crates/node/tests/adaptive_tests.mjs b/crates/node/tests/adaptive_tests.mjs index 540590fa2..44d100559 100644 --- a/crates/node/tests/adaptive_tests.mjs +++ b/crates/node/tests/adaptive_tests.mjs @@ -203,4 +203,36 @@ describe('adaptive helpers', () => { }, ); }); + + it('builds topology-aware adaptive config helpers', () => { + assert.deepEqual(adaptive.governorConfig({ enabled: true }), { + enabled: true, + epsilon: 1.0, + }); + assert.deepEqual(adaptive.driftConfig({ enabled: true }), { + enabled: true, + threshold: 0.75, + }); + assert.deepEqual(adaptive.convergenceConfig({ enabled: true }), { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }); + + const config = adaptive.defaultConfig(); + config.state = { backend: adaptive.inMemoryBackend() }; + config.adaptive_hints = adaptive.adaptiveHintsConfig({ + governor: adaptive.governorConfig({ enabled: true }), + }); + config.tool_parallelism = adaptive.toolParallelismConfig({ + drift: adaptive.driftConfig({ enabled: true }), + }); + config.acg = adaptive.acgConfig({ + provider: 'anthropic', + convergence: adaptive.convergenceConfig({ enabled: true }), + }); + config.convergence = adaptive.convergenceConfig({ enabled: true }); + + assert.equal(adaptive.validateConfig(config).diagnostics.length, 0); + }); }); diff --git a/crates/python/tests/coverage/py_storage_coverage_tests.rs b/crates/python/tests/coverage/py_storage_coverage_tests.rs index 0661dcf90..810a5bb38 100644 --- a/crates/python/tests/coverage/py_storage_coverage_tests.rs +++ b/crates/python/tests/coverage/py_storage_coverage_tests.rs @@ -140,7 +140,9 @@ fn sample_stability_result() -> StabilityAnalysisResult { observation_count: 2, }], stable_prefix_length: 1, + stable_prefix_fingerprint: None, total_observations: 2, + converged: false, } } diff --git a/go/nemo_relay/adaptive.go b/go/nemo_relay/adaptive.go index eb7ce81f5..65fffe7c5 100644 --- a/go/nemo_relay/adaptive.go +++ b/go/nemo_relay/adaptive.go @@ -17,6 +17,7 @@ type AdaptiveConfig struct { AdaptiveHints *AdaptiveHintsConfig `json:"adaptive_hints,omitempty"` ToolParallelism *ToolParallelismConfig `json:"tool_parallelism,omitempty"` Acg *AcgConfig `json:"acg,omitempty"` + Convergence *ConvergenceConfig `json:"convergence,omitempty"` Policy *ConfigPolicy `json:"policy,omitempty"` } @@ -39,16 +40,37 @@ type TelemetryConfig struct { // AdaptiveHintsConfig configures built-in LLM request hint injection. type AdaptiveHintsConfig struct { - Priority int32 `json:"priority,omitempty"` - BreakChain bool `json:"break_chain,omitempty"` - InjectHeader bool `json:"inject_header,omitempty"` - InjectBodyPath string `json:"inject_body_path,omitempty"` + Priority int32 `json:"priority,omitempty"` + BreakChain bool `json:"break_chain,omitempty"` + InjectHeader bool `json:"inject_header,omitempty"` + InjectBodyPath string `json:"inject_body_path,omitempty"` + Governor *GovernorConfig `json:"governor,omitempty"` } // ToolParallelismConfig configures built-in adaptive tool scheduling. type ToolParallelismConfig struct { - Priority int32 `json:"priority,omitempty"` - Mode string `json:"mode,omitempty"` + Priority int32 `json:"priority,omitempty"` + Mode string `json:"mode,omitempty"` + Drift *DriftConfig `json:"drift,omitempty"` +} + +// GovernorConfig configures topology-aware hint load shedding. +type GovernorConfig struct { + Enabled bool `json:"enabled,omitempty"` + Epsilon float64 `json:"epsilon,omitempty"` +} + +// DriftConfig configures topology-aware stale-plan invalidation. +type DriftConfig struct { + Enabled bool `json:"enabled,omitempty"` + Threshold float64 `json:"threshold,omitempty"` +} + +// ConvergenceConfig configures topological convergence detection. +type ConvergenceConfig struct { + Enabled bool `json:"enabled,omitempty"` + Epsilon float64 `json:"epsilon,omitempty"` + StabilityWindow uint32 `json:"stability_window,omitempty"` } // AcgStabilityThresholds configures prompt stability classification thresholds. @@ -64,6 +86,7 @@ type AcgConfig struct { ObservationWindow uint32 `json:"observation_window,omitempty"` Priority int32 `json:"priority,omitempty"` StabilityThresholds *AcgStabilityThresholds `json:"stability_thresholds,omitempty"` + Convergence *ConvergenceConfig `json:"convergence,omitempty"` } // AdaptiveComponentSpec wraps one adaptive config as a top-level plugin component. @@ -101,6 +124,21 @@ func NewTelemetryConfig() TelemetryConfig { return TelemetryConfig{} } +// NewGovernorConfig returns default topology-aware hint load-shedding settings. +func NewGovernorConfig() GovernorConfig { + return GovernorConfig{Epsilon: 1.0} +} + +// NewDriftConfig returns default topology-aware stale-plan detection settings. +func NewDriftConfig() DriftConfig { + return DriftConfig{Threshold: 0.75} +} + +// NewConvergenceConfig returns default topological convergence detection settings. +func NewConvergenceConfig() ConvergenceConfig { + return ConvergenceConfig{Epsilon: 0.001, StabilityWindow: 3} +} + // NewAdaptiveHintsConfig returns default adaptive hints injection settings. func NewAdaptiveHintsConfig() AdaptiveHintsConfig { return AdaptiveHintsConfig{ diff --git a/go/nemo_relay/adaptive/adaptive.go b/go/nemo_relay/adaptive/adaptive.go index 493064225..44fd1ba84 100644 --- a/go/nemo_relay/adaptive/adaptive.go +++ b/go/nemo_relay/adaptive/adaptive.go @@ -46,6 +46,15 @@ type AdaptiveHintsConfig = nemo_relay.AdaptiveHintsConfig // ToolParallelismConfig configures built-in adaptive tool scheduling. type ToolParallelismConfig = nemo_relay.ToolParallelismConfig +// GovernorConfig configures topology-aware hint load shedding. +type GovernorConfig = nemo_relay.GovernorConfig + +// DriftConfig configures topology-aware stale-plan invalidation. +type DriftConfig = nemo_relay.DriftConfig + +// ConvergenceConfig configures topological convergence detection. +type ConvergenceConfig = nemo_relay.ConvergenceConfig + // AcgStabilityThresholds configures ACG prompt-stability classification. type AcgStabilityThresholds = nemo_relay.AcgStabilityThresholds @@ -93,6 +102,21 @@ func NewTelemetryConfig() TelemetryConfig { return nemo_relay.NewTelemetryConfig() } +// NewGovernorConfig returns default topology-aware hint load-shedding settings. +func NewGovernorConfig() GovernorConfig { + return nemo_relay.NewGovernorConfig() +} + +// NewDriftConfig returns default topology-aware stale-plan detection settings. +func NewDriftConfig() DriftConfig { + return nemo_relay.NewDriftConfig() +} + +// NewConvergenceConfig returns default topological convergence detection settings. +func NewConvergenceConfig() ConvergenceConfig { + return nemo_relay.NewConvergenceConfig() +} + // NewAdaptiveHintsConfig returns default adaptive hints injection settings. func NewAdaptiveHintsConfig() AdaptiveHintsConfig { return nemo_relay.NewAdaptiveHintsConfig() diff --git a/go/nemo_relay/adaptive_test.go b/go/nemo_relay/adaptive_test.go index fba86b6ff..5d57e54ec 100644 --- a/go/nemo_relay/adaptive_test.go +++ b/go/nemo_relay/adaptive_test.go @@ -19,7 +19,7 @@ func TestNewAdaptiveConfigDefaults(t *testing.T) { if config.Version != 1 { t.Fatalf("expected version 1, got %d", config.Version) } - if config.Telemetry != nil || config.AdaptiveHints != nil || config.ToolParallelism != nil || config.Acg != nil { + if config.Telemetry != nil || config.AdaptiveHints != nil || config.ToolParallelism != nil || config.Acg != nil || config.Convergence != nil { t.Fatal("expected adaptive feature sections to default to nil") } } @@ -31,6 +31,15 @@ func TestAdaptiveHelperConstructors(t *testing.T) { telemetry := NewTelemetryConfig() assertTelemetryDefaults(t, telemetry) + governor := NewGovernorConfig() + assertGovernorDefaults(t, governor) + + drift := NewDriftConfig() + assertDriftDefaults(t, drift) + + convergence := NewConvergenceConfig() + assertConvergenceDefaults(t, convergence) + hints := NewAdaptiveHintsConfig() assertAdaptiveHintsDefaults(t, hints) @@ -73,6 +82,27 @@ func assertTelemetryDefaults(t *testing.T, telemetry TelemetryConfig) { } } +func assertGovernorDefaults(t *testing.T, governor GovernorConfig) { + t.Helper() + if governor.Enabled || governor.Epsilon != 1.0 { + t.Fatalf("unexpected governor defaults: %#v", governor) + } +} + +func assertDriftDefaults(t *testing.T, drift DriftConfig) { + t.Helper() + if drift.Enabled || drift.Threshold != 0.75 { + t.Fatalf("unexpected drift defaults: %#v", drift) + } +} + +func assertConvergenceDefaults(t *testing.T, convergence ConvergenceConfig) { + t.Helper() + if convergence.Enabled || convergence.Epsilon != 0.001 || convergence.StabilityWindow != 3 { + t.Fatalf("unexpected convergence defaults: %#v", convergence) + } +} + func assertAdaptiveHintsDefaults(t *testing.T, hints AdaptiveHintsConfig) { t.Helper() if hints.Priority != 100 || !hints.InjectHeader || hints.InjectBodyPath != "nvext.agent_hints" { diff --git a/python/nemo_relay/adaptive.py b/python/nemo_relay/adaptive.py index e4f375457..c160f6a9c 100644 --- a/python/nemo_relay/adaptive.py +++ b/python/nemo_relay/adaptive.py @@ -150,6 +150,65 @@ def to_dict(self) -> JsonObject: ) +@dataclass(slots=True) +class GovernorConfig: + """Topology-aware load-shedding settings for adaptive hints. + + Args: + enabled: Whether the governor is active. + epsilon: Initial sensitivity threshold. + """ + + enabled: bool = False + epsilon: float = 1.0 + + def to_dict(self) -> JsonObject: + """Serialize this governor config to the canonical JSON object shape.""" + return _normalize_object({"enabled": self.enabled, "epsilon": self.epsilon}) + + +@dataclass(slots=True) +class DriftConfig: + """Topology-aware drift detection settings for tool plans. + + Args: + enabled: Whether drift detection is active. + threshold: Drift distance above which existing plans are invalidated. + """ + + enabled: bool = False + threshold: float = 0.75 + + def to_dict(self) -> JsonObject: + """Serialize this drift config to the canonical JSON object shape.""" + return _normalize_object({"enabled": self.enabled, "threshold": self.threshold}) + + +@dataclass(slots=True) +class ConvergenceConfig: + """Topological convergence detection settings. + + Args: + enabled: Whether convergence detection is active. + epsilon: Error threshold below which convergence can be declared. + stability_window: Minimum number of epochs required for stability. + """ + + enabled: bool = False + epsilon: float = 0.001 + stability_window: int = 3 + + def to_dict(self) -> JsonObject: + """Serialize this convergence config to the canonical JSON object shape.""" + return _normalize_object( + { + "enabled": self.enabled, + "epsilon": self.epsilon, + "stability_window": self.stability_window, + } + ) + + @dataclass(slots=True) class AdaptiveHintsConfig: """Built-in adaptive hints injection settings. @@ -159,12 +218,14 @@ class AdaptiveHintsConfig: break_chain: Whether to stop later request intercepts after this one. inject_header: Whether to inject the adaptive hints HTTP header. inject_body_path: JSON body path used when injecting request-body hints. + governor: Optional topology-aware load-shedding settings. """ priority: int = 100 break_chain: bool = False inject_header: bool = True inject_body_path: str = "nvext.agent_hints" + governor: GovernorConfig | None = None def to_dict(self) -> JsonObject: """Serialize this adaptive-hints config to the canonical JSON object shape.""" @@ -174,6 +235,7 @@ def to_dict(self) -> JsonObject: "break_chain": self.break_chain, "inject_header": self.inject_header, "inject_body_path": self.inject_body_path, + "governor": _normalize(self.governor), } ) @@ -187,14 +249,16 @@ class ToolParallelismConfig: mode: Scheduling mode. ``"observe_only"`` records signals without changing behavior, while other modes enable stronger adaptive scheduling behavior. + drift: Optional topology-aware stale-plan invalidation settings. """ priority: int = 100 mode: Literal["observe_only", "inject_hints", "schedule"] = "observe_only" + drift: DriftConfig | None = None def to_dict(self) -> JsonObject: """Serialize this tool-parallelism config to the canonical JSON object shape.""" - return _normalize_object({"priority": self.priority, "mode": self.mode}) + return _normalize_object({"priority": self.priority, "mode": self.mode, "drift": _normalize(self.drift)}) @dataclass(slots=True) @@ -232,12 +296,14 @@ class AcgConfig: observation_window: Rolling PromptIR observation window size. priority: LLM execution intercept priority. stability_thresholds: Prompt-stability classification thresholds. + convergence: Optional component-scoped topological convergence settings. """ provider: Literal["anthropic", "openai", "passthrough"] = "passthrough" observation_window: int = 100 priority: int = 50 stability_thresholds: AcgStabilityThresholds | None = field(default_factory=AcgStabilityThresholds) + convergence: ConvergenceConfig | None = None def to_dict(self) -> JsonObject: """Serialize this ACG config to the canonical JSON object shape.""" @@ -247,6 +313,7 @@ def to_dict(self) -> JsonObject: "observation_window": self.observation_window, "priority": self.priority, "stability_thresholds": _normalize(self.stability_thresholds), + "convergence": _normalize(self.convergence), } ) @@ -263,6 +330,7 @@ class AdaptiveConfig: adaptive_hints: Built-in LLM hint-injection settings. tool_parallelism: Built-in tool scheduling settings. acg: Adaptive Cache Governor settings. + convergence: Global topological convergence settings. policy: Unsupported-config policy applied within the adaptive config. Behavior: @@ -277,6 +345,7 @@ class AdaptiveConfig: adaptive_hints: AdaptiveHintsConfig | None = None tool_parallelism: ToolParallelismConfig | None = None acg: AcgConfig | None = None + convergence: ConvergenceConfig | None = None policy: ConfigPolicy = field(default_factory=ConfigPolicy) def to_dict(self) -> JsonObject: @@ -289,6 +358,7 @@ def to_dict(self) -> JsonObject: "adaptive_hints": _normalize(self.adaptive_hints), "tool_parallelism": _normalize(self.tool_parallelism), "acg": _normalize(self.acg), + "convergence": _normalize(self.convergence), "policy": self.policy.to_dict(), } @@ -379,15 +449,21 @@ def set_latency_sensitivity(level: int) -> None: "AcgStabilityThresholds", "AdaptiveConfig", "AdaptiveHintsConfig", + "AdaptiveRuntime", "ADAPTIVE_PLUGIN_KIND", "BackendSpec", + "build_cache_telemetry_event", "ConfigDiagnostic", "ConfigPolicy", "ConfigReport", "ComponentSpec", + "ConvergenceConfig", + "DriftConfig", + "GovernorConfig", "StateConfig", "TelemetryConfig", "ToolParallelismConfig", "set_latency_sensitivity", "UnsupportedBehavior", + "validate_config", ] diff --git a/python/nemo_relay/adaptive.pyi b/python/nemo_relay/adaptive.pyi index 892f7260b..454042f7e 100644 --- a/python/nemo_relay/adaptive.pyi +++ b/python/nemo_relay/adaptive.pyi @@ -13,18 +13,18 @@ from typing import Literal, TypedDict from nemo_relay import JsonObject, ScopeHandle, UnsupportedBehavior -class ConfigDiagnostic(TypedDict, total=False): - """One adaptive configuration diagnostic. - - Fields mirror the runtime validation report produced by the Rust adaptive - validator. - """ +__all__: list[str] +class _ConfigDiagnosticRequired(TypedDict): level: Literal["warning", "error"] code: str + message: str + +class ConfigDiagnostic(_ConfigDiagnosticRequired, total=False): + """One adaptive configuration diagnostic.""" + component: str field: str - message: str class ConfigReport(TypedDict): """Validation report returned by adaptive configuration helpers.""" @@ -105,6 +105,40 @@ class TelemetryConfig: """Serialize this telemetry config to the canonical JSON object shape.""" ... +@dataclass(slots=True) +class GovernorConfig: + """Topology-aware load-shedding settings for adaptive hints.""" + + enabled: bool = ... + epsilon: float = ... + + def to_dict(self) -> JsonObject: + """Serialize this governor config to the canonical JSON object shape.""" + ... + +@dataclass(slots=True) +class DriftConfig: + """Topology-aware drift detection settings for tool plans.""" + + enabled: bool = ... + threshold: float = ... + + def to_dict(self) -> JsonObject: + """Serialize this drift config to the canonical JSON object shape.""" + ... + +@dataclass(slots=True) +class ConvergenceConfig: + """Topological convergence detection settings.""" + + enabled: bool = ... + epsilon: float = ... + stability_window: int = ... + + def to_dict(self) -> JsonObject: + """Serialize this convergence config to the canonical JSON object shape.""" + ... + @dataclass(slots=True) class AdaptiveHintsConfig: """Built-in adaptive hints injection settings. @@ -114,12 +148,14 @@ class AdaptiveHintsConfig: break_chain: Whether to stop later request intercepts after this one. inject_header: Whether to inject the adaptive hints HTTP header. inject_body_path: JSON body path used when injecting request-body hints. + governor: Optional topology-aware load-shedding settings. """ priority: int = ... break_chain: bool = ... inject_header: bool = ... inject_body_path: str = ... + governor: GovernorConfig | None = ... def to_dict(self) -> JsonObject: """Serialize this adaptive-hints config to the canonical JSON object shape.""" @@ -133,10 +169,12 @@ class ToolParallelismConfig: priority: Intercept priority. Lower values run first. mode: Scheduling mode. ``"observe_only"`` records signals without changing behavior, while stronger modes allow adaptive scheduling. + drift: Optional topology-aware stale-plan invalidation settings. """ priority: int = ... mode: Literal["observe_only", "inject_hints", "schedule"] = ... + drift: DriftConfig | None = ... def to_dict(self) -> JsonObject: """Serialize this tool-parallelism config to the canonical JSON object shape.""" @@ -168,14 +206,16 @@ class AcgConfig: Args: provider: Provider cache plugin name. observation_window: Rolling PromptIR observation window size. - priority: Request-intercept priority used by ACG. + priority: LLM execution intercept priority used by ACG. stability_thresholds: Prompt-stability classification thresholds. + convergence: Optional component-scoped topological convergence settings. """ provider: Literal["anthropic", "openai", "passthrough"] = ... observation_window: int = ... priority: int = ... stability_thresholds: AcgStabilityThresholds | None = ... + convergence: ConvergenceConfig | None = ... def to_dict(self) -> JsonObject: """Serialize this ACG config to the canonical JSON object shape.""" @@ -193,6 +233,7 @@ class AdaptiveConfig: adaptive_hints: Built-in adaptive request-hints configuration. tool_parallelism: Built-in adaptive tool-scheduling configuration. acg: Adaptive Cache Governor configuration. + convergence: Global topological convergence settings. policy: Policy for unsupported adaptive configuration. """ @@ -203,6 +244,7 @@ class AdaptiveConfig: adaptive_hints: AdaptiveHintsConfig | None = ... tool_parallelism: ToolParallelismConfig | None = ... acg: AcgConfig | None = ... + convergence: ConvergenceConfig | None = ... policy: ConfigPolicy = ... def to_dict(self) -> JsonObject: diff --git a/python/tests/test_adaptive.py b/python/tests/test_adaptive.py index f3fb92547..bd110f397 100644 --- a/python/tests/test_adaptive.py +++ b/python/tests/test_adaptive.py @@ -19,6 +19,9 @@ AdaptiveHintsConfig, BackendSpec, ComponentSpec, + ConvergenceConfig, + DriftConfig, + GovernorConfig, StateConfig, TelemetryConfig, ToolParallelismConfig, @@ -33,6 +36,16 @@ def test_file_covers_native_cache_request_facts_regression(self): assert runtime_call in source assert annotated_request in source + def test_public_exports_include_runtime_and_validation_helpers(self): + assert { + "AdaptiveRuntime", + "validate_config", + "build_cache_telemetry_event", + "ConvergenceConfig", + "DriftConfig", + "GovernorConfig", + }.issubset(set(adaptive_module.__all__)) + def test_backend_helpers(self): assert BackendSpec.in_memory().to_dict() == {"kind": "in_memory", "config": {}} assert BackendSpec.redis("redis://127.0.0.1:6379").to_dict() == { @@ -60,6 +73,14 @@ def test_section_helpers(self): assert TelemetryConfig(learners=["latency_sensitivity"]).to_dict() == {"learners": ["latency_sensitivity"]} assert AdaptiveHintsConfig().to_dict()["priority"] == 100 assert ToolParallelismConfig().to_dict()["mode"] == "observe_only" + assert AdaptiveHintsConfig(governor=GovernorConfig(enabled=True)).to_dict()["governor"] == { + "enabled": True, + "epsilon": 1.0, + } + assert ToolParallelismConfig(drift=DriftConfig(enabled=True)).to_dict()["drift"] == { + "enabled": True, + "threshold": 0.75, + } def test_adaptive_component_wraps_as_plugin_component(self): wrapped = ComponentSpec(AdaptiveConfig()).to_dict() @@ -71,6 +92,16 @@ def test_validate_adaptive_plugin_component_warns_missing_state(self): ) assert any(diag["code"] == "adaptive.section_disabled_missing_state" for diag in report["diagnostics"]) + def test_topology_helper_config_is_accepted_by_plugin_validation(self): + config = AdaptiveConfig( + adaptive_hints=AdaptiveHintsConfig(governor=GovernorConfig(enabled=True)), + tool_parallelism=ToolParallelismConfig(drift=DriftConfig(enabled=True)), + acg=AcgConfig(provider="anthropic", convergence=ConvergenceConfig(enabled=True)), + convergence=ConvergenceConfig(enabled=True), + ) + report = plugin.validate(plugin.PluginConfig(components=[ComponentSpec(config)])) + assert not any(diag["code"] == "adaptive.unknown_field" for diag in report["diagnostics"]) + def test_plugin_component_spec_normalizes_lists_of_dataclasses(self): @dataclass class ExampleConfig: @@ -95,6 +126,7 @@ def test_acg_config_exposes_canonical_threshold_shape(self): "observation_window", "priority", "stability_thresholds", + "convergence", ] assert AcgStabilityThresholds().to_dict() == { "stable_threshold": 0.95, @@ -111,6 +143,11 @@ def test_acg_config_exposes_canonical_threshold_shape(self): "min_observations_for_full_confidence": 20, }, } + assert AcgConfig(convergence=ConvergenceConfig(enabled=True)).to_dict()["convergence"] == { + "enabled": True, + "epsilon": 0.001, + "stability_window": 3, + } class TestAdaptivePluginConfiguration: diff --git a/python/tests/test_adaptive_config.py b/python/tests/test_adaptive_config.py index eec629c4f..89b37e441 100644 --- a/python/tests/test_adaptive_config.py +++ b/python/tests/test_adaptive_config.py @@ -15,6 +15,7 @@ BackendSpec, ComponentSpec, ConfigPolicy, + ConvergenceConfig, StateConfig, TelemetryConfig, ToolParallelismConfig, @@ -143,6 +144,11 @@ def test_openai_acg_config_serializes_without_transport_fields(self): "min_observations_for_full_confidence": 20, }, } + assert AcgConfig(convergence=ConvergenceConfig(enabled=True)).to_dict()["convergence"] == { + "enabled": True, + "epsilon": 0.001, + "stability_window": 3, + } def test_acg_config_allows_threshold_overrides(self): assert AcgConfig( From 87dee541b8b72f18de7dc4b1823d7abd4bc0ff24 Mon Sep 17 00:00:00 2001 From: teerthsharma Date: Sun, 28 Jun 2026 05:35:08 +0530 Subject: [PATCH 2/6] fix: address adaptive topology review feedback Signed-off-by: teerthsharma --- crates/adaptive/src/acg_learner.rs | 5 +- .../adaptive/tests/unit/acg_learner_tests.rs | 37 ++++++++++++++ .../unit/adaptive_hints_intercept_tests.rs | 7 ++- .../unit/tool_parallelism_learner_tests.rs | 22 +++++++++ crates/adaptive/tests/unit/types_tests.rs | 12 ++++- crates/node/adaptive.js | 48 ++++++++++++------- crates/node/tests/adaptive_tests.mjs | 20 ++++++++ 7 files changed, 130 insertions(+), 21 deletions(-) diff --git a/crates/adaptive/src/acg_learner.rs b/crates/adaptive/src/acg_learner.rs index d94559e0d..910af54eb 100644 --- a/crates/adaptive/src/acg_learner.rs +++ b/crates/adaptive/src/acg_learner.rs @@ -226,7 +226,10 @@ impl Learner for AcgLearner { // the normal repair path. Requests whose span topology changed // under the same learning key also reopen learning. if let Some(cached) = existing_stability.as_ref().filter(|stability| { - stability.converged + self.convergence + .as_ref() + .is_some_and(|config| config.enabled) + && stability.converged && stability.total_observations as usize >= stability_window && new_observations.iter().all(|observation| { Self::prompt_topology_matches_stability(stability, observation) diff --git a/crates/adaptive/tests/unit/acg_learner_tests.rs b/crates/adaptive/tests/unit/acg_learner_tests.rs index 8eecaba09..130a83f43 100644 --- a/crates/adaptive/tests/unit/acg_learner_tests.rs +++ b/crates/adaptive/tests/unit/acg_learner_tests.rs @@ -711,6 +711,43 @@ async fn acg_learner_reuses_converged_stability_without_loading_observations() { assert!(guard.acg_stability.as_ref().unwrap().converged); } +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_does_not_reuse_converged_stability_when_convergence_disabled() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: false, + epsilon: 0.001, + stability_window: 3, + }), + ); + let request = sample_request("gpt-4o", "Stable system", "Stable prompt"); + let learning_key = derive_acg_learning_key("agent-a", &request); + let seed_observation = build_prompt_ir(&request).unwrap(); + let observations = vec![ + seed_observation.clone(), + seed_observation.clone(), + seed_observation.clone(), + ]; + let mut converged_stability = analyze_stability(&observations, &StabilityThresholds::default()); + converged_stability.converged = true; + + let backend = SeedObservationBackend::new(&learning_key, observations); + backend.seed_stability(&learning_key, converged_stability); + + learner + .process_run(&sample_run(vec![request]), &backend, &empty_cache()) + .await + .unwrap(); + + assert!( + backend.load_observation_count() > 0, + "disabled convergence must not reuse cached converged stability" + ); +} + #[tokio::test(flavor = "current_thread")] async fn acg_learner_reuses_converged_profile_when_suffix_topology_changes() { let learner = AcgLearner::new_with_convergence( diff --git a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs index a2b4e92de..b25daefd3 100644 --- a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs +++ b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs @@ -383,9 +383,12 @@ fn test_adaptive_hints_governor_sheds_low_sensitivity_hints_but_keeps_manual_ove None, ) .unwrap(); + let manual_latency_sensitivity = + &manual_request.content["nvext"]["agent_hints"]["latency_sensitivity"]; + assert_eq!(manual_latency_sensitivity, &serde_json::json!(11.0)); assert_eq!( - manual_request.content["nvext"]["agent_hints"]["latency_sensitivity"], - serde_json::json!(11.0) + &manual_request.headers[AGENT_HINTS_HEADER_KEY]["latency_sensitivity"], + manual_latency_sensitivity ); reset_root_metadata(); diff --git a/crates/adaptive/tests/unit/tool_parallelism_learner_tests.rs b/crates/adaptive/tests/unit/tool_parallelism_learner_tests.rs index 0aac5cd36..2bfca3339 100644 --- a/crates/adaptive/tests/unit/tool_parallelism_learner_tests.rs +++ b/crates/adaptive/tests/unit/tool_parallelism_learner_tests.rs @@ -356,6 +356,28 @@ async fn process_run_invalidates_existing_plan_when_tool_cohort_topology_drifts( "test".to_string(), ] })); + + let cached_plan = hot_cache + .read() + .unwrap() + .plan + .clone() + .expect("hot cache should be refreshed with the drifted plan"); + assert!( + !cached_plan + .parallel_groups + .iter() + .any(|group| group.group_id == "fanout:existing"), + "drifted cohort topology should invalidate stale cached plan groups" + ); + assert!(cached_plan.parallel_groups.iter().any(|group| { + group.tool_names + == vec![ + "compile".to_string(), + "lint".to_string(), + "test".to_string(), + ] + })); } #[tokio::test] diff --git a/crates/adaptive/tests/unit/types_tests.rs b/crates/adaptive/tests/unit/types_tests.rs index 82e4ab734..3ac43fc8c 100644 --- a/crates/adaptive/tests/unit/types_tests.rs +++ b/crates/adaptive/tests/unit/types_tests.rs @@ -41,9 +41,9 @@ fn sample_stability_result() -> StabilityAnalysisResult { observation_count: 4, }], stable_prefix_length: 1, - stable_prefix_fingerprint: None, + stable_prefix_fingerprint: Some("stable-prefix-sha256".to_string()), total_observations: 4, - converged: false, + converged: true, } } @@ -194,8 +194,16 @@ fn hot_cache_serialization_keeps_acg_field_names_stable() { let decoded: HotCache = serde_json::from_value(encoded).unwrap(); assert_eq!(decoded.acg_profiles["profile-a"].stable_prefix_length, 1); + assert_eq!( + decoded.acg_profiles["profile-a"] + .stable_prefix_fingerprint + .as_deref(), + Some("stable-prefix-sha256") + ); + assert!(decoded.acg_profiles["profile-a"].converged); assert_eq!( decoded.acg_stability.as_ref().unwrap().total_observations, 4 ); + assert!(decoded.acg_stability.as_ref().unwrap().converged); } diff --git a/crates/node/adaptive.js b/crates/node/adaptive.js index 5bc5e205e..4a362bdae 100644 --- a/crates/node/adaptive.js +++ b/crates/node/adaptive.js @@ -80,6 +80,16 @@ function telemetryConfig(config = {}) { }; } +function mergeDefined(defaults, config = {}) { + const merged = { ...defaults }; + for (const [key, value] of Object.entries(config)) { + if (value !== undefined) { + merged[key] = value; + } + } + return merged; +} + /** * Create topology-aware hint load-shedding settings with defaults applied. * @@ -87,11 +97,13 @@ function telemetryConfig(config = {}) { * @returns {object} A normalized governor config object. */ function governorConfig(config = {}) { - return { - enabled: false, - epsilon: 1.0, - ...config, - }; + return mergeDefined( + { + enabled: false, + epsilon: 1.0, + }, + config, + ); } /** @@ -101,11 +113,13 @@ function governorConfig(config = {}) { * @returns {object} A normalized drift config object. */ function driftConfig(config = {}) { - return { - enabled: false, - threshold: 0.75, - ...config, - }; + return mergeDefined( + { + enabled: false, + threshold: 0.75, + }, + config, + ); } /** @@ -115,12 +129,14 @@ function driftConfig(config = {}) { * @returns {object} A normalized convergence config object. */ function convergenceConfig(config = {}) { - return { - enabled: false, - epsilon: 0.001, - stability_window: 3, - ...config, - }; + return mergeDefined( + { + enabled: false, + epsilon: 0.001, + stability_window: 3, + }, + config, + ); } /** diff --git a/crates/node/tests/adaptive_tests.mjs b/crates/node/tests/adaptive_tests.mjs index 44d100559..83cf839e7 100644 --- a/crates/node/tests/adaptive_tests.mjs +++ b/crates/node/tests/adaptive_tests.mjs @@ -218,6 +218,26 @@ describe('adaptive helpers', () => { epsilon: 0.001, stability_window: 3, }); + assert.deepEqual(adaptive.governorConfig({ enabled: true, epsilon: undefined }), { + enabled: true, + epsilon: 1.0, + }); + assert.deepEqual(adaptive.driftConfig({ enabled: true, threshold: undefined }), { + enabled: true, + threshold: 0.75, + }); + assert.deepEqual( + adaptive.convergenceConfig({ + enabled: true, + epsilon: undefined, + stability_window: undefined, + }), + { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }, + ); const config = adaptive.defaultConfig(); config.state = { backend: adaptive.inMemoryBackend() }; From 2bf142c3aaae4863852be2f9e6e60c1d8f57b0c8 Mon Sep 17 00:00:00 2001 From: teerthsharma Date: Sun, 28 Jun 2026 15:56:38 +0530 Subject: [PATCH 3/6] fix: reset reopened acg convergence state Signed-off-by: teerthsharma --- crates/adaptive/src/acg_learner.rs | 21 ++++++-- .../adaptive/tests/unit/acg_learner_tests.rs | 49 +++++++++++++++++++ crates/node/adaptive.js | 2 +- crates/node/tests/adaptive_tests.mjs | 13 +++++ 4 files changed, 81 insertions(+), 4 deletions(-) diff --git a/crates/adaptive/src/acg_learner.rs b/crates/adaptive/src/acg_learner.rs index 910af54eb..ec8bc97ce 100644 --- a/crates/adaptive/src/acg_learner.rs +++ b/crates/adaptive/src/acg_learner.rs @@ -219,6 +219,10 @@ impl Learner for AcgLearner { .as_ref() .map(|config| config.stability_window.max(3)) .unwrap_or(3); + let convergence_enabled = self + .convergence + .as_ref() + .is_some_and(|config| config.enabled); // If the profile has already converged, reuse the cached // stability result and skip loading or adding observations. @@ -226,9 +230,7 @@ impl Learner for AcgLearner { // the normal repair path. Requests whose span topology changed // under the same learning key also reopen learning. if let Some(cached) = existing_stability.as_ref().filter(|stability| { - self.convergence - .as_ref() - .is_some_and(|config| config.enabled) + convergence_enabled && stability.converged && stability.total_observations as usize >= stability_window && new_observations.iter().all(|observation| { @@ -261,6 +263,19 @@ impl Learner for AcgLearner { continue; } + if convergence_enabled + && existing_stability + .as_ref() + .is_some_and(|stability| stability.converged) + { + let mut detectors = self.convergence_detectors.write().map_err(|error| { + AdaptiveError::Internal(format!( + "convergence detector lock poisoned: {error}" + )) + })?; + detectors.remove(&profile_key); + } + let existing = backend.load_observations(&profile_key).await?; let mut window: VecDeque = diff --git a/crates/adaptive/tests/unit/acg_learner_tests.rs b/crates/adaptive/tests/unit/acg_learner_tests.rs index 130a83f43..c1e55bb6e 100644 --- a/crates/adaptive/tests/unit/acg_learner_tests.rs +++ b/crates/adaptive/tests/unit/acg_learner_tests.rs @@ -748,6 +748,55 @@ async fn acg_learner_does_not_reuse_converged_stability_when_convergence_disable ); } +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_resets_convergence_detector_when_cached_profile_reopens() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let request = sample_request("gpt-4o", "Stable system", "Stable prompt"); + let learning_key = derive_acg_learning_key("agent-a", &request); + let backend = SeedObservationBackend::empty(); + let hot_cache = empty_cache(); + + for _ in 0..4 { + learner + .process_run(&sample_run(vec![request.clone()]), &backend, &hot_cache) + .await + .unwrap(); + } + + let mut stale_stability = backend + .load_stability(&learning_key) + .await + .unwrap() + .expect("profile should have stored stability"); + assert!(stale_stability.converged); + stale_stability.stable_prefix_fingerprint = Some("stale-prefix-fingerprint".to_string()); + backend.seed_stability(&learning_key, stale_stability); + + learner + .process_run(&sample_run(vec![request]), &backend, &hot_cache) + .await + .unwrap(); + + let reopened_stability = backend + .load_stability(&learning_key) + .await + .unwrap() + .expect("reopened profile should store recomputed stability"); + assert!( + !reopened_stability.converged, + "reopened learning should require a fresh stability window" + ); +} + #[tokio::test(flavor = "current_thread")] async fn acg_learner_reuses_converged_profile_when_suffix_topology_changes() { let learner = AcgLearner::new_with_convergence( diff --git a/crates/node/adaptive.js b/crates/node/adaptive.js index 4a362bdae..3b953f995 100644 --- a/crates/node/adaptive.js +++ b/crates/node/adaptive.js @@ -82,7 +82,7 @@ function telemetryConfig(config = {}) { function mergeDefined(defaults, config = {}) { const merged = { ...defaults }; - for (const [key, value] of Object.entries(config)) { + for (const [key, value] of Object.entries(config ?? {})) { if (value !== undefined) { merged[key] = value; } diff --git a/crates/node/tests/adaptive_tests.mjs b/crates/node/tests/adaptive_tests.mjs index 83cf839e7..1d3bf0b5b 100644 --- a/crates/node/tests/adaptive_tests.mjs +++ b/crates/node/tests/adaptive_tests.mjs @@ -218,6 +218,19 @@ describe('adaptive helpers', () => { epsilon: 0.001, stability_window: 3, }); + assert.deepEqual(adaptive.governorConfig(null), { + enabled: false, + epsilon: 1.0, + }); + assert.deepEqual(adaptive.driftConfig(null), { + enabled: false, + threshold: 0.75, + }); + assert.deepEqual(adaptive.convergenceConfig(null), { + enabled: false, + epsilon: 0.001, + stability_window: 3, + }); assert.deepEqual(adaptive.governorConfig({ enabled: true, epsilon: undefined }), { enabled: true, epsilon: 1.0, From 11697014bed6e3fb41759637554732eab4a03145 Mon Sep 17 00:00:00 2001 From: teerthsharma Date: Sun, 28 Jun 2026 16:44:33 +0530 Subject: [PATCH 4/6] fix: preserve acg aggregate convergence state Signed-off-by: teerthsharma --- crates/adaptive/src/acg_learner.rs | 59 ++++++++------- .../adaptive/tests/unit/acg_learner_tests.rs | 75 +++++++++++++++++++ 2 files changed, 106 insertions(+), 28 deletions(-) diff --git a/crates/adaptive/src/acg_learner.rs b/crates/adaptive/src/acg_learner.rs index ec8bc97ce..473827221 100644 --- a/crates/adaptive/src/acg_learner.rs +++ b/crates/adaptive/src/acg_learner.rs @@ -142,6 +142,18 @@ impl AcgLearner { .all(|(score, block)| score.span_id == block.span_id) } + fn should_replace_aggregate( + candidate: &crate::acg::stability::StabilityAnalysisResult, + current: Option<&crate::acg::stability::StabilityAnalysisResult>, + ) -> bool { + current + .map(|current| { + (candidate.stable_prefix_length, candidate.total_observations) + > (current.stable_prefix_length, current.total_observations) + }) + .unwrap_or(true) + } + /// Update the per-profile topological convergence detector and return /// whether the profile has converged. fn record_stability_epoch( @@ -240,13 +252,8 @@ impl Learner for AcgLearner { profile_counts.insert(profile_key.clone(), cached.total_observations); profile_stability.insert(profile_key.clone(), cached.clone()); - let replace_best = best_aggregate_stability - .as_ref() - .map(|current| { - (cached.stable_prefix_length, cached.total_observations) - > (current.stable_prefix_length, current.total_observations) - }) - .unwrap_or(true); + let replace_best = + Self::should_replace_aggregate(cached, best_aggregate_stability.as_ref()); if replace_best { best_aggregate_stability = Some(cached.clone()); } @@ -263,18 +270,10 @@ impl Learner for AcgLearner { continue; } - if convergence_enabled + let reopen_converged_profile = convergence_enabled && existing_stability .as_ref() - .is_some_and(|stability| stability.converged) - { - let mut detectors = self.convergence_detectors.write().map_err(|error| { - AdaptiveError::Internal(format!( - "convergence detector lock poisoned: {error}" - )) - })?; - detectors.remove(&profile_key); - } + .is_some_and(|stability| stability.converged); let existing = backend.load_observations(&profile_key).await?; @@ -291,8 +290,6 @@ impl Learner for AcgLearner { let observations_vec: Vec = window.into_iter().collect(); let mut stability_result = analyze_stability(&observations_vec, &self.thresholds); - let converged_now = self.record_stability_epoch(&profile_key, &stability_result)?; - // Store the observations that produced this stability result. // On the epoch that first declares convergence these // observations are preserved; on subsequent runs the cached @@ -301,6 +298,17 @@ impl Learner for AcgLearner { .store_observations(&profile_key, &observations_vec) .await?; + if reopen_converged_profile { + let mut detectors = self.convergence_detectors.write().map_err(|error| { + AdaptiveError::Internal(format!( + "convergence detector lock poisoned: {error}" + )) + })?; + detectors.remove(&profile_key); + } + + let converged_now = self.record_stability_epoch(&profile_key, &stability_result)?; + if converged_now { stability_result.converged = true; } @@ -325,15 +333,10 @@ impl Learner for AcgLearner { profile_counts.insert(profile_key.clone(), stability_result.total_observations); profile_stability.insert(profile_key.clone(), stability_result.clone()); - let replace_best = best_profile_seed - .as_ref() - .map(|(_, current)| { - ( - stability_result.stable_prefix_length, - stability_result.total_observations, - ) > (current.stable_prefix_length, current.total_observations) - }) - .unwrap_or(true); + let replace_best = Self::should_replace_aggregate( + &stability_result, + best_aggregate_stability.as_ref(), + ); if replace_best { best_profile_seed = Some((observations_vec.clone(), stability_result.clone())); best_aggregate_stability = Some(stability_result.clone()); diff --git a/crates/adaptive/tests/unit/acg_learner_tests.rs b/crates/adaptive/tests/unit/acg_learner_tests.rs index c1e55bb6e..51eee7b11 100644 --- a/crates/adaptive/tests/unit/acg_learner_tests.rs +++ b/crates/adaptive/tests/unit/acg_learner_tests.rs @@ -311,6 +311,10 @@ impl SeedObservationBackend { self.fail_observation_store.store(true, Ordering::SeqCst); } + fn allow_observation_stores(&self) { + self.fail_observation_store.store(false, Ordering::SeqCst); + } + fn seed_stability( &self, agent_id: &str, @@ -584,6 +588,54 @@ async fn acg_learner_does_not_persist_converged_stability_when_observation_store ); } +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_does_not_advance_convergence_epoch_when_observation_store_fails() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let request = sample_request("gpt-4o", "Stable system", "Stable prompt"); + let learning_key = derive_acg_learning_key("agent-a", &request); + let backend = SeedObservationBackend::empty(); + let hot_cache = empty_cache(); + + learner + .process_run(&sample_run(vec![request.clone()]), &backend, &hot_cache) + .await + .unwrap(); + + backend.fail_observation_stores(); + let error = learner + .process_run(&sample_run(vec![request.clone()]), &backend, &hot_cache) + .await + .unwrap_err(); + assert!( + matches!(error, AdaptiveError::Storage(message) if message.contains("forced observation storage failure")) + ); + + backend.allow_observation_stores(); + learner + .process_run(&sample_run(vec![request]), &backend, &hot_cache) + .await + .unwrap(); + + let recovered_stability = backend + .load_stability(&learning_key) + .await + .unwrap() + .expect("stability should be stored after recovery"); + assert!( + !recovered_stability.converged, + "failed observation storage must not advance the in-memory convergence epoch" + ); +} + #[tokio::test(flavor = "current_thread")] async fn acg_learner_repairs_converged_stability_without_observations() { let learner = AcgLearner::new_with_convergence( @@ -711,6 +763,29 @@ async fn acg_learner_reuses_converged_stability_without_loading_observations() { assert!(guard.acg_stability.as_ref().unwrap().converged); } +#[test] +fn acg_learner_keeps_stronger_cached_aggregate_over_weaker_normal_candidate() { + let cached = crate::acg::stability::StabilityAnalysisResult { + scores: vec![stable_score(0), stable_score(1), stable_score(2)], + stable_prefix_length: 3, + stable_prefix_fingerprint: None, + total_observations: 3, + converged: true, + }; + let weaker_normal = crate::acg::stability::StabilityAnalysisResult { + scores: vec![stable_score(0)], + stable_prefix_length: 1, + stable_prefix_fingerprint: None, + total_observations: 20, + converged: false, + }; + + assert!( + !AcgLearner::should_replace_aggregate(&weaker_normal, Some(&cached)), + "normal candidates should compare against the current cached aggregate winner" + ); +} + #[tokio::test(flavor = "current_thread")] async fn acg_learner_does_not_reuse_converged_stability_when_convergence_disabled() { let learner = AcgLearner::new_with_convergence( From 5147bc9766ef083e2bb4a2654471ec4ef0fe49af Mon Sep 17 00:00:00 2001 From: teerthsharma Date: Sun, 28 Jun 2026 16:56:19 +0530 Subject: [PATCH 5/6] fix: preserve converged acg aggregate stability Signed-off-by: teerthsharma --- crates/adaptive/src/acg_learner.rs | 14 ++++++++--- .../adaptive/tests/unit/acg_learner_tests.rs | 23 +++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/crates/adaptive/src/acg_learner.rs b/crates/adaptive/src/acg_learner.rs index 473827221..a33677394 100644 --- a/crates/adaptive/src/acg_learner.rs +++ b/crates/adaptive/src/acg_learner.rs @@ -148,8 +148,15 @@ impl AcgLearner { ) -> bool { current .map(|current| { - (candidate.stable_prefix_length, candidate.total_observations) - > (current.stable_prefix_length, current.total_observations) + ( + candidate.stable_prefix_length, + candidate.converged, + candidate.total_observations, + ) > ( + current.stable_prefix_length, + current.converged, + current.total_observations, + ) }) .unwrap_or(true) } @@ -288,7 +295,6 @@ impl Learner for AcgLearner { } let observations_vec: Vec = window.into_iter().collect(); - let mut stability_result = analyze_stability(&observations_vec, &self.thresholds); // Store the observations that produced this stability result. // On the epoch that first declares convergence these @@ -307,6 +313,8 @@ impl Learner for AcgLearner { detectors.remove(&profile_key); } + let mut stability_result = analyze_stability(&observations_vec, &self.thresholds); + let converged_now = self.record_stability_epoch(&profile_key, &stability_result)?; if converged_now { diff --git a/crates/adaptive/tests/unit/acg_learner_tests.rs b/crates/adaptive/tests/unit/acg_learner_tests.rs index 51eee7b11..f8e937058 100644 --- a/crates/adaptive/tests/unit/acg_learner_tests.rs +++ b/crates/adaptive/tests/unit/acg_learner_tests.rs @@ -786,6 +786,29 @@ fn acg_learner_keeps_stronger_cached_aggregate_over_weaker_normal_candidate() { ); } +#[test] +fn acg_learner_keeps_converged_aggregate_when_prefix_ties_non_converged_candidate() { + let current = crate::acg::stability::StabilityAnalysisResult { + scores: vec![stable_score(0), stable_score(1), stable_score(2)], + stable_prefix_length: 3, + stable_prefix_fingerprint: None, + total_observations: 3, + converged: true, + }; + let candidate = crate::acg::stability::StabilityAnalysisResult { + scores: vec![stable_score(0), stable_score(1), stable_score(2)], + stable_prefix_length: 3, + stable_prefix_fingerprint: None, + total_observations: 20, + converged: false, + }; + + assert!( + !AcgLearner::should_replace_aggregate(&candidate, Some(¤t)), + "a converged aggregate should not regress to a non-converged candidate when the stable prefix ties" + ); +} + #[tokio::test(flavor = "current_thread")] async fn acg_learner_does_not_reuse_converged_stability_when_convergence_disabled() { let learner = AcgLearner::new_with_convergence( From 487635cace0c6016e7c7dcfdfe64926b189bb587 Mon Sep 17 00:00:00 2001 From: teerthsharma Date: Sun, 28 Jun 2026 18:44:09 +0530 Subject: [PATCH 6/6] fix: keep acg convergence persistence transactional Signed-off-by: teerthsharma --- crates/adaptive/src/acg_learner.rs | 61 +++++++++- .../adaptive/tests/unit/acg_learner_tests.rs | 108 ++++++++++++++++++ 2 files changed, 164 insertions(+), 5 deletions(-) diff --git a/crates/adaptive/src/acg_learner.rs b/crates/adaptive/src/acg_learner.rs index a33677394..4cc186236 100644 --- a/crates/adaptive/src/acg_learner.rs +++ b/crates/adaptive/src/acg_learner.rs @@ -192,6 +192,32 @@ impl AcgLearner { let enough_epochs = detector.epoch() as usize >= stability_window; Ok(detector.is_converged() && enough_epochs) } + + fn snapshot_convergence_detector( + &self, + profile_key: &str, + ) -> Result> { + let detectors = self.convergence_detectors.read().map_err(|error| { + AdaptiveError::Internal(format!("convergence detector lock poisoned: {error}")) + })?; + Ok(detectors.get(profile_key).copied()) + } + + fn restore_convergence_detector( + &self, + profile_key: &str, + previous: Option, + ) -> Result<()> { + let mut detectors = self.convergence_detectors.write().map_err(|error| { + AdaptiveError::Internal(format!("convergence detector lock poisoned: {error}")) + })?; + if let Some(detector) = previous { + detectors.insert(profile_key.to_string(), detector); + } else { + detectors.remove(profile_key); + } + Ok(()) + } } impl Learner for AcgLearner { @@ -262,6 +288,9 @@ impl Learner for AcgLearner { let replace_best = Self::should_replace_aggregate(cached, best_aggregate_stability.as_ref()); if replace_best { + // Cached reuse has no fresh aggregate observation seed; clear any + // older seed so it cannot be paired with the cached winner. + best_profile_seed = None; best_aggregate_stability = Some(cached.clone()); } acg_debug::emit( @@ -304,6 +333,12 @@ impl Learner for AcgLearner { .store_observations(&profile_key, &observations_vec) .await?; + let detector_before_epoch = if convergence_enabled { + Some(self.snapshot_convergence_detector(&profile_key)?) + } else { + None + }; + if reopen_converged_profile { let mut detectors = self.convergence_detectors.write().map_err(|error| { AdaptiveError::Internal(format!( @@ -315,15 +350,30 @@ impl Learner for AcgLearner { let mut stability_result = analyze_stability(&observations_vec, &self.thresholds); - let converged_now = self.record_stability_epoch(&profile_key, &stability_result)?; + let converged_now = + match self.record_stability_epoch(&profile_key, &stability_result) { + Ok(converged_now) => converged_now, + Err(error) => { + if let Some(previous) = detector_before_epoch { + self.restore_convergence_detector(&profile_key, previous)?; + } + return Err(error); + } + }; if converged_now { stability_result.converged = true; } - backend + if let Err(error) = backend .store_stability(&profile_key, &stability_result) - .await?; + .await + { + if let Some(previous) = detector_before_epoch { + self.restore_convergence_detector(&profile_key, previous)?; + } + return Err(error); + } acg_debug::emit( "learner_profile_updated", @@ -351,13 +401,14 @@ impl Learner for AcgLearner { } } - if let Some((aggregate_observations, aggregate_stability)) = best_profile_seed.as_ref() - { + if let Some((aggregate_observations, _)) = best_profile_seed.as_ref() { // Persist the runtime seed entry under plain agent_id so registration can // rehydrate HotCache without scanning profile-specific keys. backend .store_observations(&self.agent_id, aggregate_observations) .await?; + } + if let Some(aggregate_stability) = best_aggregate_stability.as_ref() { backend .store_stability(&self.agent_id, aggregate_stability) .await?; diff --git a/crates/adaptive/tests/unit/acg_learner_tests.rs b/crates/adaptive/tests/unit/acg_learner_tests.rs index f8e937058..ece5a0ceb 100644 --- a/crates/adaptive/tests/unit/acg_learner_tests.rs +++ b/crates/adaptive/tests/unit/acg_learner_tests.rs @@ -284,6 +284,7 @@ struct SeedObservationBackend { observations: std::sync::RwLock>>, stability: std::sync::RwLock>, fail_observation_store: AtomicBool, + fail_stability_store: AtomicBool, load_observation_count: AtomicUsize, } @@ -293,6 +294,7 @@ impl SeedObservationBackend { observations: std::sync::RwLock::new(HashMap::new()), stability: std::sync::RwLock::new(HashMap::new()), fail_observation_store: AtomicBool::new(false), + fail_stability_store: AtomicBool::new(false), load_observation_count: AtomicUsize::new(0), } } @@ -315,6 +317,14 @@ impl SeedObservationBackend { self.fail_observation_store.store(false, Ordering::SeqCst); } + fn fail_stability_stores(&self) { + self.fail_stability_store.store(true, Ordering::SeqCst); + } + + fn allow_stability_stores(&self) { + self.fail_stability_store.store(false, Ordering::SeqCst); + } + fn seed_stability( &self, agent_id: &str, @@ -418,6 +428,11 @@ impl StorageBackendDyn for SeedObservationBackend { ) -> Pin> + Send + 'a>> { let result = result.clone(); Box::pin(async move { + if self.fail_stability_store.load(Ordering::SeqCst) { + return Err(AdaptiveError::Storage( + "forced stability storage failure".to_string(), + )); + } self.stability .write() .unwrap() @@ -763,6 +778,45 @@ async fn acg_learner_reuses_converged_stability_without_loading_observations() { assert!(guard.acg_stability.as_ref().unwrap().converged); } +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_persists_cached_winner_as_agent_stability() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let request = sample_request("gpt-4o", "Stable system", "Stable prompt"); + let learning_key = derive_acg_learning_key("agent-a", &request); + let seed_observation = build_prompt_ir(&request).unwrap(); + let observations = vec![ + seed_observation.clone(), + seed_observation.clone(), + seed_observation.clone(), + ]; + let mut converged_stability = analyze_stability(&observations, &StabilityThresholds::default()); + converged_stability.converged = true; + + let backend = SeedObservationBackend::new(&learning_key, observations); + backend.seed_stability(&learning_key, converged_stability.clone()); + + learner + .process_run(&sample_run(vec![request]), &backend, &empty_cache()) + .await + .unwrap(); + + let agent_stability = backend + .load_stability("agent-a") + .await + .unwrap() + .expect("cached aggregate winner should be persisted under the base agent id"); + assert_eq!(agent_stability, converged_stability); +} + #[test] fn acg_learner_keeps_stronger_cached_aggregate_over_weaker_normal_candidate() { let cached = crate::acg::stability::StabilityAnalysisResult { @@ -895,6 +949,60 @@ async fn acg_learner_resets_convergence_detector_when_cached_profile_reopens() { ); } +#[tokio::test(flavor = "current_thread")] +async fn acg_learner_rolls_back_convergence_detector_when_stability_store_fails() { + let learner = AcgLearner::new_with_convergence( + "agent-a", + 20, + StabilityThresholds::default(), + Some(ConvergenceConfig { + enabled: true, + epsilon: 0.001, + stability_window: 3, + }), + ); + let request = sample_request("gpt-4o", "Stable system", "Stable prompt"); + let learning_key = derive_acg_learning_key("agent-a", &request); + let backend = SeedObservationBackend::empty(); + let hot_cache = empty_cache(); + + learner + .process_run(&sample_run(vec![request.clone()]), &backend, &hot_cache) + .await + .unwrap(); + + let epoch_before_failure = learner + .convergence_detectors + .read() + .unwrap() + .get(&learning_key) + .expect("first successful run should create a detector") + .epoch(); + assert_eq!(epoch_before_failure, 1); + + backend.fail_stability_stores(); + let error = learner + .process_run(&sample_run(vec![request]), &backend, &hot_cache) + .await + .unwrap_err(); + assert!( + matches!(error, AdaptiveError::Storage(message) if message.contains("forced stability storage failure")) + ); + backend.allow_stability_stores(); + + let epoch_after_failure = learner + .convergence_detectors + .read() + .unwrap() + .get(&learning_key) + .expect("failed stability persistence should restore the previous detector") + .epoch(); + assert_eq!( + epoch_after_failure, epoch_before_failure, + "failed stability persistence must not leave the in-memory detector ahead of storage" + ); +} + #[tokio::test(flavor = "current_thread")] async fn acg_learner_reuses_converged_profile_when_suffix_topology_changes() { let learner = AcgLearner::new_with_convergence(