Skip to content

Save hf checkpoint at every valitation iteration during distillation.#1897

Open
danielkorzekwa wants to merge 9 commits into
mainfrom
dkorzekwa/save_hf_checkpoint_for_every_val_iter_during_distill
Open

Save hf checkpoint at every valitation iteration during distillation.#1897
danielkorzekwa wants to merge 9 commits into
mainfrom
dkorzekwa/save_hf_checkpoint_for_every_val_iter_during_distill

Conversation

@danielkorzekwa

@danielkorzekwa danielkorzekwa commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Save hf checkpoint at every valitation iteration during distillation.

Usage

  • examples/megatron_bridge/distill.py
  • examples/megatron_bridge/README.md (line 228)

Testing

  • tests/examples/megatron_bridge/test_distill.py

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅
  • Did you write any new necessary tests?: ✅

Summary by CodeRabbit

  • New Features
    • Added --validate_only mode for distillation (evaluate the student at iteration 0 without training).
    • Added --hf_validation_export_path to export student HuggingFace artifacts after each validation stage.
    • Introduced prepare_data_blend.py to generate token-budgeted Megatron data blends from YAML.
    • Added --max_tokens to stop Megatron preprocessing after a token budget.
  • Bug Fixes
    • Validation exports now save only student HuggingFace model artifacts and preserve student architecture/config.
  • Documentation
    • Expanded researcher and tutorial guidance for iterative distillation, validation exports, and token-budgeted blend workflows.
  • Tests
    • Updated/added coverage for validate_only, validation exports, blend preparation, and max_tokens truncation.

Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
@danielkorzekwa danielkorzekwa requested review from a team as code owners July 3, 2026 17:02
@danielkorzekwa danielkorzekwa requested a review from jenchen13 July 3, 2026 17:02
@danielkorzekwa

Copy link
Copy Markdown
Contributor Author

/claude review

@coderabbitai

coderabbitai Bot commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: c638ba64-3ba0-459c-b266-a336f358150a

📥 Commits

Reviewing files that changed from the base of the PR and between d9aead1 and fe0e72c.

📒 Files selected for processing (1)
  • examples/researcher_guide/README.md
✅ Files skipped from review due to trivial changes (1)
  • examples/researcher_guide/README.md

📝 Walkthrough

Walkthrough

Adds per-validation HuggingFace exports, token-budgeted preprocessing and blend generation, plus tests and documentation for the new iterative workflows.

Changes

Iterative distillation tooling

Layer / File(s) Summary
Validation export callback and wiring
examples/megatron_bridge/distill.py
Adds the validation-end HuggingFace export callback, CLI flag, argument validation, and control-flow switch to run pretrain(..., callbacks=[callback]) when validation exports are enabled.
Validation export test and docs
tests/examples/megatron_bridge/test_distill.py, examples/megatron_bridge/README.md
Extends the distillation integration test to assert per-iteration HuggingFace exports and documents the workflow in the Megatron bridge README.
Token-capped preprocessing
modelopt/torch/utils/plugins/megatron_preprocess_data.py
Adds max_tokens support to Megatron preprocessing for JSONL and Hugging Face inputs, with bounded batching, early stopping, output tagging, and CLI wiring.
Blend builder and tests
examples/dataset/prepare_data_blend.py, tests/examples/dataset/test_prepare_data_blend.py
Adds the token-budgeted blend builder CLI, helper functions, and an end-to-end test covering downloads, output prefixes, blend weights, and copied config.
Blend workflow docs
examples/dataset/MEGATRON_DATA_PREP.md, examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md, examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/README.md, examples/researcher_guide/README.md
Updates dataset, tutorial, pruning, and researcher guide documentation to describe token-budgeted data blends and the validation-export workflow.

Estimated code review effort: 4 (Complex) | ~60 minutes

Sequence Diagram(s)

sequenceDiagram
  participant distill.py CLI
  participant pretrain
  participant _HFValidationExportCallback
  participant AutoConfig
  participant Filesystem

  distill.py CLI->>pretrain: pretrain(config, callbacks=[callback])
  pretrain->>_HFValidationExportCallback: on_validation_end(iteration)
  _HFValidationExportCallback->>Filesystem: create iter_<iteration>/ export path
  _HFValidationExportCallback->>AutoConfig: save_pretrained(...)
  AutoConfig->>Filesystem: write config.json
  _HFValidationExportCallback->>pretrain: torch.distributed.barrier()
