Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions crates/core/src/api/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

use std::sync::Arc;

use chrono::{DateTime, Utc};
use chrono::{DateTime, TimeDelta, Utc};
use serde::{Deserialize, Serialize};
use serde_json::json;
use typed_builder::TypedBuilder;
use uuid::Uuid;

use crate::api::event::{BaseEvent, Event, MarkEvent, PendingMarkSpec};
use crate::api::runtime::NemoRelayContextState;
use crate::api::runtime::current_scope_stack;
use crate::api::runtime::global_context;
Expand All @@ -28,7 +29,7 @@ use crate::error::{FlowError, Result};
use crate::json::Json;
use crate::stream::LlmStreamWrapper;

pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest};
pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome};

/// Runtime-owned handle identifying an active or completed LLM call.
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
Expand Down Expand Up @@ -298,6 +299,35 @@ fn emit_llm_start(
Ok(())
}

fn emit_pending_request_marks(handle: &LlmHandle, marks: Vec<PendingMarkSpec>) -> Result<()> {
if marks.is_empty() {
return Ok(());
}
ensure_runtime_owner()?;
let subscribers = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())?
};
for (index, mark) in marks.into_iter().enumerate() {
let timestamp = handle.started_at
+ TimeDelta::microseconds(i64::try_from(index).unwrap_or_default() + 1);
let event = Event::Mark(MarkEvent::new(
BaseEvent::builder()
.name(mark.name)
.parent_uuid(handle.uuid)
.timestamp(timestamp)
.data_opt(mark.data)
.metadata_opt(mark.metadata)
.build(),
mark.category,
mark.category_profile,
));
NemoRelayContextState::emit_event(&event, &subscribers);
}
Ok(())
}

/// Start a manual LLM lifecycle span.
///
/// This emits an LLM-start event after applying sanitize-request guardrails to
Expand Down Expand Up @@ -587,7 +617,7 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result<Json> {
}

let request_codec = codec.clone();
let (intercepted_request, annotated_request) =
let (intercepted_request, annotated_request, pending_marks) =
run_request_intercepts_with_codec(&name, request, codec)?;

