[Feat]: Add Final Norm for vLLM Hidden Extractor#1846
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds final-norm reconstruction helpers, threads prenorm metadata through dataset and RDMA paths, updates offline EAGLE/DFlash logit reconstruction, and expands the example streaming data loader to resolve directory inputs into JSONL shards. ChangesFinal-norm reconstruction and streaming flow updates
Estimated code review effort: 4 (Complex) | ~60 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
3540a3f to
9b43796
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 2
🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/rdma_hidden_states_connector.py (1)
252-259: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueJustify or reconsider the function-local import.
The import of
nixl_backends_from_envfrom.hf_streaming_datasetis placed insideregister_kv_cacheswithout an explicit reason (circular import, optional dependency, or heavy-import deferral). The other local imports in this method guard optionalnixl/vllmdeps; this one imports from a trainer-side dataset module, which could indicate a circular-import concern or an intent to avoid pulling trainer-only dependencies into the vllm worker process — but that reasoning isn't stated. As per coding guidelines, "Place imports inside functions only when necessary to resolve circular imports, guard optional dependencies, or defer heavy imports with explicit justification" and CONTRIBUTING.md requires "a brief comment naming the reason" for in-function imports.♻️ Suggested comment clarification
- # Backend(s) from NIXL_BACKENDS; shared helper keeps this locked to the trainer-side - # agent (they must use the same backend to hand off over RDMA). - from .hf_streaming_dataset import nixl_backends_from_env + # Backend(s) from NIXL_BACKENDS; shared helper keeps this locked to the trainer-side + # agent (they must use the same backend to hand off over RDMA). Imported here (not at + # module top) to avoid pulling hf_streaming_dataset's trainer-only deps into the vllm + # worker process at import time. + from .hf_streaming_dataset import nixl_backends_from_env🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/speculative/plugins/rdma_hidden_states_connector.py` around lines 252 - 259, The function-local import in register_kv_caches for nixl_backends_from_env is not justified, so either move it to the module scope or keep it inside the method with a brief inline comment explaining the specific reason (for example, circular import avoidance or deferring trainer-side dependencies). Use the existing register_kv_caches and nixl_backends_from_env symbols to locate the import, and ensure the rationale matches the other optional-dependency imports in this method.Source: Coding guidelines
examples/speculative_decoding/eagle_utils.py (1)
94-106: 🎯 Functional Correctness | 🔵 Trivial | 💤 Low valueDirectory expansion logic looks correct.
Non-recursive
glob("*.jsonl")will miss shards nested in subdirectories; if shard directories can have nested layout, considerrglob("*.jsonl")instead. Low priority since flat shard directories are the common case.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/eagle_utils.py` around lines 94 - 106, The directory-to-shards expansion in eagle_utils.py only uses the data-loading path logic around load_dataset and Path.glob, so it will miss any .jsonl shards nested inside subdirectories. Update the directory handling in the data_path branch to search recursively for jsonl files, while preserving the existing empty-directory error and print_rank_0 shard-count logging.modelopt/torch/speculative/plugins/hf_streaming_dataset.py (1)
67-77: 🩺 Stability & Availability | 🔵 Trivial | ⚡ Quick win
nixl_backends_from_envdoesn't handle empty or whitespace-padded values.
os.environ.get("NIXL_BACKENDS", "UCX")only falls back to"UCX"when the var is unset, not when it's set to"", soNIXL_BACKENDS=""yields[""]. Values like"UCX, LIBFABRIC"also aren't stripped, producing a mismatched backend name (" LIBFABRIC"). Since this helper is shared with the producer-side connector, either edge case misconfigures NIXL on both ends.🔧 Proposed fix
def nixl_backends_from_env() -> list[str]: - return os.environ.get("NIXL_BACKENDS", "UCX").split(",") + raw = os.environ.get("NIXL_BACKENDS") or "UCX" + return [b.strip() for b in raw.split(",") if b.strip()] or ["UCX"]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/speculative/plugins/hf_streaming_dataset.py` around lines 67 - 77, The nixl_backends_from_env helper currently returns raw comma-split values, so an empty NIXL_BACKENDS becomes [""], and whitespace-padded entries like "UCX, LIBFABRIC" keep invalid spaces. Update nixl_backends_from_env to treat empty or all-whitespace env values as missing and fall back to "UCX", then normalize the parsed list by stripping each backend name and filtering out empty entries so the shared connector logic stays aligned.modelopt/torch/speculative/eagle/utils.py (1)
197-204: 🗄️ Data Integrity & Integration | 🔵 Trivial | ⚡ Quick winBatch-level
base_hidden_prenormassumes homogeneity, unvalidated.Taking the flag only from
features[0]silently assumes every sample in the batch was dumped with the same prenorm setting. If a dataset directory ever mixes dumps from different producer runs (e.g. pre-PR dumps without the key alongside new ones), samples that disagree withfeatures[0]get the wrong norm behavior applied/skipped with no error — exactly the silent-corruption failure mode_maybe_apply_base_final_normwas designed to prevent.🛡️ Proposed fix: assert batch homogeneity
- if "base_hidden_prenorm" in features[0]: - base_model_outputs["base_hidden_prenorm"] = features[0]["base_hidden_prenorm"] + if "base_hidden_prenorm" in features[0]: + flag = features[0]["base_hidden_prenorm"] + if any(item.get("base_hidden_prenorm", flag) != flag for item in features): + raise ValueError( + "Mixed base_hidden_prenorm values within a batch; dumps must be produced " + "consistently." + ) + base_model_outputs["base_hidden_prenorm"] = flag🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/speculative/eagle/utils.py` around lines 197 - 204, The batch-level propagation of base_hidden_prenorm in the utility that builds base_model_outputs assumes every item in features matches features[0], which can silently mix different producer settings. Update the batching logic in the same helper to validate that all samples agree on base_hidden_prenorm before copying it into base_model_outputs, and raise an error if any item differs so _maybe_apply_base_final_norm is never driven by a mixed batch.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 169-178: The `_base_model_norm` property and final-norm lookup
logic are duplicated between `HFDFlashModel` and `HFEagleModel`; extract the
shared probing behavior into a common helper near `_FINAL_NORM_PATHS` in
`modeling_fakebase.py` and have both `_find_base_model_parts` implementations
call it. Keep `_base_model_norm` as a thin accessor over the resolved
`base_model_norm_path` so both models share the same path resolution and future
changes only need to be made once.
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 284-306: The shape validation in the tensor-loading path is
relying on assert, which disappears under optimized Python runs and can
re-enable silent broadcasting bugs. Update the checks in the fakebase model
loading logic, especially around _read, embed_tokens.weight, lm_head.weight, and
the final norm in modeling_fakebase.py, to use explicit if-condition checks that
raise an error when shapes do not match. Keep the same expected shapes derived
from hidden_size, vocab_size, and final_norm_type, but make the failure path
unconditional so corruption is still blocked under python -O.
---
Nitpick comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 94-106: The directory-to-shards expansion in eagle_utils.py only
uses the data-loading path logic around load_dataset and Path.glob, so it will
miss any .jsonl shards nested inside subdirectories. Update the directory
handling in the data_path branch to search recursively for jsonl files, while
preserving the existing empty-directory error and print_rank_0 shard-count
logging.
In `@modelopt/torch/speculative/eagle/utils.py`:
- Around line 197-204: The batch-level propagation of base_hidden_prenorm in the
utility that builds base_model_outputs assumes every item in features matches
features[0], which can silently mix different producer settings. Update the
batching logic in the same helper to validate that all samples agree on
base_hidden_prenorm before copying it into base_model_outputs, and raise an
error if any item differs so _maybe_apply_base_final_norm is never driven by a
mixed batch.
In `@modelopt/torch/speculative/plugins/hf_streaming_dataset.py`:
- Around line 67-77: The nixl_backends_from_env helper currently returns raw
comma-split values, so an empty NIXL_BACKENDS becomes [""], and
whitespace-padded entries like "UCX, LIBFABRIC" keep invalid spaces. Update
nixl_backends_from_env to treat empty or all-whitespace env values as missing
and fall back to "UCX", then normalize the parsed list by stripping each backend
name and filtering out empty entries so the shared connector logic stays
aligned.
In `@modelopt/torch/speculative/plugins/rdma_hidden_states_connector.py`:
- Around line 252-259: The function-local import in register_kv_caches for
nixl_backends_from_env is not justified, so either move it to the module scope
or keep it inside the method with a brief inline comment explaining the specific
reason (for example, circular import avoidance or deferring trainer-side
dependencies). Use the existing register_kv_caches and nixl_backends_from_env
symbols to locate the import, and ensure the rationale matches the other
optional-dependency imports in this method.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 2fba9c92-4e11-4a46-aec4-f87c19ed3ef9
📒 Files selected for processing (9)
examples/speculative_decoding/eagle_utils.pymodelopt/torch/speculative/eagle/utils.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/plugins/hf_eagle.pymodelopt/torch/speculative/plugins/hf_streaming_dataset.pymodelopt/torch/speculative/plugins/modeling_fakebase.pymodelopt/torch/speculative/plugins/modeling_final_norm.pymodelopt/torch/speculative/plugins/rdma_hidden_states_connector.pytests/unit/torch/speculative/plugins/test_modeling_final_norm.py
|
/claude review |
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
a7a1c50 to
d5ef662
Compare
There was a problem hiding this comment.
Claude Review Summary
Findings — CRITICAL: 0 · IMPORTANT: 1 · SUGGESTION: 0
Scope: reviewed all 8 changed modelopt/ and examples/speculative_decoding/ source files and the new unit test. (The diff against the shallow base tip also surfaces unrelated examples/hf_ptq/ and examples/torch_trt/ churn that is base drift, not part of this PR — not reviewed.)
The core fix is sound: re-applying the base final norm before lm_head when the producer declares a pre-norm capture (base_hidden_prenorm) is the right correction; failing loud when the norm cannot be located prevents silent KD-target corruption; and the offline default of False keeps old .pt dumps backward-compatible. _FinalRMSNorm matches the Llama/Qwen/Mistral/DeepSeek/Kimi RMSNorm formula (fp32 reduction, dtype-correct weight), and gpt_oss is correctly excluded pending a matching norm class. The hs_max_tokens guard and the shape checks using raise (surviving python -O) are good hardening.
--- Most impactful finding ---
[IMPORTANT Compatibility] _tokenize_entry now reads only conversations, dropping the messages key entirely (hf_streaming_dataset.py:267). The stated bug is a priority-order problem (Spec-Decoding-v2 carries a degenerate messages stub alongside the real conversations), but the README documents the canonical user dataset format as messages:[...], which the streaming path also consumes. A messages-only corpus now yields None for every entry, so the resample loop exhausts the corpus and raises the misleading "server likely down" error. Preferring "conversations or messages" fixes the priority bug and preserves the documented format. Inline suggestion posted.
--- Concurring with CodeRabbit (elevated severity) ---
CodeRabbit flagged the batch-homogeneity assumption in EagleOfflineDataCollator (eagle/utils.py:200, reading features[0]) as trivial. I would rate it higher: OfflineSupervisedDataset always sets the key, so a directory mixing pre-PR dumps (False) with new pre-norm dumps (True) can put disagreeing samples in one batch; the whole batch then follows features[0], silently feeding un-normed hiddens to lm_head — exactly the corruption _maybe_apply_base_final_norm was built to prevent. Asserting batch homogeneity is worth doing.
--- Risk assessment ---
Low-to-moderate. The norm reconstruction is correct and well-guarded. The one blocking item is the messages/conversations narrowing, a straightforward compat regression against the documented dataset format with an easy fix.
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/hf_streaming_dataset.py (1)
66-78: 🩺 Stability & Availability | 🔵 Trivial | ⚡ Quick winStrip whitespace from parsed backend names.
os.environ.get("NIXL_BACKENDS", "UCX").split(",")doesn't strip whitespace, so a value like"UCX, LIBFABRIC"yields[" LIBFABRIC"]with a leading space. NIXL backend names are matched as exact strings (e.g."backends":["LIBFABRIC"]), so this would silently break agent creation on a common env-var formatting mistake, and since both trainer and producer use this same helper, the bug is baked into both sides equally.🛠️ Proposed fix
- return os.environ.get("NIXL_BACKENDS", "UCX").split(",") + return [b.strip() for b in os.environ.get("NIXL_BACKENDS", "UCX").split(",") if b.strip()]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/speculative/plugins/hf_streaming_dataset.py` around lines 66 - 78, The nixl_backends_from_env helper currently returns raw comma-split values, so NIXL_BACKENDS entries with spaces are preserved and can break exact backend matching. Update nixl_backends_from_env to normalize the parsed list by trimming whitespace from each backend name (and dropping any empty entries if needed) while keeping the same default behavior, so both the trainer-side and producer-side paths continue to agree on valid backend strings.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_streaming_dataset.py`:
- Around line 66-78: The nixl_backends_from_env helper currently returns raw
comma-split values, so NIXL_BACKENDS entries with spaces are preserved and can
break exact backend matching. Update nixl_backends_from_env to normalize the
parsed list by trimming whitespace from each backend name (and dropping any
empty entries if needed) while keeping the same default behavior, so both the
trainer-side and producer-side paths continue to agree on valid backend strings.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 82bb96a6-31fe-47ec-91ba-1309bf13fac2
📒 Files selected for processing (4)
modelopt/torch/speculative/eagle/utils.pymodelopt/torch/speculative/plugins/hf_streaming_dataset.pymodelopt/torch/speculative/plugins/modeling_final_norm.pytests/unit/torch/speculative/plugins/test_modeling_final_norm.py
🚧 Files skipped from review as they are similar to previous changes (3)
- modelopt/torch/speculative/eagle/utils.py
- tests/unit/torch/speculative/plugins/test_modeling_final_norm.py
- modelopt/torch/speculative/plugins/modeling_final_norm.py
|
/claude review |
There was a problem hiding this comment.
Claude Review Summary
Findings — CRITICAL: 0 · IMPORTANT: 0 · SUGGESTION: 1
Scope: reviewed all 9 changed source files (7 modelopt/, 1 examples/, 1 new unit test) plus the RDMA connector. Nothing deprioritized.
Assessment: The core fix is correct and well-guarded. vLLM captures the residual stream before the base model final norm, and the offline/streaming self-logit-distillation path previously fed that un-normed hidden straight into lm_head, corrupting the KD target. The fix threads a base_hidden_prenorm declaration from producer to dataset to collator to consumer and re-applies the base final norm in _maybe_apply_base_final_norm before lm_head, only when pre-norm is declared.
Verified:
- Algorithm: _FinalRMSNorm matches the Llama/Qwen/Mistral/DeepSeek/Kimi RMSNorm formula (fp32 reduction, dtype-correct bf16 weight to avoid float32 promotion mismatching the bf16 lm_head). gpt_oss is correctly excluded pending a matching (fp32-weight, multiply-then-cast) norm class.
- Re-norm boundary: the normed hidden is computed into a local out_hiddens used only for logit reconstruction; base_outputs.out_hiddens stays un-normed for the EAGLE draft input, which is intended. DFlash and Domino paths thread it correctly; Domino ignores logits (no target-logit KD), so omitting the norm there is correct.
- Fail-loud guards: pre-norm declared + no norm located raises RuntimeError; known final_norm_type + missing checkpoint key is a hard error; shape checks use raise (survive python -O); undersized hs_max_tokens raises a clear error instead of a hung fetch.
- Backward compat: offline default base_hidden_prenorm=False keeps existing .pt dumps on the prior post-norm path; unknown model_type builds no norm and downstream skips re-norming.
Both findings from the prior Claude review are resolved in this revision: the messages/conversations priority order now prefers conversations with a messages fallback, and EagleOfflineDataCollator now rejects a batch mixing pre-norm and post-norm dumps.
Non-blocking: one SUGGESTION on nixl_backends_from_env whitespace handling around comma-separated NIXL_BACKENDS.
Risk: low. Correct, well-tested, and backward-compatible. LGTM.
| the container lacks, so UCX RDMA dies there. | ||
| """ | ||
| return os.environ.get("NIXL_BACKENDS", "UCX").split(",") | ||
|
|
There was a problem hiding this comment.
[SUGGESTION] nixl_backends_from_env splits on , without stripping whitespace. The docstring documents this as a comma-separated env var (NIXL_BACKENDS=LIBFABRIC), and a user who naturally writes NIXL_BACKENDS="UCX, LIBFABRIC" (with a space) would get ["UCX", " LIBFABRIC"] — the second entry has a leading space and won't match a valid NIXL backend name, silently breaking the RDMA handoff. Since this list must match on both the trainer and producer side, a mismatch here is hard to diagnose. Stripping (and dropping empties) makes it robust:
return [b.strip() for b in os.environ.get("NIXL_BACKENDS", "UCX").split(",") if b.strip()]There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py (1)
204-213: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick winAssert the actual value of
base_hidden_prenorm, not just its presence.The test only checks that
base_hidden_prenormis a key in the returned item, but never verifies its derived value. Since correctly deriving this flag from the offline dump is the actual bug this PR fixes, a value-only assertion doesn't protect against a regression in that derivation logic (e.g., silently defaulting toFalse/Trueregardless of the dump).✅ Example addition
assert item["input_ids"].shape == (SEQ_LEN,) assert item["attention_mask"].shape == (SEQ_LEN,) assert item["labels"].shape == (SEQ_LEN,) + assert item["base_hidden_prenorm"] == <expected_value_based_on_dump>🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py` around lines 204 - 213, The test around ds[0] only verifies that base_hidden_prenorm exists, but it should also assert the derived boolean value matches the offline dump. Update the spec in test_hf_speculative_offline.py near the ds[0] assertions to check base_hidden_prenorm’s actual value for the returned item, using the same dataset/item setup already in place so regressions in the derivation logic are caught.tests/unit/torch/speculative/plugins/test_fakebase.py (1)
49-60: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick winAdd assertions for the loaded norm weight.
These fixtures now supply
norm.weightbecause the new FakeBaseModel logic requires and copies it, but none of the three tests assert thatmodel.norm.weightwas actually loaded/copied correctly (only lm_head/embed_tokens shapes are checked). Since this PR's core change is about correctly reconstructing/loading the final norm, verifying the norm weight in these tests would catch regressions in exactly the new behavior being added.✅ Example addition
def test_fakebase_local_happy_path(fake_checkpoint): model = FakeBaseModel.from_source(str(fake_checkpoint)) assert model.lm_head.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE]) assert model.embed_tokens.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE]) + assert model.norm.weight.shape == torch.Size([_HIDDEN_SIZE]) + torch.testing.assert_close(model.norm.weight, torch.ones(_HIDDEN_SIZE))Also applies to: 75-84, 88-98
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/speculative/plugins/test_fakebase.py` around lines 49 - 60, The FakeBaseModel tests now provide norm.weight in the fixture, but the assertions in the affected test cases still only verify lm_head and embed_tokens, so the new final-norm loading path is untested. Update the three tests that use fake_checkpoint/fake_config to also assert that model.norm.weight is loaded/copied correctly, using the existing FakeBaseModel/`model.norm.weight` references so the checks cover the new reconstruction behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/unit/torch/speculative/plugins/test_fakebase.py`:
- Around line 49-60: The FakeBaseModel tests now provide norm.weight in the
fixture, but the assertions in the affected test cases still only verify lm_head
and embed_tokens, so the new final-norm loading path is untested. Update the
three tests that use fake_checkpoint/fake_config to also assert that
model.norm.weight is loaded/copied correctly, using the existing
FakeBaseModel/`model.norm.weight` references so the checks cover the new
reconstruction behavior.
In `@tests/unit/torch/speculative/plugins/test_hf_speculative_offline.py`:
- Around line 204-213: The test around ds[0] only verifies that
base_hidden_prenorm exists, but it should also assert the derived boolean value
matches the offline dump. Update the spec in test_hf_speculative_offline.py near
the ds[0] assertions to check base_hidden_prenorm’s actual value for the
returned item, using the same dataset/item setup already in place so regressions
in the derivation logic are caught.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: f9873e95-05d3-4c30-a637-d7585ecb0c9e
📒 Files selected for processing (3)
tests/unit/torch/speculative/plugins/test_fakebase.pytests/unit/torch/speculative/plugins/test_hf_speculative_offline.pytests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py
✅ Files skipped from review due to trivial changes (1)
- tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1846 +/- ##
==========================================
- Coverage 70.21% 69.98% -0.24%
==========================================
Files 515 516 +1
Lines 57244 57538 +294
==========================================
+ Hits 40196 40268 +72
- Misses 17048 17270 +222
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
What does this PR do?
Type of change: Bug fix
vLLM captures the final-layer hidden state before the model's final norm, but the
offline/streaming distillation path fed it straight into
lm_head, so the reconstructedbase logits (the KD target) were computed from un-normed hidden states.
This PR re-applies the base model's final norm before
lm_headwhen the producer declaresa pre-norm capture (
base_hidden_prenorm), for both DFlash and EAGLE:base_hidden_prenorm(streaming:True; offline: from the dump)._maybe_apply_base_final_norm) re-applies the base final norm, and fails loudif pre-norm is declared but the model's norm type isn't supported (no silent corruption).
FakeBaseModelnow loads the base final norm (+rope_theta/rms_norm_eps); norm type isgated by an explicit
model_typeallowlist (gpt_oss excluded pending a matching norm class).Testing
tests/unit/torch/speculative/plugins/test_modeling_final_norm.py; DFlash/EAGLE streamingtraining verified end-to-end.
base_hidden_prenorm=False→ unchanged)Summary by CodeRabbit
base_hidden_prenormand can configure RDMA backends from environment settings.base_hidden_prenormvalues within a batch.