Loading
sequenceDiagram
  participant CLI
  participant megatron_preprocess_data
  participant process_hf_split
  participant process_json_file
  participant _encode_docs

  CLI->>megatron_preprocess_data: --max_tokens
  megatron_preprocess_data->>process_hf_split: remaining_tokens
  megatron_preprocess_data->>process_json_file: remaining_tokens
  process_hf_split->>_encode_docs: may_stop_early=True
  process_json_file->>_encode_docs: may_stop_early=True
  process_hf_split-->>megatron_preprocess_data: stop at token budget
  process_json_file-->>megatron_preprocess_data: stop at token budget
Loading

Possibly related PRs

  • NVIDIA/Model-Optimizer#1872: Introduces the researcher guide content that this PR extends with token-budgeted blend and validation-export workflows.

Suggested reviewers: cjluo-nv, meenchen, Edwardf0t1

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly matches the main change: exporting HF checkpoints at each validation iteration during distillation.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Scanned the PR’s touched Python files and diff: no unsafe torch/numpy loads, no eval/exec or nosec, trust_remote_code is only CLI-controlled, and no dependency changes.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch dkorzekwa/save_hf_checkpoint_for_every_val_iter_during_distill
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch dkorzekwa/save_hf_checkpoint_for_every_val_iter_during_distill

Comment @coderabbitai help to get the list of available commands.

)
# TODO: Use distill(..., callbacks=[callback]) once Megatron-Bridge supports callbacks.
pretrain(config, forward_step_modelopt, callbacks=[callback])
else:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] When --hf_validation_export_path is set, the training entrypoint switches from distill(config) to pretrain(config, forward_step_modelopt, callbacks=[callback]). This creates two distinct training code paths that must stay behaviorally identical — otherwise enabling validation export silently changes how the model is trained, not just whether checkpoints are dumped.

If distill() does any distillation-specific setup beyond pretrain + forward_step_modelopt (e.g. loss-balancer wiring, KD config injection, provider hooks), that setup would be skipped on the export path. The assert isinstance(config.model, DistillationProvider) guard suggests you've considered this, but it would be worth (1) confirming distill() is genuinely just pretrain(config, forward_step_modelopt, ...) under the hood, and (2) leaving a one-line comment near the fork noting that equivalence, so a future change to distill() doesn't quietly diverge the two paths. The existing TODO about distill(..., callbacks=...) partially covers this, but the equivalence risk is the part worth calling out explicitly.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

🧹 Nitpick comments (1)
examples/megatron_bridge/distill.py (1)

209-214: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick win

Reload of AutoConfig on every validation export; duplicated with the final-export path.

AutoConfig.from_pretrained(self.student_hf_path, ...) is re-executed on every on_eval_end call even though student_hf_path/trust_remote_code never change across the run — this repeats file/Hub I/O unnecessarily. The same load-then-save_pretrained pattern (with an identical comment) is also duplicated in main() at Lines 557-560.

Consider loading the config once in __init__ and reusing it in on_eval_end, and/or extracting a small shared helper used by both this callback and the final --hf_export_path block to avoid drift between the two copies.

♻️ Proposed fix to cache the config once
     def __init__(
         self,
         export_dir: str,
         student_hf_model: str,
         student_hf_path: str,
         trust_remote_code: bool,
     ) -> None:
         self.export_dir = Path(export_dir)
-        self.student_hf_path = student_hf_path
-        self.trust_remote_code = trust_remote_code
         self._last_exported_iteration: int | None = None
         self.bridge = AutoBridge.from_hf_pretrained(
             student_hf_model, trust_remote_code=trust_remote_code
         )
+        self._student_config = AutoConfig.from_pretrained(
+            student_hf_path, trust_remote_code=trust_remote_code
+        )
         if dist.rank() == 0:
