Skip to content

[Feat]: Add Final Norm for vLLM Hidden Extractor#1846

Open
h-guo18 wants to merge 12 commits into
mainfrom
haoguo/fakebase-final-norm
Open

[Feat]: Add Final Norm for vLLM Hidden Extractor#1846
h-guo18 wants to merge 12 commits into
mainfrom
haoguo/fakebase-final-norm

Conversation

@h-guo18

@h-guo18 h-guo18 commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: Bug fix

vLLM captures the final-layer hidden state before the model's final norm, but the
offline/streaming distillation path fed it straight into lm_head, so the reconstructed
base logits (the KD target) were computed from un-normed hidden states.

This PR re-applies the base model's final norm before lm_head when the producer declares
a pre-norm capture (base_hidden_prenorm), for both DFlash and EAGLE:

  • Producer sets base_hidden_prenorm (streaming: True; offline: from the dump).
  • Consumer (_maybe_apply_base_final_norm) re-applies the base final norm, and fails loud
    if pre-norm is declared but the model's norm type isn't supported (no silent corruption).
  • FakeBaseModel now loads the base final norm (+ rope_theta/rms_norm_eps); norm type is
    gated by an explicit model_type allowlist (gpt_oss excluded pending a matching norm class).

Testing

tests/unit/torch/speculative/plugins/test_modeling_final_norm.py; DFlash/EAGLE streaming
training verified end-to-end.

  • Backward compatible?: ✅ (post-norm captures declare base_hidden_prenorm=False → unchanged)
  • New tests: ✅

Summary by CodeRabbit

  • New Features
    • Expanded speculative decoding/distillation to reconstruct missing base-model logits by optionally applying a base model’s final pre–LM-head normalization when pre-norm hidden states are provided.
    • Streaming dataset now emits base_hidden_prenorm and can configure RDMA backends from environment settings.
  • Bug Fixes
    • Rejects mixed base_hidden_prenorm values within a batch.
    • Fails fast on streaming token-length mismatches.
    • Streaming dataset loading supports directory inputs by expanding sorted JSONL shards (and errors if none are found).
  • Tests
    • Added/updated unit and dataset tests for final-norm behavior and the new batch field.

@copy-pr-bot

copy-pr-bot Bot commented Jun 28, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds final-norm reconstruction helpers, threads prenorm metadata through dataset and RDMA paths, updates offline EAGLE/DFlash logit reconstruction, and expands the example streaming data loader to resolve directory inputs into JSONL shards.

Changes

Final-norm reconstruction and streaming flow updates

Layer / File(s) Summary
Final-norm module and tests
modelopt/torch/speculative/plugins/modeling_final_norm.py, tests/unit/torch/speculative/plugins/test_modeling_final_norm.py
Adds the final-norm helper module, type selection logic, conditional norm application, and unit tests covering selection, dtype behavior, no-op behavior, application, and error handling.
FakeBaseModel final-norm loading
modelopt/torch/speculative/plugins/modeling_fakebase.py
Persists final-norm config values, conditionally builds the norm submodule, derives the final-norm type from the source model, and validates/copies embed, lm_head, and norm weights during loading.
Prenorm flags and streaming RDMA metadata
modelopt/torch/speculative/eagle/utils.py, modelopt/torch/speculative/plugins/hf_streaming_dataset.py, modelopt/torch/speculative/plugins/rdma_hidden_states_connector.py, tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py, tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py
Carries base_hidden_prenorm through offline and streaming batches, makes NIXL backends configurable, restricts tokenization to conversations, adds hs_max_tokens metadata plus a runtime size check, and updates the key assertions in dataset tests.
DFlash offline norm reconstruction
modelopt/torch/speculative/plugins/hf_dflash.py, modelopt/torch/speculative/plugins/modeling_dflash.py
Discovers the optional base final norm and uses it during offline logit reconstruction when logits are absent.
EAGLE offline norm reconstruction
modelopt/torch/speculative/plugins/hf_eagle.py
Discovers the optional base final norm and re-applies it before reconstructing offline logits.
Streaming example shard loading
examples/speculative_decoding/eagle_utils.py
Expands directory inputs into sorted JSONL shard lists before streaming dataset loading.

