Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,31 @@ def _export_config(self):
}
)
return config


class DSparkExporter(DFlashExporter):
"""Draft model exporter for DSpark (DFlash backbone + sequential Markov head).

Same z-lab-compatible format as DFlash, plus the DSpark head weights
(``markov_w1.*`` / ``markov_w2.*`` / ``gate_proj.*`` / ``joint_proj.*`` /
``confidence_proj.*``, already captured by the inherited ``dflash_module.``
stripping) and the extra config fields the loader needs to rebuild the head
(``projector_type``, ``markov_rank``, ``markov_head_type``,
``use_confidence_head``, ``shift_label``).
"""

def _export_config(self):
"""Extend the DFlash config with the DSpark head fields."""
config = super()._export_config()
draft_config = self.model.dflash_config

config["dflash_config"].update(
{
"projector_type": getattr(draft_config, "projector_type", "dspark"),
"shift_label": getattr(draft_config, "shift_label", True),
"markov_rank": draft_config.markov_rank,
"markov_head_type": getattr(draft_config, "markov_head_type", "vanilla"),
"use_confidence_head": bool(getattr(draft_config, "use_confidence_head", False)),
}
)
return config
31 changes: 31 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,37 @@ class DFlashConfig(ModeloptBaseConfig):
),
)

dflash_ce_loss_alpha: float = ModeloptField(
default=0.1,
ge=0.0,
description=(
"DSpark only: weight of the cross-entropy term in the three-term loss "
"(ce_alpha*CE + l1_alpha*TVD + conf_alpha*BCE). "
"Ignored unless dflash_architecture_config.projector_type == 'dspark'."
),
)

dflash_l1_loss_alpha: float = ModeloptField(
default=0.9,
ge=0.0,
description=(
"DSpark only: weight of the L1/total-variation distribution-matching term "
"between the corrected draft and the target distribution. "
"Ignored unless dflash_architecture_config.projector_type == 'dspark'."
),
)

dflash_confidence_head_alpha: float = ModeloptField(
default=0.0,
ge=0.0,
description=(
"DSpark only: weight of the confidence-head BCE term (predicts the per-position "
"acceptance probability). 0 disables the term; requires "
"dflash_architecture_config.use_confidence_head=true when > 0. "
"Ignored unless dflash_architecture_config.projector_type == 'dspark'."
),
)

@model_validator(mode="after")
def _check_dpace_alpha(self) -> "DFlashConfig":
# Validate at construction regardless of the active objective, so a bad alpha
Expand Down
10 changes: 9 additions & 1 deletion modelopt/torch/speculative/dflash/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
# ``dflash_architecture_config.projector_type == "domino"`` and lives in its own
# registry so its wrapper (HFDominoModel) does not overwrite HFDFlashModel.
DominoDMRegistry = _DMRegistryCls(prefix="Domino")
# DSpark also reuses the dflash mode/config/recipe, converting the base model to a
# DFlash backbone augmented with a lightweight sequential (Markov) head and an
# optional confidence head. Selected via
# ``dflash_architecture_config.projector_type == "dspark"`` and kept in its own
# registry so its wrapper (HFDSparkModel) does not overwrite HFDFlashModel.
DSparkDMRegistry = _DMRegistryCls(prefix="DSpark")


def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType:
Expand All @@ -45,12 +51,14 @@ def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertRe
projector_type = config.dflash_architecture_config.get("projector_type")
if projector_type == "domino":
registry = DominoDMRegistry
elif projector_type == "dspark":
registry = DSparkDMRegistry
elif projector_type in (None, "dflash"):
registry = DFlashDMRegistry
else:
raise ValueError(
f"Unsupported dflash_architecture_config.projector_type: {projector_type!r}. "
"Expected 'dflash' (default) or 'domino'."
"Expected 'dflash' (default), 'domino' or 'dspark'."
)

original_cls = type(model)
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/speculative/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@
with import_plugin("transformers"):
from .hf_dflash import *
from .hf_domino import *
from .hf_dspark import *
from .hf_eagle import *
from .hf_medusa import *
Loading
Loading