diff --git a/kernelguard.py b/kernelguard.py index f086847..b504cb6 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2462,6 +2462,169 @@ def _has_ver(expr: ast.AST | None) -> bool: return [] +def detect_dict_view_replay(code: str | SubmissionFacts) -> list[dict]: + """Pattern: replay through values/items views of a mutated dict.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + dict_names = { + stmt.targets[0].id + for stmt in tree.body + if isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and ( + isinstance(stmt.value, ast.Dict) + or ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id == "dict" + ) + ) + } + + def _param_names(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]: + params = {arg.arg for arg in node.args.args} + params.update(arg.arg for arg in node.args.posonlyargs) + params.update(arg.arg for arg in node.args.kwonlyargs) + return params + + def _input_derived(expr: ast.AST | None, params: set[str]) -> bool: + return bool(_expr_names(expr) & params) + + def _mutated_dict(expr: ast.AST, params: set[str]) -> str | None: + if isinstance(expr, ast.Assign) and _input_derived(expr.value, params): + for target in expr.targets: + if isinstance(target, ast.Subscript): + root = _ast_root_name(target) + if root in dict_names: + return root + if not isinstance(expr, ast.Call): + return None + if not ( + isinstance(expr.func, ast.Attribute) + and expr.func.attr in {"update", "setdefault", "__setitem__"} + ): + return None + root = _ast_root_name(expr.func.value) + if root not in dict_names: + return None + if any(_input_derived(arg, params) for arg in expr.args): + return root + if any(_input_derived(keyword.value, params) for keyword in expr.keywords): + return root + return None + + def _view_call(expr: ast.AST | None) -> tuple[str, str] | None: + if not ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr in {"values", "items"} + ): + return None + root = _ast_root_name(expr.func.value) + if root in dict_names: + return root, expr.func.attr + return None + + def _return_view(expr: ast.AST | None, aliases: dict[str, tuple[str, str]]) -> str | None: + if isinstance(expr, ast.Subscript): + return _return_view(expr.value, aliases) + if not ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Name) + and expr.func.id == "next" + and len(expr.args) == 1 + ): + return None + iter_call = expr.args[0] + if not ( + isinstance(iter_call, ast.Call) + and isinstance(iter_call.func, ast.Name) + and iter_call.func.id == "iter" + and len(iter_call.args) == 1 + ): + return None + source = iter_call.args[0] + view = _view_call(source) + if view: + return view[0] + if isinstance(source, ast.Name) and source.id in aliases: + return aliases[source.id][0] + return None + + def _body_has_only_view_calls(body: list[ast.stmt], aliases: dict[str, tuple[str, str]]) -> bool: + allowed_names = {"next", "iter"} + allowed_attrs = {"values", "items"} + for stmt in body: + for expr in ast.walk(stmt): + if not isinstance(expr, ast.Call): + continue + if isinstance(expr.func, ast.Name) and expr.func.id in allowed_names: + continue + if isinstance(expr.func, ast.Attribute) and expr.func.attr in allowed_attrs: + continue + if isinstance(expr.func, ast.Name) and expr.func.id in aliases: + continue + return False + return True + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + + params = _param_names(node) + mutated = { + root + for expr in ast.walk(node) + for root in [_mutated_dict(expr, params)] + if root is not None + } + if not mutated: + continue + + aliases = { + stmt.targets[0].id: view + for stmt in ast.walk(node) + if isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + for view in [_view_call(stmt.value)] + if view is not None + } + + for child in ast.walk(node): + if not isinstance(child, ast.If): + continue + if _expr_names(child.test) & params: + continue + if not _body_has_only_view_calls(child.body, aliases): + continue + returned = { + root + for stmt in child.body + if isinstance(stmt, ast.Return) + for root in [_return_view(stmt.value, aliases)] + if root is not None + } + if returned & mutated: + return [{ + "pattern": "DICT_VIEW_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} returns a values/items view " + "from a dict mutated with a prior input" + ), + }] + + return [] + + RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)") RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=") RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]") @@ -3479,6 +3642,10 @@ class RulePolicy: "LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep", ), + "DICT_VIEW_REPLAY": RulePolicy( + "DICT_VIEW_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), + (), "keep", + ), "SHAPE_OUTPUT_REPLAY": RulePolicy( "SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep", @@ -3762,6 +3929,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_decode_mm_ref, detect_result_caching, detect_last_call_replay, + detect_dict_view_replay, detect_shape_output_replay, detect_timed_input_replay, detect_cuda_graph_replay, @@ -3800,6 +3968,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("decode_mm_ref", detect_decode_mm_ref), ("result_caching", detect_result_caching), ("last_call_replay", detect_last_call_replay), + ("dict_view_replay", detect_dict_view_replay), ("shape_output_replay", detect_shape_output_replay), ("timed_input_replay", detect_timed_input_replay), ("cuda_graph_replay", detect_cuda_graph_replay), @@ -4696,7 +4865,7 @@ def _worker_parquet(args: tuple) -> dict: "EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT", "FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS", "TRUSTED_MODULE_IMPORT", - "OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE", + "OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "DICT_VIEW_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE", "RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY", "TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING", "SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",