Estimated code review effort: 4 (Complex) | ~60 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 74.36% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main change: adding final-norm handling for vLLM hidden-state extraction.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed PR diff only updates tests; no forbidden torch.load(weights_only=False), numpy.load(allow_pickle=True), trust_remote_code=True, eval/exec, or new # nosec appear.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/fakebase-final-norm

Comment @coderabbitai help to get the list of available commands.

@github-actions

github-actions Bot commented Jun 28, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1846/

Built to branch gh-pages at 2026-07-03 01:00 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@h-guo18 h-guo18 changed the title final norm [Fix]: Applying Final Norm for vLLM Hiddens Extractor Jun 28, 2026
@h-guo18 h-guo18 changed the title [Fix]: Applying Final Norm for vLLM Hiddens Extractor [Fix]: Add Final Norm for vLLM HIddens Extracter Jun 29, 2026
@h-guo18 h-guo18 self-assigned this Jul 1, 2026
@h-guo18 h-guo18 changed the title [Fix]: Add Final Norm for vLLM HIddens Extracter [Fix]: Add Final Norm for vLLM Hidden Extractor Jul 1, 2026
@h-guo18 h-guo18 force-pushed the haoguo/fakebase-final-norm branch from 3540a3f to 9b43796 Compare July 1, 2026 00:44
@copy-pr-bot

copy-pr-bot Bot commented Jul 1, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@h-guo18 h-guo18 marked this pull request as ready for review July 1, 2026 00:44
@h-guo18 h-guo18 requested review from a team as code owners July 1, 2026 00:44
@h-guo18 h-guo18 requested a review from yeyu-nvidia July 1, 2026 00:44
Comment thread tests/unit/torch/speculative/plugins/test_modeling_final_norm.py
Comment thread modelopt/torch/speculative/plugins/rdma_hidden_states_connector.py Outdated
Comment thread modelopt/torch/speculative/plugins/modeling_final_norm.py Outdated
Comment thread modelopt/torch/speculative/plugins/modeling_final_norm.py Outdated

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 2

🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/rdma_hidden_states_connector.py (1)

252-259: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Justify or reconsider the function-local import.

The import of nixl_backends_from_env from .hf_streaming_dataset is placed inside register_kv_caches without an explicit reason (circular import, optional dependency, or heavy-import deferral). The other local imports in this method guard optional nixl/vllm deps; this one imports from a trainer-side dataset module, which could indicate a circular-import concern or an intent to avoid pulling trainer-only dependencies into the vllm worker process — but that reasoning isn't stated. As per coding guidelines, "Place imports inside functions only when necessary to resolve circular imports, guard optional dependencies, or defer heavy imports with explicit justification" and CONTRIBUTING.md requires "a brief comment naming the reason" for in-function imports.

♻️ Suggested comment clarification
-        # Backend(s) from NIXL_BACKENDS; shared helper keeps this locked to the trainer-side
-        # agent (they must use the same backend to hand off over RDMA).
-        from .hf_streaming_dataset import nixl_backends_from_env
+        # Backend(s) from NIXL_BACKENDS; shared helper keeps this locked to the trainer-side
+        # agent (they must use the same backend to hand off over RDMA). Imported here (not at
+        # module top) to avoid pulling hf_streaming_dataset's trainer-only deps into the vllm
+        # worker process at import time.
+        from .hf_streaming_dataset import nixl_backends_from_env
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/speculative/plugins/rdma_hidden_states_connector.py` around
lines 252 - 259, The function-local import in register_kv_caches for
nixl_backends_from_env is not justified, so either move it to the module scope
or keep it inside the method with a brief inline comment explaining the specific
reason (for example, circular import avoidance or deferring trainer-side
dependencies). Use the existing register_kv_caches and nixl_backends_from_env
symbols to locate the import, and ensure the rationale matches the other
optional-dependency imports in this method.

Source: Coding guidelines

examples/speculative_decoding/eagle_utils.py (1)

94-106: 🎯 Functional Correctness | 🔵 Trivial | 💤 Low value

Directory expansion logic looks correct.

Non-recursive glob("*.jsonl") will miss shards nested in subdirectories; if shard directories can have nested layout, consider rglob("*.jsonl") instead. Low priority since flat shard directories are the common case.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/eagle_utils.py` around lines 94 - 106, The
directory-to-shards expansion in eagle_utils.py only uses the data-loading path
logic around load_dataset and Path.glob, so it will miss any .jsonl shards
nested inside subdirectories. Update the directory handling in the data_path
branch to search recursively for jsonl files, while preserving the existing
empty-directory error and print_rank_0 shard-count logging.
modelopt/torch/speculative/plugins/hf_streaming_dataset.py (1)