-            # Preserve the student architecture from student_hf_path, including heterogeneous
-            # layer changes; AutoConfig supports both local paths and Hugging Face model IDs.
-            AutoConfig.from_pretrained(
-                self.student_hf_path, trust_remote_code=self.trust_remote_code
-            ).save_pretrained(output_path)
+            self._student_config.save_pretrained(output_path)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/megatron_bridge/distill.py` around lines 209 - 214,
`DistillCallback.on_eval_end` is reloading the student config on every
validation export even though `student_hf_path` and `trust_remote_code` are
ثابت, and the same `AutoConfig.from_pretrained(...).save_pretrained(...)` logic
is duplicated in `main()`. Cache the loaded config once in
`DistillCallback.__init__` (or a shared helper) and reuse it in `on_eval_end`,
then have the final `--hf_export_path` export path call the same helper to keep
the behavior in sync and avoid repeated I/O.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/megatron_bridge/distill.py`:
- Around line 167-218: The _HFValidationExportCallback currently saves a new
Hugging Face export on every validation with no cleanup, which can cause
unbounded disk growth. Add a retention setting such as keep_last_n to the
callback, and in on_eval_end prune older iter_* export directories after a
successful save, keeping only the most recent exports. Use the existing export
flow in _HFValidationExportCallback and mirror the retention approach already
used for main checkpointing (most_recent_k) so the logic is easy to locate and
consistent.

---

Nitpick comments:
In `@examples/megatron_bridge/distill.py`:
- Around line 209-214: `DistillCallback.on_eval_end` is reloading the student
config on every validation export even though `student_hf_path` and
`trust_remote_code` are ثابت, and the same
`AutoConfig.from_pretrained(...).save_pretrained(...)` logic is duplicated in
`main()`. Cache the loaded config once in `DistillCallback.__init__` (or a
shared helper) and reuse it in `on_eval_end`, then have the final
`--hf_export_path` export path call the same helper to keep the behavior in sync
and avoid repeated I/O.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: c41374ad-ecce-4dc2-b44d-66aedfb9cdc0

📥 Commits

Reviewing files that changed from the base of the PR and between b0ee953 and ed5b737.

📒 Files selected for processing (3)
  • examples/megatron_bridge/README.md
  • examples/megatron_bridge/distill.py
  • tests/examples/megatron_bridge/test_distill.py

Comment on lines +167 to +218
class _HFValidationExportCallback(Callback):
"""Export the live student to Hugging Face format after each validation stage."""

def __init__(
self,
export_dir: str,
student_hf_model: str,
student_hf_path: str,
trust_remote_code: bool,
) -> None:
self.export_dir = Path(export_dir)
self.student_hf_path = student_hf_path
self.trust_remote_code = trust_remote_code
self._last_exported_iteration: int | None = None
self.bridge = AutoBridge.from_hf_pretrained(
student_hf_model, trust_remote_code=trust_remote_code
)

def on_eval_end(self, context) -> None:
"""Export the student at the iteration that was just validated."""
iteration = context.state.train_state.step
# The final iteration can be validated both on its regular interval and after training.
# Avoid exporting and overwriting the same Hugging Face checkpoint twice.
if iteration == self._last_exported_iteration:
return
output_path = self.export_dir / f"iter_{iteration:07d}"
print_rank_0(f"Exporting validation checkpoint {iteration} to {output_path}")

# DistillationModel is the student with teacher and KD-loss modules attached. Hide the
# auxiliary modules temporarily so the Hugging Face export contains only student weights.
with contextlib.ExitStack() as stack:
for model_chunk in unwrap_model(context.model):
if isinstance(model_chunk, mtd.DistillationModel):
stack.enter_context(model_chunk.hide_teacher_model())
stack.enter_context(model_chunk.hide_loss_modules())
self.bridge.save_hf_pretrained(
context.model,
output_path,
show_progress=True,
strict=True,
)

if dist.rank() == 0:
# Preserve the student architecture from student_hf_path, including heterogeneous
# layer changes; AutoConfig supports both local paths and Hugging Face model IDs.
AutoConfig.from_pretrained(
self.student_hf_path, trust_remote_code=self.trust_remote_code
).save_pretrained(output_path)
torch.distributed.barrier()
self._last_exported_iteration = iteration


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance & Scalability | 🟠 Major | ⚡ Quick win

No retention policy for per-validation HF exports — unbounded disk growth risk.

