|
120 | 120 | LC_AgentMiddleware = Langchain_AgentMiddleware[Any, "InvokeContext", Any] |
121 | 121 | LC_ModelRequest = Langchain_ModelRequest["InvokeContext"] |
122 | 122 |
|
| 123 | +total_token_usage: int = 0 |
| 124 | + |
123 | 125 | # Set to True to enable debugging mode. |
124 | 126 | _DEBUG = False |
125 | 127 |
|
@@ -291,7 +293,6 @@ async def awrap_model_call( |
291 | 293 | request: LC_ModelRequest, |
292 | 294 | handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]], |
293 | 295 | ) -> LC_ModelCallResult: |
294 | | - |
295 | 296 | agent_thread_ids: dict[str, set[str]] = {} |
296 | 297 |
|
297 | 298 | # Update the subagent schema definitions to include all thread_ids that the |
@@ -498,6 +499,9 @@ async def awrap_model_call( |
498 | 499 | print("LLM CALL", request) |
499 | 500 | try: |
500 | 501 | resp = await handler(request) |
| 502 | + except LC_StructuredOutputError as e: |
| 503 | + print("LLM FAILURE", e, e.ai_message) |
| 504 | + raise |
501 | 505 | except Exception as e: |
502 | 506 | print("LLM FAILURE", e) |
503 | 507 | raise |
@@ -528,6 +532,45 @@ async def awrap_tool_call( |
528 | 532 | if _DEBUG: |
529 | 533 | lc_middleware.append(_DEBUGMiddleware()) |
530 | 534 |
|
| 535 | + class _TOKENUsage(LC_AgentMiddleware): |
| 536 | + @override |
| 537 | + async def awrap_model_call( |
| 538 | + self, |
| 539 | + request: LC_ModelRequest, |
| 540 | + handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]], |
| 541 | + ) -> LC_ModelCallResult: |
| 542 | + global total_token_usage |
| 543 | + |
| 544 | + def _extract_tokens(resp: LC_ModelCallResult) -> int: |
| 545 | + ai_message = resp |
| 546 | + if isinstance(ai_message, LC_ExtendedModelResponse): |
| 547 | + ai_message = ai_message.model_response |
| 548 | + if isinstance(ai_message, LC_ModelResponse): |
| 549 | + ai_message = next( |
| 550 | + ( |
| 551 | + m |
| 552 | + for m in ai_message.result |
| 553 | + if isinstance(m, LC_AIMessage) |
| 554 | + ), |
| 555 | + None, |
| 556 | + ) |
| 557 | + if ai_message is not None and ai_message.usage_metadata: |
| 558 | + return ai_message.usage_metadata.get("total_tokens", 0) |
| 559 | + return 0 |
| 560 | + |
| 561 | + try: |
| 562 | + resp = await handler(request) |
| 563 | + total_token_usage += _extract_tokens(resp) |
| 564 | + return resp |
| 565 | + except LC_StructuredOutputError as e: |
| 566 | + if e.ai_message.usage_metadata: |
| 567 | + total_token_usage += e.ai_message.usage_metadata.get( |
| 568 | + "total_tokens", 0 |
| 569 | + ) |
| 570 | + raise |
| 571 | + |
| 572 | + lc_middleware.append(_TOKENUsage()) |
| 573 | + |
531 | 574 | response_format = None |
532 | 575 | if agent.output_schema is not None: |
533 | 576 | if _supports_provider_strategy(model_impl): |
|
0 commit comments