Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7886,6 +7886,10 @@ class RulePolicy:
"INPUT_STRUCTURE_TRUNCATION_DISPATCH", "fixture_specialization", "hard", AUTO_FILTER, (),
(), "keep",
),
"BENCHMARK_RANK_TRUNCATION": RulePolicy(
"BENCHMARK_RANK_TRUNCATION", "fixture_specialization", "hard", AUTO_FILTER, (),
(), "keep",
),
"UNSYNC_MULTISTREAM": RulePolicy(
"UNSYNC_MULTISTREAM", "timing_manipulation", "telemetry", TELEMETRY_ONLY, (),
(), "downgrade",
Expand Down Expand Up @@ -8120,6 +8124,75 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
# All detectors list
# ---------------------------------------------------------------------------


def detect_benchmark_rank_truncation(code):
"""Compile-method-INDEPENDENT fixture-specialization / fixed-limit truncation detector.

Targets the general hack of inspecting the INPUT's numerical structure and then doing LESS work
than the full problem (truncating output columns/rows, zeroing an output tail, stopping early)
with a HARDCODED or classifier-derived bound. Problem-agnostic: the value-independent
fingerprints (output tail-zeroing + `limit<n` + loop bounded by `limit`) catch arbitrary fixed
limits on any problem, not just the qr_v2 ranks {256,384,768}.

Why needed: BENCHMARK_UNSAFE_ALGO_DISPATCH requires `raw_native_cache` (literal `nvcc`), so
`load_inline` builds evade it; INPUT_STRUCTURE_TRUNCATION_DISPATCH requires gather/scatter
`subset_set>=2`, so uniform classify-then-truncate evades it.

Precision: a legitimate full QR has none of these signals; a legitimate rank-revealing QR
derives its bound ADAPTIVELY at runtime (matrix_rank / `(diag>tol).sum()` / count_nonzero), so
the value-independent `fixed_limit` rule is suppressed when an adaptive-rank computation is
present. Validated zero-FP on confirmed-clean full-QR and adaptive-RRQR samples.
"""
facts = ensure_submission_facts(code)
combined = facts.python_active
low = combined.lower()
classifier = set(re.findall(r"\b(?:classify_512|classify_1024|detect_label|colnorm|col_norm|zerofrac|zero_frac|neardiff|near_diff|tail_max|colrange|col_range|rank_probe|cond_probe|estimate_rank|detect_rank|matrix_kind|input_kind|_s3_collinear)\b", low))
family = set(re.findall(r"\b(?:rankdef|rank_def|clustered|nearrank|near_rank|nearcol|low_rank|lowrank|degenerate|near_singular|deficient)\b", low))
trunc_named = set(re.findall(r"\b(?:_cqr_blocked_limit|_blocked_qr_trunc|zero_tail|inactive_tail|stop_at|switch_at|truncate|panel_tail_partial)\b", low))
tailzero = re.findall(r"[A-Za-z_]\w*\[[^\]]*:\s*\]\s*=\s*0\.?\b", combined)
limit_lt_n = re.findall(r"\blimit\s*<\s*n\b", low)
range_limit = re.findall(r"range\(\s*0\s*,\s*limit\b", low)
trunc_general = len(trunc_named) + len(tailzero) + len(limit_lt_n) + len(range_limit)
permatrix = bool(re.search(r"\b(?:active_limits|limits_ptr|limits\s*\[)\b|__restrict__\s*limits", low))
rankcaps = set(re.findall(r"(?:limit|trunc|stop|partial|ncol|cqr_blocked_limit|blocked_qr_trunc|==)\D{0,30}\b(256|384|768)\b", low))
rankcaps |= set(re.findall(r"\b(256|384|768)\b\D{0,30}(?:limit|trunc|stop|partial)", low))
fixed_limit = (bool(re.search(r"limit\s*=\s*\d{2,}", low))
or bool(re.search(r"_(?:cqr_blocked_limit|blocked_qr_trunc)\([^)]*,\s*\d{2,}", low))
or len(rankcaps) > 0)
adaptive_rank = (bool(re.search(r"matrix_rank|count_nonzero", low))
or bool(re.search(r"\.sum\(\)[^\n]{0,40}(?:tol|>|<)", low))
or bool(re.search(r"(?:rank|limit)\s*=\s*\([^)]*(?:>|<)[^)]*\)\.sum", low)))
strong_truncation = (
len(trunc_named) >= 1
or len(limit_lt_n) >= 1
or len(range_limit) >= 1
or len(rankcaps) >= 1
or len(tailzero) >= 2
)
cond_classifier = len(classifier) >= 1 and strong_truncation and not adaptive_rank
cond_label = (
len(family) >= 1
and trunc_general >= 1
and (len(tailzero) >= 1 or len(limit_lt_n) >= 1)
and not adaptive_rank
)
cond_permatrix = permatrix and trunc_general >= 1 and len(rankcaps) >= 1 and not adaptive_rank
cond_rankcaps = len(rankcaps) >= 2 and trunc_general >= 1 and not adaptive_rank
cond_fixedlimit = (len(tailzero) >= 1 and (len(limit_lt_n) >= 1 or len(range_limit) >= 1)
and fixed_limit and not adaptive_rank)
fired = [n for n, v in (("classifier", cond_classifier), ("label", cond_label),
("per_matrix", cond_permatrix), ("rank_caps", cond_rankcaps),
("fixed_limit", cond_fixedlimit)) if v]
if fired:
return [{"pattern": "BENCHMARK_RANK_TRUNCATION", "severity": "high",
"evidence": (f"fixture-specialization / fixed-limit truncation [{','.join(fired)}]: "
f"classifiers={len(classifier)}, family={len(family)}, "
f"trunc(named={len(trunc_named)},tailzero={len(tailzero)},limit<n={len(limit_lt_n)},"
f"range_limit={len(range_limit)}), per-matrix-limits={permatrix}, "
f"rank-caps={len(rankcaps)}, fixed_limit={fixed_limit}, adaptive_rank={adaptive_rank}")}]
return []


CODE_DETECTORS = [
# Regex-based detectors (Layer 1)
detect_timer_monkeypatch,
Expand Down Expand Up @@ -8159,6 +8232,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_hardcoded_shapes,
detect_benchmark_unsafe_algo_dispatch,
detect_input_structure_truncation_dispatch,
detect_benchmark_rank_truncation,
detect_unsync_multistream,
detect_cuda_event_disable_timing,
detect_token_paste_cuda_api,
Expand Down Expand Up @@ -8222,6 +8296,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("hardcoded_shapes", detect_hardcoded_shapes),
("benchmark_unsafe_algo_dispatch", detect_benchmark_unsafe_algo_dispatch),
("input_structure_truncation_dispatch", detect_input_structure_truncation_dispatch),
("benchmark_rank_truncation", detect_benchmark_rank_truncation),
("unsync_multistream", detect_unsync_multistream),
("cuda_event_disable_timing", detect_cuda_event_disable_timing),
("token_paste_cuda_api", detect_token_paste_cuda_api),
Expand Down