67-77: 🩺 Stability & Availability | 🔵 Trivial | ⚡ Quick win

nixl_backends_from_env doesn't handle empty or whitespace-padded values.

os.environ.get("NIXL_BACKENDS", "UCX") only falls back to "UCX" when the var is unset, not when it's set to "", so NIXL_BACKENDS="" yields [""]. Values like "UCX, LIBFABRIC" also aren't stripped, producing a mismatched backend name (" LIBFABRIC"). Since this helper is shared with the producer-side connector, either edge case misconfigures NIXL on both ends.

🔧 Proposed fix
 def nixl_backends_from_env() -> list[str]:
-    return os.environ.get("NIXL_BACKENDS", "UCX").split(",")
+    raw = os.environ.get("NIXL_BACKENDS") or "UCX"
+    return [b.strip() for b in raw.split(",") if b.strip()] or ["UCX"]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/speculative/plugins/hf_streaming_dataset.py` around lines 67 -
77, The nixl_backends_from_env helper currently returns raw comma-split values,
so an empty NIXL_BACKENDS becomes [""], and whitespace-padded entries like "UCX,
LIBFABRIC" keep invalid spaces. Update nixl_backends_from_env to treat empty or
all-whitespace env values as missing and fall back to "UCX", then normalize the
parsed list by stripping each backend name and filtering out empty entries so
the shared connector logic stays aligned.
modelopt/torch/speculative/eagle/utils.py (1)

197-204: 🗄️ Data Integrity & Integration | 🔵 Trivial | ⚡ Quick win

Batch-level base_hidden_prenorm assumes homogeneity, unvalidated.

Taking the flag only from features[0] silently assumes every sample in the batch was dumped with the same prenorm setting. If a dataset directory ever mixes dumps from different producer runs (e.g. pre-PR dumps without the key alongside new ones), samples that disagree with features[0] get the wrong norm behavior applied/skipped with no error — exactly the silent-corruption failure mode _maybe_apply_base_final_norm was designed to prevent.

🛡️ Proposed fix: assert batch homogeneity
-        if "base_hidden_prenorm" in features[0]:
-            base_model_outputs["base_hidden_prenorm"] = features[0]["base_hidden_prenorm"]
+        if "base_hidden_prenorm" in features[0]:
+            flag = features[0]["base_hidden_prenorm"]
+            if any(item.get("base_hidden_prenorm", flag) != flag for item in features):
+                raise ValueError(
+                    "Mixed base_hidden_prenorm values within a batch; dumps must be produced "
+                    "consistently."
+                )
+            base_model_outputs["base_hidden_prenorm"] = flag
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/speculative/eagle/utils.py` around lines 197 - 204, The
batch-level propagation of base_hidden_prenorm in the utility that builds
base_model_outputs assumes every item in features matches features[0], which can
silently mix different producer settings. Update the batching logic in the same
helper to validate that all samples agree on base_hidden_prenorm before copying
it into base_model_outputs, and raise an error if any item differs so
_maybe_apply_base_final_norm is never driven by a mixed batch.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 169-178: The `_base_model_norm` property and final-norm lookup
logic are duplicated between `HFDFlashModel` and `HFEagleModel`; extract the
shared probing behavior into a common helper near `_FINAL_NORM_PATHS` in
`modeling_fakebase.py` and have both `_find_base_model_parts` implementations
call it. Keep `_base_model_norm` as a thin accessor over the resolved
`base_model_norm_path` so both models share the same path resolution and future
changes only need to be made once.