Every validation stage writes a full HuggingFace checkpoint to iter_<n>/ with no cap on how many are kept. For a realistic run (e.g. the README's --train_iters 15000 --eval_interval 100 example), that's ~150 full model copies with no cleanup — unlike the main Megatron checkpoint config in this same file, which bounds retention via most_recent_k=5 (Line 505). For multi-billion-parameter students this can exhaust disk and fail the job mid-training.

Consider adding a retention parameter (e.g. keep_last_n) to _HFValidationExportCallback that prunes older iter_* export directories, mirroring the most_recent_k pattern already used for the main checkpoint.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/megatron_bridge/distill.py` around lines 167 - 218, The
_HFValidationExportCallback currently saves a new Hugging Face export on every
validation with no cleanup, which can cause unbounded disk growth. Add a
retention setting such as keep_last_n to the callback, and in on_eval_end prune
older iter_* export directories after a successful save, keeping only the most
recent exports. Use the existing export flow in _HFValidationExportCallback and
mirror the retention approach already used for main checkpointing
(most_recent_k) so the logic is easy to locate and consistent.

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. LGTM

Scope: Examples-only PR (3 files: examples/megatron_bridge/distill.py, its README, and tests/examples/megatron_bridge/test_distill.py). megatron.bridge is an external dependency not installed in this environment, so distill()/pretrain() internals could not be introspected; review is based on the diff plus ModelOpt's DistillationModel API.

Findings: CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 1

What I verified

  • The export path correctly strips auxiliary modules before saving: hide_teacher_model() and hide_loss_modules() exist in modelopt/torch/distill/distillation_model.py and are the same mechanism used by the minimal state-dict export path.
  • Distributed ordering is sound: all ranks run the collective save_hf_pretrained, only rank 0 writes config.json, followed by a barrier().
  • The _last_exported_iteration guard sensibly prevents double-export/overwrite when the final iteration is validated both on its interval and post-training.
  • Argument validation was correctly widened to require --student_hf_model when either HF export flag is set.

Suggestion (non-blocking)

  • Enabling --hf_validation_export_path switches training from distill(config) to pretrain(config, forward_step_modelopt, callbacks=[...]). Worth confirming these two paths are behaviorally identical and leaving a comment noting the equivalence, so a future change to distill() doesn't silently diverge the export path from the normal one.

Overall risk: low — additive, opt-in feature gated behind a new flag, backward compatible, with test coverage added.

@codecov

codecov Bot commented Jul 3, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 2.56410% with 38 lines in your changes missing coverage. Please review.
✅ Project coverage is 61.17%. Comparing base (b0ee953) to head (fe0e72c).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...pt/torch/utils/plugins/megatron_preprocess_data.py 2.56% 38 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (b0ee953) and HEAD (fe0e72c). Click for more details.

HEAD has 3 uploads less than BASE
Flag BASE (b0ee953) HEAD (fe0e72c)
gpu 3 2
examples 12 10
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1897       +/-   ##
===========================================
- Coverage   75.20%   61.17%   -14.03%     
===========================================
  Files         515      515               
  Lines       57245    57274       +29     
===========================================
- Hits        43050    35038     -8012     
- Misses      14195    22236     +8041     
Flag Coverage Δ
examples 32.58% <2.56%> (-10.55%) ⬇️
gpu 20.51% <0.00%> (-29.52%) ⬇️
regression 14.82% <2.56%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
@danielkorzekwa danielkorzekwa requested review from a team as code owners July 3, 2026 21:36
@danielkorzekwa danielkorzekwa requested a review from jingyu-ml July 3, 2026 21:36

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/utils/plugins/megatron_preprocess_data.py (1)

308-368: 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Add max_tokens to file-backed dataset prefixes
process_json_file() returns the same .bin/.idx prefix regardless of token budget, so reruns with a different target_tokens can silently reuse stale files for "files" sources. This also leaves the file-backed path out of sync with process_hf_split() and the blend test’s _tokens expectation. Consider appending the same token_tag there.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/utils/plugins/megatron_preprocess_data.py` around lines 308 -
368, `process_json_file()` currently builds output prefixes only from the input
stem, so file-backed runs with different token budgets can reuse the same
`.bin/.idx` artifacts. Update `process_json_file` in
`megatron_preprocess_data.py` to append the same `token_tag` used by
`process_hf_split()` when constructing `output_prefix`/`prefixes`, so the
`process_json_file` and `process_hf_split` paths produce consistent,
budget-specific dataset names. Ensure the new naming is derived from the
existing `max_tokens`/token budget logic and is applied before checking for
existing builders or returning skipped prefixes.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/dataset/prepare_data_blend.py`:
- Around line 88-141: The weight-to-token allocation in prepare_data_blend can
produce zero or negative max_tokens for the last source when YAML weights are
misconfigured, leading to an empty prefixes list and a later ZeroDivisionError.
Add upfront validation in prepare_data_blend/load_config that source weights sum
to about 100 (or otherwise ensure source_tokens stays positive), and also guard
the prefix_weight calculation after megatron_preprocess_data so an empty
prefixes result raises a clear configuration error instead of dividing by zero.

---

Outside diff comments:
In `@modelopt/torch/utils/plugins/megatron_preprocess_data.py`:
- Around line 308-368: `process_json_file()` currently builds output prefixes
only from the input stem, so file-backed runs with different token budgets can
reuse the same `.bin/.idx` artifacts. Update `process_json_file` in
`megatron_preprocess_data.py` to append the same `token_tag` used by
`process_hf_split()` when constructing `output_prefix`/`prefixes`, so the
`process_json_file` and `process_hf_split` paths produce consistent,
budget-specific dataset names. Ensure the new naming is derived from the
existing `max_tokens`/token budget logic and is applied before checking for
existing builders or returning skipped prefixes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 9f8a8b4f-8080-4109-952c-6c7d31f8c071

📥 Commits

Reviewing files that changed from the base of the PR and between ed5b737 and b27265f.

📒 Files selected for processing (8)
  • examples/dataset/MEGATRON_DATA_PREP.md
  • examples/dataset/prepare_data_blend.py
  • examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md
  • examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/README.md
  • examples/researcher_guide/README.md
  • modelopt/torch/utils/plugins/megatron_preprocess_data.py
  • tests/examples/dataset/test_prepare_data_blend.py
  • tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py
✅ Files skipped from review due to trivial changes (3)
  • examples/megatron_bridge/tutorials/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md
  • examples/researcher_guide/README.md
  • examples/dataset/MEGATRON_DATA_PREP.md

Comment on lines +88 to +141
for index, source in enumerate(sources):
weight = float(source["weight"])
if total_tokens is None:
source_tokens = None
elif index == len(sources) - 1:
source_tokens = total_tokens - allocated_tokens
else:
source_tokens = round(total_tokens * weight / 100)
allocated_tokens += source_tokens

dataset = source["hf_dataset"]
source_dir = output_dir / f"{index:02d}_{dataset.replace('/', '--')}"
content_field = source["content_field"]
input_args: dict[str, Any]
if "files" in source:
raw_dir = output_dir.parent / "raw" / dataset.replace("/", "--")
paths = [
hf_hub_download(
repo_id=dataset,
filename=file,
repo_type="dataset",
local_dir=raw_dir,
)
for file in source["files"]
]
input_args = {"jsonl_paths": paths}
else:
input_args = {
"hf_dataset": dataset,
"hf_name": source.get("config"),
"hf_split": source["split"],
"hf_max_samples_per_split": source.get("max_samples"),
"hf_streaming": True,
}

# Each prefix is the path shared by a tokenized Megatron .bin/.idx file pair.
prefixes = megatron_preprocess_data(
**input_args,
output_dir=source_dir,
tokenizer_name_or_path=tokenizer,
json_keys=content_field,
# Plain text lacks chat-template boundary tokens, so terminate each document with EOS.
append_eod=content_field == "text",
# Join lines in text documents by replacing each newline with a space.
strip_newlines=content_field == "text",
reasoning_content="inline" if content_field == "messages" else "strip",
# Guard against pathological records by capping each tokenized document at 256K tokens.
max_sequence_length=256_000,
max_tokens=source_tokens,
workers=workers,
)
prefix_weight = weight / len(prefixes)
blend.extend((prefix_weight, prefix) for prefix in prefixes)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Unvalidated weight allocation can crash with ZeroDivisionError on misconfigured YAML.

source_tokens for non-last sources is round(total_tokens * weight / 100) (Line 95), with the last source getting total_tokens - allocated_tokens (Line 93). If the configured weights don't sum to ~100 (or accumulate enough rounding error), the last source's source_tokens can end up <= 0. Inside megatron_preprocess_data, a non-positive max_tokens on the very first split/file causes an immediate break with zero prefixes returned (Lines 600-601, 628-629 of megatron_preprocess_data.py). Back here, prefix_weight = weight / len(prefixes) (Line 139) then raises ZeroDivisionError.

Since load_config/prepare_data_blend is the interface boundary for this user-authored YAML, consider validating that sources weights sum to 100 (within tolerance) up front, and/or guarding against an empty prefixes result with a clear error message instead of a raw ZeroDivisionError.

🛡️ Proposed validation
 def prepare_data_blend(config_path: Path) -> list[tuple[float, str]]:
     """Download and tokenize the configured weighted data sources."""
     config = load_config(config_path)
     output_dir = Path(config["output_dir"])
     output_dir.mkdir(parents=True, exist_ok=True)
     target_tokens = config.get("target_tokens")
     total_tokens = None if target_tokens is None else int(target_tokens)
     tokenizer = str(config["tokenizer"])
+
+    total_weight = sum(float(source["weight"]) for source in config["sources"])
+    if not math.isclose(total_weight, 100, abs_tol=0.5):
+        raise ValueError(f"Source weights must sum to 100, got {total_weight}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for index, source in enumerate(sources):
weight = float(source["weight"])
if total_tokens is None:
source_tokens = None
elif index == len(sources) - 1:
source_tokens = total_tokens - allocated_tokens
else:
source_tokens = round(total_tokens * weight / 100)
allocated_tokens += source_tokens
dataset = source["hf_dataset"]
source_dir = output_dir / f"{index:02d}_{dataset.replace('/', '--')}"
content_field = source["content_field"]
input_args: dict[str, Any]
if "files" in source:
raw_dir = output_dir.parent / "raw" / dataset.replace("/", "--")
paths = [
hf_hub_download(
repo_id=dataset,
filename=file,
repo_type="dataset",
local_dir=raw_dir,
)
for file in source["files"]
]
input_args = {"jsonl_paths": paths}
else:
input_args = {
"hf_dataset": dataset,
"hf_name": source.get("config"),
"hf_split": source["split"],
"hf_max_samples_per_split": source.get("max_samples"),
"hf_streaming": True,
}
# Each prefix is the path shared by a tokenized Megatron .bin/.idx file pair.
prefixes = megatron_preprocess_data(
**input_args,
output_dir=source_dir,
tokenizer_name_or_path=tokenizer,
json_keys=content_field,
# Plain text lacks chat-template boundary tokens, so terminate each document with EOS.
append_eod=content_field == "text",
# Join lines in text documents by replacing each newline with a space.
strip_newlines=content_field == "text",
reasoning_content="inline" if content_field == "messages" else "strip",
# Guard against pathological records by capping each tokenized document at 256K tokens.
max_sequence_length=256_000,
max_tokens=source_tokens,
workers=workers,
)
prefix_weight = weight / len(prefixes)
blend.extend((prefix_weight, prefix) for prefix in prefixes)
tokenizer = str(config["tokenizer"])
total_weight = sum(float(source["weight"]) for source in config["sources"])
if not math.isclose(total_weight, 100, abs_tol=0.5):
raise ValueError(f"Source weights must sum to 100, got {total_weight}")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/dataset/prepare_data_blend.py` around lines 88 - 141, The
weight-to-token allocation in prepare_data_blend can produce zero or negative
max_tokens for the last source when YAML weights are misconfigured, leading to
an empty prefixes list and a later ZeroDivisionError. Add upfront validation in
prepare_data_blend/load_config that source weights sum to about 100 (or
otherwise ensure source_tokens stays positive), and also guard the prefix_weight
calculation after megatron_preprocess_data so an empty prefixes result raises a
clear configuration error instead of dividing by zero.

Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
@kevalmorabia97 kevalmorabia97 requested review from AAnoosheh and kevalmorabia97 and removed request for jingyu-ml July 4, 2026 16:15
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant