-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathkernelguard.py
More file actions
10967 lines (9804 loc) · 449 KB
/
Copy pathkernelguard.py
File metadata and controls
10967 lines (9804 loc) · 449 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
"""
KernelGuard — Rule-based GPU kernel hack detector.
Detects cheating/hacking patterns in GPU kernel competition submissions.
Usage:
kernelguard --jsonl /path/to/pairs.jsonl
kernelguard --parquet /path/to/submissions.parquet
kernelguard --audit-rules
"""
import argparse
import ast
import copy
from dataclasses import dataclass, field
import glob
import hashlib
import io
import json
import multiprocessing as mp
import os
import re
import sys
import tomllib
import tokenize
from collections import Counter, defaultdict
from datetime import datetime, timezone
from typing import Any, Optional
CONFIG_VERSION = 1
DEFAULT_PROFILE_NAME = "default"
# ---------------------------------------------------------------------------
# Compiled regex patterns (module-level for performance)
# ---------------------------------------------------------------------------
# Pattern 1: Timer monkey-patching
RE_CUDA_EVENT_ASSIGN = re.compile(r'torch\.cuda\.Event\s*=\s*')
RE_CUDA_EVENT_SETATTR = re.compile(r'setattr\s*\(\s*torch\.cuda\s*,\s*[\'"]Event[\'"]\s*,')
RE_ELAPSED_TIME_METHOD_PATCH = re.compile(r'torch\.cuda\.Event\.elapsed_time\s*=')
RE_CUDA_SYNCHRONIZE_PATCH = re.compile(r'torch\.cuda\.synchronize\s*=')
RE_DUMMY_FAKE_EVENT = re.compile(r'class\s+_?(?:Dummy|Fake|Mock|Stub|Noop)\w*Event', re.IGNORECASE)
RE_ELAPSED_TIME_FAKE = re.compile(r'def\s+elapsed_time\s*\([^)]*\)\s*:')
# Pattern 2: Print/stdout injection
RE_EMIT_FAKE_BENCH = re.compile(r'def\s+_?emit_fake_benchmark')
RE_BUILTINS_PRINT_ASSIGN = re.compile(r'builtins\.print\s*=')
RE_PATCHED_PRINT = re.compile(r'_(?:patched|real|orig|original)_print')
RE_SYS_STDOUT_REPLACE = re.compile(r'sys\.stdout\s*=')
RE_SYS_STDERR_REPLACE = re.compile(r'sys\.stderr\s*=')
RE_RANKED_BENCHMARK_PRINT = re.compile(
r'(?:print|sys\.stdout\.write)\s*\([^)]*(?:Ranked\s+Benchmark)',
re.DOTALL,
)
RE_RANKED_BENCHMARK_STR = re.compile(r'["\'].*?Ranked\s+Benchmark.*?["\']', re.DOTALL)
# Pattern 3: Evaluator exploitation
RE_MAIN_TIME_KERNEL = re.compile(r'main\.time_kernel\s*=')
RE_MAIN_CLONE_DATA = re.compile(r'main\._clone_data\s*=')
RE_MAIN_CLEAR_L2 = re.compile(r'main\.clear_l2_cache\s*=')
RE_SYS_MODULES_MAIN = re.compile(
r'sys\.modules(?:\s*\[\s*[\'"]__main__[\'"]\s*\]|\.get\s*\(\s*[\'"]__main__[\'"]\s*\))',
)
RE_MAIN_ALIAS_FROM_GET = re.compile(
r'(\w+)\s*=\s*sys\.modules\.get\s*\(\s*[\'"]__main__[\'"]\s*\)',
)
RE_MAIN_ALIAS_FROM_INDEX = re.compile(
r'(\w+)\s*=\s*sys\.modules\s*\[\s*[\'"]__main__[\'"]\s*\]',
)
RE_MAIN_ATTR_PATCH = re.compile(
r'(\w+)\.(time_kernel|_clone_data|clear_l2_cache)\s*=',
)
RE_MAIN_SETATTR_PATCH = re.compile(
r'setattr\s*\(\s*(\w+)\s*,\s*[\'"](time_kernel|_clone_data|clear_l2_cache)[\'"]\s*,',
)
# Pattern 4: torch._scaled_mm reference
RE_SCALED_MM_PYTHON = re.compile(r'torch\._scaled_mm\s*\(')
RE_SCALED_MM_ALIAS = re.compile(r'=\s*torch\._scaled_mm\s*(?!\()$', re.MULTILINE) # _mm = torch._scaled_mm (not a call)
RE_SCALED_MM_ALIAS_ASSIGN = re.compile(r'^\s*(\w+)\s*=\s*torch\._scaled_mm\s*$', re.MULTILINE)
RE_SCALED_MM_CPP = re.compile(r'aten::_scaled_mm')
RE_SCALED_MM_SCHEMA = re.compile(r'findSchemaOrThrow\s*\(\s*["\']aten::_scaled_mm')
RE_SCALED_MM_ANY = re.compile(r'_scaled_mm')
# Pattern 5: torch.mm/bmm as main compute (Python-only check)
RE_TORCH_MM = re.compile(r'torch\.mm\s*\(')
RE_TORCH_BMM = re.compile(r'torch\.bmm\s*\(')
RE_TORCH_MATMUL = re.compile(r'torch\.matmul\s*\(')
# Pattern 6: Result caching by pointer
RE_WEAK_VALUE_DICT = re.compile(r'WeakValueDictionary')
RE_DECODED_CACHE = re.compile(r'_decoded_cache')
RE_PREPROCESS_CACHE = re.compile(r'_PREPROCESS_CACHE')
RE_RESULT_REUSE = re.compile(r'_result_reuse')
RE_SCALE_CACHE = re.compile(r'_scale_cache')
RE_OUTPUT_CACHE = re.compile(r'_OUTPUT_CACHE')
RE_RESULT_CACHE_GENERAL = re.compile(r'_(?:RESULT|GROUPED_RESULT|COMPUTE|GEMM)_CACHE')
RE_ID_DATA_CACHE = re.compile(r'id\s*\(\s*data\s*\)')
RE_DATA_PTR_CACHE_KEY = re.compile(
r'(?:\w+\s*\[\s*int\s*\(\s*\w+\.data_ptr\s*\(\)\s*\)\s*\])|' # cache[int(A.data_ptr())]
r'(?:\.data_ptr\s*\(\)\s*[,\)].*?(?:cache|key|dict|hash))',
re.IGNORECASE,
)
RE_VERSION_CACHE = re.compile(r'\._version\b')
RE_CACHE_GET_ASSIGN = re.compile(
r'(\w+)\s*=\s*(\w*(?:cache|reuse)\w*)\.get\s*\(',
re.IGNORECASE,
)
# "return cache[...]" but exclude compiled-kernel / module caches (legitimate)
RE_RETURN_CACHE_INDEX = re.compile(
r'return\s+(?!_?(?:compiled|kernel|module|func|op)_?\w*cache)'
r'(?P<cache>\w*(?:cache|reuse)\w*)\s*\[',
re.IGNORECASE,
)
RE_CACHE_STORE_OUTPUT = re.compile(
r'\w*(?:cache|reuse)\w*\s*\[[^\]]+\]\s*=\s*(?:result|output|out|c_ref)\b',
re.IGNORECASE,
)
# Cache populated by a `.compile(...)` method call — stores a compiled
# callable (not an output tensor), so a subsequent `return cache[key]`
# returns the callable, not a result. Recognizes the standard tile-DSL
# compile-cache idiom used by TileLang / Triton / CuTeDSL / Pallas:
#
# _KERNEL_CACHE[(M, N, K)] = my_kernel.compile(M=M, N=N, K=K, ...)
# _KCACHE[key] = _gemm.compile(...)
#
# Variable-name allowlists (e.g. RE_RETURN_CACHE_INDEX's
# `compiled|kernel|module|func|op` prefix list) miss abbreviated names
# (e.g. `_KCACHE`, `_CCACHE`, `_TILECACHE`); checking the RHS of the
# store is more robust because the *value* in the cache is what
# distinguishes legitimate compile-cache from output replay.
#
# Deliberately matches only the `.compile(` method-call form, not a
# function-call form like `_compile_kernel(out)` — the latter is too
# easy to forge with an identity passthrough (e.g. attacker defines
# `_compile_passthrough(x): return x` and writes
# `_CACHE[key] = _compile_passthrough(output)`). The method-call form
# requires the cached value to actually have a `.compile()` attribute,
# which PyTorch tensors do not.
# Whitespace tolerance: callers apply this against strip_python_strings_and_comments(...)
# which emits a space between tokens, so source `_gemm.compile(` becomes
# `_gemm . compile (`. The `\s*` around the `.` covers both forms.
RE_CACHE_STORES_COMPILED_CALLABLE = re.compile(
r'(?P<cache>\w*(?:cache|reuse)\w*)\s*\[[^\]]+\]\s*=\s*'
r'[^=\n]*?\s*\.\s*compile\s*\(',
re.IGNORECASE,
)
RE_CPP_CACHE_FAST_RETURN = re.compile(
r'if\s*\([^)]*(?:cache\.last|cache\.prev|lastA|lastB|prevA|prevB)[^)]*\)\s*(?:\{[^{}]{0,200}?\breturn\b|return\b)',
re.DOTALL,
)
# Pattern 7: CUDA graph replay
RE_CTYPES_LIBCUDART = re.compile(r'ctypes\.CDLL\s*\(\s*["\']libcudart')
RE_CUDA_GRAPH_CREATE = re.compile(r'cudaGraphCreate')
RE_CUDA_GRAPH_LAUNCH = re.compile(r'cudaGraphLaunch')
RE_CUDA_GRAPH_INSTANTIATE = re.compile(r'cudaGraphInstantiate')
RE_CUDA_STREAM_CAPTURE = re.compile(r'cudaStreamBeginCapture')
# Pattern 8: Silent fallback — limit scope to ~500 chars after except (roughly 10-15 lines)
RE_TORCH_MATMUL_FALLBACK = re.compile(
r'except\s+(?:Exception|RuntimeError|BaseException|TypeError)[^:]*:'
r'[^\n]{0,500}?'
r'(?:torch\.matmul|torch\.mm\s*\(|torch\.bmm\s*\()',
re.DOTALL,
)
# Pattern 9: Trivial/probe
RE_RETURN_EMPTY = re.compile(r'return\s+\[\s*\]')
RE_RETURN_NONE = re.compile(r'return\s+None\b')
# Pattern 11: torch.compile
RE_TORCH_COMPILE = re.compile(r'torch\.compile\s*\(')
# Pattern 12: Python CUDAGraph caching (not ctypes — uses torch.cuda.CUDAGraph API)
RE_CUDA_GRAPH_PYTHON = re.compile(r'torch\.cuda\.CUDAGraph\s*\(')
RE_CUDA_GRAPH_REPLAY = re.compile(r'\.replay\s*\(')
RE_CUDA_GRAPH_CACHE_KEY = re.compile(
r'data_ptr\s*\(\).*?(?:graph|cache|dict)|'
r'(?:graph|cache|dict).*?data_ptr\s*\(\)',
re.IGNORECASE | re.DOTALL,
)
RE_GRAPH_CACHE_NAME = re.compile(r'\b(?:_graph_cache|graph_cache)\b', re.IGNORECASE)
# Pattern 13: Hardcoded benchmark shapes (KERNEL_CONFIGS gated by exact dimensions)
RE_KERNEL_CONFIGS_DICT = re.compile(r'KERNEL_CONFIGS\s*=\s*\{')
RE_SHAPE_TUPLE_KEY = re.compile(r'\(\s*\d{2,5}\s*,\s*\d{2,5}\s*(?:,\s*\d{1,5}\s*)?\)\s*:')
RE_SHAPE_IF_GATE = re.compile(
r'if\s+.*?(?:==|in)\s*[\[(]?\s*\(?\s*\d{3,5}\s*,\s*\d{3,5}',
)
# Pattern 14: Unsynchronized multi-stream dispatch
RE_GET_STREAM_FROM_POOL = re.compile(r'getStreamFromPool|get_stream_from_pool|torch\.cuda\.Stream\s*\(')
RE_NO_SYNC_STREAM = re.compile(r'(?:stream|s)\d*\.synchronize\s*\(\)')
RE_STREAM_WAIT_EVENT = re.compile(r'\.wait_event\s*\(')
RE_STREAM_WAIT_STREAM = re.compile(r'\.wait_stream\s*\(')
RE_TORCH_CUDA_SYNCHRONIZE = re.compile(r'torch\.cuda\.synchronize\s*\(')
RE_CPP_STREAM_SYNC = re.compile(
r'(?:cudaStreamSynchronize|cudaDeviceSynchronize|cudaEventSynchronize|cudaStreamWaitEvent)\s*\(',
)
RE_CPP_METHOD_SYNC = re.compile(r'\.(?:synchronize|wait_event|wait_stream)\s*\(')
# Pattern 15: cudaEventDisableTiming
RE_CUDA_EVENT_DISABLE_TIMING = re.compile(r'cudaEventDisableTiming|disable_timing\s*=\s*True')
# C++/CUDA block markers for stripping
RE_CPP_MARKERS = re.compile(r'#include|__global__|__device__|__host__|extern\s+"C"|asm\s+volatile')
RE_TRIPLE_QUOTED = re.compile(r'(?:r)?(?:"""|\'\'\')(.*?)(?:"""|\'\'\')', re.DOTALL)
# ---------------------------------------------------------------------------
# Utility: strip C++/CUDA inline source from Python code
# ---------------------------------------------------------------------------
_LARGE_LITERAL_THRESHOLD = 2000 # chars; blobs larger than this are stripped
def strip_cpp_cuda_blocks(code: str) -> str:
"""Remove C/CUDA source literals and large data blobs from Python code.
Strips triple-quoted strings that either:
- Contain C++/CUDA source markers (__global__, #include, etc.), OR
- Are larger than _LARGE_LITERAL_THRESHOLD characters (e.g. base64 blobs,
embedded source code, lookup tables). These blobs are never hacking
indicators and their long all-word-char content causes catastrophic
backtracking in patterns like (\\w+)\\s*=\\s*literal.
"""
def replacer(match):
content = match.group(1)
if len(content) > _LARGE_LITERAL_THRESHOLD or RE_CPP_MARKERS.search(content):
return '""""""' # empty triple-quoted string
return match.group(0)
return RE_TRIPLE_QUOTED.sub(replacer, code)
def strip_python_comments(code: str) -> str:
"""Remove Python comments without touching string literals."""
try:
tokens = [
tok
for tok in tokenize.generate_tokens(io.StringIO(code).readline)
if tok.type != tokenize.COMMENT
]
except (tokenize.TokenError, IndentationError, SyntaxError):
return code
return tokenize.untokenize(tokens)
def strip_python_strings_and_comments(code: str) -> str:
"""Replace Python COMMENT and STRING tokens with same-length whitespace.
Used to defeat string-literal / comment decoys: an attacker can plant
text matching a legitimate-looking pattern inside a docstring or a
`NOTE = "..."` constant to flip a detector signal. Most detectors
operate on `python_active` (which strips comments only) — when a
detector needs to be robust against in-string decoys, run it against
this stripped form instead.
Preserves line/column offsets so error messages stay roughly intact.
Falls back to the original code on tokenize / syntax errors."""
try:
out = []
for tok in tokenize.generate_tokens(io.StringIO(code).readline):
if tok.type in (tokenize.COMMENT, tokenize.STRING):
out.append(" " * len(tok.string))
elif tok.type in (
tokenize.NL, tokenize.NEWLINE,
tokenize.INDENT, tokenize.DEDENT,
tokenize.ENCODING, tokenize.ENDMARKER,
):
if tok.type in (tokenize.NL, tokenize.NEWLINE):
out.append("\n")
else:
out.append(tok.string)
out.append(" ")
return "".join(out)
except (tokenize.TokenError, IndentationError, SyntaxError, ValueError):
return code
def _compiled_callable_cache_names(scope: str) -> set[str]:
outputish_names = {"result", "output", "out", "c_ref"}
try:
tree = ast.parse(scope)
except SyntaxError:
stripped = strip_python_strings_and_comments(scope)
names: set[str] = set()
for match in RE_CACHE_STORES_COMPILED_CALLABLE.finditer(stripped):
rhs = match.group(0).split("=", 1)[-1]
if not re.search(r'\b(?:result|output|out|c_ref)\b', rhs, re.IGNORECASE):
names.add(match.group("cache").lower())
return names
local_compile_classes = {
node.name
for node in ast.walk(tree)
if isinstance(node, ast.ClassDef)
and any(
isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef))
and item.name == "compile"
for item in node.body
)
}
local_compile_instances: set[str] = set()
outputish_instances: set[str] = set()
compiled_caches: set[str] = set()
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if (
isinstance(target, ast.Name)
and isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Name)
and node.value.func.id in local_compile_classes
):
local_compile_instances.add(target.id)
if (
isinstance(target, ast.Name)
and isinstance(node.value, ast.Call)
and (_expr_names(node.value) & outputish_names)
):
outputish_instances.add(target.id)
def is_safe_compile_call(expr: ast.AST) -> bool:
if not (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "compile"
):
return False
if _expr_names(expr) & outputish_names:
return False
receiver = expr.func.value
if (
isinstance(receiver, ast.Call)
and isinstance(receiver.func, ast.Name)
and receiver.func.id in local_compile_classes
):
return False
if isinstance(receiver, ast.Name) and receiver.id in local_compile_instances:
return False
if isinstance(receiver, ast.Name) and receiver.id in outputish_instances:
return False
return True
for node in ast.walk(tree):
if not isinstance(node, ast.Assign):
continue
if not is_safe_compile_call(node.value):
continue
for target in node.targets:
if isinstance(target, ast.Subscript):
root = _ast_root_name(target.value)
if root and re.search(r'(?:cache|reuse)', root, re.IGNORECASE):
compiled_caches.add(root.lower())
return compiled_caches
def extract_function_block(code: str, func_name: str) -> str:
"""Best-effort extraction of a Python function block from source text."""
lines = code.splitlines()
func_re = re.compile(rf'^\s*def\s+{re.escape(func_name)}\s*\(')
for i, line in enumerate(lines):
if not func_re.match(line):
continue
base_indent = len(line) - len(line.lstrip())
block = [line]
for nxt in lines[i + 1:]:
stripped = nxt.strip()
cur_indent = len(nxt) - len(nxt.lstrip())
if stripped and cur_indent <= base_indent and re.match(r'^\s*(def|class)\s+\w+', nxt):
break
block.append(nxt)
return "\n".join(block)
return ""
RE_IMPORT_MAIN_AS = re.compile(r'import\s+__main__\s+as\s+(\w+)')
RE_FROM_IMPORT_MAIN = re.compile(r'from\s+__main__\s+import\s+')
def find_main_aliases(code: str) -> set[str]:
"""Collect local variable aliases bound to sys.modules['__main__']."""
aliases = {"main"}
for m in RE_MAIN_ALIAS_FROM_GET.finditer(code):
aliases.add(m.group(1))
for m in RE_MAIN_ALIAS_FROM_INDEX.finditer(code):
aliases.add(m.group(1))
for m in RE_IMPORT_MAIN_AS.finditer(code):
aliases.add(m.group(1))
return aliases
def find_scaled_mm_aliases(code: str) -> set[str]:
"""Collect local aliases bound directly to torch._scaled_mm."""
aliases = set()
for m in re.finditer(r'^\s*(\w+)\s*=\s*torch\._scaled_mm\s*$', code, re.MULTILINE):
aliases.add(m.group(1))
return aliases
def function_uses_scaled_mm(func_body: str, alias_names: set[str]) -> bool:
"""Return True when a function body calls torch._scaled_mm or one of its aliases."""
if RE_SCALED_MM_PYTHON.search(func_body):
return True
for alias in alias_names:
if re.search(rf'\b{re.escape(alias)}\s*\(', func_body):
return True
return False
BASE_ENTRYPOINT_NAMES = ("custom_kernel",)
ENTRYPOINT_NAMES = BASE_ENTRYPOINT_NAMES
def _entrypoint_candidates() -> tuple[str, ...]:
return ENTRYPOINT_NAMES
def is_entrypoint_name(name: str) -> bool:
return name in ENTRYPOINT_NAMES
def entrypoint_label(name: Optional[str] = None) -> str:
return name or (ENTRYPOINT_NAMES[0] if ENTRYPOINT_NAMES else "entrypoint")
@dataclass
class SubmissionFacts:
"""Shared normalized views and AST summaries for one submission."""
raw_code: str
python_only: str
python_active: str
ast_tree: Optional[ast.AST]
main_aliases: set[str]
scaled_mm_aliases: set[str]
trusted_aliases: dict[str, str]
entrypoint_name: Optional[str]
custom_kernel_pos: Optional[int]
code_before_custom_kernel: str
code_from_custom_kernel: str
custom_kernel_code: str
custom_kernel_active: str
_function_blocks: dict[str, str] = field(default_factory=dict)
_active_function_blocks: dict[str, str] = field(default_factory=dict)
# --- Pre-computed AST indices (populated by _build_ast_index) ---
# Nodes that contain a .data_ptr() call anywhere in their subtree
_nodes_with_data_ptr: set[int] = field(default_factory=set)
# Nodes that contain ._version attribute access
_nodes_with_version: set[int] = field(default_factory=set)
# Function names (non-entrypoint) whose body contains data_ptr / _version
_data_ptr_helpers: set[str] = field(default_factory=set)
_version_helpers: set[str] = field(default_factory=set)
# Module-level vars initialized to None
_none_inited: set[str] = field(default_factory=set)
# All assignments: {target_name: [value_node, ...]}
_assignments_by_target: dict[str, list] = field(default_factory=dict)
# All import statements
_imports: list = field(default_factory=list)
_import_froms: list = field(default_factory=list)
# Class definitions
_class_defs: list = field(default_factory=list)
def get_function_block(self, func_name: str) -> str:
block = self._function_blocks.get(func_name)
if block is None:
block = extract_function_block(self.raw_code, func_name)
self._function_blocks[func_name] = block
return block
def get_active_function_block(self, func_name: str) -> str:
block = self._active_function_blocks.get(func_name)
if block is None:
block = strip_python_comments(self.get_function_block(func_name))
self._active_function_blocks[func_name] = block
return block
def build_submission_facts(code: str) -> SubmissionFacts:
"""Parse and normalize a submission once for reuse across all detectors."""
python_only = strip_cpp_cuda_blocks(code)
python_active = strip_python_comments(python_only)
tree = _safe_ast_parse(code)
entrypoint_name = None
custom_kernel_match = None
for candidate_name in _entrypoint_candidates():
match = re.search(rf'^\s*def\s+{re.escape(candidate_name)}\s*\(', code, re.MULTILINE)
if match is not None:
entrypoint_name = candidate_name
custom_kernel_match = match
break
custom_kernel_pos = custom_kernel_match.start() if custom_kernel_match else None
code_before_custom_kernel = code[:custom_kernel_pos] if custom_kernel_pos is not None else code
code_from_custom_kernel = code[custom_kernel_pos:] if custom_kernel_pos is not None else code
custom_kernel_code = extract_function_block(code, entrypoint_name or entrypoint_label())
custom_kernel_active = strip_python_comments(custom_kernel_code)
trusted_aliases = _collect_trusted_aliases(tree) if tree is not None else {}
facts = SubmissionFacts(
raw_code=code,
python_only=python_only,
python_active=python_active,
ast_tree=tree,
main_aliases=find_main_aliases(python_only),
scaled_mm_aliases=find_scaled_mm_aliases(code_before_custom_kernel),
trusted_aliases=trusted_aliases,
entrypoint_name=entrypoint_name,
custom_kernel_pos=custom_kernel_pos,
code_before_custom_kernel=code_before_custom_kernel,
code_from_custom_kernel=code_from_custom_kernel,
custom_kernel_code=custom_kernel_code,
custom_kernel_active=custom_kernel_active,
)
if entrypoint_name:
facts._function_blocks[entrypoint_name] = custom_kernel_code
facts._active_function_blocks[entrypoint_name] = custom_kernel_active
facts._function_blocks["custom_kernel"] = custom_kernel_code
facts._active_function_blocks["custom_kernel"] = custom_kernel_active
_build_ast_index(facts)
return facts
def _build_ast_index(facts: SubmissionFacts) -> None:
"""Single-pass AST walk to populate all index fields on facts."""
tree = facts.ast_tree
if tree is None:
return
nodes_with_data_ptr: set[int] = set()
nodes_with_version: set[int] = set()
data_ptr_helpers: set[str] = set()
version_helpers: set[str] = set()
none_inited: set[str] = set()
imports: list = []
import_froms: list = []
class_defs: list = []
# Single walk: tag every node that is a data_ptr call or _version access
for node in ast.walk(tree):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
if node.func.attr == "data_ptr":
nodes_with_data_ptr.add(id(node))
if isinstance(node, ast.Attribute) and node.attr == "_version":
nodes_with_version.add(id(node))
if isinstance(node, ast.Import):
imports.append(node)
elif isinstance(node, ast.ImportFrom):
import_froms.append(node)
elif isinstance(node, ast.ClassDef):
class_defs.append(node)
# Module-level None-initialized vars
for stmt in tree.body:
if isinstance(stmt, ast.Assign):
if isinstance(stmt.value, ast.Constant) and stmt.value.value is None:
for t in stmt.targets:
n = _ast_root_name(t)
if n:
none_inited.add(n)
# Find helper functions (non-entrypoint) that contain data_ptr / _version
for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if is_entrypoint_name(node.name):
continue
for child in ast.walk(node):
if id(child) in nodes_with_data_ptr:
data_ptr_helpers.add(node.name)
if id(child) in nodes_with_version:
version_helpers.add(node.name)
if node.name in data_ptr_helpers and node.name in version_helpers:
break
# Propagate: mark ancestor expressions as containing data_ptr / _version
# We need this for _expr_has_data_ptr / _expr_has_tensor_version replacements
# Walk each assignment value and check if any descendant has the tag
# This is still O(n) total since we do one walk and check set membership
facts._nodes_with_data_ptr = nodes_with_data_ptr
facts._nodes_with_version = nodes_with_version
facts._data_ptr_helpers = data_ptr_helpers
facts._version_helpers = version_helpers
facts._none_inited = none_inited
facts._imports = imports
facts._import_froms = import_froms
facts._class_defs = class_defs
def _expr_has_data_ptr_fast(expr: ast.AST | None, index: set[int]) -> bool:
"""O(subtree) check using pre-computed index — avoids full ast.walk per call."""
if expr is None:
return False
for node in ast.walk(expr):
if id(node) in index:
return True
return False
def _expr_has_version_fast(expr: ast.AST | None, index: set[int]) -> bool:
if expr is None:
return False
for node in ast.walk(expr):
if id(node) in index:
return True
return False
def ensure_submission_facts(code_or_facts: str | SubmissionFacts) -> SubmissionFacts:
"""Accept a raw code string or a pre-built SubmissionFacts object."""
if isinstance(code_or_facts, SubmissionFacts):
return code_or_facts
return build_submission_facts(code_or_facts)
def _ast_root_name(expr: ast.AST | None) -> Optional[str]:
"""Return the left-most name that owns an expression, when present."""
cur = expr
while cur is not None:
if isinstance(cur, ast.Name):
return cur.id
if isinstance(cur, ast.Attribute):
cur = cur.value
continue
if isinstance(cur, ast.Subscript):
cur = cur.value
continue
break
return None
def _ast_dotted_name(expr: ast.AST | None) -> Optional[str]:
"""Return a dotted name such as torch.linalg.householder_product."""
parts: list[str] = []
cur = expr
while cur is not None:
if isinstance(cur, ast.Name):
parts.append(cur.id)
return ".".join(reversed(parts))
if isinstance(cur, ast.Attribute):
parts.append(cur.attr)
cur = cur.value
continue
break
return None
def _expr_names(expr: ast.AST | None) -> set[str]:
if expr is None:
return set()
return {
node.id
for node in ast.walk(expr)
if isinstance(node, ast.Name)
}
def _target_names(target: ast.AST | None) -> set[str]:
"""Return all simple names assigned by a target expression."""
if target is None:
return set()
if isinstance(target, ast.Name):
return {target.id}
if isinstance(target, (ast.Tuple, ast.List)):
names: set[str] = set()
for elt in target.elts:
names.update(_target_names(elt))
return names
if isinstance(target, ast.Starred):
return _target_names(target.value)
root = _ast_root_name(target)
return {root} if root else set()
def _expr_has_data_ptr(expr: ast.AST | None) -> bool:
if expr is None:
return False
return any(
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and node.func.attr == "data_ptr"
for node in ast.walk(expr)
)
def _expr_has_tensor_version(expr: ast.AST | None) -> bool:
if expr is None:
return False
return any(
isinstance(node, ast.Attribute) and node.attr == "_version"
for node in ast.walk(expr)
)
_TRIVIAL_GPU_OPS = frozenset({
"fill_", "zero_", "copy_", "fill", "zero", "record",
})
def _body_has_calls(body: list[ast.stmt]) -> bool:
"""Return True if the body contains non-trivial function calls.
Tiny GPU ops like ``_tiny.fill_(0)`` or ``_anchor.copy_(_anchor)`` are
common dummy work used to keep CUDA timers non-zero; they don't count
as real compute and should not prevent replay detection.
"""
for stmt in body:
for nested in ast.walk(stmt):
if not isinstance(nested, ast.Call):
continue
# Allow trivial method calls: obj.fill_(0), obj.copy_(obj), etc.
if (isinstance(nested.func, ast.Attribute)
and nested.func.attr in _TRIVIAL_GPU_OPS):
continue
return True
return False
def _looks_stateful_name(name: str) -> bool:
lowered = name.lower()
return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo"))
_ENTRYPOINT_METHOD_NAMES = ("__call__", "forward", "run", "solve")
def _iter_non_nested_nodes(node: ast.AST):
"""Yield descendants without descending into nested function/class scopes."""
for child in ast.iter_child_nodes(node):
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)):
yield child
continue
yield child
yield from _iter_non_nested_nodes(child)
def _function_input_names(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
args = list(fn.args.posonlyargs) + list(fn.args.args) + list(fn.args.kwonlyargs)
if fn.name in _ENTRYPOINT_METHOD_NAMES and args and args[0].arg in {"self", "cls"}:
args = args[1:]
names = {arg.arg for arg in args}
if fn.args.vararg is not None:
names.add(fn.args.vararg.arg)
if fn.args.kwarg is not None:
names.add(fn.args.kwarg.arg)
return names
def _method_from_class(cls: ast.ClassDef, preferred: tuple[str, ...] = _ENTRYPOINT_METHOD_NAMES):
methods = {
child.name: child
for child in cls.body
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef))
}
for name in preferred:
if name in methods:
return methods[name]
return None
def _factory_returned_function(fn: ast.FunctionDef | ast.AsyncFunctionDef):
nested = {
child.name: child
for child in fn.body
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef))
}
for stmt in fn.body:
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name):
returned = nested.get(stmt.value.id)
if returned is not None:
return returned
return None
def _entrypoint_function_nodes(facts: SubmissionFacts) -> list[ast.FunctionDef | ast.AsyncFunctionDef]:
"""Resolve simple Python callable exports for entrypoint-scoped detectors."""
tree = facts.ast_tree
if tree is None:
return []
functions: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = {}
classes: dict[str, ast.ClassDef] = {}
instances: dict[str, str] = {}
aliases: dict[str, str] = {}
resolved: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
seen: set[int] = set()
for stmt in tree.body:
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
functions[stmt.name] = stmt
elif isinstance(stmt, ast.ClassDef):
classes[stmt.name] = stmt
def add(fn: ast.FunctionDef | ast.AsyncFunctionDef | None) -> None:
if fn is None or id(fn) in seen:
return
seen.add(id(fn))
resolved.append(fn)
def resolve_name(name: str) -> str:
while name in aliases and aliases[name] != name:
name = aliases[name]
return name
for stmt in tree.body:
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) and is_entrypoint_name(stmt.name):
add(stmt)
elif isinstance(stmt, ast.ClassDef) and is_entrypoint_name(stmt.name):
add(_method_from_class(stmt))
if not isinstance(stmt, ast.Assign):
continue
target_names = [t.id for t in stmt.targets if isinstance(t, ast.Name)]
if not target_names:
continue
value = stmt.value
if isinstance(value, ast.Name):
value_name = resolve_name(value.id)
for target in target_names:
aliases[target] = value_name
if is_entrypoint_name(target):
add(functions.get(value_name))
add(_method_from_class(classes[value_name]) if value_name in classes else None)
elif value_name in classes:
instances[target] = value_name
elif isinstance(value, ast.Call):
callee = value.func
if isinstance(callee, ast.Name):
callee_name = resolve_name(callee.id)
if callee_name in classes:
for target in target_names:
instances[target] = callee_name
if is_entrypoint_name(target):
add(_method_from_class(classes[callee_name]))
elif callee_name == "partial" and value.args and isinstance(value.args[0], ast.Name):
fn = functions.get(resolve_name(value.args[0].id))
for target in target_names:
if is_entrypoint_name(target):
add(fn)
elif callee_name in functions:
for target in target_names:
if is_entrypoint_name(target):
add(_factory_returned_function(functions[callee_name]) or functions[callee_name])
elif isinstance(callee, ast.Attribute):
if callee.attr == "partial" and value.args and isinstance(value.args[0], ast.Name):
fn = functions.get(resolve_name(value.args[0].id))
for target in target_names:
if is_entrypoint_name(target):
add(fn)
owner = callee.value
if isinstance(owner, ast.Call) and isinstance(owner.func, ast.Name):
class_name = resolve_name(owner.func.id)
if class_name in classes and callee.attr in _ENTRYPOINT_METHOD_NAMES:
for target in target_names:
if is_entrypoint_name(target):
add(_method_from_class(classes[class_name], (callee.attr,)))
elif isinstance(value, ast.Attribute):
owner = value.value
if value.attr in _ENTRYPOINT_METHOD_NAMES and isinstance(owner, ast.Name):
owner_name = resolve_name(owner.id)
class_name = instances.get(owner_name, owner_name if owner_name in classes else "")
if class_name in classes:
for target in target_names:
if is_entrypoint_name(target):
add(_method_from_class(classes[class_name], (value.attr,)))
return resolved
def _expr_is_none(expr: ast.AST | None) -> bool:
return isinstance(expr, ast.Constant) and expr.value is None
def _static_string(expr: ast.AST | None) -> Optional[str]:
if isinstance(expr, ast.Constant) and isinstance(expr.value, str):
return expr.value
if isinstance(expr, ast.JoinedStr):
parts: list[str] = []
for value in expr.values:
if not isinstance(value, ast.Constant) or not isinstance(value.value, str):
return None
parts.append(value.value)
return "".join(parts)
if isinstance(expr, ast.BinOp) and isinstance(expr.op, ast.Add):
left = _static_string(expr.left)
right = _static_string(expr.right)
if left is not None and right is not None:
return left + right
if (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "join"
and _static_string(expr.func.value) is not None
and len(expr.args) == 1
and isinstance(expr.args[0], (ast.List, ast.Tuple))
):
parts = [_static_string(elt) for elt in expr.args[0].elts]
if all(part is not None for part in parts):
return _static_string(expr.func.value).join(parts) # type: ignore[arg-type]
return None
def _assembled_static_string(expr: ast.AST | None) -> Optional[str]:
"""Resolve a string only when it is *assembled* from fragments.
Returns the concatenated value for ``"a" + "b"`` / multi-fragment f-strings /
``"".join([...])``, and ``None`` for a plain string literal. Assembling an
otherwise-static attribute name (e.g. ``'elapsed' + '_time'``) is a deliberate
obfuscation signal: legitimate kernels never split a fixed attribute name like
that, so it is a high-precision marker for an evasion attempt.
"""
if isinstance(expr, ast.Constant):
return None
if isinstance(expr, ast.JoinedStr):
if sum(1 for value in expr.values if isinstance(value, ast.Constant)) <= 1:
return None
return _static_string(expr)
if isinstance(expr, ast.BinOp) and isinstance(expr.op, ast.Add):
return _static_string(expr)
if (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "join"
):
return _static_string(expr)
return None
def _expr_has_benchmark_literal(expr: ast.AST | None) -> bool:
if expr is None:
return False
for node in ast.walk(expr):
if isinstance(node, ast.Constant) and isinstance(node.value, str):
if re.search(r'Ranked\s+Benchmark|BENCHMARK_PASSED|\bbenchmark\b\s*[:=]?|score\s*[:=]', node.value, re.IGNORECASE):
return True
return False
def _expr_has_decode_like_call(expr: ast.AST | None, helper_names: set[str] | None = None) -> bool:
if expr is None:
return False
helper_names = helper_names or set()
decode_names = {
"decode", "decompress", "b64decode", "b32decode", "b16decode",
"urlsafe_b64decode", "decodebytes", "decodestring", "unhexlify",
"a2b_hex", "a2b_base64", "bytes", "bytearray", "chr",
}
for node in ast.walk(expr):
if not isinstance(node, ast.Call):
continue
if isinstance(node.func, ast.Name) and (node.func.id in decode_names or node.func.id in helper_names):
return True
if isinstance(node.func, ast.Attribute) and node.func.attr in decode_names:
return True
return False
def _expr_contains_input_derived_call(expr: ast.AST | None, input_names: set[str]) -> bool:
if expr is None:
return False
for node in ast.walk(expr):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Attribute) and _ast_root_name(node.func.value) in input_names:
return True
if isinstance(node.func, ast.Name) and any(_expr_names(arg) & input_names for arg in node.args):
return True
return False
def _is_input_float_call(expr: ast.AST | None, input_names: set[str]) -> bool:
return (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "float"
and isinstance(expr.func.value, ast.Name)
and expr.func.value.id in input_names
and not expr.args
and not expr.keywords
)
def _is_input_attr_float_call(expr: ast.AST | None, owner_name: str) -> Optional[str]:
if not (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "float"
and isinstance(expr.func.value, ast.Attribute)
and isinstance(expr.func.value.value, ast.Name)
and expr.func.value.value.id == owner_name
and not expr.args
and not expr.keywords
):