In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 284-306: The shape validation in the tensor-loading path is
relying on assert, which disappears under optimized Python runs and can
re-enable silent broadcasting bugs. Update the checks in the fakebase model
loading logic, especially around _read, embed_tokens.weight, lm_head.weight, and
the final norm in modeling_fakebase.py, to use explicit if-condition checks that
raise an error when shapes do not match. Keep the same expected shapes derived
from hidden_size, vocab_size, and final_norm_type, but make the failure path
unconditional so corruption is still blocked under python -O.

---

Nitpick comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 94-106: The directory-to-shards expansion in eagle_utils.py only
uses the data-loading path logic around load_dataset and Path.glob, so it will
miss any .jsonl shards nested inside subdirectories. Update the directory
handling in the data_path branch to search recursively for jsonl files, while
preserving the existing empty-directory error and print_rank_0 shard-count
logging.

In `@modelopt/torch/speculative/eagle/utils.py`:
- Around line 197-204: The batch-level propagation of base_hidden_prenorm in the
utility that builds base_model_outputs assumes every item in features matches
features[0], which can silently mix different producer settings. Update the
batching logic in the same helper to validate that all samples agree on
base_hidden_prenorm before copying it into base_model_outputs, and raise an
error if any item differs so _maybe_apply_base_final_norm is never driven by a
mixed batch.

In `@modelopt/torch/speculative/plugins/hf_streaming_dataset.py`:
- Around line 67-77: The nixl_backends_from_env helper currently returns raw
comma-split values, so an empty NIXL_BACKENDS becomes [""], and
whitespace-padded entries like "UCX, LIBFABRIC" keep invalid spaces. Update
nixl_backends_from_env to treat empty or all-whitespace env values as missing
and fall back to "UCX", then normalize the parsed list by stripping each backend
name and filtering out empty entries so the shared connector logic stays
aligned.

In `@modelopt/torch/speculative/plugins/rdma_hidden_states_connector.py`:
- Around line 252-259: The function-local import in register_kv_caches for
nixl_backends_from_env is not justified, so either move it to the module scope
or keep it inside the method with a brief inline comment explaining the specific
reason (for example, circular import avoidance or deferring trainer-side
dependencies). Use the existing register_kv_caches and nixl_backends_from_env
symbols to locate the import, and ensure the rationale matches the other
optional-dependency imports in this method.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 2fba9c92-4e11-4a46-aec4-f87c19ed3ef9

📥 Commits

Reviewing files that changed from the base of the PR and between 84fc1f9 and 9b43796.

📒 Files selected for processing (9)
  • examples/speculative_decoding/eagle_utils.py
  • modelopt/torch/speculative/eagle/utils.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • modelopt/torch/speculative/plugins/hf_eagle.py
  • modelopt/torch/speculative/plugins/hf_streaming_dataset.py
  • modelopt/torch/speculative/plugins/modeling_fakebase.py
  • modelopt/torch/speculative/plugins/modeling_final_norm.py
  • modelopt/torch/speculative/plugins/rdma_hidden_states_connector.py
  • tests/unit/torch/speculative/plugins/test_modeling_final_norm.py

Comment thread modelopt/torch/speculative/plugins/hf_dflash.py
Comment thread modelopt/torch/speculative/plugins/modeling_fakebase.py
Comment thread modelopt/torch/speculative/plugins/hf_dflash.py Outdated
Comment thread modelopt/torch/speculative/plugins/hf_streaming_dataset.py Outdated
@h-guo18 h-guo18 marked this pull request as draft July 2, 2026 22:10
Comment thread modelopt/torch/speculative/plugins/hf_streaming_dataset.py Outdated
@h-guo18 h-guo18 marked this pull request as ready for review July 2, 2026 23:19
@h-guo18

h-guo18 commented Jul 2, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

h-guo18 added 4 commits July 2, 2026 23:28
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
h-guo18 added 5 commits July 2, 2026 23:28
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 force-pushed the haoguo/fakebase-final-norm branch from a7a1c50 to d5ef662 Compare July 2, 2026 23:30
@h-guo18 h-guo18 changed the title [Fix]: Add Final Norm for vLLM Hidden Extractor [Feat]: Add Final Norm for vLLM Hidden Extractor Jul 2, 2026
Comment thread modelopt/torch/speculative/plugins/hf_streaming_dataset.py Outdated

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Review Summary