let handle = create_llm_handle(
Expand All @@ -606,6 +636,7 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result<Json> {
annotated_request.clone(),
request_codec.as_deref(),
)?;
emit_pending_request_marks(&handle, pending_marks)?;

let execution = {
let scope_stack = current_scope_stack();
Expand Down Expand Up @@ -743,7 +774,7 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu
}

let request_codec = codec.clone();
let (intercepted_request, annotated_request) =
let (intercepted_request, annotated_request, pending_marks) =
run_request_intercepts_with_codec(&name, request, codec)?;

let handle = create_llm_handle(
Expand All @@ -762,6 +793,7 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu
annotated_request,
request_codec.as_deref(),
)?;
emit_pending_request_marks(&handle, pending_marks)?;

let execution = {
let scope_stack = current_scope_stack();
Expand Down Expand Up @@ -818,6 +850,17 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu
/// Conditional guardrails, codecs, and execution intercepts are not run by
/// this helper.
pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result<LlmRequest> {
Ok(llm_request_intercepts_with_marks(name, request)?.request)
}

/// Run the LLM request-intercept chain and return pending lifecycle marks.
///
/// This helper does not emit the returned marks because it does not own an LLM
/// lifecycle. Callers must attach them to the lifecycle they own.
pub fn llm_request_intercepts_with_marks(
name: &str,
request: LlmRequest,
) -> Result<LlmRequestInterceptOutcome> {
ensure_runtime_owner()?;
let entries = {
let scope_stack = current_scope_stack();
Expand All @@ -830,10 +873,7 @@ pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result<LlmRequ
.map_err(|error| FlowError::Internal(error.to_string()))?;
state.llm_request_intercept_entries(&scope_locals)
};
let (request, _) = NemoRelayContextState::llm_request_intercepts_snapshot_chain(
name, request, None, &entries,
)?;
Ok(request)
NemoRelayContextState::llm_request_intercepts_snapshot_chain(name, request, None, &entries)
}

/// Run only the LLM conditional-execution guardrail chain.
Expand Down
57 changes: 45 additions & 12 deletions crates/core/src/api/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
//! intercepts, and subscribers.

use crate::api::runtime::{
LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, LlmSanitizeRequestFn,
LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn, ToolExecutionFn,
ToolInterceptFn, ToolSanitizeFn,
LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, LlmRequestInterceptWithMarksFn,
LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn,
ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn,
};
use crate::api::runtime::{current_scope_stack, global_context};
use crate::api::shared::ensure_runtime_owner;
Expand Down Expand Up @@ -547,14 +547,30 @@ global_guardrail_registry_api!(
LlmConditionalFn
);
global_intercept_registry_api!(
/// Register a global LLM request intercept.
/// Request intercepts can rewrite or annotate the outgoing LLM request.
register_llm_request_intercept,
/// Register a global LLM request intercept that can schedule lifecycle marks.
register_llm_request_intercept_with_marks,
/// Deregister a global LLM request intercept.
deregister_llm_request_intercept,
llm_request_intercepts,
LlmRequestInterceptFn
LlmRequestInterceptWithMarksFn
);

/// Register a global LLM request intercept without pending marks.
pub fn register_llm_request_intercept(
name: &str,
priority: i32,
break_chain: bool,
callable: LlmRequestInterceptFn,
) -> Result<()> {
register_llm_request_intercept_with_marks(
name,
priority,
break_chain,
std::sync::Arc::new(move |name, request, annotated| {
callable(name, request, annotated).map(Into::into)
}),
)
}
global_execution_registry_api!(
/// Register a global LLM execution intercept.
/// Execution intercepts can wrap or replace the non-streaming provider
Expand Down Expand Up @@ -653,15 +669,32 @@ scope_guardrail_registry_api!(
LlmConditionalFn
);
scope_intercept_registry_api!(
/// Register a scope-local LLM request intercept.
/// Request intercepts can rewrite or annotate LLM requests inside the
/// owning scope.
scope_register_llm_request_intercept,
/// Register a scope-local LLM request intercept that can schedule lifecycle marks.
scope_register_llm_request_intercept_with_marks,
/// Deregister a scope-local LLM request intercept.
scope_deregister_llm_request_intercept,
llm_request_intercepts,
LlmRequestInterceptFn
LlmRequestInterceptWithMarksFn
);

/// Register a scope-local LLM request intercept without pending marks.
pub fn scope_register_llm_request_intercept(
scope_uuid: &uuid::Uuid,
name: &str,
priority: i32,
break_chain: bool,
callable: LlmRequestInterceptFn,
) -> Result<()> {
scope_register_llm_request_intercept_with_marks(
scope_uuid,
name,
priority,
break_chain,
std::sync::Arc::new(move |name, request, annotated| {
callable(name, request, annotated).map(Into::into)
}),
)
}
scope_execution_registry_api!(
/// Register a scope-local LLM execution intercept.
/// Execution intercepts can wrap or replace the non-streaming provider
Expand Down
6 changes: 3 additions & 3 deletions crates/core/src/api/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ pub mod subscriber_dispatcher;

pub use callbacks::{
EventSubscriberFn, LlmCollectorFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn,
LlmFinalizerFn, LlmJsonStream, LlmRequestInterceptFn, LlmSanitizeRequestFn,
LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, ToolConditionalFn,
ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn,
LlmFinalizerFn, LlmJsonStream, LlmRequestInterceptFn, LlmRequestInterceptWithMarksFn,
LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn,
ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn,
};
pub use global::global_context;
pub use scope_stack::{
Expand Down
12 changes: 11 additions & 1 deletion crates/core/src/api/runtime/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::sync::Arc;
use tokio_stream::Stream;

use crate::api::event::Event;
use crate::api::llm::LlmRequest;
use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome};
use crate::codec::request::AnnotatedLlmRequest;
use crate::error::Result;
use crate::json::Json;
Expand Down Expand Up @@ -177,6 +177,16 @@ pub type LlmRequestInterceptFn = Arc<
+ Send
+ Sync,
>;
/// Rewrite or annotate an LLM request and schedule marks under its future scope.
///
/// This callback has the same inputs as [`LlmRequestInterceptFn`] but returns a
/// structured outcome whose pending marks are emitted after the LLM-start
/// event and before provider execution.
pub type LlmRequestInterceptWithMarksFn = Arc<
dyn Fn(&str, LlmRequest, Option<AnnotatedLlmRequest>) -> Result<LlmRequestInterceptOutcome>
+ Send
+ Sync,
>;
/// Continuation type invoked by non-streaming LLM execution intercepts.
///
/// Execution intercepts use this callable to continue the non-streaming LLM
Expand Down
33 changes: 19 additions & 14 deletions crates/core/src/api/runtime/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ use crate::api::llm::{CreateLlmHandleParams, EndLlmHandleParams};
use crate::api::llm::{LlmHandle, LlmRequest};
use crate::api::registry::{ExecutionIntercept, Guardrail, Intercept};
use crate::api::runtime::callbacks::{
EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmRequestInterceptFn,
LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn,
LlmStreamExecutionRegistryRefs, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn,
ToolInterceptFn, ToolSanitizeFn,
EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn,
LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn,
LlmStreamExecutionFn, LlmStreamExecutionNextFn, LlmStreamExecutionRegistryRefs,
ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn,
};
use crate::api::runtime::subscriber_dispatcher;
use crate::api::scope::{CreateScopeHandleParams, EndScopeHandleParams, ScopeHandle, ScopeType};
Expand Down Expand Up @@ -63,7 +63,7 @@ pub struct NemoRelayContextState {
/// Global LLM guardrails that can reject execution before the provider callback runs.
pub(crate) llm_conditional_execution_guardrails: SortedRegistry<Guardrail<LlmConditionalFn>>,
/// Global LLM request intercepts that can rewrite or annotate requests.
pub(crate) llm_request_intercepts: SortedRegistry<Intercept<LlmRequestInterceptFn>>,
pub(crate) llm_request_intercepts: SortedRegistry<Intercept<LlmRequestInterceptWithMarksFn>>,
/// Global non-streaming LLM execution intercepts that wrap callback execution.
pub(crate) llm_execution_intercepts: SortedRegistry<ExecutionIntercept<LlmExecutionFn>>,
/// Global streaming LLM execution intercepts that wrap stream-producing callbacks.
Expand Down Expand Up @@ -1011,8 +1011,8 @@ impl NemoRelayContextState {
/// are released.
pub(crate) fn llm_request_intercept_entries(
&self,
scope_locals: &[&SortedRegistry<Intercept<LlmRequestInterceptFn>>],
) -> Vec<Intercept<LlmRequestInterceptFn>> {
scope_locals: &[&SortedRegistry<Intercept<LlmRequestInterceptWithMarksFn>>],
) -> Vec<Intercept<LlmRequestInterceptWithMarksFn>> {
merge_intercept_entries(&self.llm_request_intercepts, scope_locals)
.into_iter()
.cloned()
Expand Down Expand Up @@ -1041,20 +1041,25 @@ impl NemoRelayContextState {
name: &str,
request: LlmRequest,
annotated: Option<AnnotatedLlmRequest>,
entries: &[Intercept<LlmRequestInterceptFn>],
) -> crate::error::Result<(LlmRequest, Option<AnnotatedLlmRequest>)> {
entries: &[Intercept<LlmRequestInterceptWithMarksFn>],
) -> crate::error::Result<crate::api::llm::LlmRequestInterceptOutcome> {
let mut request_value = request;
let mut annotated_value = annotated;
let mut pending_marks = Vec::new();
for entry in entries {
let (new_request, new_annotated) =
(entry.payload.callable)(name, request_value, annotated_value)?;
request_value = new_request;
annotated_value = new_annotated;
let outcome = (entry.payload.callable)(name, request_value, annotated_value)?;
request_value = outcome.request;
annotated_value = outcome.annotated_request;
pending_marks.extend(outcome.pending_marks);
if entry.payload.break_chain {
break;
}
}
Ok((request_value, annotated_value))
Ok(crate::api::llm::LlmRequestInterceptOutcome {
request: request_value,
annotated_request: annotated_value,
pending_marks,
})
}

/// Build the composed non-streaming LLM execution continuation chain.
Expand Down
17 changes: 11 additions & 6 deletions crates/core/src/api/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ pub(crate) fn run_request_intercepts_with_codec(
name: &str,
request: LlmRequest,
codec: Option<Arc<dyn LlmCodec>>,
) -> Result<(LlmRequest, Option<Arc<AnnotatedLlmRequest>>)> {
) -> Result<(
LlmRequest,
Option<Arc<AnnotatedLlmRequest>>,
Vec<crate::api::event::PendingMarkSpec>,
)> {
let original = request.clone();
let annotated = match &codec {
Some(codec) => Some(codec.decode(&request)?),
Expand All @@ -94,18 +98,19 @@ pub(crate) fn run_request_intercepts_with_codec(
state.llm_request_intercept_entries(&scope_locals)
};

let (intercepted_request, intercepted_annotated) =
let outcome =
crate::api::runtime::NemoRelayContextState::llm_request_intercepts_snapshot_chain(
name, request, annotated, &entries,
)?;
let pending_marks = outcome.pending_marks;

match (codec, intercepted_annotated) {
match (codec, outcome.annotated_request) {
(Some(codec), Some(annotated)) => {
let mut encoded = codec.encode(&annotated, &original)?;
encoded.headers = intercepted_request.headers;
Ok((encoded, Some(Arc::new(annotated))))
encoded.headers = outcome.request.headers;
Ok((encoded, Some(Arc::new(annotated)), pending_marks))
}
_ => Ok((intercepted_request, None)),
(_, annotated) => Ok((outcome.request, annotated.map(Arc::new), pending_marks)),
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/context/registries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::collections::HashMap;

use crate::api::registry::{ExecutionIntercept, Guardrail, Intercept};
use crate::api::runtime::{
EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn,
EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptWithMarksFn,
LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn,
ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn,
};
Expand Down Expand Up @@ -41,7 +41,7 @@ pub(crate) struct ScopeLocalRegistries {
/// LLM guardrails that can reject execution before the provider callback runs.
pub(crate) llm_conditional_execution_guardrails: SortedRegistry<Guardrail<LlmConditionalFn>>,
/// LLM request intercepts that can rewrite or annotate requests.
pub(crate) llm_request_intercepts: SortedRegistry<Intercept<LlmRequestInterceptFn>>,
pub(crate) llm_request_intercepts: SortedRegistry<Intercept<LlmRequestInterceptWithMarksFn>>,
/// Non-streaming LLM execution intercepts that wrap callback execution.
pub(crate) llm_execution_intercepts: SortedRegistry<ExecutionIntercept<LlmExecutionFn>>,
/// Streaming LLM execution intercepts that wrap stream-producing callbacks.
Expand Down
Loading
Loading