diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index a9f4f995540..fd5bb7ef086 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -477,3 +477,31 @@ def _export_config(self): } ) return config + + +class DSparkExporter(DFlashExporter): + """Draft model exporter for DSpark (DFlash backbone + sequential Markov head). + + Same z-lab-compatible format as DFlash, plus the DSpark head weights + (``markov_w1.*`` / ``markov_w2.*`` / ``gate_proj.*`` / ``joint_proj.*`` / + ``confidence_proj.*``, already captured by the inherited ``dflash_module.`` + stripping) and the extra config fields the loader needs to rebuild the head + (``projector_type``, ``markov_rank``, ``markov_head_type``, + ``use_confidence_head``, ``shift_label``). + """ + + def _export_config(self): + """Extend the DFlash config with the DSpark head fields.""" + config = super()._export_config() + draft_config = self.model.dflash_config + + config["dflash_config"].update( + { + "projector_type": getattr(draft_config, "projector_type", "dspark"), + "shift_label": getattr(draft_config, "shift_label", True), + "markov_rank": draft_config.markov_rank, + "markov_head_type": getattr(draft_config, "markov_head_type", "vanilla"), + "use_confidence_head": bool(getattr(draft_config, "use_confidence_head", False)), + } + ) + return config diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 3ff317aa0a4..a4e6de138cb 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -184,6 +184,37 @@ class DFlashConfig(ModeloptBaseConfig): ), ) + dflash_ce_loss_alpha: float = ModeloptField( + default=0.1, + ge=0.0, + description=( + "DSpark only: weight of the cross-entropy term in the three-term loss " + "(ce_alpha*CE + l1_alpha*TVD + conf_alpha*BCE). " + "Ignored unless dflash_architecture_config.projector_type == 'dspark'." + ), + ) + + dflash_l1_loss_alpha: float = ModeloptField( + default=0.9, + ge=0.0, + description=( + "DSpark only: weight of the L1/total-variation distribution-matching term " + "between the corrected draft and the target distribution. " + "Ignored unless dflash_architecture_config.projector_type == 'dspark'." + ), + ) + + dflash_confidence_head_alpha: float = ModeloptField( + default=0.0, + ge=0.0, + description=( + "DSpark only: weight of the confidence-head BCE term (predicts the per-position " + "acceptance probability). 0 disables the term; requires " + "dflash_architecture_config.use_confidence_head=true when > 0. " + "Ignored unless dflash_architecture_config.projector_type == 'dspark'." + ), + ) + @model_validator(mode="after") def _check_dpace_alpha(self) -> "DFlashConfig": # Validate at construction regardless of the active objective, so a bad alpha diff --git a/modelopt/torch/speculative/dflash/conversion.py b/modelopt/torch/speculative/dflash/conversion.py index b5cb82c4db1..0516ab7181d 100644 --- a/modelopt/torch/speculative/dflash/conversion.py +++ b/modelopt/torch/speculative/dflash/conversion.py @@ -29,6 +29,12 @@ # ``dflash_architecture_config.projector_type == "domino"`` and lives in its own # registry so its wrapper (HFDominoModel) does not overwrite HFDFlashModel. DominoDMRegistry = _DMRegistryCls(prefix="Domino") +# DSpark also reuses the dflash mode/config/recipe, converting the base model to a +# DFlash backbone augmented with a lightweight sequential (Markov) head and an +# optional confidence head. Selected via +# ``dflash_architecture_config.projector_type == "dspark"`` and kept in its own +# registry so its wrapper (HFDSparkModel) does not overwrite HFDFlashModel. +DSparkDMRegistry = _DMRegistryCls(prefix="DSpark") def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType: @@ -45,12 +51,14 @@ def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertRe projector_type = config.dflash_architecture_config.get("projector_type") if projector_type == "domino": registry = DominoDMRegistry + elif projector_type == "dspark": + registry = DSparkDMRegistry elif projector_type in (None, "dflash"): registry = DFlashDMRegistry else: raise ValueError( f"Unsupported dflash_architecture_config.projector_type: {projector_type!r}. " - "Expected 'dflash' (default) or 'domino'." + "Expected 'dflash' (default), 'domino' or 'dspark'." ) original_cls = type(model) diff --git a/modelopt/torch/speculative/plugins/__init__.py b/modelopt/torch/speculative/plugins/__init__.py index ec90b8c0fda..ea789388ad4 100644 --- a/modelopt/torch/speculative/plugins/__init__.py +++ b/modelopt/torch/speculative/plugins/__init__.py @@ -32,5 +32,6 @@ with import_plugin("transformers"): from .hf_dflash import * from .hf_domino import * + from .hf_dspark import * from .hf_eagle import * from .hf_medusa import * diff --git a/modelopt/torch/speculative/plugins/hf_dspark.py b/modelopt/torch/speculative/plugins/hf_dspark.py new file mode 100644 index 00000000000..a8b60784f02 --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_dspark.py @@ -0,0 +1,485 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/deepseek-ai/DeepSpec/blob/main/deepspec/modeling/dspark/loss.py +# Copyright (c) 2026 The DeepSpec Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 AND MIT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DSpark speculative decoding plugin for HuggingFace models. + +DSpark reuses the DFlash draft backbone and training pipeline (anchor sampling, +noise/mask construction, KV-injection attention) and adds a lightweight +sequential head (see ``modeling_dspark.DSparkModule``): + +- The backbone produces *base* logits for a full draft block in parallel. +- A Markov head adds a prefix-dependent transition bias ``B_k`` to each block + position, inducing a causal block distribution + ``p_k(x_k | x_0, x_ 0) in " + "dflash_architecture_config (the Markov head's low-rank dimension)." + ) + super().modify(config) + # Three-term loss weights (DSpark only). Defaults follow the DeepSpec recipe + # (L1/TVD-dominant) so a config that only sets the head still trains sensibly. + self.dflash_ce_loss_alpha = getattr(config, "dflash_ce_loss_alpha", 0.1) + self.dflash_l1_loss_alpha = getattr(config, "dflash_l1_loss_alpha", 0.9) + self.dflash_confidence_head_alpha = getattr(config, "dflash_confidence_head_alpha", 0.0) + if self.dflash_confidence_head_alpha > 0 and not self.dflash_module.use_confidence_head: + raise ValueError( + "dflash_confidence_head_alpha > 0 but the confidence head was not built; " + "set dflash_architecture_config.use_confidence_head=true." + ) + + def get_exporter(self): + """Get the exporter for the DSpark draft model.""" + from modelopt.torch.export.plugins.hf_spec_export import DSparkExporter + + return DSparkExporter(self) + + def _apply_markov_head(self, hidden, backbone_logits, input_ids, anchor_positions, n_blocks): + """Add the Markov transition bias to the backbone base logits. + + Returns ``(final_logits [B, N, bs, V], confidence_logits [B, N, bs] | None)``. + """ + bsz, seq_len = input_ids.shape + bs = self.dflash_block_size + device = input_ids.device + + hidden4d = hidden.reshape(bsz, n_blocks, bs, hidden.size(-1)) + base4d = backbone_logits.reshape(bsz, n_blocks, bs, -1) + + # Teacher-forced previous token for block position k: the real token at + # anchor+k (so position 0's predecessor is the anchor itself). + prev_offsets = torch.arange(bs, device=device).view(1, 1, -1) + prev_idx = (anchor_positions.unsqueeze(-1) + prev_offsets).clamp(max=seq_len - 1) + prev_ids = torch.gather(input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, prev_idx) + + bias = self.dflash_module.compute_markov_bias(prev_ids, hidden4d) + final4d = base4d + bias + + confidence_logits = None + if self.dflash_module.use_confidence_head: + confidence_logits = self.dflash_module.compute_confidence_logits(prev_ids, hidden4d) + return final4d, confidence_logits + + def _compute_dspark_loss( + self, + backbone_logits, + final_logits, + confidence_logits, + input_ids, + anchor_positions, + block_keep_mask, + loss_mask, + target_model_logits, + ): + """Compute the three-term DSpark loss (CE + TVD + confidence BCE) and metrics. + + Uses next-token (shift_label) alignment: block position k predicts the token + at anchor+k+1; the aligned target distribution is the base model's own + next-token distribution at position anchor+k (= label index - 1). + """ + bsz, seq_len = input_ids.shape + bs = self.dflash_block_size + n_blocks = anchor_positions.shape[1] + device = input_ids.device + vocab = final_logits.size(-1) + + # shift_label=True: label for block position k is the token at anchor+k+1. + label_offsets = torch.arange(1, 1 + bs, device=device).view(1, 1, -1) + label_indices = anchor_positions.unsqueeze(-1) + label_offsets + valid_label = label_indices < seq_len + safe_label_indices = label_indices.clamp(max=seq_len - 1) + + target_ids = torch.gather( + input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + + # Weight mask: valid block * in bounds * loss_mask (no pos-0 exclusion). + weight_mask = block_keep_mask.unsqueeze(-1).expand(-1, -1, bs).float() + weight_mask = weight_mask * valid_label.float() + orig_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + weight_mask = weight_mask * orig_loss_mask + + binary_eval_mask = weight_mask.view(-1) + + # Exponential position decay (exp(-k/gamma); position 0 gets weight 1). + if self.dflash_loss_decay_factor > 0: + k = torch.arange(bs, device=device).view(1, 1, -1) + decay = torch.exp(-k.float() / self.dflash_loss_decay_factor) + weight_mask = weight_mask * decay + + flat_final = final_logits.reshape(-1, vocab) + flat_base = backbone_logits.reshape(-1, vocab) + flat_targets = target_ids.reshape(-1) + flat_weights = weight_mask.reshape(-1) + valid_count = flat_weights.sum() + 1e-6 + + # Aligned target distribution: base-model logits that predict token anchor+k+1 + # sit at position anchor+k (= label index - 1). + teacher_indices = (safe_label_indices - 1).clamp(min=0) + teacher_logits = torch.gather( + target_model_logits.unsqueeze(1).expand(-1, n_blocks, -1, -1), + 2, + teacher_indices.unsqueeze(-1).expand(-1, -1, -1, vocab), + ) + flat_teacher = teacher_logits.reshape(-1, vocab).detach() + + if valid_count <= 1.0: + loss = flat_final.sum() * 0.0 + metrics = {"ce_loss": 0.0, "l1_loss": 0.0, "confidence_loss": 0.0, "base_accuracy": 0.0} + return loss, 0.0, metrics + + # Term 1: cross-entropy on the corrected (final) logits. + ce_per_token = F.cross_entropy(flat_final, flat_targets, reduction="none") + ce_loss = (ce_per_token * flat_weights).sum() / valid_count + + # Term 2: total-variation distance between the corrected draft and target. + # Chunked + checkpointed to avoid materializing two [N, vocab] softmaxes at once. + l1_per_token = _tvd_per_token(flat_final, flat_teacher) + l1_loss = (l1_per_token * flat_weights).sum() / valid_count + + # Term 3: confidence head BCE against the analytical accept rate c* = 1 - 0.5*TVD. + confidence_loss = ce_loss.new_zeros(()) + if confidence_logits is not None and self.dflash_confidence_head_alpha > 0: + accept_rate = (1.0 - 0.5 * l1_per_token).clamp(0.0, 1.0).detach() + conf_bce = F.binary_cross_entropy_with_logits( + confidence_logits.reshape(-1).float(), accept_rate, reduction="none" + ) + confidence_loss = (conf_bce * flat_weights).sum() / valid_count + + loss = ( + self.dflash_ce_loss_alpha * ce_loss + + self.dflash_l1_loss_alpha * l1_loss + + self.dflash_confidence_head_alpha * confidence_loss + ) + + with torch.no_grad(): + eval_count = binary_eval_mask.sum() + 1e-6 + keep = binary_eval_mask > 0.5 + accuracy = ( + ((flat_final.argmax(dim=-1) == flat_targets) & keep).sum().float() / eval_count + ).item() + base_accuracy = ( + ((flat_base.argmax(dim=-1) == flat_targets) & keep).sum().float() / eval_count + ).item() + metrics = { + "ce_loss": ce_loss.detach().item(), + "l1_loss": l1_loss.detach().item(), + "confidence_loss": float(confidence_loss.detach().item()), + "base_accuracy": base_accuracy, + } + return loss, accuracy, metrics + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + """DSpark training forward: DFlash backbone + Markov head + three-term loss. + + Mirrors ``HFDFlashModel.forward`` for data preparation (reusing the inherited + anchor/noise/mask/position helpers), then applies the Markov head and the + DSpark loss. Eval/offline-eval is delegated to the DFlash parent. + """ + if not self.training: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + bsz, seq_len = input_ids.shape + block_size = self.dflash_block_size + device = input_ids.device + + if seq_len % block_size != 0: + raise ValueError( + f"seq_len ({seq_len}) must be divisible by block_size ({block_size}). " + f"Adjust training_seq_len or use padding." + ) + + # 1. Target hidden states AND target-model logits (DSpark's L1/confidence + # terms both need the base model's next-token distribution). + if self.dflash_offline: + assert "base_model_outputs" in kwargs + base_outputs = DFlashBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) + if base_outputs.logits is None: + out_hiddens = kwargs["base_model_outputs"]["base_model_hidden_states"] + base_outputs.logits = self._base_model_lm_head(out_hiddens) + target_hidden = base_outputs.target_hidden + target_model_logits = base_outputs.logits + else: + # Call the inner base model directly (NOT super().forward(), which during + # training runs the full DFlash pipeline). Compute target-model logits via + # the lm_head — DSpark's TVD/confidence terms need the base distribution. + with torch.no_grad(): + base_out = self._base_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + target_model_logits = self._base_model_lm_head(base_out.last_hidden_state) + offset = 1 + selected = [base_out.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] + + # 2. Build loss mask (same convention as DFlash/Domino). + if labels is not None: + loss_mask = (labels != LabelSmoother.ignore_index).float() + elif attention_mask is not None: + loss_mask = attention_mask.float() + else: + loss_mask = torch.ones(bsz, seq_len, device=device) + if kwargs.get("loss_mask") is not None: + loss_mask = loss_mask * kwargs["loss_mask"] + + # 3. Random anchor sampling (inherited). + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + n_blocks = anchor_positions.shape[1] + + if n_blocks == 0 or not block_keep_mask.any(): + # Zero loss that still flows through all draft params for DDP sync. + dummy = sum(p.sum() for p in self.dflash_module.parameters()) * 0.0 + return ModelOutput(loss=dummy, logits=None, train_acc=[[0.0]]) + + # 4. Build draft inputs (inherited helpers). + noise_embedding = self._build_noise_embedding( + input_ids, anchor_positions, block_keep_mask, n_blocks + ) + full_pos = self._build_position_ids(seq_len, anchor_positions, device) + attn_mask = self._build_draft_attention_mask( + seq_len, anchor_positions, block_keep_mask, n_blocks, target_hidden.dtype, device + ) + + # 5. Draft backbone forward. + hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=full_pos, + attention_mask=attn_mask, + ) + + # 6. Backbone logits → Markov correction → three-term loss. + backbone_logits = self._base_model_lm_head(hidden).reshape(bsz, n_blocks, block_size, -1) + final_logits, confidence_logits = self._apply_markov_head( + hidden, backbone_logits, input_ids, anchor_positions, n_blocks + ) + loss, accuracy, metrics = self._compute_dspark_loss( + backbone_logits, + final_logits, + confidence_logits, + input_ids, + anchor_positions, + block_keep_mask, + loss_mask, + target_model_logits, + ) + + return ModelOutput(loss=loss, logits=None, train_acc=[[accuracy]], dspark_metrics=metrics) + + @torch.no_grad() + def pseudo_speculative_generate(self, input_ids, steps=1): + """Generate draft tokens for AR validation, with the Markov correction applied. + + Mirrors ``HFDFlashModel.pseudo_speculative_generate`` for the backbone pass, + then samples the block left-to-right applying the Markov transition bias + autoregressively (DSpark's semi-autoregressive generation). Uses next-token + (shift) alignment: the anchor is block position 0 and predicts the first draft + token; position k's bias conditions on the token decoded at position k-1 + (the anchor for k=0). Without this override DSpark would fall back to the + DFlash backbone-only generate, under-reporting acceptance length. + """ + if self.dflash_offline: + raise RuntimeError( + "DSpark offline model cannot run AR validation / pseudo_speculative_generate — " + "base model layers were deleted during offline conversion. Reload the full " + "base model before running AR validation." + ) + model_output = self._base_model(input_ids=input_ids, output_hidden_states=True) + base_logits = self._base_model_lm_head(model_output.last_hidden_state) + base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) + + if steps < 1: + return base_token, None + + hid_offset = 1 + selected = [model_output.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) + + block_size = self.dflash_block_size + bsz = input_ids.shape[0] + device = input_ids.device + + # Block input: anchor at position 0, mask tokens elsewhere (parallel backbone). + block_ids = torch.full( + (bsz, block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_ids[:, 0] = base_token.squeeze(-1) + noise_embedding = self._base_model_embeddings(block_ids) + + ctx_len = target_hidden.shape[1] + ctx_positions = torch.arange(ctx_len, device=device) + block_positions = torch.arange(ctx_len, ctx_len + block_size, device=device) + pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) + + draft_hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + backbone_logits = self._base_model_lm_head(draft_hidden) # [B, block_size, V] + + # Autoregressive Markov sampling over the block. + m = self.dflash_module + num_tokens = min(steps, block_size) + prev_token = base_token.squeeze(-1) # anchor precedes block position 0 + state = None + draft_tokens = [] + for k in range(num_tokens): + bias, state = m.markov_step(prev_token, draft_hidden[:, k, :], state) + tok = (backbone_logits[:, k, :] + bias).argmax(dim=-1) + draft_tokens.append(tok) + prev_token = tok + return base_token, torch.stack(draft_tokens, dim=1) diff --git a/modelopt/torch/speculative/plugins/modeling_dspark.py b/modelopt/torch/speculative/plugins/modeling_dspark.py new file mode 100644 index 00000000000..75aa58f2b1f --- /dev/null +++ b/modelopt/torch/speculative/plugins/modeling_dspark.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/deepseek-ai/DeepSpec/blob/main/deepspec/modeling/dspark/markov_head.py +# Copyright (c) 2026 The DeepSpec Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 AND MIT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DSpark draft module — DFlash backbone plus a lightweight sequential (Markov) head. + +DSpark (DeepSeek-AI, "DSpark: Confidence-Scheduled Speculative Decoding with +Semi-Autoregressive Generation") shares Domino's idea: keep the parallel DFlash +backbone for speed and add a lightweight sequential head that injects the +intra-block causal dependency the parallel backbone lacks (mitigating suffix +acceptance decay). Where Domino uses a GRU over base-model token embeddings, +DSpark adds a *prefix-dependent transition bias* ``B_k`` to the backbone's base +logits, inducing a causal block distribution +``p_k(x_k | x_0, x_rank``, + ``markov_w2: rank->vocab``). Cheapest; uses neither the backbone hidden nor + recurrence. +- ``gated``: gates the previous-token embedding by the backbone hidden before + projecting: ``B = W2(sigmoid(gate_proj([h_k; W1[x_{k-1}]])) * W1[x_{k-1}])``. +- ``rnn``: a GRU-like recurrent head carrying a state ``s_k`` across positions in + the block, so position ``k`` sees the full prefix ``x_ 0, got {self.markov_rank}.") + if self.markov_head_type not in ("vanilla", "gated", "rnn"): + raise ValueError( + f"Unsupported markov_head_type: {self.markov_head_type!r}. " + "Expected 'vanilla', 'gated' or 'rnn'." + ) + + hidden_size = config.hidden_size + vocab_size = config.vocab_size + r = self.markov_rank + + # Low-rank first-order transition: W1 is an embedding lookup over the + # previous token, W2 projects the rank-r state back to a vocab logit bias. + self.markov_w1 = nn.Embedding(vocab_size, r) + self.markov_w2 = nn.Linear(r, vocab_size, bias=False) + if self.markov_head_type == "gated": + self.gate_proj = nn.Linear(hidden_size + r, r) + elif self.markov_head_type == "rnn": + # Joint [gate; candidate; output] projection over [s_{k-1}; W1[x_{k-1}]; h_k]. + self.joint_proj = nn.Linear(2 * r + hidden_size, 3 * r) + + # Optional confidence head: predicts the per-position acceptance probability + # (supervised in the wrapper by the DSpark confidence BCE loss). + self.use_confidence_head = bool(getattr(config, "use_confidence_head", False)) + if self.use_confidence_head: + self.confidence_proj = nn.Linear(hidden_size + r, 1) + + # DFlashModule.__init__ already ran _init_weights before these modules + # existed, so initialize the new layers explicitly. + self._init_head_weights(config) + + def _init_head_weights(self, config): + """Initialize the head Linear/Embedding layers (matching HF _init_weights std).""" + std = getattr(config, "initializer_range", 0.02) + modules = [self.markov_w1, self.markov_w2] + if self.markov_head_type == "gated": + modules.append(self.gate_proj) + elif self.markov_head_type == "rnn": + modules.append(self.joint_proj) + if self.use_confidence_head: + modules.append(self.confidence_proj) + for module in modules: + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=std) + + def prev_token_embeddings(self, prev_ids: torch.Tensor) -> torch.Tensor: + """Look up the Markov embedding ``W1[x_{k-1}]`` of the teacher-forced prev tokens.""" + return self.markov_w1(prev_ids.long()) + + def compute_markov_bias(self, prev_ids: torch.Tensor, hidden: torch.Tensor) -> torch.Tensor: + """Compute the transition bias ``B_k`` added to the backbone base logits. + + Args: + prev_ids: Teacher-forced previous-token ids per block position [B, N, block_size]. + hidden: Backbone hidden states [B, N, block_size, H] (used by gated/rnn heads). + + Returns: + Logit bias [B, N, block_size, vocab]. + """ + prev_emb = self.prev_token_embeddings(prev_ids) # [B, N, bs, r] + + if self.markov_head_type == "vanilla": + return self.markov_w2(prev_emb) + + if self.markov_head_type == "gated": + gate = torch.sigmoid(self.gate_proj(torch.cat([hidden, prev_emb], dim=-1))) + return self.markov_w2(gate.to(prev_emb.dtype) * prev_emb) + + # rnn: unroll the gated recurrence over the block dimension. + block_size = prev_ids.shape[-1] + leading = prev_emb.shape[:-2] # [B, N] + state = torch.zeros(*leading, self.markov_rank, device=prev_emb.device, dtype=hidden.dtype) + biases = [] + for k in range(block_size): + state, bias = self._rnn_step(state, prev_emb[..., k, :], hidden[..., k, :]) + biases.append(bias) + return torch.stack(biases, dim=-2) + + def _rnn_step(self, state, prev_emb, hidden): + """One GRU-like recurrent step. Returns (new_state [.., r], bias [.., vocab]).""" + z = torch.cat([state, prev_emb, hidden], dim=-1) + gate_raw, candidate_raw, output_raw = self.joint_proj(z).chunk(3, dim=-1) + gate = torch.sigmoid(gate_raw) + candidate = torch.tanh(candidate_raw) + new_state = gate * state + (1.0 - gate) * candidate + bias = self.markov_w2(torch.tanh(output_raw)) + return new_state, bias + + def markov_step(self, prev_token: torch.Tensor, hidden: torch.Tensor, state=None): + """One autoregressive Markov step (inference): bias for a single position. + + Args: + prev_token: Previously decoded token ids [B]. + hidden: Backbone hidden at this position [B, H] (used by gated/rnn). + state: Recurrent state [B, r] (rnn head only); None initializes to zero. + + Returns: + (bias [B, vocab], new_state [B, r] | None) — ``new_state`` is the input + ``state`` unchanged for the memoryless heads. + """ + prev_emb = self.prev_token_embeddings(prev_token) + if self.markov_head_type == "vanilla": + return self.markov_w2(prev_emb), state + if self.markov_head_type == "gated": + gate = torch.sigmoid(self.gate_proj(torch.cat([hidden, prev_emb], dim=-1))) + return self.markov_w2(gate.to(prev_emb.dtype) * prev_emb), state + # rnn + if state is None: + state = torch.zeros( + prev_emb.shape[0], self.markov_rank, device=prev_emb.device, dtype=hidden.dtype + ) + state, bias = self._rnn_step(state, prev_emb, hidden) + return bias, state + + def compute_confidence_logits( + self, prev_ids: torch.Tensor, hidden: torch.Tensor + ) -> torch.Tensor: + """Per-position acceptance-probability logits ``c_k = w^T[h_k; W1[x_{k-1}]]``. + + Returns logits [B, N, block_size] (pass through sigmoid for a probability). + """ + prev_emb = self.prev_token_embeddings(prev_ids) + return self.confidence_proj(torch.cat([hidden, prev_emb], dim=-1)).squeeze(-1) diff --git a/modelopt_recipes/general/speculative_decoding/dspark.yaml b/modelopt_recipes/general/speculative_decoding/dspark.yaml new file mode 100644 index 00000000000..cb7b09f40e9 --- /dev/null +++ b/modelopt_recipes/general/speculative_decoding/dspark.yaml @@ -0,0 +1,96 @@ +# DSpark speculative-decoding training recipe. +# +# DSpark reuses the DFlash mode/pipeline and adds a lightweight sequential +# (Markov) head plus an optional confidence head, selected via +# dflash_architecture_config.projector_type=dspark. The Markov head adds a +# prefix-dependent transition bias to the backbone base logits, inducing a causal +# block distribution (semi-autoregressive generation). Trained with a three-term +# loss: ce_alpha*CE + l1_alpha*TVD + conf_alpha*confidence_BCE. Online training is +# the default path (data.mode=online). Override fields via an OmegaConf dotlist. + +metadata: + recipe_type: speculative_dflash + description: DSpark training recipe (DFlash backbone + Markov head + confidence head). + +# maps to ModelArguments (main.py) +model: + model_name_or_path: + trust_remote_code: false + use_fake_base_for_offline: false + +# maps to DataArguments (main.py) +data: + mode: online + data_path: + offline_data_path: + # Jinja chat template with {% generation %} tags for answer_only_loss. + chat_template: + +# maps to TrainingArguments (main.py) +training: + # --- commonly modified --- + output_dir: + num_train_epochs: 6 + per_device_train_batch_size: 1 + learning_rate: 6.0e-4 + warmup_ratio: 0.04 + training_seq_len: 3072 + logging_steps: 50 + save_steps: 2000 + cp_size: 1 + dp_shard_size: 1 + disable_tqdm: true + # Keep off: eval runs the DFlash backbone only (Markov head not applied yet), + # so AR would reflect the backbone alone, not the trained model. Compare via + # export + the offline acceptance-length harness instead. + estimate_ar: false + ar_validate_steps: 0 + answer_only_loss: true + + # --- rarely modified --- + do_eval: false + lr_scheduler_type: linear + save_strategy: steps + weight_decay: 0.0 + max_grad_norm: 1.0 + dataloader_drop_last: true + bf16: true + tf32: true + remove_unused_columns: false + # Safe default: the confidence head params are unused when + # dflash_confidence_head_alpha == 0, which would otherwise trip DDP. + ddp_find_unused_parameters: true + ddp_timeout: 1800 + report_to: tensorboard + +# maps to DFlashConfig (modelopt/torch/speculative/config.py). +dflash: + dflash_block_size: 16 + dflash_num_anchors: 256 + dflash_use_torch_compile: false + # DSpark computes the target distribution internally for its TVD/confidence + # terms; the DFlash KD path is unused. + dflash_self_logit_distillation: false + # gamma for exponential loss decay (block_size=16 -> 7). + dflash_loss_decay_factor: 7.0 + # Qwen3 has no native mask token; 151669 is an unused id used by the reference. + dflash_mask_token_id: 151669 + # Three-term loss weights (DeepSpec defaults: L1/TVD-dominant). + dflash_ce_loss_alpha: 0.1 + dflash_l1_loss_alpha: 0.9 + dflash_confidence_head_alpha: 1.0 + dflash_architecture_config: + num_hidden_layers: 5 + # Draft attention/MLP dims — set explicitly (the draft is an independent + # Qwen3 model and does NOT inherit these from the base). GQA: 8 KV heads. + num_attention_heads: 32 + num_key_value_heads: 8 + head_dim: 128 + intermediate_size: 12288 + projector_type: dspark + # Markov head: low-rank first-order transition bias. 'vanilla' (memoryless), + # 'gated' (hidden-gated), or 'rnn' (recurrent, closest to Domino's GRU). + markov_rank: 256 + markov_head_type: vanilla + # Build + train the confidence head (per-position acceptance predictor). + use_confidence_head: true diff --git a/tests/unit/torch/speculative/plugins/test_hf_dspark.py b/tests/unit/torch/speculative/plugins/test_hf_dspark.py new file mode 100644 index 00000000000..4fd1f539955 --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_hf_dspark.py @@ -0,0 +1,257 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for the DSpark speculative decoding plugin. + +DSpark reuses the DFlash mode/pipeline and adds a lightweight sequential (Markov) +head plus an optional confidence head. These tests cover conversion routing for +the three head variants, the three-term training forward (CE + TVD + confidence +BCE), and the export format (head weights + config) against the z-lab-compatible +layout (``markov_w1.*`` / ``markov_w2.*`` / ``gate_proj.*`` / ``joint_proj.*`` / +``confidence_proj.*``). +""" + +import json +from copy import deepcopy + +import pytest +import torch +from _test_utils.torch.transformers_models import get_tiny_llama +from safetensors.torch import load_file + +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG +from modelopt.torch.speculative.plugins.hf_dflash import HFDFlashModel +from modelopt.torch.speculative.plugins.hf_dspark import HFDSparkModel +from modelopt.torch.speculative.plugins.modeling_dflash import DFlashModule +from modelopt.torch.speculative.plugins.modeling_dspark import DSparkModule + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be a multiple of BLOCK_SIZE +MARKOV_RANK = 16 + +HEAD_TYPES = ["vanilla", "gated", "rnn"] + + +def _get_dspark_config( + head_type="vanilla", + use_confidence_head=False, + confidence_alpha=0.0, + block_size=BLOCK_SIZE, + num_layers=NUM_DRAFT_LAYERS, +): + """Create a DSpark config for testing (dflash mode + projector_type=dspark).""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_mask_token_id"] = 0 # token 0 as mask for the tiny model + config["dflash_self_logit_distillation"] = False + config["dflash_confidence_head_alpha"] = confidence_alpha + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "projector_type": "dspark", + "markov_rank": MARKOV_RANK, + "markov_head_type": head_type, + "use_confidence_head": use_confidence_head, + "pure_draft_prefix_len": 1, + "shift_label": True, + } + return config + + +class TestDSparkConvert: + """Test DSpark model conversion routing and head construction.""" + + def test_convert_creates_dspark_model(self): + """projector_type=dspark routes to HFDSparkModel (a HFDFlashModel subclass).""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_dspark_config())]) + assert isinstance(model, HFDSparkModel) + assert isinstance(model, HFDFlashModel) + assert isinstance(model.dflash_module, DSparkModule) + + @pytest.mark.parametrize("head_type", HEAD_TYPES) + def test_head_modules_per_type(self, head_type): + """The Markov head builds the right submodules for each variant.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_dspark_config(head_type=head_type))]) + head = model.dflash_module + vocab = model.dflash_config.vocab_size + + # Low-rank transition shared by all variants; markov_w2 has no bias. + assert isinstance(head.markov_w1, torch.nn.Embedding) + assert head.markov_w1.embedding_dim == MARKOV_RANK + assert head.markov_w2.in_features == MARKOV_RANK + assert head.markov_w2.out_features == vocab + assert head.markov_w2.bias is None + + # Variant-specific projections. + assert hasattr(head, "gate_proj") == (head_type == "gated") + assert hasattr(head, "joint_proj") == (head_type == "rnn") + + def test_confidence_head_built_when_enabled(self): + """use_confidence_head=true attaches a confidence_proj; otherwise absent.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_dspark_config(use_confidence_head=True))]) + assert hasattr(model.dflash_module, "confidence_proj") + assert model.dflash_module.confidence_proj.out_features == 1 + + model2 = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model2, [("dflash", _get_dspark_config(use_confidence_head=False))]) + assert not hasattr(model2.dflash_module, "confidence_proj") + + def test_head_params_trainable(self): + """The Markov head parameters are trainable.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_dspark_config())]) + head = [(n, p) for n, p in model.named_parameters() if "markov_w" in n] + assert len(head) >= 2 # markov_w1.weight, markov_w2.weight + assert all(p.requires_grad for _, p in head) + + def test_missing_markov_rank_raises(self): + """projector_type=dspark without markov_rank is a configuration error.""" + config = _get_dspark_config() + del config["dflash_architecture_config"]["markov_rank"] + model = get_tiny_llama(num_hidden_layers=4) + with pytest.raises(ValueError, match="markov_rank"): + mtsp.convert(model, [("dflash", config)]) + + def test_dflash_mode_still_creates_plain_dflash(self): + """Without projector_type=dspark, conversion still yields a plain DFlash model.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_mask_token_id"] = 0 + config["dflash_architecture_config"] = {"num_hidden_layers": NUM_DRAFT_LAYERS} + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", config)]) + assert isinstance(model, HFDFlashModel) + assert not isinstance(model, HFDSparkModel) + assert type(model.dflash_module) is DFlashModule + + +class TestDSparkForward: + """Test the DSpark training forward (online path on CPU).""" + + def _make_batch(self, vocab_size): + torch.manual_seed(0) + input_ids = torch.randint(1, vocab_size, (2, SEQ_LEN)) + attention_mask = torch.ones_like(input_ids) + labels = input_ids.clone() + return input_ids, attention_mask, labels + + @pytest.mark.parametrize("head_type", HEAD_TYPES) + def test_forward_loss_metrics_and_grads(self, head_type): + """Forward returns a scalar loss + metrics; backward fills head + backbone grads.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_dspark_config(head_type=head_type))]) + model.train() + + input_ids, attention_mask, labels = self._make_batch(model.dflash_config.vocab_size) + out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + assert out.loss.requires_grad + assert out.loss.dim() == 0 + # three-term loss bookkeeping + for key in ("ce_loss", "l1_loss", "confidence_loss", "base_accuracy"): + assert key in out.dspark_metrics + assert out.dspark_metrics["confidence_loss"] == 0.0 # no confidence head here + + out.loss.backward() + head_grad = model.dflash_module.markov_w2.weight.grad + backbone_grad = model.dflash_module.fc.weight.grad + assert head_grad is not None and torch.isfinite(head_grad).all() + assert head_grad.abs().sum() > 0 # head actually participates in the loss + assert backbone_grad is not None and torch.isfinite(backbone_grad).all() + + def test_confidence_head_contributes_grads(self): + """With the confidence head + alpha>0, confidence_proj receives gradients.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert( + model, + [("dflash", _get_dspark_config(use_confidence_head=True, confidence_alpha=1.0))], + ) + model.train() + + input_ids, attention_mask, labels = self._make_batch(model.dflash_config.vocab_size) + out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert out.dspark_metrics["confidence_loss"] > 0.0 + + out.loss.backward() + conf_grad = model.dflash_module.confidence_proj.weight.grad + assert conf_grad is not None and torch.isfinite(conf_grad).all() + assert conf_grad.abs().sum() > 0 + + def test_confidence_alpha_without_head_raises(self): + """confidence_head_alpha>0 but no confidence head is a configuration error.""" + model = get_tiny_llama(num_hidden_layers=4) + with pytest.raises(ValueError, match="confidence"): + mtsp.convert( + model, + [("dflash", _get_dspark_config(use_confidence_head=False, confidence_alpha=1.0))], + ) + + +class TestDSparkExporter: + """Test the DSpark checkpoint export format (z-lab-compatible layout).""" + + def _export(self, tmp_path, head_type="vanilla", use_confidence_head=False): + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert( + model, + [ + ( + "dflash", + _get_dspark_config( + head_type=head_type, use_confidence_head=use_confidence_head + ), + ) + ], + ) + export_dir = tmp_path / "exported" + model.get_exporter().export(export_dir) + return export_dir + + @pytest.mark.parametrize("head_type", HEAD_TYPES) + def test_export_weight_keys_match_reference(self, tmp_path, head_type): + """Exported weights carry the head tensors under reference names, no prefix.""" + sd = load_file(str(self._export(tmp_path, head_type=head_type) / "model.safetensors")) + for key in sd: + assert "dflash_module." not in key + assert "rotary_emb" not in key + assert "markov_w1.weight" in sd + assert "markov_w2.weight" in sd + assert ("gate_proj.weight" in sd) == (head_type == "gated") + assert ("joint_proj.weight" in sd) == (head_type == "rnn") + + def test_export_includes_confidence_weights(self, tmp_path): + """The confidence head weights are exported when enabled.""" + sd = load_file(str(self._export(tmp_path, use_confidence_head=True) / "model.safetensors")) + assert "confidence_proj.weight" in sd + + def test_export_config_has_dspark_fields(self, tmp_path): + """config.json carries the dflash_config DSpark head fields.""" + export_dir = self._export(tmp_path, head_type="gated") + with open(export_dir / "config.json") as f: + cfg = json.load(f) + + assert cfg["architectures"] == ["DFlashDraftModel"] + dc = cfg["dflash_config"] + assert dc["projector_type"] == "dspark" + assert dc["markov_rank"] == MARKOV_RANK + assert dc["markov_head_type"] == "gated" + assert dc["use_confidence_head"] is False + assert dc["shift_label"] is True + assert "mask_token_id" in dc + assert "target_layer_ids" in dc