Findings — CRITICAL: 0 · IMPORTANT: 1 · SUGGESTION: 0

Scope: reviewed all 8 changed modelopt/ and examples/speculative_decoding/ source files and the new unit test. (The diff against the shallow base tip also surfaces unrelated examples/hf_ptq/ and examples/torch_trt/ churn that is base drift, not part of this PR — not reviewed.)

The core fix is sound: re-applying the base final norm before lm_head when the producer declares a pre-norm capture (base_hidden_prenorm) is the right correction; failing loud when the norm cannot be located prevents silent KD-target corruption; and the offline default of False keeps old .pt dumps backward-compatible. _FinalRMSNorm matches the Llama/Qwen/Mistral/DeepSeek/Kimi RMSNorm formula (fp32 reduction, dtype-correct weight), and gpt_oss is correctly excluded pending a matching norm class. The hs_max_tokens guard and the shape checks using raise (surviving python -O) are good hardening.

--- Most impactful finding ---

[IMPORTANT Compatibility] _tokenize_entry now reads only conversations, dropping the messages key entirely (hf_streaming_dataset.py:267). The stated bug is a priority-order problem (Spec-Decoding-v2 carries a degenerate messages stub alongside the real conversations), but the README documents the canonical user dataset format as messages:[...], which the streaming path also consumes. A messages-only corpus now yields None for every entry, so the resample loop exhausts the corpus and raises the misleading "server likely down" error. Preferring "conversations or messages" fixes the priority bug and preserves the documented format. Inline suggestion posted.

--- Concurring with CodeRabbit (elevated severity) ---

CodeRabbit flagged the batch-homogeneity assumption in EagleOfflineDataCollator (eagle/utils.py:200, reading features[0]) as trivial. I would rate it higher: OfflineSupervisedDataset always sets the key, so a directory mixing pre-PR dumps (False) with new pre-norm dumps (True) can put disagreeing samples in one batch; the whole batch then follows features[0], silently feeding un-normed hiddens to lm_head — exactly the corruption _maybe_apply_base_final_norm was built to prevent. Asserting batch homogeneity is worth doing.

--- Risk assessment ---

Low-to-moderate. The norm reconstruction is correct and well-guarded. The one blocking item is the messages/conversations narrowing, a straightforward compat regression against the documented dataset format with an easy fix.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/hf_streaming_dataset.py (1)

66-78: 🩺 Stability & Availability | 🔵 Trivial | ⚡ Quick win

Strip whitespace from parsed backend names.

os.environ.get("NIXL_BACKENDS", "UCX").split(",") doesn't strip whitespace, so a value like "UCX, LIBFABRIC" yields [" LIBFABRIC"] with a leading space. NIXL backend names are matched as exact strings (e.g. "backends":["LIBFABRIC"]), so this would silently break agent creation on a common env-var formatting mistake, and since both trainer and producer use this same helper, the bug is baked into both sides equally.

🛠️ Proposed fix
-    return os.environ.get("NIXL_BACKENDS", "UCX").split(",")
+    return [b.strip() for b in os.environ.get("NIXL_BACKENDS", "UCX").split(",") if b.strip()]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/speculative/plugins/hf_streaming_dataset.py` around lines 66 -
78, The nixl_backends_from_env helper currently returns raw comma-split values,
so NIXL_BACKENDS entries with spaces are preserved and can break exact backend
matching. Update nixl_backends_from_env to normalize the parsed list by trimming
whitespace from each backend name (and dropping any empty entries if needed)
while keeping the same default behavior, so both the trainer-side and
producer-side paths continue to agree on valid backend strings.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_streaming_dataset.py`:
- Around line 66-78: The nixl_backends_from_env helper currently returns raw
comma-split values, so NIXL_BACKENDS entries with spaces are preserved and can
break exact backend matching. Update nixl_backends_from_env to normalize the
parsed list by trimming whitespace from each backend name (and dropping any
empty entries if needed) while keeping the same default behavior, so both the
trainer-side and producer-side paths continue to agree on valid backend strings.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 82bb96a6-31fe-47ec-91ba-1309bf13fac2

📥 Commits

Reviewing files that changed from the base of the PR and between d5ef662 and 8e85239.

📒 Files selected for processing (4)
  • modelopt/torch/speculative/eagle/utils.py
  • modelopt/torch/speculative/plugins/hf_streaming_dataset.py
  • modelopt/torch/speculative/plugins/modeling_final_norm.py
  • tests/unit/torch/speculative/plugins/test_modeling_final_norm.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • modelopt/torch/speculative/eagle/utils.py
  • tests/unit/torch/speculative/plugins/test_modeling_final_norm.py
  • modelopt/torch/speculative/plugins/modeling_final_norm.py

@h-guo18

h-guo18 commented Jul 2, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Review Summary

Findings — CRITICAL: 0 · IMPORTANT: 0 · SUGGESTION: 1

Scope: reviewed all 9 changed source files (7 modelopt/, 1 examples/, 1 new unit test) plus the RDMA connector. Nothing deprioritized.

Assessment: The core fix is correct and well-guarded. vLLM captures the residual stream before the base model final norm, and the offline/streaming self-logit-distillation path previously fed that un-normed hidden straight into lm_head, corrupting the KD target. The fix threads a base_hidden_prenorm declaration from producer to dataset to collator to consumer and re-applies the base final norm in _maybe_apply_base_final_norm before lm_head, only when pre-norm is declared.

Verified:

  • Algorithm: _FinalRMSNorm matches the Llama/Qwen/Mistral/DeepSeek/Kimi RMSNorm formula (fp32 reduction, dtype-correct bf16 weight to avoid float32 promotion mismatching the bf16 lm_head). gpt_oss is correctly excluded pending a matching (fp32-weight, multiply-then-cast) norm class.
  • Re-norm boundary: the normed hidden is computed into a local out_hiddens used only for logit reconstruction; base_outputs.out_hiddens stays un-normed for the EAGLE draft input, which is intended. DFlash and Domino paths thread it correctly; Domino ignores logits (no target-logit KD), so omitting the norm there is correct.
  • Fail-loud guards: pre-norm declared + no norm located raises RuntimeError; known final_norm_type + missing checkpoint key is a hard error; shape checks use raise (survive python -O); undersized hs_max_tokens raises a clear error instead of a hung fetch.
  • Backward compat: offline default base_hidden_prenorm=False keeps existing .pt dumps on the prior post-norm path; unknown model_type builds no norm and downstream skips re-norming.

Both findings from the prior Claude review are resolved in this revision: the messages/conversations priority order now prefers conversations with a messages fallback, and EagleOfflineDataCollator now rejects a batch mixing pre-norm and post-norm dumps.

Non-blocking: one SUGGESTION on nixl_backends_from_env whitespace handling around comma-separated NIXL_BACKENDS.

Risk: low. Correct, well-tested, and backward-compatible. LGTM.

