Skip to content

Commit 3b3c891

Browse files
committed
CI token usage
1 parent 8f0b229 commit 3b3c891

2 files changed

Lines changed: 53 additions & 1 deletion

File tree

splunklib/ai/engines/langchain.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@
120120
LC_AgentMiddleware = Langchain_AgentMiddleware[Any, "InvokeContext", Any]
121121
LC_ModelRequest = Langchain_ModelRequest["InvokeContext"]
122122

123+
total_token_usage: int = 0
124+
123125
# Set to True to enable debugging mode.
124126
_DEBUG = False
125127

@@ -291,7 +293,6 @@ async def awrap_model_call(
291293
request: LC_ModelRequest,
292294
handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]],
293295
) -> LC_ModelCallResult:
294-
295296
agent_thread_ids: dict[str, set[str]] = {}
296297

297298
# Update the subagent schema definitions to include all thread_ids that the
@@ -498,6 +499,9 @@ async def awrap_model_call(
498499
print("LLM CALL", request)
499500
try:
500501
resp = await handler(request)
502+
except LC_StructuredOutputError as e:
503+
print("LLM FAILURE", e, e.ai_message)
504+
raise
501505
except Exception as e:
502506
print("LLM FAILURE", e)
503507
raise
@@ -528,6 +532,45 @@ async def awrap_tool_call(
528532
if _DEBUG:
529533
lc_middleware.append(_DEBUGMiddleware())
530534

535+
class _TOKENUsage(LC_AgentMiddleware):
536+
@override
537+
async def awrap_model_call(
538+
self,
539+
request: LC_ModelRequest,
540+
handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]],
541+
) -> LC_ModelCallResult:
542+
global total_token_usage
543+
544+
def _extract_tokens(resp: LC_ModelCallResult) -> int:
545+
ai_message = resp
546+
if isinstance(ai_message, LC_ExtendedModelResponse):
547+
ai_message = ai_message.model_response
548+
if isinstance(ai_message, LC_ModelResponse):
549+
ai_message = next(
550+
(
551+
m
552+
for m in ai_message.result
553+
if isinstance(m, LC_AIMessage)
554+
),
555+
None,
556+
)
557+
if ai_message is not None and ai_message.usage_metadata:
558+
return ai_message.usage_metadata.get("total_tokens", 0)
559+
return 0
560+
561+
try:
562+
resp = await handler(request)
563+
total_token_usage += _extract_tokens(resp)
564+
return resp
565+
except LC_StructuredOutputError as e:
566+
if e.ai_message.usage_metadata:
567+
total_token_usage += e.ai_message.usage_metadata.get(
568+
"total_tokens", 0
569+
)
570+
raise
571+
572+
lc_middleware.append(_TOKENUsage())
573+
531574
response_format = None
532575
if agent.output_schema is not None:
533576
if _supports_provider_strategy(model_impl):

tests/ai_testlib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from typing import override
2+
import splunklib.ai.engines.langchain as langchain_engine
23
from splunklib.ai.model import PredefinedModel
34
from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model
45
from tests.testlib import SDKTestCase
56

67

78
class AITestCase(SDKTestCase):
89
_model: PredefinedModel | None = None
10+
_token_usage_before: int = 0
911

1012
@override
1113
def setUp(self) -> None:
1214
super().setUp()
15+
self._token_usage_before = langchain_engine.total_token_usage
1316

1417
# Our tests don't expect this app to be installed, if needed it is
1518
# installed on demand.
@@ -18,6 +21,12 @@ def setUp(self) -> None:
1821
app.delete()
1922
self.restart_splunk()
2023

24+
@override
25+
def tearDown(self) -> None:
26+
tokens_used = langchain_engine.total_token_usage - self._token_usage_before
27+
print(f"\n[token usage] {self.id()}: {tokens_used} tokens")
28+
super().tearDown()
29+
2130
@property
2231
def test_llm_settings(self) -> TestLLMSettings:
2332
client_id: str = self.opts.kwargs["internal_ai_client_id"]

0 commit comments

Comments
 (0)