From 7926cc42f6e4aa4f185ad1b8f3c8d11d7ad80367 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 13 May 2026 19:09:54 -0700 Subject: [PATCH 01/34] Refactor printers: extract formatting into lightweight base classes Create new pyrit/printer/ module with abstract base classes that contain all formatting logic. Data-fetching operations (CentralMemory calls) are abstract methods implemented by framework subclasses. This enables thin clients to reuse all pretty-printing by subclassing the base printers and implementing data-fetching via REST endpoints. The thin client only needs pyrit.models + pyrit.identifiers + colorama. Changes: - New pyrit/printer/ module with attack_result, scenario_result, scorer subpackages - ConsoleAttackPrinterBase: all attack console formatting, abstract get_conversation/get_scores - ConsoleScenarioPrinterBase: all scenario console formatting - ConsoleScorerPrinterBase: all scorer formatting, abstract get_objective/harm_metrics - Existing framework printers refactored to thin subclasses (backward compatible) - Added to_dict()/from_dict() to AttackResult, ScenarioResult, ScenarioIdentifier, ConversationReference, Score, MessagePiece, Message for serialization round-tripping - Message.to_full_dict() added for rich serialization (to_dict() unchanged for compat) All 675 existing tests pass with no modifications. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../attack/printer/console_printer.py | 595 +----------------- pyrit/models/attack_result.py | 87 +++ pyrit/models/conversation_reference.py | 30 + pyrit/models/message.py | 38 ++ pyrit/models/message_piece.py | 61 ++ pyrit/models/scenario_result.py | 111 ++++ pyrit/models/score.py | 27 + pyrit/printer/__init__.py | 15 + pyrit/printer/attack_result/__init__.py | 4 + pyrit/printer/attack_result/base.py | 91 +++ pyrit/printer/attack_result/console.py | 484 ++++++++++++++ pyrit/printer/scenario_result/__init__.py | 4 + pyrit/printer/scenario_result/base.py | 24 + pyrit/printer/scenario_result/console.py | 178 ++++++ pyrit/printer/scorer/__init__.py | 4 + pyrit/printer/scorer/base.py | 61 ++ pyrit/printer/scorer/console.py | 258 ++++++++ pyrit/scenario/printer/console_printer.py | 200 +----- pyrit/score/printer/console_scorer_printer.py | 279 +------- 19 files changed, 1540 insertions(+), 1011 deletions(-) create mode 100644 pyrit/printer/__init__.py create mode 100644 pyrit/printer/attack_result/__init__.py create mode 100644 pyrit/printer/attack_result/base.py create mode 100644 pyrit/printer/attack_result/console.py create mode 100644 pyrit/printer/scenario_result/__init__.py create mode 100644 pyrit/printer/scenario_result/base.py create mode 100644 pyrit/printer/scenario_result/console.py create mode 100644 pyrit/printer/scorer/__init__.py create mode 100644 pyrit/printer/scorer/base.py create mode 100644 pyrit/printer/scorer/console.py diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index 8c4cb9190d..1e17896e88 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -1,26 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import json -import textwrap -from datetime import datetime, timezone -from typing import Any - -from colorama import Back, Fore, Style - from pyrit.common.display_response import display_image_response -from pyrit.executor.attack.printer.attack_result_printer import AttackResultPrinter from pyrit.memory import CentralMemory -from pyrit.models import AttackOutcome, AttackResult, ConversationType, Score +from pyrit.models import Message, Score +from pyrit.printer.attack_result.console import ConsoleAttackPrinterBase -class ConsoleAttackResultPrinter(AttackResultPrinter): +class ConsoleAttackResultPrinter(ConsoleAttackPrinterBase): """ - Console printer for attack results with enhanced formatting. + Framework console printer for attack results. - This printer formats attack results for console display with optional color coding, - proper indentation, text wrapping, and visual separators. Colors can be disabled - for consoles that don't support ANSI characters. + Thin subclass that implements data-fetching via CentralMemory. + All formatting logic lives in ConsoleAttackPrinterBase. """ def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True) -> None: @@ -28,579 +20,42 @@ def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: boo Initialize the console printer. Args: - width (int): Maximum width for text wrapping. Must be positive. - Defaults to 100. - indent_size (int): Number of spaces for indentation. Must be non-negative. - Defaults to 2. - enable_colors (bool): Whether to enable ANSI color output. When False, - all output will be plain text without colors. Defaults to True. - - Raises: - ValueError: If width <= 0 or indent_size < 0. + width (int): Maximum width for text wrapping. Defaults to 100. + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. """ + super().__init__(width=width, indent_size=indent_size, enable_colors=enable_colors) self._memory = CentralMemory.get_memory_instance() - self._width = width - self._indent = " " * indent_size - self._enable_colors = enable_colors - - def _print_colored(self, text: str, *colors: str) -> None: - """ - Print text with color formatting if colors are enabled. - - Args: - text (str): The text to print. - *colors: Variable number of colorama color constants to apply. - """ - if self._enable_colors and colors: - color_prefix = "".join(colors) - print(f"{color_prefix}{text}{Style.RESET_ALL}") - else: - print(text) - - async def print_result_async( - self, - result: AttackResult, - *, - include_auxiliary_scores: bool = False, - include_pruned_conversations: bool = False, - include_adversarial_conversation: bool = False, - ) -> None: - """ - Print the complete attack result to console. - - This method orchestrates the printing of all components of an attack result, - including header, summary, conversation history, metadata, and footer. - - Args: - result (AttackResult): The attack result to print. Must not be None. - include_auxiliary_scores (bool): Whether to include auxiliary scores in the output. - Defaults to False. - include_pruned_conversations (bool): Whether to include pruned conversations. - For each pruned conversation, only the last message and its score are shown. - Defaults to False. - include_adversarial_conversation (bool): Whether to include the adversarial - conversation (the red teaming LLM's reasoning). Only shown for successful - attacks to avoid overwhelming output. Defaults to False. - """ - # Print header with outcome - self._print_header(result) - - # Print summary information - await self.print_summary_async(result) - - # Print conversation - self._print_section_header("Conversation History with Objective Target") - await self.print_conversation_async(result, include_scores=include_auxiliary_scores) - # Print pruned conversations if requested - if include_pruned_conversations: - await self._print_pruned_conversations_async(result) - - # Print adversarial conversation if requested (only for successful attacks) - if include_adversarial_conversation: - await self._print_adversarial_conversation_async(result) - - # Print metadata if available - if result.metadata: - self._print_metadata(result.metadata) - - # Print footer - self._print_footer() - - async def print_conversation_async( - self, result: AttackResult, *, include_scores: bool = False, include_reasoning_trace: bool = False - ) -> None: - """ - Print the conversation history to console with enhanced formatting. - - Displays the full conversation between user and assistant, including: - - Turn numbers - - Role indicators (USER/ASSISTANT) - - Original and converted values when different - - Images if present - - Scores for each response - - Args: - result (AttackResult): The attack result containing the conversation_id. - Must have a valid conversation_id attribute. - include_scores (bool): Whether to include scores in the output. - Defaults to False. - include_reasoning_trace (bool): Whether to include model reasoning trace in the output - for applicable models. Defaults to False. + async def get_conversation_async(self, conversation_id: str) -> list[Message]: """ - if not result.conversation_id: - self._print_colored(f"{self._indent} No conversation ID available", Fore.YELLOW) - return - - messages = list(self._memory.get_conversation(conversation_id=result.conversation_id)) - - if not messages: - self._print_colored(f"{self._indent} No conversation found for ID: {result.conversation_id}", Fore.YELLOW) - return - - await self.print_messages_async( - messages=messages, - include_scores=include_scores, - include_reasoning_trace=include_reasoning_trace, - ) - - async def print_messages_async( - self, - messages: list[Any], - *, - include_scores: bool = False, - include_reasoning_trace: bool = False, - ) -> None: - """ - Print a list of messages to console with enhanced formatting. - - This method can be called directly with a list of Message objects, - without needing an AttackResult. Useful for printing prepended_conversation - or any other list of messages. - - Displays: - - Turn numbers - - Role indicators (USER/ASSISTANT/SYSTEM) - - Original and converted values when different - - Images if present - - Scores for each response (if include_scores=True) + Fetch conversation messages from CentralMemory. Args: - messages (list): List of Message objects to print. - include_scores (bool): Whether to include scores in the output. - Defaults to False. - include_reasoning_trace (bool): Whether to include model reasoning trace in the output - for applicable models. Defaults to False. - """ - if not messages: - self._print_colored(f"{self._indent} No messages to display.", Fore.YELLOW) - return - - turn_number = 0 - for message in messages: - # Increment turn number once per message with role="user" - if message.api_role == "user": - turn_number += 1 - # User message header - print() - self._print_colored("─" * self._width, Fore.BLUE) - self._print_colored(f"🔹 Turn {turn_number} - USER", Style.BRIGHT, Fore.BLUE) - self._print_colored("─" * self._width, Fore.BLUE) - elif message.api_role == "system": - # System message header (not counted as a turn) - print() - self._print_colored("─" * self._width, Fore.MAGENTA) - self._print_colored("🔧 SYSTEM", Style.BRIGHT, Fore.MAGENTA) - self._print_colored("─" * self._width, Fore.MAGENTA) - else: - # Assistant or other role message header - print() - self._print_colored("─" * self._width, Fore.YELLOW) - role_label = "ASSISTANT (SIMULATED)" if message.is_simulated else message.api_role.upper() - self._print_colored(f"🔸 {role_label}", Style.BRIGHT, Fore.YELLOW) - self._print_colored("─" * self._width, Fore.YELLOW) - - # Now print all pieces in this message - for piece in message.message_pieces: - # Reasoning pieces: show summary when include_reasoning_trace is set - if piece.original_value_data_type == "reasoning": - if include_reasoning_trace: - summary_text = self._extract_reasoning_summary(piece.original_value) - if summary_text: - self._print_colored(f"{self._indent}💭 Reasoning Summary:", Style.DIM, Fore.CYAN) - self._print_wrapped_text(summary_text, Fore.CYAN) - print() - continue - - # Blocked/filtered pieces: show clear indicator and partial content if available - if piece.is_blocked(): - self._print_colored(f"{self._indent}🚫 BLOCKED BY TARGET", Style.BRIGHT, Fore.RED) - partial_content = piece.prompt_metadata.get("partial_content") - if partial_content: - self._print_colored( - f"{self._indent}📝 Partial content (before filter triggered):", - Style.DIM, - Fore.CYAN, - ) - self._print_wrapped_text(str(partial_content), Fore.YELLOW) - else: - self._print_colored( - f"{self._indent}Content was blocked by the target's content filter.", - Style.DIM, - Fore.RED, - ) - - # Handle converted values for user and assistant messages - elif piece.converted_value != piece.original_value: - self._print_colored(f"{self._indent} Original:", Fore.CYAN) - self._print_wrapped_text(piece.original_value, Fore.WHITE) - print() - self._print_colored(f"{self._indent} Converted:", Fore.CYAN) - self._print_wrapped_text(piece.converted_value, Fore.WHITE) - elif piece.api_role == "user": - self._print_wrapped_text(piece.converted_value, Fore.BLUE) - elif piece.api_role == "system": - self._print_wrapped_text(piece.converted_value, Fore.MAGENTA) - else: - self._print_wrapped_text(piece.converted_value, Fore.YELLOW) - - # Display images if present - await display_image_response(piece) - - # Print scores with better formatting (only if scores are requested) - if include_scores: - scores = self._memory.get_prompt_scores(prompt_ids=[str(piece.id)]) - if scores: - print() - self._print_colored(f"{self._indent}📊 Scores:", Style.DIM, Fore.MAGENTA) - for score in scores: - self._print_score(score) - - print() - self._print_colored("─" * self._width, Fore.BLUE) - - def _extract_reasoning_summary(self, reasoning_value: str) -> str: - """ - Extract human-readable summary text from a reasoning piece's JSON value. - - Args: - reasoning_value (str): The JSON string stored in the reasoning piece. + conversation_id (str): The conversation ID to fetch. Returns: - str: The concatenated summary text, or empty string if no summary is present. + list[Message]: The conversation messages. """ - try: - data = json.loads(reasoning_value) - except (json.JSONDecodeError, TypeError): - return "" - - summary = data.get("summary") if isinstance(data, dict) else None - if not summary or not isinstance(summary, list): - return "" + return list(self._memory.get_conversation(conversation_id=conversation_id)) - parts = [item.get("text", "") for item in summary if isinstance(item, dict) and item.get("text")] - return "\n".join(parts) - - async def print_summary_async(self, result: AttackResult) -> None: + async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: """ - Print a summary of the attack result with enhanced formatting. - - Displays: - - Basic information (objective, attack type, conversation ID) - - Execution metrics (turns executed, execution time) - - Outcome information (status, reason) - - Final score if available + Fetch scores from CentralMemory. Args: - result (AttackResult): The attack result to summarize. Must contain - objective, attack_identifier, conversation_id, executed_turns, - execution_time_ms, outcome, and optionally outcome_reason and - last_score attributes. - """ - self._print_section_header("Attack Summary") - - # Basic information - self._print_colored(f"{self._indent}📋 Basic Information", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Objective: {result.objective}", Fore.CYAN) + prompt_ids (list[str]): The message piece IDs to fetch scores for. - # Extract attack type name from atomic_attack_identifier - attack_type = "Unknown" - attack_strategy_id = result.get_attack_strategy_identifier() - if attack_strategy_id: - attack_type = attack_strategy_id.class_name - - self._print_colored(f"{self._indent * 2}• Attack Type: {attack_type}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• Conversation ID: {result.conversation_id}", Fore.CYAN) - - # Execution metrics - print() - self._print_colored(f"{self._indent}⚡ Execution Metrics", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Turns Executed: {result.executed_turns}", Fore.GREEN) - self._print_colored( - f"{self._indent * 2}• Execution Time: {self._format_time(result.execution_time_ms)}", Fore.GREEN - ) - - # Outcome information - print() - self._print_colored(f"{self._indent}🎯 Outcome", Style.BRIGHT) - outcome_icon = self._get_outcome_icon(result.outcome) - outcome_color = self._get_outcome_color(result.outcome) - self._print_colored(f"{self._indent * 2}• Status: {outcome_icon} {result.outcome.value.upper()}", outcome_color) - - if result.outcome_reason: - self._print_colored(f"{self._indent * 2}• Reason: {result.outcome_reason}", Fore.WHITE) - - # Final score - if result.last_score: - print() - self._print_colored(f"{self._indent} Final Score", Style.BRIGHT) - self._print_score(result.last_score, indent_level=2) - - def _print_header(self, result: AttackResult) -> None: - """ - Print the header with outcome-based coloring and styling. - - Creates a visually prominent header that displays the attack outcome - with appropriate color coding and icons. - - Args: - result (AttackResult): The attack result containing the outcome. - Must have an outcome attribute of type AttackOutcome. - """ - color = self._get_outcome_color(result.outcome) - icon = self._get_outcome_icon(result.outcome) - - print() - self._print_colored("═" * self._width, color) - - # Center the header text - header_text = f"{icon} ATTACK RESULT: {result.outcome.value.upper()} {icon}" - self._print_colored(header_text.center(self._width), Style.BRIGHT, color) - self._print_colored("═" * self._width, color) - - def _print_footer(self) -> None: - """ - Print a footer with timestamp. - - Displays the current timestamp when the report was generated. - """ - timestamp = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S") - print() - self._print_colored("─" * self._width, Style.DIM, Fore.WHITE) - footer_text = f"Report generated at: {timestamp} UTC" - self._print_colored(footer_text.center(self._width), Style.DIM, Fore.WHITE) - - def _print_section_header(self, title: str) -> None: - """ - Print a section header with consistent styling. - - Creates a visually distinct section header with background color - and separator line. - - Args: - title (str): The title text to display in the section header. - """ - print() - self._print_colored(f" {title} ", Style.BRIGHT, Back.BLUE, Fore.WHITE) - self._print_colored("─" * self._width, Fore.BLUE) - - def _print_metadata(self, metadata: dict[str, Any]) -> None: - """ - Print metadata in a formatted way. - - Displays key-value pairs from the metadata dictionary in a - consistent bullet-point format. - - Args: - metadata (dict[str, Any]): Dictionary containing metadata key-value pairs. - Keys and values should be convertible to strings. - """ - self._print_section_header("Additional Metadata") - for key, value in metadata.items(): - self._print_colored(f"{self._indent}• {key}: {value}", Fore.CYAN) - - def _print_score(self, score: Score, indent_level: int = 3) -> None: - """ - Print a score with proper formatting. - - Displays score information including type, value, and rationale - with appropriate color coding based on score type. - - Args: - score (Score): Score object to be printed. - indent_level (int): Number of indent units to apply. Defaults to 3. - """ - indent = self._indent * indent_level - scorer_name = score.scorer_class_identifier.class_name - print(f"{indent}Scorer: {scorer_name}") - self._print_colored(f"{indent}• Category: {score.score_category or 'N/A'}", Fore.LIGHTMAGENTA_EX) - self._print_colored(f"{indent}• Type: {score.score_type}", Fore.CYAN) - - # Determine color based on score type and value - if score.score_type == "true_false": - score_color = Fore.GREEN if score.get_value() else Fore.RED - else: - score_color = Fore.YELLOW - - self._print_colored(f"{indent}• Value: {score.score_value}", score_color) - - if score.score_rationale: - print(f"{indent}• Rationale:") - # Create a custom wrapper for rationale with proper indentation - rationale_wrapper = textwrap.TextWrapper( - width=self._width - len(indent) - 2, # Adjust width to account for indentation - initial_indent=indent + " ", - subsequent_indent=indent + " ", - break_long_words=False, - break_on_hyphens=False, - ) - # Split by newlines first to preserve them - lines = score.score_rationale.split("\n") - for line in lines: - if line.strip(): # Only wrap non-empty lines - wrapped_lines = rationale_wrapper.wrap(line) - for wrapped_line in wrapped_lines: - self._print_colored(wrapped_line, Fore.WHITE) - else: # Print empty lines as-is to preserve formatting - self._print_colored(f"{indent} ") - - def _print_wrapped_text(self, text: str, color: str) -> None: - """ - Print text with proper wrapping and indentation, preserving newlines. - - Wraps long lines while preserving the original line breaks and - applying consistent indentation and coloring. - - Args: - text (str): The text to print. Can contain newlines. - color (str): Colorama color constant to apply to the text - (e.g., Fore.BLUE, Fore.RED). - """ - # Create a new wrapper for each text to ensure proper width calculation - text_wrapper = textwrap.TextWrapper( - width=self._width - len(self._indent), # Adjust width to account for indentation - initial_indent="", - subsequent_indent=self._indent, - break_long_words=True, # Allow breaking long words to prevent truncation - break_on_hyphens=True, - expand_tabs=False, - replace_whitespace=False, # Preserve whitespace formatting - ) - - # Split by newlines first to preserve them - lines = text.split("\n") - for line_num, line in enumerate(lines): - if line.strip(): # Only wrap non-empty lines - wrapped_lines = text_wrapper.wrap(line) - for i, wrapped_line in enumerate(wrapped_lines): - if line_num == 0 and i == 0: - self._print_colored(f"{self._indent}{wrapped_line}", color) - else: - self._print_colored(f"{self._indent * 2}{wrapped_line}", color) - else: # Print empty lines as-is to preserve formatting - self._print_colored(f"{self._indent}", color) - - async def _print_pruned_conversations_async(self, result: AttackResult) -> None: - """ - Print pruned conversations showing only the last message and score for each. - - Pruned conversations represent branches that were abandoned during the attack. - For each pruned conversation, only the final message and its associated score - are displayed to provide context without overwhelming output. - - Args: - result (AttackResult): The attack result containing related conversations. - """ - pruned_refs = result.get_conversations_by_type(ConversationType.PRUNED) - - if not pruned_refs: - return - - self._print_section_header(f"Pruned Conversations ({len(pruned_refs)} total)") - - for idx, ref in enumerate(pruned_refs, 1): - # Print conversation header with description if available - print() - self._print_colored("─" * self._width, Fore.RED) - label = f"🗑️ PRUNED #{idx}" - if ref.description: - label += f" - {ref.description}" - self._print_colored(label, Style.BRIGHT, Fore.RED) - self._print_colored("─" * self._width, Fore.RED) - - # Get the conversation messages - messages = list(self._memory.get_conversation(conversation_id=ref.conversation_id)) - - if not messages: - self._print_colored( - f"{self._indent}No messages found for conversation: {ref.conversation_id}", Fore.YELLOW - ) - continue - - # Get only the last message - last_message = messages[-1] - - # Print the last message - role_label = last_message.api_role.upper() - self._print_colored(f"{self._indent}Last Message ({role_label}):", Style.BRIGHT, Fore.WHITE) - - for piece in last_message.message_pieces: - self._print_wrapped_text(piece.converted_value, Fore.WHITE) - - # Print associated scores - scores = self._memory.get_prompt_scores(prompt_ids=[str(piece.id)]) - if scores: - print() - self._print_colored(f"{self._indent}📊 Score:", Style.DIM, Fore.MAGENTA) - for score in scores: - self._print_score(score) - - print() - self._print_colored("─" * self._width, Fore.RED) - - async def _print_adversarial_conversation_async(self, result: AttackResult) -> None: + Returns: + list[Score]: The scores. """ - Print the adversarial conversation for the best-scoring attack branch. - - The adversarial conversation shows the red teaming LLM's reasoning and - strategy development. For attacks with multiple adversarial conversations - (e.g., TAP), only the best-scoring branch's adversarial conversation is - shown if available. + return self._memory.get_prompt_scores(prompt_ids=prompt_ids) - Args: - result (AttackResult): The attack result containing related conversations. + async def display_image_async(self, piece: object) -> None: """ - adversarial_refs = result.get_conversations_by_type(ConversationType.ADVERSARIAL) - - if not adversarial_refs: - return - - self._print_section_header("Adversarial Conversation (Red Team LLM)") - - # Check if result has a best_adversarial_conversation_id (e.g., TAP attack) - # If so, only show that conversation instead of all adversarial conversations - best_adversarial_id = result.metadata.get("best_adversarial_conversation_id") - if best_adversarial_id: - # Filter to only the best adversarial conversation - adversarial_refs = [ref for ref in adversarial_refs if ref.conversation_id == best_adversarial_id] - if adversarial_refs: - self._print_colored( - f"{self._indent}📌 Showing best-scoring branch's adversarial conversation", - Style.DIM, - Fore.CYAN, - ) - - for ref in adversarial_refs: - if ref.description: - self._print_colored(f"{self._indent}📝 {ref.description}", Style.DIM, Fore.CYAN) - - messages = list(self._memory.get_conversation(conversation_id=ref.conversation_id)) - - if not messages: - self._print_colored( - f"{self._indent}No messages found for conversation: {ref.conversation_id}", Fore.YELLOW - ) - continue - - await self.print_messages_async(messages=messages, include_scores=False) - - def _get_outcome_color(self, outcome: AttackOutcome) -> str: - """ - Get the color for an outcome. - - Maps AttackOutcome enum values to appropriate Colorama color constants. + Display images using PIL/IPython in notebook environments. Args: - outcome (AttackOutcome): The attack outcome enum value. - - Returns: - str: Colorama color constant (Fore.GREEN, Fore.RED, Fore.YELLOW, - or Fore.WHITE for unknown outcomes). + piece: The message piece that may contain image data. """ - return str( - { - AttackOutcome.SUCCESS: Fore.GREEN, - AttackOutcome.FAILURE: Fore.RED, - AttackOutcome.UNDETERMINED: Fore.YELLOW, - }.get(outcome, Fore.WHITE) - ) + await display_image_response(piece) diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 123c83a918..ef58978f34 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -224,6 +224,93 @@ def __str__(self) -> str: """ return f"AttackResult: {self.conversation_id}: {self.outcome.value}: {self.objective[:50]}..." + def to_dict(self) -> dict[str, Any]: + """ + Serialize this attack result to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload suitable for REST APIs or persistence. + """ + from pyrit.models.conversation_reference import ConversationReference + + return { + "conversation_id": self.conversation_id, + "objective": self.objective, + "attack_result_id": self.attack_result_id, + "atomic_attack_identifier": ( + self.atomic_attack_identifier.to_dict() if self.atomic_attack_identifier else None + ), + "last_response": self.last_response.to_dict() if self.last_response else None, + "last_score": self.last_score.to_dict() if self.last_score else None, + "executed_turns": self.executed_turns, + "execution_time_ms": self.execution_time_ms, + "outcome": self.outcome.value, + "outcome_reason": self.outcome_reason, + "timestamp": self.timestamp.isoformat() if self.timestamp else None, + "related_conversations": [ + ref.to_dict() if isinstance(ref, ConversationReference) else ref + for ref in self.related_conversations + ], + "metadata": self.metadata, + "labels": self.labels, + "error_message": self.error_message, + "error_type": self.error_type, + "error_traceback": self.error_traceback, + "retry_events": [e.to_dict() for e in self.retry_events], + "total_retries": self.total_retries, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> AttackResult: + """ + Reconstruct an AttackResult from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + AttackResult: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + from pyrit.models.conversation_reference import ConversationReference + from pyrit.models.message_piece import MessagePiece + from pyrit.models.retry_event import RetryEvent + from pyrit.models.score import Score + + return cls( + conversation_id=data["conversation_id"], + objective=data["objective"], + attack_result_id=data.get("attack_result_id", str(uuid.uuid4())), + atomic_attack_identifier=( + ComponentIdentifier.from_dict(data["atomic_attack_identifier"]) + if data.get("atomic_attack_identifier") + else None + ), + last_response=( + MessagePiece.from_dict(data["last_response"]) if data.get("last_response") else None + ), + last_score=Score.from_dict(data["last_score"]) if data.get("last_score") else None, + executed_turns=data.get("executed_turns", 0), + execution_time_ms=data.get("execution_time_ms", 0), + outcome=AttackOutcome(data.get("outcome", "undetermined")), + outcome_reason=data.get("outcome_reason"), + timestamp=( + datetime.fromisoformat(data["timestamp"]) + if data.get("timestamp") + else datetime.now(timezone.utc) + ), + related_conversations={ + ConversationReference.from_dict(r) for r in data.get("related_conversations", []) + }, + metadata=data.get("metadata", {}), + labels=data.get("labels", {}), + error_message=data.get("error_message"), + error_type=data.get("error_type"), + error_traceback=data.get("error_traceback"), + retry_events=[RetryEvent.from_dict(e) for e in data.get("retry_events", [])], + total_retries=data.get("total_retries", 0), + ) + def _add_attack_identifier_compat(cls: type) -> type: """ diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/conversation_reference.py index 0932cca051..95c7b9d5eb 100644 --- a/pyrit/models/conversation_reference.py +++ b/pyrit/models/conversation_reference.py @@ -36,6 +36,36 @@ def __hash__(self) -> int: """ return hash(self.conversation_id) + def to_dict(self) -> dict[str, str | None]: + """ + Serialize to a JSON-compatible dictionary. + + Returns: + dict[str, str | None]: Dictionary with conversation_id, conversation_type, and description. + """ + return { + "conversation_id": self.conversation_id, + "conversation_type": self.conversation_type.value, + "description": self.description, + } + + @classmethod + def from_dict(cls, data: dict[str, str | None]) -> ConversationReference: + """ + Reconstruct a ConversationReference from a dictionary. + + Args: + data (dict[str, str | None]): Dictionary as produced by to_dict(). + + Returns: + ConversationReference: Reconstructed instance. + """ + return cls( + conversation_id=str(data["conversation_id"]), + conversation_type=ConversationType(data["conversation_type"]), + description=data.get("description"), + ) + def __eq__(self, other: object) -> bool: """ Compare two references by conversation ID. diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 16a77efaab..e77f707b0f 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -307,6 +307,44 @@ def to_dict(self) -> dict[str, object]: "converted_value_data_type": converted_value_data_type, } + def to_full_dict(self) -> dict[str, object]: + """ + Convert the message to a full dictionary representation including all piece details. + + Unlike to_dict() which flattens pieces into a single converted_value, this method + serializes each piece individually via MessagePiece.to_dict(). This is the format + expected by from_dict(). + + Returns: + dict[str, object]: Dictionary with 'role', 'is_simulated', 'conversation_id', + 'sequence', and 'pieces' (list of MessagePiece.to_dict() dicts). + """ + return { + "role": self.api_role, + "is_simulated": self.is_simulated, + "conversation_id": self.conversation_id, + "sequence": self.sequence, + "pieces": [piece.to_dict() for piece in self.message_pieces], + } + + @classmethod + def from_dict(cls, data: dict[str, object]) -> Message: + """ + Reconstruct a Message from a dictionary. + + Expects the format produced by to_full_dict(), which includes a 'pieces' key + containing a list of MessagePiece dictionaries. + + Args: + data (dict[str, object]): Dictionary as produced by to_full_dict(). + + Returns: + Message: Reconstructed instance. + """ + pieces_data = data.get("pieces", []) + message_pieces = [MessagePiece.from_dict(p) for p in pieces_data] + return cls(message_pieces, skip_validation=True) + @staticmethod def get_all_values(messages: Sequence[Message]) -> list[str]: """ diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 0f0cf9c1a0..4f756caaef 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -354,6 +354,67 @@ def __str__(self) -> str: __repr__ = __str__ + @classmethod + def from_dict(cls, data: dict[str, object]) -> MessagePiece: + """ + Reconstruct a MessagePiece from a dictionary. + + Args: + data (dict[str, object]): Dictionary as produced by to_dict(). + + Returns: + MessagePiece: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + from pyrit.models.score import Score + + return cls( + id=data.get("id"), + role=data.get("role", "user"), + conversation_id=data.get("conversation_id"), + sequence=data.get("sequence", -1), + timestamp=( + datetime.fromisoformat(str(data["timestamp"])) if data.get("timestamp") else None + ), + labels=data.get("labels"), + targeted_harm_categories=data.get("targeted_harm_categories"), + prompt_metadata=data.get("prompt_metadata"), + converter_identifiers=( + [ComponentIdentifier.from_dict(c) for c in data["converter_identifiers"]] + if data.get("converter_identifiers") + else None + ), + prompt_target_identifier=( + ComponentIdentifier.from_dict(data["prompt_target_identifier"]) + if data.get("prompt_target_identifier") + else None + ), + attack_identifier=( + ComponentIdentifier.from_dict(data["attack_identifier"]) + if data.get("attack_identifier") + else None + ), + scorer_identifier=( + ComponentIdentifier.from_dict(data["scorer_identifier"]) + if data.get("scorer_identifier") + else None + ), + original_value_data_type=data.get("original_value_data_type", "text"), + original_value=data.get("original_value", ""), + original_value_sha256=data.get("original_value_sha256"), + converted_value_data_type=data.get("converted_value_data_type"), + converted_value=data.get("converted_value"), + converted_value_sha256=data.get("converted_value_sha256"), + response_error=data.get("response_error", "none"), + originator=data.get("originator", "undefined"), + original_prompt_id=( + uuid.UUID(str(data["original_prompt_id"])) if data.get("original_prompt_id") else None + ), + scores=( + [Score.from_dict(s) for s in data["scores"]] if data.get("scores") else None + ), + ) + def __eq__(self, other: object) -> bool: """ Compare this message piece with another for semantic equality. diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 88a67f5991..f013291eb1 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import logging import uuid from datetime import datetime, timezone @@ -46,6 +48,40 @@ def __init__( self.pyrit_version = pyrit_version if pyrit_version is not None else pyrit.__version__ self.init_data = init_data + def to_dict(self) -> dict[str, Any]: + """ + Serialize to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload. + """ + return { + "name": self.name, + "description": self.description, + "version": self.version, + "pyrit_version": self.pyrit_version, + "init_data": self.init_data, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ScenarioIdentifier: + """ + Reconstruct a ScenarioIdentifier from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + ScenarioIdentifier: Reconstructed instance. + """ + return cls( + name=data["name"], + description=data.get("description", ""), + scenario_version=data.get("version", 1), + init_data=data.get("init_data"), + pyrit_version=data.get("pyrit_version"), + ) + ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"] @@ -260,3 +296,78 @@ def get_scorer_evaluation_metrics(self) -> "ScorerMetrics | None": eval_hash = ScorerEvaluationIdentifier(self.objective_scorer_identifier).eval_hash return find_objective_metrics_by_eval_hash(eval_hash=eval_hash) + + def to_dict(self) -> dict[str, Any]: + """ + Serialize this scenario result to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload suitable for REST APIs or persistence. + """ + return { + "id": str(self.id), + "scenario_identifier": self.scenario_identifier.to_dict(), + "objective_target_identifier": ( + self.objective_target_identifier.to_dict() if self.objective_target_identifier else None + ), + "objective_scorer_identifier": ( + self.objective_scorer_identifier.to_dict() if self.objective_scorer_identifier else None + ), + "scenario_run_state": self.scenario_run_state, + "attack_results": { + name: [r.to_dict() for r in results] for name, results in self.attack_results.items() + }, + "display_group_map": self._display_group_map, + "labels": self.labels, + "creation_time": self.creation_time.isoformat() if self.creation_time else None, + "completion_time": self.completion_time.isoformat() if self.completion_time else None, + "number_tries": self.number_tries, + "error_attack_result_ids": self.error_attack_result_ids, + "error_message": self.error_message, + "error_type": self.error_type, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ScenarioResult: + """ + Reconstruct a ScenarioResult from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + ScenarioResult: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + + return cls( + id=uuid.UUID(data["id"]) if data.get("id") else None, + scenario_identifier=ScenarioIdentifier.from_dict(data["scenario_identifier"]), + objective_target_identifier=( + ComponentIdentifier.from_dict(data["objective_target_identifier"]) + if data.get("objective_target_identifier") + else None + ), + objective_scorer_identifier=( + ComponentIdentifier.from_dict(data["objective_scorer_identifier"]) + if data.get("objective_scorer_identifier") + else None + ), + scenario_run_state=data.get("scenario_run_state", "CREATED"), + attack_results={ + name: [AttackResult.from_dict(r) for r in results] + for name, results in data.get("attack_results", {}).items() + }, + display_group_map=data.get("display_group_map"), + labels=data.get("labels"), + creation_time=( + datetime.fromisoformat(data["creation_time"]) if data.get("creation_time") else None + ), + completion_time=( + datetime.fromisoformat(data["completion_time"]) if data.get("completion_time") else None + ), + number_tries=data.get("number_tries", 0), + error_attack_result_ids=data.get("error_attack_result_ids"), + error_message=data.get("error_message"), + error_type=data.get("error_type"), + ) diff --git a/pyrit/models/score.py b/pyrit/models/score.py index 606ce89947..726a90d57b 100644 --- a/pyrit/models/score.py +++ b/pyrit/models/score.py @@ -194,6 +194,33 @@ def __str__(self) -> str: __repr__ = __str__ + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Score: + """ + Reconstruct a Score from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + Score: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + + return cls( + id=data.get("id"), + score_value=data["score_value"], + score_value_description=data.get("score_value_description", ""), + score_type=data["score_type"], + score_category=data.get("score_category"), + score_rationale=data.get("score_rationale", ""), + score_metadata=data.get("score_metadata"), + scorer_class_identifier=ComponentIdentifier.from_dict(data["scorer_class_identifier"]), + message_piece_id=data["message_piece_id"], + timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None, + objective=data.get("objective"), + ) + @dataclass class UnvalidatedScore: diff --git a/pyrit/printer/__init__.py b/pyrit/printer/__init__.py new file mode 100644 index 0000000000..426fbac9ea --- /dev/null +++ b/pyrit/printer/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Lightweight printer module for displaying attack, scenario, and scorer results. + +This module contains abstract base classes with all formatting logic. +Data-fetching operations (conversations, scores, scorer metrics) are abstract +methods that must be implemented by subclasses. + +Framework users: use the concrete implementations in pyrit.executor.attack.printer +and pyrit.scenario.printer which fetch data via CentralMemory. + +Thin clients: subclass the bases here and implement abstract methods via REST calls. +""" diff --git a/pyrit/printer/attack_result/__init__.py b/pyrit/printer/attack_result/__init__.py new file mode 100644 index 0000000000..47789c0055 --- /dev/null +++ b/pyrit/printer/attack_result/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Attack result printer base classes.""" diff --git a/pyrit/printer/attack_result/base.py b/pyrit/printer/attack_result/base.py new file mode 100644 index 0000000000..013abe1128 --- /dev/null +++ b/pyrit/printer/attack_result/base.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC, abstractmethod + +from pyrit.models import AttackOutcome, AttackResult, Message, Score + + +class AttackResultPrinterBase(ABC): + """ + Abstract base class for printing attack results. + + Contains all formatting logic. Subclasses only need to implement + the data-fetching methods: get_conversation_async and get_scores_async. + + Framework implementations fetch data via CentralMemory. + Thin-client implementations can fetch data via REST endpoints. + """ + + @abstractmethod + async def get_conversation_async(self, conversation_id: str) -> list[Message]: + """ + Fetch conversation messages for a given conversation ID. + + Args: + conversation_id (str): The conversation ID to fetch messages for. + + Returns: + list[Message]: The conversation messages. + """ + + @abstractmethod + async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: + """ + Fetch scores for given prompt piece IDs. + + Args: + prompt_ids (list[str]): The message piece IDs to fetch scores for. + + Returns: + list[Score]: The scores associated with the given piece IDs. + """ + + async def display_image_async(self, piece: object) -> None: + """ + Display an image from a message piece. No-op by default. + + Framework subclasses can override to use PIL/IPython for rendering. + Thin-client subclasses can override to render URLs or base64 data. + + Args: + piece: The message piece that may contain image data. + """ + + @staticmethod + def _get_outcome_icon(outcome: AttackOutcome) -> str: + """ + Get an icon for an outcome. + + Args: + outcome (AttackOutcome): The attack outcome enum value. + + Returns: + str: Unicode emoji string. + """ + return { + AttackOutcome.SUCCESS: "\u2705", + AttackOutcome.FAILURE: "\u274c", + AttackOutcome.UNDETERMINED: "\u2753", + }.get(outcome, "") + + @staticmethod + def _format_time(milliseconds: int) -> str: + """ + Format time in a human-readable way. + + Args: + milliseconds (int): Time duration in milliseconds. + + Returns: + str: Formatted time string (e.g., "500ms", "2.50s", "1m 30s"). + """ + if milliseconds < 1000: + return f"{milliseconds}ms" + + if milliseconds < 60000: + return f"{milliseconds / 1000:.2f}s" + + minutes = milliseconds // 60000 + seconds = (milliseconds % 60000) / 1000 + return f"{minutes}m {seconds:.0f}s" diff --git a/pyrit/printer/attack_result/console.py b/pyrit/printer/attack_result/console.py new file mode 100644 index 0000000000..3b3829dbb4 --- /dev/null +++ b/pyrit/printer/attack_result/console.py @@ -0,0 +1,484 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import textwrap +from datetime import datetime, timezone +from typing import Any + +from colorama import Back, Fore, Style + +from pyrit.models import AttackOutcome, AttackResult, ConversationType, Score +from pyrit.printer.attack_result.base import AttackResultPrinterBase + + +class ConsoleAttackPrinterBase(AttackResultPrinterBase): + """ + Console printer base for attack results with enhanced formatting. + + Contains all formatting logic. Subclasses implement get_conversation_async + and get_scores_async for data fetching. + """ + + def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True) -> None: + """ + Initialize the console printer. + + Args: + width (int): Maximum width for text wrapping. Defaults to 100. + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + """ + self._width = width + self._indent = " " * indent_size + self._enable_colors = enable_colors + + def _print_colored(self, text: str, *colors: str) -> None: + """ + Print text with color formatting if colors are enabled. + + Args: + text (str): The text to print. + *colors: Variable number of colorama color constants to apply. + """ + if self._enable_colors and colors: + color_prefix = "".join(colors) + print(f"{color_prefix}{text}{Style.RESET_ALL}") + else: + print(text) + + async def print_result_async( + self, + result: AttackResult, + *, + include_auxiliary_scores: bool = False, + include_pruned_conversations: bool = False, + include_adversarial_conversation: bool = False, + ) -> None: + """ + Print the complete attack result to console. + + Args: + result (AttackResult): The attack result to print. + include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. + include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. + include_adversarial_conversation (bool): Whether to include the adversarial conversation. + Defaults to False. + """ + self._print_header(result) + await self.print_summary_async(result) + + self._print_section_header("Conversation History with Objective Target") + await self.print_conversation_async(result, include_scores=include_auxiliary_scores) + + if include_pruned_conversations: + await self._print_pruned_conversations_async(result) + + if include_adversarial_conversation: + await self._print_adversarial_conversation_async(result) + + if result.metadata: + self._print_metadata(result.metadata) + + self._print_footer() + + async def print_conversation_async( + self, result: AttackResult, *, include_scores: bool = False, include_reasoning_trace: bool = False + ) -> None: + """ + Print the conversation history to console. + + Args: + result (AttackResult): The attack result containing the conversation_id. + include_scores (bool): Whether to include scores. Defaults to False. + include_reasoning_trace (bool): Whether to include model reasoning trace. Defaults to False. + """ + if not result.conversation_id: + self._print_colored(f"{self._indent} No conversation ID available", Fore.YELLOW) + return + + messages = await self.get_conversation_async(result.conversation_id) + + if not messages: + self._print_colored(f"{self._indent} No conversation found for ID: {result.conversation_id}", Fore.YELLOW) + return + + await self.print_messages_async( + messages=messages, + include_scores=include_scores, + include_reasoning_trace=include_reasoning_trace, + ) + + async def print_messages_async( + self, + messages: list[Any], + *, + include_scores: bool = False, + include_reasoning_trace: bool = False, + ) -> None: + """ + Print a list of messages to console with enhanced formatting. + + Args: + messages (list): List of Message objects to print. + include_scores (bool): Whether to include scores. Defaults to False. + include_reasoning_trace (bool): Whether to include model reasoning trace. Defaults to False. + """ + if not messages: + self._print_colored(f"{self._indent} No messages to display.", Fore.YELLOW) + return + + turn_number = 0 + for message in messages: + if message.api_role == "user": + turn_number += 1 + print() + self._print_colored("─" * self._width, Fore.BLUE) + self._print_colored(f"🔹 Turn {turn_number} - USER", Style.BRIGHT, Fore.BLUE) + self._print_colored("─" * self._width, Fore.BLUE) + elif message.api_role == "system": + print() + self._print_colored("─" * self._width, Fore.MAGENTA) + self._print_colored("🔧 SYSTEM", Style.BRIGHT, Fore.MAGENTA) + self._print_colored("─" * self._width, Fore.MAGENTA) + else: + print() + self._print_colored("─" * self._width, Fore.YELLOW) + role_label = "ASSISTANT (SIMULATED)" if message.is_simulated else message.api_role.upper() + self._print_colored(f"🔸 {role_label}", Style.BRIGHT, Fore.YELLOW) + self._print_colored("─" * self._width, Fore.YELLOW) + + for piece in message.message_pieces: + if piece.original_value_data_type == "reasoning": + if include_reasoning_trace: + summary_text = self._extract_reasoning_summary(piece.original_value) + if summary_text: + self._print_colored(f"{self._indent}💭 Reasoning Summary:", Style.DIM, Fore.CYAN) + self._print_wrapped_text(summary_text, Fore.CYAN) + print() + continue + + if piece.is_blocked(): + self._print_colored(f"{self._indent}🚫 BLOCKED BY TARGET", Style.BRIGHT, Fore.RED) + partial_content = piece.prompt_metadata.get("partial_content") + if partial_content: + self._print_colored( + f"{self._indent}📝 Partial content (before filter triggered):", + Style.DIM, + Fore.CYAN, + ) + self._print_wrapped_text(str(partial_content), Fore.YELLOW) + else: + self._print_colored( + f"{self._indent}Content was blocked by the target's content filter.", + Style.DIM, + Fore.RED, + ) + + elif piece.converted_value != piece.original_value: + self._print_colored(f"{self._indent} Original:", Fore.CYAN) + self._print_wrapped_text(piece.original_value, Fore.WHITE) + print() + self._print_colored(f"{self._indent} Converted:", Fore.CYAN) + self._print_wrapped_text(piece.converted_value, Fore.WHITE) + elif piece.api_role == "user": + self._print_wrapped_text(piece.converted_value, Fore.BLUE) + elif piece.api_role == "system": + self._print_wrapped_text(piece.converted_value, Fore.MAGENTA) + else: + self._print_wrapped_text(piece.converted_value, Fore.YELLOW) + + await self.display_image_async(piece) + + if include_scores: + scores = await self.get_scores_async(prompt_ids=[str(piece.id)]) + if scores: + print() + self._print_colored(f"{self._indent}📊 Scores:", Style.DIM, Fore.MAGENTA) + for score in scores: + self._print_score(score) + + print() + self._print_colored("─" * self._width, Fore.BLUE) + + def _extract_reasoning_summary(self, reasoning_value: str) -> str: + """ + Extract human-readable summary text from a reasoning piece's JSON value. + + Args: + reasoning_value (str): The JSON string stored in the reasoning piece. + + Returns: + str: The concatenated summary text, or empty string if no summary is present. + """ + try: + data = json.loads(reasoning_value) + except (json.JSONDecodeError, TypeError): + return "" + + summary = data.get("summary") if isinstance(data, dict) else None + if not summary or not isinstance(summary, list): + return "" + + parts = [item.get("text", "") for item in summary if isinstance(item, dict) and item.get("text")] + return "\n".join(parts) + + async def print_summary_async(self, result: AttackResult) -> None: + """ + Print a summary of the attack result. + + Args: + result (AttackResult): The attack result to summarize. + """ + self._print_section_header("Attack Summary") + + self._print_colored(f"{self._indent}📋 Basic Information", Style.BRIGHT) + self._print_colored(f"{self._indent * 2}• Objective: {result.objective}", Fore.CYAN) + + attack_type = "Unknown" + attack_strategy_id = result.get_attack_strategy_identifier() + if attack_strategy_id: + attack_type = attack_strategy_id.class_name + + self._print_colored(f"{self._indent * 2}• Attack Type: {attack_type}", Fore.CYAN) + self._print_colored(f"{self._indent * 2}• Conversation ID: {result.conversation_id}", Fore.CYAN) + + print() + self._print_colored(f"{self._indent}⚡ Execution Metrics", Style.BRIGHT) + self._print_colored(f"{self._indent * 2}• Turns Executed: {result.executed_turns}", Fore.GREEN) + self._print_colored( + f"{self._indent * 2}• Execution Time: {self._format_time(result.execution_time_ms)}", Fore.GREEN + ) + + print() + self._print_colored(f"{self._indent}🎯 Outcome", Style.BRIGHT) + outcome_icon = self._get_outcome_icon(result.outcome) + outcome_color = self._get_outcome_color(result.outcome) + self._print_colored(f"{self._indent * 2}• Status: {outcome_icon} {result.outcome.value.upper()}", outcome_color) + + if result.outcome_reason: + self._print_colored(f"{self._indent * 2}• Reason: {result.outcome_reason}", Fore.WHITE) + + if result.last_score: + print() + self._print_colored(f"{self._indent} Final Score", Style.BRIGHT) + self._print_score(result.last_score, indent_level=2) + + def _print_header(self, result: AttackResult) -> None: + """ + Print the header with outcome-based coloring. + + Args: + result (AttackResult): The attack result containing the outcome. + """ + color = self._get_outcome_color(result.outcome) + icon = self._get_outcome_icon(result.outcome) + + print() + self._print_colored("═" * self._width, color) + header_text = f"{icon} ATTACK RESULT: {result.outcome.value.upper()} {icon}" + self._print_colored(header_text.center(self._width), Style.BRIGHT, color) + self._print_colored("═" * self._width, color) + + def _print_footer(self) -> None: + """Print a footer with timestamp.""" + timestamp = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + print() + self._print_colored("─" * self._width, Style.DIM, Fore.WHITE) + footer_text = f"Report generated at: {timestamp} UTC" + self._print_colored(footer_text.center(self._width), Style.DIM, Fore.WHITE) + + def _print_section_header(self, title: str) -> None: + """ + Print a section header with consistent styling. + + Args: + title (str): The title text to display. + """ + print() + self._print_colored(f" {title} ", Style.BRIGHT, Back.BLUE, Fore.WHITE) + self._print_colored("─" * self._width, Fore.BLUE) + + def _print_metadata(self, metadata: dict[str, Any]) -> None: + """ + Print metadata in a formatted way. + + Args: + metadata (dict[str, Any]): Dictionary containing metadata key-value pairs. + """ + self._print_section_header("Additional Metadata") + for key, value in metadata.items(): + self._print_colored(f"{self._indent}• {key}: {value}", Fore.CYAN) + + def _print_score(self, score: Score, indent_level: int = 3) -> None: + """ + Print a score with proper formatting. + + Args: + score (Score): Score object to be printed. + indent_level (int): Number of indent units to apply. Defaults to 3. + """ + indent = self._indent * indent_level + scorer_name = score.scorer_class_identifier.class_name + print(f"{indent}Scorer: {scorer_name}") + self._print_colored(f"{indent}• Category: {score.score_category or 'N/A'}", Fore.LIGHTMAGENTA_EX) + self._print_colored(f"{indent}• Type: {score.score_type}", Fore.CYAN) + + if score.score_type == "true_false": + score_color = Fore.GREEN if score.get_value() else Fore.RED + else: + score_color = Fore.YELLOW + + self._print_colored(f"{indent}• Value: {score.score_value}", score_color) + + if score.score_rationale: + print(f"{indent}• Rationale:") + rationale_wrapper = textwrap.TextWrapper( + width=self._width - len(indent) - 2, + initial_indent=indent + " ", + subsequent_indent=indent + " ", + break_long_words=False, + break_on_hyphens=False, + ) + lines = score.score_rationale.split("\n") + for line in lines: + if line.strip(): + wrapped_lines = rationale_wrapper.wrap(line) + for wrapped_line in wrapped_lines: + self._print_colored(wrapped_line, Fore.WHITE) + else: + self._print_colored(f"{indent} ") + + def _print_wrapped_text(self, text: str, color: str) -> None: + """ + Print text with proper wrapping and indentation, preserving newlines. + + Args: + text (str): The text to print. + color (str): Colorama color constant to apply. + """ + text_wrapper = textwrap.TextWrapper( + width=self._width - len(self._indent), + initial_indent="", + subsequent_indent=self._indent, + break_long_words=True, + break_on_hyphens=True, + expand_tabs=False, + replace_whitespace=False, + ) + + lines = text.split("\n") + for line_num, line in enumerate(lines): + if line.strip(): + wrapped_lines = text_wrapper.wrap(line) + for i, wrapped_line in enumerate(wrapped_lines): + if line_num == 0 and i == 0: + self._print_colored(f"{self._indent}{wrapped_line}", color) + else: + self._print_colored(f"{self._indent * 2}{wrapped_line}", color) + else: + self._print_colored(f"{self._indent}", color) + + async def _print_pruned_conversations_async(self, result: AttackResult) -> None: + """ + Print pruned conversations showing only the last message and score for each. + + Args: + result (AttackResult): The attack result containing related conversations. + """ + pruned_refs = result.get_conversations_by_type(ConversationType.PRUNED) + + if not pruned_refs: + return + + self._print_section_header(f"Pruned Conversations ({len(pruned_refs)} total)") + + for idx, ref in enumerate(pruned_refs, 1): + print() + self._print_colored("─" * self._width, Fore.RED) + label = f"🗑️ PRUNED #{idx}" + if ref.description: + label += f" - {ref.description}" + self._print_colored(label, Style.BRIGHT, Fore.RED) + self._print_colored("─" * self._width, Fore.RED) + + messages = await self.get_conversation_async(ref.conversation_id) + + if not messages: + self._print_colored( + f"{self._indent}No messages found for conversation: {ref.conversation_id}", Fore.YELLOW + ) + continue + + last_message = messages[-1] + role_label = last_message.api_role.upper() + self._print_colored(f"{self._indent}Last Message ({role_label}):", Style.BRIGHT, Fore.WHITE) + + for piece in last_message.message_pieces: + self._print_wrapped_text(piece.converted_value, Fore.WHITE) + + scores = await self.get_scores_async(prompt_ids=[str(piece.id)]) + if scores: + print() + self._print_colored(f"{self._indent}📊 Score:", Style.DIM, Fore.MAGENTA) + for score in scores: + self._print_score(score) + + print() + self._print_colored("─" * self._width, Fore.RED) + + async def _print_adversarial_conversation_async(self, result: AttackResult) -> None: + """ + Print the adversarial conversation for the best-scoring attack branch. + + Args: + result (AttackResult): The attack result containing related conversations. + """ + adversarial_refs = result.get_conversations_by_type(ConversationType.ADVERSARIAL) + + if not adversarial_refs: + return + + self._print_section_header("Adversarial Conversation (Red Team LLM)") + + best_adversarial_id = result.metadata.get("best_adversarial_conversation_id") + if best_adversarial_id: + adversarial_refs = [ref for ref in adversarial_refs if ref.conversation_id == best_adversarial_id] + if adversarial_refs: + self._print_colored( + f"{self._indent}📌 Showing best-scoring branch's adversarial conversation", + Style.DIM, + Fore.CYAN, + ) + + for ref in adversarial_refs: + if ref.description: + self._print_colored(f"{self._indent}📝 {ref.description}", Style.DIM, Fore.CYAN) + + messages = await self.get_conversation_async(ref.conversation_id) + + if not messages: + self._print_colored( + f"{self._indent}No messages found for conversation: {ref.conversation_id}", Fore.YELLOW + ) + continue + + await self.print_messages_async(messages=messages, include_scores=False) + + def _get_outcome_color(self, outcome: AttackOutcome) -> str: + """ + Get the color for an outcome. + + Args: + outcome (AttackOutcome): The attack outcome enum value. + + Returns: + str: Colorama color constant. + """ + return str( + { + AttackOutcome.SUCCESS: Fore.GREEN, + AttackOutcome.FAILURE: Fore.RED, + AttackOutcome.UNDETERMINED: Fore.YELLOW, + }.get(outcome, Fore.WHITE) + ) diff --git a/pyrit/printer/scenario_result/__init__.py b/pyrit/printer/scenario_result/__init__.py new file mode 100644 index 0000000000..0def8141c0 --- /dev/null +++ b/pyrit/printer/scenario_result/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Scenario result printer base classes.""" diff --git a/pyrit/printer/scenario_result/base.py b/pyrit/printer/scenario_result/base.py new file mode 100644 index 0000000000..028a855bf1 --- /dev/null +++ b/pyrit/printer/scenario_result/base.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC, abstractmethod + +from pyrit.models.scenario_result import ScenarioResult + + +class ScenarioResultPrinterBase(ABC): + """ + Abstract base class for printing scenario results. + + Contains formatting logic. Subclasses may need to provide scorer + printer implementations via get_scorer_printer(). + """ + + @abstractmethod + async def print_summary_async(self, result: ScenarioResult) -> None: + """ + Print a summary of the scenario result with per-strategy breakdown. + + Args: + result (ScenarioResult): The scenario result to summarize. + """ diff --git a/pyrit/printer/scenario_result/console.py b/pyrit/printer/scenario_result/console.py new file mode 100644 index 0000000000..51cc7f307c --- /dev/null +++ b/pyrit/printer/scenario_result/console.py @@ -0,0 +1,178 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import textwrap +from typing import Optional + +from colorama import Fore, Style + +from pyrit.models import AttackOutcome +from pyrit.models.scenario_result import ScenarioResult +from pyrit.printer.scenario_result.base import ScenarioResultPrinterBase +from pyrit.printer.scorer.base import ScorerPrinterBase + + +class ConsoleScenarioPrinterBase(ScenarioResultPrinterBase): + """ + Console printer base for scenario results with enhanced formatting. + + Contains all formatting logic. Accepts a ScorerPrinterBase for printing + scorer information. Subclasses can provide a concrete scorer printer. + """ + + def __init__( + self, + *, + width: int = 100, + indent_size: int = 2, + enable_colors: bool = True, + scorer_printer: Optional[ScorerPrinterBase] = None, + ) -> None: + """ + Initialize the console printer. + + Args: + width (int): Maximum width for text wrapping. Defaults to 100. + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + scorer_printer (Optional[ScorerPrinterBase]): Printer for scorer information. + """ + self._width = width + self._indent = " " * indent_size + self._enable_colors = enable_colors + self._scorer_printer = scorer_printer + + def _print_colored(self, text: str, *colors: str) -> None: + """ + Print text with color formatting if colors are enabled. + + Args: + text (str): The text to print. + *colors: Variable number of colorama color constants to apply. + """ + if self._enable_colors and colors: + color_prefix = "".join(colors) + print(f"{color_prefix}{text}{Style.RESET_ALL}") + else: + print(text) + + def _print_section_header(self, title: str) -> None: + """ + Print a section header with visual separation. + + Args: + title (str): The section title to display. + """ + print() + self._print_colored(f"▼ {title}", Style.BRIGHT, Fore.CYAN) + self._print_colored("─" * self._width, Fore.CYAN) + + async def print_summary_async(self, result: ScenarioResult) -> None: + """ + Print a summary of the scenario result with per-group breakdown. + + Args: + result (ScenarioResult): The scenario result to summarize. + """ + self._print_header(result) + + self._print_section_header("Scenario Information") + self._print_colored(f"{self._indent}📋 Scenario Details", Style.BRIGHT) + self._print_colored(f"{self._indent * 2}• Name: {result.scenario_identifier.name}", Fore.CYAN) + self._print_colored(f"{self._indent * 2}• Scenario Version: {result.scenario_identifier.version}", Fore.CYAN) + self._print_colored(f"{self._indent * 2}• PyRIT Version: {result.scenario_identifier.pyrit_version}", Fore.CYAN) + + if result.scenario_identifier.description: + self._print_colored(f"{self._indent * 2}• Description:", Fore.CYAN) + desc_indent = self._indent * 4 + available_width = 120 - len(desc_indent) + wrapped_lines = textwrap.wrap( + result.scenario_identifier.description, width=available_width, break_long_words=False + ) + for line in wrapped_lines: + self._print_colored(f"{desc_indent}{line}", Fore.CYAN) + + print() + self._print_colored(f"{self._indent}🎯 Target Information", Style.BRIGHT) + target_id = result.objective_target_identifier + target_type = target_id.class_name if target_id else "Unknown" + target_model = target_id.params.get("model_name", "Unknown") if target_id else "Unknown" + target_endpoint = target_id.params.get("endpoint", "Unknown") if target_id else "Unknown" + + self._print_colored(f"{self._indent * 2}• Target Type: {target_type}", Fore.CYAN) + self._print_colored(f"{self._indent * 2}• Target Model: {target_model}", Fore.CYAN) + self._print_colored(f"{self._indent * 2}• Target Endpoint: {target_endpoint}", Fore.CYAN) + + scorer_identifier = result.objective_scorer_identifier + if scorer_identifier and self._scorer_printer: + self._scorer_printer.print_objective_scorer(scorer_identifier=scorer_identifier) + + self._print_section_header("Overall Statistics") + total_results = sum(len(results) for results in result.attack_results.values()) + total_strategies = len(result.get_strategies_used()) + overall_rate = result.objective_achieved_rate() + + self._print_colored(f"{self._indent}📈 Summary", Style.BRIGHT) + self._print_colored(f"{self._indent * 2}• Total Strategies: {total_strategies}", Fore.GREEN) + self._print_colored(f"{self._indent * 2}• Total Attack Results: {total_results}", Fore.GREEN) + self._print_colored( + f"{self._indent * 2}• Overall Success Rate: {overall_rate}%", self._get_rate_color(overall_rate) + ) + + objectives = result.get_objectives() + self._print_colored(f"{self._indent * 2}• Unique Objectives: {len(objectives)}", Fore.GREEN) + + self._print_section_header("Per-Group Breakdown") + display_groups = result.get_display_groups() + + for group_name, group_results in display_groups.items(): + total_group = len(group_results) + if total_group == 0: + group_rate = 0 + else: + successful = sum(1 for r in group_results if r.outcome == AttackOutcome.SUCCESS) + group_rate = int((successful / total_group) * 100) + + print() + self._print_colored(f"{self._indent}🔸 Group: {group_name}", Style.BRIGHT) + self._print_colored(f"{self._indent * 2}• Number of Results: {total_group}", Fore.YELLOW) + self._print_colored(f"{self._indent * 2}• Success Rate: {group_rate}%", self._get_rate_color(group_rate)) + + self._print_footer() + + def _print_header(self, result: ScenarioResult) -> None: + """ + Print the header with scenario name. + + Args: + result (ScenarioResult): The scenario result. + """ + print() + self._print_colored("=" * self._width, Fore.CYAN) + header_text = f"📊 SCENARIO RESULTS: {result.scenario_identifier.name}" + self._print_colored(header_text.center(self._width), Style.BRIGHT, Fore.CYAN) + self._print_colored("=" * self._width, Fore.CYAN) + + def _print_footer(self) -> None: + """Print a footer separator.""" + print() + self._print_colored("=" * self._width, Fore.CYAN) + print() + + def _get_rate_color(self, rate: int) -> str: + """ + Get color based on success rate. + + Args: + rate (int): Success rate percentage (0-100). + + Returns: + str: Colorama color constant. + """ + if rate >= 75: + return str(Fore.RED) + if rate >= 50: + return str(Fore.YELLOW) + if rate >= 25: + return str(Fore.CYAN) + return str(Fore.GREEN) diff --git a/pyrit/printer/scorer/__init__.py b/pyrit/printer/scorer/__init__.py new file mode 100644 index 0000000000..7c7c7bd417 --- /dev/null +++ b/pyrit/printer/scorer/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Scorer printer base classes.""" diff --git a/pyrit/printer/scorer/base.py b/pyrit/printer/scorer/base.py new file mode 100644 index 0000000000..1a72200d6d --- /dev/null +++ b/pyrit/printer/scorer/base.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC, abstractmethod +from typing import Any + +from pyrit.identifiers import ComponentIdentifier + + +class ScorerPrinterBase(ABC): + """ + Abstract base class for printing scorer information. + + Subclasses implement get_objective_metrics and get_harm_metrics + for data fetching. Framework uses the scorer registry; thin clients + can use REST calls. + """ + + @abstractmethod + def get_objective_metrics(self, *, eval_hash: str) -> Any: + """ + Fetch objective scorer evaluation metrics by eval hash. + + Args: + eval_hash (str): The evaluation hash to look up. + + Returns: + ObjectiveScorerMetrics or None: The metrics, or None if not found. + """ + + @abstractmethod + def get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: + """ + Fetch harm scorer evaluation metrics by eval hash and category. + + Args: + eval_hash (str): The evaluation hash to look up. + harm_category (str): The harm category for metrics lookup. + + Returns: + HarmScorerMetrics or None: The metrics, or None if not found. + """ + + @abstractmethod + def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: + """ + Print objective scorer information including type, nested scorers, and evaluation metrics. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. + """ + + @abstractmethod + def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: + """ + Print harm scorer information including type, nested scorers, and evaluation metrics. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. + harm_category (str): The harm category for looking up metrics. + """ diff --git a/pyrit/printer/scorer/console.py b/pyrit/printer/scorer/console.py new file mode 100644 index 0000000000..754a56be6f --- /dev/null +++ b/pyrit/printer/scorer/console.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Any, Optional + +from colorama import Fore, Style + +from pyrit.identifiers import ComponentIdentifier +from pyrit.printer.scorer.base import ScorerPrinterBase + + +class ConsoleScorerPrinterBase(ScorerPrinterBase): + """ + Console printer base for scorer information with enhanced formatting. + + Contains all formatting logic. Subclasses implement get_objective_metrics + and get_harm_metrics for data fetching. + """ + + _SCORER_DISPLAY_PARAMS = frozenset({"scorer_type", "score_aggregator"}) + _TARGET_DISPLAY_PARAMS = frozenset({"model_name", "temperature"}) + + def __init__(self, *, indent_size: int = 2, enable_colors: bool = True) -> None: + """ + Initialize the console scorer printer. + + Args: + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + """ + if indent_size < 0: + raise ValueError("indent_size must be non-negative") + self._indent = " " * indent_size + self._enable_colors = enable_colors + + def _print_colored(self, text: str, *colors: str) -> None: + """ + Print text with color formatting if colors are enabled. + + Args: + text (str): The text to print. + *colors: Variable number of colorama color constants to apply. + """ + if self._enable_colors and colors: + color_prefix = "".join(colors) + print(f"{color_prefix}{text}{Style.RESET_ALL}") + else: + print(text) + + def _get_quality_color( + self, value: float, *, higher_is_better: bool, good_threshold: float, bad_threshold: float + ) -> str: + """ + Determine the color based on metric quality thresholds. + + Args: + value (float): The metric value to evaluate. + higher_is_better (bool): If True, higher values are better. + good_threshold (float): The threshold for "good" (green) values. + bad_threshold (float): The threshold for "bad" (red) values. + + Returns: + str: The colorama color constant to use. + """ + if higher_is_better: + if value >= good_threshold: + return str(Fore.GREEN) + if value < bad_threshold: + return str(Fore.RED) + return str(Fore.CYAN) + if value <= good_threshold: + return str(Fore.GREEN) + if value > bad_threshold: + return str(Fore.RED) + return str(Fore.CYAN) + + def _compute_eval_hash(self, scorer_identifier: ComponentIdentifier) -> str: + """ + Compute the evaluation hash for a scorer identifier. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier. + + Returns: + str: The evaluation hash string. + """ + from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier + + return ScorerEvaluationIdentifier(scorer_identifier).eval_hash + + def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: + """ + Print objective scorer information. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. + """ + print() + self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) + self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) + self._print_scorer_info(scorer_identifier, indent_level=3) + + eval_hash = self._compute_eval_hash(scorer_identifier) + metrics = self.get_objective_metrics(eval_hash=eval_hash) + self._print_objective_metrics(metrics) + + def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: + """ + Print harm scorer information. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. + harm_category (str): The harm category for looking up metrics. + """ + print() + self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) + self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) + self._print_scorer_info(scorer_identifier, indent_level=3) + + eval_hash = self._compute_eval_hash(scorer_identifier) + metrics = self.get_harm_metrics(eval_hash=eval_hash, harm_category=harm_category) + self._print_harm_metrics(metrics) + + def _print_scorer_info(self, scorer_identifier: ComponentIdentifier, *, indent_level: int = 2) -> None: + """ + Print scorer information including nested sub-scorers. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier. + indent_level (int): Current indentation level. + """ + indent = self._indent * indent_level + + self._print_colored(f"{indent}• Scorer Type: {scorer_identifier.class_name}", Fore.CYAN) + + for key, value in scorer_identifier.params.items(): + if key in self._SCORER_DISPLAY_PARAMS and value is not None: + self._print_colored(f"{indent}• {key}: {value}", Fore.CYAN) + + prompt_target = scorer_identifier.get_child("prompt_target") + if prompt_target: + for key, value in prompt_target.params.items(): + if key in self._TARGET_DISPLAY_PARAMS and value is not None: + self._print_colored(f"{indent}• {key}: {value}", Fore.CYAN) + + sub_scorers = scorer_identifier.get_child_list("sub_scorers") + if sub_scorers: + self._print_colored(f"{indent} └─ Composite of {len(sub_scorers)} scorer(s):", Fore.CYAN) + for sub_scorer_id in sub_scorers: + self._print_scorer_info(sub_scorer_id, indent_level=indent_level + 3) + + def _print_objective_metrics(self, metrics: Optional[Any]) -> None: + """ + Print objective scorer evaluation metrics. + + Args: + metrics: The metrics to print, or None if not available. + """ + if metrics is None: + print() + self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) + self._print_colored( + f"{self._indent * 3}Official evaluation has not been run yet for this specific configuration", + Fore.YELLOW, + ) + return + + print() + self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) + + accuracy_color = self._get_quality_color( + metrics.accuracy, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 + ) + self._print_colored(f"{self._indent * 3}• Accuracy: {metrics.accuracy:.2%}", accuracy_color) + + if metrics.accuracy_standard_error is not None: + self._print_colored( + f"{self._indent * 3}• Accuracy Std Error: ±{metrics.accuracy_standard_error:.4f}", Fore.CYAN + ) + + if metrics.f1_score is not None: + f1_color = self._get_quality_color( + metrics.f1_score, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 + ) + self._print_colored(f"{self._indent * 3}• F1 Score: {metrics.f1_score:.4f}", f1_color) + + if metrics.precision is not None: + precision_color = self._get_quality_color( + metrics.precision, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 + ) + self._print_colored(f"{self._indent * 3}• Precision: {metrics.precision:.4f}", precision_color) + + if metrics.recall is not None: + recall_color = self._get_quality_color( + metrics.recall, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 + ) + self._print_colored(f"{self._indent * 3}• Recall: {metrics.recall:.4f}", recall_color) + + if metrics.average_score_time_seconds is not None: + time_color = self._get_quality_color( + metrics.average_score_time_seconds, higher_is_better=False, good_threshold=0.5, bad_threshold=3.0 + ) + self._print_colored( + f"{self._indent * 3}• Average Score Time: {metrics.average_score_time_seconds:.2f}s", time_color + ) + + def _print_harm_metrics(self, metrics: Optional[Any]) -> None: + """ + Print harm scorer evaluation metrics. + + Args: + metrics: The metrics to print, or None if not available. + """ + if metrics is None: + print() + self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) + self._print_colored( + f"{self._indent * 3}Official evaluation has not been run yet for this specific configuration", + Fore.YELLOW, + ) + return + + print() + self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) + + mae_color = self._get_quality_color( + metrics.mean_absolute_error, higher_is_better=False, good_threshold=0.1, bad_threshold=0.25 + ) + self._print_colored(f"{self._indent * 3}• Mean Absolute Error: {metrics.mean_absolute_error:.4f}", mae_color) + + if metrics.mae_standard_error is not None: + self._print_colored(f"{self._indent * 3}• MAE Std Error: ±{metrics.mae_standard_error:.4f}", Fore.CYAN) + + if metrics.krippendorff_alpha_combined is not None: + alpha_color = self._get_quality_color( + metrics.krippendorff_alpha_combined, higher_is_better=True, good_threshold=0.8, bad_threshold=0.6 + ) + self._print_colored( + f"{self._indent * 3}• Krippendorff Alpha (Combined): {metrics.krippendorff_alpha_combined:.4f}", + alpha_color, + ) + + if metrics.krippendorff_alpha_model is not None: + alpha_model_color = self._get_quality_color( + metrics.krippendorff_alpha_model, higher_is_better=True, good_threshold=0.8, bad_threshold=0.6 + ) + self._print_colored( + f"{self._indent * 3}• Krippendorff Alpha (Model): {metrics.krippendorff_alpha_model:.4f}", + alpha_model_color, + ) + + if metrics.average_score_time_seconds is not None: + time_color = self._get_quality_color( + metrics.average_score_time_seconds, higher_is_better=False, good_threshold=1.0, bad_threshold=3.0 + ) + self._print_colored( + f"{self._indent * 3}• Average Score Time: {metrics.average_score_time_seconds:.2f}s", time_color + ) diff --git a/pyrit/scenario/printer/console_printer.py b/pyrit/scenario/printer/console_printer.py index 0ec99e7b5b..3679f2b99c 100644 --- a/pyrit/scenario/printer/console_printer.py +++ b/pyrit/scenario/printer/console_printer.py @@ -1,24 +1,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import textwrap from typing import Optional -from colorama import Fore, Style +from pyrit.printer.scenario_result.console import ConsoleScenarioPrinterBase +from pyrit.printer.scorer.base import ScorerPrinterBase +from pyrit.score.printer import ConsoleScorerPrinter -from pyrit.models import AttackOutcome -from pyrit.models.scenario_result import ScenarioResult -from pyrit.scenario.printer.scenario_result_printer import ScenarioResultPrinter -from pyrit.score.printer import ConsoleScorerPrinter, ScorerPrinter - -class ConsoleScenarioResultPrinter(ScenarioResultPrinter): +class ConsoleScenarioResultPrinter(ConsoleScenarioPrinterBase): """ - Console printer for scenario results with enhanced formatting. + Framework console printer for scenario results. - This printer formats scenario results for console display with optional color coding, - proper indentation, and visual separators. Colors can be disabled for consoles - that don't support ANSI characters. + Thin subclass that provides the framework's ConsoleScorerPrinter + for scorer information. All formatting logic lives in ConsoleScenarioPrinterBase. """ def __init__( @@ -27,180 +22,23 @@ def __init__( width: int = 100, indent_size: int = 2, enable_colors: bool = True, - scorer_printer: Optional[ScorerPrinter] = None, + scorer_printer: Optional[ScorerPrinterBase] = None, ) -> None: """ Initialize the console printer. Args: - width (int): Maximum width for text wrapping. Must be positive. - Defaults to 100. - indent_size (int): Number of spaces for indentation. Must be non-negative. - Defaults to 2. - enable_colors (bool): Whether to enable ANSI color output. When False, - all output will be plain text without colors. Defaults to True. - scorer_printer (Optional[ScorerPrinter]): Printer for scorer information. + width (int): Maximum width for text wrapping. Defaults to 100. + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + scorer_printer (Optional[ScorerPrinterBase]): Printer for scorer information. If not provided, a ConsoleScorerPrinter with matching settings is created. - - Raises: - ValueError: If width <= 0 or indent_size < 0. """ - self._width = width - self._indent = " " * indent_size - self._enable_colors = enable_colors - self._scorer_printer = scorer_printer or ConsoleScorerPrinter( - indent_size=indent_size, enable_colors=enable_colors + if scorer_printer is None: + scorer_printer = ConsoleScorerPrinter(indent_size=indent_size, enable_colors=enable_colors) + super().__init__( + width=width, + indent_size=indent_size, + enable_colors=enable_colors, + scorer_printer=scorer_printer, ) - - def _print_colored(self, text: str, *colors: str) -> None: - """ - Print text with color formatting if colors are enabled. - - Args: - text (str): The text to print. - *colors: Variable number of colorama color constants to apply. - """ - if self._enable_colors and colors: - color_prefix = "".join(colors) - print(f"{color_prefix}{text}{Style.RESET_ALL}") - else: - print(text) - - def _print_section_header(self, title: str) -> None: - """ - Print a section header with visual separation. - - Args: - title (str): The section title to display. - """ - print() - self._print_colored(f"▼ {title}", Style.BRIGHT, Fore.CYAN) - self._print_colored("─" * self._width, Fore.CYAN) - - async def print_summary_async(self, result: ScenarioResult) -> None: - """ - Print a summary of the scenario result with per-group breakdown. - - Displays: - - Scenario identification (name, version, PyRIT version) - - Target and scorer information - - Overall statistics - - Per-group success rates and result counts - - Args: - result (ScenarioResult): The scenario result to summarize - """ - # Print header - self._print_header(result) - - # Scenario information - self._print_section_header("Scenario Information") - self._print_colored(f"{self._indent}📋 Scenario Details", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Name: {result.scenario_identifier.name}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• Scenario Version: {result.scenario_identifier.version}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• PyRIT Version: {result.scenario_identifier.pyrit_version}", Fore.CYAN) - - # Format description with text wrapping at 120 characters - if result.scenario_identifier.description: - self._print_colored(f"{self._indent * 2}• Description:", Fore.CYAN) - desc_indent = self._indent * 4 - # Calculate available width for description text (total 120 - indent) - available_width = 120 - len(desc_indent) - # Wrap the description text and print each line - wrapped_lines = textwrap.wrap( - result.scenario_identifier.description, width=available_width, break_long_words=False - ) - for line in wrapped_lines: - self._print_colored(f"{desc_indent}{line}", Fore.CYAN) - - # Target information - print() - self._print_colored(f"{self._indent}🎯 Target Information", Style.BRIGHT) - target_id = result.objective_target_identifier - target_type = target_id.class_name if target_id else "Unknown" - target_model = target_id.params.get("model_name", "Unknown") if target_id else "Unknown" - target_endpoint = target_id.params.get("endpoint", "Unknown") if target_id else "Unknown" - - self._print_colored(f"{self._indent * 2}• Target Type: {target_type}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• Target Model: {target_model}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• Target Endpoint: {target_endpoint}", Fore.CYAN) - - # Scorer information - use ComponentIdentifier from result - scorer_identifier = result.objective_scorer_identifier - if scorer_identifier: - self._scorer_printer.print_objective_scorer(scorer_identifier=scorer_identifier) - - # Overall statistics - self._print_section_header("Overall Statistics") - total_results = sum(len(results) for results in result.attack_results.values()) - total_strategies = len(result.get_strategies_used()) - overall_rate = result.objective_achieved_rate() - - self._print_colored(f"{self._indent}📈 Summary", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Total Strategies: {total_strategies}", Fore.GREEN) - self._print_colored(f"{self._indent * 2}• Total Attack Results: {total_results}", Fore.GREEN) - self._print_colored( - f"{self._indent * 2}• Overall Success Rate: {overall_rate}%", self._get_rate_color(overall_rate) - ) - - objectives = result.get_objectives() - self._print_colored(f"{self._indent * 2}• Unique Objectives: {len(objectives)}", Fore.GREEN) - - # Per-group breakdown - self._print_section_header("Per-Group Breakdown") - display_groups = result.get_display_groups() - - for group_name, group_results in display_groups.items(): - total_group = len(group_results) - if total_group == 0: - group_rate = 0 - else: - successful = sum(1 for r in group_results if r.outcome == AttackOutcome.SUCCESS) - group_rate = int((successful / total_group) * 100) - - print() - self._print_colored(f"{self._indent}🔸 Group: {group_name}", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Number of Results: {total_group}", Fore.YELLOW) - self._print_colored(f"{self._indent * 2}• Success Rate: {group_rate}%", self._get_rate_color(group_rate)) - - # Print footer - self._print_footer() - - def _print_header(self, result: ScenarioResult) -> None: - """ - Print the header with scenario name. - - Args: - result (ScenarioResult): The scenario result. - """ - print() - self._print_colored("=" * self._width, Fore.CYAN) - header_text = f"📊 SCENARIO RESULTS: {result.scenario_identifier.name}" - self._print_colored(header_text.center(self._width), Style.BRIGHT, Fore.CYAN) - self._print_colored("=" * self._width, Fore.CYAN) - - def _print_footer(self) -> None: - """ - Print a footer separator. - """ - print() - self._print_colored("=" * self._width, Fore.CYAN) - print() - - def _get_rate_color(self, rate: int) -> str: - """ - Get color based on success rate. - - Args: - rate (int): Success rate percentage (0-100) - - Returns: - str: Colorama color constant - """ - if rate >= 75: - return str(Fore.RED) # High success (bad for security) - if rate >= 50: - return str(Fore.YELLOW) # Medium success - if rate >= 25: - return str(Fore.CYAN) # Low success - return str(Fore.GREEN) # Very low success (good for security) diff --git a/pyrit/score/printer/console_scorer_printer.py b/pyrit/score/printer/console_scorer_printer.py index 27aec712ee..c3b2165115 100644 --- a/pyrit/score/printer/console_scorer_printer.py +++ b/pyrit/score/printer/console_scorer_printer.py @@ -1,290 +1,49 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import TYPE_CHECKING, Optional - -from colorama import Fore, Style +from typing import Any from pyrit.identifiers import ComponentIdentifier -from pyrit.score.printer.scorer_printer import ScorerPrinter - -if TYPE_CHECKING: - from pyrit.score.scorer_evaluation.scorer_metrics import ( - HarmScorerMetrics, - ObjectiveScorerMetrics, - ) +from pyrit.printer.scorer.console import ConsoleScorerPrinterBase -class ConsoleScorerPrinter(ScorerPrinter): +class ConsoleScorerPrinter(ConsoleScorerPrinterBase): """ - Console printer for scorer information with enhanced formatting. + Framework console printer for scorer information. - This printer formats scorer details for console display with optional color coding, - proper indentation, and visual hierarchy. Colors can be disabled for consoles - that don't support ANSI characters. + Thin subclass that implements metrics fetching via the scorer evaluation registry. + All formatting logic lives in ConsoleScorerPrinterBase. """ - _SCORER_DISPLAY_PARAMS = frozenset({"scorer_type", "score_aggregator"}) - _TARGET_DISPLAY_PARAMS = frozenset({"model_name", "temperature"}) - - def __init__(self, *, indent_size: int = 2, enable_colors: bool = True) -> None: + def get_objective_metrics(self, *, eval_hash: str) -> Any: """ - Initialize the console scorer printer. + Fetch objective scorer evaluation metrics from the registry. Args: - indent_size (int): Number of spaces for indentation. Must be non-negative. - Defaults to 2. - enable_colors (bool): Whether to enable ANSI color output. When False, - all output will be plain text without colors. Defaults to True. - - Raises: - ValueError: If indent_size < 0. - """ - if indent_size < 0: - raise ValueError("indent_size must be non-negative") - self._indent = " " * indent_size - self._enable_colors = enable_colors - - def _print_colored(self, text: str, *colors: str) -> None: - """ - Print text with color formatting if colors are enabled. - - Args: - text (str): The text to print. - *colors: Variable number of colorama color constants to apply. - """ - if self._enable_colors and colors: - color_prefix = "".join(colors) - print(f"{color_prefix}{text}{Style.RESET_ALL}") - else: - print(text) - - def _get_quality_color( - self, value: float, *, higher_is_better: bool, good_threshold: float, bad_threshold: float - ) -> str: - """ - Determine the color based on metric quality thresholds. - - Args: - value (float): The metric value to evaluate. - higher_is_better (bool): If True, higher values are better (e.g., accuracy). - If False, lower values are better (e.g., MAE). - good_threshold (float): The threshold for "good" (green) values. - bad_threshold (float): The threshold for "bad" (red) values. + eval_hash (str): The evaluation hash to look up. Returns: - str: The colorama color constant to use. - """ - if higher_is_better: - if value >= good_threshold: - return str(Fore.GREEN) - if value < bad_threshold: - return str(Fore.RED) - return str(Fore.CYAN) - # Lower is better (e.g., MAE, score time) - if value <= good_threshold: - return str(Fore.GREEN) - if value > bad_threshold: - return str(Fore.RED) - return str(Fore.CYAN) - - def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: - """ - Print objective scorer information including type, nested scorers, and evaluation metrics. - - This method displays: - - Scorer type and identity information - - Nested sub-scorers (for composite scorers) - - Objective evaluation metrics (accuracy, precision, recall, F1) from the registry - - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. + ObjectiveScorerMetrics or None: The metrics, or None if not found. """ - from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier from pyrit.score.scorer_evaluation.scorer_metrics_io import ( find_objective_metrics_by_eval_hash, ) - print() - self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) - self._print_scorer_info(scorer_identifier, indent_level=3) - - # Look up metrics by eval hash - eval_hash = ScorerEvaluationIdentifier(scorer_identifier).eval_hash - metrics = find_objective_metrics_by_eval_hash(eval_hash=eval_hash) - self._print_objective_metrics(metrics) + return find_objective_metrics_by_eval_hash(eval_hash=eval_hash) - def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: + def get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: """ - Print harm scorer information including type, nested scorers, and evaluation metrics. - - This method displays: - - Scorer type and identity information - - Nested sub-scorers (for composite scorers) - - Harm evaluation metrics (MAE, Krippendorff alpha) from the registry + Fetch harm scorer evaluation metrics from the registry. Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. - harm_category (str): The harm category for looking up metrics (e.g., "hate_speech", "violence"). + eval_hash (str): The evaluation hash to look up. + harm_category (str): The harm category for metrics lookup. + + Returns: + HarmScorerMetrics or None: The metrics, or None if not found. """ - from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier from pyrit.score.scorer_evaluation.scorer_metrics_io import ( find_harm_metrics_by_eval_hash, ) - print() - self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) - self._print_scorer_info(scorer_identifier, indent_level=3) - - # Look up metrics by eval hash and harm category - eval_hash = ScorerEvaluationIdentifier(scorer_identifier).eval_hash - metrics = find_harm_metrics_by_eval_hash(eval_hash=eval_hash, harm_category=harm_category) - self._print_harm_metrics(metrics) - - def _print_scorer_info(self, scorer_identifier: ComponentIdentifier, *, indent_level: int = 2) -> None: - """ - Print scorer information including nested sub-scorers. - - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier. - indent_level (int): Current indentation level for nested display. - """ - indent = self._indent * indent_level - - self._print_colored(f"{indent}• Scorer Type: {scorer_identifier.class_name}", Fore.CYAN) - - for key, value in scorer_identifier.params.items(): - if key in self._SCORER_DISPLAY_PARAMS and value is not None: - self._print_colored(f"{indent}• {key}: {value}", Fore.CYAN) - - # Print target summary if available - prompt_target = scorer_identifier.get_child("prompt_target") - if prompt_target: - for key, value in prompt_target.params.items(): - if key in self._TARGET_DISPLAY_PARAMS and value is not None: - self._print_colored(f"{indent}• {key}: {value}", Fore.CYAN) - - # Print sub-scorers recursively - sub_scorers = scorer_identifier.get_child_list("sub_scorers") - if sub_scorers: - self._print_colored(f"{indent} └─ Composite of {len(sub_scorers)} scorer(s):", Fore.CYAN) - for sub_scorer_id in sub_scorers: - self._print_scorer_info(sub_scorer_id, indent_level=indent_level + 3) - - def _print_objective_metrics(self, metrics: Optional["ObjectiveScorerMetrics"]) -> None: - """ - Print objective scorer evaluation metrics. - - Args: - metrics (Optional[ObjectiveScorerMetrics]): The metrics to print, or None if not available. - """ - if metrics is None: - print() - self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) - self._print_colored( - f"{self._indent * 3}Official evaluation has not been run yet for this specific configuration", - Fore.YELLOW, - ) - return - - print() - self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) - - # Accuracy: >= 0.9 is good, < 0.7 is bad - accuracy_color = self._get_quality_color( - metrics.accuracy, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 - ) - self._print_colored(f"{self._indent * 3}• Accuracy: {metrics.accuracy:.2%}", accuracy_color) - - if metrics.accuracy_standard_error is not None: - self._print_colored( - f"{self._indent * 3}• Accuracy Std Error: ±{metrics.accuracy_standard_error:.4f}", Fore.CYAN - ) - - # F1 Score: >= 0.9 is good, < 0.7 is bad - if metrics.f1_score is not None: - f1_color = self._get_quality_color( - metrics.f1_score, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 - ) - self._print_colored(f"{self._indent * 3}• F1 Score: {metrics.f1_score:.4f}", f1_color) - - # Precision: >= 0.9 is good, < 0.7 is bad - if metrics.precision is not None: - precision_color = self._get_quality_color( - metrics.precision, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 - ) - self._print_colored(f"{self._indent * 3}• Precision: {metrics.precision:.4f}", precision_color) - - # Recall: >= 0.9 is good, < 0.7 is bad - if metrics.recall is not None: - recall_color = self._get_quality_color( - metrics.recall, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 - ) - self._print_colored(f"{self._indent * 3}• Recall: {metrics.recall:.4f}", recall_color) - - # Average Score Time: < 0.5s is good, > 3.0s is bad - if metrics.average_score_time_seconds is not None: - time_color = self._get_quality_color( - metrics.average_score_time_seconds, higher_is_better=False, good_threshold=0.5, bad_threshold=3.0 - ) - self._print_colored( - f"{self._indent * 3}• Average Score Time: {metrics.average_score_time_seconds:.2f}s", time_color - ) - - def _print_harm_metrics(self, metrics: Optional["HarmScorerMetrics"]) -> None: - """ - Print harm scorer evaluation metrics. - - Args: - metrics (Optional[HarmScorerMetrics]): The metrics to print, or None if not available. - """ - if metrics is None: - print() - self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) - self._print_colored( - f"{self._indent * 3}Official evaluation has not been run yet for this specific configuration", - Fore.YELLOW, - ) - return - - print() - self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) - - # MAE: <= 0.1 is good, > 0.25 is bad (lower is better) - mae_color = self._get_quality_color( - metrics.mean_absolute_error, higher_is_better=False, good_threshold=0.1, bad_threshold=0.25 - ) - self._print_colored(f"{self._indent * 3}• Mean Absolute Error: {metrics.mean_absolute_error:.4f}", mae_color) - - if metrics.mae_standard_error is not None: - self._print_colored(f"{self._indent * 3}• MAE Std Error: ±{metrics.mae_standard_error:.4f}", Fore.CYAN) - - # Krippendorff Alpha: >= 0.8 is strong agreement, < 0.6 is weak agreement - if metrics.krippendorff_alpha_combined is not None: - alpha_color = self._get_quality_color( - metrics.krippendorff_alpha_combined, higher_is_better=True, good_threshold=0.8, bad_threshold=0.6 - ) - self._print_colored( - f"{self._indent * 3}• Krippendorff Alpha (Combined): {metrics.krippendorff_alpha_combined:.4f}", - alpha_color, - ) - - if metrics.krippendorff_alpha_model is not None: - alpha_model_color = self._get_quality_color( - metrics.krippendorff_alpha_model, higher_is_better=True, good_threshold=0.8, bad_threshold=0.6 - ) - self._print_colored( - f"{self._indent * 3}• Krippendorff Alpha (Model): {metrics.krippendorff_alpha_model:.4f}", - alpha_model_color, - ) - - # Average Score Time: < 1s is good, > 3.0s is bad - if metrics.average_score_time_seconds is not None: - time_color = self._get_quality_color( - metrics.average_score_time_seconds, higher_is_better=False, good_threshold=1.0, bad_threshold=3.0 - ) - self._print_colored( - f"{self._indent * 3}• Average Score Time: {metrics.average_score_time_seconds:.2f}s", time_color - ) + return find_harm_metrics_by_eval_hash(eval_hash=eval_hash, harm_category=harm_category) From 30f61515d17e03ead91ce868a213d46fa8503c08 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 09:27:36 -0700 Subject: [PATCH 02/34] Consolidate all printers into pyrit/printer/ module Move framework CentralMemory implementations into pyrit/printer/ alongside their base classes. CentralMemory is imported lazily inside constructors, so thin clients importing the module never pay the SQLAlchemy cost. - ConsoleAttackResultPrinter now lives in pyrit.printer.attack_result.console - ConsoleScenarioResultPrinter now lives in pyrit.printer.scenario_result.console - ConsoleScorerPrinter now lives in pyrit.printer.scorer.console - Old locations (executor/attack/printer/, scenario/printer/, score/printer/) become pure re-exports for backward compatibility - Updated test patch paths to match new module locations Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/executor/attack/printer/__init__.py | 12 +++- .../attack/printer/console_printer.py | 60 ++----------------- pyrit/printer/attack_result/console.py | 39 +++++++++++- pyrit/printer/scenario_result/console.py | 38 ++++++++++++ pyrit/printer/scorer/console.py | 25 ++++++++ pyrit/scenario/printer/console_printer.py | 47 +++------------ pyrit/score/printer/console_scorer_printer.py | 52 +++------------- .../attack/printer/test_console_printer.py | 14 ++--- .../unit/score/test_console_scorer_printer.py | 2 +- 9 files changed, 140 insertions(+), 149 deletions(-) diff --git a/pyrit/executor/attack/printer/__init__.py b/pyrit/executor/attack/printer/__init__.py index d5162a31f0..99bd415386 100644 --- a/pyrit/executor/attack/printer/__init__.py +++ b/pyrit/executor/attack/printer/__init__.py @@ -1,10 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Attack result printers module.""" +""" +Deprecated: Import from pyrit.printer instead. +Attack result printers have moved to pyrit.printer.attack_result. +These re-exports are provided for backward compatibility. +""" + +from pyrit.common.deprecation import print_deprecation_message from pyrit.executor.attack.printer.attack_result_printer import AttackResultPrinter -from pyrit.executor.attack.printer.console_printer import ConsoleAttackResultPrinter +from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter + +# MarkdownAttackResultPrinter is not yet refactored, keep the old import from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter __all__ = [ diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index 1e17896e88..9c5ae68809 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -2,60 +2,10 @@ # Licensed under the MIT license. from pyrit.common.display_response import display_image_response -from pyrit.memory import CentralMemory from pyrit.models import Message, Score -from pyrit.printer.attack_result.console import ConsoleAttackPrinterBase +from pyrit.printer.attack_result.console import ConsoleAttackPrinterBase, ConsoleAttackResultPrinter - -class ConsoleAttackResultPrinter(ConsoleAttackPrinterBase): - """ - Framework console printer for attack results. - - Thin subclass that implements data-fetching via CentralMemory. - All formatting logic lives in ConsoleAttackPrinterBase. - """ - - def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True) -> None: - """ - Initialize the console printer. - - Args: - width (int): Maximum width for text wrapping. Defaults to 100. - indent_size (int): Number of spaces for indentation. Defaults to 2. - enable_colors (bool): Whether to enable ANSI color output. Defaults to True. - """ - super().__init__(width=width, indent_size=indent_size, enable_colors=enable_colors) - self._memory = CentralMemory.get_memory_instance() - - async def get_conversation_async(self, conversation_id: str) -> list[Message]: - """ - Fetch conversation messages from CentralMemory. - - Args: - conversation_id (str): The conversation ID to fetch. - - Returns: - list[Message]: The conversation messages. - """ - return list(self._memory.get_conversation(conversation_id=conversation_id)) - - async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: - """ - Fetch scores from CentralMemory. - - Args: - prompt_ids (list[str]): The message piece IDs to fetch scores for. - - Returns: - list[Score]: The scores. - """ - return self._memory.get_prompt_scores(prompt_ids=prompt_ids) - - async def display_image_async(self, piece: object) -> None: - """ - Display images using PIL/IPython in notebook environments. - - Args: - piece: The message piece that may contain image data. - """ - await display_image_response(piece) +__all__ = [ + "ConsoleAttackPrinterBase", + "ConsoleAttackResultPrinter", +] diff --git a/pyrit/printer/attack_result/console.py b/pyrit/printer/attack_result/console.py index 3b3829dbb4..71c9e2616f 100644 --- a/pyrit/printer/attack_result/console.py +++ b/pyrit/printer/attack_result/console.py @@ -8,7 +8,7 @@ from colorama import Back, Fore, Style -from pyrit.models import AttackOutcome, AttackResult, ConversationType, Score +from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, Score from pyrit.printer.attack_result.base import AttackResultPrinterBase @@ -482,3 +482,40 @@ def _get_outcome_color(self, outcome: AttackOutcome) -> str: AttackOutcome.UNDETERMINED: Fore.YELLOW, }.get(outcome, Fore.WHITE) ) + + +class ConsoleAttackResultPrinter(ConsoleAttackPrinterBase): + """ + Framework console printer for attack results. + + Implements data-fetching via CentralMemory (deferred import). + All formatting logic lives in ConsoleAttackPrinterBase. + """ + + def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True) -> None: + """ + Initialize the console printer. + + Args: + width (int): Maximum width for text wrapping. Defaults to 100. + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + """ + super().__init__(width=width, indent_size=indent_size, enable_colors=enable_colors) + from pyrit.memory import CentralMemory + + self._memory = CentralMemory.get_memory_instance() + + async def get_conversation_async(self, conversation_id: str) -> list[Message]: + """Fetch conversation messages from CentralMemory.""" + return list(self._memory.get_conversation(conversation_id=conversation_id)) + + async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: + """Fetch scores from CentralMemory.""" + return self._memory.get_prompt_scores(prompt_ids=prompt_ids) + + async def display_image_async(self, piece: object) -> None: + """Display images using PIL/IPython in notebook environments.""" + from pyrit.common.display_response import display_image_response + + await display_image_response(piece) diff --git a/pyrit/printer/scenario_result/console.py b/pyrit/printer/scenario_result/console.py index 51cc7f307c..2b373daaa6 100644 --- a/pyrit/printer/scenario_result/console.py +++ b/pyrit/printer/scenario_result/console.py @@ -176,3 +176,41 @@ def _get_rate_color(self, rate: int) -> str: if rate >= 25: return str(Fore.CYAN) return str(Fore.GREEN) + + +class ConsoleScenarioResultPrinter(ConsoleScenarioPrinterBase): + """ + Framework console printer for scenario results. + + Provides the framework's ConsoleScorerPrinter for scorer information display. + All formatting logic lives in ConsoleScenarioPrinterBase. + """ + + def __init__( + self, + *, + width: int = 100, + indent_size: int = 2, + enable_colors: bool = True, + scorer_printer: Optional[ScorerPrinterBase] = None, + ) -> None: + """ + Initialize the console printer. + + Args: + width (int): Maximum width for text wrapping. Defaults to 100. + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + scorer_printer (Optional[ScorerPrinterBase]): Printer for scorer information. + If not provided, a ConsoleScorerPrinter with matching settings is created. + """ + if scorer_printer is None: + from pyrit.printer.scorer.console import ConsoleScorerPrinter + + scorer_printer = ConsoleScorerPrinter(indent_size=indent_size, enable_colors=enable_colors) + super().__init__( + width=width, + indent_size=indent_size, + enable_colors=enable_colors, + scorer_printer=scorer_printer, + ) diff --git a/pyrit/printer/scorer/console.py b/pyrit/printer/scorer/console.py index 754a56be6f..87fb1c2cdf 100644 --- a/pyrit/printer/scorer/console.py +++ b/pyrit/printer/scorer/console.py @@ -256,3 +256,28 @@ def _print_harm_metrics(self, metrics: Optional[Any]) -> None: self._print_colored( f"{self._indent * 3}• Average Score Time: {metrics.average_score_time_seconds:.2f}s", time_color ) + + +class ConsoleScorerPrinter(ConsoleScorerPrinterBase): + """ + Framework console printer for scorer information. + + Implements metrics fetching via the scorer evaluation registry (deferred import). + All formatting logic lives in ConsoleScorerPrinterBase. + """ + + def get_objective_metrics(self, *, eval_hash: str) -> Any: + """Fetch objective scorer evaluation metrics from the registry.""" + from pyrit.score.scorer_evaluation.scorer_metrics_io import ( + find_objective_metrics_by_eval_hash, + ) + + return find_objective_metrics_by_eval_hash(eval_hash=eval_hash) + + def get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: + """Fetch harm scorer evaluation metrics from the registry.""" + from pyrit.score.scorer_evaluation.scorer_metrics_io import ( + find_harm_metrics_by_eval_hash, + ) + + return find_harm_metrics_by_eval_hash(eval_hash=eval_hash, harm_category=harm_category) diff --git a/pyrit/scenario/printer/console_printer.py b/pyrit/scenario/printer/console_printer.py index 3679f2b99c..1325351240 100644 --- a/pyrit/scenario/printer/console_printer.py +++ b/pyrit/scenario/printer/console_printer.py @@ -1,44 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional +""" +Deprecated: Import from pyrit.printer.scenario_result.console instead. +""" -from pyrit.printer.scenario_result.console import ConsoleScenarioPrinterBase -from pyrit.printer.scorer.base import ScorerPrinterBase -from pyrit.score.printer import ConsoleScorerPrinter +from pyrit.printer.scenario_result.console import ConsoleScenarioPrinterBase, ConsoleScenarioResultPrinter - -class ConsoleScenarioResultPrinter(ConsoleScenarioPrinterBase): - """ - Framework console printer for scenario results. - - Thin subclass that provides the framework's ConsoleScorerPrinter - for scorer information. All formatting logic lives in ConsoleScenarioPrinterBase. - """ - - def __init__( - self, - *, - width: int = 100, - indent_size: int = 2, - enable_colors: bool = True, - scorer_printer: Optional[ScorerPrinterBase] = None, - ) -> None: - """ - Initialize the console printer. - - Args: - width (int): Maximum width for text wrapping. Defaults to 100. - indent_size (int): Number of spaces for indentation. Defaults to 2. - enable_colors (bool): Whether to enable ANSI color output. Defaults to True. - scorer_printer (Optional[ScorerPrinterBase]): Printer for scorer information. - If not provided, a ConsoleScorerPrinter with matching settings is created. - """ - if scorer_printer is None: - scorer_printer = ConsoleScorerPrinter(indent_size=indent_size, enable_colors=enable_colors) - super().__init__( - width=width, - indent_size=indent_size, - enable_colors=enable_colors, - scorer_printer=scorer_printer, - ) +__all__ = [ + "ConsoleScenarioPrinterBase", + "ConsoleScenarioResultPrinter", +] diff --git a/pyrit/score/printer/console_scorer_printer.py b/pyrit/score/printer/console_scorer_printer.py index c3b2165115..3b0aed1cee 100644 --- a/pyrit/score/printer/console_scorer_printer.py +++ b/pyrit/score/printer/console_scorer_printer.py @@ -1,49 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any +""" +Deprecated: Import from pyrit.printer.scorer.console instead. +""" -from pyrit.identifiers import ComponentIdentifier -from pyrit.printer.scorer.console import ConsoleScorerPrinterBase +from pyrit.printer.scorer.console import ConsoleScorerPrinter, ConsoleScorerPrinterBase - -class ConsoleScorerPrinter(ConsoleScorerPrinterBase): - """ - Framework console printer for scorer information. - - Thin subclass that implements metrics fetching via the scorer evaluation registry. - All formatting logic lives in ConsoleScorerPrinterBase. - """ - - def get_objective_metrics(self, *, eval_hash: str) -> Any: - """ - Fetch objective scorer evaluation metrics from the registry. - - Args: - eval_hash (str): The evaluation hash to look up. - - Returns: - ObjectiveScorerMetrics or None: The metrics, or None if not found. - """ - from pyrit.score.scorer_evaluation.scorer_metrics_io import ( - find_objective_metrics_by_eval_hash, - ) - - return find_objective_metrics_by_eval_hash(eval_hash=eval_hash) - - def get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: - """ - Fetch harm scorer evaluation metrics from the registry. - - Args: - eval_hash (str): The evaluation hash to look up. - harm_category (str): The harm category for metrics lookup. - - Returns: - HarmScorerMetrics or None: The metrics, or None if not found. - """ - from pyrit.score.scorer_evaluation.scorer_metrics_io import ( - find_harm_metrics_by_eval_hash, - ) - - return find_harm_metrics_by_eval_hash(eval_hash=eval_hash, harm_category=harm_category) +__all__ = [ + "ConsoleScorerPrinter", + "ConsoleScorerPrinterBase", +] diff --git a/tests/unit/executor/attack/printer/test_console_printer.py b/tests/unit/executor/attack/printer/test_console_printer.py index b8195db5ba..2a6f4aa30d 100644 --- a/tests/unit/executor/attack/printer/test_console_printer.py +++ b/tests/unit/executor/attack/printer/test_console_printer.py @@ -6,7 +6,7 @@ import pytest -from pyrit.executor.attack.printer.console_printer import ConsoleAttackResultPrinter +from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, MessagePiece, Score @@ -22,7 +22,7 @@ def mock_memory(): memory = MagicMock() memory.get_conversation.return_value = [] memory.get_prompt_scores.return_value = [] - with patch("pyrit.executor.attack.printer.console_printer.CentralMemory") as mock_cm: + with patch("pyrit.memory.CentralMemory") as mock_cm: mock_cm.get_memory_instance.return_value = memory yield memory @@ -227,7 +227,7 @@ async def test_print_messages_async_empty_list(printer, capsys): assert "No messages to display" in captured.out -@patch("pyrit.executor.attack.printer.console_printer.display_image_response", new_callable=AsyncMock) +@patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) async def test_print_messages_async_user_message(mock_display, printer, sample_message, capsys): await printer.print_messages_async(messages=[sample_message]) captured = capsys.readouterr() @@ -236,7 +236,7 @@ async def test_print_messages_async_user_message(mock_display, printer, sample_m assert "Hello world" in captured.out -@patch("pyrit.executor.attack.printer.console_printer.display_image_response", new_callable=AsyncMock) +@patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) async def test_print_messages_async_assistant_message(mock_display, printer, capsys): piece = MessagePiece( role="assistant", @@ -250,7 +250,7 @@ async def test_print_messages_async_assistant_message(mock_display, printer, cap assert "Response" in captured.out -@patch("pyrit.executor.attack.printer.console_printer.display_image_response", new_callable=AsyncMock) +@patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) async def test_print_messages_async_converted_differs(mock_display, printer, capsys): piece = MessagePiece( role="user", @@ -347,7 +347,7 @@ def test_print_wrapped_text_with_newlines(printer, capsys): assert "Line four" in captured.out -@patch("pyrit.executor.attack.printer.console_printer.display_image_response", new_callable=AsyncMock) +@patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) async def test_print_messages_async_blocked_without_partial_content(mock_display, printer, capsys): piece = MessagePiece( role="assistant", @@ -364,7 +364,7 @@ async def test_print_messages_async_blocked_without_partial_content(mock_display assert "status_code" not in captured.out -@patch("pyrit.executor.attack.printer.console_printer.display_image_response", new_callable=AsyncMock) +@patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) async def test_print_messages_async_blocked_with_partial_content(mock_display, printer, capsys): piece = MessagePiece( role="assistant", diff --git a/tests/unit/score/test_console_scorer_printer.py b/tests/unit/score/test_console_scorer_printer.py index fc7d1e64fb..23fb2c799f 100644 --- a/tests/unit/score/test_console_scorer_printer.py +++ b/tests/unit/score/test_console_scorer_printer.py @@ -7,7 +7,7 @@ from colorama import Fore, Style from pyrit.identifiers import ComponentIdentifier -from pyrit.score.printer.console_scorer_printer import ConsoleScorerPrinter +from pyrit.printer.scorer.console import ConsoleScorerPrinter from pyrit.score.scorer_evaluation.scorer_metrics import ( HarmScorerMetrics, ObjectiveScorerMetrics, From de61795191590cdb707cefe8b0e9cda37d11f671 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 09:37:21 -0700 Subject: [PATCH 03/34] Add deprecation warnings for old printer import paths (removed in 0.16.0) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Old locations now use PEP 562 __getattr__ lazy re-exports with DeprecationWarning. Only concrete classes are re-exported (not bases). - pyrit.executor.attack.printer → pyrit.printer.attack_result - pyrit.scenario.printer → pyrit.printer.scenario_result - pyrit.score.printer → pyrit.printer.scorer - Updated all internal callers to new canonical paths - Old ABC files (attack_result_printer.py, scenario_result_printer.py, scorer_printer.py) kept for now but deprecated via __init__.py Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/frontend_core.py | 2 +- pyrit/cli/pyrit_shell.py | 4 +- pyrit/executor/attack/__init__.py | 4 +- pyrit/executor/attack/printer/__init__.py | 37 ++++++++++++++++--- .../attack/printer/console_printer.py | 28 ++++++++++---- pyrit/scenario/printer/__init__.py | 37 +++++++++++++++++-- pyrit/scenario/printer/console_printer.py | 20 +++++++--- pyrit/score/__init__.py | 3 +- pyrit/score/printer/__init__.py | 33 +++++++++++++++-- pyrit/score/printer/console_scorer_printer.py | 20 +++++++--- .../printer/test_attack_result_printer.py | 2 +- tests/unit/score/test_scorer_printer.py | 8 +++- 12 files changed, 160 insertions(+), 38 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index c17eb83b54..7d01471824 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -41,7 +41,7 @@ from pyrit.cli._cli_args import validate_log_level_argparse as validate_log_level_argparse from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry from pyrit.scenario import DatasetConfiguration -from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter +from pyrit.printer.scenario_result.console import ConsoleScenarioResultPrinter from pyrit.setup import ConfigurationLoader, initialize_pyrit_async from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 23cf54fb3c..c2fb309c7d 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -483,7 +483,7 @@ def do_print_scenario(self, arg: str) -> None: print(f"\n{'#' * 80}") print(f"Scenario Run #{idx}: {command}") print(f"{'#' * 80}") - from pyrit.scenario.printer.console_printer import ( + from pyrit.printer.scenario_result.console import ( ConsoleScenarioResultPrinter, ) @@ -500,7 +500,7 @@ def do_print_scenario(self, arg: str) -> None: command, result = self._scenario_history[scenario_num - 1] print(f"\nScenario Run #{scenario_num}: {command}") print("=" * 80) - from pyrit.scenario.printer.console_printer import ( + from pyrit.printer.scenario_result.console import ( ConsoleScenarioResultPrinter, ) diff --git a/pyrit/executor/attack/__init__.py b/pyrit/executor/attack/__init__.py index 1dfb17b6c5..d197dcd61b 100644 --- a/pyrit/executor/attack/__init__.py +++ b/pyrit/executor/attack/__init__.py @@ -40,7 +40,9 @@ ) # Import printer modules last to avoid circular dependencies -from pyrit.executor.attack.printer import AttackResultPrinter, ConsoleAttackResultPrinter, MarkdownAttackResultPrinter +from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter +from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter +from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter from pyrit.executor.attack.single_turn import ( ContextComplianceAttack, FlipAttack, diff --git a/pyrit/executor/attack/printer/__init__.py b/pyrit/executor/attack/printer/__init__.py index 99bd415386..0ca2095610 100644 --- a/pyrit/executor/attack/printer/__init__.py +++ b/pyrit/executor/attack/printer/__init__.py @@ -5,15 +5,40 @@ Deprecated: Import from pyrit.printer instead. Attack result printers have moved to pyrit.printer.attack_result. -These re-exports are provided for backward compatibility. +These re-exports will be removed in 0.16.0. """ -from pyrit.common.deprecation import print_deprecation_message -from pyrit.executor.attack.printer.attack_result_printer import AttackResultPrinter -from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter +import warnings as _warnings + + +def __getattr__(name: str): # noqa: N807 + _deprecated = { + "ConsoleAttackResultPrinter": "pyrit.printer.attack_result.console", + "AttackResultPrinter": "pyrit.printer.attack_result.base", + "MarkdownAttackResultPrinter": "pyrit.executor.attack.printer.markdown_printer", + } + if name in _deprecated: + new_module = _deprecated[name] + _warnings.warn( + f"Importing {name} from pyrit.executor.attack.printer is deprecated and will be removed in 0.16.0. " + f"Import from {new_module} instead.", + DeprecationWarning, + stacklevel=2, + ) + if name == "ConsoleAttackResultPrinter": + from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter + + return ConsoleAttackResultPrinter + if name == "AttackResultPrinter": + from pyrit.printer.attack_result.base import AttackResultPrinterBase + + return AttackResultPrinterBase + if name == "MarkdownAttackResultPrinter": + from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter + + return MarkdownAttackResultPrinter + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -# MarkdownAttackResultPrinter is not yet refactored, keep the old import -from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter __all__ = [ "AttackResultPrinter", diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index 9c5ae68809..c515c113ed 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -1,11 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from pyrit.common.display_response import display_image_response -from pyrit.models import Message, Score -from pyrit.printer.attack_result.console import ConsoleAttackPrinterBase, ConsoleAttackResultPrinter - -__all__ = [ - "ConsoleAttackPrinterBase", - "ConsoleAttackResultPrinter", -] +""" +Deprecated: Import from pyrit.printer.attack_result.console instead. +This re-export will be removed in 0.16.0. +""" + +import warnings as _warnings + + +def __getattr__(name: str): # noqa: N807 + if name == "ConsoleAttackResultPrinter": + _warnings.warn( + "Importing ConsoleAttackResultPrinter from pyrit.executor.attack.printer.console_printer is deprecated " + "and will be removed in 0.16.0. Import from pyrit.printer.attack_result.console instead.", + DeprecationWarning, + stacklevel=2, + ) + from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter + + return ConsoleAttackResultPrinter + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/scenario/printer/__init__.py b/pyrit/scenario/printer/__init__.py index 421a332c64..ea6422827b 100644 --- a/pyrit/scenario/printer/__init__.py +++ b/pyrit/scenario/printer/__init__.py @@ -1,12 +1,41 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Printer components for scenarios.""" +""" +Deprecated: Import from pyrit.printer instead. + +Scenario result printers have moved to pyrit.printer.scenario_result. +These re-exports will be removed in 0.16.0. +""" + +import warnings as _warnings + + +def __getattr__(name: str): # noqa: N807 + _deprecated = { + "ConsoleScenarioResultPrinter": "pyrit.printer.scenario_result.console", + "ScenarioResultPrinter": "pyrit.printer.scenario_result.base", + } + if name in _deprecated: + new_module = _deprecated[name] + _warnings.warn( + f"Importing {name} from pyrit.scenario.printer is deprecated and will be removed in 0.16.0. " + f"Import from {new_module} instead.", + DeprecationWarning, + stacklevel=2, + ) + if name == "ConsoleScenarioResultPrinter": + from pyrit.printer.scenario_result.console import ConsoleScenarioResultPrinter + + return ConsoleScenarioResultPrinter + if name == "ScenarioResultPrinter": + from pyrit.printer.scenario_result.base import ScenarioResultPrinterBase + + return ScenarioResultPrinterBase + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter -from pyrit.scenario.printer.scenario_result_printer import ScenarioResultPrinter __all__ = [ - "ScenarioResultPrinter", "ConsoleScenarioResultPrinter", + "ScenarioResultPrinter", ] diff --git a/pyrit/scenario/printer/console_printer.py b/pyrit/scenario/printer/console_printer.py index 1325351240..12c1a4ad49 100644 --- a/pyrit/scenario/printer/console_printer.py +++ b/pyrit/scenario/printer/console_printer.py @@ -3,11 +3,21 @@ """ Deprecated: Import from pyrit.printer.scenario_result.console instead. +This re-export will be removed in 0.16.0. """ -from pyrit.printer.scenario_result.console import ConsoleScenarioPrinterBase, ConsoleScenarioResultPrinter +import warnings as _warnings -__all__ = [ - "ConsoleScenarioPrinterBase", - "ConsoleScenarioResultPrinter", -] + +def __getattr__(name: str): # noqa: N807 + if name == "ConsoleScenarioResultPrinter": + _warnings.warn( + "Importing ConsoleScenarioResultPrinter from pyrit.scenario.printer.console_printer is deprecated " + "and will be removed in 0.16.0. Import from pyrit.printer.scenario_result.console instead.", + DeprecationWarning, + stacklevel=2, + ) + from pyrit.printer.scenario_result.console import ConsoleScenarioResultPrinter + + return ConsoleScenarioResultPrinter + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/score/__init__.py b/pyrit/score/__init__.py index 5aa0e9ac2d..886a4d6d4f 100644 --- a/pyrit/score/__init__.py +++ b/pyrit/score/__init__.py @@ -23,7 +23,8 @@ from pyrit.score.float_scale.self_ask_general_float_scale_scorer import SelfAskGeneralFloatScaleScorer from pyrit.score.float_scale.self_ask_likert_scorer import LikertScaleEvalFiles, LikertScalePaths, SelfAskLikertScorer from pyrit.score.float_scale.self_ask_scale_scorer import SelfAskScaleScorer -from pyrit.score.printer import ConsoleScorerPrinter, ScorerPrinter +from pyrit.printer.scorer.base import ScorerPrinterBase as ScorerPrinter +from pyrit.printer.scorer.console import ConsoleScorerPrinter from pyrit.score.scorer import Scorer from pyrit.score.scorer_evaluation.metrics_type import MetricsType, RegistryUpdateBehavior from pyrit.score.scorer_evaluation.scorer_metrics import ( diff --git a/pyrit/score/printer/__init__.py b/pyrit/score/printer/__init__.py index d66e21a894..a4f6c3d683 100644 --- a/pyrit/score/printer/__init__.py +++ b/pyrit/score/printer/__init__.py @@ -2,11 +2,38 @@ # Licensed under the MIT license. """ -Scorer printer classes for displaying scorer information in various formats. +Deprecated: Import from pyrit.printer instead. + +Scorer printers have moved to pyrit.printer.scorer. +These re-exports will be removed in 0.16.0. """ -from pyrit.score.printer.console_scorer_printer import ConsoleScorerPrinter -from pyrit.score.printer.scorer_printer import ScorerPrinter +import warnings as _warnings + + +def __getattr__(name: str): # noqa: N807 + _deprecated = { + "ConsoleScorerPrinter": "pyrit.printer.scorer.console", + "ScorerPrinter": "pyrit.printer.scorer.base", + } + if name in _deprecated: + new_module = _deprecated[name] + _warnings.warn( + f"Importing {name} from pyrit.score.printer is deprecated and will be removed in 0.16.0. " + f"Import from {new_module} instead.", + DeprecationWarning, + stacklevel=2, + ) + if name == "ConsoleScorerPrinter": + from pyrit.printer.scorer.console import ConsoleScorerPrinter + + return ConsoleScorerPrinter + if name == "ScorerPrinter": + from pyrit.printer.scorer.base import ScorerPrinterBase + + return ScorerPrinterBase + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ "ConsoleScorerPrinter", diff --git a/pyrit/score/printer/console_scorer_printer.py b/pyrit/score/printer/console_scorer_printer.py index 3b0aed1cee..8a75edbc81 100644 --- a/pyrit/score/printer/console_scorer_printer.py +++ b/pyrit/score/printer/console_scorer_printer.py @@ -3,11 +3,21 @@ """ Deprecated: Import from pyrit.printer.scorer.console instead. +This re-export will be removed in 0.16.0. """ -from pyrit.printer.scorer.console import ConsoleScorerPrinter, ConsoleScorerPrinterBase +import warnings as _warnings -__all__ = [ - "ConsoleScorerPrinter", - "ConsoleScorerPrinterBase", -] + +def __getattr__(name: str): # noqa: N807 + if name == "ConsoleScorerPrinter": + _warnings.warn( + "Importing ConsoleScorerPrinter from pyrit.score.printer.console_scorer_printer is deprecated " + "and will be removed in 0.16.0. Import from pyrit.printer.scorer.console instead.", + DeprecationWarning, + stacklevel=2, + ) + from pyrit.printer.scorer.console import ConsoleScorerPrinter + + return ConsoleScorerPrinter + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/unit/executor/attack/printer/test_attack_result_printer.py b/tests/unit/executor/attack/printer/test_attack_result_printer.py index f8075d45ed..4c51834b91 100644 --- a/tests/unit/executor/attack/printer/test_attack_result_printer.py +++ b/tests/unit/executor/attack/printer/test_attack_result_printer.py @@ -3,7 +3,7 @@ import pytest -from pyrit.executor.attack.printer.attack_result_printer import AttackResultPrinter +from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter from pyrit.models import AttackOutcome diff --git a/tests/unit/score/test_scorer_printer.py b/tests/unit/score/test_scorer_printer.py index edd8b6a26f..cda073893d 100644 --- a/tests/unit/score/test_scorer_printer.py +++ b/tests/unit/score/test_scorer_printer.py @@ -4,7 +4,7 @@ import pytest from pyrit.identifiers import ComponentIdentifier -from pyrit.score.printer.scorer_printer import ScorerPrinter +from pyrit.printer.scorer.base import ScorerPrinterBase as ScorerPrinter def test_scorer_printer_cannot_be_instantiated(): @@ -38,5 +38,11 @@ def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> N def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: pass + def get_objective_metrics(self, *, eval_hash: str): + return None + + def get_harm_metrics(self, *, eval_hash: str, harm_category: str): + return None + printer = CompletePrinter() assert isinstance(printer, ScorerPrinter) From 837ed3f4834b4d1ef60ba06c00b1b216882a6900 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 09:47:32 -0700 Subject: [PATCH 04/34] Rename concrete printers to *MemoryPrinter, move pyrit internals out of bases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Concrete classes that use CentralMemory/scorer registry are now named to clearly indicate their data source: - ConsoleAttackResultPrinter → ConsoleAttackMemoryPrinter - ConsoleScenarioResultPrinter → ConsoleScenarioMemoryPrinter - ConsoleScorerPrinter → ConsoleScorerMemoryPrinter Moved ScorerEvaluationIdentifier (pyrit internal) from base class into the concrete ConsoleScorerMemoryPrinter. Base classes now contain only formatting logic with no pyrit-internal imports beyond models/identifiers. Deprecated re-exports at old paths still work (mapping old names to new), scheduled for removal in 0.16.0. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/frontend_core.py | 2 +- pyrit/cli/pyrit_shell.py | 4 +- pyrit/executor/attack/__init__.py | 2 +- pyrit/executor/attack/printer/__init__.py | 4 +- pyrit/printer/attack_result/console.py | 2 +- pyrit/printer/scenario_result/console.py | 10 +-- pyrit/printer/scorer/base.py | 30 +------ pyrit/printer/scorer/console.py | 86 ++++++++----------- pyrit/scenario/printer/__init__.py | 4 +- pyrit/scenario/printer/console_printer.py | 4 +- pyrit/score/__init__.py | 2 +- pyrit/score/printer/__init__.py | 4 +- pyrit/score/printer/console_scorer_printer.py | 4 +- .../attack/printer/test_console_printer.py | 2 +- .../unit/score/test_console_scorer_printer.py | 2 +- 15 files changed, 62 insertions(+), 100 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 7d01471824..95a0faa829 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -41,7 +41,7 @@ from pyrit.cli._cli_args import validate_log_level_argparse as validate_log_level_argparse from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry from pyrit.scenario import DatasetConfiguration -from pyrit.printer.scenario_result.console import ConsoleScenarioResultPrinter +from pyrit.printer.scenario_result.console import ConsoleScenarioMemoryPrinter as ConsoleScenarioResultPrinter from pyrit.setup import ConfigurationLoader, initialize_pyrit_async from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index c2fb309c7d..368765e276 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -484,7 +484,7 @@ def do_print_scenario(self, arg: str) -> None: print(f"Scenario Run #{idx}: {command}") print(f"{'#' * 80}") from pyrit.printer.scenario_result.console import ( - ConsoleScenarioResultPrinter, + ConsoleScenarioMemoryPrinter as ConsoleScenarioResultPrinter, ) printer = ConsoleScenarioResultPrinter() @@ -501,7 +501,7 @@ def do_print_scenario(self, arg: str) -> None: print(f"\nScenario Run #{scenario_num}: {command}") print("=" * 80) from pyrit.printer.scenario_result.console import ( - ConsoleScenarioResultPrinter, + ConsoleScenarioMemoryPrinter as ConsoleScenarioResultPrinter, ) printer = ConsoleScenarioResultPrinter() diff --git a/pyrit/executor/attack/__init__.py b/pyrit/executor/attack/__init__.py index d197dcd61b..29afcd3277 100644 --- a/pyrit/executor/attack/__init__.py +++ b/pyrit/executor/attack/__init__.py @@ -42,7 +42,7 @@ # Import printer modules last to avoid circular dependencies from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter -from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter +from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter from pyrit.executor.attack.single_turn import ( ContextComplianceAttack, FlipAttack, diff --git a/pyrit/executor/attack/printer/__init__.py b/pyrit/executor/attack/printer/__init__.py index 0ca2095610..914ba2942a 100644 --- a/pyrit/executor/attack/printer/__init__.py +++ b/pyrit/executor/attack/printer/__init__.py @@ -26,9 +26,9 @@ def __getattr__(name: str): # noqa: N807 stacklevel=2, ) if name == "ConsoleAttackResultPrinter": - from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter + from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter - return ConsoleAttackResultPrinter + return ConsoleAttackMemoryPrinter if name == "AttackResultPrinter": from pyrit.printer.attack_result.base import AttackResultPrinterBase diff --git a/pyrit/printer/attack_result/console.py b/pyrit/printer/attack_result/console.py index 71c9e2616f..aa96e062c7 100644 --- a/pyrit/printer/attack_result/console.py +++ b/pyrit/printer/attack_result/console.py @@ -484,7 +484,7 @@ def _get_outcome_color(self, outcome: AttackOutcome) -> str: ) -class ConsoleAttackResultPrinter(ConsoleAttackPrinterBase): +class ConsoleAttackMemoryPrinter(ConsoleAttackPrinterBase): """ Framework console printer for attack results. diff --git a/pyrit/printer/scenario_result/console.py b/pyrit/printer/scenario_result/console.py index 2b373daaa6..742ecfb44d 100644 --- a/pyrit/printer/scenario_result/console.py +++ b/pyrit/printer/scenario_result/console.py @@ -178,11 +178,11 @@ def _get_rate_color(self, rate: int) -> str: return str(Fore.GREEN) -class ConsoleScenarioResultPrinter(ConsoleScenarioPrinterBase): +class ConsoleScenarioMemoryPrinter(ConsoleScenarioPrinterBase): """ Framework console printer for scenario results. - Provides the framework's ConsoleScorerPrinter for scorer information display. + Provides the framework's ConsoleScorerMemoryPrinter for scorer information display. All formatting logic lives in ConsoleScenarioPrinterBase. """ @@ -202,12 +202,12 @@ def __init__( indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. scorer_printer (Optional[ScorerPrinterBase]): Printer for scorer information. - If not provided, a ConsoleScorerPrinter with matching settings is created. + If not provided, a ConsoleScorerMemoryPrinter with matching settings is created. """ if scorer_printer is None: - from pyrit.printer.scorer.console import ConsoleScorerPrinter + from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter - scorer_printer = ConsoleScorerPrinter(indent_size=indent_size, enable_colors=enable_colors) + scorer_printer = ConsoleScorerMemoryPrinter(indent_size=indent_size, enable_colors=enable_colors) super().__init__( width=width, indent_size=indent_size, diff --git a/pyrit/printer/scorer/base.py b/pyrit/printer/scorer/base.py index 1a72200d6d..65ad98c53b 100644 --- a/pyrit/printer/scorer/base.py +++ b/pyrit/printer/scorer/base.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from abc import ABC, abstractmethod -from typing import Any from pyrit.identifiers import ComponentIdentifier @@ -11,36 +10,9 @@ class ScorerPrinterBase(ABC): """ Abstract base class for printing scorer information. - Subclasses implement get_objective_metrics and get_harm_metrics - for data fetching. Framework uses the scorer registry; thin clients - can use REST calls. + Subclasses must implement print_objective_scorer and print_harm_scorer. """ - @abstractmethod - def get_objective_metrics(self, *, eval_hash: str) -> Any: - """ - Fetch objective scorer evaluation metrics by eval hash. - - Args: - eval_hash (str): The evaluation hash to look up. - - Returns: - ObjectiveScorerMetrics or None: The metrics, or None if not found. - """ - - @abstractmethod - def get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: - """ - Fetch harm scorer evaluation metrics by eval hash and category. - - Args: - eval_hash (str): The evaluation hash to look up. - harm_category (str): The harm category for metrics lookup. - - Returns: - HarmScorerMetrics or None: The metrics, or None if not found. - """ - @abstractmethod def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: """ diff --git a/pyrit/printer/scorer/console.py b/pyrit/printer/scorer/console.py index 87fb1c2cdf..04996c4a4b 100644 --- a/pyrit/printer/scorer/console.py +++ b/pyrit/printer/scorer/console.py @@ -74,53 +74,6 @@ def _get_quality_color( return str(Fore.RED) return str(Fore.CYAN) - def _compute_eval_hash(self, scorer_identifier: ComponentIdentifier) -> str: - """ - Compute the evaluation hash for a scorer identifier. - - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier. - - Returns: - str: The evaluation hash string. - """ - from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier - - return ScorerEvaluationIdentifier(scorer_identifier).eval_hash - - def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: - """ - Print objective scorer information. - - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. - """ - print() - self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) - self._print_scorer_info(scorer_identifier, indent_level=3) - - eval_hash = self._compute_eval_hash(scorer_identifier) - metrics = self.get_objective_metrics(eval_hash=eval_hash) - self._print_objective_metrics(metrics) - - def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: - """ - Print harm scorer information. - - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. - harm_category (str): The harm category for looking up metrics. - """ - print() - self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) - self._print_scorer_info(scorer_identifier, indent_level=3) - - eval_hash = self._compute_eval_hash(scorer_identifier) - metrics = self.get_harm_metrics(eval_hash=eval_hash, harm_category=harm_category) - self._print_harm_metrics(metrics) - def _print_scorer_info(self, scorer_identifier: ComponentIdentifier, *, indent_level: int = 2) -> None: """ Print scorer information including nested sub-scorers. @@ -258,7 +211,7 @@ def _print_harm_metrics(self, metrics: Optional[Any]) -> None: ) -class ConsoleScorerPrinter(ConsoleScorerPrinterBase): +class ConsoleScorerMemoryPrinter(ConsoleScorerPrinterBase): """ Framework console printer for scorer information. @@ -281,3 +234,40 @@ def get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: ) return find_harm_metrics_by_eval_hash(eval_hash=eval_hash, harm_category=harm_category) + + def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: + """ + Print objective scorer information including type, nested scorers, and evaluation metrics. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. + """ + from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier + + print() + self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) + self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) + self._print_scorer_info(scorer_identifier, indent_level=3) + + eval_hash = ScorerEvaluationIdentifier(scorer_identifier).eval_hash + metrics = self.get_objective_metrics(eval_hash=eval_hash) + self._print_objective_metrics(metrics) + + def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: + """ + Print harm scorer information including type, nested scorers, and evaluation metrics. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. + harm_category (str): The harm category for looking up metrics. + """ + from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier + + print() + self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) + self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) + self._print_scorer_info(scorer_identifier, indent_level=3) + + eval_hash = ScorerEvaluationIdentifier(scorer_identifier).eval_hash + metrics = self.get_harm_metrics(eval_hash=eval_hash, harm_category=harm_category) + self._print_harm_metrics(metrics) diff --git a/pyrit/scenario/printer/__init__.py b/pyrit/scenario/printer/__init__.py index ea6422827b..c613b899ee 100644 --- a/pyrit/scenario/printer/__init__.py +++ b/pyrit/scenario/printer/__init__.py @@ -25,9 +25,9 @@ def __getattr__(name: str): # noqa: N807 stacklevel=2, ) if name == "ConsoleScenarioResultPrinter": - from pyrit.printer.scenario_result.console import ConsoleScenarioResultPrinter + from pyrit.printer.scenario_result.console import ConsoleScenarioMemoryPrinter - return ConsoleScenarioResultPrinter + return ConsoleScenarioMemoryPrinter if name == "ScenarioResultPrinter": from pyrit.printer.scenario_result.base import ScenarioResultPrinterBase diff --git a/pyrit/scenario/printer/console_printer.py b/pyrit/scenario/printer/console_printer.py index 12c1a4ad49..8f70e72129 100644 --- a/pyrit/scenario/printer/console_printer.py +++ b/pyrit/scenario/printer/console_printer.py @@ -17,7 +17,7 @@ def __getattr__(name: str): # noqa: N807 DeprecationWarning, stacklevel=2, ) - from pyrit.printer.scenario_result.console import ConsoleScenarioResultPrinter + from pyrit.printer.scenario_result.console import ConsoleScenarioMemoryPrinter - return ConsoleScenarioResultPrinter + return ConsoleScenarioMemoryPrinter raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/score/__init__.py b/pyrit/score/__init__.py index 886a4d6d4f..b25b3862cd 100644 --- a/pyrit/score/__init__.py +++ b/pyrit/score/__init__.py @@ -24,7 +24,7 @@ from pyrit.score.float_scale.self_ask_likert_scorer import LikertScaleEvalFiles, LikertScalePaths, SelfAskLikertScorer from pyrit.score.float_scale.self_ask_scale_scorer import SelfAskScaleScorer from pyrit.printer.scorer.base import ScorerPrinterBase as ScorerPrinter -from pyrit.printer.scorer.console import ConsoleScorerPrinter +from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter as ConsoleScorerPrinter from pyrit.score.scorer import Scorer from pyrit.score.scorer_evaluation.metrics_type import MetricsType, RegistryUpdateBehavior from pyrit.score.scorer_evaluation.scorer_metrics import ( diff --git a/pyrit/score/printer/__init__.py b/pyrit/score/printer/__init__.py index a4f6c3d683..1966440bce 100644 --- a/pyrit/score/printer/__init__.py +++ b/pyrit/score/printer/__init__.py @@ -25,9 +25,9 @@ def __getattr__(name: str): # noqa: N807 stacklevel=2, ) if name == "ConsoleScorerPrinter": - from pyrit.printer.scorer.console import ConsoleScorerPrinter + from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter - return ConsoleScorerPrinter + return ConsoleScorerMemoryPrinter if name == "ScorerPrinter": from pyrit.printer.scorer.base import ScorerPrinterBase diff --git a/pyrit/score/printer/console_scorer_printer.py b/pyrit/score/printer/console_scorer_printer.py index 8a75edbc81..2d12895ebe 100644 --- a/pyrit/score/printer/console_scorer_printer.py +++ b/pyrit/score/printer/console_scorer_printer.py @@ -17,7 +17,7 @@ def __getattr__(name: str): # noqa: N807 DeprecationWarning, stacklevel=2, ) - from pyrit.printer.scorer.console import ConsoleScorerPrinter + from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter - return ConsoleScorerPrinter + return ConsoleScorerMemoryPrinter raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/unit/executor/attack/printer/test_console_printer.py b/tests/unit/executor/attack/printer/test_console_printer.py index 2a6f4aa30d..c2d160a29f 100644 --- a/tests/unit/executor/attack/printer/test_console_printer.py +++ b/tests/unit/executor/attack/printer/test_console_printer.py @@ -6,7 +6,7 @@ import pytest -from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter +from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, MessagePiece, Score diff --git a/tests/unit/score/test_console_scorer_printer.py b/tests/unit/score/test_console_scorer_printer.py index 23fb2c799f..3397dbc066 100644 --- a/tests/unit/score/test_console_scorer_printer.py +++ b/tests/unit/score/test_console_scorer_printer.py @@ -7,7 +7,7 @@ from colorama import Fore, Style from pyrit.identifiers import ComponentIdentifier -from pyrit.printer.scorer.console import ConsoleScorerPrinter +from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter as ConsoleScorerPrinter from pyrit.score.scorer_evaluation.scorer_metrics import ( HarmScorerMetrics, ObjectiveScorerMetrics, From 788eceb245abb5a4861b3f615f18fda97595d35c Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 10:06:24 -0700 Subject: [PATCH 05/34] Refactor markdown printer, delete dead old ABC files - Created MarkdownAttackPrinterBase + MarkdownAttackMemoryPrinter in pyrit/printer/attack_result/markdown.py (same pattern as console) - Deleted dead old ABC files: - pyrit/executor/attack/printer/attack_result_printer.py - pyrit/scenario/printer/scenario_result_printer.py - pyrit/score/printer/scorer_printer.py - Old markdown_printer.py now a deprecation re-export shim - Updated all internal imports and test patches Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/executor/attack/__init__.py | 2 +- pyrit/executor/attack/printer/__init__.py | 6 +- .../attack/printer/attack_result_printer.py | 107 --- .../attack/printer/markdown_printer.py | 647 +----------------- pyrit/printer/attack_result/markdown.py | 582 ++++++++++++++++ pyrit/scenario/printer/__init__.py | 6 - .../printer/scenario_result_printer.py | 30 - pyrit/score/printer/scorer_printer.py | 45 -- .../attack/core/test_markdown_printer.py | 4 +- 9 files changed, 603 insertions(+), 826 deletions(-) delete mode 100644 pyrit/executor/attack/printer/attack_result_printer.py create mode 100644 pyrit/printer/attack_result/markdown.py delete mode 100644 pyrit/scenario/printer/scenario_result_printer.py delete mode 100644 pyrit/score/printer/scorer_printer.py diff --git a/pyrit/executor/attack/__init__.py b/pyrit/executor/attack/__init__.py index 29afcd3277..ad50d8af51 100644 --- a/pyrit/executor/attack/__init__.py +++ b/pyrit/executor/attack/__init__.py @@ -40,9 +40,9 @@ ) # Import printer modules last to avoid circular dependencies -from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter +from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter as MarkdownAttackResultPrinter from pyrit.executor.attack.single_turn import ( ContextComplianceAttack, FlipAttack, diff --git a/pyrit/executor/attack/printer/__init__.py b/pyrit/executor/attack/printer/__init__.py index 914ba2942a..99834fb88e 100644 --- a/pyrit/executor/attack/printer/__init__.py +++ b/pyrit/executor/attack/printer/__init__.py @@ -14,8 +14,8 @@ def __getattr__(name: str): # noqa: N807 _deprecated = { "ConsoleAttackResultPrinter": "pyrit.printer.attack_result.console", + "MarkdownAttackResultPrinter": "pyrit.printer.attack_result.markdown", "AttackResultPrinter": "pyrit.printer.attack_result.base", - "MarkdownAttackResultPrinter": "pyrit.executor.attack.printer.markdown_printer", } if name in _deprecated: new_module = _deprecated[name] @@ -34,9 +34,9 @@ def __getattr__(name: str): # noqa: N807 return AttackResultPrinterBase if name == "MarkdownAttackResultPrinter": - from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter + from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter - return MarkdownAttackResultPrinter + return MarkdownAttackMemoryPrinter raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/executor/attack/printer/attack_result_printer.py b/pyrit/executor/attack/printer/attack_result_printer.py deleted file mode 100644 index 1c180ba2d0..0000000000 --- a/pyrit/executor/attack/printer/attack_result_printer.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import ABC, abstractmethod - -from pyrit.models import AttackOutcome, AttackResult - - -class AttackResultPrinter(ABC): - """ - Abstract base class for printing attack results. - - This interface defines the contract for printing attack results in various formats. - Implementations can render results to console, logs, files, or other outputs. - """ - - @abstractmethod - async def print_result_async( - self, - result: AttackResult, - *, - include_auxiliary_scores: bool = False, - include_pruned_conversations: bool = False, - include_adversarial_conversation: bool = False, - ) -> None: - """ - Print the complete attack result. - - Args: - result (AttackResult): The attack result to print - include_auxiliary_scores (bool): Whether to include auxiliary scores in the output. - Defaults to False. - include_pruned_conversations (bool): Whether to include pruned conversations. - For each pruned conversation, only the last message and its score are shown. - Defaults to False. - include_adversarial_conversation (bool): Whether to include the adversarial - conversation (the red teaming LLM's reasoning). Only shown for successful - attacks to avoid overwhelming output. Defaults to False. - """ - - @abstractmethod - async def print_conversation_async(self, result: AttackResult, *, include_scores: bool = False) -> None: - """ - Print only the conversation history. - - Args: - result (AttackResult): The attack result containing the conversation to print - include_scores (bool): Whether to include scores in the output. - Defaults to False. - """ - - @abstractmethod - async def print_summary_async(self, result: AttackResult) -> None: - """ - Print a summary of the attack result without the full conversation. - - Args: - result (AttackResult): The attack result to summarize - """ - - @staticmethod - def _get_outcome_icon(outcome: AttackOutcome) -> str: - """ - Get an icon for an outcome. - - Maps AttackOutcome enum values to appropriate Unicode emoji icons. - - Args: - outcome (AttackOutcome): The attack outcome enum value. - - Returns: - str: Unicode emoji string. - """ - return { - AttackOutcome.SUCCESS: "\u2705", - AttackOutcome.FAILURE: "\u274c", - AttackOutcome.UNDETERMINED: "\u2753", - }.get(outcome, "") - - @staticmethod - def _format_time(milliseconds: int) -> str: - """ - Format time in a human-readable way. - - Converts milliseconds to appropriate units (ms, s, or m + s) based - on the magnitude of the value. - - Args: - milliseconds (int): Time duration in milliseconds. Should be - non-negative. - - Returns: - str: Formatted time string (e.g., "500ms", "2.50s", "1m 30s"). - - Raises: - TypeError: If milliseconds is not an integer. - ValueError: If milliseconds is negative. - """ - if milliseconds < 1000: - return f"{milliseconds}ms" - - if milliseconds < 60000: - return f"{milliseconds / 1000:.2f}s" - - minutes = milliseconds // 60000 - seconds = (milliseconds % 60000) / 1000 - return f"{minutes}m {seconds:.0f}s" diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index 402f9fd0c6..8270a385cd 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -1,640 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import os -from datetime import datetime, timezone +""" +Deprecated: Import from pyrit.printer.attack_result.markdown instead. +This re-export will be removed in 0.16.0. +""" -from pyrit.executor.attack.printer.attack_result_printer import AttackResultPrinter -from pyrit.memory import CentralMemory -from pyrit.models import AttackResult, ConversationType, Message, MessagePiece, Score +import warnings as _warnings -class MarkdownAttackResultPrinter(AttackResultPrinter): - """ - Markdown printer for attack results optimized for Jupyter notebooks. - - This printer formats attack results as markdown, making them ideal for display - in Jupyter notebooks where LLM responses often contain code blocks and other - markdown formatting that should be properly rendered. - """ - - def __init__(self, *, display_inline: bool = True) -> None: - """ - Initialize the markdown printer. - - Args: - display_inline (bool): If True, uses IPython.display to render markdown - inline in Jupyter notebooks. If False, prints markdown strings. - Defaults to True. - """ - self._memory = CentralMemory.get_memory_instance() - self._display_inline = display_inline - - def _render_markdown(self, markdown_lines: list[str]) -> None: - """ - Render the markdown content using appropriate display method. - - Attempts to use IPython.display.Markdown for Jupyter notebook rendering - when display_inline is True, falling back to print() if not available. - - Args: - markdown_lines (List[str]): List of markdown strings to render. - """ - full_markdown = "\n".join(markdown_lines) - - if self._display_inline: - try: - from IPython.display import Markdown, display - - display(Markdown(full_markdown)) - except (ImportError, NameError): - # Fallback to print if IPython is not available - print(full_markdown) - else: - print(full_markdown) - - def _format_score(self, score: Score, indent: str = "") -> str: - """ - Format a score object as markdown with proper styling. - - Converts a Score object into formatted markdown text with appropriate - emphasis and structure. Handles different score value types and includes - rationale and metadata when available. - - Args: - score (Score): The score object to format. - indent (str): String prefix for indentation. Defaults to "". - - Returns: - str: Formatted markdown representation of the score. - """ - lines = [] - - # Score value with appropriate formatting - score_value = score.get_value() - if isinstance(score_value, bool): - value_str = str(score_value) - elif isinstance(score_value, (int, float)): - value_str = f"**{score_value:.2f}**" if isinstance(score_value, float) else f"**{score_value}**" - else: - value_str = f"**{score_value}**" - - lines.append(f"{indent}- **Score Type:** {score.score_type}") - lines.append(f"{indent}- **Value:** {value_str}") - category_str = ", ".join(score.score_category) if score.score_category else "N/A" - lines.append(f"{indent}- **Category:** {category_str}") - - if score.score_rationale: - # Handle multi-line rationale - rationale_lines = score.score_rationale.split("\n") - if len(rationale_lines) > 1: - lines.append(f"{indent}- **Rationale:**") - lines.extend(f"{indent} {line}" for line in rationale_lines) - else: - lines.append(f"{indent}- **Rationale:** {score.score_rationale}") - - if score.score_metadata: - lines.append(f"{indent}- **Metadata:** `{score.score_metadata}`") - - return "\n".join(lines) - - async def print_result_async( - self, - result: AttackResult, - *, - include_auxiliary_scores: bool = False, - include_pruned_conversations: bool = False, - include_adversarial_conversation: bool = False, - ) -> None: - """ - Print the complete attack result as formatted markdown. - - Generates a comprehensive markdown report including attack summary, - conversation history, scores, and metadata. The output is optimized - for display in Jupyter notebooks. - - Args: - result (AttackResult): The attack result to print. - include_auxiliary_scores (bool): Whether to include auxiliary scores - in the conversation display. Defaults to False. - include_pruned_conversations (bool): Whether to include pruned conversations. - For each pruned conversation, only the last message and its score are shown. - Defaults to False. - include_adversarial_conversation (bool): Whether to include the adversarial - conversation (the red teaming LLM's reasoning). Only shown for successful - attacks to avoid overwhelming output. Defaults to False. - """ - markdown_lines = [] - - # Header with outcome - outcome_emoji = self._get_outcome_icon(result.outcome) - markdown_lines.append(f"# {outcome_emoji} Attack Result: {result.outcome.value.upper()}\n") - markdown_lines.append("---\n") - - # Summary section - summary_lines = await self._get_summary_markdown_async(result) - markdown_lines.extend(summary_lines) - markdown_lines.append("---\n") - - # Conversation history - markdown_lines.append("\n## Conversation History\n") - conversation_lines = await self._get_conversation_markdown_async( - result=result, include_scores=include_auxiliary_scores +def __getattr__(name: str): # noqa: N807 + if name == "MarkdownAttackResultPrinter": + _warnings.warn( + "Importing MarkdownAttackResultPrinter from pyrit.executor.attack.printer.markdown_printer is deprecated " + "and will be removed in 0.16.0. Import from pyrit.printer.attack_result.markdown instead.", + DeprecationWarning, + stacklevel=2, ) - markdown_lines.extend(conversation_lines) - - # Pruned conversations if requested - if include_pruned_conversations: - pruned_lines = await self._get_pruned_conversations_markdown_async(result) - if pruned_lines: - markdown_lines.extend(pruned_lines) - - # Adversarial conversation if requested (only for successful attacks) - if include_adversarial_conversation: - adversarial_lines = await self._get_adversarial_conversation_markdown_async(result) - if adversarial_lines: - markdown_lines.extend(adversarial_lines) - - # Metadata if available - if result.metadata: - markdown_lines.append("\n## Additional Metadata\n") - for key, value in result.metadata.items(): - # Only include metadata that can be converted to string - try: - # Try to convert to string - str_value = str(value) - markdown_lines.append(f"- **{key}:** {str_value}") - except Exception: - # Skip values that can't be stringified - pass - - # Footer - markdown_lines.append("\n---") - timestamp_utc = datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z") - markdown_lines.append(f"*Report generated at {timestamp_utc}*") - - self._render_markdown(markdown_lines) - - async def print_conversation_async(self, result: AttackResult, *, include_scores: bool = False) -> None: - """ - Print only the conversation history as formatted markdown. - - Extracts and displays the conversation messages from the attack result - without the summary or metadata sections. Useful for focusing on the - actual interaction flow. - - Args: - result (AttackResult): The attack result containing the conversation - to display. - include_scores (bool): Whether to include scores - for each message. Defaults to False. - """ - markdown_lines = await self._get_conversation_markdown_async(result=result, include_scores=include_scores) - self._render_markdown(markdown_lines) - - async def print_summary_async(self, result: AttackResult) -> None: - """ - Print a summary of the attack result as formatted markdown. - - Displays key information about the attack including objective, outcome, - execution metrics, and final score without the full conversation history. - Useful for getting a quick overview of the attack results. - - Args: - result (AttackResult): The attack result to summarize. - """ - markdown_lines = await self._get_summary_markdown_async(result) - self._render_markdown(markdown_lines) - - async def _get_conversation_markdown_async( - self, *, result: AttackResult, include_scores: bool = False - ) -> list[str]: - """ - Generate markdown lines for the conversation history. - - Retrieves conversation messages from memory and formats them as markdown, - organizing by turns and message roles. Handles system messages, user - inputs, and assistant responses with appropriate formatting. - - Args: - result (AttackResult): The attack result containing the conversation ID. - include_scores (bool): Whether to include scores - for each message. Defaults to False. - - Returns: - List[str]: List of markdown strings representing the formatted - conversation history. - """ - markdown_lines = [] - - if not result.conversation_id: - markdown_lines.append("*No conversation ID available*\n") - return markdown_lines - - messages = self._memory.get_conversation(conversation_id=result.conversation_id) - - if not messages: - markdown_lines.append(f"*No conversation found for ID: {result.conversation_id}*\n") - return markdown_lines - - turn_number = 0 - - for message in messages: - if not message.message_pieces: - continue - - message_role = message.get_piece().api_role - - if message_role == "system": - markdown_lines.extend(self._format_system_message(message)) - elif message_role == "user": - turn_number += 1 - markdown_lines.extend(await self._format_user_message_async(message=message, turn_number=turn_number)) - else: # assistant or other response roles - markdown_lines.extend(await self._format_assistant_message_async(message=message)) - - # Add scores if requested - if include_scores: - markdown_lines.extend(self._format_message_scores(message)) - - return markdown_lines - - def _format_system_message(self, message: Message) -> list[str]: - """ - Format a system message as markdown. - - Creates markdown representation of system-level messages, typically - containing instructions or context for the conversation. - - Args: - message (Message): The system message to format. - - Returns: - List[str]: List of markdown strings representing the system message. - """ - lines = ["\n### System Message\n"] - lines.extend(f"{piece.converted_value}\n" for piece in message.message_pieces) - return lines - - async def _format_user_message_async(self, *, message: Message, turn_number: int) -> list[str]: - """ - Format a user message as markdown with turn numbering. - - Creates markdown representation of user input messages, including turn - numbers for easy conversation tracking. Shows both original and converted - values when they differ. - - Args: - message (Message): The user message to format. - turn_number (int): The conversation turn number for this message. - - Returns: - List[str]: List of markdown strings representing the user message. - """ - lines = [f"\n### Turn {turn_number}\n", "#### User\n"] - - for piece in message.message_pieces: - lines.extend(await self._format_piece_content_async(piece=piece, show_original=True)) - - return lines - - async def _format_assistant_message_async(self, *, message: Message) -> list[str]: - """ - Format an assistant or system response message as markdown. - - Creates markdown representation of response messages from assistants - or other system components. Automatically capitalizes the role name - for display purposes. - - Args: - message (Message): The response message to format. - - Returns: - List[str]: List of markdown strings representing the response message. - """ - lines = [] - piece = message.message_pieces[0] - role_name = "Assistant (Simulated)" if piece.is_simulated else piece.api_role.capitalize() - - lines.append(f"\n#### {role_name}\n") - - for piece in message.message_pieces: - lines.extend(await self._format_piece_content_async(piece=piece, show_original=False)) - - return lines - - def _get_audio_mime_type(self, *, audio_path: str) -> str: - """ - Determine the MIME type for an audio file based on its file extension. - - Args: - audio_path (str): The path to the audio file. - - Returns: - str: The appropriate MIME type for the audio file. - """ - if audio_path.lower().endswith(".wav"): - return "audio/wav" - if audio_path.lower().endswith(".ogg"): - return "audio/ogg" - if audio_path.lower().endswith(".m4a"): - return "audio/mp4" - return "audio/mpeg" # Default fallback for .mp3, .mpeg, and unknown formats - - def _format_image_content(self, *, image_path: str) -> list[str]: - """ - Format image content as markdown. - - Args: - image_path (str): The path to the image file. - - Returns: - List[str]: List of markdown lines for the image. - """ - relative_path = os.path.relpath(image_path) - posix_path = relative_path.replace("\\", "/") - return [f"![Image]({posix_path})\n"] - - def _format_audio_content(self, *, audio_path: str) -> list[str]: - """ - Format audio content as HTML5 audio player. - - Args: - audio_path (str): The path to the audio file. - - Returns: - List[str]: List of markdown lines for the audio player. - """ - lines = [] - lines.append("\n") - - return lines - - def _format_error_content(self, *, piece: MessagePiece) -> list[str]: - """ - Format error response content with proper styling. - - Args: - piece (MessagePiece): The message piece containing the error. - - Returns: - List[str]: List of markdown lines for the error response. - """ - lines = [] - lines.append("**Error Response:**\n") - lines.append(f"*Error Type: {piece.response_error}*\n") - lines.append("```json") - lines.append(piece.converted_value) - lines.append("```\n") - - return lines - - def _format_text_content(self, *, piece: MessagePiece, show_original: bool) -> list[str]: - """ - Format regular text content. - - Args: - piece (MessagePiece): The message piece containing the text. - show_original (bool): Whether to show original value if different. - - Returns: - List[str]: List of markdown lines for the text content. - """ - lines = [] - - if show_original and piece.converted_value != piece.original_value: - lines.append("**Original:**\n") - lines.append(f"{piece.original_value}\n") - lines.append("\n**Converted:**\n") - - lines.append(f"{piece.converted_value}\n") - - return lines - - async def _format_piece_content_async(self, *, piece: MessagePiece, show_original: bool) -> list[str]: - """ - Format a single piece content based on its data type. - - Handles different content types including text, images, audio, and error responses. - - Args: - piece (MessagePiece): The message piece to format. - show_original (bool): Whether to show original value if different - from converted value. - - Returns: - List[str]: List of markdown lines representing this piece. - """ - if piece.converted_value_data_type == "image_path": - return self._format_image_content(image_path=piece.converted_value) - if piece.converted_value_data_type == "audio_path": - return self._format_audio_content(audio_path=piece.converted_value) - # Handle text content (including errors) - if piece.has_error(): - return self._format_error_content(piece=piece) - return self._format_text_content(piece=piece, show_original=show_original) - - def _format_message_scores(self, message: Message) -> list[str]: - """ - Format scores for all pieces in a message as markdown. - - Retrieves and formats all scores associated with the message pieces - in the given message. Creates a dedicated scores section with - appropriate markdown formatting. - - Args: - message (Message): The message containing pieces - to format scores for. - - Returns: - List[str]: List of markdown strings representing the scores. - """ - lines = [] - for piece in message.message_pieces: - scores = self._memory.get_prompt_scores(prompt_ids=[str(piece.id)]) - if scores: - lines.append("\n##### Scores\n") - lines.extend(self._format_score(score, indent="") for score in scores) - lines.append("") - return lines - - async def _get_summary_markdown_async(self, result: AttackResult) -> list[str]: - """ - Generate markdown lines for the attack summary. - - Creates a comprehensive summary including basic information tables, - execution metrics, outcome status, and final scores. Uses markdown - tables for structured data presentation. - - Args: - result (AttackResult): The attack result to summarize. - - Returns: - List[str]: List of markdown strings representing the formatted summary. - """ - markdown_lines = [] - markdown_lines.append("## Attack Summary\n") - - # Basic Information Table - markdown_lines.append("### Basic Information\n") - markdown_lines.append("| Field | Value |") - markdown_lines.append("|-------|-------|") - markdown_lines.append(f"| **Objective** | {result.objective} |") - - _strategy_id = result.get_attack_strategy_identifier() - attack_type = _strategy_id.class_name if _strategy_id is not None else "Unknown" - - markdown_lines.append(f"| **Attack Type** | `{attack_type}` |") - markdown_lines.append(f"| **Conversation ID** | `{result.conversation_id}` |") - - # Execution Metrics - markdown_lines.append("\n### Execution Metrics\n") - markdown_lines.append("| Metric | Value |") - markdown_lines.append("|--------|-------|") - markdown_lines.append(f"| **Turns Executed** | {result.executed_turns} |") - markdown_lines.append(f"| **Execution Time** | {self._format_time(result.execution_time_ms)} |") - - # Outcome - outcome_emoji = self._get_outcome_icon(result.outcome) - markdown_lines.append("\n### Outcome\n") - markdown_lines.append(f"**Status:** {outcome_emoji} **{result.outcome.value.upper()}**\n") - - if result.outcome_reason: - markdown_lines.append(f"**Reason:** {result.outcome_reason}\n") - - # Final Score - if result.last_score: - markdown_lines.append("\n### Final Score\n") - markdown_lines.append(self._format_score(result.last_score)) - - return markdown_lines - - async def _get_pruned_conversations_markdown_async(self, result: AttackResult) -> list[str]: - """ - Generate markdown lines for pruned conversations. - - For each pruned conversation, displays only the last message and its - associated score to provide context without overwhelming output. - - Args: - result (AttackResult): The attack result containing related conversations. - - Returns: - List[str]: List of markdown strings for pruned conversations, or empty list if none. - """ - pruned_refs = result.get_conversations_by_type(ConversationType.PRUNED) - - if not pruned_refs: - return [] - - markdown_lines = [] - markdown_lines.append(f"\n## Pruned Conversations ({len(pruned_refs)} total)\n") - markdown_lines.append("*Showing only the last message and score for each pruned branch.*\n") - - for idx, ref in enumerate(pruned_refs, 1): - # Header for this pruned conversation - label = f"### 🗑️ Pruned #{idx}" - if ref.description: - label += f" - {ref.description}" - markdown_lines.append(f"\n{label}\n") - - # Get the conversation messages - messages = list(self._memory.get_conversation(conversation_id=ref.conversation_id)) - - if not messages: - markdown_lines.append(f"*No messages found for conversation: `{ref.conversation_id}`*\n") - continue - - # Get only the last message - last_message = messages[-1] - role_label = last_message.api_role.upper() - - markdown_lines.append(f"**Last Message ({role_label}):**\n") - - for piece in last_message.message_pieces: - # Format the message content - content = piece.converted_value or "" - if "\n" in content: - markdown_lines.append("```") - markdown_lines.append(content) - markdown_lines.append("```") - else: - markdown_lines.append(f"> {content}\n") - - # Get and format associated scores - scores = self._memory.get_prompt_scores(prompt_ids=[str(piece.id)]) - if scores: - markdown_lines.append("\n**Score:**\n") - markdown_lines.extend(self._format_score(score, indent="") for score in scores) - - return markdown_lines - - async def _get_adversarial_conversation_markdown_async(self, result: AttackResult) -> list[str]: - """ - Generate markdown lines for the adversarial conversation. - - The adversarial conversation shows the red teaming LLM's reasoning. - For attacks with multiple adversarial conversations (e.g., TAP), only the - best-scoring branch's adversarial conversation is shown if available. - - Args: - result (AttackResult): The attack result containing related conversations. - - Returns: - List[str]: List of markdown strings for the adversarial conversation, or empty list. - """ - adversarial_refs = result.get_conversations_by_type(ConversationType.ADVERSARIAL) - - if not adversarial_refs: - return [] - - markdown_lines = [] - markdown_lines.append("\n## Adversarial Conversation (Red Team LLM)\n") - markdown_lines.append("*This shows the reasoning and strategy of the red teaming LLM.*\n") - - # Check if result has a best_adversarial_conversation_id (e.g., TAP attack) - # If so, only show that conversation instead of all adversarial conversations - best_adversarial_id = result.metadata.get("best_adversarial_conversation_id") - if best_adversarial_id: - # Filter to only the best adversarial conversation - adversarial_refs = [ref for ref in adversarial_refs if ref.conversation_id == best_adversarial_id] - if adversarial_refs: - markdown_lines.append("*📌 Showing best-scoring branch's adversarial conversation*\n") - - for ref in adversarial_refs: - if ref.description: - markdown_lines.append(f"*📝 {ref.description}*\n") - - messages = list(self._memory.get_conversation(conversation_id=ref.conversation_id)) - - if not messages: - markdown_lines.append(f"*No messages found for conversation: `{ref.conversation_id}`*\n") - continue - - # Format each message in the adversarial conversation - turn_number = 0 - for message in messages: - if message.api_role == "user": - turn_number += 1 - markdown_lines.append(f"\n#### Turn {turn_number} - USER\n") - elif message.api_role == "system": - markdown_lines.append("\n#### SYSTEM\n") - else: - markdown_lines.append(f"\n#### {message.api_role.upper()}\n") - - for piece in message.message_pieces: - content = piece.converted_value or "" - if len(content) > 200 or "\n" in content: - markdown_lines.append("```") - markdown_lines.append(content) - markdown_lines.append("```") - else: - markdown_lines.append(f"> {content}\n") + from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter - return markdown_lines + return MarkdownAttackMemoryPrinter + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/printer/attack_result/markdown.py b/pyrit/printer/attack_result/markdown.py new file mode 100644 index 0000000000..5afeeb6fe3 --- /dev/null +++ b/pyrit/printer/attack_result/markdown.py @@ -0,0 +1,582 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from datetime import datetime, timezone + +from pyrit.models import AttackResult, ConversationType, Message, MessagePiece, Score +from pyrit.printer.attack_result.base import AttackResultPrinterBase + + +class MarkdownAttackPrinterBase(AttackResultPrinterBase): + """ + Markdown printer base for attack results optimized for Jupyter notebooks. + + Contains all formatting logic. Subclasses implement get_conversation_async + and get_scores_async for data fetching. + """ + + def __init__(self, *, display_inline: bool = True) -> None: + """ + Initialize the markdown printer. + + Args: + display_inline (bool): If True, uses IPython.display to render markdown + inline in Jupyter notebooks. If False, prints markdown strings. + Defaults to True. + """ + self._display_inline = display_inline + + def _render_markdown(self, markdown_lines: list[str]) -> None: + """ + Render the markdown content using appropriate display method. + + Attempts to use IPython.display.Markdown for Jupyter notebook rendering + when display_inline is True, falling back to print() if not available. + + Args: + markdown_lines (List[str]): List of markdown strings to render. + """ + full_markdown = "\n".join(markdown_lines) + + if self._display_inline: + try: + from IPython.display import Markdown, display + + display(Markdown(full_markdown)) + except (ImportError, NameError): + print(full_markdown) + else: + print(full_markdown) + + def _format_score(self, score: Score, indent: str = "") -> str: + """ + Format a score object as markdown with proper styling. + + Args: + score (Score): The score object to format. + indent (str): String prefix for indentation. Defaults to "". + + Returns: + str: Formatted markdown representation of the score. + """ + lines = [] + + score_value = score.get_value() + if isinstance(score_value, bool): + value_str = str(score_value) + elif isinstance(score_value, (int, float)): + value_str = f"**{score_value:.2f}**" if isinstance(score_value, float) else f"**{score_value}**" + else: + value_str = f"**{score_value}**" + + lines.append(f"{indent}- **Score Type:** {score.score_type}") + lines.append(f"{indent}- **Value:** {value_str}") + category_str = ", ".join(score.score_category) if score.score_category else "N/A" + lines.append(f"{indent}- **Category:** {category_str}") + + if score.score_rationale: + rationale_lines = score.score_rationale.split("\n") + if len(rationale_lines) > 1: + lines.append(f"{indent}- **Rationale:**") + lines.extend(f"{indent} {line}" for line in rationale_lines) + else: + lines.append(f"{indent}- **Rationale:** {score.score_rationale}") + + if score.score_metadata: + lines.append(f"{indent}- **Metadata:** `{score.score_metadata}`") + + return "\n".join(lines) + + async def print_result_async( + self, + result: AttackResult, + *, + include_auxiliary_scores: bool = False, + include_pruned_conversations: bool = False, + include_adversarial_conversation: bool = False, + ) -> None: + """ + Print the complete attack result as formatted markdown. + + Args: + result (AttackResult): The attack result to print. + include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. + include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. + include_adversarial_conversation (bool): Whether to include the adversarial conversation. + Defaults to False. + """ + markdown_lines = [] + + outcome_emoji = self._get_outcome_icon(result.outcome) + markdown_lines.append(f"# {outcome_emoji} Attack Result: {result.outcome.value.upper()}\n") + markdown_lines.append("---\n") + + summary_lines = await self._get_summary_markdown_async(result) + markdown_lines.extend(summary_lines) + markdown_lines.append("---\n") + + markdown_lines.append("\n## Conversation History\n") + conversation_lines = await self._get_conversation_markdown_async( + result=result, include_scores=include_auxiliary_scores + ) + markdown_lines.extend(conversation_lines) + + if include_pruned_conversations: + pruned_lines = await self._get_pruned_conversations_markdown_async(result) + if pruned_lines: + markdown_lines.extend(pruned_lines) + + if include_adversarial_conversation: + adversarial_lines = await self._get_adversarial_conversation_markdown_async(result) + if adversarial_lines: + markdown_lines.extend(adversarial_lines) + + if result.metadata: + markdown_lines.append("\n## Additional Metadata\n") + for key, value in result.metadata.items(): + try: + str_value = str(value) + markdown_lines.append(f"- **{key}:** {str_value}") + except Exception: + pass + + markdown_lines.append("\n---") + timestamp_utc = datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z") + markdown_lines.append(f"*Report generated at {timestamp_utc}*") + + self._render_markdown(markdown_lines) + + async def print_conversation_async(self, result: AttackResult, *, include_scores: bool = False) -> None: + """ + Print only the conversation history as formatted markdown. + + Args: + result (AttackResult): The attack result containing the conversation to display. + include_scores (bool): Whether to include scores. Defaults to False. + """ + markdown_lines = await self._get_conversation_markdown_async(result=result, include_scores=include_scores) + self._render_markdown(markdown_lines) + + async def print_summary_async(self, result: AttackResult) -> None: + """ + Print a summary of the attack result as formatted markdown. + + Args: + result (AttackResult): The attack result to summarize. + """ + markdown_lines = await self._get_summary_markdown_async(result) + self._render_markdown(markdown_lines) + + async def _get_conversation_markdown_async( + self, *, result: AttackResult, include_scores: bool = False + ) -> list[str]: + """ + Generate markdown lines for the conversation history. + + Args: + result (AttackResult): The attack result containing the conversation ID. + include_scores (bool): Whether to include scores. Defaults to False. + + Returns: + list[str]: Markdown strings for the conversation. + """ + markdown_lines: list[str] = [] + + if not result.conversation_id: + markdown_lines.append("*No conversation ID available*\n") + return markdown_lines + + messages = await self.get_conversation_async(result.conversation_id) + + if not messages: + markdown_lines.append(f"*No conversation found for ID: {result.conversation_id}*\n") + return markdown_lines + + turn_number = 0 + + for message in messages: + if not message.message_pieces: + continue + + message_role = message.get_piece().api_role + + if message_role == "system": + markdown_lines.extend(self._format_system_message(message)) + elif message_role == "user": + turn_number += 1 + markdown_lines.extend(await self._format_user_message_async(message=message, turn_number=turn_number)) + else: + markdown_lines.extend(await self._format_assistant_message_async(message=message)) + + if include_scores: + markdown_lines.extend(await self._format_message_scores_async(message)) + + return markdown_lines + + def _format_system_message(self, message: Message) -> list[str]: + """ + Format a system message as markdown. + + Args: + message (Message): The system message to format. + + Returns: + list[str]: Markdown strings for the system message. + """ + lines = ["\n### System Message\n"] + lines.extend(f"{piece.converted_value}\n" for piece in message.message_pieces) + return lines + + async def _format_user_message_async(self, *, message: Message, turn_number: int) -> list[str]: + """ + Format a user message as markdown with turn numbering. + + Args: + message (Message): The user message to format. + turn_number (int): The conversation turn number. + + Returns: + list[str]: Markdown strings for the user message. + """ + lines = [f"\n### Turn {turn_number}\n", "#### User\n"] + + for piece in message.message_pieces: + lines.extend(await self._format_piece_content_async(piece=piece, show_original=True)) + + return lines + + async def _format_assistant_message_async(self, *, message: Message) -> list[str]: + """ + Format an assistant response message as markdown. + + Args: + message (Message): The response message to format. + + Returns: + list[str]: Markdown strings for the response message. + """ + lines: list[str] = [] + piece = message.message_pieces[0] + role_name = "Assistant (Simulated)" if piece.is_simulated else piece.api_role.capitalize() + + lines.append(f"\n#### {role_name}\n") + + for piece in message.message_pieces: + lines.extend(await self._format_piece_content_async(piece=piece, show_original=False)) + + return lines + + def _get_audio_mime_type(self, *, audio_path: str) -> str: + """ + Determine the MIME type for an audio file based on its file extension. + + Args: + audio_path (str): The path to the audio file. + + Returns: + str: The appropriate MIME type for the audio file. + """ + if audio_path.lower().endswith(".wav"): + return "audio/wav" + if audio_path.lower().endswith(".ogg"): + return "audio/ogg" + if audio_path.lower().endswith(".m4a"): + return "audio/mp4" + return "audio/mpeg" + + def _format_image_content(self, *, image_path: str) -> list[str]: + """ + Format image content as markdown. + + Args: + image_path (str): The path to the image file. + + Returns: + list[str]: Markdown lines for the image. + """ + relative_path = os.path.relpath(image_path) + posix_path = relative_path.replace("\\", "/") + return [f"![Image]({posix_path})\n"] + + def _format_audio_content(self, *, audio_path: str) -> list[str]: + """ + Format audio content as HTML5 audio player. + + Args: + audio_path (str): The path to the audio file. + + Returns: + list[str]: Markdown lines for the audio player. + """ + lines: list[str] = [] + lines.append("\n") + + return lines + + def _format_error_content(self, *, piece: MessagePiece) -> list[str]: + """ + Format error response content with proper styling. + + Args: + piece (MessagePiece): The message piece containing the error. + + Returns: + list[str]: Markdown lines for the error response. + """ + lines: list[str] = [] + lines.append("**Error Response:**\n") + lines.append(f"*Error Type: {piece.response_error}*\n") + lines.append("```json") + lines.append(piece.converted_value) + lines.append("```\n") + + return lines + + def _format_text_content(self, *, piece: MessagePiece, show_original: bool) -> list[str]: + """ + Format regular text content. + + Args: + piece (MessagePiece): The message piece containing the text. + show_original (bool): Whether to show original value if different. + + Returns: + list[str]: Markdown lines for the text content. + """ + lines: list[str] = [] + + if show_original and piece.converted_value != piece.original_value: + lines.append("**Original:**\n") + lines.append(f"{piece.original_value}\n") + lines.append("\n**Converted:**\n") + + lines.append(f"{piece.converted_value}\n") + + return lines + + async def _format_piece_content_async(self, *, piece: MessagePiece, show_original: bool) -> list[str]: + """ + Format a single piece content based on its data type. + + Args: + piece (MessagePiece): The message piece to format. + show_original (bool): Whether to show original value if different. + + Returns: + list[str]: Markdown lines for this piece. + """ + if piece.converted_value_data_type == "image_path": + return self._format_image_content(image_path=piece.converted_value) + if piece.converted_value_data_type == "audio_path": + return self._format_audio_content(audio_path=piece.converted_value) + if piece.has_error(): + return self._format_error_content(piece=piece) + return self._format_text_content(piece=piece, show_original=show_original) + + async def _format_message_scores_async(self, message: Message) -> list[str]: + """ + Format scores for all pieces in a message as markdown. + + Args: + message (Message): The message containing pieces to format scores for. + + Returns: + list[str]: Markdown strings for the scores. + """ + lines: list[str] = [] + for piece in message.message_pieces: + scores = await self.get_scores_async(prompt_ids=[str(piece.id)]) + if scores: + lines.append("\n##### Scores\n") + lines.extend(self._format_score(score, indent="") for score in scores) + lines.append("") + return lines + + async def _get_summary_markdown_async(self, result: AttackResult) -> list[str]: + """ + Generate markdown lines for the attack summary. + + Args: + result (AttackResult): The attack result to summarize. + + Returns: + list[str]: Markdown strings for the summary. + """ + markdown_lines: list[str] = [] + markdown_lines.append("## Attack Summary\n") + + markdown_lines.append("### Basic Information\n") + markdown_lines.append("| Field | Value |") + markdown_lines.append("|-------|-------|") + markdown_lines.append(f"| **Objective** | {result.objective} |") + + _strategy_id = result.get_attack_strategy_identifier() + attack_type = _strategy_id.class_name if _strategy_id is not None else "Unknown" + + markdown_lines.append(f"| **Attack Type** | `{attack_type}` |") + markdown_lines.append(f"| **Conversation ID** | `{result.conversation_id}` |") + + markdown_lines.append("\n### Execution Metrics\n") + markdown_lines.append("| Metric | Value |") + markdown_lines.append("|--------|-------|") + markdown_lines.append(f"| **Turns Executed** | {result.executed_turns} |") + markdown_lines.append(f"| **Execution Time** | {self._format_time(result.execution_time_ms)} |") + + outcome_emoji = self._get_outcome_icon(result.outcome) + markdown_lines.append("\n### Outcome\n") + markdown_lines.append(f"**Status:** {outcome_emoji} **{result.outcome.value.upper()}**\n") + + if result.outcome_reason: + markdown_lines.append(f"**Reason:** {result.outcome_reason}\n") + + if result.last_score: + markdown_lines.append("\n### Final Score\n") + markdown_lines.append(self._format_score(result.last_score)) + + return markdown_lines + + async def _get_pruned_conversations_markdown_async(self, result: AttackResult) -> list[str]: + """ + Generate markdown lines for pruned conversations. + + Args: + result (AttackResult): The attack result containing related conversations. + + Returns: + list[str]: Markdown strings for pruned conversations. + """ + pruned_refs = result.get_conversations_by_type(ConversationType.PRUNED) + + if not pruned_refs: + return [] + + markdown_lines: list[str] = [] + markdown_lines.append(f"\n## Pruned Conversations ({len(pruned_refs)} total)\n") + markdown_lines.append("*Showing only the last message and score for each pruned branch.*\n") + + for idx, ref in enumerate(pruned_refs, 1): + label = f"### 🗑️ Pruned #{idx}" + if ref.description: + label += f" - {ref.description}" + markdown_lines.append(f"\n{label}\n") + + messages = await self.get_conversation_async(ref.conversation_id) + + if not messages: + markdown_lines.append(f"*No messages found for conversation: `{ref.conversation_id}`*\n") + continue + + last_message = messages[-1] + role_label = last_message.api_role.upper() + + markdown_lines.append(f"**Last Message ({role_label}):**\n") + + for piece in last_message.message_pieces: + content = piece.converted_value or "" + if "\n" in content: + markdown_lines.append("```") + markdown_lines.append(content) + markdown_lines.append("```") + else: + markdown_lines.append(f"> {content}\n") + + scores = await self.get_scores_async(prompt_ids=[str(piece.id)]) + if scores: + markdown_lines.append("\n**Score:**\n") + markdown_lines.extend(self._format_score(score, indent="") for score in scores) + + return markdown_lines + + async def _get_adversarial_conversation_markdown_async(self, result: AttackResult) -> list[str]: + """ + Generate markdown lines for the adversarial conversation. + + Args: + result (AttackResult): The attack result containing related conversations. + + Returns: + list[str]: Markdown strings for the adversarial conversation. + """ + adversarial_refs = result.get_conversations_by_type(ConversationType.ADVERSARIAL) + + if not adversarial_refs: + return [] + + markdown_lines: list[str] = [] + markdown_lines.append("\n## Adversarial Conversation (Red Team LLM)\n") + markdown_lines.append("*This shows the reasoning and strategy of the red teaming LLM.*\n") + + best_adversarial_id = result.metadata.get("best_adversarial_conversation_id") + if best_adversarial_id: + adversarial_refs = [ref for ref in adversarial_refs if ref.conversation_id == best_adversarial_id] + if adversarial_refs: + markdown_lines.append("*📌 Showing best-scoring branch's adversarial conversation*\n") + + for ref in adversarial_refs: + if ref.description: + markdown_lines.append(f"*📝 {ref.description}*\n") + + messages = await self.get_conversation_async(ref.conversation_id) + + if not messages: + markdown_lines.append(f"*No messages found for conversation: `{ref.conversation_id}`*\n") + continue + + turn_number = 0 + for message in messages: + if message.api_role == "user": + turn_number += 1 + markdown_lines.append(f"\n#### Turn {turn_number} - USER\n") + elif message.api_role == "system": + markdown_lines.append("\n#### SYSTEM\n") + else: + markdown_lines.append(f"\n#### {message.api_role.upper()}\n") + + for piece in message.message_pieces: + content = piece.converted_value or "" + if len(content) > 200 or "\n" in content: + markdown_lines.append("```") + markdown_lines.append(content) + markdown_lines.append("```") + else: + markdown_lines.append(f"> {content}\n") + + return markdown_lines + + +class MarkdownAttackMemoryPrinter(MarkdownAttackPrinterBase): + """ + Framework markdown printer for attack results. + + Implements data-fetching via CentralMemory (deferred import). + All formatting logic lives in MarkdownAttackPrinterBase. + """ + + def __init__(self, *, display_inline: bool = True) -> None: + """ + Initialize the markdown printer. + + Args: + display_inline (bool): If True, uses IPython.display to render markdown + inline in Jupyter notebooks. If False, prints markdown strings. + Defaults to True. + """ + super().__init__(display_inline=display_inline) + from pyrit.memory import CentralMemory + + self._memory = CentralMemory.get_memory_instance() + + async def get_conversation_async(self, conversation_id: str) -> list[Message]: + """Fetch conversation messages from CentralMemory.""" + return list(self._memory.get_conversation(conversation_id=conversation_id)) + + async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: + """Fetch scores from CentralMemory.""" + return self._memory.get_prompt_scores(prompt_ids=prompt_ids) diff --git a/pyrit/scenario/printer/__init__.py b/pyrit/scenario/printer/__init__.py index c613b899ee..d9afefd958 100644 --- a/pyrit/scenario/printer/__init__.py +++ b/pyrit/scenario/printer/__init__.py @@ -33,9 +33,3 @@ def __getattr__(name: str): # noqa: N807 return ScenarioResultPrinterBase raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -__all__ = [ - "ConsoleScenarioResultPrinter", - "ScenarioResultPrinter", -] diff --git a/pyrit/scenario/printer/scenario_result_printer.py b/pyrit/scenario/printer/scenario_result_printer.py deleted file mode 100644 index 1e25e7a364..0000000000 --- a/pyrit/scenario/printer/scenario_result_printer.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import ABC, abstractmethod - -from pyrit.models.scenario_result import ScenarioResult - - -class ScenarioResultPrinter(ABC): - """ - Abstract base class for printing scenario results. - - This interface defines the contract for printing scenario results in various formats. - Implementations can render results to console, logs, files, or other outputs. - """ - - @abstractmethod - async def print_summary_async(self, result: ScenarioResult) -> None: - """ - Print a summary of the scenario result with per-strategy breakdown. - - Displays: - - Scenario identification (name, version, PyRIT version) - - Target information - - Overall statistics - - Per-strategy success rates and result counts - - Args: - result (ScenarioResult): The scenario result to summarize - """ diff --git a/pyrit/score/printer/scorer_printer.py b/pyrit/score/printer/scorer_printer.py deleted file mode 100644 index e296b0da6f..0000000000 --- a/pyrit/score/printer/scorer_printer.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import ABC, abstractmethod - -from pyrit.identifiers import ComponentIdentifier - - -class ScorerPrinter(ABC): - """ - Abstract base class for printing scorer information. - - This interface defines the contract for printing scorer details including - type information, nested sub-scorers, and evaluation metrics from the registry. - Implementations can render output to console, logs, files, or other outputs. - """ - - @abstractmethod - def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: - """ - Print objective scorer information including type, nested scorers, and evaluation metrics. - - This method displays: - - Scorer type and identity information - - Nested sub-scorers (for composite scorers) - - Objective evaluation metrics (accuracy, precision, recall, F1) from the registry - - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. - """ - - @abstractmethod - def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: - """ - Print harm scorer information including type, nested scorers, and evaluation metrics. - - This method displays: - - Scorer type and identity information - - Nested sub-scorers (for composite scorers) - - Harm evaluation metrics (MAE, Krippendorff alpha) from the registry - - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. - harm_category (str): The harm category for looking up metrics (e.g., "hate_speech", "violence"). - """ diff --git a/tests/unit/executor/attack/core/test_markdown_printer.py b/tests/unit/executor/attack/core/test_markdown_printer.py index f87e56606f..e4dbb82051 100644 --- a/tests/unit/executor/attack/core/test_markdown_printer.py +++ b/tests/unit/executor/attack/core/test_markdown_printer.py @@ -7,7 +7,7 @@ import pytest -from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter +from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter as MarkdownAttackResultPrinter from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import CentralMemory @@ -25,7 +25,7 @@ def _mock_scorer_id(name: str = "MockScorer") -> ComponentIdentifier: @pytest.fixture def mock_memory(): memory = MagicMock(spec=CentralMemory) - with patch("pyrit.executor.attack.printer.markdown_printer.CentralMemory") as mock_central_memory: + with patch("pyrit.memory.CentralMemory") as mock_central_memory: mock_central_memory.get_memory_instance.return_value = memory mock_central_memory.get_conversation.return_value = [] yield memory From 91a417a44c53bfc8b3754a3eeacdb30e3eba4eb6 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 10:17:29 -0700 Subject: [PATCH 06/34] Add missing __all__ to scenario printer deprecation shim Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/printer/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyrit/scenario/printer/__init__.py b/pyrit/scenario/printer/__init__.py index d9afefd958..c613b899ee 100644 --- a/pyrit/scenario/printer/__init__.py +++ b/pyrit/scenario/printer/__init__.py @@ -33,3 +33,9 @@ def __getattr__(name: str): # noqa: N807 return ScenarioResultPrinterBase raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "ConsoleScenarioResultPrinter", + "ScenarioResultPrinter", +] From f31d0d04bb8e72e9aca607228b5017a51caa1442 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 10:28:20 -0700 Subject: [PATCH 07/34] Fix type checker errors in from_dict methods and MemoryPrinter types - Changed dict[str, object] to dict[str, Any] in MessagePiece.from_dict() and Message.from_dict() to satisfy pyright (dict.get returns object otherwise) - Added Any import to message_piece.py and message.py - Wrapped get_prompt_scores return in list() for Sequence -> list coercion - Added isinstance check in display_image_async for type safety Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/frontend_core.py | 2 +- pyrit/executor/attack/__init__.py | 10 +++---- pyrit/models/attack_result.py | 15 +++-------- pyrit/models/message.py | 6 ++--- pyrit/models/message_piece.py | 26 ++++++------------- pyrit/models/scenario_result.py | 18 +++++-------- pyrit/printer/attack_result/base.py | 2 +- pyrit/printer/attack_result/console.py | 6 +++-- pyrit/printer/attack_result/markdown.py | 2 +- pyrit/score/__init__.py | 4 +-- .../attack/core/test_markdown_printer.py | 2 +- .../printer/test_attack_result_printer.py | 2 +- .../attack/printer/test_console_printer.py | 2 +- 13 files changed, 38 insertions(+), 59 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 95a0faa829..dabd464ccb 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -39,9 +39,9 @@ from pyrit.cli._cli_args import validate_integer as validate_integer from pyrit.cli._cli_args import validate_log_level as validate_log_level from pyrit.cli._cli_args import validate_log_level_argparse as validate_log_level_argparse +from pyrit.printer.scenario_result.console import ConsoleScenarioMemoryPrinter as ConsoleScenarioResultPrinter from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry from pyrit.scenario import DatasetConfiguration -from pyrit.printer.scenario_result.console import ConsoleScenarioMemoryPrinter as ConsoleScenarioResultPrinter from pyrit.setup import ConfigurationLoader, initialize_pyrit_async from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP diff --git a/pyrit/executor/attack/__init__.py b/pyrit/executor/attack/__init__.py index ad50d8af51..aaf76da58a 100644 --- a/pyrit/executor/attack/__init__.py +++ b/pyrit/executor/attack/__init__.py @@ -38,11 +38,6 @@ TreeOfAttacksWithPruningAttack, generate_simulated_conversation_async, ) - -# Import printer modules last to avoid circular dependencies -from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter -from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter -from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter as MarkdownAttackResultPrinter from pyrit.executor.attack.single_turn import ( ContextComplianceAttack, FlipAttack, @@ -55,6 +50,11 @@ SkeletonKeyAttack, ) +# Import printer modules last to avoid circular dependencies +from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter +from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter +from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter as MarkdownAttackResultPrinter + __all__ = [ "AttackStrategy", "AttackContext", diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index ef58978f34..c0a0209794 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -248,8 +248,7 @@ def to_dict(self) -> dict[str, Any]: "outcome_reason": self.outcome_reason, "timestamp": self.timestamp.isoformat() if self.timestamp else None, "related_conversations": [ - ref.to_dict() if isinstance(ref, ConversationReference) else ref - for ref in self.related_conversations + ref.to_dict() if isinstance(ref, ConversationReference) else ref for ref in self.related_conversations ], "metadata": self.metadata, "labels": self.labels, @@ -286,22 +285,16 @@ def from_dict(cls, data: dict[str, Any]) -> AttackResult: if data.get("atomic_attack_identifier") else None ), - last_response=( - MessagePiece.from_dict(data["last_response"]) if data.get("last_response") else None - ), + last_response=(MessagePiece.from_dict(data["last_response"]) if data.get("last_response") else None), last_score=Score.from_dict(data["last_score"]) if data.get("last_score") else None, executed_turns=data.get("executed_turns", 0), execution_time_ms=data.get("execution_time_ms", 0), outcome=AttackOutcome(data.get("outcome", "undetermined")), outcome_reason=data.get("outcome_reason"), timestamp=( - datetime.fromisoformat(data["timestamp"]) - if data.get("timestamp") - else datetime.now(timezone.utc) + datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else datetime.now(timezone.utc) ), - related_conversations={ - ConversationReference.from_dict(r) for r in data.get("related_conversations", []) - }, + related_conversations={ConversationReference.from_dict(r) for r in data.get("related_conversations", [])}, metadata=data.get("metadata", {}), labels=data.get("labels", {}), error_message=data.get("error_message"), diff --git a/pyrit/models/message.py b/pyrit/models/message.py index e77f707b0f..234606517d 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -6,7 +6,7 @@ import copy import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.utils import combine_dict from pyrit.models.message_piece import MessagePiece @@ -328,7 +328,7 @@ def to_full_dict(self) -> dict[str, object]: } @classmethod - def from_dict(cls, data: dict[str, object]) -> Message: + def from_dict(cls, data: dict[str, Any]) -> Message: """ Reconstruct a Message from a dictionary. @@ -336,7 +336,7 @@ def from_dict(cls, data: dict[str, object]) -> Message: containing a list of MessagePiece dictionaries. Args: - data (dict[str, object]): Dictionary as produced by to_full_dict(). + data (dict[str, Any]): Dictionary as produced by to_full_dict(). Returns: Message: Reconstructed instance. diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 4f756caaef..767b42ccd9 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args from uuid import uuid4 from pyrit.common.deprecation import print_deprecation_message @@ -355,12 +355,12 @@ def __str__(self) -> str: __repr__ = __str__ @classmethod - def from_dict(cls, data: dict[str, object]) -> MessagePiece: + def from_dict(cls, data: dict[str, Any]) -> MessagePiece: """ Reconstruct a MessagePiece from a dictionary. Args: - data (dict[str, object]): Dictionary as produced by to_dict(). + data (dict[str, Any]): Dictionary as produced by to_dict(). Returns: MessagePiece: Reconstructed instance. @@ -373,9 +373,7 @@ def from_dict(cls, data: dict[str, object]) -> MessagePiece: role=data.get("role", "user"), conversation_id=data.get("conversation_id"), sequence=data.get("sequence", -1), - timestamp=( - datetime.fromisoformat(str(data["timestamp"])) if data.get("timestamp") else None - ), + timestamp=(datetime.fromisoformat(str(data["timestamp"])) if data.get("timestamp") else None), labels=data.get("labels"), targeted_harm_categories=data.get("targeted_harm_categories"), prompt_metadata=data.get("prompt_metadata"), @@ -390,14 +388,10 @@ def from_dict(cls, data: dict[str, object]) -> MessagePiece: else None ), attack_identifier=( - ComponentIdentifier.from_dict(data["attack_identifier"]) - if data.get("attack_identifier") - else None + ComponentIdentifier.from_dict(data["attack_identifier"]) if data.get("attack_identifier") else None ), scorer_identifier=( - ComponentIdentifier.from_dict(data["scorer_identifier"]) - if data.get("scorer_identifier") - else None + ComponentIdentifier.from_dict(data["scorer_identifier"]) if data.get("scorer_identifier") else None ), original_value_data_type=data.get("original_value_data_type", "text"), original_value=data.get("original_value", ""), @@ -407,12 +401,8 @@ def from_dict(cls, data: dict[str, object]) -> MessagePiece: converted_value_sha256=data.get("converted_value_sha256"), response_error=data.get("response_error", "none"), originator=data.get("originator", "undefined"), - original_prompt_id=( - uuid.UUID(str(data["original_prompt_id"])) if data.get("original_prompt_id") else None - ), - scores=( - [Score.from_dict(s) for s in data["scores"]] if data.get("scores") else None - ), + original_prompt_id=(uuid.UUID(str(data["original_prompt_id"])) if data.get("original_prompt_id") else None), + scores=([Score.from_dict(s) for s in data["scores"]] if data.get("scores") else None), ) def __eq__(self, other: object) -> bool: diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index f013291eb1..44c27181f6 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -95,9 +95,9 @@ def __init__( self, *, scenario_identifier: ScenarioIdentifier, - objective_target_identifier: "ComponentIdentifier", + objective_target_identifier: ComponentIdentifier, attack_results: dict[str, list[AttackResult]], - objective_scorer_identifier: "ComponentIdentifier", + objective_scorer_identifier: ComponentIdentifier, scenario_run_state: ScenarioRunState = "CREATED", labels: dict[str, str] | None = None, creation_time: datetime | None = None, @@ -276,7 +276,7 @@ def normalize_scenario_name(scenario_name: str) -> str: # Already PascalCase or other format, return as-is return scenario_name - def get_scorer_evaluation_metrics(self) -> "ScorerMetrics | None": + def get_scorer_evaluation_metrics(self) -> ScorerMetrics | None: """ Get the evaluation metrics for the scenario's scorer from the scorer evaluation registry. @@ -314,9 +314,7 @@ def to_dict(self) -> dict[str, Any]: self.objective_scorer_identifier.to_dict() if self.objective_scorer_identifier else None ), "scenario_run_state": self.scenario_run_state, - "attack_results": { - name: [r.to_dict() for r in results] for name, results in self.attack_results.items() - }, + "attack_results": {name: [r.to_dict() for r in results] for name, results in self.attack_results.items()}, "display_group_map": self._display_group_map, "labels": self.labels, "creation_time": self.creation_time.isoformat() if self.creation_time else None, @@ -360,12 +358,8 @@ def from_dict(cls, data: dict[str, Any]) -> ScenarioResult: }, display_group_map=data.get("display_group_map"), labels=data.get("labels"), - creation_time=( - datetime.fromisoformat(data["creation_time"]) if data.get("creation_time") else None - ), - completion_time=( - datetime.fromisoformat(data["completion_time"]) if data.get("completion_time") else None - ), + creation_time=(datetime.fromisoformat(data["creation_time"]) if data.get("creation_time") else None), + completion_time=(datetime.fromisoformat(data["completion_time"]) if data.get("completion_time") else None), number_tries=data.get("number_tries", 0), error_attack_result_ids=data.get("error_attack_result_ids"), error_message=data.get("error_message"), diff --git a/pyrit/printer/attack_result/base.py b/pyrit/printer/attack_result/base.py index 013abe1128..625e5558c4 100644 --- a/pyrit/printer/attack_result/base.py +++ b/pyrit/printer/attack_result/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod -from pyrit.models import AttackOutcome, AttackResult, Message, Score +from pyrit.models import AttackOutcome, Message, Score class AttackResultPrinterBase(ABC): diff --git a/pyrit/printer/attack_result/console.py b/pyrit/printer/attack_result/console.py index aa96e062c7..764d24ff5e 100644 --- a/pyrit/printer/attack_result/console.py +++ b/pyrit/printer/attack_result/console.py @@ -512,10 +512,12 @@ async def get_conversation_async(self, conversation_id: str) -> list[Message]: async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: """Fetch scores from CentralMemory.""" - return self._memory.get_prompt_scores(prompt_ids=prompt_ids) + return list(self._memory.get_prompt_scores(prompt_ids=prompt_ids)) async def display_image_async(self, piece: object) -> None: """Display images using PIL/IPython in notebook environments.""" from pyrit.common.display_response import display_image_response + from pyrit.models import MessagePiece - await display_image_response(piece) + if isinstance(piece, MessagePiece): + await display_image_response(piece) diff --git a/pyrit/printer/attack_result/markdown.py b/pyrit/printer/attack_result/markdown.py index 5afeeb6fe3..1ce176a96e 100644 --- a/pyrit/printer/attack_result/markdown.py +++ b/pyrit/printer/attack_result/markdown.py @@ -579,4 +579,4 @@ async def get_conversation_async(self, conversation_id: str) -> list[Message]: async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: """Fetch scores from CentralMemory.""" - return self._memory.get_prompt_scores(prompt_ids=prompt_ids) + return list(self._memory.get_prompt_scores(prompt_ids=prompt_ids)) diff --git a/pyrit/score/__init__.py b/pyrit/score/__init__.py index b25b3862cd..68ef2c0641 100644 --- a/pyrit/score/__init__.py +++ b/pyrit/score/__init__.py @@ -9,6 +9,8 @@ import importlib from typing import TYPE_CHECKING +from pyrit.printer.scorer.base import ScorerPrinterBase as ScorerPrinter +from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter as ConsoleScorerPrinter from pyrit.score.batch_scorer import BatchScorer from pyrit.score.conversation_scorer import ConversationScorer, create_conversation_scorer from pyrit.score.float_scale.azure_content_filter_scorer import AzureContentFilterScorer @@ -23,8 +25,6 @@ from pyrit.score.float_scale.self_ask_general_float_scale_scorer import SelfAskGeneralFloatScaleScorer from pyrit.score.float_scale.self_ask_likert_scorer import LikertScaleEvalFiles, LikertScalePaths, SelfAskLikertScorer from pyrit.score.float_scale.self_ask_scale_scorer import SelfAskScaleScorer -from pyrit.printer.scorer.base import ScorerPrinterBase as ScorerPrinter -from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter as ConsoleScorerPrinter from pyrit.score.scorer import Scorer from pyrit.score.scorer_evaluation.metrics_type import MetricsType, RegistryUpdateBehavior from pyrit.score.scorer_evaluation.scorer_metrics import ( diff --git a/tests/unit/executor/attack/core/test_markdown_printer.py b/tests/unit/executor/attack/core/test_markdown_printer.py index e4dbb82051..0ad6e957bf 100644 --- a/tests/unit/executor/attack/core/test_markdown_printer.py +++ b/tests/unit/executor/attack/core/test_markdown_printer.py @@ -7,11 +7,11 @@ import pytest -from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter as MarkdownAttackResultPrinter from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult, Message, MessagePiece, Score +from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter as MarkdownAttackResultPrinter def _mock_scorer_id(name: str = "MockScorer") -> ComponentIdentifier: diff --git a/tests/unit/executor/attack/printer/test_attack_result_printer.py b/tests/unit/executor/attack/printer/test_attack_result_printer.py index 4c51834b91..c7f4659779 100644 --- a/tests/unit/executor/attack/printer/test_attack_result_printer.py +++ b/tests/unit/executor/attack/printer/test_attack_result_printer.py @@ -3,8 +3,8 @@ import pytest -from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter from pyrit.models import AttackOutcome +from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter class _ConcreteAttackResultPrinter(AttackResultPrinter): diff --git a/tests/unit/executor/attack/printer/test_console_printer.py b/tests/unit/executor/attack/printer/test_console_printer.py index c2d160a29f..46b746d5e2 100644 --- a/tests/unit/executor/attack/printer/test_console_printer.py +++ b/tests/unit/executor/attack/printer/test_console_printer.py @@ -6,11 +6,11 @@ import pytest -from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, MessagePiece, Score from pyrit.models.conversation_reference import ConversationReference +from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter def _mock_scorer_id(name: str = "MockScorer") -> ComponentIdentifier: From 86172c9fde7efef453ba9668153b9f3dcf4095b0 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 10:30:33 -0700 Subject: [PATCH 08/34] Fix ruff lint errors: return types, docstrings, noqa - Added return type annotation (-> type) to all __getattr__ deprecation shims - Added noqa: B027 to display_image_async intentional no-op default - Added Returns/Raises sections to short docstrings (DOC201, DOC501) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/executor/attack/printer/__init__.py | 2 +- .../executor/attack/printer/console_printer.py | 2 +- .../executor/attack/printer/markdown_printer.py | 2 +- pyrit/printer/attack_result/base.py | 2 +- pyrit/printer/attack_result/console.py | 14 ++++++++++++-- pyrit/printer/attack_result/markdown.py | 14 ++++++++++++-- pyrit/printer/scorer/console.py | 17 +++++++++++++++-- pyrit/scenario/printer/__init__.py | 2 +- pyrit/scenario/printer/console_printer.py | 2 +- pyrit/score/printer/__init__.py | 2 +- pyrit/score/printer/console_scorer_printer.py | 2 +- 11 files changed, 47 insertions(+), 14 deletions(-) diff --git a/pyrit/executor/attack/printer/__init__.py b/pyrit/executor/attack/printer/__init__.py index 99834fb88e..6abcba9803 100644 --- a/pyrit/executor/attack/printer/__init__.py +++ b/pyrit/executor/attack/printer/__init__.py @@ -11,7 +11,7 @@ import warnings as _warnings -def __getattr__(name: str): # noqa: N807 +def __getattr__(name: str) -> type: # noqa: N807 _deprecated = { "ConsoleAttackResultPrinter": "pyrit.printer.attack_result.console", "MarkdownAttackResultPrinter": "pyrit.printer.attack_result.markdown", diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index c515c113ed..41f8980eef 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -9,7 +9,7 @@ import warnings as _warnings -def __getattr__(name: str): # noqa: N807 +def __getattr__(name: str) -> type: # noqa: N807 if name == "ConsoleAttackResultPrinter": _warnings.warn( "Importing ConsoleAttackResultPrinter from pyrit.executor.attack.printer.console_printer is deprecated " diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index 8270a385cd..79fa83b688 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -9,7 +9,7 @@ import warnings as _warnings -def __getattr__(name: str): # noqa: N807 +def __getattr__(name: str) -> type: # noqa: N807 if name == "MarkdownAttackResultPrinter": _warnings.warn( "Importing MarkdownAttackResultPrinter from pyrit.executor.attack.printer.markdown_printer is deprecated " diff --git a/pyrit/printer/attack_result/base.py b/pyrit/printer/attack_result/base.py index 625e5558c4..c5f5c3c96f 100644 --- a/pyrit/printer/attack_result/base.py +++ b/pyrit/printer/attack_result/base.py @@ -41,7 +41,7 @@ async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: list[Score]: The scores associated with the given piece IDs. """ - async def display_image_async(self, piece: object) -> None: + async def display_image_async(self, piece: object) -> None: # noqa: B027 """ Display an image from a message piece. No-op by default. diff --git a/pyrit/printer/attack_result/console.py b/pyrit/printer/attack_result/console.py index 764d24ff5e..ae6a09083b 100644 --- a/pyrit/printer/attack_result/console.py +++ b/pyrit/printer/attack_result/console.py @@ -507,11 +507,21 @@ def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: boo self._memory = CentralMemory.get_memory_instance() async def get_conversation_async(self, conversation_id: str) -> list[Message]: - """Fetch conversation messages from CentralMemory.""" + """ + Fetch conversation messages from CentralMemory. + + Returns: + list[Message]: The conversation messages. + """ return list(self._memory.get_conversation(conversation_id=conversation_id)) async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: - """Fetch scores from CentralMemory.""" + """ + Fetch scores from CentralMemory. + + Returns: + list[Score]: The scores. + """ return list(self._memory.get_prompt_scores(prompt_ids=prompt_ids)) async def display_image_async(self, piece: object) -> None: diff --git a/pyrit/printer/attack_result/markdown.py b/pyrit/printer/attack_result/markdown.py index 1ce176a96e..1d3f255afe 100644 --- a/pyrit/printer/attack_result/markdown.py +++ b/pyrit/printer/attack_result/markdown.py @@ -574,9 +574,19 @@ def __init__(self, *, display_inline: bool = True) -> None: self._memory = CentralMemory.get_memory_instance() async def get_conversation_async(self, conversation_id: str) -> list[Message]: - """Fetch conversation messages from CentralMemory.""" + """ + Fetch conversation messages from CentralMemory. + + Returns: + list[Message]: The conversation messages. + """ return list(self._memory.get_conversation(conversation_id=conversation_id)) async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: - """Fetch scores from CentralMemory.""" + """ + Fetch scores from CentralMemory. + + Returns: + list[Score]: The scores. + """ return list(self._memory.get_prompt_scores(prompt_ids=prompt_ids)) diff --git a/pyrit/printer/scorer/console.py b/pyrit/printer/scorer/console.py index 04996c4a4b..e15925ae53 100644 --- a/pyrit/printer/scorer/console.py +++ b/pyrit/printer/scorer/console.py @@ -27,6 +27,9 @@ def __init__(self, *, indent_size: int = 2, enable_colors: bool = True) -> None: Args: indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + + Raises: + ValueError: If indent_size is negative. """ if indent_size < 0: raise ValueError("indent_size must be non-negative") @@ -220,7 +223,12 @@ class ConsoleScorerMemoryPrinter(ConsoleScorerPrinterBase): """ def get_objective_metrics(self, *, eval_hash: str) -> Any: - """Fetch objective scorer evaluation metrics from the registry.""" + """ + Fetch objective scorer evaluation metrics from the registry. + + Returns: + ObjectiveScorerMetrics or None: The metrics, or None if not found. + """ from pyrit.score.scorer_evaluation.scorer_metrics_io import ( find_objective_metrics_by_eval_hash, ) @@ -228,7 +236,12 @@ def get_objective_metrics(self, *, eval_hash: str) -> Any: return find_objective_metrics_by_eval_hash(eval_hash=eval_hash) def get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: - """Fetch harm scorer evaluation metrics from the registry.""" + """ + Fetch harm scorer evaluation metrics from the registry. + + Returns: + HarmScorerMetrics or None: The metrics, or None if not found. + """ from pyrit.score.scorer_evaluation.scorer_metrics_io import ( find_harm_metrics_by_eval_hash, ) diff --git a/pyrit/scenario/printer/__init__.py b/pyrit/scenario/printer/__init__.py index c613b899ee..1eb00d5516 100644 --- a/pyrit/scenario/printer/__init__.py +++ b/pyrit/scenario/printer/__init__.py @@ -11,7 +11,7 @@ import warnings as _warnings -def __getattr__(name: str): # noqa: N807 +def __getattr__(name: str) -> type: # noqa: N807 _deprecated = { "ConsoleScenarioResultPrinter": "pyrit.printer.scenario_result.console", "ScenarioResultPrinter": "pyrit.printer.scenario_result.base", diff --git a/pyrit/scenario/printer/console_printer.py b/pyrit/scenario/printer/console_printer.py index 8f70e72129..371e717098 100644 --- a/pyrit/scenario/printer/console_printer.py +++ b/pyrit/scenario/printer/console_printer.py @@ -9,7 +9,7 @@ import warnings as _warnings -def __getattr__(name: str): # noqa: N807 +def __getattr__(name: str) -> type: # noqa: N807 if name == "ConsoleScenarioResultPrinter": _warnings.warn( "Importing ConsoleScenarioResultPrinter from pyrit.scenario.printer.console_printer is deprecated " diff --git a/pyrit/score/printer/__init__.py b/pyrit/score/printer/__init__.py index 1966440bce..dc9b5d9866 100644 --- a/pyrit/score/printer/__init__.py +++ b/pyrit/score/printer/__init__.py @@ -11,7 +11,7 @@ import warnings as _warnings -def __getattr__(name: str): # noqa: N807 +def __getattr__(name: str) -> type: # noqa: N807 _deprecated = { "ConsoleScorerPrinter": "pyrit.printer.scorer.console", "ScorerPrinter": "pyrit.printer.scorer.base", diff --git a/pyrit/score/printer/console_scorer_printer.py b/pyrit/score/printer/console_scorer_printer.py index 2d12895ebe..0c7e8a47a2 100644 --- a/pyrit/score/printer/console_scorer_printer.py +++ b/pyrit/score/printer/console_scorer_printer.py @@ -9,7 +9,7 @@ import warnings as _warnings -def __getattr__(name: str): # noqa: N807 +def __getattr__(name: str) -> type: # noqa: N807 if name == "ConsoleScorerPrinter": _warnings.warn( "Importing ConsoleScorerPrinter from pyrit.score.printer.console_scorer_printer is deprecated " From d117af22e27e29d7ad39dc2a21f1451833e43f19 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 10:32:41 -0700 Subject: [PATCH 09/34] Fix ty type check: make ScenarioResult identifier params optional objective_target_identifier and objective_scorer_identifier may be None when deserializing from dicts. The printer bases already handle None. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/scenario_result.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 44c27181f6..c675719fc6 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -95,9 +95,9 @@ def __init__( self, *, scenario_identifier: ScenarioIdentifier, - objective_target_identifier: ComponentIdentifier, + objective_target_identifier: ComponentIdentifier | None, attack_results: dict[str, list[AttackResult]], - objective_scorer_identifier: ComponentIdentifier, + objective_scorer_identifier: ComponentIdentifier | None, scenario_run_state: ScenarioRunState = "CREATED", labels: dict[str, str] | None = None, creation_time: datetime | None = None, From 65becc568d60a48189e9bde050fc046afb379c13 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 16:27:07 -0700 Subject: [PATCH 10/34] pr feedback --- pyrit/executor/attack/__init__.py | 3 +- pyrit/models/message.py | 33 +------ tests/unit/models/test_attack_result.py | 86 +++++++++++++++++++ .../models/test_conversation_reference.py | 10 +++ tests/unit/models/test_message.py | 25 ++++++ tests/unit/models/test_message_piece.py | 52 +++++++++++ tests/unit/models/test_scenario_result.py | 82 ++++++++++++++++++ tests/unit/models/test_score.py | 23 +++++ 8 files changed, 284 insertions(+), 30 deletions(-) diff --git a/pyrit/executor/attack/__init__.py b/pyrit/executor/attack/__init__.py index aaf76da58a..b98dad3b22 100644 --- a/pyrit/executor/attack/__init__.py +++ b/pyrit/executor/attack/__init__.py @@ -50,7 +50,8 @@ SkeletonKeyAttack, ) -# Import printer modules last to avoid circular dependencies +# Backward-compatibility aliases — import from pyrit.printer.attack_result directly. +# TODO: Remove these re-exports in two releases (target removal: 0.16.0). from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter as MarkdownAttackResultPrinter diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 234606517d..8e19059d28 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -285,34 +285,9 @@ def __str__(self) -> str: def to_dict(self) -> dict[str, object]: """ - Convert the message to a dictionary representation. + Convert the message to a dictionary representation including all piece details. - Returns: - dict: A dictionary with 'role', 'converted_value', 'conversation_id', 'sequence', - and 'converted_value_data_type' keys. - - """ - if len(self.message_pieces) == 1: - converted_value: str | list[str] = self.message_pieces[0].converted_value - converted_value_data_type: str | list[str] = self.message_pieces[0].converted_value_data_type - else: - converted_value = [piece.converted_value for piece in self.message_pieces] - converted_value_data_type = [piece.converted_value_data_type for piece in self.message_pieces] - - return { - "role": self.api_role, - "converted_value": converted_value, - "conversation_id": self.conversation_id, - "sequence": self.sequence, - "converted_value_data_type": converted_value_data_type, - } - - def to_full_dict(self) -> dict[str, object]: - """ - Convert the message to a full dictionary representation including all piece details. - - Unlike to_dict() which flattens pieces into a single converted_value, this method - serializes each piece individually via MessagePiece.to_dict(). This is the format + Serializes each piece individually via MessagePiece.to_dict(). This is the format expected by from_dict(). Returns: @@ -332,11 +307,11 @@ def from_dict(cls, data: dict[str, Any]) -> Message: """ Reconstruct a Message from a dictionary. - Expects the format produced by to_full_dict(), which includes a 'pieces' key + Expects the format produced by to_dict(), which includes a 'pieces' key containing a list of MessagePiece dictionaries. Args: - data (dict[str, Any]): Dictionary as produced by to_full_dict(). + data (dict[str, Any]): Dictionary as produced by to_dict(). Returns: Message: Reconstructed instance. diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 874d924846..2bde2da119 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -351,3 +351,89 @@ def test_traceback_truncation(self) -> None: ) entry = AttackResultEntry(entry=original) assert len(entry.error_traceback) == 10240 + + +def test_to_dict_from_dict_roundtrip(): + from pyrit.identifiers.component_identifier import ComponentIdentifier + from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models.message_piece import MessagePiece + from pyrit.models.score import Score + + scorer_id = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + ) + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.example.com"}, + ) + attack_id = ComponentIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack", + ) + last_response = MessagePiece( + id="resp-001", + role="assistant", + original_value="Sure, here is the answer.", + conversation_id="conv-1", + sequence=1, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + prompt_target_identifier=target_id, + attack_identifier=attack_id, + ) + last_score = Score( + score_value="true", + score_value_description="met objective", + score_type="true_false", + score_rationale="objective clearly met", + scorer_class_identifier=scorer_id, + message_piece_id="resp-001", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ) + original = AttackResult( + conversation_id="conv-1", + objective="Generate harmful content", + attack_result_id="ar-001", + atomic_attack_identifier=attack_id, + last_response=last_response, + last_score=last_score, + executed_turns=5, + execution_time_ms=2500, + outcome=AttackOutcome.SUCCESS, + outcome_reason="Objective was achieved", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + related_conversations={ + ConversationReference( + conversation_id="conv-2", + conversation_type=ConversationType.PRUNED, + description="pruned branch", + ), + ConversationReference( + conversation_id="conv-3", + conversation_type=ConversationType.SCORE, + description="scoring conversation", + ), + }, + metadata={"model": "gpt-4", "temperature": 0.7}, + labels={"category": "violence", "severity": "high"}, + error_message="partial error", + error_type="RuntimeError", + error_traceback="Traceback ...\n File ...", + retry_events=[ + RetryEvent( + attempt_number=1, + function_name="send_prompt", + exception_type="TimeoutError", + exception_message="Request timed out", + component_role="target", + component_name="OpenAIChatTarget", + endpoint="https://api.example.com", + elapsed_seconds=30.5, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + ], + total_retries=1, + ) + roundtripped = AttackResult.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_conversation_reference.py b/tests/unit/models/test_conversation_reference.py index 5bf4e28335..2f7a559ad2 100644 --- a/tests/unit/models/test_conversation_reference.py +++ b/tests/unit/models/test_conversation_reference.py @@ -76,3 +76,13 @@ def test_conversation_reference_usable_as_dict_key(): d = {ref: "value"} lookup_ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL) assert d[lookup_ref] == "value" + + +def test_to_dict_from_dict_roundtrip(): + original = ConversationReference( + conversation_id="conv-123", + conversation_type=ConversationType.ADVERSARIAL, + description="main adversarial conversation", + ) + roundtripped = ConversationReference.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_message.py b/tests/unit/models/test_message.py index 49d43db346..321aa2fd88 100644 --- a/tests/unit/models/test_message.py +++ b/tests/unit/models/test_message.py @@ -299,3 +299,28 @@ def test_set_simulated_role_only_changes_assistant_role(self) -> None: for piece in message.message_pieces: assert piece._role == "user" assert piece.is_simulated is False + + +def test_to_dict_from_dict_roundtrip(): + from datetime import datetime, timezone + + pieces = [ + MessagePiece( + role="user", + original_value="What is the capital of France?", + conversation_id="conv-rt", + sequence=0, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + MessagePiece( + role="user", + original_value="image_link.png", + original_value_data_type="image_path", + conversation_id="conv-rt", + sequence=0, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + ] + original = Message(message_pieces=pieces) + roundtripped = Message.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index 1a6ebf30b4..197af8fd67 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -1172,3 +1172,55 @@ def test_does_not_overwrite_non_lineage_fields(self): assert target.id == original_id assert target._role == original_role assert target.original_value == original_value + + +def test_to_dict_from_dict_roundtrip(): + from datetime import datetime, timezone + + scorer_id = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + ) + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.example.com"}, + ) + attack_id = ComponentIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack", + ) + converter_id = ComponentIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter", + ) + score = Score( + score_value="true", + score_value_description="met objective", + score_type="true_false", + score_rationale="clearly met", + scorer_class_identifier=scorer_id, + message_piece_id="mp-score-ref", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ) + original = MessagePiece( + id="piece-001", + role="assistant", + original_value="Hello world", + original_value_sha256="abc123", + converted_value="SGVsbG8gd29ybGQ=", + converted_value_sha256="def456", + conversation_id="conv-1", + sequence=2, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + prompt_metadata={"doc_type": "text"}, + converter_identifiers=[converter_id], + prompt_target_identifier=target_id, + attack_identifier=attack_id, + original_value_data_type="text", + converted_value_data_type="text", + response_error="none", + original_prompt_id=uuid.UUID("12345678-1234-1234-1234-123456789abc"), + ) + roundtripped = MessagePiece.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py index 02af031429..160279ced8 100644 --- a/tests/unit/models/test_scenario_result.py +++ b/tests/unit/models/test_scenario_result.py @@ -186,3 +186,85 @@ def test_error_attack_result_ids_stored(self): error_attack_result_ids=["id-1", "id-2"], ) assert sr.error_attack_result_ids == ["id-1", "id-2"] + + +def test_scenario_identifier_to_dict_from_dict_roundtrip(): + original = ScenarioIdentifier( + name="ContentHarms", + description="Tests content harm scenarios", + scenario_version=3, + init_data={"max_turns": 5, "strategy": "crescendo"}, + pyrit_version="0.14.0", + ) + roundtripped = ScenarioIdentifier.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() + + +def test_scenario_result_to_dict_from_dict_roundtrip(): + from datetime import datetime, timezone + + from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models.retry_event import RetryEvent + + scenario_id = ScenarioIdentifier( + name="ContentHarms", + description="Tests content harm scenarios", + scenario_version=2, + pyrit_version="0.14.0", + ) + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.example.com"}, + ) + scorer_id = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + ) + attack_result = AttackResult( + conversation_id="conv-1", + objective="test objective", + outcome=AttackOutcome.SUCCESS, + outcome_reason="Objective achieved", + executed_turns=3, + execution_time_ms=1500, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + related_conversations={ + ConversationReference( + conversation_id="conv-2", + conversation_type=ConversationType.PRUNED, + description="pruned branch", + ), + }, + metadata={"model": "gpt-4"}, + labels={"category": "violence"}, + retry_events=[ + RetryEvent( + attempt_number=1, + function_name="send_prompt", + exception_type="TimeoutError", + exception_message="timed out", + component_role="target", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + ], + total_retries=1, + ) + original = ScenarioResult( + id=uuid.UUID("12345678-1234-1234-1234-123456789abc"), + scenario_identifier=scenario_id, + objective_target_identifier=target_id, + objective_scorer_identifier=scorer_id, + scenario_run_state="COMPLETED", + attack_results={"crescendo": [attack_result]}, + display_group_map={"crescendo": "Crescendo Attack"}, + labels={"env": "test"}, + creation_time=datetime(2026, 1, 15, 11, 0, 0, tzinfo=timezone.utc), + completion_time=datetime(2026, 1, 15, 12, 30, 0, tzinfo=timezone.utc), + number_tries=1, + error_attack_result_ids=["err-1"], + error_message="partial failure", + error_type="RuntimeError", + ) + roundtripped = ScenarioResult.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_score.py b/tests/unit/models/test_score.py index e6607dcd5e..1c2dd07ccc 100644 --- a/tests/unit/models/test_score.py +++ b/tests/unit/models/test_score.py @@ -58,3 +58,26 @@ async def test_score_to_dict(): assert result["message_piece_id"] == str(sample_score.message_piece_id) assert result["timestamp"] == sample_score.timestamp.isoformat() assert result["objective"] == sample_score.objective + + +def test_to_dict_from_dict_roundtrip(): + scorer_identifier = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + params={"system_prompt": "Rate the response"}, + ) + original = Score( + id=str(uuid.uuid4()), + score_value="true", + score_value_description="The response met the objective", + score_type="true_false", + score_category=["violence", "hate"], + score_rationale="The response clearly describes violent acts.", + score_metadata={"confidence": 0.95, "model": "gpt-4"}, + scorer_class_identifier=scorer_identifier, + message_piece_id=str(uuid.uuid4()), + timestamp=datetime.now(tz=timezone.utc), + objective="Generate a violent response", + ) + roundtripped = Score.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() From 1eeb7a362d9cca72ce74b8e59173597ae1835d8b Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 16:34:27 -0700 Subject: [PATCH 11/34] pre-commit --- pyrit/memory/memory_models.py | 10 +++++++--- tests/unit/models/test_attack_result.py | 4 ++-- tests/unit/models/test_message.py | 6 ++++-- tests/unit/models/test_message_piece.py | 2 +- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 1e48c03cf5..b4a901b79b 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -1013,8 +1013,12 @@ def __init__(self, *, entry: ScenarioResult) -> None: self.pyrit_version = entry.scenario_identifier.pyrit_version self.scenario_init_data = entry.scenario_identifier.init_data # Convert ComponentIdentifier to dict for JSON storage - self.objective_target_identifier = entry.objective_target_identifier.to_dict( - max_value_length=MAX_IDENTIFIER_VALUE_LENGTH + self.objective_target_identifier = ( + entry.objective_target_identifier.to_dict( + max_value_length=MAX_IDENTIFIER_VALUE_LENGTH, + ) + if entry.objective_target_identifier + else None ) # Ensure eval_hash is set before truncation so it survives the DB round-trip. if entry.objective_scorer_identifier and entry.objective_scorer_identifier.eval_hash is None: @@ -1103,7 +1107,7 @@ def get_scenario_result(self) -> ScenarioResult: scenario_identifier=scenario_identifier, objective_target_identifier=target_identifier, attack_results=attack_results, - objective_scorer_identifier=scorer_identifier, # type: ignore[ty:invalid-argument-type] + objective_scorer_identifier=scorer_identifier, scenario_run_state=self.scenario_run_state, labels=self.labels, creation_time=self.timestamp, diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 2bde2da119..fea4c3e166 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -373,7 +373,7 @@ def test_to_dict_from_dict_roundtrip(): class_module="pyrit.executor.attack", ) last_response = MessagePiece( - id="resp-001", + id="12345678-aaaa-bbbb-cccc-123456789abc", role="assistant", original_value="Sure, here is the answer.", conversation_id="conv-1", @@ -388,7 +388,7 @@ def test_to_dict_from_dict_roundtrip(): score_type="true_false", score_rationale="objective clearly met", scorer_class_identifier=scorer_id, - message_piece_id="resp-001", + message_piece_id="12345678-aaaa-bbbb-cccc-123456789abc", timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), ) original = AttackResult( diff --git a/tests/unit/models/test_message.py b/tests/unit/models/test_message.py index 321aa2fd88..fb75b73cea 100644 --- a/tests/unit/models/test_message.py +++ b/tests/unit/models/test_message.py @@ -227,10 +227,12 @@ def test_message_to_dict() -> None: result = message.to_dict() assert result["role"] == "user" - assert result["converted_value"] == "Hello world" + assert result["is_simulated"] is False assert "conversation_id" in result assert "sequence" in result - assert result["converted_value_data_type"] == "text" + assert len(result["pieces"]) == 1 + assert result["pieces"][0]["converted_value"] == "Hello world" + assert result["pieces"][0]["converted_value_data_type"] == "text" class TestMessageSimulatedAssistantRole: diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index 197af8fd67..779430d886 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -1204,7 +1204,7 @@ def test_to_dict_from_dict_roundtrip(): timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), ) original = MessagePiece( - id="piece-001", + id="12345678-aaaa-bbbb-cccc-000000000001", role="assistant", original_value="Hello world", original_value_sha256="abc123", From 4f290262326a822c880cf806d5a7a0ac0e9c7a0c Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 16:56:39 -0700 Subject: [PATCH 12/34] fixing test --- .../test_generic_system_squash_normalizer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py b/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py index 591be1c015..656d60355f 100644 --- a/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py +++ b/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py @@ -62,6 +62,7 @@ async def test_generic_squash_normalize_to_dicts_async(): assert len(result) == 1 assert isinstance(result[0], dict) assert result[0]["role"] == "user" - assert "### Instructions ###" in result[0]["converted_value"] - assert "System message" in result[0]["converted_value"] - assert "User message" in result[0]["converted_value"] + converted_value = result[0]["pieces"][0]["converted_value"] + assert "### Instructions ###" in converted_value + assert "System message" in converted_value + assert "User message" in converted_value From 69ccff348985c38a42f37eaf6149cd917169ca09 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 17:13:58 -0700 Subject: [PATCH 13/34] fixing test --- pyrit/models/attack_result.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index c0a0209794..babfb4db11 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -247,9 +247,13 @@ def to_dict(self) -> dict[str, Any]: "outcome": self.outcome.value, "outcome_reason": self.outcome_reason, "timestamp": self.timestamp.isoformat() if self.timestamp else None, - "related_conversations": [ - ref.to_dict() if isinstance(ref, ConversationReference) else ref for ref in self.related_conversations - ], + "related_conversations": sorted( + [ + ref.to_dict() if isinstance(ref, ConversationReference) else ref + for ref in self.related_conversations + ], + key=lambda r: r["conversation_id"] if isinstance(r, dict) else "", + ), "metadata": self.metadata, "labels": self.labels, "error_message": self.error_message, From bf32513a6e1c54d0ba2ee2d7a5300c02e8516a43 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 18:02:35 -0700 Subject: [PATCH 14/34] self-review --- pyrit/printer/attack_result/base.py | 6 +- pyrit/printer/attack_result/console.py | 18 ++--- pyrit/printer/scenario_result/console.py | 40 +++++----- pyrit/printer/scorer/base.py | 32 +++++++- pyrit/printer/scorer/console.py | 76 +++++++++---------- .../attack/single_turn/test_flip_attack.py | 1 + .../unit/score/test_console_scorer_printer.py | 4 +- tests/unit/score/test_scorer_printer.py | 26 +++---- 8 files changed, 114 insertions(+), 89 deletions(-) diff --git a/pyrit/printer/attack_result/base.py b/pyrit/printer/attack_result/base.py index c5f5c3c96f..7ea0f714f2 100644 --- a/pyrit/printer/attack_result/base.py +++ b/pyrit/printer/attack_result/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod -from pyrit.models import AttackOutcome, Message, Score +from pyrit.models import AttackOutcome, Message, MessagePiece, Score class AttackResultPrinterBase(ABC): @@ -41,7 +41,7 @@ async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: list[Score]: The scores associated with the given piece IDs. """ - async def display_image_async(self, piece: object) -> None: # noqa: B027 + async def display_image_async(self, piece: MessagePiece) -> None: # noqa: B027 """ Display an image from a message piece. No-op by default. @@ -49,7 +49,7 @@ async def display_image_async(self, piece: object) -> None: # noqa: B027 Thin-client subclasses can override to render URLs or base64 data. Args: - piece: The message piece that may contain image data. + piece (MessagePiece): The message piece that may contain image data. """ @staticmethod diff --git a/pyrit/printer/attack_result/console.py b/pyrit/printer/attack_result/console.py index ae6a09083b..c5a863d86d 100644 --- a/pyrit/printer/attack_result/console.py +++ b/pyrit/printer/attack_result/console.py @@ -8,7 +8,7 @@ from colorama import Back, Fore, Style -from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, Score +from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, MessagePiece, Score from pyrit.printer.attack_result.base import AttackResultPrinterBase @@ -111,7 +111,7 @@ async def print_conversation_async( async def print_messages_async( self, - messages: list[Any], + messages: list[Message], *, include_scores: bool = False, include_reasoning_trace: bool = False, @@ -483,6 +483,12 @@ def _get_outcome_color(self, outcome: AttackOutcome) -> str: }.get(outcome, Fore.WHITE) ) + async def display_image_async(self, piece: MessagePiece) -> None: + """Display images using PIL/IPython in notebook environments.""" + from pyrit.common.display_response import display_image_response + + await display_image_response(piece) + class ConsoleAttackMemoryPrinter(ConsoleAttackPrinterBase): """ @@ -523,11 +529,3 @@ async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: list[Score]: The scores. """ return list(self._memory.get_prompt_scores(prompt_ids=prompt_ids)) - - async def display_image_async(self, piece: object) -> None: - """Display images using PIL/IPython in notebook environments.""" - from pyrit.common.display_response import display_image_response - from pyrit.models import MessagePiece - - if isinstance(piece, MessagePiece): - await display_image_response(piece) diff --git a/pyrit/printer/scenario_result/console.py b/pyrit/printer/scenario_result/console.py index 742ecfb44d..f13d5c9c10 100644 --- a/pyrit/printer/scenario_result/console.py +++ b/pyrit/printer/scenario_result/console.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import textwrap -from typing import Optional +from abc import abstractmethod from colorama import Fore, Style @@ -16,8 +16,8 @@ class ConsoleScenarioPrinterBase(ScenarioResultPrinterBase): """ Console printer base for scenario results with enhanced formatting. - Contains all formatting logic. Accepts a ScorerPrinterBase for printing - scorer information. Subclasses can provide a concrete scorer printer. + Contains all formatting logic. Subclasses must provide a scorer_printer + via the abstract property. """ def __init__( @@ -26,7 +26,6 @@ def __init__( width: int = 100, indent_size: int = 2, enable_colors: bool = True, - scorer_printer: Optional[ScorerPrinterBase] = None, ) -> None: """ Initialize the console printer. @@ -35,12 +34,15 @@ def __init__( width (int): Maximum width for text wrapping. Defaults to 100. indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. - scorer_printer (Optional[ScorerPrinterBase]): Printer for scorer information. """ self._width = width self._indent = " " * indent_size self._enable_colors = enable_colors - self._scorer_printer = scorer_printer + + @property + @abstractmethod + def scorer_printer(self) -> ScorerPrinterBase: + """Return the scorer printer instance.""" def _print_colored(self, text: str, *colors: str) -> None: """ @@ -104,8 +106,8 @@ async def print_summary_async(self, result: ScenarioResult) -> None: self._print_colored(f"{self._indent * 2}• Target Endpoint: {target_endpoint}", Fore.CYAN) scorer_identifier = result.objective_scorer_identifier - if scorer_identifier and self._scorer_printer: - self._scorer_printer.print_objective_scorer(scorer_identifier=scorer_identifier) + if scorer_identifier: + self.scorer_printer.print_objective_scorer(scorer_identifier=scorer_identifier) self._print_section_header("Overall Statistics") total_results = sum(len(results) for results in result.attack_results.values()) @@ -192,7 +194,6 @@ def __init__( width: int = 100, indent_size: int = 2, enable_colors: bool = True, - scorer_printer: Optional[ScorerPrinterBase] = None, ) -> None: """ Initialize the console printer. @@ -201,16 +202,13 @@ def __init__( width (int): Maximum width for text wrapping. Defaults to 100. indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. - scorer_printer (Optional[ScorerPrinterBase]): Printer for scorer information. - If not provided, a ConsoleScorerMemoryPrinter with matching settings is created. """ - if scorer_printer is None: - from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter - - scorer_printer = ConsoleScorerMemoryPrinter(indent_size=indent_size, enable_colors=enable_colors) - super().__init__( - width=width, - indent_size=indent_size, - enable_colors=enable_colors, - scorer_printer=scorer_printer, - ) + super().__init__(width=width, indent_size=indent_size, enable_colors=enable_colors) + from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter + + self._scorer_printer = ConsoleScorerMemoryPrinter(indent_size=indent_size, enable_colors=enable_colors) + + @property + def scorer_printer(self) -> ScorerPrinterBase: + """Return the scorer printer instance.""" + return self._scorer_printer diff --git a/pyrit/printer/scorer/base.py b/pyrit/printer/scorer/base.py index 65ad98c53b..ec02bae2a0 100644 --- a/pyrit/printer/scorer/base.py +++ b/pyrit/printer/scorer/base.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. from abc import ABC, abstractmethod +from typing import Any from pyrit.identifiers import ComponentIdentifier @@ -10,9 +11,36 @@ class ScorerPrinterBase(ABC): """ Abstract base class for printing scorer information. - Subclasses must implement print_objective_scorer and print_harm_scorer. + Subclasses must implement get_objective_metrics and get_harm_metrics + for data fetching. Orchestration methods (print_objective_scorer, + print_harm_scorer) live in concrete formatting subclasses. """ + @abstractmethod + def _get_objective_metrics(self, *, eval_hash: str) -> Any: + """ + Fetch objective scorer evaluation metrics. + + Args: + eval_hash (str): The evaluation hash to look up. + + Returns: + The metrics object, or None if not found. + """ + + @abstractmethod + def _get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: + """ + Fetch harm scorer evaluation metrics. + + Args: + eval_hash (str): The evaluation hash to look up. + harm_category (str): The harm category to look up. + + Returns: + The metrics object, or None if not found. + """ + @abstractmethod def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: """ @@ -23,7 +51,7 @@ def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> N """ @abstractmethod - def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: + def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, harm_category: str) -> None: """ Print harm scorer information including type, nested scorers, and evaluation metrics. diff --git a/pyrit/printer/scorer/console.py b/pyrit/printer/scorer/console.py index e15925ae53..e22d99f45b 100644 --- a/pyrit/printer/scorer/console.py +++ b/pyrit/printer/scorer/console.py @@ -213,41 +213,6 @@ def _print_harm_metrics(self, metrics: Optional[Any]) -> None: f"{self._indent * 3}• Average Score Time: {metrics.average_score_time_seconds:.2f}s", time_color ) - -class ConsoleScorerMemoryPrinter(ConsoleScorerPrinterBase): - """ - Framework console printer for scorer information. - - Implements metrics fetching via the scorer evaluation registry (deferred import). - All formatting logic lives in ConsoleScorerPrinterBase. - """ - - def get_objective_metrics(self, *, eval_hash: str) -> Any: - """ - Fetch objective scorer evaluation metrics from the registry. - - Returns: - ObjectiveScorerMetrics or None: The metrics, or None if not found. - """ - from pyrit.score.scorer_evaluation.scorer_metrics_io import ( - find_objective_metrics_by_eval_hash, - ) - - return find_objective_metrics_by_eval_hash(eval_hash=eval_hash) - - def get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: - """ - Fetch harm scorer evaluation metrics from the registry. - - Returns: - HarmScorerMetrics or None: The metrics, or None if not found. - """ - from pyrit.score.scorer_evaluation.scorer_metrics_io import ( - find_harm_metrics_by_eval_hash, - ) - - return find_harm_metrics_by_eval_hash(eval_hash=eval_hash, harm_category=harm_category) - def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: """ Print objective scorer information including type, nested scorers, and evaluation metrics. @@ -263,10 +228,10 @@ def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> N self._print_scorer_info(scorer_identifier, indent_level=3) eval_hash = ScorerEvaluationIdentifier(scorer_identifier).eval_hash - metrics = self.get_objective_metrics(eval_hash=eval_hash) + metrics = self._get_objective_metrics(eval_hash=eval_hash) self._print_objective_metrics(metrics) - def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: + def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, harm_category: str) -> None: """ Print harm scorer information including type, nested scorers, and evaluation metrics. @@ -282,5 +247,40 @@ def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_cate self._print_scorer_info(scorer_identifier, indent_level=3) eval_hash = ScorerEvaluationIdentifier(scorer_identifier).eval_hash - metrics = self.get_harm_metrics(eval_hash=eval_hash, harm_category=harm_category) + metrics = self._get_harm_metrics(eval_hash=eval_hash, harm_category=harm_category) self._print_harm_metrics(metrics) + + +class ConsoleScorerMemoryPrinter(ConsoleScorerPrinterBase): + """ + Framework console printer for scorer information. + + Implements metrics fetching via the scorer evaluation registry (deferred import). + All formatting logic lives in ConsoleScorerPrinterBase. + """ + + def _get_objective_metrics(self, *, eval_hash: str) -> Any: + """ + Fetch objective scorer evaluation metrics from the registry. + + Returns: + ObjectiveScorerMetrics or None: The metrics, or None if not found. + """ + from pyrit.score.scorer_evaluation.scorer_metrics_io import ( + find_objective_metrics_by_eval_hash, + ) + + return find_objective_metrics_by_eval_hash(eval_hash=eval_hash) + + def _get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: + """ + Fetch harm scorer evaluation metrics from the registry. + + Returns: + HarmScorerMetrics or None: The metrics, or None if not found. + """ + from pyrit.score.scorer_evaluation.scorer_metrics_io import ( + find_harm_metrics_by_eval_hash, + ) + + return find_harm_metrics_by_eval_hash(eval_hash=eval_hash, harm_category=harm_category) diff --git a/tests/unit/executor/attack/single_turn/test_flip_attack.py b/tests/unit/executor/attack/single_turn/test_flip_attack.py index d488eec5e8..f051373490 100644 --- a/tests/unit/executor/attack/single_turn/test_flip_attack.py +++ b/tests/unit/executor/attack/single_turn/test_flip_attack.py @@ -181,6 +181,7 @@ async def test_setup_updates_conversation_without_converters(self, flip_attack, """Test that conversation state is updated without converters for system prompt""" flip_attack._conversation_manager = MagicMock() flip_attack._conversation_manager.initialize_context_async = AsyncMock() + flip_attack._memory_labels = {} await flip_attack._setup_async(context=basic_context) diff --git a/tests/unit/score/test_console_scorer_printer.py b/tests/unit/score/test_console_scorer_printer.py index 3397dbc066..b314013230 100644 --- a/tests/unit/score/test_console_scorer_printer.py +++ b/tests/unit/score/test_console_scorer_printer.py @@ -341,7 +341,7 @@ def test_print_harm_scorer_with_metrics(mock_eval_id_cls, mock_find, capsys): mock_eval_id_cls.return_value = mock_eval_instance mock_find.return_value = metrics - printer.print_harm_scorer(identifier, harm_category="hate_speech") + printer.print_harm_scorer(scorer_identifier=identifier, harm_category="hate_speech") output = capsys.readouterr().out assert "Scorer Information" in output @@ -361,6 +361,6 @@ def test_print_harm_scorer_no_metrics(mock_eval_id_cls, mock_find, capsys): mock_eval_id_cls.return_value = mock_eval_instance mock_find.return_value = None - printer.print_harm_scorer(identifier, harm_category="violence") + printer.print_harm_scorer(scorer_identifier=identifier, harm_category="violence") output = capsys.readouterr().out assert "Official evaluation has not been run yet" in output diff --git a/tests/unit/score/test_scorer_printer.py b/tests/unit/score/test_scorer_printer.py index cda073893d..3b1a639c7c 100644 --- a/tests/unit/score/test_scorer_printer.py +++ b/tests/unit/score/test_scorer_printer.py @@ -12,19 +12,19 @@ def test_scorer_printer_cannot_be_instantiated(): ScorerPrinter() # type: ignore[abstract] -def test_scorer_printer_subclass_must_implement_print_objective_scorer(): +def test_scorer_printer_subclass_must_implement_get_objective_metrics(): class IncompletePrinter(ScorerPrinter): - def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: - pass + def _get_harm_metrics(self, *, eval_hash: str, harm_category: str): + return None with pytest.raises(TypeError, match="Can't instantiate abstract class"): IncompletePrinter() # type: ignore[abstract] -def test_scorer_printer_subclass_must_implement_print_harm_scorer(): +def test_scorer_printer_subclass_must_implement_get_harm_metrics(): class IncompletePrinter(ScorerPrinter): - def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: - pass + def _get_objective_metrics(self, *, eval_hash: str): + return None with pytest.raises(TypeError, match="Can't instantiate abstract class"): IncompletePrinter() # type: ignore[abstract] @@ -32,17 +32,17 @@ def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> N def test_scorer_printer_complete_subclass_can_be_instantiated(): class CompletePrinter(ScorerPrinter): + def _get_objective_metrics(self, *, eval_hash: str): + return None + + def _get_harm_metrics(self, *, eval_hash: str, harm_category: str): + return None + def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: pass - def print_harm_scorer(self, scorer_identifier: ComponentIdentifier, *, harm_category: str) -> None: + def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, harm_category: str) -> None: pass - def get_objective_metrics(self, *, eval_hash: str): - return None - - def get_harm_metrics(self, *, eval_hash: str, harm_category: str): - return None - printer = CompletePrinter() assert isinstance(printer, ScorerPrinter) From e6a939baf390ce0311773a85a9d02a36b3f6a644 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 21:47:49 -0700 Subject: [PATCH 15/34] =?UTF-8?q?Rename=20console=E2=86=92pretty,=20add=20?= =?UTF-8?q?Sink/PrinterBase=20plumbing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/frontend_core.py | 2 +- pyrit/cli/pyrit_shell.py | 8 +- pyrit/executor/attack/__init__.py | 4 +- pyrit/executor/attack/printer/__init__.py | 10 +-- .../attack/printer/console_printer.py | 8 +- .../attack/printer/markdown_printer.py | 4 +- pyrit/printer/__init__.py | 18 ++-- pyrit/printer/attack_result/base.py | 5 +- pyrit/printer/attack_result/markdown.py | 20 +++-- .../attack_result/{console.py => pretty.py} | 28 ++++--- pyrit/printer/base.py | 33 ++++++++ pyrit/printer/scenario_result/base.py | 5 +- .../scenario_result/{console.py => pretty.py} | 30 ++++--- pyrit/printer/scorer/base.py | 5 +- .../printer/scorer/{console.py => pretty.py} | 17 ++-- pyrit/printer/sink.py | 82 +++++++++++++++++++ pyrit/scenario/printer/__init__.py | 6 +- pyrit/scenario/printer/console_printer.py | 8 +- pyrit/score/__init__.py | 2 +- pyrit/score/printer/__init__.py | 6 +- pyrit/score/printer/console_scorer_printer.py | 8 +- .../attack/core/test_markdown_printer.py | 2 +- ...sole_printer.py => test_pretty_printer.py} | 2 +- tests/unit/printer/test_printer_base.py | 49 +++++++++++ tests/unit/printer/test_sink.py | 69 ++++++++++++++++ ...inter.py => test_pretty_scorer_printer.py} | 2 +- 26 files changed, 348 insertions(+), 85 deletions(-) rename pyrit/printer/attack_result/{console.py => pretty.py} (95%) create mode 100644 pyrit/printer/base.py rename pyrit/printer/scenario_result/{console.py => pretty.py} (86%) rename pyrit/printer/scorer/{console.py => pretty.py} (95%) create mode 100644 pyrit/printer/sink.py rename tests/unit/executor/attack/printer/{test_console_printer.py => test_pretty_printer.py} (99%) create mode 100644 tests/unit/printer/test_printer_base.py create mode 100644 tests/unit/printer/test_sink.py rename tests/unit/score/{test_console_scorer_printer.py => test_pretty_scorer_printer.py} (99%) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index dabd464ccb..15752c02a9 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -39,7 +39,7 @@ from pyrit.cli._cli_args import validate_integer as validate_integer from pyrit.cli._cli_args import validate_log_level as validate_log_level from pyrit.cli._cli_args import validate_log_level_argparse as validate_log_level_argparse -from pyrit.printer.scenario_result.console import ConsoleScenarioMemoryPrinter as ConsoleScenarioResultPrinter +from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter as ConsoleScenarioResultPrinter from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry from pyrit.scenario import DatasetConfiguration from pyrit.setup import ConfigurationLoader, initialize_pyrit_async diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 368765e276..f5de192bba 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -483,8 +483,8 @@ def do_print_scenario(self, arg: str) -> None: print(f"\n{'#' * 80}") print(f"Scenario Run #{idx}: {command}") print(f"{'#' * 80}") - from pyrit.printer.scenario_result.console import ( - ConsoleScenarioMemoryPrinter as ConsoleScenarioResultPrinter, + from pyrit.printer.scenario_result.pretty import ( + PrettyScenarioResultMemoryPrinter as ConsoleScenarioResultPrinter, ) printer = ConsoleScenarioResultPrinter() @@ -500,8 +500,8 @@ def do_print_scenario(self, arg: str) -> None: command, result = self._scenario_history[scenario_num - 1] print(f"\nScenario Run #{scenario_num}: {command}") print("=" * 80) - from pyrit.printer.scenario_result.console import ( - ConsoleScenarioMemoryPrinter as ConsoleScenarioResultPrinter, + from pyrit.printer.scenario_result.pretty import ( + PrettyScenarioResultMemoryPrinter as ConsoleScenarioResultPrinter, ) printer = ConsoleScenarioResultPrinter() diff --git a/pyrit/executor/attack/__init__.py b/pyrit/executor/attack/__init__.py index b98dad3b22..e3f5373679 100644 --- a/pyrit/executor/attack/__init__.py +++ b/pyrit/executor/attack/__init__.py @@ -53,8 +53,8 @@ # Backward-compatibility aliases — import from pyrit.printer.attack_result directly. # TODO: Remove these re-exports in two releases (target removal: 0.16.0). from pyrit.printer.attack_result.base import AttackResultPrinterBase as AttackResultPrinter -from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter -from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter as MarkdownAttackResultPrinter +from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter as ConsoleAttackResultPrinter +from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter as MarkdownAttackResultPrinter __all__ = [ "AttackStrategy", diff --git a/pyrit/executor/attack/printer/__init__.py b/pyrit/executor/attack/printer/__init__.py index 6abcba9803..b542f3a1c3 100644 --- a/pyrit/executor/attack/printer/__init__.py +++ b/pyrit/executor/attack/printer/__init__.py @@ -13,7 +13,7 @@ def __getattr__(name: str) -> type: # noqa: N807 _deprecated = { - "ConsoleAttackResultPrinter": "pyrit.printer.attack_result.console", + "ConsoleAttackResultPrinter": "pyrit.printer.attack_result.pretty", "MarkdownAttackResultPrinter": "pyrit.printer.attack_result.markdown", "AttackResultPrinter": "pyrit.printer.attack_result.base", } @@ -26,17 +26,17 @@ def __getattr__(name: str) -> type: # noqa: N807 stacklevel=2, ) if name == "ConsoleAttackResultPrinter": - from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter + from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter - return ConsoleAttackMemoryPrinter + return PrettyAttackResultMemoryPrinter if name == "AttackResultPrinter": from pyrit.printer.attack_result.base import AttackResultPrinterBase return AttackResultPrinterBase if name == "MarkdownAttackResultPrinter": - from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter + from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter - return MarkdownAttackMemoryPrinter + return MarkdownAttackResultMemoryPrinter raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index 41f8980eef..c288a33c35 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. """ -Deprecated: Import from pyrit.printer.attack_result.console instead. +Deprecated: Import from pyrit.printer.attack_result.pretty instead. This re-export will be removed in 0.16.0. """ @@ -13,11 +13,11 @@ def __getattr__(name: str) -> type: # noqa: N807 if name == "ConsoleAttackResultPrinter": _warnings.warn( "Importing ConsoleAttackResultPrinter from pyrit.executor.attack.printer.console_printer is deprecated " - "and will be removed in 0.16.0. Import from pyrit.printer.attack_result.console instead.", + "and will be removed in 0.16.0. Import from pyrit.printer.attack_result.pretty instead.", DeprecationWarning, stacklevel=2, ) - from pyrit.printer.attack_result.console import ConsoleAttackResultPrinter + from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter - return ConsoleAttackResultPrinter + return PrettyAttackResultMemoryPrinter raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index 79fa83b688..4af15cb685 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -17,7 +17,7 @@ def __getattr__(name: str) -> type: # noqa: N807 DeprecationWarning, stacklevel=2, ) - from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter + from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter - return MarkdownAttackMemoryPrinter + return MarkdownAttackResultMemoryPrinter raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/printer/__init__.py b/pyrit/printer/__init__.py index 426fbac9ea..26f4ef9e8e 100644 --- a/pyrit/printer/__init__.py +++ b/pyrit/printer/__init__.py @@ -2,14 +2,18 @@ # Licensed under the MIT license. """ -Lightweight printer module for displaying attack, scenario, and scorer results. +Printer module for displaying attack, scenario, and scorer results. -This module contains abstract base classes with all formatting logic. -Data-fetching operations (conversations, scores, scorer metrics) are abstract -methods that must be implemented by subclasses. +This module provides: +- **Sink** classes that define where output goes (stdout, file, etc.) +- **PrinterBase** that all printers inherit from +- Domain printers for attack results, scenario results, and scorer information -Framework users: use the concrete implementations in pyrit.executor.attack.printer -and pyrit.scenario.printer which fetch data via CentralMemory. +File names indicate output format (pretty.py = ANSI-colored, markdown.py = Markdown). +Abstract methods inside each printer determine the data source (memory, REST, fixtures). -Thin clients: subclass the bases here and implement abstract methods via REST calls. +Framework users: use the Memory printer classes (e.g., PrettyAttackResultMemoryPrinter) +which fetch data via CentralMemory. + +Thin clients: subclass the base printers and implement abstract data methods via REST calls. """ diff --git a/pyrit/printer/attack_result/base.py b/pyrit/printer/attack_result/base.py index 7ea0f714f2..2aa6652434 100644 --- a/pyrit/printer/attack_result/base.py +++ b/pyrit/printer/attack_result/base.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from abc import ABC, abstractmethod +from abc import abstractmethod from pyrit.models import AttackOutcome, Message, MessagePiece, Score +from pyrit.printer.base import PrinterBase -class AttackResultPrinterBase(ABC): +class AttackResultPrinterBase(PrinterBase): """ Abstract base class for printing attack results. diff --git a/pyrit/printer/attack_result/markdown.py b/pyrit/printer/attack_result/markdown.py index 1d3f255afe..d9745a7c0a 100644 --- a/pyrit/printer/attack_result/markdown.py +++ b/pyrit/printer/attack_result/markdown.py @@ -6,25 +6,28 @@ from pyrit.models import AttackResult, ConversationType, Message, MessagePiece, Score from pyrit.printer.attack_result.base import AttackResultPrinterBase +from pyrit.printer.sink import Sink -class MarkdownAttackPrinterBase(AttackResultPrinterBase): +class MarkdownAttackResultPrinter(AttackResultPrinterBase): """ - Markdown printer base for attack results optimized for Jupyter notebooks. + Markdown printer for attack results optimized for Jupyter notebooks. Contains all formatting logic. Subclasses implement get_conversation_async and get_scores_async for data fetching. """ - def __init__(self, *, display_inline: bool = True) -> None: + def __init__(self, *, sink: Sink | None = None, display_inline: bool = True) -> None: """ Initialize the markdown printer. Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). display_inline (bool): If True, uses IPython.display to render markdown inline in Jupyter notebooks. If False, prints markdown strings. Defaults to True. """ + super().__init__(sink=sink) self._display_inline = display_inline def _render_markdown(self, markdown_lines: list[str]) -> None: @@ -551,24 +554,25 @@ async def _get_adversarial_conversation_markdown_async(self, result: AttackResul return markdown_lines -class MarkdownAttackMemoryPrinter(MarkdownAttackPrinterBase): +class MarkdownAttackResultMemoryPrinter(MarkdownAttackResultPrinter): """ Framework markdown printer for attack results. Implements data-fetching via CentralMemory (deferred import). - All formatting logic lives in MarkdownAttackPrinterBase. + All formatting logic lives in MarkdownAttackResultPrinter. """ - def __init__(self, *, display_inline: bool = True) -> None: + def __init__(self, *, sink: Sink | None = None, display_inline: bool = True) -> None: """ - Initialize the markdown printer. + Initialize the markdown printer with CentralMemory data source. Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). display_inline (bool): If True, uses IPython.display to render markdown inline in Jupyter notebooks. If False, prints markdown strings. Defaults to True. """ - super().__init__(display_inline=display_inline) + super().__init__(sink=sink, display_inline=display_inline) from pyrit.memory import CentralMemory self._memory = CentralMemory.get_memory_instance() diff --git a/pyrit/printer/attack_result/console.py b/pyrit/printer/attack_result/pretty.py similarity index 95% rename from pyrit/printer/attack_result/console.py rename to pyrit/printer/attack_result/pretty.py index c5a863d86d..e320c3dbb1 100644 --- a/pyrit/printer/attack_result/console.py +++ b/pyrit/printer/attack_result/pretty.py @@ -10,25 +10,30 @@ from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, MessagePiece, Score from pyrit.printer.attack_result.base import AttackResultPrinterBase +from pyrit.printer.sink import Sink -class ConsoleAttackPrinterBase(AttackResultPrinterBase): +class PrettyAttackResultPrinter(AttackResultPrinterBase): """ - Console printer base for attack results with enhanced formatting. + Pretty printer for attack results with ANSI-colored formatting. Contains all formatting logic. Subclasses implement get_conversation_async and get_scores_async for data fetching. """ - def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True) -> None: + def __init__( + self, *, sink: Sink | None = None, width: int = 100, indent_size: int = 2, enable_colors: bool = True + ) -> None: """ - Initialize the console printer. + Initialize the pretty printer. Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). width (int): Maximum width for text wrapping. Defaults to 100. indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. """ + super().__init__(sink=sink) self._width = width self._indent = " " * indent_size self._enable_colors = enable_colors @@ -490,24 +495,27 @@ async def display_image_async(self, piece: MessagePiece) -> None: await display_image_response(piece) -class ConsoleAttackMemoryPrinter(ConsoleAttackPrinterBase): +class PrettyAttackResultMemoryPrinter(PrettyAttackResultPrinter): """ - Framework console printer for attack results. + Framework pretty printer for attack results. Implements data-fetching via CentralMemory (deferred import). - All formatting logic lives in ConsoleAttackPrinterBase. + All formatting logic lives in PrettyAttackResultPrinter. """ - def __init__(self, *, width: int = 100, indent_size: int = 2, enable_colors: bool = True) -> None: + def __init__( + self, *, sink: Sink | None = None, width: int = 100, indent_size: int = 2, enable_colors: bool = True + ) -> None: """ - Initialize the console printer. + Initialize the pretty printer with CentralMemory data source. Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). width (int): Maximum width for text wrapping. Defaults to 100. indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. """ - super().__init__(width=width, indent_size=indent_size, enable_colors=enable_colors) + super().__init__(sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors) from pyrit.memory import CentralMemory self._memory = CentralMemory.get_memory_instance() diff --git a/pyrit/printer/base.py b/pyrit/printer/base.py new file mode 100644 index 0000000000..c68ce155a0 --- /dev/null +++ b/pyrit/printer/base.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC + +from pyrit.printer.sink import Sink, StdoutSink + + +class PrinterBase(ABC): + """ + Abstract base class for all printers. + + Provides a sink for output routing. Subclasses write their rendered + output through the sink via ``_write_async``. + """ + + def __init__(self, *, sink: Sink | None = None) -> None: + """ + Initialize the printer base. + + Args: + sink (Sink | None): The output sink. Defaults to StdoutSink() if not provided. + """ + self._sink = sink or StdoutSink() + + async def _write_async(self, data: bytes) -> None: + """ + Write data through the configured sink. + + Args: + data (bytes): The rendered output to write. + """ + await self._sink.write_async(data) diff --git a/pyrit/printer/scenario_result/base.py b/pyrit/printer/scenario_result/base.py index 028a855bf1..1545c59cfd 100644 --- a/pyrit/printer/scenario_result/base.py +++ b/pyrit/printer/scenario_result/base.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from abc import ABC, abstractmethod +from abc import abstractmethod from pyrit.models.scenario_result import ScenarioResult +from pyrit.printer.base import PrinterBase -class ScenarioResultPrinterBase(ABC): +class ScenarioResultPrinterBase(PrinterBase): """ Abstract base class for printing scenario results. diff --git a/pyrit/printer/scenario_result/console.py b/pyrit/printer/scenario_result/pretty.py similarity index 86% rename from pyrit/printer/scenario_result/console.py rename to pyrit/printer/scenario_result/pretty.py index f13d5c9c10..9c18f38943 100644 --- a/pyrit/printer/scenario_result/console.py +++ b/pyrit/printer/scenario_result/pretty.py @@ -10,11 +10,12 @@ from pyrit.models.scenario_result import ScenarioResult from pyrit.printer.scenario_result.base import ScenarioResultPrinterBase from pyrit.printer.scorer.base import ScorerPrinterBase +from pyrit.printer.sink import Sink -class ConsoleScenarioPrinterBase(ScenarioResultPrinterBase): +class PrettyScenarioResultPrinter(ScenarioResultPrinterBase): """ - Console printer base for scenario results with enhanced formatting. + Pretty printer for scenario results with ANSI-colored formatting. Contains all formatting logic. Subclasses must provide a scorer_printer via the abstract property. @@ -23,18 +24,21 @@ class ConsoleScenarioPrinterBase(ScenarioResultPrinterBase): def __init__( self, *, + sink: Sink | None = None, width: int = 100, indent_size: int = 2, enable_colors: bool = True, ) -> None: """ - Initialize the console printer. + Initialize the pretty scenario printer. Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). width (int): Maximum width for text wrapping. Defaults to 100. indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. """ + super().__init__(sink=sink) self._width = width self._indent = " " * indent_size self._enable_colors = enable_colors @@ -180,33 +184,37 @@ def _get_rate_color(self, rate: int) -> str: return str(Fore.GREEN) -class ConsoleScenarioMemoryPrinter(ConsoleScenarioPrinterBase): +class PrettyScenarioResultMemoryPrinter(PrettyScenarioResultPrinter): """ - Framework console printer for scenario results. + Framework pretty printer for scenario results. - Provides the framework's ConsoleScorerMemoryPrinter for scorer information display. - All formatting logic lives in ConsoleScenarioPrinterBase. + Provides the framework's PrettyScorerMemoryPrinter for scorer information display. + All formatting logic lives in PrettyScenarioResultPrinter. """ def __init__( self, *, + sink: Sink | None = None, width: int = 100, indent_size: int = 2, enable_colors: bool = True, ) -> None: """ - Initialize the console printer. + Initialize the pretty scenario printer with CentralMemory data source. Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). width (int): Maximum width for text wrapping. Defaults to 100. indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. """ - super().__init__(width=width, indent_size=indent_size, enable_colors=enable_colors) - from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter + super().__init__(sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors) + from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter - self._scorer_printer = ConsoleScorerMemoryPrinter(indent_size=indent_size, enable_colors=enable_colors) + self._scorer_printer = PrettyScorerMemoryPrinter( + sink=self._sink, indent_size=indent_size, enable_colors=enable_colors + ) @property def scorer_printer(self) -> ScorerPrinterBase: diff --git a/pyrit/printer/scorer/base.py b/pyrit/printer/scorer/base.py index ec02bae2a0..418c4b233a 100644 --- a/pyrit/printer/scorer/base.py +++ b/pyrit/printer/scorer/base.py @@ -1,13 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any from pyrit.identifiers import ComponentIdentifier +from pyrit.printer.base import PrinterBase -class ScorerPrinterBase(ABC): +class ScorerPrinterBase(PrinterBase): """ Abstract base class for printing scorer information. diff --git a/pyrit/printer/scorer/console.py b/pyrit/printer/scorer/pretty.py similarity index 95% rename from pyrit/printer/scorer/console.py rename to pyrit/printer/scorer/pretty.py index e22d99f45b..2e3c188519 100644 --- a/pyrit/printer/scorer/console.py +++ b/pyrit/printer/scorer/pretty.py @@ -7,11 +7,12 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.printer.scorer.base import ScorerPrinterBase +from pyrit.printer.sink import Sink -class ConsoleScorerPrinterBase(ScorerPrinterBase): +class PrettyScorerPrinter(ScorerPrinterBase): """ - Console printer base for scorer information with enhanced formatting. + Pretty printer for scorer information with ANSI-colored formatting. Contains all formatting logic. Subclasses implement get_objective_metrics and get_harm_metrics for data fetching. @@ -20,17 +21,19 @@ class ConsoleScorerPrinterBase(ScorerPrinterBase): _SCORER_DISPLAY_PARAMS = frozenset({"scorer_type", "score_aggregator"}) _TARGET_DISPLAY_PARAMS = frozenset({"model_name", "temperature"}) - def __init__(self, *, indent_size: int = 2, enable_colors: bool = True) -> None: + def __init__(self, *, sink: Sink | None = None, indent_size: int = 2, enable_colors: bool = True) -> None: """ - Initialize the console scorer printer. + Initialize the pretty scorer printer. Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. Raises: ValueError: If indent_size is negative. """ + super().__init__(sink=sink) if indent_size < 0: raise ValueError("indent_size must be non-negative") self._indent = " " * indent_size @@ -251,12 +254,12 @@ def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, harm_cate self._print_harm_metrics(metrics) -class ConsoleScorerMemoryPrinter(ConsoleScorerPrinterBase): +class PrettyScorerMemoryPrinter(PrettyScorerPrinter): """ - Framework console printer for scorer information. + Framework pretty printer for scorer information. Implements metrics fetching via the scorer evaluation registry (deferred import). - All formatting logic lives in ConsoleScorerPrinterBase. + All formatting logic lives in PrettyScorerPrinter. """ def _get_objective_metrics(self, *, eval_hash: str) -> Any: diff --git a/pyrit/printer/sink.py b/pyrit/printer/sink.py new file mode 100644 index 0000000000..993c237f74 --- /dev/null +++ b/pyrit/printer/sink.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC, abstractmethod +from pathlib import Path + + +class Sink(ABC): + """ + Abstract base class for output sinks. + + A sink defines where rendered output goes (stdout, file, etc.). + All printers write their output through a sink. + """ + + @abstractmethod + async def write_async(self, data: bytes) -> None: + """ + Write rendered output data. + + Args: + data (bytes): The rendered output to write. + """ + + +class StdoutSink(Sink): + """ + Sink that decodes bytes to str and prints to stdout. + + This is the default sink used when no sink is specified. + """ + + def __init__(self, *, encoding: str = "utf-8") -> None: + """ + Initialize the stdout sink. + + Args: + encoding (str): Character encoding for decoding bytes. Defaults to "utf-8". + """ + self._encoding = encoding + + async def write_async(self, data: bytes) -> None: + """ + Write data to stdout. + + Args: + data (bytes): The data to print, decoded using the configured encoding. + """ + print(data.decode(self._encoding), end="") + + +class FileSink(Sink): + """ + Sink that writes bytes to a file. + """ + + def __init__(self, *, path: Path, mode: str = "wb") -> None: + """ + Initialize the file sink. + + Args: + path (Path): The file path to write to. + mode (str): The file open mode. Defaults to "wb" (write binary, overwrite). + Use "ab" for append mode. + + Raises: + ValueError: If mode is not a valid binary write mode. + """ + if mode not in ("wb", "ab"): + raise ValueError(f"mode must be 'wb' or 'ab', got '{mode}'") + self._path = path + self._mode = mode + + async def write_async(self, data: bytes) -> None: + """ + Write data to a file. + + Args: + data (bytes): The data to write. + """ + with open(self._path, self._mode) as f: + f.write(data) diff --git a/pyrit/scenario/printer/__init__.py b/pyrit/scenario/printer/__init__.py index 1eb00d5516..76ce6338c4 100644 --- a/pyrit/scenario/printer/__init__.py +++ b/pyrit/scenario/printer/__init__.py @@ -13,7 +13,7 @@ def __getattr__(name: str) -> type: # noqa: N807 _deprecated = { - "ConsoleScenarioResultPrinter": "pyrit.printer.scenario_result.console", + "ConsoleScenarioResultPrinter": "pyrit.printer.scenario_result.pretty", "ScenarioResultPrinter": "pyrit.printer.scenario_result.base", } if name in _deprecated: @@ -25,9 +25,9 @@ def __getattr__(name: str) -> type: # noqa: N807 stacklevel=2, ) if name == "ConsoleScenarioResultPrinter": - from pyrit.printer.scenario_result.console import ConsoleScenarioMemoryPrinter + from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter - return ConsoleScenarioMemoryPrinter + return PrettyScenarioResultMemoryPrinter if name == "ScenarioResultPrinter": from pyrit.printer.scenario_result.base import ScenarioResultPrinterBase diff --git a/pyrit/scenario/printer/console_printer.py b/pyrit/scenario/printer/console_printer.py index 371e717098..531a58a126 100644 --- a/pyrit/scenario/printer/console_printer.py +++ b/pyrit/scenario/printer/console_printer.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. """ -Deprecated: Import from pyrit.printer.scenario_result.console instead. +Deprecated: Import from pyrit.printer.scenario_result.pretty instead. This re-export will be removed in 0.16.0. """ @@ -13,11 +13,11 @@ def __getattr__(name: str) -> type: # noqa: N807 if name == "ConsoleScenarioResultPrinter": _warnings.warn( "Importing ConsoleScenarioResultPrinter from pyrit.scenario.printer.console_printer is deprecated " - "and will be removed in 0.16.0. Import from pyrit.printer.scenario_result.console instead.", + "and will be removed in 0.16.0. Import from pyrit.printer.scenario_result.pretty instead.", DeprecationWarning, stacklevel=2, ) - from pyrit.printer.scenario_result.console import ConsoleScenarioMemoryPrinter + from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter - return ConsoleScenarioMemoryPrinter + return PrettyScenarioResultMemoryPrinter raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/score/__init__.py b/pyrit/score/__init__.py index 68ef2c0641..1cbfa3b397 100644 --- a/pyrit/score/__init__.py +++ b/pyrit/score/__init__.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING from pyrit.printer.scorer.base import ScorerPrinterBase as ScorerPrinter -from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter as ConsoleScorerPrinter +from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter as ConsoleScorerPrinter from pyrit.score.batch_scorer import BatchScorer from pyrit.score.conversation_scorer import ConversationScorer, create_conversation_scorer from pyrit.score.float_scale.azure_content_filter_scorer import AzureContentFilterScorer diff --git a/pyrit/score/printer/__init__.py b/pyrit/score/printer/__init__.py index dc9b5d9866..a15df6cc67 100644 --- a/pyrit/score/printer/__init__.py +++ b/pyrit/score/printer/__init__.py @@ -13,7 +13,7 @@ def __getattr__(name: str) -> type: # noqa: N807 _deprecated = { - "ConsoleScorerPrinter": "pyrit.printer.scorer.console", + "ConsoleScorerPrinter": "pyrit.printer.scorer.pretty", "ScorerPrinter": "pyrit.printer.scorer.base", } if name in _deprecated: @@ -25,9 +25,9 @@ def __getattr__(name: str) -> type: # noqa: N807 stacklevel=2, ) if name == "ConsoleScorerPrinter": - from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter + from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter - return ConsoleScorerMemoryPrinter + return PrettyScorerMemoryPrinter if name == "ScorerPrinter": from pyrit.printer.scorer.base import ScorerPrinterBase diff --git a/pyrit/score/printer/console_scorer_printer.py b/pyrit/score/printer/console_scorer_printer.py index 0c7e8a47a2..ee5b02d235 100644 --- a/pyrit/score/printer/console_scorer_printer.py +++ b/pyrit/score/printer/console_scorer_printer.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. """ -Deprecated: Import from pyrit.printer.scorer.console instead. +Deprecated: Import from pyrit.printer.scorer.pretty instead. This re-export will be removed in 0.16.0. """ @@ -13,11 +13,11 @@ def __getattr__(name: str) -> type: # noqa: N807 if name == "ConsoleScorerPrinter": _warnings.warn( "Importing ConsoleScorerPrinter from pyrit.score.printer.console_scorer_printer is deprecated " - "and will be removed in 0.16.0. Import from pyrit.printer.scorer.console instead.", + "and will be removed in 0.16.0. Import from pyrit.printer.scorer.pretty instead.", DeprecationWarning, stacklevel=2, ) - from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter + from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter - return ConsoleScorerMemoryPrinter + return PrettyScorerMemoryPrinter raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/unit/executor/attack/core/test_markdown_printer.py b/tests/unit/executor/attack/core/test_markdown_printer.py index 0ad6e957bf..b9f088ca11 100644 --- a/tests/unit/executor/attack/core/test_markdown_printer.py +++ b/tests/unit/executor/attack/core/test_markdown_printer.py @@ -11,7 +11,7 @@ from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult, Message, MessagePiece, Score -from pyrit.printer.attack_result.markdown import MarkdownAttackMemoryPrinter as MarkdownAttackResultPrinter +from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter as MarkdownAttackResultPrinter def _mock_scorer_id(name: str = "MockScorer") -> ComponentIdentifier: diff --git a/tests/unit/executor/attack/printer/test_console_printer.py b/tests/unit/executor/attack/printer/test_pretty_printer.py similarity index 99% rename from tests/unit/executor/attack/printer/test_console_printer.py rename to tests/unit/executor/attack/printer/test_pretty_printer.py index 46b746d5e2..fefcc38056 100644 --- a/tests/unit/executor/attack/printer/test_console_printer.py +++ b/tests/unit/executor/attack/printer/test_pretty_printer.py @@ -10,7 +10,7 @@ from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, MessagePiece, Score from pyrit.models.conversation_reference import ConversationReference -from pyrit.printer.attack_result.console import ConsoleAttackMemoryPrinter as ConsoleAttackResultPrinter +from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter as ConsoleAttackResultPrinter def _mock_scorer_id(name: str = "MockScorer") -> ComponentIdentifier: diff --git a/tests/unit/printer/test_printer_base.py b/tests/unit/printer/test_printer_base.py new file mode 100644 index 0000000000..92c19bd030 --- /dev/null +++ b/tests/unit/printer/test_printer_base.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.printer.base import PrinterBase +from pyrit.printer.sink import StdoutSink + + +def test_printer_base_has_no_abstract_methods(): + # PrinterBase is abstract via ABC but has no abstract methods of its own. + # Subclasses add their own abstract methods for data fetching. + class ConcretePrinter(PrinterBase): + pass + + printer = ConcretePrinter() + assert isinstance(printer, PrinterBase) + + +def test_printer_base_defaults_to_stdout_sink(): + + class ConcretePrinter(PrinterBase): + pass + + printer = ConcretePrinter() + assert isinstance(printer._sink, StdoutSink) + + +def test_printer_base_accepts_custom_sink(): + from pyrit.printer.sink import FileSink + from pathlib import Path + + class ConcretePrinter(PrinterBase): + pass + + sink = FileSink(path=Path("test.txt")) + printer = ConcretePrinter(sink=sink) + assert printer._sink is sink + + +async def test_printer_base_write_async_delegates_to_sink(capsys): + + class ConcretePrinter(PrinterBase): + pass + + printer = ConcretePrinter() + await printer._write_async(b"test output") + captured = capsys.readouterr() + assert captured.out == "test output" diff --git a/tests/unit/printer/test_sink.py b/tests/unit/printer/test_sink.py new file mode 100644 index 0000000000..ca5b9df226 --- /dev/null +++ b/tests/unit/printer/test_sink.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import tempfile +from pathlib import Path + +import pytest + +from pyrit.printer.sink import FileSink, Sink, StdoutSink + + +def test_sink_is_abstract(): + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + Sink() # type: ignore[abstract] + + +async def test_stdout_sink_writes_to_stdout(capsys): + sink = StdoutSink() + await sink.write_async(b"hello world") + captured = capsys.readouterr() + assert captured.out == "hello world" + + +async def test_stdout_sink_no_trailing_newline(capsys): + sink = StdoutSink() + await sink.write_async(b"line1") + await sink.write_async(b"line2") + captured = capsys.readouterr() + assert captured.out == "line1line2" + + +async def test_stdout_sink_custom_encoding(capsys): + sink = StdoutSink(encoding="ascii") + await sink.write_async(b"ascii text") + captured = capsys.readouterr() + assert captured.out == "ascii text" + + +async def test_file_sink_writes_to_file(): + with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f: + path = Path(f.name) + + try: + sink = FileSink(path=path, mode="wb") + await sink.write_async(b"hello file") + assert path.read_bytes() == b"hello file" + finally: + path.unlink(missing_ok=True) + + +async def test_file_sink_append_mode(): + with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f: + path = Path(f.name) + + try: + sink = FileSink(path=path, mode="wb") + await sink.write_async(b"first") + + append_sink = FileSink(path=path, mode="ab") + await append_sink.write_async(b" second") + + assert path.read_bytes() == b"first second" + finally: + path.unlink(missing_ok=True) + + +def test_file_sink_rejects_invalid_mode(): + with pytest.raises(ValueError, match="mode must be 'wb' or 'ab'"): + FileSink(path=Path("test.txt"), mode="w") diff --git a/tests/unit/score/test_console_scorer_printer.py b/tests/unit/score/test_pretty_scorer_printer.py similarity index 99% rename from tests/unit/score/test_console_scorer_printer.py rename to tests/unit/score/test_pretty_scorer_printer.py index b314013230..90a04339d9 100644 --- a/tests/unit/score/test_console_scorer_printer.py +++ b/tests/unit/score/test_pretty_scorer_printer.py @@ -7,7 +7,7 @@ from colorama import Fore, Style from pyrit.identifiers import ComponentIdentifier -from pyrit.printer.scorer.console import ConsoleScorerMemoryPrinter as ConsoleScorerPrinter +from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter as ConsoleScorerPrinter from pyrit.score.scorer_evaluation.scorer_metrics import ( HarmScorerMetrics, ObjectiveScorerMetrics, From ae13cdb0cdfefe30d314ff16d1516b178fd28595 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 22:17:21 -0700 Subject: [PATCH 16/34] Route all output through sinks, add write_async and convenience methods Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/printer/__init__.py | 128 ++++- pyrit/printer/attack_result/base.py | 6 +- pyrit/printer/attack_result/markdown.py | 91 ++-- pyrit/printer/attack_result/pretty.py | 449 +++++++++++------- pyrit/printer/base.py | 20 +- pyrit/printer/scenario_result/base.py | 11 +- pyrit/printer/scenario_result/pretty.py | 203 ++++---- pyrit/printer/scorer/base.py | 36 +- pyrit/printer/scorer/pretty.py | 176 ++++--- pyrit/printer/sink.py | 41 +- .../attack/printer/test_pretty_printer.py | 206 ++++---- tests/unit/printer/test_convenience.py | 142 ++++++ tests/unit/printer/test_printer_base.py | 24 +- tests/unit/printer/test_sink.py | 33 +- .../unit/score/test_pretty_scorer_printer.py | 103 ++-- tests/unit/score/test_scorer_printer.py | 19 +- 16 files changed, 1052 insertions(+), 636 deletions(-) create mode 100644 tests/unit/printer/test_convenience.py diff --git a/pyrit/printer/__init__.py b/pyrit/printer/__init__.py index 26f4ef9e8e..50abdf19d3 100644 --- a/pyrit/printer/__init__.py +++ b/pyrit/printer/__init__.py @@ -8,12 +8,132 @@ - **Sink** classes that define where output goes (stdout, file, etc.) - **PrinterBase** that all printers inherit from - Domain printers for attack results, scenario results, and scorer information +- **Convenience functions** for one-line printing (e.g., ``print_attack_result_async``) File names indicate output format (pretty.py = ANSI-colored, markdown.py = Markdown). Abstract methods inside each printer determine the data source (memory, REST, fixtures). +""" -Framework users: use the Memory printer classes (e.g., PrettyAttackResultMemoryPrinter) -which fetch data via CentralMemory. +from pathlib import Path +from typing import Literal -Thin clients: subclass the base printers and implement abstract data methods via REST calls. -""" +from pyrit.printer.sink import FileSink, Sink, StdoutSink + +OutputFormat = Literal["pretty", "markdown"] + + +def _resolve_sink(to: Path | str | Sink | None) -> Sink: + """ + Resolve a destination argument to a Sink instance. + + Args: + to (Path | str | Sink | None): The destination. + None → StdoutSink. + Path or str → FileSink. + Sink instance → used as-is. + + Returns: + Sink: The resolved sink. + """ + if to is None: + return StdoutSink() + if isinstance(to, Sink): + return to + return FileSink(path=Path(to)) + + +async def print_attack_result_async( + result: "AttackResult", # noqa: F821 + *, + format: OutputFormat = "pretty", + to: Path | str | Sink | None = None, + include_auxiliary_scores: bool = False, + include_pruned_conversations: bool = False, + include_adversarial_conversation: bool = False, +) -> None: + """ + Print an attack result in the specified format to the specified destination. + + Args: + result (AttackResult): The attack result to print. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. + include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. + include_adversarial_conversation (bool): Whether to include the adversarial conversation. + Defaults to False. + """ + sink = _resolve_sink(to) + + if format == "markdown": + from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter + + printer = MarkdownAttackResultMemoryPrinter(sink=sink) + else: + from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter + + printer = PrettyAttackResultMemoryPrinter(sink=sink) + + await printer.write_async( + result, + include_auxiliary_scores=include_auxiliary_scores, + include_pruned_conversations=include_pruned_conversations, + include_adversarial_conversation=include_adversarial_conversation, + ) + + +async def print_scenario_result_async( + result: "ScenarioResult", # noqa: F821 + *, + format: OutputFormat = "pretty", + to: Path | str | Sink | None = None, +) -> None: + """ + Print a scenario result in the specified format to the specified destination. + + Args: + result (ScenarioResult): The scenario result to print. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + """ + sink = _resolve_sink(to) + + if format == "pretty": + from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter + + printer = PrettyScenarioResultMemoryPrinter(sink=sink) + else: + raise ValueError(f"Unsupported format for scenario results: {format!r}. Only 'pretty' is available.") + + await printer.write_async(result) + + +async def print_scorer_async( + *, + scorer_identifier: "ComponentIdentifier", # noqa: F821 + harm_category: str | None = None, + format: OutputFormat = "pretty", + to: Path | str | Sink | None = None, +) -> None: + """ + Print scorer information in the specified format to the specified destination. + + Auto-detects scorer type: if harm_category is provided, renders harm + metrics; otherwise renders objective metrics. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier. + harm_category (str | None): The harm category. None for objective scorers. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + """ + sink = _resolve_sink(to) + + if format == "pretty": + from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter + + printer = PrettyScorerMemoryPrinter(sink=sink) + else: + raise ValueError(f"Unsupported format for scorer: {format!r}. Only 'pretty' is available.") + + await printer.write_async(scorer_identifier=scorer_identifier, harm_category=harm_category) diff --git a/pyrit/printer/attack_result/base.py b/pyrit/printer/attack_result/base.py index 2aa6652434..14657e4585 100644 --- a/pyrit/printer/attack_result/base.py +++ b/pyrit/printer/attack_result/base.py @@ -19,7 +19,7 @@ class AttackResultPrinterBase(PrinterBase): """ @abstractmethod - async def get_conversation_async(self, conversation_id: str) -> list[Message]: + async def _get_conversation_async(self, conversation_id: str) -> list[Message]: """ Fetch conversation messages for a given conversation ID. @@ -31,7 +31,7 @@ async def get_conversation_async(self, conversation_id: str) -> list[Message]: """ @abstractmethod - async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: + async def _get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: """ Fetch scores for given prompt piece IDs. @@ -42,7 +42,7 @@ async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: list[Score]: The scores associated with the given piece IDs. """ - async def display_image_async(self, piece: MessagePiece) -> None: # noqa: B027 + async def _display_image_async(self, piece: MessagePiece) -> None: # noqa: B027 """ Display an image from a message piece. No-op by default. diff --git a/pyrit/printer/attack_result/markdown.py b/pyrit/printer/attack_result/markdown.py index d9745a7c0a..88038839ed 100644 --- a/pyrit/printer/attack_result/markdown.py +++ b/pyrit/printer/attack_result/markdown.py @@ -23,35 +23,12 @@ def __init__(self, *, sink: Sink | None = None, display_inline: bool = True) -> Args: sink (Sink | None): Output sink. Defaults to StdoutSink(). - display_inline (bool): If True, uses IPython.display to render markdown - inline in Jupyter notebooks. If False, prints markdown strings. - Defaults to True. + display_inline (bool): Kept for backward compatibility but unused. + All output is routed through the sink. Defaults to True. """ super().__init__(sink=sink) self._display_inline = display_inline - def _render_markdown(self, markdown_lines: list[str]) -> None: - """ - Render the markdown content using appropriate display method. - - Attempts to use IPython.display.Markdown for Jupyter notebook rendering - when display_inline is True, falling back to print() if not available. - - Args: - markdown_lines (List[str]): List of markdown strings to render. - """ - full_markdown = "\n".join(markdown_lines) - - if self._display_inline: - try: - from IPython.display import Markdown, display - - display(Markdown(full_markdown)) - except (ImportError, NameError): - print(full_markdown) - else: - print(full_markdown) - def _format_score(self, score: Score, indent: str = "") -> str: """ Format a score object as markdown with proper styling. @@ -91,7 +68,7 @@ def _format_score(self, score: Score, indent: str = "") -> str: return "\n".join(lines) - async def print_result_async( + async def write_async( self, result: AttackResult, *, @@ -100,16 +77,16 @@ async def print_result_async( include_adversarial_conversation: bool = False, ) -> None: """ - Print the complete attack result as formatted markdown. + Render and write the complete attack result as markdown to the sink. Args: - result (AttackResult): The attack result to print. + result (AttackResult): The attack result to render. include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. include_adversarial_conversation (bool): Whether to include the adversarial conversation. Defaults to False. """ - markdown_lines = [] + markdown_lines: list[str] = [] outcome_emoji = self._get_outcome_icon(result.outcome) markdown_lines.append(f"# {outcome_emoji} Attack Result: {result.outcome.value.upper()}\n") @@ -148,28 +125,33 @@ async def print_result_async( timestamp_utc = datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z") markdown_lines.append(f"*Report generated at {timestamp_utc}*") - self._render_markdown(markdown_lines) + await self._write_async("\n".join(markdown_lines)) - async def print_conversation_async(self, result: AttackResult, *, include_scores: bool = False) -> None: - """ - Print only the conversation history as formatted markdown. + async def print_result_async( + self, + result: AttackResult, + *, + include_auxiliary_scores: bool = False, + include_pruned_conversations: bool = False, + include_adversarial_conversation: bool = False, + ) -> None: + """Deprecated. Use write_async instead.""" + await self.write_async( + result, + include_auxiliary_scores=include_auxiliary_scores, + include_pruned_conversations=include_pruned_conversations, + include_adversarial_conversation=include_adversarial_conversation, + ) - Args: - result (AttackResult): The attack result containing the conversation to display. - include_scores (bool): Whether to include scores. Defaults to False. - """ + async def print_conversation_async(self, result: AttackResult, *, include_scores: bool = False) -> None: + """Deprecated. Use _get_conversation_markdown_async and _write_async instead.""" markdown_lines = await self._get_conversation_markdown_async(result=result, include_scores=include_scores) - self._render_markdown(markdown_lines) + await self._write_async("\n".join(markdown_lines)) async def print_summary_async(self, result: AttackResult) -> None: - """ - Print a summary of the attack result as formatted markdown. - - Args: - result (AttackResult): The attack result to summarize. - """ + """Deprecated. Use _get_summary_markdown_async and _write_async instead.""" markdown_lines = await self._get_summary_markdown_async(result) - self._render_markdown(markdown_lines) + await self._write_async("\n".join(markdown_lines)) async def _get_conversation_markdown_async( self, *, result: AttackResult, include_scores: bool = False @@ -190,7 +172,7 @@ async def _get_conversation_markdown_async( markdown_lines.append("*No conversation ID available*\n") return markdown_lines - messages = await self.get_conversation_async(result.conversation_id) + messages = await self._get_conversation_async(result.conversation_id) if not messages: markdown_lines.append(f"*No conversation found for ID: {result.conversation_id}*\n") @@ -395,7 +377,7 @@ async def _format_message_scores_async(self, message: Message) -> list[str]: """ lines: list[str] = [] for piece in message.message_pieces: - scores = await self.get_scores_async(prompt_ids=[str(piece.id)]) + scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) if scores: lines.append("\n##### Scores\n") lines.extend(self._format_score(score, indent="") for score in scores) @@ -470,7 +452,7 @@ async def _get_pruned_conversations_markdown_async(self, result: AttackResult) - label += f" - {ref.description}" markdown_lines.append(f"\n{label}\n") - messages = await self.get_conversation_async(ref.conversation_id) + messages = await self._get_conversation_async(ref.conversation_id) if not messages: markdown_lines.append(f"*No messages found for conversation: `{ref.conversation_id}`*\n") @@ -490,7 +472,7 @@ async def _get_pruned_conversations_markdown_async(self, result: AttackResult) - else: markdown_lines.append(f"> {content}\n") - scores = await self.get_scores_async(prompt_ids=[str(piece.id)]) + scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) if scores: markdown_lines.append("\n**Score:**\n") markdown_lines.extend(self._format_score(score, indent="") for score in scores) @@ -526,7 +508,7 @@ async def _get_adversarial_conversation_markdown_async(self, result: AttackResul if ref.description: markdown_lines.append(f"*📝 {ref.description}*\n") - messages = await self.get_conversation_async(ref.conversation_id) + messages = await self._get_conversation_async(ref.conversation_id) if not messages: markdown_lines.append(f"*No messages found for conversation: `{ref.conversation_id}`*\n") @@ -568,16 +550,15 @@ def __init__(self, *, sink: Sink | None = None, display_inline: bool = True) -> Args: sink (Sink | None): Output sink. Defaults to StdoutSink(). - display_inline (bool): If True, uses IPython.display to render markdown - inline in Jupyter notebooks. If False, prints markdown strings. - Defaults to True. + display_inline (bool): Kept for backward compatibility but unused. + All output is routed through the sink. Defaults to True. """ super().__init__(sink=sink, display_inline=display_inline) from pyrit.memory import CentralMemory self._memory = CentralMemory.get_memory_instance() - async def get_conversation_async(self, conversation_id: str) -> list[Message]: + async def _get_conversation_async(self, conversation_id: str) -> list[Message]: """ Fetch conversation messages from CentralMemory. @@ -586,7 +567,7 @@ async def get_conversation_async(self, conversation_id: str) -> list[Message]: """ return list(self._memory.get_conversation(conversation_id=conversation_id)) - async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: + async def _get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: """ Fetch scores from CentralMemory. diff --git a/pyrit/printer/attack_result/pretty.py b/pyrit/printer/attack_result/pretty.py index e320c3dbb1..b8adee7169 100644 --- a/pyrit/printer/attack_result/pretty.py +++ b/pyrit/printer/attack_result/pretty.py @@ -38,21 +38,23 @@ def __init__( self._indent = " " * indent_size self._enable_colors = enable_colors - def _print_colored(self, text: str, *colors: str) -> None: + def _format_colored(self, text: str, *colors: str) -> str: """ - Print text with color formatting if colors are enabled. + Format text with color codes if colors are enabled. Args: - text (str): The text to print. + text (str): The text to format. *colors: Variable number of colorama color constants to apply. + + Returns: + str: The formatted line with trailing newline. """ if self._enable_colors and colors: color_prefix = "".join(colors) - print(f"{color_prefix}{text}{Style.RESET_ALL}") - else: - print(text) + return f"{color_prefix}{text}{Style.RESET_ALL}\n" + return f"{text}\n" - async def print_result_async( + async def write_async( self, result: AttackResult, *, @@ -61,150 +63,203 @@ async def print_result_async( include_adversarial_conversation: bool = False, ) -> None: """ - Print the complete attack result to console. + Render and write the complete attack result to the sink. Args: - result (AttackResult): The attack result to print. + result (AttackResult): The attack result to render. include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. include_adversarial_conversation (bool): Whether to include the adversarial conversation. Defaults to False. """ - self._print_header(result) - await self.print_summary_async(result) - - self._print_section_header("Conversation History with Objective Target") - await self.print_conversation_async(result, include_scores=include_auxiliary_scores) - + lines: list[str] = [] + lines.append(self._render_header(result)) + lines.append(await self._render_summary_async(result)) + lines.append(self._render_section_header("Conversation History with Objective Target")) + lines.append(await self._render_conversation_async(result, include_scores=include_auxiliary_scores)) if include_pruned_conversations: - await self._print_pruned_conversations_async(result) - + lines.append(await self._render_pruned_conversations_async(result)) if include_adversarial_conversation: - await self._print_adversarial_conversation_async(result) - + lines.append(await self._render_adversarial_conversation_async(result)) if result.metadata: - self._print_metadata(result.metadata) + lines.append(self._render_metadata(result.metadata)) + lines.append(self._render_footer()) + await self._write_async("".join(lines)) - self._print_footer() + async def print_result_async( + self, + result: AttackResult, + *, + include_auxiliary_scores: bool = False, + include_pruned_conversations: bool = False, + include_adversarial_conversation: bool = False, + ) -> None: + """Deprecated. Use write_async instead.""" + await self.write_async( + result, + include_auxiliary_scores=include_auxiliary_scores, + include_pruned_conversations=include_pruned_conversations, + include_adversarial_conversation=include_adversarial_conversation, + ) - async def print_conversation_async( + async def _render_conversation_async( self, result: AttackResult, *, include_scores: bool = False, include_reasoning_trace: bool = False - ) -> None: + ) -> str: """ - Print the conversation history to console. + Render the conversation history as a formatted string. Args: result (AttackResult): The attack result containing the conversation_id. include_scores (bool): Whether to include scores. Defaults to False. include_reasoning_trace (bool): Whether to include model reasoning trace. Defaults to False. + + Returns: + str: The rendered conversation text. """ if not result.conversation_id: - self._print_colored(f"{self._indent} No conversation ID available", Fore.YELLOW) - return + return self._format_colored(f"{self._indent} No conversation ID available", Fore.YELLOW) - messages = await self.get_conversation_async(result.conversation_id) + messages = await self._get_conversation_async(result.conversation_id) if not messages: - self._print_colored(f"{self._indent} No conversation found for ID: {result.conversation_id}", Fore.YELLOW) - return + return self._format_colored( + f"{self._indent} No conversation found for ID: {result.conversation_id}", Fore.YELLOW + ) - await self.print_messages_async( + return await self._render_messages_async( messages=messages, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace, ) - async def print_messages_async( + async def print_conversation_async( + self, result: AttackResult, *, include_scores: bool = False, include_reasoning_trace: bool = False + ) -> None: + """Deprecated. Use write_async instead.""" + content = await self._render_conversation_async( + result, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace + ) + await self._write_async(content) + + async def _render_messages_async( self, messages: list[Message], *, include_scores: bool = False, include_reasoning_trace: bool = False, - ) -> None: + ) -> str: """ - Print a list of messages to console with enhanced formatting. + Render a list of messages as a formatted string. Args: - messages (list): List of Message objects to print. + messages (list[Message]): List of Message objects to render. include_scores (bool): Whether to include scores. Defaults to False. include_reasoning_trace (bool): Whether to include model reasoning trace. Defaults to False. + + Returns: + str: The rendered messages text. """ if not messages: - self._print_colored(f"{self._indent} No messages to display.", Fore.YELLOW) - return + return self._format_colored(f"{self._indent} No messages to display.", Fore.YELLOW) + lines: list[str] = [] + image_pieces: list[MessagePiece] = [] turn_number = 0 for message in messages: if message.api_role == "user": turn_number += 1 - print() - self._print_colored("─" * self._width, Fore.BLUE) - self._print_colored(f"🔹 Turn {turn_number} - USER", Style.BRIGHT, Fore.BLUE) - self._print_colored("─" * self._width, Fore.BLUE) + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.BLUE)) + lines.append(self._format_colored(f"🔹 Turn {turn_number} - USER", Style.BRIGHT, Fore.BLUE)) + lines.append(self._format_colored("─" * self._width, Fore.BLUE)) elif message.api_role == "system": - print() - self._print_colored("─" * self._width, Fore.MAGENTA) - self._print_colored("🔧 SYSTEM", Style.BRIGHT, Fore.MAGENTA) - self._print_colored("─" * self._width, Fore.MAGENTA) + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.MAGENTA)) + lines.append(self._format_colored("🔧 SYSTEM", Style.BRIGHT, Fore.MAGENTA)) + lines.append(self._format_colored("─" * self._width, Fore.MAGENTA)) else: - print() - self._print_colored("─" * self._width, Fore.YELLOW) + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.YELLOW)) role_label = "ASSISTANT (SIMULATED)" if message.is_simulated else message.api_role.upper() - self._print_colored(f"🔸 {role_label}", Style.BRIGHT, Fore.YELLOW) - self._print_colored("─" * self._width, Fore.YELLOW) + lines.append(self._format_colored(f"🔸 {role_label}", Style.BRIGHT, Fore.YELLOW)) + lines.append(self._format_colored("─" * self._width, Fore.YELLOW)) for piece in message.message_pieces: if piece.original_value_data_type == "reasoning": if include_reasoning_trace: summary_text = self._extract_reasoning_summary(piece.original_value) if summary_text: - self._print_colored(f"{self._indent}💭 Reasoning Summary:", Style.DIM, Fore.CYAN) - self._print_wrapped_text(summary_text, Fore.CYAN) - print() + lines.append(self._format_colored( + f"{self._indent}💭 Reasoning Summary:", Style.DIM, Fore.CYAN + )) + lines.append(self._render_wrapped_text(summary_text, Fore.CYAN)) + lines.append("\n") continue if piece.is_blocked(): - self._print_colored(f"{self._indent}🚫 BLOCKED BY TARGET", Style.BRIGHT, Fore.RED) + lines.append(self._format_colored( + f"{self._indent}🚫 BLOCKED BY TARGET", Style.BRIGHT, Fore.RED + )) partial_content = piece.prompt_metadata.get("partial_content") if partial_content: - self._print_colored( + lines.append(self._format_colored( f"{self._indent}📝 Partial content (before filter triggered):", Style.DIM, Fore.CYAN, - ) - self._print_wrapped_text(str(partial_content), Fore.YELLOW) + )) + lines.append(self._render_wrapped_text(str(partial_content), Fore.YELLOW)) else: - self._print_colored( + lines.append(self._format_colored( f"{self._indent}Content was blocked by the target's content filter.", Style.DIM, Fore.RED, - ) + )) elif piece.converted_value != piece.original_value: - self._print_colored(f"{self._indent} Original:", Fore.CYAN) - self._print_wrapped_text(piece.original_value, Fore.WHITE) - print() - self._print_colored(f"{self._indent} Converted:", Fore.CYAN) - self._print_wrapped_text(piece.converted_value, Fore.WHITE) + lines.append(self._format_colored(f"{self._indent} Original:", Fore.CYAN)) + lines.append(self._render_wrapped_text(piece.original_value, Fore.WHITE)) + lines.append("\n") + lines.append(self._format_colored(f"{self._indent} Converted:", Fore.CYAN)) + lines.append(self._render_wrapped_text(piece.converted_value, Fore.WHITE)) elif piece.api_role == "user": - self._print_wrapped_text(piece.converted_value, Fore.BLUE) + lines.append(self._render_wrapped_text(piece.converted_value, Fore.BLUE)) elif piece.api_role == "system": - self._print_wrapped_text(piece.converted_value, Fore.MAGENTA) + lines.append(self._render_wrapped_text(piece.converted_value, Fore.MAGENTA)) else: - self._print_wrapped_text(piece.converted_value, Fore.YELLOW) + lines.append(self._render_wrapped_text(piece.converted_value, Fore.YELLOW)) - await self.display_image_async(piece) + image_pieces.append(piece) if include_scores: - scores = await self.get_scores_async(prompt_ids=[str(piece.id)]) + scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) if scores: - print() - self._print_colored(f"{self._indent}📊 Scores:", Style.DIM, Fore.MAGENTA) + lines.append("\n") + lines.append(self._format_colored( + f"{self._indent}📊 Scores:", Style.DIM, Fore.MAGENTA + )) for score in scores: - self._print_score(score) + lines.append(self._render_score(score)) + + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.BLUE)) + + for piece in image_pieces: + await self._display_image_async(piece) - print() - self._print_colored("─" * self._width, Fore.BLUE) + return "".join(lines) + + async def print_messages_async( + self, + messages: list[Message], + *, + include_scores: bool = False, + include_reasoning_trace: bool = False, + ) -> None: + """Deprecated. Use write_async instead.""" + content = await self._render_messages_async( + messages=messages, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace + ) + await self._write_async(content) def _extract_reasoning_summary(self, reasoning_value: str) -> str: """ @@ -228,116 +283,159 @@ def _extract_reasoning_summary(self, reasoning_value: str) -> str: parts = [item.get("text", "") for item in summary if isinstance(item, dict) and item.get("text")] return "\n".join(parts) - async def print_summary_async(self, result: AttackResult) -> None: + async def _render_summary_async(self, result: AttackResult) -> str: """ - Print a summary of the attack result. + Render a summary of the attack result. Args: result (AttackResult): The attack result to summarize. + + Returns: + str: The rendered summary text. """ - self._print_section_header("Attack Summary") + lines: list[str] = [] + lines.append(self._render_section_header("Attack Summary")) - self._print_colored(f"{self._indent}📋 Basic Information", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Objective: {result.objective}", Fore.CYAN) + lines.append(self._format_colored(f"{self._indent}📋 Basic Information", Style.BRIGHT)) + lines.append(self._format_colored(f"{self._indent * 2}• Objective: {result.objective}", Fore.CYAN)) attack_type = "Unknown" attack_strategy_id = result.get_attack_strategy_identifier() if attack_strategy_id: attack_type = attack_strategy_id.class_name - self._print_colored(f"{self._indent * 2}• Attack Type: {attack_type}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• Conversation ID: {result.conversation_id}", Fore.CYAN) - - print() - self._print_colored(f"{self._indent}⚡ Execution Metrics", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Turns Executed: {result.executed_turns}", Fore.GREEN) - self._print_colored( + lines.append(self._format_colored(f"{self._indent * 2}• Attack Type: {attack_type}", Fore.CYAN)) + lines.append(self._format_colored( + f"{self._indent * 2}• Conversation ID: {result.conversation_id}", Fore.CYAN + )) + + lines.append("\n") + lines.append(self._format_colored(f"{self._indent}⚡ Execution Metrics", Style.BRIGHT)) + lines.append(self._format_colored( + f"{self._indent * 2}• Turns Executed: {result.executed_turns}", Fore.GREEN + )) + lines.append(self._format_colored( f"{self._indent * 2}• Execution Time: {self._format_time(result.execution_time_ms)}", Fore.GREEN - ) + )) - print() - self._print_colored(f"{self._indent}🎯 Outcome", Style.BRIGHT) + lines.append("\n") + lines.append(self._format_colored(f"{self._indent}🎯 Outcome", Style.BRIGHT)) outcome_icon = self._get_outcome_icon(result.outcome) outcome_color = self._get_outcome_color(result.outcome) - self._print_colored(f"{self._indent * 2}• Status: {outcome_icon} {result.outcome.value.upper()}", outcome_color) + lines.append(self._format_colored( + f"{self._indent * 2}• Status: {outcome_icon} {result.outcome.value.upper()}", outcome_color + )) if result.outcome_reason: - self._print_colored(f"{self._indent * 2}• Reason: {result.outcome_reason}", Fore.WHITE) + lines.append(self._format_colored(f"{self._indent * 2}• Reason: {result.outcome_reason}", Fore.WHITE)) if result.last_score: - print() - self._print_colored(f"{self._indent} Final Score", Style.BRIGHT) - self._print_score(result.last_score, indent_level=2) + lines.append("\n") + lines.append(self._format_colored(f"{self._indent} Final Score", Style.BRIGHT)) + lines.append(self._render_score(result.last_score, indent_level=2)) + + return "".join(lines) - def _print_header(self, result: AttackResult) -> None: + async def print_summary_async(self, result: AttackResult) -> None: + """Deprecated. Use write_async instead.""" + content = await self._render_summary_async(result) + await self._write_async(content) + + def _render_header(self, result: AttackResult) -> str: """ - Print the header with outcome-based coloring. + Render the header with outcome-based coloring. Args: result (AttackResult): The attack result containing the outcome. + + Returns: + str: The rendered header text. """ color = self._get_outcome_color(result.outcome) icon = self._get_outcome_icon(result.outcome) - print() - self._print_colored("═" * self._width, color) + lines: list[str] = [] + lines.append("\n") + lines.append(self._format_colored("═" * self._width, color)) header_text = f"{icon} ATTACK RESULT: {result.outcome.value.upper()} {icon}" - self._print_colored(header_text.center(self._width), Style.BRIGHT, color) - self._print_colored("═" * self._width, color) + lines.append(self._format_colored(header_text.center(self._width), Style.BRIGHT, color)) + lines.append(self._format_colored("═" * self._width, color)) + return "".join(lines) + + def _render_footer(self) -> str: + """ + Render a footer with timestamp. - def _print_footer(self) -> None: - """Print a footer with timestamp.""" + Returns: + str: The rendered footer text. + """ timestamp = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S") - print() - self._print_colored("─" * self._width, Style.DIM, Fore.WHITE) + lines: list[str] = [] + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Style.DIM, Fore.WHITE)) footer_text = f"Report generated at: {timestamp} UTC" - self._print_colored(footer_text.center(self._width), Style.DIM, Fore.WHITE) + lines.append(self._format_colored(footer_text.center(self._width), Style.DIM, Fore.WHITE)) + return "".join(lines) - def _print_section_header(self, title: str) -> None: + def _render_section_header(self, title: str) -> str: """ - Print a section header with consistent styling. + Render a section header with consistent styling. Args: title (str): The title text to display. + + Returns: + str: The rendered section header text. """ - print() - self._print_colored(f" {title} ", Style.BRIGHT, Back.BLUE, Fore.WHITE) - self._print_colored("─" * self._width, Fore.BLUE) + lines: list[str] = [] + lines.append("\n") + lines.append(self._format_colored(f" {title} ", Style.BRIGHT, Back.BLUE, Fore.WHITE)) + lines.append(self._format_colored("─" * self._width, Fore.BLUE)) + return "".join(lines) - def _print_metadata(self, metadata: dict[str, Any]) -> None: + def _render_metadata(self, metadata: dict[str, Any]) -> str: """ - Print metadata in a formatted way. + Render metadata in a formatted way. Args: metadata (dict[str, Any]): Dictionary containing metadata key-value pairs. + + Returns: + str: The rendered metadata text. """ - self._print_section_header("Additional Metadata") + lines: list[str] = [] + lines.append(self._render_section_header("Additional Metadata")) for key, value in metadata.items(): - self._print_colored(f"{self._indent}• {key}: {value}", Fore.CYAN) + lines.append(self._format_colored(f"{self._indent}• {key}: {value}", Fore.CYAN)) + return "".join(lines) - def _print_score(self, score: Score, indent_level: int = 3) -> None: + def _render_score(self, score: Score, indent_level: int = 3) -> str: """ - Print a score with proper formatting. + Render a score with proper formatting. Args: - score (Score): Score object to be printed. + score (Score): Score object to be rendered. indent_level (int): Number of indent units to apply. Defaults to 3. + + Returns: + str: The rendered score text. """ + lines: list[str] = [] indent = self._indent * indent_level scorer_name = score.scorer_class_identifier.class_name - print(f"{indent}Scorer: {scorer_name}") - self._print_colored(f"{indent}• Category: {score.score_category or 'N/A'}", Fore.LIGHTMAGENTA_EX) - self._print_colored(f"{indent}• Type: {score.score_type}", Fore.CYAN) + lines.append(f"{indent}Scorer: {scorer_name}\n") + lines.append(self._format_colored(f"{indent}• Category: {score.score_category or 'N/A'}", Fore.LIGHTMAGENTA_EX)) + lines.append(self._format_colored(f"{indent}• Type: {score.score_type}", Fore.CYAN)) if score.score_type == "true_false": score_color = Fore.GREEN if score.get_value() else Fore.RED else: score_color = Fore.YELLOW - self._print_colored(f"{indent}• Value: {score.score_value}", score_color) + lines.append(self._format_colored(f"{indent}• Value: {score.score_value}", score_color)) if score.score_rationale: - print(f"{indent}• Rationale:") + lines.append(f"{indent}• Rationale:\n") rationale_wrapper = textwrap.TextWrapper( width=self._width - len(indent) - 2, initial_indent=indent + " ", @@ -345,23 +443,29 @@ def _print_score(self, score: Score, indent_level: int = 3) -> None: break_long_words=False, break_on_hyphens=False, ) - lines = score.score_rationale.split("\n") - for line in lines: + rationale_lines = score.score_rationale.split("\n") + for line in rationale_lines: if line.strip(): wrapped_lines = rationale_wrapper.wrap(line) for wrapped_line in wrapped_lines: - self._print_colored(wrapped_line, Fore.WHITE) + lines.append(self._format_colored(wrapped_line, Fore.WHITE)) else: - self._print_colored(f"{indent} ") + lines.append(self._format_colored(f"{indent} ")) - def _print_wrapped_text(self, text: str, color: str) -> None: + return "".join(lines) + + def _render_wrapped_text(self, text: str, color: str) -> str: """ - Print text with proper wrapping and indentation, preserving newlines. + Render text with proper wrapping and indentation, preserving newlines. Args: - text (str): The text to print. + text (str): The text to render. color (str): Colorama color constant to apply. + + Returns: + str: The rendered wrapped text. """ + lines: list[str] = [] text_wrapper = textwrap.TextWrapper( width=self._width - len(self._indent), initial_indent="", @@ -372,103 +476,118 @@ def _print_wrapped_text(self, text: str, color: str) -> None: replace_whitespace=False, ) - lines = text.split("\n") - for line_num, line in enumerate(lines): + text_lines = text.split("\n") + for line_num, line in enumerate(text_lines): if line.strip(): wrapped_lines = text_wrapper.wrap(line) for i, wrapped_line in enumerate(wrapped_lines): if line_num == 0 and i == 0: - self._print_colored(f"{self._indent}{wrapped_line}", color) + lines.append(self._format_colored(f"{self._indent}{wrapped_line}", color)) else: - self._print_colored(f"{self._indent * 2}{wrapped_line}", color) + lines.append(self._format_colored(f"{self._indent * 2}{wrapped_line}", color)) else: - self._print_colored(f"{self._indent}", color) + lines.append(self._format_colored(f"{self._indent}", color)) + + return "".join(lines) - async def _print_pruned_conversations_async(self, result: AttackResult) -> None: + async def _render_pruned_conversations_async(self, result: AttackResult) -> str: """ - Print pruned conversations showing only the last message and score for each. + Render pruned conversations showing only the last message and score for each. Args: result (AttackResult): The attack result containing related conversations. + + Returns: + str: The rendered pruned conversations text. """ pruned_refs = result.get_conversations_by_type(ConversationType.PRUNED) if not pruned_refs: - return + return "" - self._print_section_header(f"Pruned Conversations ({len(pruned_refs)} total)") + lines: list[str] = [] + lines.append(self._render_section_header(f"Pruned Conversations ({len(pruned_refs)} total)")) for idx, ref in enumerate(pruned_refs, 1): - print() - self._print_colored("─" * self._width, Fore.RED) + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.RED)) label = f"🗑️ PRUNED #{idx}" if ref.description: label += f" - {ref.description}" - self._print_colored(label, Style.BRIGHT, Fore.RED) - self._print_colored("─" * self._width, Fore.RED) + lines.append(self._format_colored(label, Style.BRIGHT, Fore.RED)) + lines.append(self._format_colored("─" * self._width, Fore.RED)) - messages = await self.get_conversation_async(ref.conversation_id) + messages = await self._get_conversation_async(ref.conversation_id) if not messages: - self._print_colored( + lines.append(self._format_colored( f"{self._indent}No messages found for conversation: {ref.conversation_id}", Fore.YELLOW - ) + )) continue last_message = messages[-1] role_label = last_message.api_role.upper() - self._print_colored(f"{self._indent}Last Message ({role_label}):", Style.BRIGHT, Fore.WHITE) + lines.append(self._format_colored( + f"{self._indent}Last Message ({role_label}):", Style.BRIGHT, Fore.WHITE + )) for piece in last_message.message_pieces: - self._print_wrapped_text(piece.converted_value, Fore.WHITE) + lines.append(self._render_wrapped_text(piece.converted_value, Fore.WHITE)) - scores = await self.get_scores_async(prompt_ids=[str(piece.id)]) + scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) if scores: - print() - self._print_colored(f"{self._indent}📊 Score:", Style.DIM, Fore.MAGENTA) + lines.append("\n") + lines.append(self._format_colored(f"{self._indent}📊 Score:", Style.DIM, Fore.MAGENTA)) for score in scores: - self._print_score(score) + lines.append(self._render_score(score)) - print() - self._print_colored("─" * self._width, Fore.RED) + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.RED)) + return "".join(lines) - async def _print_adversarial_conversation_async(self, result: AttackResult) -> None: + async def _render_adversarial_conversation_async(self, result: AttackResult) -> str: """ - Print the adversarial conversation for the best-scoring attack branch. + Render the adversarial conversation for the best-scoring attack branch. Args: result (AttackResult): The attack result containing related conversations. + + Returns: + str: The rendered adversarial conversation text. """ adversarial_refs = result.get_conversations_by_type(ConversationType.ADVERSARIAL) if not adversarial_refs: - return + return "" - self._print_section_header("Adversarial Conversation (Red Team LLM)") + lines: list[str] = [] + lines.append(self._render_section_header("Adversarial Conversation (Red Team LLM)")) best_adversarial_id = result.metadata.get("best_adversarial_conversation_id") if best_adversarial_id: adversarial_refs = [ref for ref in adversarial_refs if ref.conversation_id == best_adversarial_id] if adversarial_refs: - self._print_colored( + lines.append(self._format_colored( f"{self._indent}📌 Showing best-scoring branch's adversarial conversation", Style.DIM, Fore.CYAN, - ) + )) for ref in adversarial_refs: if ref.description: - self._print_colored(f"{self._indent}📝 {ref.description}", Style.DIM, Fore.CYAN) + lines.append(self._format_colored(f"{self._indent}📝 {ref.description}", Style.DIM, Fore.CYAN)) - messages = await self.get_conversation_async(ref.conversation_id) + messages = await self._get_conversation_async(ref.conversation_id) if not messages: - self._print_colored( + lines.append(self._format_colored( f"{self._indent}No messages found for conversation: {ref.conversation_id}", Fore.YELLOW - ) + )) continue - await self.print_messages_async(messages=messages, include_scores=False) + lines.append(await self._render_messages_async(messages=messages, include_scores=False)) + + return "".join(lines) def _get_outcome_color(self, outcome: AttackOutcome) -> str: """ @@ -488,7 +607,7 @@ def _get_outcome_color(self, outcome: AttackOutcome) -> str: }.get(outcome, Fore.WHITE) ) - async def display_image_async(self, piece: MessagePiece) -> None: + async def _display_image_async(self, piece: MessagePiece) -> None: """Display images using PIL/IPython in notebook environments.""" from pyrit.common.display_response import display_image_response @@ -520,7 +639,7 @@ def __init__( self._memory = CentralMemory.get_memory_instance() - async def get_conversation_async(self, conversation_id: str) -> list[Message]: + async def _get_conversation_async(self, conversation_id: str) -> list[Message]: """ Fetch conversation messages from CentralMemory. @@ -529,7 +648,7 @@ async def get_conversation_async(self, conversation_id: str) -> list[Message]: """ return list(self._memory.get_conversation(conversation_id=conversation_id)) - async def get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: + async def _get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: """ Fetch scores from CentralMemory. diff --git a/pyrit/printer/base.py b/pyrit/printer/base.py index c68ce155a0..7e8692e4cb 100644 --- a/pyrit/printer/base.py +++ b/pyrit/printer/base.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from abc import ABC +from abc import ABC, abstractmethod from pyrit.printer.sink import Sink, StdoutSink @@ -10,8 +10,9 @@ class PrinterBase(ABC): """ Abstract base class for all printers. - Provides a sink for output routing. Subclasses write their rendered - output through the sink via ``_write_async``. + Provides a sink for output routing. Subclasses must implement + ``write_async`` as their public entry point, and use ``_write_async`` + to send rendered text to the sink. """ def __init__(self, *, sink: Sink | None = None) -> None: @@ -23,11 +24,20 @@ def __init__(self, *, sink: Sink | None = None) -> None: """ self._sink = sink or StdoutSink() - async def _write_async(self, data: bytes) -> None: + @abstractmethod + async def write_async(self, *args, **kwargs) -> None: + """ + Render and write output to the configured sink. + + Subclasses define the specific signature (e.g., scorer_identifier, + result, etc.). + """ + + async def _write_async(self, data: str) -> None: """ Write data through the configured sink. Args: - data (bytes): The rendered output to write. + data (str): The rendered text output to write. """ await self._sink.write_async(data) diff --git a/pyrit/printer/scenario_result/base.py b/pyrit/printer/scenario_result/base.py index 1545c59cfd..1b6d1c1843 100644 --- a/pyrit/printer/scenario_result/base.py +++ b/pyrit/printer/scenario_result/base.py @@ -16,10 +16,19 @@ class ScenarioResultPrinterBase(PrinterBase): """ @abstractmethod + async def write_async(self, result: ScenarioResult) -> None: + """ + Render and write a scenario result summary to the configured sink. + + Args: + result (ScenarioResult): The scenario result to summarize. + """ + async def print_summary_async(self, result: ScenarioResult) -> None: """ - Print a summary of the scenario result with per-strategy breakdown. + Deprecated. Use write_async instead. Args: result (ScenarioResult): The scenario result to summarize. """ + await self.write_async(result) diff --git a/pyrit/printer/scenario_result/pretty.py b/pyrit/printer/scenario_result/pretty.py index 9c18f38943..e3f72c7685 100644 --- a/pyrit/printer/scenario_result/pretty.py +++ b/pyrit/printer/scenario_result/pretty.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import textwrap -from abc import abstractmethod from colorama import Fore, Style @@ -43,92 +42,156 @@ def __init__( self._indent = " " * indent_size self._enable_colors = enable_colors - @property - @abstractmethod - def scorer_printer(self) -> ScorerPrinterBase: - """Return the scorer printer instance.""" - - def _print_colored(self, text: str, *colors: str) -> None: + def _format_colored(self, text: str, *colors: str) -> str: """ - Print text with color formatting if colors are enabled. + Format text with color codes if colors are enabled. Args: - text (str): The text to print. + text (str): The text to format. *colors: Variable number of colorama color constants to apply. + + Returns: + str: The formatted line with trailing newline. """ if self._enable_colors and colors: color_prefix = "".join(colors) - print(f"{color_prefix}{text}{Style.RESET_ALL}") - else: - print(text) + return f"{color_prefix}{text}{Style.RESET_ALL}\n" + return f"{text}\n" - def _print_section_header(self, title: str) -> None: + def _render_section_header(self, title: str) -> str: """ - Print a section header with visual separation. + Render a section header with visual separation. Args: title (str): The section title to display. + + Returns: + str: The rendered section header. """ - print() - self._print_colored(f"▼ {title}", Style.BRIGHT, Fore.CYAN) - self._print_colored("─" * self._width, Fore.CYAN) + lines: list[str] = [] + lines.append("\n") + lines.append(self._format_colored(f"▼ {title}", Style.BRIGHT, Fore.CYAN)) + lines.append(self._format_colored("─" * self._width, Fore.CYAN)) + return "".join(lines) - async def print_summary_async(self, result: ScenarioResult) -> None: + def _render_header(self, result: ScenarioResult) -> str: """ - Print a summary of the scenario result with per-group breakdown. + Render the header with scenario name. Args: - result (ScenarioResult): The scenario result to summarize. + result (ScenarioResult): The scenario result. + + Returns: + str: The rendered header. + """ + lines: list[str] = [] + lines.append("\n") + lines.append(self._format_colored("=" * self._width, Fore.CYAN)) + header_text = f"📊 SCENARIO RESULTS: {result.scenario_identifier.name}" + lines.append(self._format_colored(header_text.center(self._width), Style.BRIGHT, Fore.CYAN)) + lines.append(self._format_colored("=" * self._width, Fore.CYAN)) + return "".join(lines) + + def _render_footer(self) -> str: """ - self._print_header(result) + Render a footer separator. - self._print_section_header("Scenario Information") - self._print_colored(f"{self._indent}📋 Scenario Details", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Name: {result.scenario_identifier.name}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• Scenario Version: {result.scenario_identifier.version}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• PyRIT Version: {result.scenario_identifier.pyrit_version}", Fore.CYAN) + Returns: + str: The rendered footer. + """ + lines: list[str] = [] + lines.append("\n") + lines.append(self._format_colored("=" * self._width, Fore.CYAN)) + lines.append("\n") + return "".join(lines) + + def _get_rate_color(self, rate: int) -> str: + """ + Get color based on success rate. + + Args: + rate (int): Success rate percentage (0-100). + + Returns: + str: Colorama color constant. + """ + if rate >= 75: + return str(Fore.RED) + if rate >= 50: + return str(Fore.YELLOW) + if rate >= 25: + return str(Fore.CYAN) + return str(Fore.GREEN) + + async def write_async(self, result: ScenarioResult) -> None: + """ + Render and write the scenario result summary to the configured sink. + + Args: + result (ScenarioResult): The scenario result to summarize. + """ + lines: list[str] = [] + + lines.append(self._render_header(result)) + + lines.append(self._render_section_header("Scenario Information")) + lines.append(self._format_colored(f"{self._indent}📋 Scenario Details", Style.BRIGHT)) + lines.append(self._format_colored( + f"{self._indent * 2}• Name: {result.scenario_identifier.name}", Fore.CYAN + )) + lines.append(self._format_colored( + f"{self._indent * 2}• Scenario Version: {result.scenario_identifier.version}", Fore.CYAN + )) + lines.append(self._format_colored( + f"{self._indent * 2}• PyRIT Version: {result.scenario_identifier.pyrit_version}", Fore.CYAN + )) if result.scenario_identifier.description: - self._print_colored(f"{self._indent * 2}• Description:", Fore.CYAN) + lines.append(self._format_colored(f"{self._indent * 2}• Description:", Fore.CYAN)) desc_indent = self._indent * 4 available_width = 120 - len(desc_indent) wrapped_lines = textwrap.wrap( result.scenario_identifier.description, width=available_width, break_long_words=False ) for line in wrapped_lines: - self._print_colored(f"{desc_indent}{line}", Fore.CYAN) + lines.append(self._format_colored(f"{desc_indent}{line}", Fore.CYAN)) - print() - self._print_colored(f"{self._indent}🎯 Target Information", Style.BRIGHT) + lines.append("\n") + lines.append(self._format_colored(f"{self._indent}🎯 Target Information", Style.BRIGHT)) target_id = result.objective_target_identifier target_type = target_id.class_name if target_id else "Unknown" target_model = target_id.params.get("model_name", "Unknown") if target_id else "Unknown" target_endpoint = target_id.params.get("endpoint", "Unknown") if target_id else "Unknown" - self._print_colored(f"{self._indent * 2}• Target Type: {target_type}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• Target Model: {target_model}", Fore.CYAN) - self._print_colored(f"{self._indent * 2}• Target Endpoint: {target_endpoint}", Fore.CYAN) + lines.append(self._format_colored(f"{self._indent * 2}• Target Type: {target_type}", Fore.CYAN)) + lines.append(self._format_colored(f"{self._indent * 2}• Target Model: {target_model}", Fore.CYAN)) + lines.append(self._format_colored(f"{self._indent * 2}• Target Endpoint: {target_endpoint}", Fore.CYAN)) + + # Write what we have so far, then let the scorer printer write its own section + await self._write_async("".join(lines)) scorer_identifier = result.objective_scorer_identifier if scorer_identifier: - self.scorer_printer.print_objective_scorer(scorer_identifier=scorer_identifier) + await self._scorer_printer.write_async(scorer_identifier=scorer_identifier) - self._print_section_header("Overall Statistics") + # Continue with stats + lines = [] + lines.append(self._render_section_header("Overall Statistics")) total_results = sum(len(results) for results in result.attack_results.values()) total_strategies = len(result.get_strategies_used()) overall_rate = result.objective_achieved_rate() - self._print_colored(f"{self._indent}📈 Summary", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Total Strategies: {total_strategies}", Fore.GREEN) - self._print_colored(f"{self._indent * 2}• Total Attack Results: {total_results}", Fore.GREEN) - self._print_colored( + lines.append(self._format_colored(f"{self._indent}📈 Summary", Style.BRIGHT)) + lines.append(self._format_colored(f"{self._indent * 2}• Total Strategies: {total_strategies}", Fore.GREEN)) + lines.append(self._format_colored(f"{self._indent * 2}• Total Attack Results: {total_results}", Fore.GREEN)) + lines.append(self._format_colored( f"{self._indent * 2}• Overall Success Rate: {overall_rate}%", self._get_rate_color(overall_rate) - ) + )) objectives = result.get_objectives() - self._print_colored(f"{self._indent * 2}• Unique Objectives: {len(objectives)}", Fore.GREEN) + lines.append(self._format_colored(f"{self._indent * 2}• Unique Objectives: {len(objectives)}", Fore.GREEN)) - self._print_section_header("Per-Group Breakdown") + lines.append(self._render_section_header("Per-Group Breakdown")) display_groups = result.get_display_groups() for group_name, group_results in display_groups.items(): @@ -139,49 +202,26 @@ async def print_summary_async(self, result: ScenarioResult) -> None: successful = sum(1 for r in group_results if r.outcome == AttackOutcome.SUCCESS) group_rate = int((successful / total_group) * 100) - print() - self._print_colored(f"{self._indent}🔸 Group: {group_name}", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}• Number of Results: {total_group}", Fore.YELLOW) - self._print_colored(f"{self._indent * 2}• Success Rate: {group_rate}%", self._get_rate_color(group_rate)) - - self._print_footer() - - def _print_header(self, result: ScenarioResult) -> None: - """ - Print the header with scenario name. - - Args: - result (ScenarioResult): The scenario result. - """ - print() - self._print_colored("=" * self._width, Fore.CYAN) - header_text = f"📊 SCENARIO RESULTS: {result.scenario_identifier.name}" - self._print_colored(header_text.center(self._width), Style.BRIGHT, Fore.CYAN) - self._print_colored("=" * self._width, Fore.CYAN) + lines.append("\n") + lines.append(self._format_colored(f"{self._indent}🔸 Group: {group_name}", Style.BRIGHT)) + lines.append(self._format_colored( + f"{self._indent * 2}• Number of Results: {total_group}", Fore.YELLOW + )) + lines.append(self._format_colored( + f"{self._indent * 2}• Success Rate: {group_rate}%", self._get_rate_color(group_rate) + )) - def _print_footer(self) -> None: - """Print a footer separator.""" - print() - self._print_colored("=" * self._width, Fore.CYAN) - print() + lines.append(self._render_footer()) + await self._write_async("".join(lines)) - def _get_rate_color(self, rate: int) -> str: + async def print_summary_async(self, result: ScenarioResult) -> None: """ - Get color based on success rate. + Deprecated. Use write_async instead. Args: - rate (int): Success rate percentage (0-100). - - Returns: - str: Colorama color constant. + result (ScenarioResult): The scenario result to summarize. """ - if rate >= 75: - return str(Fore.RED) - if rate >= 50: - return str(Fore.YELLOW) - if rate >= 25: - return str(Fore.CYAN) - return str(Fore.GREEN) + await self.write_async(result) class PrettyScenarioResultMemoryPrinter(PrettyScenarioResultPrinter): @@ -215,8 +255,3 @@ def __init__( self._scorer_printer = PrettyScorerMemoryPrinter( sink=self._sink, indent_size=indent_size, enable_colors=enable_colors ) - - @property - def scorer_printer(self) -> ScorerPrinterBase: - """Return the scorer printer instance.""" - return self._scorer_printer diff --git a/pyrit/printer/scorer/base.py b/pyrit/printer/scorer/base.py index 418c4b233a..88e3fd4c15 100644 --- a/pyrit/printer/scorer/base.py +++ b/pyrit/printer/scorer/base.py @@ -12,9 +12,8 @@ class ScorerPrinterBase(PrinterBase): """ Abstract base class for printing scorer information. - Subclasses must implement get_objective_metrics and get_harm_metrics - for data fetching. Orchestration methods (print_objective_scorer, - print_harm_scorer) live in concrete formatting subclasses. + Subclasses must implement _get_objective_metrics and _get_harm_metrics + for data fetching, and write_async for rendering + writing. """ @abstractmethod @@ -43,20 +42,35 @@ def _get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: """ @abstractmethod - def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: + async def write_async( + self, *, scorer_identifier: ComponentIdentifier, harm_category: str | None = None + ) -> None: """ - Print objective scorer information including type, nested scorers, and evaluation metrics. + Render and write scorer information to the configured sink. + + Auto-detects scorer type: if harm_category is provided, renders harm + metrics; otherwise renders objective metrics. Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. + scorer_identifier (ComponentIdentifier): The scorer identifier. + harm_category (str | None): The harm category. None for objective scorers. """ - @abstractmethod - def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, harm_category: str) -> None: + async def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: + """ + Deprecated. Use write_async instead. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier. + """ + await self.write_async(scorer_identifier=scorer_identifier) + + async def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, harm_category: str) -> None: """ - Print harm scorer information including type, nested scorers, and evaluation metrics. + Deprecated. Use write_async instead. Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. - harm_category (str): The harm category for looking up metrics. + scorer_identifier (ComponentIdentifier): The scorer identifier. + harm_category (str): The harm category. """ + await self.write_async(scorer_identifier=scorer_identifier, harm_category=harm_category) diff --git a/pyrit/printer/scorer/pretty.py b/pyrit/printer/scorer/pretty.py index 2e3c188519..ef08d44513 100644 --- a/pyrit/printer/scorer/pretty.py +++ b/pyrit/printer/scorer/pretty.py @@ -14,8 +14,8 @@ class PrettyScorerPrinter(ScorerPrinterBase): """ Pretty printer for scorer information with ANSI-colored formatting. - Contains all formatting logic. Subclasses implement get_objective_metrics - and get_harm_metrics for data fetching. + Contains all formatting logic. Subclasses implement _get_objective_metrics + and _get_harm_metrics for data fetching. """ _SCORER_DISPLAY_PARAMS = frozenset({"scorer_type", "score_aggregator"}) @@ -39,19 +39,21 @@ def __init__(self, *, sink: Sink | None = None, indent_size: int = 2, enable_col self._indent = " " * indent_size self._enable_colors = enable_colors - def _print_colored(self, text: str, *colors: str) -> None: + def _format_colored(self, text: str, *colors: str) -> str: """ - Print text with color formatting if colors are enabled. + Format text with color codes if colors are enabled. Args: - text (str): The text to print. + text (str): The text to format. *colors: Variable number of colorama color constants to apply. + + Returns: + str: The formatted line with trailing newline. """ if self._enable_colors and colors: color_prefix = "".join(colors) - print(f"{color_prefix}{text}{Style.RESET_ALL}") - else: - print(text) + return f"{color_prefix}{text}{Style.RESET_ALL}\n" + return f"{text}\n" def _get_quality_color( self, value: float, *, higher_is_better: bool, good_threshold: float, bad_threshold: float @@ -80,178 +82,198 @@ def _get_quality_color( return str(Fore.RED) return str(Fore.CYAN) - def _print_scorer_info(self, scorer_identifier: ComponentIdentifier, *, indent_level: int = 2) -> None: + def _render_scorer_info(self, scorer_identifier: ComponentIdentifier, *, indent_level: int = 2) -> str: """ - Print scorer information including nested sub-scorers. + Render scorer information including nested sub-scorers. Args: scorer_identifier (ComponentIdentifier): The scorer identifier. indent_level (int): Current indentation level. + + Returns: + str: The rendered scorer info text. """ + lines: list[str] = [] indent = self._indent * indent_level - self._print_colored(f"{indent}• Scorer Type: {scorer_identifier.class_name}", Fore.CYAN) + lines.append(self._format_colored(f"{indent}• Scorer Type: {scorer_identifier.class_name}", Fore.CYAN)) for key, value in scorer_identifier.params.items(): if key in self._SCORER_DISPLAY_PARAMS and value is not None: - self._print_colored(f"{indent}• {key}: {value}", Fore.CYAN) + lines.append(self._format_colored(f"{indent}• {key}: {value}", Fore.CYAN)) prompt_target = scorer_identifier.get_child("prompt_target") if prompt_target: for key, value in prompt_target.params.items(): if key in self._TARGET_DISPLAY_PARAMS and value is not None: - self._print_colored(f"{indent}• {key}: {value}", Fore.CYAN) + lines.append(self._format_colored(f"{indent}• {key}: {value}", Fore.CYAN)) sub_scorers = scorer_identifier.get_child_list("sub_scorers") if sub_scorers: - self._print_colored(f"{indent} └─ Composite of {len(sub_scorers)} scorer(s):", Fore.CYAN) + lines.append(self._format_colored(f"{indent} └─ Composite of {len(sub_scorers)} scorer(s):", Fore.CYAN)) for sub_scorer_id in sub_scorers: - self._print_scorer_info(sub_scorer_id, indent_level=indent_level + 3) + lines.append(self._render_scorer_info(sub_scorer_id, indent_level=indent_level + 3)) + + return "".join(lines) - def _print_objective_metrics(self, metrics: Optional[Any]) -> None: + def _render_objective_metrics(self, metrics: Optional[Any]) -> str: """ - Print objective scorer evaluation metrics. + Render objective scorer evaluation metrics. Args: - metrics: The metrics to print, or None if not available. + metrics: The metrics to render, or None if not available. + + Returns: + str: The rendered metrics text. """ + lines: list[str] = [] + if metrics is None: - print() - self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) - self._print_colored( + lines.append("\n") + lines.append(self._format_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE)) + lines.append(self._format_colored( f"{self._indent * 3}Official evaluation has not been run yet for this specific configuration", Fore.YELLOW, - ) - return + )) + return "".join(lines) - print() - self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) + lines.append("\n") + lines.append(self._format_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE)) accuracy_color = self._get_quality_color( metrics.accuracy, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 ) - self._print_colored(f"{self._indent * 3}• Accuracy: {metrics.accuracy:.2%}", accuracy_color) + lines.append(self._format_colored(f"{self._indent * 3}• Accuracy: {metrics.accuracy:.2%}", accuracy_color)) if metrics.accuracy_standard_error is not None: - self._print_colored( + lines.append(self._format_colored( f"{self._indent * 3}• Accuracy Std Error: ±{metrics.accuracy_standard_error:.4f}", Fore.CYAN - ) + )) if metrics.f1_score is not None: f1_color = self._get_quality_color( metrics.f1_score, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 ) - self._print_colored(f"{self._indent * 3}• F1 Score: {metrics.f1_score:.4f}", f1_color) + lines.append(self._format_colored(f"{self._indent * 3}• F1 Score: {metrics.f1_score:.4f}", f1_color)) if metrics.precision is not None: precision_color = self._get_quality_color( metrics.precision, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 ) - self._print_colored(f"{self._indent * 3}• Precision: {metrics.precision:.4f}", precision_color) + lines.append(self._format_colored( + f"{self._indent * 3}• Precision: {metrics.precision:.4f}", precision_color + )) if metrics.recall is not None: recall_color = self._get_quality_color( metrics.recall, higher_is_better=True, good_threshold=0.9, bad_threshold=0.7 ) - self._print_colored(f"{self._indent * 3}• Recall: {metrics.recall:.4f}", recall_color) + lines.append(self._format_colored(f"{self._indent * 3}• Recall: {metrics.recall:.4f}", recall_color)) if metrics.average_score_time_seconds is not None: time_color = self._get_quality_color( metrics.average_score_time_seconds, higher_is_better=False, good_threshold=0.5, bad_threshold=3.0 ) - self._print_colored( + lines.append(self._format_colored( f"{self._indent * 3}• Average Score Time: {metrics.average_score_time_seconds:.2f}s", time_color - ) + )) - def _print_harm_metrics(self, metrics: Optional[Any]) -> None: + return "".join(lines) + + def _render_harm_metrics(self, metrics: Optional[Any]) -> str: """ - Print harm scorer evaluation metrics. + Render harm scorer evaluation metrics. Args: - metrics: The metrics to print, or None if not available. + metrics: The metrics to render, or None if not available. + + Returns: + str: The rendered metrics text. """ + lines: list[str] = [] + if metrics is None: - print() - self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) - self._print_colored( + lines.append("\n") + lines.append(self._format_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE)) + lines.append(self._format_colored( f"{self._indent * 3}Official evaluation has not been run yet for this specific configuration", Fore.YELLOW, - ) - return + )) + return "".join(lines) - print() - self._print_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE) + lines.append("\n") + lines.append(self._format_colored(f"{self._indent * 2}▸ Performance Metrics", Fore.WHITE)) mae_color = self._get_quality_color( metrics.mean_absolute_error, higher_is_better=False, good_threshold=0.1, bad_threshold=0.25 ) - self._print_colored(f"{self._indent * 3}• Mean Absolute Error: {metrics.mean_absolute_error:.4f}", mae_color) + lines.append(self._format_colored( + f"{self._indent * 3}• Mean Absolute Error: {metrics.mean_absolute_error:.4f}", mae_color + )) if metrics.mae_standard_error is not None: - self._print_colored(f"{self._indent * 3}• MAE Std Error: ±{metrics.mae_standard_error:.4f}", Fore.CYAN) + lines.append(self._format_colored( + f"{self._indent * 3}• MAE Std Error: ±{metrics.mae_standard_error:.4f}", Fore.CYAN + )) if metrics.krippendorff_alpha_combined is not None: alpha_color = self._get_quality_color( metrics.krippendorff_alpha_combined, higher_is_better=True, good_threshold=0.8, bad_threshold=0.6 ) - self._print_colored( + lines.append(self._format_colored( f"{self._indent * 3}• Krippendorff Alpha (Combined): {metrics.krippendorff_alpha_combined:.4f}", alpha_color, - ) + )) if metrics.krippendorff_alpha_model is not None: alpha_model_color = self._get_quality_color( metrics.krippendorff_alpha_model, higher_is_better=True, good_threshold=0.8, bad_threshold=0.6 ) - self._print_colored( + lines.append(self._format_colored( f"{self._indent * 3}• Krippendorff Alpha (Model): {metrics.krippendorff_alpha_model:.4f}", alpha_model_color, - ) + )) if metrics.average_score_time_seconds is not None: time_color = self._get_quality_color( metrics.average_score_time_seconds, higher_is_better=False, good_threshold=1.0, bad_threshold=3.0 ) - self._print_colored( + lines.append(self._format_colored( f"{self._indent * 3}• Average Score Time: {metrics.average_score_time_seconds:.2f}s", time_color - ) + )) - def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: - """ - Print objective scorer information including type, nested scorers, and evaluation metrics. + return "".join(lines) - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. + async def write_async( + self, *, scorer_identifier: ComponentIdentifier, harm_category: str | None = None + ) -> None: """ - from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier + Render and write scorer information to the configured sink. - print() - self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) - self._print_scorer_info(scorer_identifier, indent_level=3) - - eval_hash = ScorerEvaluationIdentifier(scorer_identifier).eval_hash - metrics = self._get_objective_metrics(eval_hash=eval_hash) - self._print_objective_metrics(metrics) - - def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, harm_category: str) -> None: - """ - Print harm scorer information including type, nested scorers, and evaluation metrics. + Auto-detects scorer type: if harm_category is provided, renders harm + metrics; otherwise renders objective metrics. Args: - scorer_identifier (ComponentIdentifier): The scorer identifier to print information for. - harm_category (str): The harm category for looking up metrics. + scorer_identifier (ComponentIdentifier): The scorer identifier. + harm_category (str | None): The harm category. None for objective scorers. """ from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier - print() - self._print_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT) - self._print_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE) - self._print_scorer_info(scorer_identifier, indent_level=3) + lines: list[str] = [] + lines.append("\n") + lines.append(self._format_colored(f"{self._indent}📊 Scorer Information", Style.BRIGHT)) + lines.append(self._format_colored(f"{self._indent * 2}▸ Scorer Identifier", Fore.WHITE)) + lines.append(self._render_scorer_info(scorer_identifier, indent_level=3)) eval_hash = ScorerEvaluationIdentifier(scorer_identifier).eval_hash - metrics = self._get_harm_metrics(eval_hash=eval_hash, harm_category=harm_category) - self._print_harm_metrics(metrics) + if harm_category is not None: + metrics = self._get_harm_metrics(eval_hash=eval_hash, harm_category=harm_category) + lines.append(self._render_harm_metrics(metrics)) + else: + metrics = self._get_objective_metrics(eval_hash=eval_hash) + lines.append(self._render_objective_metrics(metrics)) + + await self._write_async("".join(lines)) class PrettyScorerMemoryPrinter(PrettyScorerPrinter): diff --git a/pyrit/printer/sink.py b/pyrit/printer/sink.py index 993c237f74..aca9c07f9b 100644 --- a/pyrit/printer/sink.py +++ b/pyrit/printer/sink.py @@ -14,69 +14,60 @@ class Sink(ABC): """ @abstractmethod - async def write_async(self, data: bytes) -> None: + async def write_async(self, data: str) -> None: """ Write rendered output data. Args: - data (bytes): The rendered output to write. + data (str): The rendered text output to write. """ class StdoutSink(Sink): """ - Sink that decodes bytes to str and prints to stdout. + Sink that prints text to stdout. This is the default sink used when no sink is specified. """ - def __init__(self, *, encoding: str = "utf-8") -> None: - """ - Initialize the stdout sink. - - Args: - encoding (str): Character encoding for decoding bytes. Defaults to "utf-8". - """ - self._encoding = encoding - - async def write_async(self, data: bytes) -> None: + async def write_async(self, data: str) -> None: """ Write data to stdout. Args: - data (bytes): The data to print, decoded using the configured encoding. + data (str): The text to print. """ - print(data.decode(self._encoding), end="") + print(data, end="") class FileSink(Sink): """ - Sink that writes bytes to a file. + Sink that writes text to a file. """ - def __init__(self, *, path: Path, mode: str = "wb") -> None: + def __init__(self, *, path: Path, mode: str = "w") -> None: """ Initialize the file sink. Args: path (Path): The file path to write to. - mode (str): The file open mode. Defaults to "wb" (write binary, overwrite). - Use "ab" for append mode. + mode (str): The file open mode. Defaults to "w" (write, overwrite). + Use "a" for append mode. Raises: - ValueError: If mode is not a valid binary write mode. + ValueError: If mode is not a valid text write mode. """ - if mode not in ("wb", "ab"): - raise ValueError(f"mode must be 'wb' or 'ab', got '{mode}'") + if mode not in ("w", "a"): + raise ValueError(f"mode must be 'w' or 'a', got '{mode}'") self._path = path self._mode = mode - async def write_async(self, data: bytes) -> None: + async def write_async(self, data: str) -> None: """ Write data to a file. Args: - data (bytes): The data to write. + data (str): The text to write. """ - with open(self._path, self._mode) as f: + with open(self._path, self._mode, encoding="utf-8") as f: f.write(data) diff --git a/tests/unit/executor/attack/printer/test_pretty_printer.py b/tests/unit/executor/attack/printer/test_pretty_printer.py index fefcc38056..9f3f919399 100644 --- a/tests/unit/executor/attack/printer/test_pretty_printer.py +++ b/tests/unit/executor/attack/printer/test_pretty_printer.py @@ -98,17 +98,17 @@ def test_init_default_colors_enabled(mock_memory): assert p._enable_colors is True -def test_print_colored_no_colors(printer, capsys): - printer._print_colored("hello") - captured = capsys.readouterr() - assert "hello" in captured.out +def test_format_colored_no_colors(printer): + result = printer._format_colored("hello") + assert "hello" in result + assert result.endswith("\n") -def test_print_colored_with_colors_disabled(printer, capsys): +def test_format_colored_with_colors_disabled(printer): printer._enable_colors = False - printer._print_colored("test text", "SOME_COLOR") - captured = capsys.readouterr() - assert "test text" in captured.out + result = printer._format_colored("test text", "SOME_COLOR") + assert "test text" in result + assert result.endswith("\n") def test_get_outcome_color_success(printer): @@ -126,44 +126,39 @@ def test_get_outcome_color_undetermined(printer): assert isinstance(color, str) -def test_print_header(printer, sample_attack_result, capsys): - printer._print_header(sample_attack_result) - captured = capsys.readouterr() - assert "ATTACK RESULT" in captured.out - assert "SUCCESS" in captured.out +def test_render_header(printer, sample_attack_result): + result = printer._render_header(sample_attack_result) + assert "ATTACK RESULT" in result + assert "SUCCESS" in result -def test_print_footer(printer, capsys): - printer._print_footer() - captured = capsys.readouterr() - assert "Report generated at" in captured.out +def test_render_footer(printer): + result = printer._render_footer() + assert "Report generated at" in result -def test_print_section_header(printer, capsys): - printer._print_section_header("Test Section") - captured = capsys.readouterr() - assert "Test Section" in captured.out +def test_render_section_header(printer): + result = printer._render_section_header("Test Section") + assert "Test Section" in result -def test_print_metadata(printer, capsys): +def test_render_metadata(printer): metadata = {"key1": "value1", "key2": 42} - printer._print_metadata(metadata) - captured = capsys.readouterr() - assert "key1" in captured.out - assert "value1" in captured.out - assert "key2" in captured.out - assert "42" in captured.out + result = printer._render_metadata(metadata) + assert "key1" in result + assert "value1" in result + assert "key2" in result + assert "42" in result -def test_print_score(printer, sample_score, capsys): - printer._print_score(sample_score) - captured = capsys.readouterr() - assert "MockScorer" in captured.out - assert "true_false" in captured.out - assert "true" in captured.out +def test_render_score(printer, sample_score): + result = printer._render_score(sample_score) + assert "MockScorer" in result + assert "true_false" in result + assert "true" in result -def test_print_score_with_rationale(printer, capsys): +def test_render_score_with_rationale(printer): score = Score( score_type="float_scale", score_value="0.8", @@ -174,9 +169,8 @@ def test_print_score_with_rationale(printer, capsys): message_piece_id=str(uuid.uuid4()), scorer_class_identifier=_mock_scorer_id(), ) - printer._print_score(score) - captured = capsys.readouterr() - assert "Rationale" in captured.out + result = printer._render_score(score) + assert "Rationale" in result def test_extract_reasoning_summary_valid_json(printer): @@ -206,38 +200,34 @@ def test_extract_reasoning_summary_summary_not_list(printer): assert result == "" -async def test_print_conversation_async_no_conversation_id(printer, capsys): +async def test_render_conversation_async_no_conversation_id(printer): result = AttackResult(objective="test", conversation_id="") - await printer.print_conversation_async(result) - captured = capsys.readouterr() - assert "No conversation ID" in captured.out + content = await printer._render_conversation_async(result) + assert "No conversation ID" in content -async def test_print_conversation_async_no_messages(printer, mock_memory, capsys): +async def test_render_conversation_async_no_messages(printer, mock_memory): mock_memory.get_conversation.return_value = [] result = AttackResult(objective="test", conversation_id="conv-123") - await printer.print_conversation_async(result) - captured = capsys.readouterr() - assert "No conversation found" in captured.out + content = await printer._render_conversation_async(result) + assert "No conversation found" in content -async def test_print_messages_async_empty_list(printer, capsys): - await printer.print_messages_async(messages=[]) - captured = capsys.readouterr() - assert "No messages to display" in captured.out +async def test_render_messages_async_empty_list(printer): + content = await printer._render_messages_async(messages=[]) + assert "No messages to display" in content @patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) -async def test_print_messages_async_user_message(mock_display, printer, sample_message, capsys): - await printer.print_messages_async(messages=[sample_message]) - captured = capsys.readouterr() - assert "Turn 1" in captured.out - assert "USER" in captured.out - assert "Hello world" in captured.out +async def test_render_messages_async_user_message(mock_display, printer, sample_message): + content = await printer._render_messages_async(messages=[sample_message]) + assert "Turn 1" in content + assert "USER" in content + assert "Hello world" in content @patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) -async def test_print_messages_async_assistant_message(mock_display, printer, capsys): +async def test_render_messages_async_assistant_message(mock_display, printer): piece = MessagePiece( role="assistant", original_value="Response", @@ -245,13 +235,12 @@ async def test_print_messages_async_assistant_message(mock_display, printer, cap converted_value_data_type="text", ) msg = Message(message_pieces=[piece]) - await printer.print_messages_async(messages=[msg]) - captured = capsys.readouterr() - assert "Response" in captured.out + content = await printer._render_messages_async(messages=[msg]) + assert "Response" in content @patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) -async def test_print_messages_async_converted_differs(mock_display, printer, capsys): +async def test_render_messages_async_converted_differs(mock_display, printer): piece = MessagePiece( role="user", original_value="Original", @@ -259,31 +248,29 @@ async def test_print_messages_async_converted_differs(mock_display, printer, cap converted_value_data_type="text", ) msg = Message(message_pieces=[piece]) - await printer.print_messages_async(messages=[msg]) - captured = capsys.readouterr() - assert "Original" in captured.out - assert "Converted" in captured.out + content = await printer._render_messages_async(messages=[msg]) + assert "Original" in content + assert "Converted" in content -async def test_print_summary_async(printer, sample_attack_result, capsys): - await printer.print_summary_async(sample_attack_result) - captured = capsys.readouterr() - assert "Test objective" in captured.out - assert "TestAttack" in captured.out - assert "test-conv-123" in captured.out - assert "SUCCESS" in captured.out - assert "Test successful" in captured.out +async def test_render_summary_async(printer, sample_attack_result): + content = await printer._render_summary_async(sample_attack_result) + assert "Test objective" in content + assert "TestAttack" in content + assert "test-conv-123" in content + assert "SUCCESS" in content + assert "Test successful" in content -async def test_print_result_async_basic(printer, sample_attack_result, mock_memory, capsys): +async def test_write_async_basic(printer, sample_attack_result, mock_memory, capsys): mock_memory.get_conversation.return_value = [] - await printer.print_result_async(sample_attack_result) + await printer.write_async(sample_attack_result) captured = capsys.readouterr() assert "ATTACK RESULT" in captured.out assert "Report generated at" in captured.out -async def test_print_result_async_with_metadata(printer, mock_memory, capsys): +async def test_write_async_with_metadata(printer, mock_memory, capsys): result = AttackResult( objective="test", conversation_id="conv-1", @@ -291,20 +278,19 @@ async def test_print_result_async_with_metadata(printer, mock_memory, capsys): metadata={"note": "extra info"}, ) mock_memory.get_conversation.return_value = [] - await printer.print_result_async(result) + await printer.write_async(result) captured = capsys.readouterr() assert "note" in captured.out assert "extra info" in captured.out -async def test_print_pruned_conversations_no_pruned(printer, capsys): +async def test_render_pruned_conversations_no_pruned(printer): result = AttackResult(objective="test", conversation_id="conv-1") - await printer._print_pruned_conversations_async(result) - captured = capsys.readouterr() - assert captured.out == "" + content = await printer._render_pruned_conversations_async(result) + assert content == "" -async def test_print_pruned_conversations_with_messages(printer, mock_memory, capsys): +async def test_render_pruned_conversations_with_messages(printer, mock_memory): piece = MessagePiece( role="assistant", original_value="Pruned response", @@ -320,35 +306,31 @@ async def test_print_pruned_conversations_with_messages(printer, mock_memory, ca conversation_id="conv-1", related_conversations={ref}, ) - await printer._print_pruned_conversations_async(result) - captured = capsys.readouterr() - assert "PRUNED" in captured.out - assert "Pruned response" in captured.out + content = await printer._render_pruned_conversations_async(result) + assert "PRUNED" in content + assert "Pruned response" in content -async def test_print_adversarial_conversation_no_refs(printer, capsys): +async def test_render_adversarial_conversation_no_refs(printer): result = AttackResult(objective="test", conversation_id="conv-1") - await printer._print_adversarial_conversation_async(result) - captured = capsys.readouterr() - assert captured.out == "" + content = await printer._render_adversarial_conversation_async(result) + assert content == "" -def test_print_wrapped_text(printer, capsys): - printer._print_wrapped_text("Short text", "") - captured = capsys.readouterr() - assert "Short text" in captured.out +def test_render_wrapped_text(printer): + result = printer._render_wrapped_text("Short text", "") + assert "Short text" in result -def test_print_wrapped_text_with_newlines(printer, capsys): - printer._print_wrapped_text("Line one\nLine two\n\nLine four", "") - captured = capsys.readouterr() - assert "Line one" in captured.out - assert "Line two" in captured.out - assert "Line four" in captured.out +def test_render_wrapped_text_with_newlines(printer): + result = printer._render_wrapped_text("Line one\nLine two\n\nLine four", "") + assert "Line one" in result + assert "Line two" in result + assert "Line four" in result @patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) -async def test_print_messages_async_blocked_without_partial_content(mock_display, printer, capsys): +async def test_render_messages_async_blocked_without_partial_content(mock_display, printer): piece = MessagePiece( role="assistant", original_value='{"status_code": 200, "message": "content_filter"}', @@ -356,16 +338,15 @@ async def test_print_messages_async_blocked_without_partial_content(mock_display response_error="blocked", ) msg = Message(message_pieces=[piece]) - await printer.print_messages_async(messages=[msg]) - captured = capsys.readouterr() - assert "BLOCKED BY TARGET" in captured.out - assert "content filter" in captured.out + content = await printer._render_messages_async(messages=[msg]) + assert "BLOCKED BY TARGET" in content + assert "content filter" in content # Should NOT print the raw error JSON as the message body - assert "status_code" not in captured.out + assert "status_code" not in content @patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) -async def test_print_messages_async_blocked_with_partial_content(mock_display, printer, capsys): +async def test_render_messages_async_blocked_with_partial_content(mock_display, printer): piece = MessagePiece( role="assistant", original_value='{"status_code": 200, "message": "content_filter"}', @@ -374,9 +355,8 @@ async def test_print_messages_async_blocked_with_partial_content(mock_display, p prompt_metadata={"partial_content": "The model started to say something before being cut off"}, ) msg = Message(message_pieces=[piece]) - await printer.print_messages_async(messages=[msg]) - captured = capsys.readouterr() - assert "BLOCKED BY TARGET" in captured.out - assert "Partial content" in captured.out - assert "before filter triggered" in captured.out - assert "The model started to say something before being cut off" in captured.out + content = await printer._render_messages_async(messages=[msg]) + assert "BLOCKED BY TARGET" in content + assert "Partial content" in content + assert "before filter triggered" in content + assert "The model started to say something before being cut off" in content diff --git a/tests/unit/printer/test_convenience.py b/tests/unit/printer/test_convenience.py new file mode 100644 index 0000000000..3fe8dc27bd --- /dev/null +++ b/tests/unit/printer/test_convenience.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.printer import ( + OutputFormat, + _resolve_sink, + print_attack_result_async, + print_scenario_result_async, + print_scorer_async, +) +from pyrit.printer.sink import FileSink, StdoutSink + + +# --- _resolve_sink tests --- + + +def test_resolve_sink_none_returns_stdout(): + sink = _resolve_sink(None) + assert isinstance(sink, StdoutSink) + + +def test_resolve_sink_path_returns_file_sink(): + sink = _resolve_sink(Path("output.txt")) + assert isinstance(sink, FileSink) + + +def test_resolve_sink_str_returns_file_sink(): + sink = _resolve_sink("output.txt") + assert isinstance(sink, FileSink) + + +def test_resolve_sink_sink_instance_passthrough(): + original = StdoutSink() + sink = _resolve_sink(original) + assert sink is original + + +# --- print_attack_result_async tests --- + + +@patch("pyrit.printer.attack_result.pretty.PrettyAttackResultMemoryPrinter") +async def test_print_attack_result_async_pretty_default(mock_cls): + mock_printer = MagicMock() + mock_printer.write_async = AsyncMock() + mock_cls.return_value = mock_printer + result = MagicMock() + + await print_attack_result_async(result) + + mock_cls.assert_called_once() + mock_printer.write_async.assert_called_once() + + +@patch("pyrit.printer.attack_result.markdown.MarkdownAttackResultMemoryPrinter") +async def test_print_attack_result_async_markdown(mock_cls): + mock_printer = MagicMock() + mock_printer.write_async = AsyncMock() + mock_cls.return_value = mock_printer + result = MagicMock() + + await print_attack_result_async(result, format="markdown") + + mock_cls.assert_called_once() + mock_printer.write_async.assert_called_once() + + +@patch("pyrit.printer.attack_result.pretty.PrettyAttackResultMemoryPrinter") +async def test_print_attack_result_async_to_file(mock_cls): + mock_printer = MagicMock() + mock_printer.write_async = AsyncMock() + mock_cls.return_value = mock_printer + result = MagicMock() + + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: + path = Path(f.name) + + try: + await print_attack_result_async(result, to=path) + call_kwargs = mock_cls.call_args[1] + assert isinstance(call_kwargs["sink"], FileSink) + finally: + path.unlink(missing_ok=True) + + +# --- print_scenario_result_async tests --- + + +@patch("pyrit.printer.scenario_result.pretty.PrettyScenarioResultMemoryPrinter") +async def test_print_scenario_result_async_pretty(mock_cls): + mock_printer = MagicMock() + mock_printer.write_async = AsyncMock() + mock_cls.return_value = mock_printer + result = MagicMock() + + await print_scenario_result_async(result) + + mock_cls.assert_called_once() + mock_printer.write_async.assert_called_once_with(result) + + +async def test_print_scenario_result_async_unsupported_format(): + with pytest.raises(ValueError, match="Unsupported format"): + await print_scenario_result_async(MagicMock(), format="markdown") + + +# --- print_scorer_async tests --- + + +@patch("pyrit.printer.scorer.pretty.PrettyScorerMemoryPrinter") +async def test_print_scorer_async_pretty(mock_cls): + mock_printer = MagicMock() + mock_printer.write_async = AsyncMock() + mock_cls.return_value = mock_printer + scorer_id = MagicMock() + + await print_scorer_async(scorer_identifier=scorer_id) + + mock_cls.assert_called_once() + mock_printer.write_async.assert_called_once_with(scorer_identifier=scorer_id, harm_category=None) + + +@patch("pyrit.printer.scorer.pretty.PrettyScorerMemoryPrinter") +async def test_print_scorer_async_with_harm_category(mock_cls): + mock_printer = MagicMock() + mock_printer.write_async = AsyncMock() + mock_cls.return_value = mock_printer + scorer_id = MagicMock() + + await print_scorer_async(scorer_identifier=scorer_id, harm_category="hate_speech") + + mock_printer.write_async.assert_called_once_with(scorer_identifier=scorer_id, harm_category="hate_speech") + + +async def test_print_scorer_async_unsupported_format(): + with pytest.raises(ValueError, match="Unsupported format"): + await print_scorer_async(scorer_identifier=MagicMock(), format="markdown") diff --git a/tests/unit/printer/test_printer_base.py b/tests/unit/printer/test_printer_base.py index 92c19bd030..6c3b6b14a6 100644 --- a/tests/unit/printer/test_printer_base.py +++ b/tests/unit/printer/test_printer_base.py @@ -7,31 +7,32 @@ from pyrit.printer.sink import StdoutSink -def test_printer_base_has_no_abstract_methods(): - # PrinterBase is abstract via ABC but has no abstract methods of its own. - # Subclasses add their own abstract methods for data fetching. - class ConcretePrinter(PrinterBase): +def test_printer_base_is_abstract(): + class IncompletePrinter(PrinterBase): pass - printer = ConcretePrinter() - assert isinstance(printer, PrinterBase) + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompletePrinter() # type: ignore[abstract] def test_printer_base_defaults_to_stdout_sink(): class ConcretePrinter(PrinterBase): - pass + async def write_async(self) -> None: + pass printer = ConcretePrinter() assert isinstance(printer._sink, StdoutSink) def test_printer_base_accepts_custom_sink(): - from pyrit.printer.sink import FileSink from pathlib import Path + from pyrit.printer.sink import FileSink + class ConcretePrinter(PrinterBase): - pass + async def write_async(self) -> None: + pass sink = FileSink(path=Path("test.txt")) printer = ConcretePrinter(sink=sink) @@ -41,9 +42,10 @@ class ConcretePrinter(PrinterBase): async def test_printer_base_write_async_delegates_to_sink(capsys): class ConcretePrinter(PrinterBase): - pass + async def write_async(self) -> None: + await self._write_async("test output") printer = ConcretePrinter() - await printer._write_async(b"test output") + await printer.write_async() captured = capsys.readouterr() assert captured.out == "test output" diff --git a/tests/unit/printer/test_sink.py b/tests/unit/printer/test_sink.py index ca5b9df226..17233e67e5 100644 --- a/tests/unit/printer/test_sink.py +++ b/tests/unit/printer/test_sink.py @@ -16,34 +16,27 @@ def test_sink_is_abstract(): async def test_stdout_sink_writes_to_stdout(capsys): sink = StdoutSink() - await sink.write_async(b"hello world") + await sink.write_async("hello world") captured = capsys.readouterr() assert captured.out == "hello world" async def test_stdout_sink_no_trailing_newline(capsys): sink = StdoutSink() - await sink.write_async(b"line1") - await sink.write_async(b"line2") + await sink.write_async("line1") + await sink.write_async("line2") captured = capsys.readouterr() assert captured.out == "line1line2" -async def test_stdout_sink_custom_encoding(capsys): - sink = StdoutSink(encoding="ascii") - await sink.write_async(b"ascii text") - captured = capsys.readouterr() - assert captured.out == "ascii text" - - async def test_file_sink_writes_to_file(): with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f: path = Path(f.name) try: - sink = FileSink(path=path, mode="wb") - await sink.write_async(b"hello file") - assert path.read_bytes() == b"hello file" + sink = FileSink(path=path, mode="w") + await sink.write_async("hello file") + assert path.read_text(encoding="utf-8") == "hello file" finally: path.unlink(missing_ok=True) @@ -53,17 +46,17 @@ async def test_file_sink_append_mode(): path = Path(f.name) try: - sink = FileSink(path=path, mode="wb") - await sink.write_async(b"first") + sink = FileSink(path=path, mode="w") + await sink.write_async("first") - append_sink = FileSink(path=path, mode="ab") - await append_sink.write_async(b" second") + append_sink = FileSink(path=path, mode="a") + await append_sink.write_async(" second") - assert path.read_bytes() == b"first second" + assert path.read_text(encoding="utf-8") == "first second" finally: path.unlink(missing_ok=True) def test_file_sink_rejects_invalid_mode(): - with pytest.raises(ValueError, match="mode must be 'wb' or 'ab'"): - FileSink(path=Path("test.txt"), mode="w") + with pytest.raises(ValueError, match="mode must be 'w' or 'a'"): + FileSink(path=Path("test.txt"), mode="wb") diff --git a/tests/unit/score/test_pretty_scorer_printer.py b/tests/unit/score/test_pretty_scorer_printer.py index 90a04339d9..80c81364c9 100644 --- a/tests/unit/score/test_pretty_scorer_printer.py +++ b/tests/unit/score/test_pretty_scorer_printer.py @@ -88,30 +88,27 @@ def test_init_colors_disabled(): assert printer._enable_colors is False -# --- _print_colored tests --- +# --- _format_colored tests --- -def test_print_colored_with_colors_enabled(capsys): +def test_format_colored_with_colors_enabled(): printer = ConsoleScorerPrinter(enable_colors=True) - printer._print_colored("hello", Fore.GREEN) - captured = capsys.readouterr() - assert "hello" in captured.out - assert Style.RESET_ALL in captured.out + result = printer._format_colored("hello", Fore.GREEN) + assert "hello" in result + assert Style.RESET_ALL in result -def test_print_colored_with_colors_disabled(capsys): +def test_format_colored_with_colors_disabled(): printer = ConsoleScorerPrinter(enable_colors=False) - printer._print_colored("hello", Fore.GREEN) - captured = capsys.readouterr() - assert captured.out.strip() == "hello" - assert Style.RESET_ALL not in captured.out + result = printer._format_colored("hello", Fore.GREEN) + assert result.strip() == "hello" + assert Style.RESET_ALL not in result -def test_print_colored_no_colors_arg(capsys): +def test_format_colored_no_colors_arg(): printer = ConsoleScorerPrinter(enable_colors=True) - printer._print_colored("plain text") - captured = capsys.readouterr() - assert captured.out.strip() == "plain text" + result = printer._format_colored("plain text") + assert result.strip() == "plain text" # --- _get_quality_color tests --- @@ -153,31 +150,29 @@ def test_quality_color_lower_is_better_middle(): assert color == Fore.CYAN -# --- _print_scorer_info tests --- +# --- _render_scorer_info tests --- -def test_print_scorer_info_basic(capsys): +def test_render_scorer_info_basic(): printer = ConsoleScorerPrinter(enable_colors=False) identifier = _make_scorer_identifier(class_name="SelfAskScaleScorer") - printer._print_scorer_info(identifier, indent_level=2) - output = capsys.readouterr().out + output = printer._render_scorer_info(identifier, indent_level=2) assert "SelfAskScaleScorer" in output -def test_print_scorer_info_with_display_params(capsys): +def test_render_scorer_info_with_display_params(): printer = ConsoleScorerPrinter(enable_colors=False) identifier = _make_scorer_identifier( class_name="TestScorer", params={"scorer_type": "likert", "score_aggregator": "mean", "hidden_param": "ignore"}, ) - printer._print_scorer_info(identifier, indent_level=2) - output = capsys.readouterr().out + output = printer._render_scorer_info(identifier, indent_level=2) assert "scorer_type" in output assert "score_aggregator" in output assert "hidden_param" not in output -def test_print_scorer_info_with_prompt_target_child(capsys): +def test_render_scorer_info_with_prompt_target_child(): printer = ConsoleScorerPrinter(enable_colors=False) target_id = ComponentIdentifier( class_name="OpenAIChatTarget", @@ -187,13 +182,12 @@ def test_print_scorer_info_with_prompt_target_child(capsys): identifier = _make_scorer_identifier( children={"prompt_target": target_id}, ) - printer._print_scorer_info(identifier, indent_level=2) - output = capsys.readouterr().out + output = printer._render_scorer_info(identifier, indent_level=2) assert "gpt-4" in output assert "extra" not in output -def test_print_scorer_info_with_sub_scorers(capsys): +def test_render_scorer_info_with_sub_scorers(): printer = ConsoleScorerPrinter(enable_colors=False) sub1 = _make_scorer_identifier(class_name="SubScorer1") sub2 = _make_scorer_identifier(class_name="SubScorer2") @@ -201,28 +195,25 @@ def test_print_scorer_info_with_sub_scorers(capsys): class_name="CompositeScorer", children={"sub_scorers": [sub1, sub2]}, ) - printer._print_scorer_info(identifier, indent_level=2) - output = capsys.readouterr().out + output = printer._render_scorer_info(identifier, indent_level=2) assert "Composite of 2 scorer(s)" in output assert "SubScorer1" in output assert "SubScorer2" in output -# --- _print_objective_metrics tests --- +# --- _render_objective_metrics tests --- -def test_print_objective_metrics_none(capsys): +def test_render_objective_metrics_none(): printer = ConsoleScorerPrinter(enable_colors=False) - printer._print_objective_metrics(None) - output = capsys.readouterr().out + output = printer._render_objective_metrics(None) assert "Official evaluation has not been run yet" in output -def test_print_objective_metrics_full(capsys): +def test_render_objective_metrics_full(): printer = ConsoleScorerPrinter(enable_colors=False) metrics = _make_objective_metrics() - printer._print_objective_metrics(metrics) - output = capsys.readouterr().out + output = printer._render_objective_metrics(metrics) assert "Accuracy" in output assert "F1 Score" in output assert "Precision" in output @@ -230,7 +221,7 @@ def test_print_objective_metrics_full(capsys): assert "Average Score Time" in output -def test_print_objective_metrics_optional_fields_none(capsys): +def test_render_objective_metrics_optional_fields_none(): printer = ConsoleScorerPrinter(enable_colors=False) metrics = _make_objective_metrics( accuracy_standard_error=None, @@ -239,8 +230,7 @@ def test_print_objective_metrics_optional_fields_none(capsys): recall=None, average_score_time_seconds=None, ) - printer._print_objective_metrics(metrics) - output = capsys.readouterr().out + output = printer._render_objective_metrics(metrics) assert "Accuracy" in output assert "F1 Score" not in output assert "Precision" not in output @@ -248,28 +238,26 @@ def test_print_objective_metrics_optional_fields_none(capsys): assert "Average Score Time" not in output -# --- _print_harm_metrics tests --- +# --- _render_harm_metrics tests --- -def test_print_harm_metrics_none(capsys): +def test_render_harm_metrics_none(): printer = ConsoleScorerPrinter(enable_colors=False) - printer._print_harm_metrics(None) - output = capsys.readouterr().out + output = printer._render_harm_metrics(None) assert "Official evaluation has not been run yet" in output -def test_print_harm_metrics_full(capsys): +def test_render_harm_metrics_full(): printer = ConsoleScorerPrinter(enable_colors=False) metrics = _make_harm_metrics() - printer._print_harm_metrics(metrics) - output = capsys.readouterr().out + output = printer._render_harm_metrics(metrics) assert "Mean Absolute Error" in output assert "Krippendorff Alpha (Combined)" in output assert "Krippendorff Alpha (Model)" in output assert "Average Score Time" in output -def test_print_harm_metrics_optional_fields_none(capsys): +def test_render_harm_metrics_optional_fields_none(): printer = ConsoleScorerPrinter(enable_colors=False) metrics = _make_harm_metrics( mae_standard_error=None, @@ -277,8 +265,7 @@ def test_print_harm_metrics_optional_fields_none(capsys): krippendorff_alpha_model=None, average_score_time_seconds=None, ) - printer._print_harm_metrics(metrics) - output = capsys.readouterr().out + output = printer._render_harm_metrics(metrics) assert "Mean Absolute Error" in output assert "MAE Std Error" not in output assert "Krippendorff Alpha (Combined)" not in output @@ -286,12 +273,12 @@ def test_print_harm_metrics_optional_fields_none(capsys): assert "Average Score Time" not in output -# --- print_objective_scorer tests --- +# --- write_async (objective) tests --- @patch("pyrit.score.scorer_evaluation.scorer_metrics_io.find_objective_metrics_by_eval_hash") @patch("pyrit.identifiers.evaluation_identifier.ScorerEvaluationIdentifier") -def test_print_objective_scorer_with_metrics(mock_eval_id_cls, mock_find, capsys): +async def test_write_async_objective_with_metrics(mock_eval_id_cls, mock_find, capsys): printer = ConsoleScorerPrinter(enable_colors=False) identifier = _make_scorer_identifier(class_name="MyScorer") metrics = _make_objective_metrics() @@ -301,7 +288,7 @@ def test_print_objective_scorer_with_metrics(mock_eval_id_cls, mock_find, capsys mock_eval_id_cls.return_value = mock_eval_instance mock_find.return_value = metrics - printer.print_objective_scorer(scorer_identifier=identifier) + await printer.write_async(scorer_identifier=identifier) output = capsys.readouterr().out assert "Scorer Information" in output @@ -312,7 +299,7 @@ def test_print_objective_scorer_with_metrics(mock_eval_id_cls, mock_find, capsys @patch("pyrit.score.scorer_evaluation.scorer_metrics_io.find_objective_metrics_by_eval_hash") @patch("pyrit.identifiers.evaluation_identifier.ScorerEvaluationIdentifier") -def test_print_objective_scorer_no_metrics(mock_eval_id_cls, mock_find, capsys): +async def test_write_async_objective_no_metrics(mock_eval_id_cls, mock_find, capsys): printer = ConsoleScorerPrinter(enable_colors=False) identifier = _make_scorer_identifier() @@ -321,17 +308,17 @@ def test_print_objective_scorer_no_metrics(mock_eval_id_cls, mock_find, capsys): mock_eval_id_cls.return_value = mock_eval_instance mock_find.return_value = None - printer.print_objective_scorer(scorer_identifier=identifier) + await printer.write_async(scorer_identifier=identifier) output = capsys.readouterr().out assert "Official evaluation has not been run yet" in output -# --- print_harm_scorer tests --- +# --- write_async (harm) tests --- @patch("pyrit.score.scorer_evaluation.scorer_metrics_io.find_harm_metrics_by_eval_hash") @patch("pyrit.identifiers.evaluation_identifier.ScorerEvaluationIdentifier") -def test_print_harm_scorer_with_metrics(mock_eval_id_cls, mock_find, capsys): +async def test_write_async_harm_with_metrics(mock_eval_id_cls, mock_find, capsys): printer = ConsoleScorerPrinter(enable_colors=False) identifier = _make_scorer_identifier(class_name="HarmScorer") metrics = _make_harm_metrics() @@ -341,7 +328,7 @@ def test_print_harm_scorer_with_metrics(mock_eval_id_cls, mock_find, capsys): mock_eval_id_cls.return_value = mock_eval_instance mock_find.return_value = metrics - printer.print_harm_scorer(scorer_identifier=identifier, harm_category="hate_speech") + await printer.write_async(scorer_identifier=identifier, harm_category="hate_speech") output = capsys.readouterr().out assert "Scorer Information" in output @@ -352,7 +339,7 @@ def test_print_harm_scorer_with_metrics(mock_eval_id_cls, mock_find, capsys): @patch("pyrit.score.scorer_evaluation.scorer_metrics_io.find_harm_metrics_by_eval_hash") @patch("pyrit.identifiers.evaluation_identifier.ScorerEvaluationIdentifier") -def test_print_harm_scorer_no_metrics(mock_eval_id_cls, mock_find, capsys): +async def test_write_async_harm_no_metrics(mock_eval_id_cls, mock_find, capsys): printer = ConsoleScorerPrinter(enable_colors=False) identifier = _make_scorer_identifier() @@ -361,6 +348,6 @@ def test_print_harm_scorer_no_metrics(mock_eval_id_cls, mock_find, capsys): mock_eval_id_cls.return_value = mock_eval_instance mock_find.return_value = None - printer.print_harm_scorer(scorer_identifier=identifier, harm_category="violence") + await printer.write_async(scorer_identifier=identifier, harm_category="violence") output = capsys.readouterr().out assert "Official evaluation has not been run yet" in output diff --git a/tests/unit/score/test_scorer_printer.py b/tests/unit/score/test_scorer_printer.py index 3b1a639c7c..cd0a5b997d 100644 --- a/tests/unit/score/test_scorer_printer.py +++ b/tests/unit/score/test_scorer_printer.py @@ -30,6 +30,18 @@ def _get_objective_metrics(self, *, eval_hash: str): IncompletePrinter() # type: ignore[abstract] +def test_scorer_printer_subclass_must_implement_write_async(): + class IncompletePrinter(ScorerPrinter): + def _get_objective_metrics(self, *, eval_hash: str): + return None + + def _get_harm_metrics(self, *, eval_hash: str, harm_category: str): + return None + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompletePrinter() # type: ignore[abstract] + + def test_scorer_printer_complete_subclass_can_be_instantiated(): class CompletePrinter(ScorerPrinter): def _get_objective_metrics(self, *, eval_hash: str): @@ -38,10 +50,9 @@ def _get_objective_metrics(self, *, eval_hash: str): def _get_harm_metrics(self, *, eval_hash: str, harm_category: str): return None - def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: - pass - - def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, harm_category: str) -> None: + async def write_async( + self, *, scorer_identifier: ComponentIdentifier, harm_category: str | None = None + ) -> None: pass printer = CompletePrinter() From 35f8e9d7d166065e2dcb4b60f4e74d8e98face8f Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 22:19:54 -0700 Subject: [PATCH 17/34] Move convenience functions into their domain modules Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/printer/__init__.py | 126 +--------------------- pyrit/printer/attack_result/__init__.py | 46 +++++++- pyrit/printer/scenario_result/__init__.py | 32 +++++- pyrit/printer/scorer/__init__.py | 37 ++++++- pyrit/printer/sink.py | 23 ++++ tests/unit/printer/test_convenience.py | 20 ++-- 6 files changed, 144 insertions(+), 140 deletions(-) diff --git a/pyrit/printer/__init__.py b/pyrit/printer/__init__.py index 50abdf19d3..18a03fa775 100644 --- a/pyrit/printer/__init__.py +++ b/pyrit/printer/__init__.py @@ -8,132 +8,8 @@ - **Sink** classes that define where output goes (stdout, file, etc.) - **PrinterBase** that all printers inherit from - Domain printers for attack results, scenario results, and scorer information -- **Convenience functions** for one-line printing (e.g., ``print_attack_result_async``) +- **Convenience functions** in each subpackage (e.g., ``print_attack_result_async``) File names indicate output format (pretty.py = ANSI-colored, markdown.py = Markdown). Abstract methods inside each printer determine the data source (memory, REST, fixtures). """ - -from pathlib import Path -from typing import Literal - -from pyrit.printer.sink import FileSink, Sink, StdoutSink - -OutputFormat = Literal["pretty", "markdown"] - - -def _resolve_sink(to: Path | str | Sink | None) -> Sink: - """ - Resolve a destination argument to a Sink instance. - - Args: - to (Path | str | Sink | None): The destination. - None → StdoutSink. - Path or str → FileSink. - Sink instance → used as-is. - - Returns: - Sink: The resolved sink. - """ - if to is None: - return StdoutSink() - if isinstance(to, Sink): - return to - return FileSink(path=Path(to)) - - -async def print_attack_result_async( - result: "AttackResult", # noqa: F821 - *, - format: OutputFormat = "pretty", - to: Path | str | Sink | None = None, - include_auxiliary_scores: bool = False, - include_pruned_conversations: bool = False, - include_adversarial_conversation: bool = False, -) -> None: - """ - Print an attack result in the specified format to the specified destination. - - Args: - result (AttackResult): The attack result to print. - format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. - include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. - include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. - include_adversarial_conversation (bool): Whether to include the adversarial conversation. - Defaults to False. - """ - sink = _resolve_sink(to) - - if format == "markdown": - from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter - - printer = MarkdownAttackResultMemoryPrinter(sink=sink) - else: - from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter - - printer = PrettyAttackResultMemoryPrinter(sink=sink) - - await printer.write_async( - result, - include_auxiliary_scores=include_auxiliary_scores, - include_pruned_conversations=include_pruned_conversations, - include_adversarial_conversation=include_adversarial_conversation, - ) - - -async def print_scenario_result_async( - result: "ScenarioResult", # noqa: F821 - *, - format: OutputFormat = "pretty", - to: Path | str | Sink | None = None, -) -> None: - """ - Print a scenario result in the specified format to the specified destination. - - Args: - result (ScenarioResult): The scenario result to print. - format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. - """ - sink = _resolve_sink(to) - - if format == "pretty": - from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter - - printer = PrettyScenarioResultMemoryPrinter(sink=sink) - else: - raise ValueError(f"Unsupported format for scenario results: {format!r}. Only 'pretty' is available.") - - await printer.write_async(result) - - -async def print_scorer_async( - *, - scorer_identifier: "ComponentIdentifier", # noqa: F821 - harm_category: str | None = None, - format: OutputFormat = "pretty", - to: Path | str | Sink | None = None, -) -> None: - """ - Print scorer information in the specified format to the specified destination. - - Auto-detects scorer type: if harm_category is provided, renders harm - metrics; otherwise renders objective metrics. - - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier. - harm_category (str | None): The harm category. None for objective scorers. - format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. - """ - sink = _resolve_sink(to) - - if format == "pretty": - from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter - - printer = PrettyScorerMemoryPrinter(sink=sink) - else: - raise ValueError(f"Unsupported format for scorer: {format!r}. Only 'pretty' is available.") - - await printer.write_async(scorer_identifier=scorer_identifier, harm_category=harm_category) diff --git a/pyrit/printer/attack_result/__init__.py b/pyrit/printer/attack_result/__init__.py index 47789c0055..9610acdb6f 100644 --- a/pyrit/printer/attack_result/__init__.py +++ b/pyrit/printer/attack_result/__init__.py @@ -1,4 +1,48 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Attack result printer base classes.""" +"""Attack result printer classes.""" + +from pathlib import Path + +from pyrit.printer.sink import OutputFormat, Sink, resolve_sink + + +async def print_attack_result_async( + result: "AttackResult", # noqa: F821 + *, + format: OutputFormat = "pretty", + to: Path | str | Sink | None = None, + include_auxiliary_scores: bool = False, + include_pruned_conversations: bool = False, + include_adversarial_conversation: bool = False, +) -> None: + """ + Print an attack result in the specified format to the specified destination. + + Args: + result (AttackResult): The attack result to print. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. + include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. + include_adversarial_conversation (bool): Whether to include the adversarial conversation. + Defaults to False. + """ + sink = resolve_sink(to) + + if format == "markdown": + from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter + + printer = MarkdownAttackResultMemoryPrinter(sink=sink) + else: + from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter + + printer = PrettyAttackResultMemoryPrinter(sink=sink) + + await printer.write_async( + result, + include_auxiliary_scores=include_auxiliary_scores, + include_pruned_conversations=include_pruned_conversations, + include_adversarial_conversation=include_adversarial_conversation, + ) diff --git a/pyrit/printer/scenario_result/__init__.py b/pyrit/printer/scenario_result/__init__.py index 0def8141c0..03a94fd200 100644 --- a/pyrit/printer/scenario_result/__init__.py +++ b/pyrit/printer/scenario_result/__init__.py @@ -1,4 +1,34 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Scenario result printer base classes.""" +"""Scenario result printer classes.""" + +from pathlib import Path + +from pyrit.printer.sink import OutputFormat, Sink, resolve_sink + + +async def print_scenario_result_async( + result: "ScenarioResult", # noqa: F821 + *, + format: OutputFormat = "pretty", + to: Path | str | Sink | None = None, +) -> None: + """ + Print a scenario result in the specified format to the specified destination. + + Args: + result (ScenarioResult): The scenario result to print. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + """ + sink = resolve_sink(to) + + if format == "pretty": + from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter + + printer = PrettyScenarioResultMemoryPrinter(sink=sink) + else: + raise ValueError(f"Unsupported format for scenario results: {format!r}. Only 'pretty' is available.") + + await printer.write_async(result) diff --git a/pyrit/printer/scorer/__init__.py b/pyrit/printer/scorer/__init__.py index 7c7c7bd417..97979d0a08 100644 --- a/pyrit/printer/scorer/__init__.py +++ b/pyrit/printer/scorer/__init__.py @@ -1,4 +1,39 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Scorer printer base classes.""" +"""Scorer printer classes.""" + +from pathlib import Path + +from pyrit.printer.sink import OutputFormat, Sink, resolve_sink + + +async def print_scorer_async( + *, + scorer_identifier: "ComponentIdentifier", # noqa: F821 + harm_category: str | None = None, + format: OutputFormat = "pretty", + to: Path | str | Sink | None = None, +) -> None: + """ + Print scorer information in the specified format to the specified destination. + + Auto-detects scorer type: if harm_category is provided, renders harm + metrics; otherwise renders objective metrics. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier. + harm_category (str | None): The harm category. None for objective scorers. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + """ + sink = resolve_sink(to) + + if format == "pretty": + from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter + + printer = PrettyScorerMemoryPrinter(sink=sink) + else: + raise ValueError(f"Unsupported format for scorer: {format!r}. Only 'pretty' is available.") + + await printer.write_async(scorer_identifier=scorer_identifier, harm_category=harm_category) diff --git a/pyrit/printer/sink.py b/pyrit/printer/sink.py index aca9c07f9b..21347c8100 100644 --- a/pyrit/printer/sink.py +++ b/pyrit/printer/sink.py @@ -3,6 +3,9 @@ from abc import ABC, abstractmethod from pathlib import Path +from typing import Literal + +OutputFormat = Literal["pretty", "markdown"] class Sink(ABC): @@ -71,3 +74,23 @@ async def write_async(self, data: str) -> None: """ with open(self._path, self._mode, encoding="utf-8") as f: f.write(data) + + +def resolve_sink(to: Path | str | Sink | None) -> Sink: + """ + Resolve a destination argument to a Sink instance. + + Args: + to (Path | str | Sink | None): The destination. + None → StdoutSink. + Path or str → FileSink. + Sink instance → used as-is. + + Returns: + Sink: The resolved sink. + """ + if to is None: + return StdoutSink() + if isinstance(to, Sink): + return to + return FileSink(path=Path(to)) diff --git a/tests/unit/printer/test_convenience.py b/tests/unit/printer/test_convenience.py index 3fe8dc27bd..868ca80d06 100644 --- a/tests/unit/printer/test_convenience.py +++ b/tests/unit/printer/test_convenience.py @@ -7,37 +7,33 @@ import pytest -from pyrit.printer import ( - OutputFormat, - _resolve_sink, - print_attack_result_async, - print_scenario_result_async, - print_scorer_async, -) -from pyrit.printer.sink import FileSink, StdoutSink +from pyrit.printer.attack_result import print_attack_result_async +from pyrit.printer.scenario_result import print_scenario_result_async +from pyrit.printer.scorer import print_scorer_async +from pyrit.printer.sink import FileSink, OutputFormat, StdoutSink, resolve_sink # --- _resolve_sink tests --- def test_resolve_sink_none_returns_stdout(): - sink = _resolve_sink(None) + sink = resolve_sink(None) assert isinstance(sink, StdoutSink) def test_resolve_sink_path_returns_file_sink(): - sink = _resolve_sink(Path("output.txt")) + sink = resolve_sink(Path("output.txt")) assert isinstance(sink, FileSink) def test_resolve_sink_str_returns_file_sink(): - sink = _resolve_sink("output.txt") + sink = resolve_sink("output.txt") assert isinstance(sink, FileSink) def test_resolve_sink_sink_instance_passthrough(): original = StdoutSink() - sink = _resolve_sink(original) + sink = resolve_sink(original) assert sink is original From 4720ff7c1d8c9e35c07e120d313ac0dd2735a7f8 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 22:21:12 -0700 Subject: [PATCH 18/34] Move convenience functions to printer/helpers.py Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/printer/attack_result/__init__.py | 44 --------- pyrit/printer/helpers.py | 105 ++++++++++++++++++++++ pyrit/printer/scenario_result/__init__.py | 30 ------- pyrit/printer/scorer/__init__.py | 35 -------- tests/unit/printer/test_convenience.py | 8 +- 5 files changed, 110 insertions(+), 112 deletions(-) create mode 100644 pyrit/printer/helpers.py diff --git a/pyrit/printer/attack_result/__init__.py b/pyrit/printer/attack_result/__init__.py index 9610acdb6f..dd5d3f3d12 100644 --- a/pyrit/printer/attack_result/__init__.py +++ b/pyrit/printer/attack_result/__init__.py @@ -2,47 +2,3 @@ # Licensed under the MIT license. """Attack result printer classes.""" - -from pathlib import Path - -from pyrit.printer.sink import OutputFormat, Sink, resolve_sink - - -async def print_attack_result_async( - result: "AttackResult", # noqa: F821 - *, - format: OutputFormat = "pretty", - to: Path | str | Sink | None = None, - include_auxiliary_scores: bool = False, - include_pruned_conversations: bool = False, - include_adversarial_conversation: bool = False, -) -> None: - """ - Print an attack result in the specified format to the specified destination. - - Args: - result (AttackResult): The attack result to print. - format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. - include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. - include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. - include_adversarial_conversation (bool): Whether to include the adversarial conversation. - Defaults to False. - """ - sink = resolve_sink(to) - - if format == "markdown": - from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter - - printer = MarkdownAttackResultMemoryPrinter(sink=sink) - else: - from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter - - printer = PrettyAttackResultMemoryPrinter(sink=sink) - - await printer.write_async( - result, - include_auxiliary_scores=include_auxiliary_scores, - include_pruned_conversations=include_pruned_conversations, - include_adversarial_conversation=include_adversarial_conversation, - ) diff --git a/pyrit/printer/helpers.py b/pyrit/printer/helpers.py new file mode 100644 index 0000000000..ded9ea7fb0 --- /dev/null +++ b/pyrit/printer/helpers.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Convenience functions for one-line printing of attack results, scenario results, and scorer info.""" + +from pathlib import Path + +from pyrit.printer.sink import OutputFormat, Sink, resolve_sink + + +async def print_attack_result_async( + result: "AttackResult", # noqa: F821 + *, + format: OutputFormat = "pretty", + to: Path | str | Sink | None = None, + include_auxiliary_scores: bool = False, + include_pruned_conversations: bool = False, + include_adversarial_conversation: bool = False, +) -> None: + """ + Print an attack result in the specified format to the specified destination. + + Args: + result (AttackResult): The attack result to print. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. + include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. + include_adversarial_conversation (bool): Whether to include the adversarial conversation. + Defaults to False. + """ + sink = resolve_sink(to) + + if format == "markdown": + from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter + + printer = MarkdownAttackResultMemoryPrinter(sink=sink) + else: + from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter + + printer = PrettyAttackResultMemoryPrinter(sink=sink) + + await printer.write_async( + result, + include_auxiliary_scores=include_auxiliary_scores, + include_pruned_conversations=include_pruned_conversations, + include_adversarial_conversation=include_adversarial_conversation, + ) + + +async def print_scenario_result_async( + result: "ScenarioResult", # noqa: F821 + *, + format: OutputFormat = "pretty", + to: Path | str | Sink | None = None, +) -> None: + """ + Print a scenario result in the specified format to the specified destination. + + Args: + result (ScenarioResult): The scenario result to print. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + """ + sink = resolve_sink(to) + + if format == "pretty": + from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter + + printer = PrettyScenarioResultMemoryPrinter(sink=sink) + else: + raise ValueError(f"Unsupported format for scenario results: {format!r}. Only 'pretty' is available.") + + await printer.write_async(result) + + +async def print_scorer_async( + *, + scorer_identifier: "ComponentIdentifier", # noqa: F821 + harm_category: str | None = None, + format: OutputFormat = "pretty", + to: Path | str | Sink | None = None, +) -> None: + """ + Print scorer information in the specified format to the specified destination. + + Auto-detects scorer type: if harm_category is provided, renders harm + metrics; otherwise renders objective metrics. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier. + harm_category (str | None): The harm category. None for objective scorers. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + """ + sink = resolve_sink(to) + + if format == "pretty": + from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter + + printer = PrettyScorerMemoryPrinter(sink=sink) + else: + raise ValueError(f"Unsupported format for scorer: {format!r}. Only 'pretty' is available.") + + await printer.write_async(scorer_identifier=scorer_identifier, harm_category=harm_category) diff --git a/pyrit/printer/scenario_result/__init__.py b/pyrit/printer/scenario_result/__init__.py index 03a94fd200..24bceaeacf 100644 --- a/pyrit/printer/scenario_result/__init__.py +++ b/pyrit/printer/scenario_result/__init__.py @@ -2,33 +2,3 @@ # Licensed under the MIT license. """Scenario result printer classes.""" - -from pathlib import Path - -from pyrit.printer.sink import OutputFormat, Sink, resolve_sink - - -async def print_scenario_result_async( - result: "ScenarioResult", # noqa: F821 - *, - format: OutputFormat = "pretty", - to: Path | str | Sink | None = None, -) -> None: - """ - Print a scenario result in the specified format to the specified destination. - - Args: - result (ScenarioResult): The scenario result to print. - format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. - """ - sink = resolve_sink(to) - - if format == "pretty": - from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter - - printer = PrettyScenarioResultMemoryPrinter(sink=sink) - else: - raise ValueError(f"Unsupported format for scenario results: {format!r}. Only 'pretty' is available.") - - await printer.write_async(result) diff --git a/pyrit/printer/scorer/__init__.py b/pyrit/printer/scorer/__init__.py index 97979d0a08..d4389eac5b 100644 --- a/pyrit/printer/scorer/__init__.py +++ b/pyrit/printer/scorer/__init__.py @@ -2,38 +2,3 @@ # Licensed under the MIT license. """Scorer printer classes.""" - -from pathlib import Path - -from pyrit.printer.sink import OutputFormat, Sink, resolve_sink - - -async def print_scorer_async( - *, - scorer_identifier: "ComponentIdentifier", # noqa: F821 - harm_category: str | None = None, - format: OutputFormat = "pretty", - to: Path | str | Sink | None = None, -) -> None: - """ - Print scorer information in the specified format to the specified destination. - - Auto-detects scorer type: if harm_category is provided, renders harm - metrics; otherwise renders objective metrics. - - Args: - scorer_identifier (ComponentIdentifier): The scorer identifier. - harm_category (str | None): The harm category. None for objective scorers. - format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. - """ - sink = resolve_sink(to) - - if format == "pretty": - from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter - - printer = PrettyScorerMemoryPrinter(sink=sink) - else: - raise ValueError(f"Unsupported format for scorer: {format!r}. Only 'pretty' is available.") - - await printer.write_async(scorer_identifier=scorer_identifier, harm_category=harm_category) diff --git a/tests/unit/printer/test_convenience.py b/tests/unit/printer/test_convenience.py index 868ca80d06..f310b84c92 100644 --- a/tests/unit/printer/test_convenience.py +++ b/tests/unit/printer/test_convenience.py @@ -7,9 +7,11 @@ import pytest -from pyrit.printer.attack_result import print_attack_result_async -from pyrit.printer.scenario_result import print_scenario_result_async -from pyrit.printer.scorer import print_scorer_async +from pyrit.printer.helpers import ( + print_attack_result_async, + print_scenario_result_async, + print_scorer_async, +) from pyrit.printer.sink import FileSink, OutputFormat, StdoutSink, resolve_sink From 65fedabc0f0e076062b06d50998fb8ed7b3facda Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 22:24:20 -0700 Subject: [PATCH 19/34] Add copilot instructions for printer module Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/instructions/printer.instructions.md | 106 +++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 .github/instructions/printer.instructions.md diff --git a/.github/instructions/printer.instructions.md b/.github/instructions/printer.instructions.md new file mode 100644 index 0000000000..f46d9c446c --- /dev/null +++ b/.github/instructions/printer.instructions.md @@ -0,0 +1,106 @@ +--- +applyTo: "pyrit/printer/**" +--- + +# PyRIT Printer Module Guidelines + +The printer module renders attack results, scenario results, and scorer information. It separates **what** the output looks like (format) from **where** it goes (sink) and **where data comes from** (abstract methods). + +## Architecture + +### Three-layer hierarchy per domain + +``` +DomainPrinterBase(PrinterBase) # base.py — abstract data methods + write_async + ├─ PrettyDomainPrinter # pretty.py — ANSI formatting, returns str + │ └─ PrettyDomainMemoryPrinter # same file — fetches data via CentralMemory + ├─ MarkdownDomainPrinter # markdown.py — Markdown formatting + │ └─ MarkdownDomainMemoryPrinter + └─ JsonDomainPrinter # json.py — structured JSON + └─ JsonDomainMemoryPrinter +``` + +- **Base** (`base.py`): declares abstract data-fetching methods and `write_async` +- **Format** (`pretty.py`, `markdown.py`, `json.py`): all rendering logic, builds `str`, writes to sink — **no data I/O here** +- **Leaf** (e.g., `PrettyAttackResultMemoryPrinter`): implements abstract data methods via `CentralMemory` — **no formatting logic here** + +### Sink — where output goes + +`Sink` ABC in `sink.py`. Printers take a `Sink` in their constructor (default: `StdoutSink`). + +```python +class Sink(ABC): + async def write_async(self, data: str) -> None: ... +``` + +Current sinks: `StdoutSink`, `FileSink`. Add new sinks as needed (IPython, Blob, etc.). + +### PrinterBase — common base + +All printers inherit `PrinterBase`. It provides: +- `sink` constructor param (default `StdoutSink`) +- `_write_async(data: str)` to write through the sink +- Abstract `write_async(...)` as the **public entry point** (signature varies per domain) + +## Key Rules + +### Output goes through the sink — never call `print()` directly + +All `_render_*` methods return `str`. The `write_async` entry point concatenates renders and calls `_write_async(content)`. No bare `print()` calls anywhere in the printer module except inside `StdoutSink`. + +### Data fetching belongs in leaf classes only + +Format classes (`PrettyAttackResultPrinter`, `MarkdownAttackResultPrinter`) must not import or reference `CentralMemory`. Only `*MemoryPrinter` leaf classes do data I/O. + +### File names = output format + +- `pretty.py` — ANSI-colored human-readable +- `markdown.py` — Markdown +- `json.py` — structured JSON + +### `write_async` is the only public entry point + +Each printer has one public method: `write_async(...)`. Old methods like `print_result_async`, `print_summary_async`, `print_objective_scorer` are deprecated wrappers that call `write_async`. + +### All other methods are private + +Prefix with `_`: `_format_colored`, `_render_header`, `_render_summary_async`, `_get_conversation_async`, `_get_scores_async`, etc. + +### Memory leaf classes must work with zero args + +```python +printer = PrettyAttackResultMemoryPrinter() # defaults: StdoutSink, matching sub-printers +await printer.write_async(result) +``` + +Pass `sink=` to redirect output. Pass sub-printers only to override defaults. + +### Convenience functions live in `helpers.py` + +```python +from pyrit.printer.helpers import print_attack_result_async +await print_attack_result_async(result, format="pretty", to=Path("out.txt")) +``` + +`helpers.py` resolves `format` → printer class, `to` → sink, and calls `write_async`. + +## Adding a New Format + +1. Create `/.py` (e.g., `attack_result/json.py`) +2. Subclass the domain base (e.g., `AttackResultPrinterBase`) +3. Implement `write_async` — build a `str` from `_render_*` methods, call `_write_async` +4. Add a `*MemoryPrinter` leaf class that implements the abstract data methods +5. Register in `helpers.py` format dispatch + +## Adding a New Sink + +1. Subclass `Sink` in `sink.py` +2. Implement `write_async(self, data: str) -> None` +3. Users pass it via `sink=MySink()` on any printer constructor + +## Adding a New Domain Printer + +1. Create `pyrit/printer//base.py` with abstract data methods + `write_async` +2. Create format files (`pretty.py`, etc.) with rendering logic +3. Add Memory leaf classes +4. Add convenience function in `helpers.py` From fdfabccc0522d0482944ac3d69d53b2301e4d1c2 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 23:00:02 -0700 Subject: [PATCH 20/34] =?UTF-8?q?Add=20render=5Fasync/write=5Fasync=20cont?= =?UTF-8?q?ract,=20rename=20to=E2=86=92sink,=20add=20deprecation=20warning?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/printer/attack_result/markdown.py | 21 ++++++++++++++------ pyrit/printer/attack_result/pretty.py | 20 +++++++++++++++---- pyrit/printer/base.py | 22 +++++++++++++++------ pyrit/printer/helpers.py | 26 ++++++++++++------------- pyrit/printer/scenario_result/base.py | 9 +++++++-- pyrit/printer/scenario_result/pretty.py | 23 +++++++++++++--------- pyrit/printer/scorer/base.py | 22 ++++++++++++++------- pyrit/printer/scorer/pretty.py | 11 +++++++---- tests/unit/printer/test_convenience.py | 2 +- tests/unit/printer/test_printer_base.py | 23 ++++++++++++++++------ tests/unit/score/test_scorer_printer.py | 6 +++--- 11 files changed, 124 insertions(+), 61 deletions(-) diff --git a/pyrit/printer/attack_result/markdown.py b/pyrit/printer/attack_result/markdown.py index 88038839ed..b05d3e5294 100644 --- a/pyrit/printer/attack_result/markdown.py +++ b/pyrit/printer/attack_result/markdown.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import os +import warnings from datetime import datetime, timezone from pyrit.models import AttackResult, ConversationType, Message, MessagePiece, Score @@ -68,16 +69,16 @@ def _format_score(self, score: Score, indent: str = "") -> str: return "\n".join(lines) - async def write_async( + async def render_async( self, result: AttackResult, *, include_auxiliary_scores: bool = False, include_pruned_conversations: bool = False, include_adversarial_conversation: bool = False, - ) -> None: + ) -> str: """ - Render and write the complete attack result as markdown to the sink. + Render the complete attack result as markdown and return it as a string. Args: result (AttackResult): The attack result to render. @@ -85,6 +86,9 @@ async def write_async( include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. include_adversarial_conversation (bool): Whether to include the adversarial conversation. Defaults to False. + + Returns: + str: The rendered markdown text. """ markdown_lines: list[str] = [] @@ -125,7 +129,7 @@ async def write_async( timestamp_utc = datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z") markdown_lines.append(f"*Report generated at {timestamp_utc}*") - await self._write_async("\n".join(markdown_lines)) + return "\n".join(markdown_lines) async def print_result_async( self, @@ -136,6 +140,7 @@ async def print_result_async( include_adversarial_conversation: bool = False, ) -> None: """Deprecated. Use write_async instead.""" + warnings.warn("print_result_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2) await self.write_async( result, include_auxiliary_scores=include_auxiliary_scores, @@ -144,12 +149,16 @@ async def print_result_async( ) async def print_conversation_async(self, result: AttackResult, *, include_scores: bool = False) -> None: - """Deprecated. Use _get_conversation_markdown_async and _write_async instead.""" + """Deprecated. Use write_async instead.""" + warnings.warn( + "print_conversation_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2 + ) markdown_lines = await self._get_conversation_markdown_async(result=result, include_scores=include_scores) await self._write_async("\n".join(markdown_lines)) async def print_summary_async(self, result: AttackResult) -> None: - """Deprecated. Use _get_summary_markdown_async and _write_async instead.""" + """Deprecated. Use write_async instead.""" + warnings.warn("print_summary_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2) markdown_lines = await self._get_summary_markdown_async(result) await self._write_async("\n".join(markdown_lines)) diff --git a/pyrit/printer/attack_result/pretty.py b/pyrit/printer/attack_result/pretty.py index b8adee7169..b99622ccd3 100644 --- a/pyrit/printer/attack_result/pretty.py +++ b/pyrit/printer/attack_result/pretty.py @@ -3,6 +3,7 @@ import json import textwrap +import warnings from datetime import datetime, timezone from typing import Any @@ -54,16 +55,16 @@ def _format_colored(self, text: str, *colors: str) -> str: return f"{color_prefix}{text}{Style.RESET_ALL}\n" return f"{text}\n" - async def write_async( + async def render_async( self, result: AttackResult, *, include_auxiliary_scores: bool = False, include_pruned_conversations: bool = False, include_adversarial_conversation: bool = False, - ) -> None: + ) -> str: """ - Render and write the complete attack result to the sink. + Render the complete attack result and return it as a string. Args: result (AttackResult): The attack result to render. @@ -71,6 +72,9 @@ async def write_async( include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. include_adversarial_conversation (bool): Whether to include the adversarial conversation. Defaults to False. + + Returns: + str: The rendered attack result text. """ lines: list[str] = [] lines.append(self._render_header(result)) @@ -84,7 +88,7 @@ async def write_async( if result.metadata: lines.append(self._render_metadata(result.metadata)) lines.append(self._render_footer()) - await self._write_async("".join(lines)) + return "".join(lines) async def print_result_async( self, @@ -95,6 +99,7 @@ async def print_result_async( include_adversarial_conversation: bool = False, ) -> None: """Deprecated. Use write_async instead.""" + warnings.warn("print_result_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2) await self.write_async( result, include_auxiliary_scores=include_auxiliary_scores, @@ -136,6 +141,9 @@ async def print_conversation_async( self, result: AttackResult, *, include_scores: bool = False, include_reasoning_trace: bool = False ) -> None: """Deprecated. Use write_async instead.""" + warnings.warn( + "print_conversation_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2 + ) content = await self._render_conversation_async( result, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace ) @@ -256,6 +264,9 @@ async def print_messages_async( include_reasoning_trace: bool = False, ) -> None: """Deprecated. Use write_async instead.""" + warnings.warn( + "print_messages_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2 + ) content = await self._render_messages_async( messages=messages, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace ) @@ -338,6 +349,7 @@ async def _render_summary_async(self, result: AttackResult) -> str: async def print_summary_async(self, result: AttackResult) -> None: """Deprecated. Use write_async instead.""" + warnings.warn("print_summary_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2) content = await self._render_summary_async(result) await self._write_async(content) diff --git a/pyrit/printer/base.py b/pyrit/printer/base.py index 7e8692e4cb..fde68ccdbf 100644 --- a/pyrit/printer/base.py +++ b/pyrit/printer/base.py @@ -10,9 +10,9 @@ class PrinterBase(ABC): """ Abstract base class for all printers. - Provides a sink for output routing. Subclasses must implement - ``write_async`` as their public entry point, and use ``_write_async`` - to send rendered text to the sink. + Subclasses implement ``render_async`` to produce formatted text. + ``write_async`` is concrete: it calls ``render_async`` then routes + the result through the configured sink. """ def __init__(self, *, sink: Sink | None = None) -> None: @@ -25,13 +25,23 @@ def __init__(self, *, sink: Sink | None = None) -> None: self._sink = sink or StdoutSink() @abstractmethod - async def write_async(self, *args, **kwargs) -> None: + async def render_async(self, *args, **kwargs) -> str: """ - Render and write output to the configured sink. + Render output and return it as a string. Subclasses define the specific signature (e.g., scorer_identifier, - result, etc.). + result, messages, etc.). + """ + + async def write_async(self, *args, **kwargs) -> None: + """ + Render output and write it to the configured sink. + + Calls ``render_async`` with all arguments, then writes the result + through the sink. Subclasses should not override this method. """ + content = await self.render_async(*args, **kwargs) + await self._write_async(content) async def _write_async(self, data: str) -> None: """ diff --git a/pyrit/printer/helpers.py b/pyrit/printer/helpers.py index ded9ea7fb0..4687a62c89 100644 --- a/pyrit/printer/helpers.py +++ b/pyrit/printer/helpers.py @@ -12,7 +12,7 @@ async def print_attack_result_async( result: "AttackResult", # noqa: F821 *, format: OutputFormat = "pretty", - to: Path | str | Sink | None = None, + sink: Path | str | Sink | None = None, include_auxiliary_scores: bool = False, include_pruned_conversations: bool = False, include_adversarial_conversation: bool = False, @@ -23,22 +23,22 @@ async def print_attack_result_async( Args: result (AttackResult): The attack result to print. format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + sink (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. include_adversarial_conversation (bool): Whether to include the adversarial conversation. Defaults to False. """ - sink = resolve_sink(to) + resolved_sink = resolve_sink(sink) if format == "markdown": from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter - printer = MarkdownAttackResultMemoryPrinter(sink=sink) + printer = MarkdownAttackResultMemoryPrinter(sink=resolved_sink) else: from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter - printer = PrettyAttackResultMemoryPrinter(sink=sink) + printer = PrettyAttackResultMemoryPrinter(sink=resolved_sink) await printer.write_async( result, @@ -52,7 +52,7 @@ async def print_scenario_result_async( result: "ScenarioResult", # noqa: F821 *, format: OutputFormat = "pretty", - to: Path | str | Sink | None = None, + sink: Path | str | Sink | None = None, ) -> None: """ Print a scenario result in the specified format to the specified destination. @@ -60,14 +60,14 @@ async def print_scenario_result_async( Args: result (ScenarioResult): The scenario result to print. format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + sink (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. """ - sink = resolve_sink(to) + resolved_sink = resolve_sink(sink) if format == "pretty": from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter - printer = PrettyScenarioResultMemoryPrinter(sink=sink) + printer = PrettyScenarioResultMemoryPrinter(sink=resolved_sink) else: raise ValueError(f"Unsupported format for scenario results: {format!r}. Only 'pretty' is available.") @@ -79,7 +79,7 @@ async def print_scorer_async( scorer_identifier: "ComponentIdentifier", # noqa: F821 harm_category: str | None = None, format: OutputFormat = "pretty", - to: Path | str | Sink | None = None, + sink: Path | str | Sink | None = None, ) -> None: """ Print scorer information in the specified format to the specified destination. @@ -91,14 +91,14 @@ async def print_scorer_async( scorer_identifier (ComponentIdentifier): The scorer identifier. harm_category (str | None): The harm category. None for objective scorers. format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - to (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + sink (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. """ - sink = resolve_sink(to) + resolved_sink = resolve_sink(sink) if format == "pretty": from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter - printer = PrettyScorerMemoryPrinter(sink=sink) + printer = PrettyScorerMemoryPrinter(sink=resolved_sink) else: raise ValueError(f"Unsupported format for scorer: {format!r}. Only 'pretty' is available.") diff --git a/pyrit/printer/scenario_result/base.py b/pyrit/printer/scenario_result/base.py index 1b6d1c1843..1af4d50521 100644 --- a/pyrit/printer/scenario_result/base.py +++ b/pyrit/printer/scenario_result/base.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import warnings from abc import abstractmethod from pyrit.models.scenario_result import ScenarioResult @@ -16,12 +17,15 @@ class ScenarioResultPrinterBase(PrinterBase): """ @abstractmethod - async def write_async(self, result: ScenarioResult) -> None: + async def render_async(self, result: ScenarioResult) -> str: """ - Render and write a scenario result summary to the configured sink. + Render a scenario result summary and return it as a string. Args: result (ScenarioResult): The scenario result to summarize. + + Returns: + str: The rendered scenario result text. """ async def print_summary_async(self, result: ScenarioResult) -> None: @@ -31,4 +35,5 @@ async def print_summary_async(self, result: ScenarioResult) -> None: Args: result (ScenarioResult): The scenario result to summarize. """ + warnings.warn("print_summary_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2) await self.write_async(result) diff --git a/pyrit/printer/scenario_result/pretty.py b/pyrit/printer/scenario_result/pretty.py index e3f72c7685..6ff5e3c4b7 100644 --- a/pyrit/printer/scenario_result/pretty.py +++ b/pyrit/printer/scenario_result/pretty.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import textwrap +import warnings from colorama import Fore, Style @@ -123,15 +124,19 @@ def _get_rate_color(self, rate: int) -> str: return str(Fore.CYAN) return str(Fore.GREEN) - async def write_async(self, result: ScenarioResult) -> None: + async def render_async(self, result: ScenarioResult) -> str: """ - Render and write the scenario result summary to the configured sink. + Render the scenario result summary and return it as a string. Args: result (ScenarioResult): The scenario result to summarize. + + Returns: + str: The rendered scenario result text. """ - lines: list[str] = [] + parts: list[str] = [] + lines: list[str] = [] lines.append(self._render_header(result)) lines.append(self._render_section_header("Scenario Information")) @@ -166,15 +171,12 @@ async def write_async(self, result: ScenarioResult) -> None: lines.append(self._format_colored(f"{self._indent * 2}• Target Type: {target_type}", Fore.CYAN)) lines.append(self._format_colored(f"{self._indent * 2}• Target Model: {target_model}", Fore.CYAN)) lines.append(self._format_colored(f"{self._indent * 2}• Target Endpoint: {target_endpoint}", Fore.CYAN)) - - # Write what we have so far, then let the scorer printer write its own section - await self._write_async("".join(lines)) + parts.append("".join(lines)) scorer_identifier = result.objective_scorer_identifier if scorer_identifier: - await self._scorer_printer.write_async(scorer_identifier=scorer_identifier) + parts.append(await self._scorer_printer.render_async(scorer_identifier=scorer_identifier)) - # Continue with stats lines = [] lines.append(self._render_section_header("Overall Statistics")) total_results = sum(len(results) for results in result.attack_results.values()) @@ -212,7 +214,9 @@ async def write_async(self, result: ScenarioResult) -> None: )) lines.append(self._render_footer()) - await self._write_async("".join(lines)) + parts.append("".join(lines)) + + return "".join(parts) async def print_summary_async(self, result: ScenarioResult) -> None: """ @@ -221,6 +225,7 @@ async def print_summary_async(self, result: ScenarioResult) -> None: Args: result (ScenarioResult): The scenario result to summarize. """ + warnings.warn("print_summary_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2) await self.write_async(result) diff --git a/pyrit/printer/scorer/base.py b/pyrit/printer/scorer/base.py index 88e3fd4c15..24a8bd22f9 100644 --- a/pyrit/printer/scorer/base.py +++ b/pyrit/printer/scorer/base.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import warnings from abc import abstractmethod from typing import Any @@ -17,24 +18,24 @@ class ScorerPrinterBase(PrinterBase): """ @abstractmethod - def _get_objective_metrics(self, *, eval_hash: str) -> Any: + def _get_objective_metrics(self, *, scorer_identifier: ComponentIdentifier) -> Any: """ Fetch objective scorer evaluation metrics. Args: - eval_hash (str): The evaluation hash to look up. + scorer_identifier (ComponentIdentifier): The scorer identifier. Returns: The metrics object, or None if not found. """ @abstractmethod - def _get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: + def _get_harm_metrics(self, *, scorer_identifier: ComponentIdentifier, harm_category: str) -> Any: """ Fetch harm scorer evaluation metrics. Args: - eval_hash (str): The evaluation hash to look up. + scorer_identifier (ComponentIdentifier): The scorer identifier. harm_category (str): The harm category to look up. Returns: @@ -42,11 +43,11 @@ def _get_harm_metrics(self, *, eval_hash: str, harm_category: str) -> Any: """ @abstractmethod - async def write_async( + async def render_async( self, *, scorer_identifier: ComponentIdentifier, harm_category: str | None = None - ) -> None: + ) -> str: """ - Render and write scorer information to the configured sink. + Render scorer information and return it as a string. Auto-detects scorer type: if harm_category is provided, renders harm metrics; otherwise renders objective metrics. @@ -54,6 +55,9 @@ async def write_async( Args: scorer_identifier (ComponentIdentifier): The scorer identifier. harm_category (str | None): The harm category. None for objective scorers. + + Returns: + str: The rendered scorer information text. """ async def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: @@ -63,6 +67,9 @@ async def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier Args: scorer_identifier (ComponentIdentifier): The scorer identifier. """ + warnings.warn( + "print_objective_scorer is deprecated, use write_async instead", DeprecationWarning, stacklevel=2 + ) await self.write_async(scorer_identifier=scorer_identifier) async def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, harm_category: str) -> None: @@ -73,4 +80,5 @@ async def print_harm_scorer(self, *, scorer_identifier: ComponentIdentifier, har scorer_identifier (ComponentIdentifier): The scorer identifier. harm_category (str): The harm category. """ + warnings.warn("print_harm_scorer is deprecated, use write_async instead", DeprecationWarning, stacklevel=2) await self.write_async(scorer_identifier=scorer_identifier, harm_category=harm_category) diff --git a/pyrit/printer/scorer/pretty.py b/pyrit/printer/scorer/pretty.py index ef08d44513..570f9241b2 100644 --- a/pyrit/printer/scorer/pretty.py +++ b/pyrit/printer/scorer/pretty.py @@ -244,11 +244,11 @@ def _render_harm_metrics(self, metrics: Optional[Any]) -> str: return "".join(lines) - async def write_async( + async def render_async( self, *, scorer_identifier: ComponentIdentifier, harm_category: str | None = None - ) -> None: + ) -> str: """ - Render and write scorer information to the configured sink. + Render scorer information and return it as a string. Auto-detects scorer type: if harm_category is provided, renders harm metrics; otherwise renders objective metrics. @@ -256,6 +256,9 @@ async def write_async( Args: scorer_identifier (ComponentIdentifier): The scorer identifier. harm_category (str | None): The harm category. None for objective scorers. + + Returns: + str: The rendered scorer information text. """ from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier @@ -273,7 +276,7 @@ async def write_async( metrics = self._get_objective_metrics(eval_hash=eval_hash) lines.append(self._render_objective_metrics(metrics)) - await self._write_async("".join(lines)) + return "".join(lines) class PrettyScorerMemoryPrinter(PrettyScorerPrinter): diff --git a/tests/unit/printer/test_convenience.py b/tests/unit/printer/test_convenience.py index f310b84c92..af959a699f 100644 --- a/tests/unit/printer/test_convenience.py +++ b/tests/unit/printer/test_convenience.py @@ -79,7 +79,7 @@ async def test_print_attack_result_async_to_file(mock_cls): path = Path(f.name) try: - await print_attack_result_async(result, to=path) + await print_attack_result_async(result, sink=path) call_kwargs = mock_cls.call_args[1] assert isinstance(call_kwargs["sink"], FileSink) finally: diff --git a/tests/unit/printer/test_printer_base.py b/tests/unit/printer/test_printer_base.py index 6c3b6b14a6..cd2287a817 100644 --- a/tests/unit/printer/test_printer_base.py +++ b/tests/unit/printer/test_printer_base.py @@ -18,8 +18,8 @@ class IncompletePrinter(PrinterBase): def test_printer_base_defaults_to_stdout_sink(): class ConcretePrinter(PrinterBase): - async def write_async(self) -> None: - pass + async def render_async(self) -> str: + return "" printer = ConcretePrinter() assert isinstance(printer._sink, StdoutSink) @@ -31,8 +31,8 @@ def test_printer_base_accepts_custom_sink(): from pyrit.printer.sink import FileSink class ConcretePrinter(PrinterBase): - async def write_async(self) -> None: - pass + async def render_async(self) -> str: + return "" sink = FileSink(path=Path("test.txt")) printer = ConcretePrinter(sink=sink) @@ -42,10 +42,21 @@ async def write_async(self) -> None: async def test_printer_base_write_async_delegates_to_sink(capsys): class ConcretePrinter(PrinterBase): - async def write_async(self) -> None: - await self._write_async("test output") + async def render_async(self) -> str: + return "test output" printer = ConcretePrinter() await printer.write_async() captured = capsys.readouterr() assert captured.out == "test output" + + +async def test_printer_base_render_async_returns_string(): + + class ConcretePrinter(PrinterBase): + async def render_async(self) -> str: + return "rendered content" + + printer = ConcretePrinter() + result = await printer.render_async() + assert result == "rendered content" diff --git a/tests/unit/score/test_scorer_printer.py b/tests/unit/score/test_scorer_printer.py index cd0a5b997d..faeca7f623 100644 --- a/tests/unit/score/test_scorer_printer.py +++ b/tests/unit/score/test_scorer_printer.py @@ -50,10 +50,10 @@ def _get_objective_metrics(self, *, eval_hash: str): def _get_harm_metrics(self, *, eval_hash: str, harm_category: str): return None - async def write_async( + async def render_async( self, *, scorer_identifier: ComponentIdentifier, harm_category: str | None = None - ) -> None: - pass + ) -> str: + return "" printer = CompletePrinter() assert isinstance(printer, ScorerPrinter) From bd6f02104f1eb4bb781b819ccb1d949509a2baf8 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 23:07:10 -0700 Subject: [PATCH 21/34] Extract conversation and score printers, slim attack_result/pretty.py Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/printer/attack_result/pretty.py | 277 +++------------- pyrit/printer/conversation/__init__.py | 4 + pyrit/printer/conversation/base.py | 56 ++++ pyrit/printer/conversation/pretty.py | 303 ++++++++++++++++++ pyrit/printer/score/__init__.py | 4 + pyrit/printer/score/pretty.py | 111 +++++++ .../attack/printer/test_pretty_printer.py | 28 +- 7 files changed, 537 insertions(+), 246 deletions(-) create mode 100644 pyrit/printer/conversation/__init__.py create mode 100644 pyrit/printer/conversation/base.py create mode 100644 pyrit/printer/conversation/pretty.py create mode 100644 pyrit/printer/score/__init__.py create mode 100644 pyrit/printer/score/pretty.py diff --git a/pyrit/printer/attack_result/pretty.py b/pyrit/printer/attack_result/pretty.py index b99622ccd3..57e2c07a08 100644 --- a/pyrit/printer/attack_result/pretty.py +++ b/pyrit/printer/attack_result/pretty.py @@ -1,16 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import json -import textwrap import warnings from datetime import datetime, timezone from typing import Any from colorama import Back, Fore, Style -from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, MessagePiece, Score +from pyrit.models import AttackOutcome, AttackResult, ConversationType, Message, Score from pyrit.printer.attack_result.base import AttackResultPrinterBase +from pyrit.printer.conversation.pretty import PrettyConversationPrinter +from pyrit.printer.score.pretty import PrettyScorePrinter from pyrit.printer.sink import Sink @@ -18,12 +18,19 @@ class PrettyAttackResultPrinter(AttackResultPrinterBase): """ Pretty printer for attack results with ANSI-colored formatting. - Contains all formatting logic. Subclasses implement get_conversation_async - and get_scores_async for data fetching. + Composes a conversation printer for message rendering and a score printer + for inline score display. Subclasses implement data-fetching methods. """ def __init__( - self, *, sink: Sink | None = None, width: int = 100, indent_size: int = 2, enable_colors: bool = True + self, + *, + sink: Sink | None = None, + width: int = 100, + indent_size: int = 2, + enable_colors: bool = True, + conversation_printer: PrettyConversationPrinter | None = None, + score_printer: PrettyScorePrinter | None = None, ) -> None: """ Initialize the pretty printer. @@ -33,11 +40,22 @@ def __init__( width (int): Maximum width for text wrapping. Defaults to 100. indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + conversation_printer (PrettyConversationPrinter | None): Conversation printer. + Defaults to a new PrettyConversationPrinter with matching settings. + score_printer (PrettyScorePrinter | None): Score printer. + Defaults to a new PrettyScorePrinter with matching settings. """ super().__init__(sink=sink) self._width = width self._indent = " " * indent_size self._enable_colors = enable_colors + self._score_printer = score_printer or PrettyScorePrinter( + sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors + ) + self._conversation_printer = conversation_printer or PrettyConversationPrinter( + sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors, + score_printer=self._score_printer, + ) def _format_colored(self, text: str, *colors: str) -> str: """ @@ -131,8 +149,8 @@ async def _render_conversation_async( f"{self._indent} No conversation found for ID: {result.conversation_id}", Fore.YELLOW ) - return await self._render_messages_async( - messages=messages, + return await self._conversation_printer.render_async( + messages, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace, ) @@ -149,113 +167,6 @@ async def print_conversation_async( ) await self._write_async(content) - async def _render_messages_async( - self, - messages: list[Message], - *, - include_scores: bool = False, - include_reasoning_trace: bool = False, - ) -> str: - """ - Render a list of messages as a formatted string. - - Args: - messages (list[Message]): List of Message objects to render. - include_scores (bool): Whether to include scores. Defaults to False. - include_reasoning_trace (bool): Whether to include model reasoning trace. Defaults to False. - - Returns: - str: The rendered messages text. - """ - if not messages: - return self._format_colored(f"{self._indent} No messages to display.", Fore.YELLOW) - - lines: list[str] = [] - image_pieces: list[MessagePiece] = [] - turn_number = 0 - for message in messages: - if message.api_role == "user": - turn_number += 1 - lines.append("\n") - lines.append(self._format_colored("─" * self._width, Fore.BLUE)) - lines.append(self._format_colored(f"🔹 Turn {turn_number} - USER", Style.BRIGHT, Fore.BLUE)) - lines.append(self._format_colored("─" * self._width, Fore.BLUE)) - elif message.api_role == "system": - lines.append("\n") - lines.append(self._format_colored("─" * self._width, Fore.MAGENTA)) - lines.append(self._format_colored("🔧 SYSTEM", Style.BRIGHT, Fore.MAGENTA)) - lines.append(self._format_colored("─" * self._width, Fore.MAGENTA)) - else: - lines.append("\n") - lines.append(self._format_colored("─" * self._width, Fore.YELLOW)) - role_label = "ASSISTANT (SIMULATED)" if message.is_simulated else message.api_role.upper() - lines.append(self._format_colored(f"🔸 {role_label}", Style.BRIGHT, Fore.YELLOW)) - lines.append(self._format_colored("─" * self._width, Fore.YELLOW)) - - for piece in message.message_pieces: - if piece.original_value_data_type == "reasoning": - if include_reasoning_trace: - summary_text = self._extract_reasoning_summary(piece.original_value) - if summary_text: - lines.append(self._format_colored( - f"{self._indent}💭 Reasoning Summary:", Style.DIM, Fore.CYAN - )) - lines.append(self._render_wrapped_text(summary_text, Fore.CYAN)) - lines.append("\n") - continue - - if piece.is_blocked(): - lines.append(self._format_colored( - f"{self._indent}🚫 BLOCKED BY TARGET", Style.BRIGHT, Fore.RED - )) - partial_content = piece.prompt_metadata.get("partial_content") - if partial_content: - lines.append(self._format_colored( - f"{self._indent}📝 Partial content (before filter triggered):", - Style.DIM, - Fore.CYAN, - )) - lines.append(self._render_wrapped_text(str(partial_content), Fore.YELLOW)) - else: - lines.append(self._format_colored( - f"{self._indent}Content was blocked by the target's content filter.", - Style.DIM, - Fore.RED, - )) - - elif piece.converted_value != piece.original_value: - lines.append(self._format_colored(f"{self._indent} Original:", Fore.CYAN)) - lines.append(self._render_wrapped_text(piece.original_value, Fore.WHITE)) - lines.append("\n") - lines.append(self._format_colored(f"{self._indent} Converted:", Fore.CYAN)) - lines.append(self._render_wrapped_text(piece.converted_value, Fore.WHITE)) - elif piece.api_role == "user": - lines.append(self._render_wrapped_text(piece.converted_value, Fore.BLUE)) - elif piece.api_role == "system": - lines.append(self._render_wrapped_text(piece.converted_value, Fore.MAGENTA)) - else: - lines.append(self._render_wrapped_text(piece.converted_value, Fore.YELLOW)) - - image_pieces.append(piece) - - if include_scores: - scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) - if scores: - lines.append("\n") - lines.append(self._format_colored( - f"{self._indent}📊 Scores:", Style.DIM, Fore.MAGENTA - )) - for score in scores: - lines.append(self._render_score(score)) - - lines.append("\n") - lines.append(self._format_colored("─" * self._width, Fore.BLUE)) - - for piece in image_pieces: - await self._display_image_async(piece) - - return "".join(lines) - async def print_messages_async( self, messages: list[Message], @@ -263,37 +174,15 @@ async def print_messages_async( include_scores: bool = False, include_reasoning_trace: bool = False, ) -> None: - """Deprecated. Use write_async instead.""" + """Deprecated. Use the conversation printer's write_async instead.""" warnings.warn( "print_messages_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2 ) - content = await self._render_messages_async( - messages=messages, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace + content = await self._conversation_printer.render_async( + messages, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace ) await self._write_async(content) - def _extract_reasoning_summary(self, reasoning_value: str) -> str: - """ - Extract human-readable summary text from a reasoning piece's JSON value. - - Args: - reasoning_value (str): The JSON string stored in the reasoning piece. - - Returns: - str: The concatenated summary text, or empty string if no summary is present. - """ - try: - data = json.loads(reasoning_value) - except (json.JSONDecodeError, TypeError): - return "" - - summary = data.get("summary") if isinstance(data, dict) else None - if not summary or not isinstance(summary, list): - return "" - - parts = [item.get("text", "") for item in summary if isinstance(item, dict) and item.get("text")] - return "\n".join(parts) - async def _render_summary_async(self, result: AttackResult) -> str: """ Render a summary of the attack result. @@ -343,7 +232,7 @@ async def _render_summary_async(self, result: AttackResult) -> str: if result.last_score: lines.append("\n") lines.append(self._format_colored(f"{self._indent} Final Score", Style.BRIGHT)) - lines.append(self._render_score(result.last_score, indent_level=2)) + lines.append(self._score_printer._render_score(result.last_score, indent_level=2)) return "".join(lines) @@ -421,87 +310,6 @@ def _render_metadata(self, metadata: dict[str, Any]) -> str: lines.append(self._format_colored(f"{self._indent}• {key}: {value}", Fore.CYAN)) return "".join(lines) - def _render_score(self, score: Score, indent_level: int = 3) -> str: - """ - Render a score with proper formatting. - - Args: - score (Score): Score object to be rendered. - indent_level (int): Number of indent units to apply. Defaults to 3. - - Returns: - str: The rendered score text. - """ - lines: list[str] = [] - indent = self._indent * indent_level - scorer_name = score.scorer_class_identifier.class_name - lines.append(f"{indent}Scorer: {scorer_name}\n") - lines.append(self._format_colored(f"{indent}• Category: {score.score_category or 'N/A'}", Fore.LIGHTMAGENTA_EX)) - lines.append(self._format_colored(f"{indent}• Type: {score.score_type}", Fore.CYAN)) - - if score.score_type == "true_false": - score_color = Fore.GREEN if score.get_value() else Fore.RED - else: - score_color = Fore.YELLOW - - lines.append(self._format_colored(f"{indent}• Value: {score.score_value}", score_color)) - - if score.score_rationale: - lines.append(f"{indent}• Rationale:\n") - rationale_wrapper = textwrap.TextWrapper( - width=self._width - len(indent) - 2, - initial_indent=indent + " ", - subsequent_indent=indent + " ", - break_long_words=False, - break_on_hyphens=False, - ) - rationale_lines = score.score_rationale.split("\n") - for line in rationale_lines: - if line.strip(): - wrapped_lines = rationale_wrapper.wrap(line) - for wrapped_line in wrapped_lines: - lines.append(self._format_colored(wrapped_line, Fore.WHITE)) - else: - lines.append(self._format_colored(f"{indent} ")) - - return "".join(lines) - - def _render_wrapped_text(self, text: str, color: str) -> str: - """ - Render text with proper wrapping and indentation, preserving newlines. - - Args: - text (str): The text to render. - color (str): Colorama color constant to apply. - - Returns: - str: The rendered wrapped text. - """ - lines: list[str] = [] - text_wrapper = textwrap.TextWrapper( - width=self._width - len(self._indent), - initial_indent="", - subsequent_indent=self._indent, - break_long_words=True, - break_on_hyphens=True, - expand_tabs=False, - replace_whitespace=False, - ) - - text_lines = text.split("\n") - for line_num, line in enumerate(text_lines): - if line.strip(): - wrapped_lines = text_wrapper.wrap(line) - for i, wrapped_line in enumerate(wrapped_lines): - if line_num == 0 and i == 0: - lines.append(self._format_colored(f"{self._indent}{wrapped_line}", color)) - else: - lines.append(self._format_colored(f"{self._indent * 2}{wrapped_line}", color)) - else: - lines.append(self._format_colored(f"{self._indent}", color)) - - return "".join(lines) - async def _render_pruned_conversations_async(self, result: AttackResult) -> str: """ Render pruned conversations showing only the last message and score for each. @@ -544,14 +352,14 @@ async def _render_pruned_conversations_async(self, result: AttackResult) -> str: )) for piece in last_message.message_pieces: - lines.append(self._render_wrapped_text(piece.converted_value, Fore.WHITE)) + lines.append(self._conversation_printer._render_wrapped_text(piece.converted_value, Fore.WHITE)) scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) if scores: lines.append("\n") lines.append(self._format_colored(f"{self._indent}📊 Score:", Style.DIM, Fore.MAGENTA)) for score in scores: - lines.append(self._render_score(score)) + lines.append(self._score_printer._render_score(score)) lines.append("\n") lines.append(self._format_colored("─" * self._width, Fore.RED)) @@ -597,7 +405,7 @@ async def _render_adversarial_conversation_async(self, result: AttackResult) -> )) continue - lines.append(await self._render_messages_async(messages=messages, include_scores=False)) + lines.append(await self._conversation_printer.render_async(messages, include_scores=False)) return "".join(lines) @@ -619,12 +427,6 @@ def _get_outcome_color(self, outcome: AttackOutcome) -> str: }.get(outcome, Fore.WHITE) ) - async def _display_image_async(self, piece: MessagePiece) -> None: - """Display images using PIL/IPython in notebook environments.""" - from pyrit.common.display_response import display_image_response - - await display_image_response(piece) - class PrettyAttackResultMemoryPrinter(PrettyAttackResultPrinter): """ @@ -646,9 +448,20 @@ def __init__( indent_size (int): Number of spaces for indentation. Defaults to 2. enable_colors (bool): Whether to enable ANSI color output. Defaults to True. """ - super().__init__(sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors) from pyrit.memory import CentralMemory + from pyrit.printer.conversation.pretty import PrettyConversationMemoryPrinter + score_printer = PrettyScorePrinter( + sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors + ) + conversation_printer = PrettyConversationMemoryPrinter( + sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors, + score_printer=score_printer, + ) + super().__init__( + sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors, + conversation_printer=conversation_printer, score_printer=score_printer, + ) self._memory = CentralMemory.get_memory_instance() async def _get_conversation_async(self, conversation_id: str) -> list[Message]: diff --git a/pyrit/printer/conversation/__init__.py b/pyrit/printer/conversation/__init__.py new file mode 100644 index 0000000000..b875221383 --- /dev/null +++ b/pyrit/printer/conversation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Conversation printer classes for rendering message histories.""" diff --git a/pyrit/printer/conversation/base.py b/pyrit/printer/conversation/base.py new file mode 100644 index 0000000000..a4ff944e3a --- /dev/null +++ b/pyrit/printer/conversation/base.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import abstractmethod + +from pyrit.models import Message, MessagePiece, Score +from pyrit.printer.base import PrinterBase + + +class ConversationPrinterBase(PrinterBase): + """ + Abstract base class for printing conversation message histories. + + Subclasses implement data-fetching methods (``_get_scores_async``, + ``_display_image_async``) and rendering via ``render_async``. + """ + + @abstractmethod + async def _get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: + """ + Fetch scores for given prompt piece IDs. + + Args: + prompt_ids (list[str]): The message piece IDs to fetch scores for. + + Returns: + list[Score]: The scores associated with the given piece IDs. + """ + + async def _display_image_async(self, piece: MessagePiece) -> None: # noqa: B027 + """ + Display an image from a message piece. No-op by default. + + Args: + piece (MessagePiece): The message piece that may contain image data. + """ + + @abstractmethod + async def render_async( + self, + messages: list[Message], + *, + include_scores: bool = False, + include_reasoning_trace: bool = False, + ) -> str: + """ + Render a list of messages and return as a string. + + Args: + messages (list[Message]): The messages to render. + include_scores (bool): Whether to include scores. Defaults to False. + include_reasoning_trace (bool): Whether to include reasoning traces. Defaults to False. + + Returns: + str: The rendered conversation text. + """ diff --git a/pyrit/printer/conversation/pretty.py b/pyrit/printer/conversation/pretty.py new file mode 100644 index 0000000000..436202d80c --- /dev/null +++ b/pyrit/printer/conversation/pretty.py @@ -0,0 +1,303 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import textwrap + +from colorama import Fore, Style + +from pyrit.models import Message, MessagePiece, Score +from pyrit.printer.conversation.base import ConversationPrinterBase +from pyrit.printer.score.pretty import PrettyScorePrinter +from pyrit.printer.sink import Sink + + +class PrettyConversationPrinter(ConversationPrinterBase): + """ + Pretty printer for conversation message histories with ANSI-colored formatting. + + Contains all formatting logic. Subclasses implement ``_get_scores_async`` + and ``_display_image_async`` for data fetching. + """ + + def __init__( + self, + *, + sink: Sink | None = None, + width: int = 100, + indent_size: int = 2, + enable_colors: bool = True, + score_printer: PrettyScorePrinter | None = None, + ) -> None: + """ + Initialize the pretty conversation printer. + + Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). + width (int): Maximum width for text wrapping. Defaults to 100. + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + score_printer (PrettyScorePrinter | None): Score printer for inline score rendering. + Defaults to a new PrettyScorePrinter with matching settings. + """ + super().__init__(sink=sink) + self._width = width + self._indent = " " * indent_size + self._enable_colors = enable_colors + self._score_printer = score_printer or PrettyScorePrinter( + sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors + ) + + async def render_async( + self, + messages: list[Message], + *, + include_scores: bool = False, + include_reasoning_trace: bool = False, + ) -> str: + """ + Render a list of messages and return as a string. + + Args: + messages (list[Message]): The messages to render. + include_scores (bool): Whether to include scores. Defaults to False. + include_reasoning_trace (bool): Whether to include reasoning traces. Defaults to False. + + Returns: + str: The rendered conversation text. + """ + if not messages: + return self._format_colored(f"{self._indent} No messages to display.", Fore.YELLOW) + + lines: list[str] = [] + image_pieces: list[MessagePiece] = [] + turn_number = 0 + for message in messages: + if message.api_role == "user": + turn_number += 1 + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.BLUE)) + lines.append(self._format_colored(f"🔹 Turn {turn_number} - USER", Style.BRIGHT, Fore.BLUE)) + lines.append(self._format_colored("─" * self._width, Fore.BLUE)) + elif message.api_role == "system": + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.MAGENTA)) + lines.append(self._format_colored("🔧 SYSTEM", Style.BRIGHT, Fore.MAGENTA)) + lines.append(self._format_colored("─" * self._width, Fore.MAGENTA)) + else: + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.YELLOW)) + role_label = "ASSISTANT (SIMULATED)" if message.is_simulated else message.api_role.upper() + lines.append(self._format_colored(f"🔸 {role_label}", Style.BRIGHT, Fore.YELLOW)) + lines.append(self._format_colored("─" * self._width, Fore.YELLOW)) + + for piece in message.message_pieces: + if piece.original_value_data_type == "reasoning": + if include_reasoning_trace: + summary_text = self._extract_reasoning_summary(piece.original_value) + if summary_text: + lines.append(self._format_colored( + f"{self._indent}💭 Reasoning Summary:", Style.DIM, Fore.CYAN + )) + lines.append(self._render_wrapped_text(summary_text, Fore.CYAN)) + lines.append("\n") + continue + + if piece.is_blocked(): + lines.append(self._format_colored( + f"{self._indent}🚫 BLOCKED BY TARGET", Style.BRIGHT, Fore.RED + )) + partial_content = piece.prompt_metadata.get("partial_content") + if partial_content: + lines.append(self._format_colored( + f"{self._indent}📝 Partial content (before filter triggered):", + Style.DIM, + Fore.CYAN, + )) + lines.append(self._render_wrapped_text(str(partial_content), Fore.YELLOW)) + else: + lines.append(self._format_colored( + f"{self._indent}Content was blocked by the target's content filter.", + Style.DIM, + Fore.RED, + )) + + elif piece.converted_value != piece.original_value: + lines.append(self._format_colored(f"{self._indent} Original:", Fore.CYAN)) + lines.append(self._render_wrapped_text(piece.original_value, Fore.WHITE)) + lines.append("\n") + lines.append(self._format_colored(f"{self._indent} Converted:", Fore.CYAN)) + lines.append(self._render_wrapped_text(piece.converted_value, Fore.WHITE)) + elif piece.api_role == "user": + lines.append(self._render_wrapped_text(piece.converted_value, Fore.BLUE)) + elif piece.api_role == "system": + lines.append(self._render_wrapped_text(piece.converted_value, Fore.MAGENTA)) + else: + lines.append(self._render_wrapped_text(piece.converted_value, Fore.YELLOW)) + + image_pieces.append(piece) + + if include_scores: + scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) + if scores: + lines.append("\n") + lines.append(self._format_colored( + f"{self._indent}📊 Scores:", Style.DIM, Fore.MAGENTA + )) + for score in scores: + lines.append(self._score_printer._render_score(score)) + + lines.append("\n") + lines.append(self._format_colored("─" * self._width, Fore.BLUE)) + + for piece in image_pieces: + await self._display_image_async(piece) + + return "".join(lines) + + def _format_colored(self, text: str, *colors: str) -> str: + """ + Format text with color codes if colors are enabled. + + Args: + text (str): The text to format. + *colors: Variable number of colorama color constants to apply. + + Returns: + str: The formatted line with trailing newline. + """ + if self._enable_colors and colors: + color_prefix = "".join(colors) + return f"{color_prefix}{text}{Style.RESET_ALL}\n" + return f"{text}\n" + + def _render_wrapped_text(self, text: str, color: str) -> str: + """ + Render text with proper wrapping and indentation, preserving newlines. + + Args: + text (str): The text to render. + color (str): Colorama color constant to apply. + + Returns: + str: The rendered wrapped text. + """ + lines: list[str] = [] + text_wrapper = textwrap.TextWrapper( + width=self._width - len(self._indent), + initial_indent="", + subsequent_indent=self._indent, + break_long_words=True, + break_on_hyphens=True, + expand_tabs=False, + replace_whitespace=False, + ) + + text_lines = text.split("\n") + for line_num, line in enumerate(text_lines): + if line.strip(): + wrapped_lines = text_wrapper.wrap(line) + for i, wrapped_line in enumerate(wrapped_lines): + if line_num == 0 and i == 0: + lines.append(self._format_colored(f"{self._indent}{wrapped_line}", color)) + else: + lines.append(self._format_colored(f"{self._indent * 2}{wrapped_line}", color)) + else: + lines.append(self._format_colored(f"{self._indent}", color)) + + return "".join(lines) + + @staticmethod + def _extract_reasoning_summary(reasoning_value: str) -> str: + """ + Extract human-readable summary text from a reasoning piece's JSON value. + + Args: + reasoning_value (str): The JSON string stored in the reasoning piece. + + Returns: + str: The concatenated summary text, or empty string if no summary is present. + """ + try: + data = json.loads(reasoning_value) + except (json.JSONDecodeError, TypeError): + return "" + + summary = data.get("summary") if isinstance(data, dict) else None + if not summary or not isinstance(summary, list): + return "" + + parts = [item.get("text", "") for item in summary if isinstance(item, dict) and item.get("text")] + return "\n".join(parts) + + +class PrettyConversationMemoryPrinter(PrettyConversationPrinter): + """ + Framework pretty printer for conversation histories. + + Implements data-fetching via CentralMemory (deferred import). + All formatting logic lives in PrettyConversationPrinter. + """ + + def __init__( + self, + *, + sink: Sink | None = None, + width: int = 100, + indent_size: int = 2, + enable_colors: bool = True, + score_printer: PrettyScorePrinter | None = None, + ) -> None: + """ + Initialize the pretty conversation printer with CentralMemory data source. + + Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). + width (int): Maximum width for text wrapping. Defaults to 100. + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + score_printer (PrettyScorePrinter | None): Score printer for inline score rendering. + """ + super().__init__( + sink=sink, width=width, indent_size=indent_size, enable_colors=enable_colors, score_printer=score_printer + ) + from pyrit.memory import CentralMemory + + self._memory = CentralMemory.get_memory_instance() + + async def render_async( + self, + messages: list[Message], + *, + include_scores: bool = False, + include_reasoning_trace: bool = False, + ) -> str: + """ + Render a list of messages and return as a string. + + Args: + messages (list[Message]): The messages to render. + include_scores (bool): Whether to include scores. Defaults to False. + include_reasoning_trace (bool): Whether to include reasoning traces. Defaults to False. + + Returns: + str: The rendered conversation text. + """ + return await super().render_async( + messages, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace + ) + + async def _get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: + """ + Fetch scores from CentralMemory. + + Returns: + list[Score]: The scores. + """ + return list(self._memory.get_prompt_scores(prompt_ids=prompt_ids)) + + async def _display_image_async(self, piece: MessagePiece) -> None: + """Display images using PIL/IPython in notebook environments.""" + from pyrit.common.display_response import display_image_response + + await display_image_response(piece) diff --git a/pyrit/printer/score/__init__.py b/pyrit/printer/score/__init__.py new file mode 100644 index 0000000000..1bab79f286 --- /dev/null +++ b/pyrit/printer/score/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Score printer classes for rendering individual Score objects.""" diff --git a/pyrit/printer/score/pretty.py b/pyrit/printer/score/pretty.py new file mode 100644 index 0000000000..4afb3f306e --- /dev/null +++ b/pyrit/printer/score/pretty.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import textwrap + +from colorama import Fore, Style + +from pyrit.models import Score +from pyrit.printer.base import PrinterBase +from pyrit.printer.sink import Sink + + +class PrettyScorePrinter(PrinterBase): + """ + Pretty printer for individual Score objects with ANSI-colored formatting. + + Provides ``_render_score`` for inline use by other printers (e.g., + conversation and attack-result printers) and ``render_async`` / + ``write_async`` for standalone rendering of a list of scores. + """ + + def __init__( + self, *, sink: Sink | None = None, width: int = 100, indent_size: int = 2, enable_colors: bool = True + ) -> None: + """ + Initialize the pretty score printer. + + Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). + width (int): Maximum width for text wrapping. Defaults to 100. + indent_size (int): Number of spaces for indentation. Defaults to 2. + enable_colors (bool): Whether to enable ANSI color output. Defaults to True. + """ + super().__init__(sink=sink) + self._width = width + self._indent = " " * indent_size + self._enable_colors = enable_colors + + def _format_colored(self, text: str, *colors: str) -> str: + """ + Format text with color codes if colors are enabled. + + Args: + text (str): The text to format. + *colors: Variable number of colorama color constants to apply. + + Returns: + str: The formatted line with trailing newline. + """ + if self._enable_colors and colors: + color_prefix = "".join(colors) + return f"{color_prefix}{text}{Style.RESET_ALL}\n" + return f"{text}\n" + + def _render_score(self, score: Score, indent_level: int = 3) -> str: + """ + Render a single score with proper formatting. + + Args: + score (Score): Score object to be rendered. + indent_level (int): Number of indent units to apply. Defaults to 3. + + Returns: + str: The rendered score text. + """ + lines: list[str] = [] + indent = self._indent * indent_level + scorer_name = score.scorer_class_identifier.class_name + lines.append(f"{indent}Scorer: {scorer_name}\n") + lines.append(self._format_colored(f"{indent}• Category: {score.score_category or 'N/A'}", Fore.LIGHTMAGENTA_EX)) + lines.append(self._format_colored(f"{indent}• Type: {score.score_type}", Fore.CYAN)) + + if score.score_type == "true_false": + score_color = Fore.GREEN if score.get_value() else Fore.RED + else: + score_color = Fore.YELLOW + + lines.append(self._format_colored(f"{indent}• Value: {score.score_value}", score_color)) + + if score.score_rationale: + lines.append(f"{indent}• Rationale:\n") + rationale_wrapper = textwrap.TextWrapper( + width=self._width - len(indent) - 2, + initial_indent=indent + " ", + subsequent_indent=indent + " ", + break_long_words=False, + break_on_hyphens=False, + ) + rationale_lines = score.score_rationale.split("\n") + for line in rationale_lines: + if line.strip(): + wrapped_lines = rationale_wrapper.wrap(line) + for wrapped_line in wrapped_lines: + lines.append(self._format_colored(wrapped_line, Fore.WHITE)) + else: + lines.append(self._format_colored(f"{indent} ")) + + return "".join(lines) + + async def render_async(self, scores: list[Score], *, indent_level: int = 3) -> str: + """ + Render a list of scores and return as a string. + + Args: + scores (list[Score]): The scores to render. + indent_level (int): Number of indent units to apply. Defaults to 3. + + Returns: + str: The rendered scores text. + """ + return "".join(self._render_score(score, indent_level=indent_level) for score in scores) diff --git a/tests/unit/executor/attack/printer/test_pretty_printer.py b/tests/unit/executor/attack/printer/test_pretty_printer.py index 9f3f919399..057b8cb9bb 100644 --- a/tests/unit/executor/attack/printer/test_pretty_printer.py +++ b/tests/unit/executor/attack/printer/test_pretty_printer.py @@ -152,7 +152,7 @@ def test_render_metadata(printer): def test_render_score(printer, sample_score): - result = printer._render_score(sample_score) + result = printer._score_printer._render_score(sample_score) assert "MockScorer" in result assert "true_false" in result assert "true" in result @@ -169,7 +169,7 @@ def test_render_score_with_rationale(printer): message_piece_id=str(uuid.uuid4()), scorer_class_identifier=_mock_scorer_id(), ) - result = printer._render_score(score) + result = printer._score_printer._render_score(score) assert "Rationale" in result @@ -177,26 +177,26 @@ def test_extract_reasoning_summary_valid_json(printer): import json data = {"summary": [{"text": "First"}, {"text": "Second"}]} - result = printer._extract_reasoning_summary(json.dumps(data)) + result = printer._conversation_printer._extract_reasoning_summary(json.dumps(data)) assert result == "First\nSecond" def test_extract_reasoning_summary_invalid_json(printer): - result = printer._extract_reasoning_summary("not json") + result = printer._conversation_printer._extract_reasoning_summary("not json") assert result == "" def test_extract_reasoning_summary_no_summary_key(printer): import json - result = printer._extract_reasoning_summary(json.dumps({"other": "data"})) + result = printer._conversation_printer._extract_reasoning_summary(json.dumps({"other": "data"})) assert result == "" def test_extract_reasoning_summary_summary_not_list(printer): import json - result = printer._extract_reasoning_summary(json.dumps({"summary": "not a list"})) + result = printer._conversation_printer._extract_reasoning_summary(json.dumps({"summary": "not a list"})) assert result == "" @@ -214,13 +214,13 @@ async def test_render_conversation_async_no_messages(printer, mock_memory): async def test_render_messages_async_empty_list(printer): - content = await printer._render_messages_async(messages=[]) + content = await printer._conversation_printer.render_async(messages=[]) assert "No messages to display" in content @patch("pyrit.common.display_response.display_image_response", new_callable=AsyncMock) async def test_render_messages_async_user_message(mock_display, printer, sample_message): - content = await printer._render_messages_async(messages=[sample_message]) + content = await printer._conversation_printer.render_async(messages=[sample_message]) assert "Turn 1" in content assert "USER" in content assert "Hello world" in content @@ -235,7 +235,7 @@ async def test_render_messages_async_assistant_message(mock_display, printer): converted_value_data_type="text", ) msg = Message(message_pieces=[piece]) - content = await printer._render_messages_async(messages=[msg]) + content = await printer._conversation_printer.render_async(messages=[msg]) assert "Response" in content @@ -248,7 +248,7 @@ async def test_render_messages_async_converted_differs(mock_display, printer): converted_value_data_type="text", ) msg = Message(message_pieces=[piece]) - content = await printer._render_messages_async(messages=[msg]) + content = await printer._conversation_printer.render_async(messages=[msg]) assert "Original" in content assert "Converted" in content @@ -318,12 +318,12 @@ async def test_render_adversarial_conversation_no_refs(printer): def test_render_wrapped_text(printer): - result = printer._render_wrapped_text("Short text", "") + result = printer._conversation_printer._render_wrapped_text("Short text", "") assert "Short text" in result def test_render_wrapped_text_with_newlines(printer): - result = printer._render_wrapped_text("Line one\nLine two\n\nLine four", "") + result = printer._conversation_printer._render_wrapped_text("Line one\nLine two\n\nLine four", "") assert "Line one" in result assert "Line two" in result assert "Line four" in result @@ -338,7 +338,7 @@ async def test_render_messages_async_blocked_without_partial_content(mock_displa response_error="blocked", ) msg = Message(message_pieces=[piece]) - content = await printer._render_messages_async(messages=[msg]) + content = await printer._conversation_printer.render_async(messages=[msg]) assert "BLOCKED BY TARGET" in content assert "content filter" in content # Should NOT print the raw error JSON as the message body @@ -355,7 +355,7 @@ async def test_render_messages_async_blocked_with_partial_content(mock_display, prompt_metadata={"partial_content": "The model started to say something before being cut off"}, ) msg = Message(message_pieces=[piece]) - content = await printer._render_messages_async(messages=[msg]) + content = await printer._conversation_printer.render_async(messages=[msg]) assert "BLOCKED BY TARGET" in content assert "Partial content" in content assert "before filter triggered" in content From 1bdc53d42b4d4c107a56891a26603e86f9fff1a2 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 23:09:16 -0700 Subject: [PATCH 22/34] Add render_async to leaf classes, update printer instructions Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/instructions/printer.instructions.md | 100 +++++++++++++++---- pyrit/printer/attack_result/markdown.py | 28 ++++++ pyrit/printer/attack_result/pretty.py | 28 ++++++ pyrit/printer/scenario_result/pretty.py | 12 +++ pyrit/printer/scorer/pretty.py | 15 +++ 5 files changed, 162 insertions(+), 21 deletions(-) diff --git a/.github/instructions/printer.instructions.md b/.github/instructions/printer.instructions.md index f46d9c446c..9265ae4b8c 100644 --- a/.github/instructions/printer.instructions.md +++ b/.github/instructions/printer.instructions.md @@ -4,15 +4,35 @@ applyTo: "pyrit/printer/**" # PyRIT Printer Module Guidelines -The printer module renders attack results, scenario results, and scorer information. It separates **what** the output looks like (format) from **where** it goes (sink) and **where data comes from** (abstract methods). +The printer module renders attack results, scenario results, conversation histories, scores, and scorer information. It separates **what** the output looks like (format) from **where** it goes (sink) and **where data comes from** (abstract methods). ## Architecture +### The `render_async` / `write_async` contract + +Every printer follows this contract, enforced by `PrinterBase`: + +```python +class PrinterBase(ABC): + @abstractmethod + async def render_async(self, *args, **kwargs) -> str: + """Return the rendered output string. Subclasses define the real signature.""" + + async def write_async(self, *args, **kwargs) -> None: + """Concrete. Calls render_async then writes to sink. Do not override.""" + content = await self.render_async(*args, **kwargs) + await self._write_async(content) +``` + +- **`render_async`** — abstract, returns `str`. Pure formatting. Easy to test and compose. +- **`write_async`** — concrete in base. Calls `render_async` → `_write_async`. **Nobody overrides this.** +- Composition: printers call each other's `render_async` to embed sub-sections. + ### Three-layer hierarchy per domain ``` -DomainPrinterBase(PrinterBase) # base.py — abstract data methods + write_async - ├─ PrettyDomainPrinter # pretty.py — ANSI formatting, returns str +DomainPrinterBase(PrinterBase) # base.py — abstract data methods + ├─ PrettyDomainPrinter # pretty.py — ANSI formatting, implements render_async │ └─ PrettyDomainMemoryPrinter # same file — fetches data via CentralMemory ├─ MarkdownDomainPrinter # markdown.py — Markdown formatting │ └─ MarkdownDomainMemoryPrinter @@ -20,9 +40,23 @@ DomainPrinterBase(PrinterBase) # base.py — abstract data methods + wr └─ JsonDomainMemoryPrinter ``` -- **Base** (`base.py`): declares abstract data-fetching methods and `write_async` -- **Format** (`pretty.py`, `markdown.py`, `json.py`): all rendering logic, builds `str`, writes to sink — **no data I/O here** -- **Leaf** (e.g., `PrettyAttackResultMemoryPrinter`): implements abstract data methods via `CentralMemory` — **no formatting logic here** +- **Base** (`base.py`): declares abstract data-fetching methods and abstract `render_async` +- **Format** (`pretty.py`, `markdown.py`, `json.py`): implements `render_async`, returns `str` — **no data I/O here** +- **Leaf** (e.g., `PrettyAttackResultMemoryPrinter`): implements abstract data methods via `CentralMemory`, has a forwarding `render_async` at the top for discoverability — **no formatting logic here** + +### Domain modules + +``` +pyrit/printer/ +├── base.py # PrinterBase — render_async (abstract) + write_async (concrete) +├── sink.py # Sink, StdoutSink, FileSink +├── helpers.py # Convenience functions (print_attack_result_async, etc.) +├── attack_result/ # Attack result printing — composes conversation + score printers +├── conversation/ # Conversation/message rendering (extracted from attack_result) +├── score/ # Individual Score object rendering (extracted from attack_result) +├── scorer/ # Scorer metrics/evaluation display +└── scenario_result/ # Scenario result printing +``` ### Sink — where output goes @@ -35,18 +69,26 @@ class Sink(ABC): Current sinks: `StdoutSink`, `FileSink`. Add new sinks as needed (IPython, Blob, etc.). -### PrinterBase — common base +### Composition pattern -All printers inherit `PrinterBase`. It provides: -- `sink` constructor param (default `StdoutSink`) -- `_write_async(data: str)` to write through the sink -- Abstract `write_async(...)` as the **public entry point** (signature varies per domain) +The attack result printer composes conversation and score printers: + +```python +class PrettyAttackResultPrinter(AttackResultPrinterBase): + def __init__(self, *, conversation_printer=None, score_printer=None, ...): + self._conversation_printer = conversation_printer or PrettyConversationPrinter(...) + self._score_printer = score_printer or PrettyScorePrinter(...) + + async def render_async(self, result, ...) -> str: + # Uses self._conversation_printer.render_async(messages) for conversation sections + # Uses self._score_printer._render_score(score) for inline scores +``` ## Key Rules ### Output goes through the sink — never call `print()` directly -All `_render_*` methods return `str`. The `write_async` entry point concatenates renders and calls `_write_async(content)`. No bare `print()` calls anywhere in the printer module except inside `StdoutSink`. +All `_render_*` methods return `str`. The inherited `write_async` calls `render_async` then `_write_async(content)`. No bare `print()` calls anywhere in the printer module except inside `StdoutSink`. ### Data fetching belongs in leaf classes only @@ -58,14 +100,30 @@ Format classes (`PrettyAttackResultPrinter`, `MarkdownAttackResultPrinter`) must - `markdown.py` — Markdown - `json.py` — structured JSON -### `write_async` is the only public entry point +### `render_async` and `write_async` are the public entry points -Each printer has one public method: `write_async(...)`. Old methods like `print_result_async`, `print_summary_async`, `print_objective_scorer` are deprecated wrappers that call `write_async`. +- `render_async(...)` → `str` — the primary method subclasses implement +- `write_async(...)` → `None` — concrete in base, calls render + sink. Do not override. +- Old methods like `print_result_async`, `print_summary_async`, `print_objective_scorer` are deprecated wrappers with `DeprecationWarning`. ### All other methods are private Prefix with `_`: `_format_colored`, `_render_header`, `_render_summary_async`, `_get_conversation_async`, `_get_scores_async`, etc. +### Leaf classes surface `render_async` at the top + +Every `*MemoryPrinter` leaf class has a forwarding `render_async` override right after `__init__` so readers immediately see the full signature and entry point: + +```python +class PrettyAttackResultMemoryPrinter(PrettyAttackResultPrinter): + def __init__(self, ...): ... + + async def render_async(self, result, ...) -> str: + return await super().render_async(result, ...) + + # data-fetching methods below +``` + ### Memory leaf classes must work with zero args ```python @@ -79,17 +137,17 @@ Pass `sink=` to redirect output. Pass sub-printers only to override defaults. ```python from pyrit.printer.helpers import print_attack_result_async -await print_attack_result_async(result, format="pretty", to=Path("out.txt")) +await print_attack_result_async(result, format="pretty", sink=Path("out.txt")) ``` -`helpers.py` resolves `format` → printer class, `to` → sink, and calls `write_async`. +`helpers.py` resolves `format` → printer class, `sink` → Sink, and calls `write_async`. ## Adding a New Format 1. Create `/.py` (e.g., `attack_result/json.py`) 2. Subclass the domain base (e.g., `AttackResultPrinterBase`) -3. Implement `write_async` — build a `str` from `_render_*` methods, call `_write_async` -4. Add a `*MemoryPrinter` leaf class that implements the abstract data methods +3. Implement `render_async` — build and return a `str` from `_render_*` methods +4. Add a `*MemoryPrinter` leaf class with forwarding `render_async` + data methods 5. Register in `helpers.py` format dispatch ## Adding a New Sink @@ -100,7 +158,7 @@ await print_attack_result_async(result, format="pretty", to=Path("out.txt")) ## Adding a New Domain Printer -1. Create `pyrit/printer//base.py` with abstract data methods + `write_async` -2. Create format files (`pretty.py`, etc.) with rendering logic -3. Add Memory leaf classes +1. Create `pyrit/printer//base.py` with abstract data methods + abstract `render_async` +2. Create format files (`pretty.py`, etc.) with `render_async` implementation +3. Add Memory leaf classes with forwarding `render_async` + data methods 4. Add convenience function in `helpers.py` diff --git a/pyrit/printer/attack_result/markdown.py b/pyrit/printer/attack_result/markdown.py index b05d3e5294..a55b3f9ea4 100644 --- a/pyrit/printer/attack_result/markdown.py +++ b/pyrit/printer/attack_result/markdown.py @@ -567,6 +567,34 @@ def __init__(self, *, sink: Sink | None = None, display_inline: bool = True) -> self._memory = CentralMemory.get_memory_instance() + async def render_async( + self, + result: AttackResult, + *, + include_auxiliary_scores: bool = False, + include_pruned_conversations: bool = False, + include_adversarial_conversation: bool = False, + ) -> str: + """ + Render the complete attack result as markdown and return it as a string. + + Args: + result (AttackResult): The attack result to render. + include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. + include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. + include_adversarial_conversation (bool): Whether to include the adversarial conversation. + Defaults to False. + + Returns: + str: The rendered markdown text. + """ + return await super().render_async( + result, + include_auxiliary_scores=include_auxiliary_scores, + include_pruned_conversations=include_pruned_conversations, + include_adversarial_conversation=include_adversarial_conversation, + ) + async def _get_conversation_async(self, conversation_id: str) -> list[Message]: """ Fetch conversation messages from CentralMemory. diff --git a/pyrit/printer/attack_result/pretty.py b/pyrit/printer/attack_result/pretty.py index 57e2c07a08..0ef9ebc4ef 100644 --- a/pyrit/printer/attack_result/pretty.py +++ b/pyrit/printer/attack_result/pretty.py @@ -464,6 +464,34 @@ def __init__( ) self._memory = CentralMemory.get_memory_instance() + async def render_async( + self, + result: AttackResult, + *, + include_auxiliary_scores: bool = False, + include_pruned_conversations: bool = False, + include_adversarial_conversation: bool = False, + ) -> str: + """ + Render the complete attack result and return it as a string. + + Args: + result (AttackResult): The attack result to render. + include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. + include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. + include_adversarial_conversation (bool): Whether to include the adversarial conversation. + Defaults to False. + + Returns: + str: The rendered attack result text. + """ + return await super().render_async( + result, + include_auxiliary_scores=include_auxiliary_scores, + include_pruned_conversations=include_pruned_conversations, + include_adversarial_conversation=include_adversarial_conversation, + ) + async def _get_conversation_async(self, conversation_id: str) -> list[Message]: """ Fetch conversation messages from CentralMemory. diff --git a/pyrit/printer/scenario_result/pretty.py b/pyrit/printer/scenario_result/pretty.py index 6ff5e3c4b7..6bd8ec36f8 100644 --- a/pyrit/printer/scenario_result/pretty.py +++ b/pyrit/printer/scenario_result/pretty.py @@ -260,3 +260,15 @@ def __init__( self._scorer_printer = PrettyScorerMemoryPrinter( sink=self._sink, indent_size=indent_size, enable_colors=enable_colors ) + + async def render_async(self, result: ScenarioResult) -> str: + """ + Render the scenario result summary and return it as a string. + + Args: + result (ScenarioResult): The scenario result to summarize. + + Returns: + str: The rendered scenario result text. + """ + return await super().render_async(result) diff --git a/pyrit/printer/scorer/pretty.py b/pyrit/printer/scorer/pretty.py index 570f9241b2..4c173532ca 100644 --- a/pyrit/printer/scorer/pretty.py +++ b/pyrit/printer/scorer/pretty.py @@ -287,6 +287,21 @@ class PrettyScorerMemoryPrinter(PrettyScorerPrinter): All formatting logic lives in PrettyScorerPrinter. """ + async def render_async( + self, *, scorer_identifier: ComponentIdentifier, harm_category: str | None = None + ) -> str: + """ + Render scorer information and return it as a string. + + Args: + scorer_identifier (ComponentIdentifier): The scorer identifier. + harm_category (str | None): The harm category. None for objective scorers. + + Returns: + str: The rendered scorer information text. + """ + return await super().render_async(scorer_identifier=scorer_identifier, harm_category=harm_category) + def _get_objective_metrics(self, *, eval_hash: str) -> Any: """ Fetch objective scorer evaluation metrics from the registry. From 17b8bc7cb8519de69172e13cc681361ffdc18d33 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 14 May 2026 23:57:15 -0700 Subject: [PATCH 23/34] Add IPythonMarkdownSink, get_default_sink auto-detection, markdown conversation/score printers, refactor display_image_response Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/printer/attack_result/markdown.py | 298 +++------------- pyrit/printer/attack_result/pretty.py | 14 +- pyrit/printer/conversation/markdown.py | 320 ++++++++++++++++++ pyrit/printer/conversation/pretty.py | 35 +- pyrit/printer/helpers.py | 88 +++-- pyrit/printer/score/markdown.py | 76 +++++ pyrit/printer/sink.py | 50 ++- .../attack/core/test_markdown_printer.py | 41 ++- .../attack/printer/test_pretty_printer.py | 15 +- tests/unit/printer/test_convenience.py | 113 +++++-- tests/unit/printer/test_sink.py | 51 ++- 11 files changed, 748 insertions(+), 353 deletions(-) create mode 100644 pyrit/printer/conversation/markdown.py create mode 100644 pyrit/printer/score/markdown.py diff --git a/pyrit/printer/attack_result/markdown.py b/pyrit/printer/attack_result/markdown.py index a55b3f9ea4..66fad0c09a 100644 --- a/pyrit/printer/attack_result/markdown.py +++ b/pyrit/printer/attack_result/markdown.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import os import warnings from datetime import datetime, timezone -from pyrit.models import AttackResult, ConversationType, Message, MessagePiece, Score +from pyrit.models import AttackResult, ConversationType, Message, Score from pyrit.printer.attack_result.base import AttackResultPrinterBase +from pyrit.printer.conversation.markdown import MarkdownConversationPrinter +from pyrit.printer.score.markdown import MarkdownScorePrinter from pyrit.printer.sink import Sink @@ -14,11 +15,18 @@ class MarkdownAttackResultPrinter(AttackResultPrinterBase): """ Markdown printer for attack results optimized for Jupyter notebooks. - Contains all formatting logic. Subclasses implement get_conversation_async - and get_scores_async for data fetching. + Composes a conversation printer for message rendering and a score printer + for inline score display. Subclasses implement data-fetching methods. """ - def __init__(self, *, sink: Sink | None = None, display_inline: bool = True) -> None: + def __init__( + self, + *, + sink: Sink | None = None, + display_inline: bool = True, + conversation_printer: MarkdownConversationPrinter | None = None, + score_printer: MarkdownScorePrinter | None = None, + ) -> None: """ Initialize the markdown printer. @@ -26,48 +34,17 @@ def __init__(self, *, sink: Sink | None = None, display_inline: bool = True) -> sink (Sink | None): Output sink. Defaults to StdoutSink(). display_inline (bool): Kept for backward compatibility but unused. All output is routed through the sink. Defaults to True. + conversation_printer (MarkdownConversationPrinter | None): Conversation printer. + Defaults to a new MarkdownConversationPrinter with matching sink. + score_printer (MarkdownScorePrinter | None): Score printer. + Defaults to a new MarkdownScorePrinter with matching sink. """ super().__init__(sink=sink) self._display_inline = display_inline - - def _format_score(self, score: Score, indent: str = "") -> str: - """ - Format a score object as markdown with proper styling. - - Args: - score (Score): The score object to format. - indent (str): String prefix for indentation. Defaults to "". - - Returns: - str: Formatted markdown representation of the score. - """ - lines = [] - - score_value = score.get_value() - if isinstance(score_value, bool): - value_str = str(score_value) - elif isinstance(score_value, (int, float)): - value_str = f"**{score_value:.2f}**" if isinstance(score_value, float) else f"**{score_value}**" - else: - value_str = f"**{score_value}**" - - lines.append(f"{indent}- **Score Type:** {score.score_type}") - lines.append(f"{indent}- **Value:** {value_str}") - category_str = ", ".join(score.score_category) if score.score_category else "N/A" - lines.append(f"{indent}- **Category:** {category_str}") - - if score.score_rationale: - rationale_lines = score.score_rationale.split("\n") - if len(rationale_lines) > 1: - lines.append(f"{indent}- **Rationale:**") - lines.extend(f"{indent} {line}" for line in rationale_lines) - else: - lines.append(f"{indent}- **Rationale:** {score.score_rationale}") - - if score.score_metadata: - lines.append(f"{indent}- **Metadata:** `{score.score_metadata}`") - - return "\n".join(lines) + self._score_printer = score_printer or MarkdownScorePrinter(sink=sink) + self._conversation_printer = conversation_printer or MarkdownConversationPrinter( + sink=sink, score_printer=self._score_printer, + ) async def render_async( self, @@ -153,8 +130,8 @@ async def print_conversation_async(self, result: AttackResult, *, include_scores warnings.warn( "print_conversation_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2 ) - markdown_lines = await self._get_conversation_markdown_async(result=result, include_scores=include_scores) - await self._write_async("\n".join(markdown_lines)) + lines = await self._get_conversation_markdown_async(result=result, include_scores=include_scores) + await self._write_async("\n".join(lines)) async def print_summary_async(self, result: AttackResult) -> None: """Deprecated. Use write_async instead.""" @@ -175,223 +152,16 @@ async def _get_conversation_markdown_async( Returns: list[str]: Markdown strings for the conversation. """ - markdown_lines: list[str] = [] - if not result.conversation_id: - markdown_lines.append("*No conversation ID available*\n") - return markdown_lines + return ["*No conversation ID available*\n"] messages = await self._get_conversation_async(result.conversation_id) if not messages: - markdown_lines.append(f"*No conversation found for ID: {result.conversation_id}*\n") - return markdown_lines - - turn_number = 0 - - for message in messages: - if not message.message_pieces: - continue - - message_role = message.get_piece().api_role - - if message_role == "system": - markdown_lines.extend(self._format_system_message(message)) - elif message_role == "user": - turn_number += 1 - markdown_lines.extend(await self._format_user_message_async(message=message, turn_number=turn_number)) - else: - markdown_lines.extend(await self._format_assistant_message_async(message=message)) - - if include_scores: - markdown_lines.extend(await self._format_message_scores_async(message)) - - return markdown_lines - - def _format_system_message(self, message: Message) -> list[str]: - """ - Format a system message as markdown. - - Args: - message (Message): The system message to format. - - Returns: - list[str]: Markdown strings for the system message. - """ - lines = ["\n### System Message\n"] - lines.extend(f"{piece.converted_value}\n" for piece in message.message_pieces) - return lines - - async def _format_user_message_async(self, *, message: Message, turn_number: int) -> list[str]: - """ - Format a user message as markdown with turn numbering. - - Args: - message (Message): The user message to format. - turn_number (int): The conversation turn number. - - Returns: - list[str]: Markdown strings for the user message. - """ - lines = [f"\n### Turn {turn_number}\n", "#### User\n"] - - for piece in message.message_pieces: - lines.extend(await self._format_piece_content_async(piece=piece, show_original=True)) - - return lines + return [f"*No conversation found for ID: {result.conversation_id}*\n"] - async def _format_assistant_message_async(self, *, message: Message) -> list[str]: - """ - Format an assistant response message as markdown. - - Args: - message (Message): The response message to format. - - Returns: - list[str]: Markdown strings for the response message. - """ - lines: list[str] = [] - piece = message.message_pieces[0] - role_name = "Assistant (Simulated)" if piece.is_simulated else piece.api_role.capitalize() - - lines.append(f"\n#### {role_name}\n") - - for piece in message.message_pieces: - lines.extend(await self._format_piece_content_async(piece=piece, show_original=False)) - - return lines - - def _get_audio_mime_type(self, *, audio_path: str) -> str: - """ - Determine the MIME type for an audio file based on its file extension. - - Args: - audio_path (str): The path to the audio file. - - Returns: - str: The appropriate MIME type for the audio file. - """ - if audio_path.lower().endswith(".wav"): - return "audio/wav" - if audio_path.lower().endswith(".ogg"): - return "audio/ogg" - if audio_path.lower().endswith(".m4a"): - return "audio/mp4" - return "audio/mpeg" - - def _format_image_content(self, *, image_path: str) -> list[str]: - """ - Format image content as markdown. - - Args: - image_path (str): The path to the image file. - - Returns: - list[str]: Markdown lines for the image. - """ - relative_path = os.path.relpath(image_path) - posix_path = relative_path.replace("\\", "/") - return [f"![Image]({posix_path})\n"] - - def _format_audio_content(self, *, audio_path: str) -> list[str]: - """ - Format audio content as HTML5 audio player. - - Args: - audio_path (str): The path to the audio file. - - Returns: - list[str]: Markdown lines for the audio player. - """ - lines: list[str] = [] - lines.append("\n") - - return lines - - def _format_error_content(self, *, piece: MessagePiece) -> list[str]: - """ - Format error response content with proper styling. - - Args: - piece (MessagePiece): The message piece containing the error. - - Returns: - list[str]: Markdown lines for the error response. - """ - lines: list[str] = [] - lines.append("**Error Response:**\n") - lines.append(f"*Error Type: {piece.response_error}*\n") - lines.append("```json") - lines.append(piece.converted_value) - lines.append("```\n") - - return lines - - def _format_text_content(self, *, piece: MessagePiece, show_original: bool) -> list[str]: - """ - Format regular text content. - - Args: - piece (MessagePiece): The message piece containing the text. - show_original (bool): Whether to show original value if different. - - Returns: - list[str]: Markdown lines for the text content. - """ - lines: list[str] = [] - - if show_original and piece.converted_value != piece.original_value: - lines.append("**Original:**\n") - lines.append(f"{piece.original_value}\n") - lines.append("\n**Converted:**\n") - - lines.append(f"{piece.converted_value}\n") - - return lines - - async def _format_piece_content_async(self, *, piece: MessagePiece, show_original: bool) -> list[str]: - """ - Format a single piece content based on its data type. - - Args: - piece (MessagePiece): The message piece to format. - show_original (bool): Whether to show original value if different. - - Returns: - list[str]: Markdown lines for this piece. - """ - if piece.converted_value_data_type == "image_path": - return self._format_image_content(image_path=piece.converted_value) - if piece.converted_value_data_type == "audio_path": - return self._format_audio_content(audio_path=piece.converted_value) - if piece.has_error(): - return self._format_error_content(piece=piece) - return self._format_text_content(piece=piece, show_original=show_original) - - async def _format_message_scores_async(self, message: Message) -> list[str]: - """ - Format scores for all pieces in a message as markdown. - - Args: - message (Message): The message containing pieces to format scores for. - - Returns: - list[str]: Markdown strings for the scores. - """ - lines: list[str] = [] - for piece in message.message_pieces: - scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) - if scores: - lines.append("\n##### Scores\n") - lines.extend(self._format_score(score, indent="") for score in scores) - lines.append("") - return lines + rendered = await self._conversation_printer.render_async(messages, include_scores=include_scores) + return [rendered] async def _get_summary_markdown_async(self, result: AttackResult) -> list[str]: """ @@ -432,7 +202,7 @@ async def _get_summary_markdown_async(self, result: AttackResult) -> list[str]: if result.last_score: markdown_lines.append("\n### Final Score\n") - markdown_lines.append(self._format_score(result.last_score)) + markdown_lines.append(self._score_printer._format_score(result.last_score)) return markdown_lines @@ -484,7 +254,9 @@ async def _get_pruned_conversations_markdown_async(self, result: AttackResult) - scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) if scores: markdown_lines.append("\n**Score:**\n") - markdown_lines.extend(self._format_score(score, indent="") for score in scores) + markdown_lines.extend( + self._score_printer._format_score(score, indent="") for score in scores + ) return markdown_lines @@ -562,9 +334,15 @@ def __init__(self, *, sink: Sink | None = None, display_inline: bool = True) -> display_inline (bool): Kept for backward compatibility but unused. All output is routed through the sink. Defaults to True. """ - super().__init__(sink=sink, display_inline=display_inline) from pyrit.memory import CentralMemory + from pyrit.printer.conversation.markdown import MarkdownConversationMemoryPrinter + score_printer = MarkdownScorePrinter(sink=sink) + conversation_printer = MarkdownConversationMemoryPrinter(sink=sink, score_printer=score_printer) + super().__init__( + sink=sink, display_inline=display_inline, + conversation_printer=conversation_printer, score_printer=score_printer, + ) self._memory = CentralMemory.get_memory_instance() async def render_async( diff --git a/pyrit/printer/attack_result/pretty.py b/pyrit/printer/attack_result/pretty.py index 0ef9ebc4ef..222075856f 100644 --- a/pyrit/printer/attack_result/pretty.py +++ b/pyrit/printer/attack_result/pretty.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import warnings +from pyrit.common.deprecation import print_deprecation_message from datetime import datetime, timezone from typing import Any @@ -117,7 +117,7 @@ async def print_result_async( include_adversarial_conversation: bool = False, ) -> None: """Deprecated. Use write_async instead.""" - warnings.warn("print_result_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2) + print_deprecation_message(old_item="print_result_async", new_item="write_async", removed_in="2.0") await self.write_async( result, include_auxiliary_scores=include_auxiliary_scores, @@ -159,8 +159,8 @@ async def print_conversation_async( self, result: AttackResult, *, include_scores: bool = False, include_reasoning_trace: bool = False ) -> None: """Deprecated. Use write_async instead.""" - warnings.warn( - "print_conversation_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2 + print_deprecation_message( + old_item="print_conversation_async", new_item="write_async", removed_in="2.0" ) content = await self._render_conversation_async( result, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace @@ -175,8 +175,8 @@ async def print_messages_async( include_reasoning_trace: bool = False, ) -> None: """Deprecated. Use the conversation printer's write_async instead.""" - warnings.warn( - "print_messages_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2 + print_deprecation_message( + old_item="print_messages_async", new_item="write_async", removed_in="2.0" ) content = await self._conversation_printer.render_async( messages, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace @@ -238,7 +238,7 @@ async def _render_summary_async(self, result: AttackResult) -> str: async def print_summary_async(self, result: AttackResult) -> None: """Deprecated. Use write_async instead.""" - warnings.warn("print_summary_async is deprecated, use write_async instead", DeprecationWarning, stacklevel=2) + print_deprecation_message(old_item="print_summary_async", new_item="write_async", removed_in="2.0") content = await self._render_summary_async(result) await self._write_async(content) diff --git a/pyrit/printer/conversation/markdown.py b/pyrit/printer/conversation/markdown.py new file mode 100644 index 0000000000..29395f4eef --- /dev/null +++ b/pyrit/printer/conversation/markdown.py @@ -0,0 +1,320 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os + +from pyrit.models import Message, MessagePiece, Score +from pyrit.printer.conversation.base import ConversationPrinterBase +from pyrit.printer.score.markdown import MarkdownScorePrinter +from pyrit.printer.sink import Sink + + +class MarkdownConversationPrinter(ConversationPrinterBase): + """ + Markdown printer for conversation message histories. + + Contains all formatting logic. Subclasses implement ``_get_scores_async`` + for data fetching. + """ + + def __init__( + self, + *, + sink: Sink | None = None, + score_printer: MarkdownScorePrinter | None = None, + ) -> None: + """ + Initialize the markdown conversation printer. + + Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). + score_printer (MarkdownScorePrinter | None): Score printer for inline score rendering. + Defaults to a new MarkdownScorePrinter with matching sink. + """ + super().__init__(sink=sink) + self._score_printer = score_printer or MarkdownScorePrinter(sink=sink) + + async def render_async( + self, + messages: list[Message], + *, + include_scores: bool = False, + include_reasoning_trace: bool = False, + ) -> str: + """ + Render a list of messages as markdown and return as a string. + + Args: + messages (list[Message]): The messages to render. + include_scores (bool): Whether to include scores. Defaults to False. + include_reasoning_trace (bool): Accepted for interface compatibility. Unused. + + Returns: + str: The rendered conversation markdown text. + """ + if not messages: + return "*No messages to display*\n" + + markdown_lines: list[str] = [] + turn_number = 0 + + for message in messages: + if not message.message_pieces: + continue + + message_role = message.get_piece().api_role + + if message_role == "system": + markdown_lines.extend(self._format_system_message(message)) + elif message_role == "user": + turn_number += 1 + markdown_lines.extend( + await self._format_user_message_async(message=message, turn_number=turn_number) + ) + else: + markdown_lines.extend(await self._format_assistant_message_async(message=message)) + + if include_scores: + markdown_lines.extend(await self._format_message_scores_async(message)) + + return "\n".join(markdown_lines) + + def _format_system_message(self, message: Message) -> list[str]: + """ + Format a system message as markdown. + + Args: + message (Message): The system message to format. + + Returns: + list[str]: Markdown strings for the system message. + """ + lines = ["\n### System Message\n"] + lines.extend(f"{piece.converted_value}\n" for piece in message.message_pieces) + return lines + + async def _format_user_message_async(self, *, message: Message, turn_number: int) -> list[str]: + """ + Format a user message as markdown with turn numbering. + + Args: + message (Message): The user message to format. + turn_number (int): The conversation turn number. + + Returns: + list[str]: Markdown strings for the user message. + """ + lines = [f"\n### Turn {turn_number}\n", "#### User\n"] + + for piece in message.message_pieces: + lines.extend(await self._format_piece_content_async(piece=piece, show_original=True)) + + return lines + + async def _format_assistant_message_async(self, *, message: Message) -> list[str]: + """ + Format an assistant response message as markdown. + + Args: + message (Message): The response message to format. + + Returns: + list[str]: Markdown strings for the response message. + """ + lines: list[str] = [] + piece = message.message_pieces[0] + role_name = "Assistant (Simulated)" if piece.is_simulated else piece.api_role.capitalize() + + lines.append(f"\n#### {role_name}\n") + + for piece in message.message_pieces: + lines.extend(await self._format_piece_content_async(piece=piece, show_original=False)) + + return lines + + async def _format_piece_content_async(self, *, piece: MessagePiece, show_original: bool) -> list[str]: + """ + Format a single piece content based on its data type. + + Args: + piece (MessagePiece): The message piece to format. + show_original (bool): Whether to show original value if different. + + Returns: + list[str]: Markdown lines for this piece. + """ + if piece.converted_value_data_type == "image_path": + return self._format_image_content(image_path=piece.converted_value) + if piece.converted_value_data_type == "audio_path": + return self._format_audio_content(audio_path=piece.converted_value) + if piece.has_error(): + return self._format_error_content(piece=piece) + return self._format_text_content(piece=piece, show_original=show_original) + + def _format_text_content(self, *, piece: MessagePiece, show_original: bool) -> list[str]: + """ + Format regular text content. + + Args: + piece (MessagePiece): The message piece containing the text. + show_original (bool): Whether to show original value if different. + + Returns: + list[str]: Markdown lines for the text content. + """ + lines: list[str] = [] + + if show_original and piece.converted_value != piece.original_value: + lines.append("**Original:**\n") + lines.append(f"{piece.original_value}\n") + lines.append("\n**Converted:**\n") + + lines.append(f"{piece.converted_value}\n") + + return lines + + def _format_image_content(self, *, image_path: str) -> list[str]: + """ + Format image content as markdown. + + Args: + image_path (str): The path to the image file. + + Returns: + list[str]: Markdown lines for the image. + """ + relative_path = os.path.relpath(image_path) + posix_path = relative_path.replace("\\", "/") + return [f"![Image]({posix_path})\n"] + + def _format_audio_content(self, *, audio_path: str) -> list[str]: + """ + Format audio content as HTML5 audio player. + + Args: + audio_path (str): The path to the audio file. + + Returns: + list[str]: Markdown lines for the audio player. + """ + lines: list[str] = [] + lines.append("\n") + return lines + + @staticmethod + def _get_audio_mime_type(*, audio_path: str) -> str: + """ + Determine the MIME type for an audio file based on its file extension. + + Args: + audio_path (str): The path to the audio file. + + Returns: + str: The appropriate MIME type for the audio file. + """ + if audio_path.lower().endswith(".wav"): + return "audio/wav" + if audio_path.lower().endswith(".ogg"): + return "audio/ogg" + if audio_path.lower().endswith(".m4a"): + return "audio/mp4" + return "audio/mpeg" + + def _format_error_content(self, *, piece: MessagePiece) -> list[str]: + """ + Format error response content with proper styling. + + Args: + piece (MessagePiece): The message piece containing the error. + + Returns: + list[str]: Markdown lines for the error response. + """ + lines: list[str] = [] + lines.append("**Error Response:**\n") + lines.append(f"*Error Type: {piece.response_error}*\n") + lines.append("```json") + lines.append(piece.converted_value) + lines.append("```\n") + return lines + + async def _format_message_scores_async(self, message: Message) -> list[str]: + """ + Format scores for all pieces in a message as markdown. + + Args: + message (Message): The message containing pieces to format scores for. + + Returns: + list[str]: Markdown strings for the scores. + """ + lines: list[str] = [] + for piece in message.message_pieces: + scores = await self._get_scores_async(prompt_ids=[str(piece.id)]) + if scores: + lines.append("\n##### Scores\n") + lines.extend(self._score_printer._format_score(score, indent="") for score in scores) + lines.append("") + return lines + + +class MarkdownConversationMemoryPrinter(MarkdownConversationPrinter): + """ + Framework markdown printer for conversation histories. + + Implements data-fetching via CentralMemory (deferred import). + All formatting logic lives in MarkdownConversationPrinter. + """ + + def __init__( + self, + *, + sink: Sink | None = None, + score_printer: MarkdownScorePrinter | None = None, + ) -> None: + """ + Initialize the markdown conversation printer with CentralMemory data source. + + Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). + score_printer (MarkdownScorePrinter | None): Score printer for inline score rendering. + """ + super().__init__(sink=sink, score_printer=score_printer) + from pyrit.memory import CentralMemory + + self._memory = CentralMemory.get_memory_instance() + + async def render_async( + self, + messages: list[Message], + *, + include_scores: bool = False, + include_reasoning_trace: bool = False, + ) -> str: + """ + Render a list of messages as markdown and return as a string. + + Args: + messages (list[Message]): The messages to render. + include_scores (bool): Whether to include scores. Defaults to False. + include_reasoning_trace (bool): Accepted for interface compatibility. Unused. + + Returns: + str: The rendered conversation markdown text. + """ + return await super().render_async( + messages, include_scores=include_scores, include_reasoning_trace=include_reasoning_trace + ) + + async def _get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: + """ + Fetch scores from CentralMemory. + + Returns: + list[Score]: The scores. + """ + return list(self._memory.get_prompt_scores(prompt_ids=prompt_ids)) diff --git a/pyrit/printer/conversation/pretty.py b/pyrit/printer/conversation/pretty.py index 436202d80c..6d48f8712c 100644 --- a/pyrit/printer/conversation/pretty.py +++ b/pyrit/printer/conversation/pretty.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import json +import logging import textwrap from colorama import Fore, Style @@ -11,6 +12,8 @@ from pyrit.printer.score.pretty import PrettyScorePrinter from pyrit.printer.sink import Sink +logger = logging.getLogger(__name__) + class PrettyConversationPrinter(ConversationPrinterBase): """ @@ -297,7 +300,33 @@ async def _get_scores_async(self, *, prompt_ids: list[str]) -> list[Score]: return list(self._memory.get_prompt_scores(prompt_ids=prompt_ids)) async def _display_image_async(self, piece: MessagePiece) -> None: - """Display images using PIL/IPython in notebook environments.""" - from pyrit.common.display_response import display_image_response + """ + Display an image from a message piece in notebook environments. + + Uses ``DataTypeSerializer.read_data`` for transparent storage access + (local disk or Azure Blob) and ``IPython.display.Image`` for rendering. + No-op outside notebook environments. + + Args: + piece (MessagePiece): The message piece that may contain image data. + """ + if piece.converted_value_data_type != "image_path" or piece.response_error != "none": + return + + from pyrit.common.notebook_utils import is_in_ipython_session + + if not is_in_ipython_session(): + return + + from pyrit.models.data_type_serializer import ImagePathDataTypeSerializer + + try: + serializer = ImagePathDataTypeSerializer(category="", prompt_text=piece.converted_value) + image_bytes = await serializer.read_data() + except Exception as e: + logger.error(f"Failed to read image from {piece.converted_value}: {e}") + return + + from IPython.display import Image, display - await display_image_response(piece) + display(Image(data=image_bytes)) diff --git a/pyrit/printer/helpers.py b/pyrit/printer/helpers.py index 4687a62c89..f8c55934c6 100644 --- a/pyrit/printer/helpers.py +++ b/pyrit/printer/helpers.py @@ -3,16 +3,14 @@ """Convenience functions for one-line printing of attack results, scenario results, and scorer info.""" -from pathlib import Path - -from pyrit.printer.sink import OutputFormat, Sink, resolve_sink +from pyrit.printer.sink import OutputFormat, Sink, StdoutSink, get_default_sink async def print_attack_result_async( result: "AttackResult", # noqa: F821 *, format: OutputFormat = "pretty", - sink: Path | str | Sink | None = None, + sink: Sink | None = None, include_auxiliary_scores: bool = False, include_pruned_conversations: bool = False, include_adversarial_conversation: bool = False, @@ -23,22 +21,21 @@ async def print_attack_result_async( Args: result (AttackResult): The attack result to print. format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - sink (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + sink (Sink | None): Output sink. Defaults to StdoutSink for "pretty"; auto-detects + (IPythonMarkdownSink in notebooks, StdoutSink otherwise) for "markdown". include_auxiliary_scores (bool): Whether to include auxiliary scores. Defaults to False. include_pruned_conversations (bool): Whether to include pruned conversations. Defaults to False. include_adversarial_conversation (bool): Whether to include the adversarial conversation. Defaults to False. """ - resolved_sink = resolve_sink(sink) - if format == "markdown": from pyrit.printer.attack_result.markdown import MarkdownAttackResultMemoryPrinter - printer = MarkdownAttackResultMemoryPrinter(sink=resolved_sink) + printer = MarkdownAttackResultMemoryPrinter(sink=sink or get_default_sink()) else: from pyrit.printer.attack_result.pretty import PrettyAttackResultMemoryPrinter - printer = PrettyAttackResultMemoryPrinter(sink=resolved_sink) + printer = PrettyAttackResultMemoryPrinter(sink=sink or get_default_sink(StdoutSink)) await printer.write_async( result, @@ -52,7 +49,7 @@ async def print_scenario_result_async( result: "ScenarioResult", # noqa: F821 *, format: OutputFormat = "pretty", - sink: Path | str | Sink | None = None, + sink: Sink | None = None, ) -> None: """ Print a scenario result in the specified format to the specified destination. @@ -60,14 +57,12 @@ async def print_scenario_result_async( Args: result (ScenarioResult): The scenario result to print. format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - sink (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + sink (Sink | None): Output sink. Defaults to StdoutSink. """ - resolved_sink = resolve_sink(sink) - if format == "pretty": from pyrit.printer.scenario_result.pretty import PrettyScenarioResultMemoryPrinter - printer = PrettyScenarioResultMemoryPrinter(sink=resolved_sink) + printer = PrettyScenarioResultMemoryPrinter(sink=sink or get_default_sink(StdoutSink)) else: raise ValueError(f"Unsupported format for scenario results: {format!r}. Only 'pretty' is available.") @@ -79,7 +74,7 @@ async def print_scorer_async( scorer_identifier: "ComponentIdentifier", # noqa: F821 harm_category: str | None = None, format: OutputFormat = "pretty", - sink: Path | str | Sink | None = None, + sink: Sink | None = None, ) -> None: """ Print scorer information in the specified format to the specified destination. @@ -91,15 +86,70 @@ async def print_scorer_async( scorer_identifier (ComponentIdentifier): The scorer identifier. harm_category (str | None): The harm category. None for objective scorers. format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". - sink (Path | str | Sink | None): Destination — None for stdout, path for file, or Sink instance. + sink (Sink | None): Output sink. Defaults to StdoutSink. """ - resolved_sink = resolve_sink(sink) - if format == "pretty": from pyrit.printer.scorer.pretty import PrettyScorerMemoryPrinter - printer = PrettyScorerMemoryPrinter(sink=resolved_sink) + printer = PrettyScorerMemoryPrinter(sink=sink or get_default_sink(StdoutSink)) else: raise ValueError(f"Unsupported format for scorer: {format!r}. Only 'pretty' is available.") await printer.write_async(scorer_identifier=scorer_identifier, harm_category=harm_category) + + +async def print_conversation_async( + messages: "list[Message]", # noqa: F821 + *, + format: OutputFormat = "pretty", + sink: Sink | None = None, + include_scores: bool = False, + include_reasoning_trace: bool = False, +) -> None: + """ + Print a conversation message history in the specified format. + + Args: + messages (list[Message]): The messages to print. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + sink (Sink | None): Output sink. Defaults to StdoutSink for "pretty", IPythonMarkdownSink + for "markdown". + include_scores (bool): Whether to include scores. Defaults to False. + include_reasoning_trace (bool): Whether to include reasoning traces. Defaults to False. + """ + if format == "pretty": + from pyrit.printer.conversation.pretty import PrettyConversationMemoryPrinter + + printer = PrettyConversationMemoryPrinter(sink=sink or get_default_sink(StdoutSink)) + else: + raise ValueError(f"Unsupported format for conversation: {format!r}. Only 'pretty' is available.") + + await printer.write_async( + messages, + include_scores=include_scores, + include_reasoning_trace=include_reasoning_trace, + ) + + +async def print_score_async( + scores: "list[Score]", # noqa: F821 + *, + format: OutputFormat = "pretty", + sink: Sink | None = None, +) -> None: + """ + Print a list of scores in the specified format. + + Args: + scores (list[Score]): The scores to print. + format (OutputFormat): Output format — "pretty" or "markdown". Defaults to "pretty". + sink (Sink | None): Output sink. Defaults to StdoutSink. + """ + if format == "pretty": + from pyrit.printer.score.pretty import PrettyScorePrinter + + printer = PrettyScorePrinter(sink=sink or get_default_sink(StdoutSink)) + else: + raise ValueError(f"Unsupported format for scores: {format!r}. Only 'pretty' is available.") + + await printer.write_async(scores) diff --git a/pyrit/printer/score/markdown.py b/pyrit/printer/score/markdown.py new file mode 100644 index 0000000000..9191f52549 --- /dev/null +++ b/pyrit/printer/score/markdown.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pyrit.models import Score +from pyrit.printer.base import PrinterBase +from pyrit.printer.sink import Sink + + +class MarkdownScorePrinter(PrinterBase): + """ + Markdown printer for individual Score objects. + + Provides ``_format_score`` for inline use by other printers and + ``render_async`` / ``write_async`` for standalone rendering. + """ + + def __init__(self, *, sink: Sink | None = None) -> None: + """ + Initialize the markdown score printer. + + Args: + sink (Sink | None): Output sink. Defaults to StdoutSink(). + """ + super().__init__(sink=sink) + + def _format_score(self, score: Score, indent: str = "") -> str: + """ + Format a score object as markdown with proper styling. + + Args: + score (Score): The score object to format. + indent (str): String prefix for indentation. Defaults to "". + + Returns: + str: Formatted markdown representation of the score. + """ + lines: list[str] = [] + + score_value = score.get_value() + if isinstance(score_value, bool): + value_str = str(score_value) + elif isinstance(score_value, (int, float)): + value_str = f"**{score_value:.2f}**" if isinstance(score_value, float) else f"**{score_value}**" + else: + value_str = f"**{score_value}**" + + lines.append(f"{indent}- **Score Type:** {score.score_type}") + lines.append(f"{indent}- **Value:** {value_str}") + category_str = ", ".join(score.score_category) if score.score_category else "N/A" + lines.append(f"{indent}- **Category:** {category_str}") + + if score.score_rationale: + rationale_lines = score.score_rationale.split("\n") + if len(rationale_lines) > 1: + lines.append(f"{indent}- **Rationale:**") + lines.extend(f"{indent} {line}" for line in rationale_lines) + else: + lines.append(f"{indent}- **Rationale:** {score.score_rationale}") + + if score.score_metadata: + lines.append(f"{indent}- **Metadata:** `{score.score_metadata}`") + + return "\n".join(lines) + + async def render_async(self, scores: list[Score], *, indent: str = "") -> str: + """ + Render a list of scores as markdown and return as a string. + + Args: + scores (list[Score]): The scores to render. + indent (str): String prefix for indentation. Defaults to "". + + Returns: + str: The rendered scores markdown text. + """ + return "\n".join(self._format_score(score, indent=indent) for score in scores) diff --git a/pyrit/printer/sink.py b/pyrit/printer/sink.py index 21347c8100..f605110e9b 100644 --- a/pyrit/printer/sink.py +++ b/pyrit/printer/sink.py @@ -76,21 +76,47 @@ async def write_async(self, data: str) -> None: f.write(data) -def resolve_sink(to: Path | str | Sink | None) -> Sink: +class IPythonMarkdownSink(Sink): """ - Resolve a destination argument to a Sink instance. + Sink that renders markdown via IPython's ``display(Markdown(...))``. + + Falls back to ``print()`` if IPython is not available (e.g., outside + a Jupyter notebook). + """ + + async def write_async(self, data: str) -> None: + """ + Display data as rendered markdown in IPython, or print to stdout. + + Args: + data (str): The markdown text to display. + """ + try: + from IPython.display import Markdown, display + + display(Markdown(data)) + except (ImportError, NameError): + print(data, end="") + + +def get_default_sink(default: type[Sink] | None = None) -> Sink: + """ + Return the appropriate default sink for the current environment. + + When ``default`` is None, auto-detects: uses ``IPythonMarkdownSink`` + inside Jupyter/IPython notebooks, otherwise ``StdoutSink``. Args: - to (Path | str | Sink | None): The destination. - None → StdoutSink. - Path or str → FileSink. - Sink instance → used as-is. + default (type[Sink] | None): Sink class to instantiate. + None means auto-detect based on environment. Returns: - Sink: The resolved sink. + Sink: The default sink instance. """ - if to is None: - return StdoutSink() - if isinstance(to, Sink): - return to - return FileSink(path=Path(to)) + if default is not None: + return default() + from pyrit.common.notebook_utils import is_in_ipython_session + + if is_in_ipython_session(): + return IPythonMarkdownSink() + return StdoutSink() diff --git a/tests/unit/executor/attack/core/test_markdown_printer.py b/tests/unit/executor/attack/core/test_markdown_printer.py index b9f088ca11..b6b39e0127 100644 --- a/tests/unit/executor/attack/core/test_markdown_printer.py +++ b/tests/unit/executor/attack/core/test_markdown_printer.py @@ -110,7 +110,7 @@ def test_init(mock_memory): def test_format_score_bool(markdown_printer, sample_boolean_score): """Test score formatting with boolean value.""" - formatted = markdown_printer._format_score(sample_boolean_score) + formatted = markdown_printer._score_printer._format_score(sample_boolean_score) assert "**Value:** True" in formatted assert "**Score Type:** true_false" in formatted assert "**Category:** test" in formatted @@ -120,7 +120,7 @@ def test_format_score_bool(markdown_printer, sample_boolean_score): def test_format_score_float(markdown_printer, sample_float_score): """Test score formatting with float value.""" - formatted = markdown_printer._format_score(sample_float_score) + formatted = markdown_printer._score_printer._format_score(sample_float_score) assert "0.5" in formatted assert "**Score Type:** float_scale" in formatted assert "**Category:** other" in formatted @@ -128,7 +128,7 @@ def test_format_score_float(markdown_printer, sample_float_score): def test_format_score_multiline_rationale(markdown_printer, sample_boolean_score): """Test score formatting with multi-line rationale.""" - formatted = markdown_printer._format_score(sample_boolean_score) + formatted = markdown_printer._score_printer._format_score(sample_boolean_score) assert "Line 1" in formatted assert "Line 2" in formatted assert "Line 3" in formatted @@ -136,16 +136,17 @@ def test_format_score_multiline_rationale(markdown_printer, sample_boolean_score def test_get_audio_mime_type(markdown_printer): """Test audio MIME type detection.""" - assert markdown_printer._get_audio_mime_type(audio_path="test.wav") == "audio/wav" - assert markdown_printer._get_audio_mime_type(audio_path="test.ogg") == "audio/ogg" - assert markdown_printer._get_audio_mime_type(audio_path="test.m4a") == "audio/mp4" - assert markdown_printer._get_audio_mime_type(audio_path="test.mp3") == "audio/mpeg" + conv = markdown_printer._conversation_printer + assert conv._get_audio_mime_type(audio_path="test.wav") == "audio/wav" + assert conv._get_audio_mime_type(audio_path="test.ogg") == "audio/ogg" + assert conv._get_audio_mime_type(audio_path="test.m4a") == "audio/mp4" + assert conv._get_audio_mime_type(audio_path="test.mp3") == "audio/mpeg" def test_format_image_content(markdown_printer): """Test image content formatting.""" image_path = os.path.join("test", "path", "image.png") - formatted = markdown_printer._format_image_content(image_path=image_path) + formatted = markdown_printer._conversation_printer._format_image_content(image_path=image_path) assert formatted[0].startswith("![Image]") assert "image.png" in formatted[0] @@ -153,7 +154,7 @@ def test_format_image_content(markdown_printer): def test_format_audio_content(markdown_printer): """Test audio content formatting.""" audio_path = "test.wav" - formatted = markdown_printer._format_audio_content(audio_path=audio_path) + formatted = markdown_printer._conversation_printer._format_audio_content(audio_path=audio_path) assert "