the container lacks, so UCX RDMA dies there.
"""
return os.environ.get("NIXL_BACKENDS", "UCX").split(",")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] nixl_backends_from_env splits on , without stripping whitespace. The docstring documents this as a comma-separated env var (NIXL_BACKENDS=LIBFABRIC), and a user who naturally writes NIXL_BACKENDS="UCX, LIBFABRIC" (with a space) would get ["UCX", " LIBFABRIC"] — the second entry has a leading space and won't match a valid NIXL backend name, silently breaking the RDMA handoff. Since this list must match on both the trainer and producer side, a mismatch here is hard to diagnose. Stripping (and dropping empties) makes it robust:

return [b.strip() for b in os.environ.get("NIXL_BACKENDS", "UCX").split(",") if b.strip()]

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py (1)

204-213: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick win

Assert the actual value of base_hidden_prenorm, not just its presence.

The test only checks that base_hidden_prenorm is a key in the returned item, but never verifies its derived value. Since correctly deriving this flag from the offline dump is the actual bug this PR fixes, a value-only assertion doesn't protect against a regression in that derivation logic (e.g., silently defaulting to False/True regardless of the dump).

✅ Example addition
     assert item["input_ids"].shape == (SEQ_LEN,)
     assert item["attention_mask"].shape == (SEQ_LEN,)
     assert item["labels"].shape == (SEQ_LEN,)
+    assert item["base_hidden_prenorm"] == <expected_value_based_on_dump>
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py` around
lines 204 - 213, The test around ds[0] only verifies that base_hidden_prenorm
exists, but it should also assert the derived boolean value matches the offline
dump. Update the spec in test_hf_speculative_offline.py near the ds[0]
assertions to check base_hidden_prenorm’s actual value for the returned item,
using the same dataset/item setup already in place so regressions in the
derivation logic are caught.
tests/unit/torch/speculative/plugins/test_fakebase.py (1)

49-60: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick win

Add assertions for the loaded norm weight.

These fixtures now supply norm.weight because the new FakeBaseModel logic requires and copies it, but none of the three tests assert that model.norm.weight was actually loaded/copied correctly (only lm_head/embed_tokens shapes are checked). Since this PR's core change is about correctly reconstructing/loading the final norm, verifying the norm weight in these tests would catch regressions in exactly the new behavior being added.

✅ Example addition
 def test_fakebase_local_happy_path(fake_checkpoint):
     model = FakeBaseModel.from_source(str(fake_checkpoint))
     assert model.lm_head.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE])
     assert model.embed_tokens.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE])
+    assert model.norm.weight.shape == torch.Size([_HIDDEN_SIZE])
+    torch.testing.assert_close(model.norm.weight, torch.ones(_HIDDEN_SIZE))

Also applies to: 75-84, 88-98

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/speculative/plugins/test_fakebase.py` around lines 49 - 60,
The FakeBaseModel tests now provide norm.weight in the fixture, but the
assertions in the affected test cases still only verify lm_head and
embed_tokens, so the new final-norm loading path is untested. Update the three
tests that use fake_checkpoint/fake_config to also assert that model.norm.weight
is loaded/copied correctly, using the existing FakeBaseModel/`model.norm.weight`
references so the checks cover the new reconstruction behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/unit/torch/speculative/plugins/test_fakebase.py`:
- Around line 49-60: The FakeBaseModel tests now provide norm.weight in the
fixture, but the assertions in the affected test cases still only verify lm_head
and embed_tokens, so the new final-norm loading path is untested. Update the
three tests that use fake_checkpoint/fake_config to also assert that
model.norm.weight is loaded/copied correctly, using the existing
FakeBaseModel/`model.norm.weight` references so the checks cover the new
reconstruction behavior.

In `@tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py`:
- Around line 204-213: The test around ds[0] only verifies that
base_hidden_prenorm exists, but it should also assert the derived boolean value
matches the offline dump. Update the spec in test_hf_speculative_offline.py near
the ds[0] assertions to check base_hidden_prenorm’s actual value for the
returned item, using the same dataset/item setup already in place so regressions
in the derivation logic are caught.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: f9873e95-05d3-4c30-a637-d7585ecb0c9e

📥 Commits

Reviewing files that changed from the base of the PR and between 8e85239 and 558a956.

📒 Files selected for processing (3)
  • tests/unit/torch/speculative/plugins/test_fakebase.py
  • tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py
  • tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py
✅ Files skipped from review due to trivial changes (1)
  • tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@codecov

codecov Bot commented Jul 3, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 76.34409% with 22 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.98%. Comparing base (4b9225b) to head (e4a541a).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...elopt/torch/speculative/plugins/modeling_dflash.py 22.22% 7 Missing ⚠️
modelopt/torch/speculative/plugins/hf_eagle.py 62.50% 6 Missing ⚠️
modelopt/torch/speculative/plugins/hf_dflash.py 64.28% 5 Missing ⚠️
...peculative/plugins/rdma_hidden_states_connector.py 0.00% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1846      +/-   ##
==========================================
- Coverage   70.21%   69.98%   -0.24%     
==========================================
  Files         515      516       +1     
  Lines       57244    57538     +294     
==========================================
+ Hits        40196    40268      +72     
- Misses      17048    17270     +222     
Flag Coverage Δ
unit 54.95% <76.34%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants