diff --git a/crates/core/src/api/llm.rs b/crates/core/src/api/llm.rs index 0af799d04..7910222e5 100644 --- a/crates/core/src/api/llm.rs +++ b/crates/core/src/api/llm.rs @@ -3,12 +3,13 @@ use std::sync::Arc; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, TimeDelta, Utc}; use serde::{Deserialize, Serialize}; use serde_json::json; use typed_builder::TypedBuilder; use uuid::Uuid; +use crate::api::event::{BaseEvent, Event, MarkEvent, PendingMarkSpec}; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::current_scope_stack; use crate::api::runtime::global_context; @@ -28,7 +29,7 @@ use crate::error::{FlowError, Result}; use crate::json::Json; use crate::stream::LlmStreamWrapper; -pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; /// Runtime-owned handle identifying an active or completed LLM call. #[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)] @@ -298,6 +299,35 @@ fn emit_llm_start( Ok(()) } +fn emit_pending_request_marks(handle: &LlmHandle, marks: Vec) -> Result<()> { + if marks.is_empty() { + return Ok(()); + } + ensure_runtime_owner()?; + let subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? + }; + for (index, mark) in marks.into_iter().enumerate() { + let timestamp = handle.started_at + + TimeDelta::microseconds(i64::try_from(index).unwrap_or_default() + 1); + let event = Event::Mark(MarkEvent::new( + BaseEvent::builder() + .name(mark.name) + .parent_uuid(handle.uuid) + .timestamp(timestamp) + .data_opt(mark.data) + .metadata_opt(mark.metadata) + .build(), + mark.category, + mark.category_profile, + )); + NemoRelayContextState::emit_event(&event, &subscribers); + } + Ok(()) +} + /// Start a manual LLM lifecycle span. /// /// This emits an LLM-start event after applying sanitize-request guardrails to @@ -587,7 +617,7 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { } let request_codec = codec.clone(); - let (intercepted_request, annotated_request) = + let (intercepted_request, annotated_request, pending_marks) = run_request_intercepts_with_codec(&name, request, codec)?; let handle = create_llm_handle( @@ -606,6 +636,7 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { annotated_request.clone(), request_codec.as_deref(), )?; + emit_pending_request_marks(&handle, pending_marks)?; let execution = { let scope_stack = current_scope_stack(); @@ -743,7 +774,7 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu } let request_codec = codec.clone(); - let (intercepted_request, annotated_request) = + let (intercepted_request, annotated_request, pending_marks) = run_request_intercepts_with_codec(&name, request, codec)?; let handle = create_llm_handle( @@ -762,6 +793,7 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu annotated_request, request_codec.as_deref(), )?; + emit_pending_request_marks(&handle, pending_marks)?; let execution = { let scope_stack = current_scope_stack(); @@ -818,6 +850,17 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu /// Conditional guardrails, codecs, and execution intercepts are not run by /// this helper. pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result { + Ok(llm_request_intercepts_with_marks(name, request)?.request) +} + +/// Run the LLM request-intercept chain and return pending lifecycle marks. +/// +/// This helper does not emit the returned marks because it does not own an LLM +/// lifecycle. Callers must attach them to the lifecycle they own. +pub fn llm_request_intercepts_with_marks( + name: &str, + request: LlmRequest, +) -> Result { ensure_runtime_owner()?; let entries = { let scope_stack = current_scope_stack(); @@ -830,10 +873,7 @@ pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result Result<()> { + register_llm_request_intercept_with_marks( + name, + priority, + break_chain, + std::sync::Arc::new(move |name, request, annotated| { + callable(name, request, annotated).map(Into::into) + }), + ) +} global_execution_registry_api!( /// Register a global LLM execution intercept. /// Execution intercepts can wrap or replace the non-streaming provider @@ -653,15 +669,32 @@ scope_guardrail_registry_api!( LlmConditionalFn ); scope_intercept_registry_api!( - /// Register a scope-local LLM request intercept. - /// Request intercepts can rewrite or annotate LLM requests inside the - /// owning scope. - scope_register_llm_request_intercept, + /// Register a scope-local LLM request intercept that can schedule lifecycle marks. + scope_register_llm_request_intercept_with_marks, /// Deregister a scope-local LLM request intercept. scope_deregister_llm_request_intercept, llm_request_intercepts, - LlmRequestInterceptFn + LlmRequestInterceptWithMarksFn ); + +/// Register a scope-local LLM request intercept without pending marks. +pub fn scope_register_llm_request_intercept( + scope_uuid: &uuid::Uuid, + name: &str, + priority: i32, + break_chain: bool, + callable: LlmRequestInterceptFn, +) -> Result<()> { + scope_register_llm_request_intercept_with_marks( + scope_uuid, + name, + priority, + break_chain, + std::sync::Arc::new(move |name, request, annotated| { + callable(name, request, annotated).map(Into::into) + }), + ) +} scope_execution_registry_api!( /// Register a scope-local LLM execution intercept. /// Execution intercepts can wrap or replace the non-streaming provider diff --git a/crates/core/src/api/runtime.rs b/crates/core/src/api/runtime.rs index 2351ae352..4d0c372c4 100644 --- a/crates/core/src/api/runtime.rs +++ b/crates/core/src/api/runtime.rs @@ -11,9 +11,9 @@ pub mod subscriber_dispatcher; pub use callbacks::{ EventSubscriberFn, LlmCollectorFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, - LlmFinalizerFn, LlmJsonStream, LlmRequestInterceptFn, LlmSanitizeRequestFn, - LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, ToolConditionalFn, - ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, + LlmFinalizerFn, LlmJsonStream, LlmRequestInterceptFn, LlmRequestInterceptWithMarksFn, + LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, + ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; pub use global::global_context; pub use scope_stack::{ diff --git a/crates/core/src/api/runtime/callbacks.rs b/crates/core/src/api/runtime/callbacks.rs index 980c47d74..dfccd2fea 100644 --- a/crates/core/src/api/runtime/callbacks.rs +++ b/crates/core/src/api/runtime/callbacks.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use tokio_stream::Stream; use crate::api::event::Event; -use crate::api::llm::LlmRequest; +use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::codec::request::AnnotatedLlmRequest; use crate::error::Result; use crate::json::Json; @@ -177,6 +177,16 @@ pub type LlmRequestInterceptFn = Arc< + Send + Sync, >; +/// Rewrite or annotate an LLM request and schedule marks under its future scope. +/// +/// This callback has the same inputs as [`LlmRequestInterceptFn`] but returns a +/// structured outcome whose pending marks are emitted after the LLM-start +/// event and before provider execution. +pub type LlmRequestInterceptWithMarksFn = Arc< + dyn Fn(&str, LlmRequest, Option) -> Result + + Send + + Sync, +>; /// Continuation type invoked by non-streaming LLM execution intercepts. /// /// Execution intercepts use this callable to continue the non-streaming LLM diff --git a/crates/core/src/api/runtime/state.rs b/crates/core/src/api/runtime/state.rs index 70276786d..ca03d3b29 100644 --- a/crates/core/src/api/runtime/state.rs +++ b/crates/core/src/api/runtime/state.rs @@ -20,10 +20,10 @@ use crate::api::llm::{CreateLlmHandleParams, EndLlmHandleParams}; use crate::api::llm::{LlmHandle, LlmRequest}; use crate::api::registry::{ExecutionIntercept, Guardrail, Intercept}; use crate::api::runtime::callbacks::{ - EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmRequestInterceptFn, - LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, - LlmStreamExecutionRegistryRefs, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, - ToolInterceptFn, ToolSanitizeFn, + EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, + LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, + LlmStreamExecutionFn, LlmStreamExecutionNextFn, LlmStreamExecutionRegistryRefs, + ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; use crate::api::runtime::subscriber_dispatcher; use crate::api::scope::{CreateScopeHandleParams, EndScopeHandleParams, ScopeHandle, ScopeType}; @@ -63,7 +63,7 @@ pub struct NemoRelayContextState { /// Global LLM guardrails that can reject execution before the provider callback runs. pub(crate) llm_conditional_execution_guardrails: SortedRegistry>, /// Global LLM request intercepts that can rewrite or annotate requests. - pub(crate) llm_request_intercepts: SortedRegistry>, + pub(crate) llm_request_intercepts: SortedRegistry>, /// Global non-streaming LLM execution intercepts that wrap callback execution. pub(crate) llm_execution_intercepts: SortedRegistry>, /// Global streaming LLM execution intercepts that wrap stream-producing callbacks. @@ -1011,8 +1011,8 @@ impl NemoRelayContextState { /// are released. pub(crate) fn llm_request_intercept_entries( &self, - scope_locals: &[&SortedRegistry>], - ) -> Vec> { + scope_locals: &[&SortedRegistry>], + ) -> Vec> { merge_intercept_entries(&self.llm_request_intercepts, scope_locals) .into_iter() .cloned() @@ -1041,20 +1041,25 @@ impl NemoRelayContextState { name: &str, request: LlmRequest, annotated: Option, - entries: &[Intercept], - ) -> crate::error::Result<(LlmRequest, Option)> { + entries: &[Intercept], + ) -> crate::error::Result { let mut request_value = request; let mut annotated_value = annotated; + let mut pending_marks = Vec::new(); for entry in entries { - let (new_request, new_annotated) = - (entry.payload.callable)(name, request_value, annotated_value)?; - request_value = new_request; - annotated_value = new_annotated; + let outcome = (entry.payload.callable)(name, request_value, annotated_value)?; + request_value = outcome.request; + annotated_value = outcome.annotated_request; + pending_marks.extend(outcome.pending_marks); if entry.payload.break_chain { break; } } - Ok((request_value, annotated_value)) + Ok(crate::api::llm::LlmRequestInterceptOutcome { + request: request_value, + annotated_request: annotated_value, + pending_marks, + }) } /// Build the composed non-streaming LLM execution continuation chain. diff --git a/crates/core/src/api/shared.rs b/crates/core/src/api/shared.rs index 861cd41c2..53f14a599 100644 --- a/crates/core/src/api/shared.rs +++ b/crates/core/src/api/shared.rs @@ -74,7 +74,11 @@ pub(crate) fn run_request_intercepts_with_codec( name: &str, request: LlmRequest, codec: Option>, -) -> Result<(LlmRequest, Option>)> { +) -> Result<( + LlmRequest, + Option>, + Vec, +)> { let original = request.clone(); let annotated = match &codec { Some(codec) => Some(codec.decode(&request)?), @@ -94,18 +98,19 @@ pub(crate) fn run_request_intercepts_with_codec( state.llm_request_intercept_entries(&scope_locals) }; - let (intercepted_request, intercepted_annotated) = + let outcome = crate::api::runtime::NemoRelayContextState::llm_request_intercepts_snapshot_chain( name, request, annotated, &entries, )?; + let pending_marks = outcome.pending_marks; - match (codec, intercepted_annotated) { + match (codec, outcome.annotated_request) { (Some(codec), Some(annotated)) => { let mut encoded = codec.encode(&annotated, &original)?; - encoded.headers = intercepted_request.headers; - Ok((encoded, Some(Arc::new(annotated)))) + encoded.headers = outcome.request.headers; + Ok((encoded, Some(Arc::new(annotated)), pending_marks)) } - _ => Ok((intercepted_request, None)), + (_, annotated) => Ok((outcome.request, annotated.map(Arc::new), pending_marks)), } } diff --git a/crates/core/src/context/registries.rs b/crates/core/src/context/registries.rs index 2a0d2fde9..c6781eee0 100644 --- a/crates/core/src/context/registries.rs +++ b/crates/core/src/context/registries.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; use crate::api::registry::{ExecutionIntercept, Guardrail, Intercept}; use crate::api::runtime::{ - EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, + EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn, ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, }; @@ -41,7 +41,7 @@ pub(crate) struct ScopeLocalRegistries { /// LLM guardrails that can reject execution before the provider callback runs. pub(crate) llm_conditional_execution_guardrails: SortedRegistry>, /// LLM request intercepts that can rewrite or annotate requests. - pub(crate) llm_request_intercepts: SortedRegistry>, + pub(crate) llm_request_intercepts: SortedRegistry>, /// Non-streaming LLM execution intercepts that wrap callback execution. pub(crate) llm_execution_intercepts: SortedRegistry>, /// Streaming LLM execution intercepts that wrap stream-producing callbacks. diff --git a/crates/core/src/plugin.rs b/crates/core/src/plugin.rs index 94420f6df..c2934e3d0 100644 --- a/crates/core/src/plugin.rs +++ b/crates/core/src/plugin.rs @@ -27,15 +27,16 @@ use crate::api::registry::{ deregister_tool_request_intercept, deregister_tool_sanitize_request_guardrail, deregister_tool_sanitize_response_guardrail, register_llm_conditional_execution_guardrail, register_llm_execution_intercept, register_llm_request_intercept, - register_llm_sanitize_request_guardrail, register_llm_sanitize_response_guardrail, - register_llm_stream_execution_intercept, register_tool_conditional_execution_guardrail, - register_tool_execution_intercept, register_tool_request_intercept, - register_tool_sanitize_request_guardrail, register_tool_sanitize_response_guardrail, + register_llm_request_intercept_with_marks, register_llm_sanitize_request_guardrail, + register_llm_sanitize_response_guardrail, register_llm_stream_execution_intercept, + register_tool_conditional_execution_guardrail, register_tool_execution_intercept, + register_tool_request_intercept, register_tool_sanitize_request_guardrail, + register_tool_sanitize_response_guardrail, }; use crate::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, - LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn, - ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, + LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, + LlmStreamExecutionFn, ToolConditionalFn, ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, }; use crate::api::subscriber::{deregister_subscriber, register_subscriber}; pub use nemo_relay_types::plugin::{ConfigDiagnostic, DiagnosticLevel}; @@ -350,6 +351,37 @@ impl PluginRegistrationContext { Ok(()) } + /// Registers an LLM request intercept that can schedule lifecycle marks. + pub fn register_llm_request_intercept_with_marks( + &mut self, + name: &str, + priority: i32, + break_chain: bool, + callback: LlmRequestInterceptWithMarksFn, + ) -> Result<()> { + let qualified_name = self.qualify_name(name); + register_llm_request_intercept_with_marks(&qualified_name, priority, break_chain, callback) + .map_err(|err| { + PluginError::RegistrationFailed(format!("llm request intercept: {err}")) + })?; + + let name_owned = qualified_name; + self.registrations.push(PluginRegistration::new( + "plugin", + name_owned.clone(), + Box::new(move || { + deregister_llm_request_intercept(&name_owned) + .map(|_| ()) + .map_err(|err| { + PluginError::RegistrationFailed(format!( + "llm request intercept deregistration failed: {err}" + )) + }) + }), + )); + Ok(()) + } + /// Registers a tool sanitize-request guardrail and records its rollback closure. pub fn register_tool_sanitize_request_guardrail( &mut self, diff --git a/crates/core/src/plugin/dynamic/native.rs b/crates/core/src/plugin/dynamic/native.rs index 0f1c14444..6582bef95 100644 --- a/crates/core/src/plugin/dynamic/native.rs +++ b/crates/core/src/plugin/dynamic/native.rs @@ -26,18 +26,19 @@ use nemo_relay_plugin::{ NemoRelayNativeWithScopeStackCb, NemoRelayStatus, }; use semver::{Version, VersionReq}; +use serde::Deserialize; use serde_json::{Map, Value as Json}; use sha2::{Digest, Sha256}; use tokio::runtime::Runtime; use tokio_stream::{Stream, StreamExt}; -use crate::api::event::Event; -use crate::api::llm::LlmRequest; +use crate::api::event::{Event, PendingMarkSpec}; +use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmJsonStream, - LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, - LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, - ToolInterceptFn, ToolSanitizeFn, + LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, + LlmStreamExecutionFn, LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionFn, + ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; use crate::api::runtime::{ ScopeStackHandle, ThreadScopeStackBinding, capture_thread_scope_stack, create_scope_stack, @@ -1332,7 +1333,7 @@ unsafe extern "C" fn native_plugin_context_register_llm_request_intercept( Ok(name) => name, Err(status) => return status, }; - match ctx.register_llm_request_intercept( + match ctx.register_llm_request_intercept_with_marks( &name, priority, break_chain, @@ -1690,7 +1691,7 @@ fn wrap_llm_request_intercept_fn( cb: NemoRelayNativeLlmRequestInterceptCb, user_data: *mut c_void, free_fn: NemoRelayNativeFreeFn, -) -> LlmRequestInterceptFn { +) -> LlmRequestInterceptWithMarksFn { let user_data = make_user_data(instance, user_data, free_fn); Arc::new(move |name, request, annotated| { clear_native_last_error(); @@ -1756,17 +1757,55 @@ fn wrap_llm_request_intercept_fn( let annotated_json = annotated_json?; let request: LlmRequest = serde_json::from_value(request_json) .map_err(|err| FlowError::Internal(format!("invalid LLM request JSON: {err}")))?; - let annotated = annotated_json - .map(|annotated_json| { - serde_json::from_value::(annotated_json).map_err(|err| { - FlowError::Internal(format!("invalid annotated request JSON: {err}")) - }) - }) - .transpose()?; - Ok((request, annotated)) + let (annotated_request, pending_marks) = match annotated_json { + Some(value) + if value.get(nemo_relay_types::api::llm::NATIVE_LLM_INTERCEPT_OUTCOME_FIELD) + == Some(&Json::Bool(true)) => + { + let metadata: NativeLlmRequestInterceptOutcome = serde_json::from_value(value) + .map_err(|err| { + FlowError::Internal(format!("invalid marked LLM outcome JSON: {err}")) + })?; + (metadata.annotated_request, metadata.pending_marks) + } + Some(value) => { + let annotated = + serde_json::from_value::(value).map_err(|err| { + FlowError::Internal(format!("invalid annotated request JSON: {err}")) + })?; + (Some(annotated), Vec::new()) + } + None => (None, Vec::new()), + }; + Ok(LlmRequestInterceptOutcome { + request, + annotated_request, + pending_marks, + }) }) } +#[derive(Deserialize)] +struct NativeLlmRequestInterceptOutcome { + #[serde(rename = "__nemo_relay_llm_intercept_outcome")] + _marked_outcome: bool, + annotated_request: Option, + #[serde(default)] + pending_marks: Vec, +} + +#[cfg(test)] +#[test] +fn native_llm_request_intercept_outcome_defaults_omitted_pending_marks() { + let outcome: NativeLlmRequestInterceptOutcome = serde_json::from_value(serde_json::json!({ + "__nemo_relay_llm_intercept_outcome": true + })) + .unwrap(); + + assert!(outcome.annotated_request.is_none()); + assert!(outcome.pending_marks.is_empty()); +} + fn wrap_llm_execution_fn( instance: Arc, cb: NemoRelayNativeLlmExecutionCb, diff --git a/crates/core/tests/fixtures/native_plugin/src/lib.rs b/crates/core/tests/fixtures/native_plugin/src/lib.rs index 9fc58bd41..4eeebd5ba 100644 --- a/crates/core/tests/fixtures/native_plugin/src/lib.rs +++ b/crates/core/tests/fixtures/native_plugin/src/lib.rs @@ -5,10 +5,10 @@ use std::ffi::c_void; use std::ptr; use nemo_relay_plugin::{ - ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmRequest, - NemoRelayNativeHostApiV1, NemoRelayNativePluginContext, NemoRelayNativePluginV1, - NemoRelayNativeString, NemoRelayStatus, NativePlugin, PluginContext, PluginRuntime, - ScopeCategory, ScopeType, + CategoryProfile, ConfigDiagnostic, DiagnosticLevel, Event, EventCategory, Json, LlmJsonStream, + LlmRequest, LlmRequestInterceptOutcome, NemoRelayNativeHostApiV1, + NemoRelayNativePluginContext, NemoRelayNativePluginV1, NemoRelayNativeString, NemoRelayStatus, + NativePlugin, PendingMarkSpec, PluginContext, PluginRuntime, ScopeCategory, ScopeType, }; use serde_json::{Map, json}; @@ -114,14 +114,26 @@ impl NativePlugin for FixtureNativePlugin { 0, |_request| Ok(None), )?; - ctx.register_llm_request_intercept( + ctx.register_llm_request_intercept_with_marks( "fixture_llm_request_intercept", 0, false, |_name, request, annotated| { - Ok(( + Ok(LlmRequestInterceptOutcome::new( mark_llm_request(request, "native_plugin_llm_request_intercept"), annotated, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("fixture.native.llm_request.mark") + .category(EventCategory::custom()) + .category_profile(CategoryProfile { + subtype: Some("fixture.native.pending".into()), + ..CategoryProfile::default() + }) + .data(json!({ "source": "native_request_intercept" })) + .metadata(json!({ "fixture": true })) + .build(), )) }, )?; diff --git a/crates/core/tests/integration/middleware_tests.rs b/crates/core/tests/integration/middleware_tests.rs index 612c5a83d..0394f6c10 100644 --- a/crates/core/tests/integration/middleware_tests.rs +++ b/crates/core/tests/integration/middleware_tests.rs @@ -14,12 +14,14 @@ use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use futures::StreamExt; -use nemo_relay::api::event::{Event, ScopeCategory}; -use nemo_relay::api::llm::LlmRequest; +use nemo_relay::api::event::{ + CategoryProfile, Event, EventCategory, PendingMarkSpec, ScopeCategory, +}; use nemo_relay::api::llm::{ LlmCallExecuteParams, LlmStreamCallExecuteParams, llm_call_execute, llm_request_intercepts, - llm_stream_call_execute, + llm_request_intercepts_with_marks, llm_stream_call_execute, }; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::registry::{ deregister_llm_conditional_execution_guardrail, deregister_llm_execution_intercept, deregister_llm_request_intercept, deregister_llm_sanitize_request_guardrail, @@ -28,13 +30,14 @@ use nemo_relay::api::registry::{ deregister_tool_request_intercept, deregister_tool_sanitize_request_guardrail, deregister_tool_sanitize_response_guardrail, register_llm_conditional_execution_guardrail, register_llm_execution_intercept, register_llm_request_intercept, - register_llm_sanitize_request_guardrail, register_llm_sanitize_response_guardrail, - register_llm_stream_execution_intercept, register_tool_conditional_execution_guardrail, - register_tool_execution_intercept, register_tool_request_intercept, - register_tool_sanitize_request_guardrail, register_tool_sanitize_response_guardrail, - scope_register_llm_conditional_execution_guardrail, scope_register_llm_execution_intercept, - scope_register_llm_request_intercept, scope_register_llm_sanitize_request_guardrail, - scope_register_llm_sanitize_response_guardrail, scope_register_llm_stream_execution_intercept, + register_llm_request_intercept_with_marks, register_llm_sanitize_request_guardrail, + register_llm_sanitize_response_guardrail, register_llm_stream_execution_intercept, + register_tool_conditional_execution_guardrail, register_tool_execution_intercept, + register_tool_request_intercept, register_tool_sanitize_request_guardrail, + register_tool_sanitize_response_guardrail, scope_register_llm_conditional_execution_guardrail, + scope_register_llm_execution_intercept, scope_register_llm_request_intercept, + scope_register_llm_sanitize_request_guardrail, scope_register_llm_sanitize_response_guardrail, + scope_register_llm_stream_execution_intercept, scope_register_tool_conditional_execution_guardrail, scope_register_tool_execution_intercept, scope_register_tool_request_intercept, scope_register_tool_sanitize_request_guardrail, scope_register_tool_sanitize_response_guardrail, @@ -2588,6 +2591,195 @@ async fn test_llm_request_intercept_transforms() { deregister_llm_request_intercept("llm_req_i").unwrap(); } +#[test] +fn test_llm_request_intercept_pending_marks_preserve_order_and_break_chain() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + for (name, priority, break_chain, mark_name) in [ + ("pending_first", 1, false, "first"), + ("pending_break", 2, true, "second"), + ("pending_skipped", 3, false, "skipped"), + ] { + register_llm_request_intercept_with_marks( + name, + priority, + break_chain, + Arc::new(move |_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark(PendingMarkSpec::builder().name(mark_name).build())) + }), + ) + .unwrap(); + } + + let outcome = llm_request_intercepts_with_marks( + "llm", + LlmRequest { + headers: serde_json::Map::new(), + content: json!({"prompt": "hello"}), + }, + ) + .unwrap(); + + assert_eq!( + outcome + .pending_marks + .iter() + .map(|mark| mark.name.as_str()) + .collect::>(), + ["first", "second"] + ); + assert_eq!(outcome.request.content["prompt"], "hello"); + + for name in ["pending_first", "pending_break", "pending_skipped"] { + deregister_llm_request_intercept(name).unwrap(); + } +} + +#[tokio::test] +async fn test_managed_llm_emits_pending_marks_under_started_scope() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured = events.clone(); + register_subscriber( + "pending_mark_observer", + Arc::new(move |event: &Event| captured.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + register_llm_request_intercept_with_marks( + "pending_managed", + 1, + false, + Arc::new(|_name, request, annotated| { + Ok( + LlmRequestInterceptOutcome::new(request, annotated).with_pending_mark( + PendingMarkSpec::builder() + .name("request.optimized") + .category(EventCategory::custom()) + .category_profile( + CategoryProfile::builder() + .subtype("optimizer.saved_tokens") + .build(), + ) + .data(json!({"saved_tokens": 12})) + .build(), + ), + ) + }), + ) + .unwrap(); + + let provider_request = Arc::new(Mutex::new(None::)); + let captured_request = provider_request.clone(); + llm_call_execute( + LlmCallExecuteParams::builder() + .name("pending-managed-llm") + .request(LlmRequest { + headers: serde_json::Map::new(), + content: json!({"prompt": "hello"}), + }) + .func(Arc::new(move |request| { + *captured_request.lock().unwrap() = Some(request); + Box::pin(async { Ok(json!({"response": "done"})) }) + })) + .build(), + ) + .await + .unwrap(); + + let provider_request = provider_request.lock().unwrap().clone().unwrap(); + assert!( + serde_json::to_value(provider_request) + .unwrap() + .get("__nemo_relay_llm_intercept_outcome") + .is_none() + ); + + let captured = captured_events_snapshot(&events); + let start = captured + .iter() + .find(|event| { + event.name() == "pending-managed-llm" + && event.scope_category() == Some(ScopeCategory::Start) + }) + .unwrap(); + let mark = captured + .iter() + .find(|event| event.name() == "request.optimized") + .unwrap(); + assert_eq!(mark.parent_uuid(), Some(start.uuid())); + assert!(mark.timestamp() > start.timestamp()); + assert_eq!(mark.data().unwrap()["saved_tokens"], 12); + + deregister_llm_request_intercept("pending_managed").unwrap(); + deregister_subscriber("pending_mark_observer").unwrap(); +} + +#[tokio::test] +async fn test_failed_request_intercept_does_not_emit_pending_marks_or_start_scope() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured = events.clone(); + register_subscriber( + "failed_pending_mark_observer", + Arc::new(move |event: &Event| captured.lock().unwrap().push(event.clone())), + ) + .unwrap(); + register_llm_request_intercept_with_marks( + "pending_before_failure", + 1, + false, + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark(PendingMarkSpec::builder().name("must.not.emit").build())) + }), + ) + .unwrap(); + register_llm_request_intercept_with_marks( + "pending_failure", + 2, + false, + Arc::new(|_name, _request, _annotated| { + Err(FlowError::Internal("request intercept failed".into())) + }), + ) + .unwrap(); + + let provider_called = Arc::new(AtomicBool::new(false)); + let called = provider_called.clone(); + let result = llm_call_execute( + LlmCallExecuteParams::builder() + .name("failed-pending-llm") + .request(LlmRequest { + headers: serde_json::Map::new(), + content: json!({"prompt": "hello"}), + }) + .func(Arc::new(move |_request| { + called.store(true, Ordering::SeqCst); + Box::pin(async { Ok(json!({"response": "unexpected"})) }) + })) + .build(), + ) + .await; + + assert!(result.is_err()); + assert!(!provider_called.load(Ordering::SeqCst)); + assert!(captured_events_snapshot(&events).is_empty()); + + deregister_llm_request_intercept("pending_before_failure").unwrap(); + deregister_llm_request_intercept("pending_failure").unwrap(); + deregister_subscriber("failed_pending_mark_observer").unwrap(); +} + /// LLM execution intercept middleware chain with next(). #[tokio::test] async fn test_llm_execution_intercept_chain() { diff --git a/crates/core/tests/integration/native_plugin_tests.rs b/crates/core/tests/integration/native_plugin_tests.rs index c209a553e..a41aa31a1 100644 --- a/crates/core/tests/integration/native_plugin_tests.rs +++ b/crates/core/tests/integration/native_plugin_tests.rs @@ -342,6 +342,24 @@ async fn sdk_cdylib_registers_tool_request_intercept() { llm_start.input().unwrap()["content"]["native_plugin_llm_request_intercept"], true ); + let pending_mark = find_event(&managed_llm_events, "fixture.native.llm_request.mark", None); + assert_eq!(pending_mark.parent_uuid(), Some(llm_start.uuid())); + assert_eq!( + pending_mark.category().map(|category| category.as_str()), + Some("custom") + ); + assert_eq!( + pending_mark + .category_profile() + .and_then(|profile| profile.subtype.as_deref()), + Some("fixture.native.pending") + ); + assert_eq!( + pending_mark.data().unwrap()["source"], + "native_request_intercept" + ); + assert_eq!(pending_mark.metadata().unwrap()["fixture"], true); + assert!(pending_mark.timestamp() > llm_start.timestamp()); let llm_end = find_event( &managed_llm_events, "native-fixture-llm-execute", @@ -400,6 +418,13 @@ async fn sdk_cdylib_registers_tool_request_intercept() { assert_eq!(*collected_stream_chunks.lock().unwrap(), stream_chunks); flush_subscribers().expect("stream native fixture events should flush"); let stream_events = events.lock().unwrap().clone(); + let stream_start = find_event( + &stream_events, + "native-fixture-llm-stream", + Some(ScopeCategory::Start), + ); + let stream_pending_mark = find_event(&stream_events, "fixture.native.llm_request.mark", None); + assert_eq!(stream_pending_mark.parent_uuid(), Some(stream_start.uuid())); let stream_end = find_event( &stream_events, "native-fixture-llm-stream", diff --git a/crates/core/tests/unit/shared_tests.rs b/crates/core/tests/unit/shared_tests.rs index fe3eeab59..c45ca5a5e 100644 --- a/crates/core/tests/unit/shared_tests.rs +++ b/crates/core/tests/unit/shared_tests.rs @@ -156,25 +156,34 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { Arc::new(|_name, mut request, annotated| { assert!(annotated.is_none()); request.headers.insert("x-no-codec".into(), json!(true)); - Ok((request, None)) + let mut annotated = SharedTestCodec.decode(&request)?; + annotated.model = Some("interceptor-model".into()); + Ok((request, Some(annotated))) }), ) .unwrap(); - let (request_without_codec, annotated_without_codec) = run_request_intercepts_with_codec( - "shared", - LlmRequest { - headers: Map::new(), - content: json!({"prompt": "hello"}), - }, - None, - ) - .unwrap(); + let (request_without_codec, annotated_without_codec, pending_marks_without_codec) = + run_request_intercepts_with_codec( + "shared", + LlmRequest { + headers: Map::new(), + content: json!({"prompt": "hello"}), + }, + None, + ) + .unwrap(); assert_eq!( request_without_codec.headers.get("x-no-codec"), Some(&json!(true)) ); - assert!(annotated_without_codec.is_none()); + assert_eq!( + annotated_without_codec + .as_deref() + .and_then(|annotated| annotated.model.as_deref()), + Some("interceptor-model") + ); + assert!(pending_marks_without_codec.is_empty()); deregister_llm_request_intercept("shared-none").unwrap(); register_llm_request_intercept( @@ -191,15 +200,16 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { .unwrap(); let codec: Arc = Arc::new(SharedTestCodec); - let (request_with_codec, annotated_with_codec) = run_request_intercepts_with_codec( - "shared", - LlmRequest { - headers: Map::new(), - content: json!({"prompt": "hello"}), - }, - Some(codec), - ) - .unwrap(); + let (request_with_codec, annotated_with_codec, pending_marks_with_codec) = + run_request_intercepts_with_codec( + "shared", + LlmRequest { + headers: Map::new(), + content: json!({"prompt": "hello"}), + }, + Some(codec), + ) + .unwrap(); assert_eq!( request_with_codec.headers.get("x-codec"), @@ -215,6 +225,7 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { .and_then(|annotated| annotated.model.as_deref()), Some("intercepted-model") ); + assert!(pending_marks_with_codec.is_empty()); deregister_llm_request_intercept("shared-codec").unwrap(); reset_global(); diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs index 31f0c4e40..007322cf7 100644 --- a/crates/plugin/src/lib.rs +++ b/crates/plugin/src/lib.rs @@ -16,8 +16,10 @@ use std::ptr; use std::sync::Mutex; pub use nemo_relay_types::Json; -pub use nemo_relay_types::api::event::{Event, ScopeCategory}; -pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +pub use nemo_relay_types::api::event::{ + CategoryProfile, Event, EventCategory, PendingMarkSpec, ScopeCategory, +}; +pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; pub use nemo_relay_types::api::scope::{HandleAttributes, ScopeAttributes, ScopeType}; pub use nemo_relay_types::api::tool::ToolAttributes; pub use nemo_relay_types::codec::request::AnnotatedLlmRequest; @@ -1490,6 +1492,39 @@ impl<'a> PluginContext<'a> { finish_typed_registration::(self.host, status, user_data, "llm request intercept") } + /// Registers a typed LLM request intercept that can schedule lifecycle marks. + pub fn register_llm_request_intercept_with_marks( + &mut self, + name: &str, + priority: i32, + break_chain: bool, + callback: F, + ) -> Result<()> + where + F: Fn(&str, LlmRequest, Option) -> Result + + Send + + Sync + + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_llm_request_intercept_raw( + name, + priority, + break_chain, + typed_llm_request_intercept_with_marks_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::( + self.host, + status, + user_data, + "LLM request intercept with marks", + ) + } + /// Registers a typed LLM execution intercept. pub fn register_llm_execution_intercept( &mut self, @@ -2189,6 +2224,74 @@ where } } +#[derive(Serialize)] +struct NativeLlmRequestInterceptOutcome<'a> { + #[serde(rename = "__nemo_relay_llm_intercept_outcome")] + marked_outcome: bool, + annotated_request: &'a Option, + pending_marks: &'a [PendingMarkSpec], +} + +unsafe extern "C" fn typed_llm_request_intercept_with_marks_trampoline( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + request_json: *const NemoRelayNativeString, + annotated_json: *const NemoRelayNativeString, + out_request_json: *mut *mut NemoRelayNativeString, + out_annotated_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn(&str, LlmRequest, Option) -> Result + + Send + + Sync + + 'static, +{ + if user_data.is_null() || out_request_json.is_null() || out_annotated_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { + *out_request_json = ptr::null_mut(); + *out_annotated_json = ptr::null_mut(); + } + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let name = read_required_host_string(&state.host, name, "LLM name")?; + let request: LlmRequest = read_json_value(&state.host, request_json, "LLM request")?; + let annotated: Option = + read_optional_json_value(&state.host, annotated_json, "annotated LLM request")?; + match (state.callback)(&name, request, annotated) { + Ok(outcome) => { + let Some(request) = HostString::from_json(&state.host, &outcome.request) else { + set_last_error(&state.host, "failed to allocate LLM request output"); + return Ok(NemoRelayStatus::Internal); + }; + let metadata = NativeLlmRequestInterceptOutcome { + marked_outcome: true, + annotated_request: &outcome.annotated_request, + pending_marks: &outcome.pending_marks, + }; + let Some(metadata) = HostString::from_json(&state.host, &metadata) else { + set_last_error(&state.host, "failed to allocate marked LLM outcome"); + return Ok(NemoRelayStatus::Internal); + }; + unsafe { + *out_request_json = request.ptr; + *out_annotated_json = metadata.ptr; + } + std::mem::forget(request); + std::mem::forget(metadata); + Ok(NemoRelayStatus::Ok) + } + Err(message) => Ok(callback_error(&state.host, message)), + } + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "LLM request intercept with marks callback"), + } +} + unsafe extern "C" fn typed_llm_execution_trampoline( user_data: *mut c_void, name: *const NemoRelayNativeString, diff --git a/crates/plugin/tests/typed_callbacks.rs b/crates/plugin/tests/typed_callbacks.rs index d3a502ffb..ae88024b0 100644 --- a/crates/plugin/tests/typed_callbacks.rs +++ b/crates/plugin/tests/typed_callbacks.rs @@ -14,16 +14,16 @@ use std::sync::{ use nemo_relay_plugin::{ AnnotatedLlmRequest, ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmNext, - LlmRequest, LlmStream, LlmStreamNext, NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, - NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, - NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, - NemoRelayNativeLlmRequestCb, NemoRelayNativeLlmRequestInterceptCb, - NemoRelayNativeLlmStreamExecutionCb, NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, - NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, - NemoRelayNativeScopeStackBinding, NemoRelayNativeScopeType, NemoRelayNativeString, - NemoRelayNativeToolConditionalCb, NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, - NemoRelayNativeWithScopeStackCb, NemoRelayStatus, PluginContext, PluginRuntime, ScopeType, - ToolNext, + LlmRequest, LlmRequestInterceptOutcome, LlmStream, LlmStreamNext, + NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, NemoRelayNativeEventSubscriberCb, + NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, + NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, + NemoRelayNativeLlmRequestInterceptCb, NemoRelayNativeLlmStreamExecutionCb, + NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, NemoRelayNativePluginV1, + NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, NemoRelayNativeScopeStackBinding, + NemoRelayNativeScopeType, NemoRelayNativeString, NemoRelayNativeToolConditionalCb, + NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, NemoRelayNativeWithScopeStackCb, + NemoRelayStatus, PendingMarkSpec, PluginContext, PluginRuntime, ScopeType, ToolNext, }; use serde_json::{Map, json}; @@ -4388,6 +4388,75 @@ fn typed_llm_request_intercept_round_trips_request_and_annotations() { } } +#[test] +fn typed_llm_request_intercept_with_marks_uses_tagged_annotation_envelope() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_llm_request_intercept_with_marks( + "llm", + 23, + false, + |_name, mut request, annotated| { + request.content["rewritten"] = json!(true); + Ok( + LlmRequestInterceptOutcome::new(request, annotated).with_pending_mark( + PendingMarkSpec::builder() + .name("plugin.request.rewritten") + .data(json!({ "saved_tokens": 7 })) + .build(), + ), + ) + }, + ) + .unwrap(); + + let registration = take_llm_request_intercept_registration(); + assert_eq!(registration.priority, 23); + assert!(!registration.break_chain); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let annotated = json_host_string( + &host, + serde_json::to_value(test_annotated_llm_request()).unwrap(), + ); + let mut out_request = ptr::null_mut(); + let mut out_annotated = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + annotated, + &mut out_request, + &mut out_annotated, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + let out_request = read_json_and_free(&host, out_request); + assert_eq!(out_request["content"]["rewritten"], true); + assert!( + out_request + .pointer("/content/__nemo_relay_llm_intercept_outcome") + .is_none() + ); + let metadata = read_json_and_free(&host, out_annotated); + assert_eq!(metadata["__nemo_relay_llm_intercept_outcome"], true); + assert_eq!(metadata["annotated_request"]["messages"], json!([])); + assert_eq!( + metadata["pending_marks"][0]["name"], + "plugin.request.rewritten" + ); + assert_eq!(metadata["pending_marks"][0]["data"]["saved_tokens"], 7); + + unsafe { + (host.string_free)(name); + (host.string_free)(request); + (host.string_free)(annotated); + registration.free(); + } +} + struct DropCounter(Arc); impl Drop for DropCounter { diff --git a/crates/types/src/api/event.rs b/crates/types/src/api/event.rs index af17546ae..d08bf8180 100644 --- a/crates/types/src/api/event.rs +++ b/crates/types/src/api/event.rs @@ -364,6 +364,29 @@ pub struct MarkEvent { pub category_profile: Option, } +/// Mark requested by middleware before its owning runtime scope exists. +/// +/// The runtime assigns the parent UUID, event UUID, and timestamp when it +/// materializes the mark at the appropriate lifecycle boundary. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] +#[builder(field_defaults(setter(into, strip_option(ignore_invalid, fallback_suffix = "_opt"))))] +pub struct PendingMarkSpec { + /// Human-readable mark name. + pub name: String, + /// Optional semantic category for the mark. + #[builder(default)] + pub category: Option, + /// Optional category-specific typed fields. + #[builder(default)] + pub category_profile: Option, + /// Optional application payload attached to the mark. + #[builder(default)] + pub data: Option, + /// Optional metadata attached to the mark. + #[builder(default)] + pub metadata: Option, +} + impl MarkEvent { /// Construct a mark event from a base envelope and optional category data. /// diff --git a/crates/types/src/api/llm.rs b/crates/types/src/api/llm.rs index caa917a65..c4dba6efb 100644 --- a/crates/types/src/api/llm.rs +++ b/crates/types/src/api/llm.rs @@ -7,6 +7,15 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; use crate::Json; +use crate::api::event::PendingMarkSpec; +use crate::codec::request::AnnotatedLlmRequest; + +/// Private native-ABI tag used for marked LLM request-intercept outcomes. +/// +/// Native plugin authors should return [`LlmRequestInterceptOutcome`] through +/// the plugin SDK instead of reading or writing this field directly. +#[doc(hidden)] +pub const NATIVE_LLM_INTERCEPT_OUTCOME_FIELD: &str = "__nemo_relay_llm_intercept_outcome"; bitflags! { /// Bitflags that modify LLM-call behavior and observability. @@ -20,10 +29,47 @@ bitflags! { } /// JSON-shaped LLM request payload passed through the runtime. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct LlmRequest { /// Provider-specific request headers. pub headers: serde_json::Map, /// Provider-specific request body. pub content: Json, } + +/// Result of an LLM request intercept that can schedule lifecycle marks. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LlmRequestInterceptOutcome { + /// Rewritten provider request. + pub request: LlmRequest, + /// Optional normalized request annotation to carry forward. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub annotated_request: Option, + /// Ordered marks to emit after Relay creates and starts the LLM scope. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub pending_marks: Vec, +} + +impl LlmRequestInterceptOutcome { + /// Create an outcome without pending marks. + pub fn new(request: LlmRequest, annotated_request: Option) -> Self { + Self { + request, + annotated_request, + pending_marks: Vec::new(), + } + } + + /// Append one pending mark while preserving interceptor order. + #[must_use] + pub fn with_pending_mark(mut self, mark: PendingMarkSpec) -> Self { + self.pending_marks.push(mark); + self + } +} + +impl From<(LlmRequest, Option)> for LlmRequestInterceptOutcome { + fn from((request, annotated_request): (LlmRequest, Option)) -> Self { + Self::new(request, annotated_request) + } +} diff --git a/crates/types/tests/serialization_tests.rs b/crates/types/tests/serialization_tests.rs index 9e899ff7d..fb423ecf9 100644 --- a/crates/types/tests/serialization_tests.rs +++ b/crates/types/tests/serialization_tests.rs @@ -6,10 +6,10 @@ use std::sync::Arc; use nemo_relay_types::api::event::{ - BaseEvent, CategoryProfile, Event, EventCategory, ScopeCategory, ScopeEvent, + BaseEvent, CategoryProfile, Event, EventCategory, PendingMarkSpec, ScopeCategory, ScopeEvent, llm_attributes_to_strings, }; -use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay_types::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; use nemo_relay_types::codec::response::AnnotatedLlmResponse; use serde_json::{Map, json}; @@ -78,3 +78,46 @@ fn event_round_trips_with_annotated_llm_profiles() { Some("resp_1") ); } + +#[test] +fn llm_request_intercept_outcome_round_trips_pending_marks() { + let outcome = LlmRequestInterceptOutcome::new( + LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "hello" }), + }, + None, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("request.optimized") + .category(EventCategory::custom()) + .category_profile( + CategoryProfile::builder() + .subtype("optimizer.saved_tokens") + .build(), + ) + .data(json!({ "saved_tokens": 12 })) + .metadata(json!({ "source": "test" })) + .build(), + ); + + let encoded = serde_json::to_value(&outcome).expect("outcome should serialize"); + assert_eq!(encoded["pending_marks"][0]["name"], "request.optimized"); + assert_eq!(encoded["pending_marks"][0]["category"], "custom"); + assert!(encoded.get("annotated_request").is_none()); + + let mut encoded_without_pending_marks = encoded.clone(); + encoded_without_pending_marks + .as_object_mut() + .unwrap() + .remove("pending_marks"); + let decoded_without_pending_marks: LlmRequestInterceptOutcome = + serde_json::from_value(encoded_without_pending_marks) + .expect("outcome without pending marks should deserialize"); + assert!(decoded_without_pending_marks.pending_marks.is_empty()); + + let decoded: LlmRequestInterceptOutcome = + serde_json::from_value(encoded).expect("outcome should deserialize"); + assert_eq!(decoded, outcome); +} diff --git a/examples/rust-native-plugin/README.md b/examples/rust-native-plugin/README.md index 82fa142a6..9c8df8860 100644 --- a/examples/rust-native-plugin/README.md +++ b/examples/rust-native-plugin/README.md @@ -84,9 +84,18 @@ The example registers the following runtime behavior: - Request and execution intercepts for tools that mutate JSON payloads and call continuations. - LLM sanitize request/response guardrails. -- LLM request, execution, and stream execution intercepts. +- An LLM request intercept that rewrites the request and schedules a mark. Relay + emits that mark after the LLM start event with the LLM scope as its parent. +- LLM execution and stream execution intercepts. - Runtime mark and scope events. - A plugin-owned isolated scope stack for non-correlated visibility. Native plugins are not sandboxed. They run in the Relay process and must not unwind across ABI callbacks. + +Request intercepts do not own an LLM lifecycle because they run before Relay +creates the LLM scope. Use `register_llm_request_intercept_with_marks` to return +`PendingMarkSpec` values. Relay emits them in interceptor order after the LLM +start event and before provider execution. The legacy +`register_llm_request_intercept` API remains available for intercepts that only +rewrite requests. diff --git a/examples/rust-native-plugin/src/lib.rs b/examples/rust-native-plugin/src/lib.rs index f18296f54..d7e4d961f 100644 --- a/examples/rust-native-plugin/src/lib.rs +++ b/examples/rust-native-plugin/src/lib.rs @@ -2,8 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 use nemo_relay_plugin::{ - ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmRequest, NativePlugin, - PluginContext, PluginRuntime, ScopeCategory, ScopeType, + CategoryProfile, ConfigDiagnostic, DiagnosticLevel, Event, EventCategory, Json, LlmJsonStream, + LlmRequest, LlmRequestInterceptOutcome, NativePlugin, PendingMarkSpec, PluginContext, + PluginRuntime, ScopeCategory, ScopeType, }; use serde_json::{Map, json}; @@ -232,12 +233,23 @@ impl NativePlugin for ExampleNativePlugin { Ok(block_llms.then(|| "LLM call blocked by Rust native plugin".to_string())) } })?; - ctx.register_llm_request_intercept("example_llm_request", 20, false, { + ctx.register_llm_request_intercept_with_marks("example_llm_request", 20, false, { let tag = config.tag.clone(); move |_name, request, annotated| { - Ok(( + Ok(LlmRequestInterceptOutcome::new( tag_llm_request(request, "native_llm_request_intercept", &tag), annotated, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("example.native.llm_request_intercept") + .category(EventCategory::custom()) + .category_profile(CategoryProfile { + subtype: Some("example.native.request_rewrite".into()), + ..CategoryProfile::default() + }) + .data(json!({ "tag": &tag })) + .build(), )) } })?;