From 7dbd86d0fe0722e357c4e10c42fb6dbed043a000 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Jun 2026 11:08:03 +0000 Subject: [PATCH 1/7] refactor(moe): Merge MoE implementation into modalities core --- .../config_lorem_ipsum_long_moe_ep_fsdp2.yaml | 382 +++++++++++++ moe/modalities_moe/__init__.py | 0 moe/modalities_moe/config/__init__.py | 0 moe/modalities_moe/config/config.py | 22 - moe/modalities_moe/models/__init__.py | 0 moe/modalities_moe/models/moe/__init__.py | 0 moe/modalities_moe/models/moe/moe_model.py | 537 ------------------ moe/modalities_moe/optimizers/__init__.py | 0 moe/modalities_moe/training/__init__.py | 0 .../training/gradient_clipping/__init__.py | 0 moe/scripts/train_ep.py | 52 +- src/modalities/config/config.py | 30 + src/modalities/models/moe/__init__.py | 10 + .../modalities/models/moe}/loss_functions.py | 3 +- .../modalities/models/moe}/model_factory.py | 16 +- .../modalities}/models/moe/qwen_model.py | 69 +-- .../modalities}/optimizers/ep_adamw.py | 35 +- src/modalities/registry/components.py | 13 + .../gradient_clipping/ep_gradient_clipper.py | 5 +- 19 files changed, 464 insertions(+), 710 deletions(-) create mode 100644 config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml delete mode 100644 moe/modalities_moe/__init__.py delete mode 100644 moe/modalities_moe/config/__init__.py delete mode 100644 moe/modalities_moe/config/config.py delete mode 100644 moe/modalities_moe/models/__init__.py delete mode 100644 moe/modalities_moe/models/moe/__init__.py delete mode 100644 moe/modalities_moe/models/moe/moe_model.py delete mode 100644 moe/modalities_moe/optimizers/__init__.py delete mode 100644 moe/modalities_moe/training/__init__.py delete mode 100644 moe/modalities_moe/training/gradient_clipping/__init__.py create mode 100644 src/modalities/models/moe/__init__.py rename {moe/modalities_moe => src/modalities/models/moe}/loss_functions.py (92%) rename {moe/modalities_moe/models => src/modalities/models/moe}/model_factory.py (86%) rename {moe/modalities_moe => src/modalities}/models/moe/qwen_model.py (92%) rename {moe/modalities_moe => src/modalities}/optimizers/ep_adamw.py (80%) rename {moe/modalities_moe => src/modalities}/training/gradient_clipping/ep_gradient_clipper.py (91%) diff --git a/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml new file mode 100644 index 000000000..577656b3b --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml @@ -0,0 +1,382 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + experiments_root_path: ${modalities_env:experiments_root_path} + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 32 + evaluation_interval_in_steps: 32 + consistency_enforcement: + enforce_tokens_per_step_consistency: false + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 1 + sequence_length: 256 + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + dp_degree: + instance_key: dp_degree + pass_type: BY_REFERENCE + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + drop_last: true + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: moe_cross_entropy + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + model: + instance_key: model_raw + pass_type: BY_REFERENCE + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: -1 + tensor_parallel_degree: 4 + world_size: ${settings.cuda_env.world_size} + +dp_degree: + component_key: number_conversion + variant_key: parallel_degree + config: + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + parallelism_methods: [dp_shard, dp_replicate] + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.num_layers} + multi_device_generator_policy: error + +ep_model: + component_key: model + variant_key: ep_wrapped + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + ep_mesh_dim_name: tp + block_names: [TransformerBlock] + +ac_model: + component_key: model + variant_key: activation_checkpointed + config: + model: + instance_key: ep_model + pass_type: BY_REFERENCE + ac_variant: full_activation_checkpointing + layers_fqn: layers + ac_fun_params: + ac_freq: 1 + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: ac_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + reshard_after_forward: true + block_names: [TransformerBlock] + +model_raw: + component_key: model + variant_key: moe + config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 + max_seq_len: ${settings.step_profile.sequence_length} + d_model: 128 + n_heads: 8 + n_kv_heads: 4 + num_layers: 2 + d_ff: 128 + attn_dropout: 0.0 + ffn_dropout: 0.0 + tie_embeddings: false + norm_eps: 1e-6 + rope_base: 1000000.0 + moe_num_experts: 8 + moe_top_k: 2 + moe_d_ff: 128 + moe_capacity_factor: 1.25 + moe_min_capacity: 4 + moe_overflow_policy: residual + moe_router_noise_std: 0.0 + moe_router_temperature: 1.0 + moe_router_dropout: 0.0 + moe_aux_loss_coef: 0.001 + moe_z_loss_coef: 0.0 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: ep_adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: ep + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +mfu_calculator: + component_key: mfu_calculator + variant_key: gpt2 + config: + n_layer: ${model_raw.config.num_layers} + sequence_length: ${settings.step_profile.sequence_length} + n_embd: ${model_raw.config.d_model} + world_size: ${settings.cuda_env.world_size} + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE diff --git a/moe/modalities_moe/__init__.py b/moe/modalities_moe/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/config/__init__.py b/moe/modalities_moe/config/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/config/config.py b/moe/modalities_moe/config/config.py deleted file mode 100644 index 0ab14372a..000000000 --- a/moe/modalities_moe/config/config.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any - -from pydantic import BaseModel - -from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticPytorchModuleOrListType - - -class MoECrossEntropyLossConfig(BaseModel): - target_key: str - prediction_key: str - model: Any - tag: str = "MoECrossEntropyLoss" - - class Config: - arbitrary_types_allowed = True - - -class EPWrappedModelConfig(BaseModel): - model: PydanticPytorchModuleOrListType - block_names: list[str] - device_mesh: PydanticDeviceMeshIFType - ep_mesh_dim_name: str | None = None diff --git a/moe/modalities_moe/models/__init__.py b/moe/modalities_moe/models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/models/moe/__init__.py b/moe/modalities_moe/models/moe/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/models/moe/moe_model.py b/moe/modalities_moe/models/moe/moe_model.py deleted file mode 100644 index a1412d410..000000000 --- a/moe/modalities_moe/models/moe/moe_model.py +++ /dev/null @@ -1,537 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Literal, Optional, overload - -import torch -import torch.nn as nn -import torch.nn.functional as F -from pydantic import BaseModel - -# TODO reolve this import -try: - from torch.distributed.tensor import DTensor -except Exception: - DTensor = None - - -class MoEModelConfig(BaseModel): - # model config - vocab_size: int - max_seq_len: int - d_model: int - n_heads: int - n_kv_heads: int - num_layers: int - d_ff: int - sample_key: str = "input_ids" - prediction_key: str = "logits" - attn_dropout: float = 0.0 - ffn_dropout: float = 0.0 - tie_embeddings: bool = False - moe_every_n_layers: int = 1 - moe_num_experts: int = 8 - moe_top_k: int = 2 - moe_capacity_factor: float = 1.25 - moe_aux_loss_coef: float = 0.01 - moe_z_loss_coef: float = 0.0 - moe_router_noise_std: float = 0.0 - - -@dataclass -class MoEArguments: - # Model hyperparameters - d_model: int - d_ff: int - - # MoE hyperparameters - num_experts: int - top_k: int - capacity_factor: float = 1.25 - min_capacity: int = 4 - overflow_policy: Literal["drop", "residual"] = "residual" - - # Router configuration - router_noise_std: float = 0.0 - router_temperature: float = 1.0 - router_dropout: float = 0.0 - - # Auxiliary loss coefficients - aux_loss_coef: float = 0.01 - z_loss_coef: float = 0.0 - - # Training configuration - dropout: float = 0.0 - - -class RMSNorm(nn.Module): - def __init__(self, d_model, eps=1e-8): - super(RMSNorm, self).__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(d_model)) - - def forward(self, x): - norm_x = x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - return norm_x * self.weight - - -class Expert(nn.Module): - def __init__(self, d_model, d_ff, dropout=0.0): - super(Expert, self).__init__() - self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() - self.w1 = nn.Linear(d_model, d_ff) - self.w2 = nn.Linear(d_model, d_ff) - self.w3 = nn.Linear(d_ff, d_model) - - def forward(self, x): - x1 = self.w1(x) - x2 = self.w2(x) - x = torch.nn.functional.silu(x1) * x2 - x = self.w3(x) - return self.dropout(x) - - -class GroupedExperts(nn.Module): - """Grouped experts for torchtitan compatibility.""" - - def __init__(self, config: MoEArguments): - super().__init__() - self.num_experts = config.num_experts - self.d_model = config.d_model - self.d_ff = config.d_ff - self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity() - - self.w1 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) - self.b1 = nn.Parameter(torch.empty(self.num_experts, self.d_ff)) - self.w2 = nn.Parameter(torch.empty(self.num_experts, self.d_ff, self.d_model)) - self.b2 = nn.Parameter(torch.empty(self.num_experts, self.d_ff)) - self.w3 = nn.Parameter(torch.empty(self.num_experts, self.d_model, self.d_ff)) - self.b3 = nn.Parameter(torch.empty(self.num_experts, self.d_model)) - - self.initialize() - - def initialize(self): - nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) - bound_w1 = 1 / math.sqrt(self.d_model) - nn.init.uniform_(self.b1, -bound_w1, bound_w1) - - nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) - bound_w2 = 1 / math.sqrt(self.d_model) - nn.init.uniform_(self.b2, -bound_w2, bound_w2) - - nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5)) - bound_w3 = 1 / math.sqrt(self.d_ff) - nn.init.uniform_(self.b3, -bound_w3, bound_w3) - - def _forward_local(self, routed_input, num_tokens_per_expert) -> torch.Tensor: - outputs: list[torch.Tensor] = [] - start = 0 - - # ExpertParallel may convert parameters to DTensor. Local expert compute - # expects plain tensors, so we materialize local shards when needed. - w1 = self.w1.to_local() if DTensor is not None and isinstance(self.w1, DTensor) else self.w1 - b1 = self.b1.to_local() if DTensor is not None and isinstance(self.b1, DTensor) else self.b1 - w2 = self.w2.to_local() if DTensor is not None and isinstance(self.w2, DTensor) else self.w2 - b2 = self.b2.to_local() if DTensor is not None and isinstance(self.b2, DTensor) else self.b2 - w3 = self.w3.to_local() if DTensor is not None and isinstance(self.w3, DTensor) else self.w3 - b3 = self.b3.to_local() if DTensor is not None and isinstance(self.b3, DTensor) else self.b3 - - local_num_tokens = ( - num_tokens_per_expert.to_local() - if DTensor is not None and isinstance(num_tokens_per_expert, DTensor) - else num_tokens_per_expert - ) - - total_rows = routed_input.shape[0] - for expert_idx, num_tokens in enumerate(local_num_tokens.tolist()): - requested_tokens = int(num_tokens) - end = start + requested_tokens - - # EP alignment can request padded tokens; only a subset may exist in routed_input. - local_end = min(end, total_rows) - expert_input = routed_input[start:local_end] - real_tokens = int(expert_input.shape[0]) - - out_parts: list[torch.Tensor] = [] - if real_tokens > 0: - x1 = torch.nn.functional.linear(expert_input, w1[expert_idx], b1[expert_idx]) - x2 = torch.nn.functional.linear(expert_input, w2[expert_idx], b2[expert_idx]) - hidden = torch.nn.functional.silu(x1) * x2 - out_real = torch.nn.functional.linear(hidden, w3[expert_idx], b3[expert_idx]) - out_parts.append(self.dropout(out_real)) - - pad_tokens = requested_tokens - real_tokens - if pad_tokens > 0: - out_parts.append(routed_input.new_zeros((pad_tokens, self.d_model))) - - if len(out_parts) > 0: - outputs.append(torch.cat(out_parts, dim=0) if len(out_parts) > 1 else out_parts[0]) - - start = end - - if len(outputs) == 0: - return routed_input.new_zeros((0, self.d_model)) - - out = torch.cat(outputs, dim=0) - - # EP permute may append extra global padding slots beyond per-expert aligned sizes. - # output_fn(_unpermute) expects the same row count as routed_input. - if out.shape[0] < total_rows: - out = torch.cat( - [out, routed_input.new_zeros((total_rows - out.shape[0], self.d_model))], - dim=0, - ) - elif out.shape[0] > total_rows: - out = out[:total_rows] - - return out - - def forward(self, routed_input, num_tokens_per_expert) -> torch.Tensor: - # routed_input: (M, D), sorted/grouped by expert id - # num_tokens_per_expert: (E_local,) for local compute, or global counts before EP input_fn - return self._forward_local(routed_input, num_tokens_per_expert) - - -class MoEBlock(nn.Module): - def __init__(self, config: MoEArguments): - super(MoEBlock, self).__init__() - self.config = config - self.num_experts = config.num_experts - self.router = nn.Linear(config.d_model, self.num_experts) - self.router_dropout = nn.Dropout(config.router_dropout) if config.router_dropout > 0 else nn.Identity() - self.experts = GroupedExperts(config) - - self.last_aux_loss: Optional[torch.Tensor] = None - - def forward(self, x): - B, T, D = x.size() - E = self.config.num_experts - K = self.config.top_k - N = B * T - - x_flat = x.view(N, D) - - # Router logits - logits = self.router(self.router_dropout(x_flat)) # (N, E) - if self.config.router_noise_std > 0 and self.training: - noise = torch.randn_like(logits) * self.config.router_noise_std - logits = logits + noise - logits = logits / self.config.router_temperature - probs = torch.softmax(logits, dim=-1) # (N, E) - - # top-k - topk_val, topk_idx = torch.topk(probs, k=K, dim=-1) # (N, K) - topk_w = topk_val / (topk_val.sum(dim=-1, keepdim=True) + 1e-9) # (N, K) - - # capacity per expert - capacity = math.ceil(self.config.capacity_factor * N / E) - capacity = max(capacity, self.config.min_capacity) - - # dispatch mask - preserve dtype of input - dispatch_mask = torch.nn.functional.one_hot(topk_idx, num_classes=E).to(x_flat.dtype) # (N, K, E) - - # token assignment - expert_mask = dispatch_mask.sum(dim=1) # (N, E) - positions = torch.cumsum(expert_mask, dim=0) # (N, E) - capacity_mask = (positions <= capacity).to(x_flat.dtype) # (N, E) - final_mask = dispatch_mask * capacity_mask.unsqueeze(1) # (N, K, E) - combine_weights = final_mask * topk_w.unsqueeze(-1) # (N, K, E) - - combine_weights.sum(dim=1) # (N, E) - - # count actual assignments per expert - load = final_mask.sum(dim=[0, 1]) # (E,) - importance = probs.sum(dim=0) # (E,) - - # Build routed token stream. - valid_mask = capacity_mask.gather(1, topk_idx).bool() # (N, K) - token_ids = torch.arange(N, device=x.device).unsqueeze(1).expand(N, K) - - flat_valid = valid_mask.reshape(-1) - flat_token_ids = token_ids.reshape(-1)[flat_valid] - flat_expert_ids = topk_idx.reshape(-1)[flat_valid] - flat_weights = topk_w.reshape(-1)[flat_valid] - - if flat_expert_ids.numel() > 0: - sort_idx = torch.argsort(flat_expert_ids) - token_ids_sorted = flat_token_ids[sort_idx] - expert_ids_sorted = flat_expert_ids[sort_idx] - weights_sorted = flat_weights[sort_idx] - - routed_input = x_flat[token_ids_sorted] - num_tokens_per_expert = torch.bincount(expert_ids_sorted, minlength=E) - - routed_output = self.experts(routed_input, num_tokens_per_expert) - weighted_output = routed_output * weights_sorted.unsqueeze(-1) - - out = x_flat.new_zeros((N, D)) - out.index_add_(0, token_ids_sorted, weighted_output) - - assigned = x_flat.new_zeros((N,)) - assigned.index_add_(0, token_ids_sorted, weights_sorted) - else: - out = x_flat.new_zeros((N, D)) - assigned = x_flat.new_zeros((N,)) - - # Overflow handling: tokens not assigned to any expert - not_assigned = assigned < 1e-6 - - if not_assigned.any(): - if self.config.overflow_policy == "residual": - out[not_assigned] = x_flat[not_assigned] - # if 'drop', out is already zero for those positions - - # auxiliary loss - aux = None - if self.config.aux_loss_coef > 0: - imp = importance / (importance.sum() + 1e-9) - ld = load / (load.sum() + 1e-9) - lb = E * torch.sum(imp * ld) - aux = self.config.aux_loss_coef * lb - - if self.config.z_loss_coef > 0: - z = torch.logsumexp(logits, dim=-1) - z_loss = torch.mean(z**2) - aux = (aux if aux is not None else 0.0) + self.config.z_loss_coef * z_loss - - self.last_aux_loss = aux - return out.view(B, T, D) - - -class GroupedQueryAttention(nn.Module): - def __init__(self, d_model, num_heads, num_kv_heads): - super(GroupedQueryAttention, self).__init__() - self.d_model = d_model - self.n_heads = num_heads - self.n_kv_heads = num_kv_heads - self.head_dim = d_model // num_heads - self.q_proj = nn.Linear(d_model, num_heads * self.head_dim) - self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim) - self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim) - self.out_proj = nn.Linear(num_heads * self.head_dim, d_model) - - def forward(self, query, key, value, mask=None): - Q = self.q_proj(query).view(query.size(0), -1, self.n_heads, self.head_dim) - K = self.k_proj(key).view(key.size(0), -1, self.n_kv_heads, self.head_dim) - V = self.v_proj(value).view(value.size(0), -1, self.n_kv_heads, self.head_dim) - Q = Q.permute(0, 2, 1, 3) - K = K.permute(0, 2, 1, 3) - V = V.permute(0, 2, 1, 3) - # Compute attention scores - attn_scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) / (self.head_dim**0.5) - if mask is not None: - attn_scores += mask - attn_weights = F.softmax(attn_scores, dim=-1) - attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights, V) - attn_output = attn_output.contiguous().view(query.size(0), -1, self.n_heads * self.head_dim) - return self.out_proj(attn_output), None - - -class TransformerBlock(nn.Module): - """Transformer block with MoE""" - - def __init__(self, d_model, d_ff, num_heads, num_kv_heads, moe_config: MoEArguments): - super(TransformerBlock, self).__init__() - self.d_model = d_model - self.d_ff = d_ff - self.n_heads = num_heads - self.n_kv_heads = num_kv_heads - self.attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True) - self.pre_attn_norm = RMSNorm(d_model) - self.pre_ffn_norm = RMSNorm(d_model) - - if moe_config is not None: - self.ffn = MoEBlock(moe_config) - self.is_moe = True - else: - self.ffn = Expert(d_model, d_ff) - self.is_moe = False - - def forward(self, x): - x_norm = self.pre_attn_norm(x) - attn_output, _ = self.attention(x_norm, x_norm, x_norm) - x = x + attn_output - - # Pre-MoE norm - x_norm = self.pre_ffn_norm(x) - moe_output = self.ffn(x_norm) - x = x + moe_output - - return x - - @property # TODO: AUX LOSS IN FORWARD - def aux_loss(self): - if self.is_moe and hasattr(self.ffn, "last_aux_loss"): - return self.ffn.last_aux_loss - return None - - -class MoEModel(nn.Module): - def __init__( - self, - vocab_size: int, - max_seq_len: int, - d_model: int, - n_heads: int, - n_kv_heads: int, - d_ff: int, - num_layers: int, - sample_key: str = "input_ids", - prediction_key: str = "logits", - attn_dropout: float = 0.0, - ffn_dropout: float = 0.0, - tie_embeddings: bool = True, - moe_every_n_layers: int = 1, - moe_num_experts: int = 8, - moe_top_k: int = 2, - moe_capacity_factor: float = 1.25, - moe_aux_loss_coef: float = 0.01, - moe_z_loss_coef: float = 0.0, - moe_router_noise_std: float = 0.0, - ): - super(MoEModel, self).__init__() - self.sample_key = sample_key - self.prediction_key = prediction_key - self.vocab_size = vocab_size - self.max_seq_len = max_seq_len - self.d_model = d_model - self.n_heads = n_heads - self.n_kv_heads = n_kv_heads - self.num_layers = num_layers - self.d_ff = d_ff - self.attn_dropout = attn_dropout - self.ffn_dropout = ffn_dropout - self.tie_embeddings = tie_embeddings - self.moe_every_n_layers = moe_every_n_layers - self.moe_num_experts = moe_num_experts - self.moe_top_k = moe_top_k - self.moe_capacity_factor = moe_capacity_factor - self.moe_aux_loss_coef = moe_aux_loss_coef - self.moe_z_loss_coef = moe_z_loss_coef - self.moe_router_noise_std = moe_router_noise_std - - self.token_emb = nn.Embedding(self.vocab_size, self.d_model) - self.pos_emb = nn.Embedding(self.max_seq_len, self.d_model) - - moe_config = MoEArguments( - d_model=self.d_model, - d_ff=self.d_ff, - num_experts=self.moe_num_experts, - top_k=self.moe_top_k, - capacity_factor=self.moe_capacity_factor, - aux_loss_coef=self.moe_aux_loss_coef, - z_loss_coef=self.moe_z_loss_coef, - router_noise_std=self.moe_router_noise_std, - dropout=self.ffn_dropout, - ) - - self.layers = nn.ModuleDict() - for i in range(self.num_layers): - if i % self.moe_every_n_layers == 0: - block = TransformerBlock(self.d_model, self.d_ff, self.n_heads, self.n_kv_heads, moe_config) - else: - block = TransformerBlock(self.d_model, self.d_ff, self.n_heads, self.n_kv_heads, None) # No MoE - self.layers[str(i)] = block - self.final_norm = RMSNorm(self.d_model) - self.lm_head = nn.Linear(self.d_model, self.vocab_size, bias=False) - if self.tie_embeddings: - self.lm_head.weight = self.token_emb.weight - - @property - def weight_decay_groups(self): - return { - "linear": ["attention", "router", "w1", "w2", "w3", "b1", "b2", "b3", "lm_head"], - "embedding": ["token_emb", "pos_emb"], - "layernorm": ["pre_attn_norm", "pre_ffn_norm", "final_norm"], - } - - @overload - def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """ - Forward pass of the MoE module. - - Args: - inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - - sample_key (str): Key for the input tensor containing token ids. - - Returns: - dict[str, torch.Tensor]: A dictionary containing output tensors. - - prediction_key (str): Key for the output tensor containing logits. - """ - ... - - @overload - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the module. - - Args: - inputs (torch.Tensor): A tensor containing input token ids. - - Returns: - torch.Tensor: A tensor containing output logits. - """ - ... - - def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: - """ - Forward pass of the module. - - Args: - inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. - - Returns: - dict[str, torch.Tensor] | torch.Tensor: Model output. - """ - if isinstance(inputs, dict): - return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} - else: - return self.forward_impl(inputs) - - def forward_impl(self, input_ids: torch.Tensor) -> torch.Tensor: - B, T = input_ids.size() - assert T <= self.max_seq_len, f"Sequence length {T} exceeds model's max_seq_len {self.max_seq_len}" - device = input_ids.device - - # Token and position embeddings - token_embeddings = self.token_emb(input_ids) # (B, T, D) - positions = torch.arange(T, device=device).unsqueeze(0).expand(B, T) - pos_embeddings = self.pos_emb(positions) # (B, T, D) - x = token_embeddings + pos_embeddings # (B, T, D) - - # Transformer blocks - for i, layer in enumerate(self.layers.values()): - x = layer(x) - - x = self.final_norm(x) - logits = self.lm_head(x) # (B, T, vocab_size) - - return logits - - -if __name__ == "__main__": # sanity test - torch.manual_seed(0) - - model = MoEModel( - vocab_size=32064, - max_seq_len=32768, - d_model=4096, - n_heads=32, - n_kv_heads=8, - num_layers=32, - d_ff=14336, - moe_every_n_layers=1, - moe_num_experts=8, - moe_top_k=2, - ) - - # Print number of trainable parameters - num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Number of trainable parameters: {num_params:,}") - - x = torch.randint(0, model.vocab_size, (2, 64)) - logits = model(x) - - print("logits:", logits.shape) - loss = logits.mean() - loss.backward() - print("backward OK") diff --git a/moe/modalities_moe/optimizers/__init__.py b/moe/modalities_moe/optimizers/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/training/__init__.py b/moe/modalities_moe/training/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/modalities_moe/training/gradient_clipping/__init__.py b/moe/modalities_moe/training/gradient_clipping/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/moe/scripts/train_ep.py b/moe/scripts/train_ep.py index 232ef287f..7c99eee03 100644 --- a/moe/scripts/train_ep.py +++ b/moe/scripts/train_ep.py @@ -1,27 +1,17 @@ # ruff: noqa: E402 import os -import sys from pathlib import Path - -MOE_ROOT = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(MOE_ROOT)) +from typing import cast import torch import torch.distributed as dist -from modalities_moe.config.config import EPWrappedModelConfig, MoECrossEntropyLossConfig -from modalities_moe.loss_functions import MoECrossEntropyLoss -from modalities_moe.models.model_factory import get_ep_wrapped_model -from modalities_moe.models.moe.qwen_model import QwenModel, QwenModelConfig -from modalities_moe.optimizers.ep_adamw import EPAdamWConfig, get_ep_adam_w -from modalities_moe.training.gradient_clipping.ep_gradient_clipper import EPGradientClipper from torch.distributed.tensor import DTensor from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType from modalities.config.instantiation_models import TrainingComponentsInstantiationModel from modalities.running_env.cuda_env import CudaEnv -from modalities.training.gradient_clipping.fsdp_gradient_clipper_config import FSDP2GradientClipperConfig cwd = Path(__file__).resolve().parent.parent os.chdir(cwd) @@ -102,8 +92,8 @@ def _generate_permute_indices_no_triton( kernels.generate_permute_indices = _generate_permute_indices_no_triton moe_utils.generate_permute_indices = _generate_permute_indices_no_triton - kernels._modalities_fallback_enabled = True - kernels._modalities_generate_permute_indices_original = _orig_generate_permute_indices + setattr(kernels, "_modalities_fallback_enabled", True) + setattr(kernels, "_modalities_generate_permute_indices_original", _orig_generate_permute_indices) def debug_ep(model): @@ -132,40 +122,10 @@ def main(): config_path=CONFIG_FILE_PATH, experiments_root_path=EXPERIMENTS_ROOT_PATH, ) - modalities_main.add_custom_component( - component_key="model", - variant_key="ep_wrapped", - custom_component=get_ep_wrapped_model, - custom_config=EPWrappedModelConfig, - ) - - modalities_main.add_custom_component( - component_key="model", variant_key="moe", custom_component=QwenModel, custom_config=QwenModelConfig - ) - - modalities_main.add_custom_component( - component_key="gradient_clipper", - variant_key="ep", - custom_component=EPGradientClipper, - custom_config=FSDP2GradientClipperConfig, - ) - - modalities_main.add_custom_component( - component_key="loss", - variant_key="moe_cross_entropy", - custom_component=MoECrossEntropyLoss, - custom_config=MoECrossEntropyLossConfig, - ) - - modalities_main.add_custom_component( - component_key="optimizer", - variant_key="ep_adam_w", - custom_component=get_ep_adam_w, - custom_config=EPAdamWConfig, - ) - components: TrainingComponentsInstantiationModel = modalities_main.build_components( - components_model_type=TrainingComponentsInstantiationModel + components = cast( + TrainingComponentsInstantiationModel, + modalities_main.build_components(components_model_type=TrainingComponentsInstantiationModel), ) # WORKAROUNDS (wip) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 42a19b99a..af280a0a4 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -83,6 +83,16 @@ class CLMCrossEntropyLossConfig(BaseModel): prediction_key: str +class MoECrossEntropyLossConfig(BaseModel): + target_key: str + prediction_key: str + model: Any + tag: str = "MoECrossEntropyLoss" + + class Config: + arbitrary_types_allowed = True + + # Checkpointing class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): k: PositiveInt @@ -167,6 +177,19 @@ class AdamWOptimizerConfig(BaseModel): fused: bool | None = None +class EPAdamWConfig(BaseModel): + wrapped_model: PydanticPytorchModuleOrListType + device_mesh: PydanticDeviceMeshIFType + lr: float + betas: tuple[float, float] + eps: float + weight_decay: float + weight_decay_groups_excluded: list[str] + + class Config: + arbitrary_types_allowed = True + + class DummyLRSchedulerConfig(BaseModel): optimizer: PydanticOptimizerIFType @@ -311,6 +334,13 @@ def validate_dp_mesh_existence(self): return self +class EPWrappedModelConfig(BaseModel): + model: PydanticPytorchModuleOrListType + block_names: list[str] + device_mesh: PydanticDeviceMeshIFType + ep_mesh_dim_name: str | None = None + + class DebuggingEnrichedModelConfig(BaseModel): model: PydanticPytorchModuleOrListType logging_dir_path: Path diff --git a/src/modalities/models/moe/__init__.py b/src/modalities/models/moe/__init__.py new file mode 100644 index 000000000..5e55327a1 --- /dev/null +++ b/src/modalities/models/moe/__init__.py @@ -0,0 +1,10 @@ +from modalities.models.moe.loss_functions import MoECrossEntropyLoss +from modalities.models.moe.model_factory import get_ep_wrapped_model +from modalities.models.moe.qwen_model import QwenModel, QwenModelConfig + +__all__ = [ + "MoECrossEntropyLoss", + "QwenModel", + "QwenModelConfig", + "get_ep_wrapped_model", +] diff --git a/moe/modalities_moe/loss_functions.py b/src/modalities/models/moe/loss_functions.py similarity index 92% rename from moe/modalities_moe/loss_functions.py rename to src/modalities/models/moe/loss_functions.py index 654677efb..642efb47a 100644 --- a/moe/modalities_moe/loss_functions.py +++ b/src/modalities/models/moe/loss_functions.py @@ -6,7 +6,7 @@ class MoECrossEntropyLoss(Loss): - """Cross Entropy Loss with auxiliary loss support for router balancing""" + """Cross entropy loss with optional MoE auxiliary losses from model layers.""" def __init__( self, @@ -31,7 +31,6 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: labels.contiguous().long().view(-1), ) - # Aux loss for layer in self.model.layers.values(): if hasattr(layer, "aux_loss") and layer.aux_loss is not None: loss = loss + layer.aux_loss.to(loss.dtype) diff --git a/moe/modalities_moe/models/model_factory.py b/src/modalities/models/moe/model_factory.py similarity index 86% rename from moe/modalities_moe/models/model_factory.py rename to src/modalities/models/moe/model_factory.py index 65dbe8e3f..406da1964 100644 --- a/moe/modalities_moe/models/model_factory.py +++ b/src/modalities/models/moe/model_factory.py @@ -10,10 +10,7 @@ from modalities.util import get_module_class_from_name -# TODO refactor these funtions into a utils -def _resolve_ep_mesh( - device_mesh: DeviceMesh, ep_mesh_dim_name: str | None -) -> DeviceMesh: # devicemesh not supporting EP +def _resolve_ep_mesh(device_mesh: DeviceMesh, ep_mesh_dim_name: str | None) -> DeviceMesh: mesh_dim_names = tuple(device_mesh.mesh_dim_names or ()) if ep_mesh_dim_name is not None: @@ -72,15 +69,6 @@ def _apply_torchtitan_ep(module, ep_mesh) -> None: setattr(module.experts, "_ep_enabled", True) -def debug_forward_hook(module, input): - for name, param in module.named_parameters(recurse=False): - if hasattr(param, "_local_tensor"): - # still dTensor - print(f"[EP forward] {name}: still DTensor, local={param._local_tensor.shape}") - else: - print(f"[EP forward] {name}: plain tensor shape={param.shape}") - - def get_ep_wrapped_model( model, block_names: list[str], @@ -89,7 +77,6 @@ def get_ep_wrapped_model( mp_param_dtype=torch.bfloat16, mp_reduce_dtype=torch.bfloat16, ) -> nn.Module: - # Warn for unresolved names, but still wrap any block types that can be resolved. block_types = [] missing_block_names = [] for name in block_names: @@ -111,7 +98,6 @@ def get_ep_wrapped_model( raise ValueError(f"None of the requested MoE block names were found: {block_names}") ep_mesh = _resolve_ep_mesh(device_mesh, ep_mesh_dim_name) - device_mesh["dp_shard"] MixedPrecisionPolicy(param_dtype=mp_param_dtype, reduce_dtype=mp_reduce_dtype) wrapped_blocks = 0 diff --git a/moe/modalities_moe/models/moe/qwen_model.py b/src/modalities/models/moe/qwen_model.py similarity index 92% rename from moe/modalities_moe/models/moe/qwen_model.py rename to src/modalities/models/moe/qwen_model.py index 3a5ec2d61..47a810722 100644 --- a/moe/modalities_moe/models/moe/qwen_model.py +++ b/src/modalities/models/moe/qwen_model.py @@ -1,11 +1,13 @@ import math -from typing import Literal, Optional +from typing import Literal, Optional, overload import torch import torch.nn as nn import torch.nn.functional as F from pydantic import BaseModel +from modalities.models.model import NNModel + try: from torch.distributed.tensor import DTensor except Exception: @@ -13,7 +15,6 @@ class QwenModelConfig(BaseModel): - # Model vocab_size: int max_seq_len: int d_model: int @@ -311,18 +312,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out.view(B, T, D) -class DenseMLP(nn.Module): - def __init__(self, d_model, d_ff, ffn_dropout): - super().__init__() - self.w1 = nn.Linear(d_model, d_ff, bias=False) - self.w2 = nn.Linear(d_model, d_ff, bias=False) - self.w3 = nn.Linear(d_ff, d_model, bias=False) - self.dropout = nn.Dropout(ffn_dropout) if ffn_dropout > 0 else nn.Identity() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x))) - - class TransformerBlock(nn.Module): def __init__( self, @@ -385,7 +374,7 @@ def aux_loss(self) -> Optional[torch.Tensor]: return getattr(self.ffn, "last_aux_loss", None) -class QwenModel(nn.Module): +class QwenModel(NNModel): def __init__( self, vocab_size: int, @@ -414,7 +403,12 @@ def __init__( moe_aux_loss_coef: float = 0.001, moe_z_loss_coef: float = 0.0, ): - super().__init__() + weight_decay_groups = { + "linear": ["q_proj", "k_proj", "v_proj", "o_proj", "lm_head", "router", "w1", "w2", "w3"], + "embedding": ["token_emb"], + "layernorm": ["pre_attn_norm", "pre_ffn_norm", "final_norm", "q_norm", "k_norm"], + } + super().__init__(weight_decay_groups=weight_decay_groups) self.sample_key = sample_key self.prediction_key = prediction_key @@ -454,15 +448,15 @@ def __init__( if tie_embeddings: self.lm_head.weight = self.token_emb.weight - @property - def weight_decay_groups(self): - return { - "linear": ["q_proj", "k_proj", "v_proj", "o_proj", "lm_head", "router", "w1", "w2", "w3"], - "embedding": ["token_emb"], - "layernorm": ["pre_attn_norm", "pre_ffn_norm", "final_norm", "q_norm", "k_norm"], - } + @overload + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + ... + + @overload + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + ... - def forward(self, inputs): + def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: if isinstance(inputs, dict): return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} return self.forward_impl(inputs) @@ -472,30 +466,3 @@ def forward_impl(self, input_ids: torch.Tensor) -> torch.Tensor: for layer in self.layers.values(): x = layer(x) return self.lm_head(self.final_norm(x)) - - -if __name__ == "__main__": - torch.manual_seed(0) - - model = QwenModel( - vocab_size=151936, - max_seq_len=4096, - d_model=2048, - n_heads=32, - n_kv_heads=8, - d_ff=6144, - moe_d_ff=768, - num_layers=48, - moe_num_experts=128, - moe_top_k=8, - ) - num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Parametri: {num_params/1e9:.2f}B") - - x = torch.randint(0, 151936, (2, 64)) - logits = model(x) - print(f"Output: {logits.shape}") - - loss = logits.mean() - loss.backward() - print("Backward OK") diff --git a/moe/modalities_moe/optimizers/ep_adamw.py b/src/modalities/optimizers/ep_adamw.py similarity index 80% rename from moe/modalities_moe/optimizers/ep_adamw.py rename to src/modalities/optimizers/ep_adamw.py index d7d19fe9c..2b5e72aae 100644 --- a/moe/modalities_moe/optimizers/ep_adamw.py +++ b/src/modalities/optimizers/ep_adamw.py @@ -1,27 +1,12 @@ import torch import torch.distributed as dist -from pydantic import BaseModel from torch.distributed.tensor import DTensor from torch.nn import Module from torch.optim import AdamW, Optimizer -from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticPytorchModuleOrListType from modalities.optimizers.optimizer_factory import _build_optimizer_groups_via_weight_decay_split -class EPAdamWConfig(BaseModel): - wrapped_model: PydanticPytorchModuleOrListType - device_mesh: PydanticDeviceMeshIFType - lr: float - betas: tuple[float, float] - eps: float - weight_decay: float - weight_decay_groups_excluded: list[str] - - class Config: - arbitrary_types_allowed = True - - def _get_ep_param_ids(model: Module) -> set: return {id(p) for m in model.modules() if getattr(m, "_ep_enabled", False) for p in m.parameters(recurse=False)} @@ -39,14 +24,6 @@ def _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_ class EPAdamW(Optimizer): - """ - ZeRO stage-1 for EP (DTensor) params + standard AdamW for dense params. - - Each dp_shard rank stores optimizer states for 1/dp_shard of the EP params. - After each step, updated EP param values are broadcast from owner to all ranks. - Dense params are handled by a separate AdamW (FSDP2 shards them independently). - """ - def __init__( self, model: Module, @@ -65,7 +42,6 @@ def __init__( ep_param_ids = _get_ep_param_ids(model) self._all_ep_params = [p for p in model.parameters() if id(p) in ep_param_ids] - # rank r owns params[r::dp_size] self._owned_ep_params = self._all_ep_params[self._dp_rank :: self._dp_size] dense_groups = _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_groups_excluded) @@ -76,8 +52,6 @@ def __init__( self._ep_adamw = None self._dense_adamw = AdamW(dense_groups, lr=lr, betas=betas, eps=eps) - # unified param groups for lr_scheduler compatibility: - # group 0 = all EP params, groups 1+ = dense weight-decay split ep_group = {"params": self._all_ep_params, "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay} all_groups = [ep_group] + [{**g, "lr": lr, "betas": betas, "eps": eps} for g in dense_groups] super().__init__(all_groups, {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}) @@ -89,7 +63,6 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() - # all-reduce for p in self._all_ep_params: if p.grad is None: continue @@ -101,20 +74,16 @@ def step(self, closure=None): dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, group=self._dp_group) p.grad.div_(self._dp_size) - # Sync lr if self._ep_adamw is not None: self._ep_adamw.param_groups[0]["lr"] = self.param_groups[0]["lr"] - for i, g in enumerate(self._dense_adamw.param_groups): - g["lr"] = self.param_groups[i + 1]["lr"] + for i, group in enumerate(self._dense_adamw.param_groups): + group["lr"] = self.param_groups[i + 1]["lr"] - # Update ep params if self._ep_adamw is not None: self._ep_adamw.step() - # Update dense params self._dense_adamw.step() - # broadcast updated EP param local tensors for i, p in enumerate(self._all_ep_params): owner_local_rank = i % self._dp_size owner_global_rank = dist.get_global_rank(self._dp_group, owner_local_rank) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 26df9b432..71eb2c8ad 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -36,6 +36,8 @@ DummyLRSchedulerConfig, DummyProgressSubscriberConfig, DummyResultSubscriberConfig, + EPAdamWConfig, + EPWrappedModelConfig, EvaluationResultToDiscSubscriberConfig, FSDP1ActivationCheckpointedModelConfig, FSDP1CheckpointedModelConfig, @@ -51,6 +53,7 @@ LinearWarmupCosineAnnealingLRSchedulerConfig, LLMDataLoaderConfig, MemMapDatasetConfig, + MoECrossEntropyLossConfig, OneCycleLRSchedulerConfig, PackedMemMapDatasetContinuousConfig, PackedMemMapDatasetMegatronConfig, @@ -96,6 +99,9 @@ from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory +from modalities.models.moe.loss_functions import MoECrossEntropyLoss +from modalities.models.moe.model_factory import get_ep_wrapped_model +from modalities.models.moe.qwen_model import QwenModel, QwenModelConfig from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory from modalities.models.parallelism.pipeline_parallelism_configs import ( ComponentSelectorFromPipelineConfig, @@ -109,12 +115,14 @@ ComposedInitializationRoutines, ComposedModelInitializationConfig, ) +from modalities.optimizers.ep_adamw import get_ep_adam_w from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory from modalities.optimizers.optimizer_factory import OptimizerFactory from modalities.optimizers.optimizer_list import OptimizersList from modalities.optimizers.scheduler_list import SchedulerList from modalities.running_env.fsdp.device_mesh import DeviceMeshConfig, get_device_mesh, get_parallel_degree from modalities.tokenization.tokenizer_wrapper import PreTrainedHFTokenizer, PreTrainedSPTokenizer +from modalities.training.gradient_clipping.ep_gradient_clipper import EPGradientClipper from modalities.training.gradient_clipping.fsdp_gradient_clipper import ( FSDP1GradientClipper, FSDP1LoggingOnlyGradientClipper, @@ -187,6 +195,8 @@ class ComponentEntity: COMPONENTS = [ # models ComponentEntity("model", "gpt2", GPT2ModelFactory.get_gpt2_model, GPT2LLMConfig), + ComponentEntity("model", "moe", QwenModel, QwenModelConfig), + ComponentEntity("model", "ep_wrapped", get_ep_wrapped_model, EPWrappedModelConfig), ComponentEntity( "model", "gpt2_tp", maybe_model_list(GPT2ModelFactory.get_gpt2_tensor_parallelized_model), GPT2ModelTPConfig ), @@ -250,6 +260,7 @@ class ComponentEntity: ), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), + ComponentEntity("loss", "moe_cross_entropy", MoECrossEntropyLoss, MoECrossEntropyLossConfig), # optimizers ComponentEntity( "optimizer", "adam", maybe_model_list_for_optimizer(OptimizerFactory.get_adam), AdamOptimizerConfig @@ -257,6 +268,7 @@ class ComponentEntity: ComponentEntity( "optimizer", "adam_w", maybe_model_list_for_optimizer(OptimizerFactory.get_adam_w), AdamWOptimizerConfig ), + ComponentEntity("optimizer", "ep_adam_w", maybe_model_list_for_optimizer(get_ep_adam_w), EPAdamWConfig), ComponentEntity( "optimizer", "fsdp1_checkpointed", @@ -402,6 +414,7 @@ class ComponentEntity: "gradient_clipper", "fsdp1_logging_only", FSDP1LoggingOnlyGradientClipper, FSDP1DummyGradientClipperConfig ), ComponentEntity("gradient_clipper", "fsdp2", FSDP2GradientClipper, FSDP2GradientClipperConfig), + ComponentEntity("gradient_clipper", "ep", EPGradientClipper, FSDP2GradientClipperConfig), ComponentEntity( "gradient_clipper", "fsdp2_logging_only", FSDP2LoggingOnlyGradientClipper, FSDP2DummyGradientClipperConfig ), diff --git a/moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py similarity index 91% rename from moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py rename to src/modalities/training/gradient_clipping/ep_gradient_clipper.py index 0581e3634..a2b6f25b8 100644 --- a/moe/modalities_moe/training/gradient_clipping/ep_gradient_clipper.py +++ b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py @@ -16,7 +16,7 @@ class EPGradientClipper(FSDP2GradientClipper): - """FSDP2 clipper wrapper for EP adaptation""" + """FSDP2 clipper wrapper that handles EP DTensor gradients safely.""" def __init__( self, @@ -54,7 +54,6 @@ def clip_gradients(self) -> torch.Tensor: for grad in grads: grad_norm = torch.linalg.vector_norm(grad, ord=norm_type_val) if isinstance(grad_norm, DTensor): - # Reduce each partial norm inside its own mesh before aggregation. grad_norm = grad_norm.full_tensor() norm_scalars.append(grad_norm.to(first_device)) @@ -79,8 +78,6 @@ def clip_gradients(self) -> torch.Tensor: dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) total_norm **= 1.0 / self.norm_type.value - # do not use torch.nn.utils.clip_grads_with_norm_ here: it batches grads with - # torch._foreach_mul_, which fails when the list mixes DTensors from different meshes. clip_coef = self.max_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) From b70ec51df1aad54c15aaedd6777f2b3f1156f448 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Jun 2026 11:12:08 +0000 Subject: [PATCH 2/7] fix(moe): Fix dtype and state_dict mismatch --- .../checkpointing/stateful/app_state.py | 16 ++++++++++++++++ src/modalities/models/moe/qwen_model.py | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/src/modalities/checkpointing/stateful/app_state.py b/src/modalities/checkpointing/stateful/app_state.py index 2da3ab236..25c1efc91 100644 --- a/src/modalities/checkpointing/stateful/app_state.py +++ b/src/modalities/checkpointing/stateful/app_state.py @@ -184,6 +184,16 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: class OptimizerStateRetriever(StateRetrieverIF): + @staticmethod + def _uses_standard_optimizer_state_dict(app_state: AppState) -> bool: + """Checks whether the optimizer state dict follows the standard torch Optimizer schema. + + Standard optimizer state dicts contain top-level "state" and "param_groups" keys, + which are required by distributed optimizer checkpoint utilities. + """ + state_dict = app_state.optimizer.state_dict() + return isinstance(state_dict, dict) and "state" in state_dict and "param_groups" in state_dict + @staticmethod def get_state_dict(app_state: AppState) -> dict[str, Any]: """Returns the state dict of the optimizer in the AppState object. @@ -196,6 +206,10 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]: """ if isinstance(app_state.optimizer, OptimizersList): sd = app_state.optimizer.state_dict() + elif not OptimizerStateRetriever._uses_standard_optimizer_state_dict(app_state): + # Custom optimizers (e.g. EP wrappers) may not expose the standard torch + # optimizer format expected by get_optimizer_state_dict. + sd = app_state.optimizer.state_dict() else: assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." sd = get_optimizer_state_dict( @@ -217,6 +231,8 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: """ if isinstance(app_state.optimizer, OptimizersList): app_state.optimizer.load_state_dict(state_dict) + elif not OptimizerStateRetriever._uses_standard_optimizer_state_dict(app_state): + app_state.optimizer.load_state_dict(state_dict) else: assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." set_optimizer_state_dict( diff --git a/src/modalities/models/moe/qwen_model.py b/src/modalities/models/moe/qwen_model.py index 47a810722..ac4cab752 100644 --- a/src/modalities/models/moe/qwen_model.py +++ b/src/modalities/models/moe/qwen_model.py @@ -166,6 +166,12 @@ def _forward_local(self, routed_input: torch.Tensor, num_tokens_per_expert: torc w1 = self.w1.to_local() if DTensor is not None and isinstance(self.w1, DTensor) else self.w1 w2 = self.w2.to_local() if DTensor is not None and isinstance(self.w2, DTensor) else self.w2 w3 = self.w3.to_local() if DTensor is not None and isinstance(self.w3, DTensor) else self.w3 + # F.linear requires matching dtypes between inputs and weights. Under mixed precision, + # routed_input can be BF16 while local expert weights remain FP32. + if routed_input.dtype != w1.dtype: + w1 = w1.to(dtype=routed_input.dtype) + w2 = w2.to(dtype=routed_input.dtype) + w3 = w3.to(dtype=routed_input.dtype) local_num_tokens = ( num_tokens_per_expert.to_local() if DTensor is not None and isinstance(num_tokens_per_expert, DTensor) From e5edeeb715c644368ba0096a8985d7a2f3ad219f Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Jun 2026 11:15:46 +0000 Subject: [PATCH 3/7] chore: Remove outdated files --- moe/config/moe_ep_config.yaml | 357 ---------------------- moe/config/qwen_config.yaml | 365 ----------------------- moe/config/tokenization_config.yaml | 18 -- moe/scripts/train_ep.py | 155 ---------- {moe/scripts => scripts}/monitor_gpus.sh | 0 5 files changed, 895 deletions(-) delete mode 100644 moe/config/moe_ep_config.yaml delete mode 100644 moe/config/qwen_config.yaml delete mode 100644 moe/config/tokenization_config.yaml delete mode 100644 moe/scripts/train_ep.py rename {moe/scripts => scripts}/monitor_gpus.sh (100%) diff --git a/moe/config/moe_ep_config.yaml b/moe/config/moe_ep_config.yaml deleted file mode 100644 index 883e0dffb..000000000 --- a/moe/config/moe_ep_config.yaml +++ /dev/null @@ -1,357 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - prediction_key: logits - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - experiments_root_path: /leonardo/home/userexternal/gesposit/projects/modalities/moe/experiments - experiment_folder_path: ${settings.paths.experiments_root_path}/${settings.experiment_id} - checkpoint_saving_path: /leonardo_scratch/large/userexternal/gesposit/modalities/checkpoints - train_dataset_path: /leonardo_scratch/large/userexternal/gesposit/modalities/data/processed/fineweb_edu_num_docs_483606.pbin - intervals: - training_log_interval_in_steps: 1 - checkpointing_interval_in_steps: 1001 - evaluation_interval_in_steps: 1001 - consistency_enforcement: - enforce_tokens_per_step_consistency: true - enforce_last_step_logged: false - enforce_last_step_evaluated: false - enforce_last_step_checkpointed: false - step_profile: - gradient_accumulation_steps: 4 - local_train_micro_batch_size: 1 - sequence_length: 512 - dp_degree: - instance_key: dp_degree - pass_type: BY_REFERENCE - training_target: - num_target_tokens: - component_key: number_conversion - variant_key: num_tokens_from_num_steps - config: - num_steps: ${settings.training_target.num_target_steps} - dp_degree: - instance_key: dp_degree - pass_type: BY_REFERENCE - local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} - sequence_length: ${settings.step_profile.sequence_length} - gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} - num_target_steps: 10 - training_progress: - global_num_seen_tokens: 0 - num_seen_steps: 0 - num_seen_samples: 0 - last_step: -1 - -collate_fn: - component_key: collate_fn - variant_key: gpt_2_llm_collator - config: - sample_key: ${settings.referencing_keys.sample_key} - target_key: ${settings.referencing_keys.target_key} - -train_dataset: - component_key: dataset - variant_key: packed_mem_map_dataset_continuous - config: - raw_data_path: ${settings.paths.train_dataset_path} - sequence_length: ${settings.step_profile.sequence_length} - sample_key: ${settings.referencing_keys.sample_key} - -train_dataloader: - component_key: data_loader - variant_key: default - config: - # we set num_workers to 0 so that the the data is loaded in the main process - # this is required to track how often the collator has been called - # in the library tutorials. Otherwise the collator will be copied for each worker - # and the number of call is out of scope. - num_workers: 0 - pin_memory: true - dataloader_tag: train - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.step_profile.local_train_micro_batch_size} - drop_last: true - sampler: - component_key: sampler - variant_key: resumable_distributed_sampler - config: - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: true - seed: 42 - drop_last: true - skip_num_global_samples: ${settings.training_progress.num_seen_samples} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: [] - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: dcp - config: - checkpoint_path: ${settings.paths.experiment_folder_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - -loss_fn: - component_key: loss - variant_key: moe_cross_entropy - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${settings.referencing_keys.prediction_key} - model: - instance_key: model_raw - pass_type: BY_REFERENCE - -device_mesh: - component_key: device_mesh - variant_key: default - config: - device_type: cuda - data_parallel_replicate_degree: 1 - # Keep FSDP sharding on dp_shard and reserve tp for expert parallel. - data_parallel_shard_degree: -1 - tensor_parallel_degree: 32 - world_size: ${settings.cuda_env.world_size} - -dp_degree: - component_key: number_conversion - variant_key: parallel_degree - config: # get the parallel degree from the device mesh - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - parallelism_methods: [dp_shard, dp_replicate] - -app_state: - component_key: app_state - variant_key: raw - config: - model: - instance_key: initialized_model - pass_type: BY_REFERENCE - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - lr_scheduler: - instance_key: lr_scheduler - pass_type: BY_REFERENCE - -initialized_model: - component_key: model - variant_key: model_initialized - config: - model: - instance_key: fsdp_model - pass_type: BY_REFERENCE - model_initializer: - component_key: model_initialization - variant_key: composed - config: - model_type: gpt2 - weight_init_type: scaled - mean: 0.0 - std: 0.02 - num_layers: ${model_raw.config.num_layers} - -ep_model: - component_key: model - variant_key: ep_wrapped - config: - model: - instance_key: model_raw # Bypass torch.compile - MoE routing is incompatible - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - ep_mesh_dim_name: tp - block_names: [TransformerBlock] - -ac_model: - component_key: model - variant_key: activation_checkpointed # using modalities fsdp2 ac. should do to job also for ep layers - config: - model: - instance_key: ep_model - pass_type: BY_REFERENCE - ac_variant: full_activation_checkpointing - layers_fqn: layers - ac_fun_params: - ac_freq: 1 - -fsdp_model: - component_key: model - variant_key: fsdp2_wrapped - config: - model: - instance_key: ac_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - mixed_precision_settings: - param_dtype: BF_16 - reduce_dtype: BF_16 - reshard_after_forward: true - block_names: [TransformerBlock] - -compiled_model: - component_key: model - variant_key: compiled - config: - model: - instance_key: model_raw - pass_type: BY_REFERENCE - block_names: [TransformerBlock] - -model_raw: - component_key: model - variant_key: moe - config: - vocab_size: 32064 # to match a pretrained tokenizer - max_seq_len: 4096 - d_model: 4096 - n_heads: 32 - n_kv_heads: 8 - num_layers: 32 - d_ff: 14336 - moe_every_n_layers: 1 - moe_num_experts: 16 - moe_top_k: 2 - -lr_scheduler: - component_key: scheduler - variant_key: onecycle_lr - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - max_lr: 6e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: ${settings.training_target.num_target_steps} - pct_start: 0.02 - anneal_strategy: cos - last_epoch: ${settings.training_progress.last_step} - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 0.0001 - betas: [0.9, 0.95] - eps: 1e-8 - weight_decay: 1e-1 - weight_decay_groups_excluded: [embedding, layernorm] - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: ep - config: - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - max_norm: 1.0 - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - global_rank: ${settings.cuda_env.global_rank} - num_seen_steps: ${settings.training_progress.num_seen_steps} - num_target_steps: ${settings.training_target.num_target_steps} - train_dataloader_tag: ${train_dataloader.config.dataloader_tag} - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: to_disc - config: - output_file_path: ${settings.paths.experiment_folder_path}/evaluation_results.jsonl - -mfu_calculator: - component_key: mfu_calculator - variant_key: gpt2 - config: - n_layer: ${model_raw.config.num_layers} - sequence_length: ${settings.step_profile.sequence_length} - n_embd: ${model_raw.config.d_model} - world_size: ${settings.cuda_env.world_size} - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -# profiler: -# component_key: steppable_profiler -# variant_key: combined -# config: -# profilers: -# - instance_key: kernel_profiler -# pass_type: BY_REFERENCE -# # - instance_key: memory_profiler -# # pass_type: BY_REFERENCE - -kernel_profiler: - component_key: steppable_profiler - variant_key: kernel_tracing - config: - num_wait_steps: 1 - num_warmup_steps: 1 - num_active_steps: 3 - profiler_activities: [CUDA] - profile_memory: true - record_shapes: true - with_stack: true - with_flops: true - with_modules: true - tracked_ranks: [0] - output_folder_path: ${settings.paths.experiment_folder_path}/profiling - -memory_profiler: - component_key: steppable_profiler - variant_key: memory_tracing - config: - memory_snapshot_folder_path: ${settings.paths.experiment_folder_path}/profiling - num_wait_steps: 1 - num_warmup_steps: 1 - num_active_steps: 3 - tracked_ranks: [0] \ No newline at end of file diff --git a/moe/config/qwen_config.yaml b/moe/config/qwen_config.yaml deleted file mode 100644 index 46b233dec..000000000 --- a/moe/config/qwen_config.yaml +++ /dev/null @@ -1,365 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - prediction_key: logits - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - experiments_root_path: /raid/s3/opengptx/user/richard-rutmann/experiments/modalities/moe_fsdp2 - experiment_folder_path: ${settings.paths.experiments_root_path}/${settings.experiment_id} - checkpoint_saving_path: /raid/s3/opengptx/user/richard-rutmann/experiments/modalities/moe_fsdp2/checkpoints - train_dataset_path: /raid/s3/opengptx/user/richard-rutmann/data/modalities/gpt2_tokenized/000_00000.pbin - intervals: - training_log_interval_in_steps: 1 - checkpointing_interval_in_steps: 1001 - evaluation_interval_in_steps: 1001 - consistency_enforcement: - enforce_tokens_per_step_consistency: true - enforce_last_step_logged: false - enforce_last_step_evaluated: false - enforce_last_step_checkpointed: false - step_profile: - gradient_accumulation_steps: 4 - local_train_micro_batch_size: 2 - sequence_length: 4096 - dp_degree: - instance_key: dp_degree - pass_type: BY_REFERENCE - training_target: - num_target_tokens: - component_key: number_conversion - variant_key: num_tokens_from_num_steps - config: - num_steps: ${settings.training_target.num_target_steps} - dp_degree: - instance_key: dp_degree - pass_type: BY_REFERENCE - local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} - sequence_length: ${settings.step_profile.sequence_length} - gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} - num_target_steps: 10 - training_progress: - global_num_seen_tokens: 0 - num_seen_steps: 0 - num_seen_samples: 0 - last_step: -1 - -collate_fn: - component_key: collate_fn - variant_key: gpt_2_llm_collator - config: - sample_key: ${settings.referencing_keys.sample_key} - target_key: ${settings.referencing_keys.target_key} - -train_dataset: - component_key: dataset - variant_key: packed_mem_map_dataset_continuous - config: - raw_data_path: ${settings.paths.train_dataset_path} - sequence_length: ${settings.step_profile.sequence_length} - sample_key: ${settings.referencing_keys.sample_key} - -train_dataloader: - component_key: data_loader - variant_key: default - config: - # we set num_workers to 0 so that the the data is loaded in the main process - # this is required to track how often the collator has been called - # in the library tutorials. Otherwise the collator will be copied for each worker - # and the number of call is out of scope. - num_workers: 0 - pin_memory: true - dataloader_tag: train - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.step_profile.local_train_micro_batch_size} - drop_last: true - sampler: - component_key: sampler - variant_key: resumable_distributed_sampler - config: - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: true - seed: 42 - drop_last: true - skip_num_global_samples: ${settings.training_progress.num_seen_samples} - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: [] - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: dcp - config: - checkpoint_path: ${settings.paths.experiment_folder_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - -loss_fn: - component_key: loss - variant_key: moe_cross_entropy - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${settings.referencing_keys.prediction_key} - model: - instance_key: model_raw - pass_type: BY_REFERENCE - -device_mesh: - component_key: device_mesh - variant_key: default - config: - device_type: cuda - data_parallel_replicate_degree: 1 - # Keep FSDP sharding on dp_shard and reserve tp for expert parallel. - data_parallel_shard_degree: -1 - tensor_parallel_degree: 4 - world_size: ${settings.cuda_env.world_size} - -dp_degree: - component_key: number_conversion - variant_key: parallel_degree - config: # get the parallel degree from the device mesh - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - parallelism_methods: [dp_shard, dp_replicate] - -app_state: - component_key: app_state - variant_key: raw - config: - model: - instance_key: initialized_model - pass_type: BY_REFERENCE - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - lr_scheduler: - instance_key: lr_scheduler - pass_type: BY_REFERENCE - -initialized_model: - component_key: model - variant_key: model_initialized - config: - model: - instance_key: fsdp_model - pass_type: BY_REFERENCE - model_initializer: - component_key: model_initialization - variant_key: composed - config: - model_type: gpt2 - weight_init_type: scaled - mean: 0.0 - std: 0.02 - num_layers: ${model_raw.config.num_layers} - -ep_model: - component_key: model - variant_key: ep_wrapped - config: - model: - instance_key: model_raw # Bypass torch.compile - MoE routing is incompatible - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - ep_mesh_dim_name: tp - block_names: [TransformerBlock] - -ac_model: - component_key: model - variant_key: activation_checkpointed # using modalities fsdp2 ac. should do to job also for ep layers - config: - model: - instance_key: ep_model - pass_type: BY_REFERENCE - ac_variant: full_activation_checkpointing - layers_fqn: layers - ac_fun_params: - ac_freq: 1 - -fsdp_model: - component_key: model - variant_key: fsdp2_wrapped - config: - model: - instance_key: ac_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - mixed_precision_settings: - param_dtype: BF_16 - reduce_dtype: BF_16 - reshard_after_forward: true - block_names: [TransformerBlock] - -compiled_model: - component_key: model - variant_key: compiled - config: - model: - instance_key: model_raw - pass_type: BY_REFERENCE - block_names: [TransformerBlock] - -model_raw: - component_key: model - variant_key: moe - config: - vocab_size: 50257 # to match a pretrained tokenizer, tochange - max_seq_len: 4096 - d_model: 2048 - d_ff: 6144 - n_heads: 32 - n_kv_heads: 8 - num_layers: 8 - attn_dropout: 0.0 - ffn_dropout: 0.0 - tie_embeddings: false - norm_eps: 1e-06 - rope_base: 1000000.0 - moe_num_experts: 128 - moe_d_ff: 768 - moe_top_k: 8 - -lr_scheduler: - component_key: scheduler - variant_key: onecycle_lr - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - max_lr: 6e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: ${settings.training_target.num_target_steps} - pct_start: 0.02 - anneal_strategy: cos - last_epoch: ${settings.training_progress.last_step} - -optimizer: - component_key: optimizer - variant_key: ep_adam_w - config: - lr: 0.0001 - betas: [0.9, 0.95] - eps: 1e-8 - weight_decay: 1e-1 - weight_decay_groups_excluded: [embedding, layernorm] - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: ep - config: - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - max_norm: 1.0 - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - global_rank: ${settings.cuda_env.global_rank} - num_seen_steps: ${settings.training_progress.num_seen_steps} - num_target_steps: ${settings.training_target.num_target_steps} - train_dataloader_tag: ${train_dataloader.config.dataloader_tag} - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: to_disc - config: - output_file_path: ${settings.paths.experiment_folder_path}/evaluation_results.jsonl - -mfu_calculator: - component_key: mfu_calculator - variant_key: gpt2 - config: - n_layer: ${model_raw.config.num_layers} - sequence_length: ${settings.step_profile.sequence_length} - n_embd: ${model_raw.config.d_model} - world_size: ${settings.cuda_env.world_size} - wrapped_model: - instance_key: initialized_model - pass_type: BY_REFERENCE - device_mesh: - instance_key: device_mesh - pass_type: BY_REFERENCE - -# profiler: -# component_key: steppable_profiler -# variant_key: combined -# config: -# profilers: -# - instance_key: kernel_profiler -# pass_type: BY_REFERENCE -# # - instance_key: memory_profiler -# # pass_type: BY_REFERENCE - -kernel_profiler: - component_key: steppable_profiler - variant_key: kernel_tracing - config: - num_wait_steps: 1 - num_warmup_steps: 1 - num_active_steps: 3 - profiler_activities: [CUDA] - profile_memory: true - record_shapes: true - with_stack: true - with_flops: true - with_modules: true - tracked_ranks: [0] - output_folder_path: ${settings.paths.experiment_folder_path}/profiling - -memory_profiler: - component_key: steppable_profiler - variant_key: memory_tracing - config: - memory_snapshot_folder_path: ${settings.paths.experiment_folder_path}/profiling - num_wait_steps: 1 - num_warmup_steps: 1 - num_active_steps: 3 - tracked_ranks: [0] \ No newline at end of file diff --git a/moe/config/tokenization_config.yaml b/moe/config/tokenization_config.yaml deleted file mode 100644 index 5a4b8b781..000000000 --- a/moe/config/tokenization_config.yaml +++ /dev/null @@ -1,18 +0,0 @@ -settings: - src_path: data/raw/fineweb_edu_num_docs_483606.jsonl - dst_path: data/preprocessed/fineweb_edu_num_docs_483606.pbin - index_path: data/preprocessed/fineweb_edu_num_docs_483606.idx - jq_pattern: .text - num_cpus: ${node_env:num_cpus} - eod_token: <|endoftext|> - processing_batch_size: 10 - raw_samples_queue_size: 300 - processed_samples_queue_size: 300 - -tokenizer: - component_key: tokenizer - variant_key: pretrained_hf_tokenizer - config: - pretrained_model_name_or_path: data/tokenizer - padding: false - truncation: false \ No newline at end of file diff --git a/moe/scripts/train_ep.py b/moe/scripts/train_ep.py deleted file mode 100644 index 7c99eee03..000000000 --- a/moe/scripts/train_ep.py +++ /dev/null @@ -1,155 +0,0 @@ -# ruff: noqa: E402 - -import os -from pathlib import Path -from typing import cast - -import torch -import torch.distributed as dist -from torch.distributed.tensor import DTensor - -from modalities.__main__ import Main -from modalities.config.config import ProcessGroupBackendType -from modalities.config.instantiation_models import TrainingComponentsInstantiationModel -from modalities.running_env.cuda_env import CudaEnv - -cwd = Path(__file__).resolve().parent.parent -os.chdir(cwd) -CONFIG_FILE_PATH = cwd / "config" / "qwen_config.yaml" -EXPERIMENTS_ROOT_PATH = cwd / "results" / "debug" - - -# TODO solve this -def _enable_torchtitan_moe_permute_fallback() -> ( - None -): # VIBECODATA because of Triton C error with Python headers don't know what that is - """Avoid Triton JIT build for MoE permute indices on systems without Python dev headers.""" - try: - import torchtitan.models.moe.kernels as kernels - import torchtitan.models.moe.utils as moe_utils - except Exception: - return - - if getattr(kernels, "_modalities_fallback_enabled", False): - return - - def _fill_indices_torch( - tokens_per_expert_group: torch.Tensor, - start_index_values: torch.Tensor, - write_offsets: torch.Tensor, - experts_per_rank: int, - num_ranks: int, - max_len: int, - ) -> torch.Tensor: - device = tokens_per_expert_group.device - permuted_indices = torch.full((max_len,), -1, dtype=torch.int32, device=device) - - for e in range(experts_per_rank): - write_start = int(write_offsets[e].item()) - for r in range(num_ranks): - i = r * experts_per_rank + e - start_index = int(start_index_values[i].item()) - length = int(tokens_per_expert_group[i].item()) - if length > 0: - end_idx = min(write_start + length, max_len) - permuted_indices[write_start:end_idx] = torch.arange( - start_index, - start_index + (end_idx - write_start), - dtype=torch.int32, - device=device, - ) - write_start += length - - return permuted_indices - - _orig_generate_permute_indices = kernels.generate_permute_indices - - def _generate_permute_indices_no_triton( - tokens_per_expert_group: torch.Tensor, - experts_per_rank: int, - num_ranks: int, - max_len: int, - alignment: int, - use_cpu: bool = False, - ): - del use_cpu - start_index_values = torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group - total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) - total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment) - m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(torch.int32) - m_offsets = torch.cumsum(m_sizes, 0) - write_offsets = m_offsets - m_sizes - - permuted_indices = _fill_indices_torch( - tokens_per_expert_group=tokens_per_expert_group, - start_index_values=start_index_values, - write_offsets=write_offsets, - experts_per_rank=experts_per_rank, - num_ranks=num_ranks, - max_len=max_len, - ) - return permuted_indices, m_sizes, m_offsets.to(torch.int32) - - kernels.generate_permute_indices = _generate_permute_indices_no_triton - moe_utils.generate_permute_indices = _generate_permute_indices_no_triton - setattr(kernels, "_modalities_fallback_enabled", True) - setattr(kernels, "_modalities_generate_permute_indices_original", _orig_generate_permute_indices) - - -def debug_ep(model): - # Stima memoria teorica - total_params = sum(p.numel() for p in model.parameters()) - ep_params = sum( - p.numel() for m in model.modules() if getattr(m, "_ep_enabled", False) for p in m.parameters(recurse=False) - ) - dense_params = total_params - ep_params - - print(f"Params totali: {total_params/1e6:.0f}M") - print(f"Params EP (non shardati): {ep_params/1e6:.0f}M") - print(f"Params densi (shardati su dp_shard): {dense_params/1e6:.0f}M") - - rank = dist.get_rank() - free, total = torch.cuda.mem_get_info() - print(f"[rank{rank}] Memoria dopo init: {(total-free)/1e9:.1f} GB usati") - - -def main(): - _enable_torchtitan_moe_permute_fallback() - EXPERIMENTS_ROOT_PATH.mkdir(parents=True, exist_ok=True) - - with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): - modalities_main = Main( - config_path=CONFIG_FILE_PATH, - experiments_root_path=EXPERIMENTS_ROOT_PATH, - ) - - components = cast( - TrainingComponentsInstantiationModel, - modalities_main.build_components(components_model_type=TrainingComponentsInstantiationModel), - ) - - # WORKAROUNDS (wip) - # TODO implement those into moe code - # 1. some parameters remain on cpu - device = torch.device(f"cuda:{torch.cuda.current_device()}") - for name, param in components.model_raw.named_parameters(): - if param.device.type == "cpu": - param.data = param.data.to(device) - - # 2. cast EP params to bf16 — FSDP2 skips them via ignored_params, so they stay - # fp32 from model init. Cast here to match the MixedPrecisionPolicy applied to - # dense params (param_dtype=BF_16). Halves EP memory: 29 GB → 14.5 GB at tp=4. - for mod in components.model_raw.modules(): - if getattr(mod, "_ep_enabled", False): - for pname, p in list(mod._parameters.items()): - if isinstance(p, DTensor) and p.dtype != torch.bfloat16: - bf16_local = p.to_local().to(torch.bfloat16) - bf16_p = DTensor.from_local(bf16_local, p.device_mesh, p.placements, run_check=False) - mod._parameters[pname] = torch.nn.Parameter(bf16_p, requires_grad=p.requires_grad) - - debug_ep(components.model_raw) - modalities_main.run(components) - - -if __name__ == "__main__": - main() diff --git a/moe/scripts/monitor_gpus.sh b/scripts/monitor_gpus.sh similarity index 100% rename from moe/scripts/monitor_gpus.sh rename to scripts/monitor_gpus.sh From e09aa06f5f514d78c47a6d19751c8cd25a05b1f4 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 11 Jun 2026 09:19:42 +0000 Subject: [PATCH 4/7] docs: Add removed comments --- src/modalities/models/moe/loss_functions.py | 1 + src/modalities/models/moe/model_factory.py | 1 + src/modalities/optimizers/ep_adamw.py | 16 ++++++++++++++++ .../gradient_clipping/ep_gradient_clipper.py | 2 ++ 4 files changed, 20 insertions(+) diff --git a/src/modalities/models/moe/loss_functions.py b/src/modalities/models/moe/loss_functions.py index 642efb47a..57f30da69 100644 --- a/src/modalities/models/moe/loss_functions.py +++ b/src/modalities/models/moe/loss_functions.py @@ -31,6 +31,7 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: labels.contiguous().long().view(-1), ) + # Aux loss for layer in self.model.layers.values(): if hasattr(layer, "aux_loss") and layer.aux_loss is not None: loss = loss + layer.aux_loss.to(loss.dtype) diff --git a/src/modalities/models/moe/model_factory.py b/src/modalities/models/moe/model_factory.py index 406da1964..d5b95d9bb 100644 --- a/src/modalities/models/moe/model_factory.py +++ b/src/modalities/models/moe/model_factory.py @@ -10,6 +10,7 @@ from modalities.util import get_module_class_from_name +# TODO refactor these funtions into a utils def _resolve_ep_mesh(device_mesh: DeviceMesh, ep_mesh_dim_name: str | None) -> DeviceMesh: mesh_dim_names = tuple(device_mesh.mesh_dim_names or ()) diff --git a/src/modalities/optimizers/ep_adamw.py b/src/modalities/optimizers/ep_adamw.py index 2b5e72aae..006f9faf9 100644 --- a/src/modalities/optimizers/ep_adamw.py +++ b/src/modalities/optimizers/ep_adamw.py @@ -24,6 +24,14 @@ def _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_ class EPAdamW(Optimizer): + """ + ZeRO stage-1 for EP (DTensor) params + standard AdamW for dense params. + + Each dp_shard rank stores optimizer states for 1/dp_shard of the EP params. + After each step, updated EP param values are broadcast from owner to all ranks. + Dense params are handled by a separate AdamW (FSDP2 shards them independently). + """ + def __init__( self, model: Module, @@ -42,6 +50,7 @@ def __init__( ep_param_ids = _get_ep_param_ids(model) self._all_ep_params = [p for p in model.parameters() if id(p) in ep_param_ids] + # rank r owns params[r::dp_size] self._owned_ep_params = self._all_ep_params[self._dp_rank :: self._dp_size] dense_groups = _get_dense_optimizer_groups(model, ep_param_ids, weight_decay, weight_decay_groups_excluded) @@ -52,6 +61,8 @@ def __init__( self._ep_adamw = None self._dense_adamw = AdamW(dense_groups, lr=lr, betas=betas, eps=eps) + # unified param groups for lr_scheduler compatibility: + # group 0 = all EP params, groups 1+ = dense weight-decay split ep_group = {"params": self._all_ep_params, "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay} all_groups = [ep_group] + [{**g, "lr": lr, "betas": betas, "eps": eps} for g in dense_groups] super().__init__(all_groups, {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}) @@ -63,6 +74,7 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() + # all-reduce for p in self._all_ep_params: if p.grad is None: continue @@ -74,16 +86,20 @@ def step(self, closure=None): dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, group=self._dp_group) p.grad.div_(self._dp_size) + # Sync lr if self._ep_adamw is not None: self._ep_adamw.param_groups[0]["lr"] = self.param_groups[0]["lr"] for i, group in enumerate(self._dense_adamw.param_groups): group["lr"] = self.param_groups[i + 1]["lr"] + # Update ep params if self._ep_adamw is not None: self._ep_adamw.step() + # Update dense params self._dense_adamw.step() + # broadcast updated EP param local tensors for i, p in enumerate(self._all_ep_params): owner_local_rank = i % self._dp_size owner_global_rank = dist.get_global_rank(self._dp_group, owner_local_rank) diff --git a/src/modalities/training/gradient_clipping/ep_gradient_clipper.py b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py index a2b6f25b8..2efc5ed58 100644 --- a/src/modalities/training/gradient_clipping/ep_gradient_clipper.py +++ b/src/modalities/training/gradient_clipping/ep_gradient_clipper.py @@ -78,6 +78,8 @@ def clip_gradients(self) -> torch.Tensor: dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) total_norm **= 1.0 / self.norm_type.value + # do not use torch.nn.utils.clip_grads_with_norm_ here: it batches grads with + # torch._foreach_mul_, which fails when the list mixes DTensors from different meshes. clip_coef = self.max_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) From baf94e945e94f69c37e4fb15db9b3203b5a2025c Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 11 Jun 2026 09:33:38 +0000 Subject: [PATCH 5/7] test: Add tests for MoE components --- tests/models/moe/__init__.py | 0 tests/models/moe/test_loss_functions.py | 59 ++++++++++++ tests/models/moe/test_qwen_model.py | 60 ++++++++++++ tests/optimizers/test_ep_adamw.py | 92 +++++++++++++++++++ .../test_ep_gradient_clipper.py | 50 ++++++++++ 5 files changed, 261 insertions(+) create mode 100644 tests/models/moe/__init__.py create mode 100644 tests/models/moe/test_loss_functions.py create mode 100644 tests/models/moe/test_qwen_model.py create mode 100644 tests/optimizers/test_ep_adamw.py create mode 100644 tests/training/gradient_clipping/test_ep_gradient_clipper.py diff --git a/tests/models/moe/__init__.py b/tests/models/moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/moe/test_loss_functions.py b/tests/models/moe/test_loss_functions.py new file mode 100644 index 000000000..346b69818 --- /dev/null +++ b/tests/models/moe/test_loss_functions.py @@ -0,0 +1,59 @@ +import torch +from torch.nn import CrossEntropyLoss + +from modalities.batch import InferenceResultBatch +from modalities.models.moe.loss_functions import MoECrossEntropyLoss + + +class DummyLayer: + def __init__(self, aux_loss): + self.aux_loss = aux_loss + + +class DummyModel: + def __init__(self, aux_losses: list[torch.Tensor | None]): + self.layers = {str(i): DummyLayer(aux) for i, aux in enumerate(aux_losses)} + + +def test_moe_cross_entropy_loss_adds_aux_losses(): + logits = torch.tensor( + [ + [[1.2, 0.3, -0.5], [0.1, 1.8, -0.3]], + [[0.5, -0.4, 1.1], [0.7, 0.2, -0.1]], + ], + dtype=torch.float32, + ) + targets = torch.tensor([[0, 1], [2, 0]], dtype=torch.long) + + batch = InferenceResultBatch( + targets={"targets": targets}, + predictions={"logits": logits}, + ) + + aux_1 = torch.tensor(0.2) + aux_2 = torch.tensor(0.3) + model = DummyModel(aux_losses=[aux_1, None, aux_2]) + loss_fn = MoECrossEntropyLoss(target_key="targets", prediction_key="logits", model=model) + + loss = loss_fn(batch) + base_ce = CrossEntropyLoss(reduction="mean")(logits.view(-1, logits.size(-1)), targets.view(-1)) + + assert torch.allclose(loss, base_ce + aux_1 + aux_2) + + +def test_moe_cross_entropy_loss_without_aux_matches_plain_ce(): + logits = torch.randn(2, 3, 5) + targets = torch.randint(0, 5, (2, 3), dtype=torch.long) + + batch = InferenceResultBatch( + targets={"labels": targets}, + predictions={"pred": logits}, + ) + + model = DummyModel(aux_losses=[None, None]) + loss_fn = MoECrossEntropyLoss(target_key="labels", prediction_key="pred", model=model) + + loss = loss_fn(batch) + expected = CrossEntropyLoss(reduction="mean")(logits.view(-1, logits.size(-1)), targets.view(-1)) + + assert torch.allclose(loss, expected) diff --git a/tests/models/moe/test_qwen_model.py b/tests/models/moe/test_qwen_model.py new file mode 100644 index 000000000..d4d90b592 --- /dev/null +++ b/tests/models/moe/test_qwen_model.py @@ -0,0 +1,60 @@ +import torch + +from modalities.models.moe.qwen_model import GroupedExperts, QwenModel + + +def _build_tiny_qwen_model() -> QwenModel: + return QwenModel( + vocab_size=32, + max_seq_len=16, + d_model=16, + n_heads=4, + n_kv_heads=2, + d_ff=32, + num_layers=1, + moe_d_ff=24, + moe_num_experts=4, + moe_top_k=2, + moe_capacity_factor=1.25, + moe_min_capacity=1, + moe_overflow_policy="residual", + moe_aux_loss_coef=0.01, + moe_z_loss_coef=0.0, + ) + + +def test_qwen_model_forward_dict_output_shape(): + torch.manual_seed(0) + model = _build_tiny_qwen_model() + batch_size, seq_len = 2, 5 + + input_ids = torch.randint(0, 32, (batch_size, seq_len), dtype=torch.long) + output = model({"input_ids": input_ids}) + + assert "logits" in output + assert output["logits"].shape == (batch_size, seq_len, 32) + + +def test_grouped_experts_forward_local_preserves_input_dtype(): + experts = GroupedExperts(num_experts=2, d_model=8, d_ff=12, ffn_dropout=0.0) + experts.reset_parameters() + + # Input in bf16 while expert weights are initialized in fp32. + routed_input = torch.randn(4, 8, dtype=torch.bfloat16) + num_tokens_per_expert = torch.tensor([2, 2], dtype=torch.long) + + out = experts._forward_local(routed_input=routed_input, num_tokens_per_expert=num_tokens_per_expert) + + assert out.shape == routed_input.shape + assert out.dtype == routed_input.dtype + + +def test_transformer_block_exposes_aux_loss_after_forward(): + torch.manual_seed(1) + model = _build_tiny_qwen_model() + input_ids = torch.randint(0, 32, (2, 4), dtype=torch.long) + + _ = model({"input_ids": input_ids}) + + first_layer = next(iter(model.layers.values())) + assert first_layer.aux_loss is not None diff --git a/tests/optimizers/test_ep_adamw.py b/tests/optimizers/test_ep_adamw.py new file mode 100644 index 000000000..bb366627b --- /dev/null +++ b/tests/optimizers/test_ep_adamw.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn + +from modalities.models.model import NNModel +from modalities.optimizers.ep_adamw import EPAdamW + + +class DummyDPShardMesh: + def __init__(self): + self._group = object() + + def get_group(self): + return self._group + + +class EPSubmodule(nn.Module): + def __init__(self): + super().__init__() + self.ep_weight = nn.Parameter(torch.tensor([1.0, -1.0])) + self._ep_enabled = True + + +class TinyModel(NNModel): + def __init__(self): + super().__init__(weight_decay_groups={"linear": ["linear"], "embedding": [], "layernorm": ["norm"]}) + self.linear = nn.Linear(2, 2, bias=False) + self.norm = nn.LayerNorm(2) + self.experts = EPSubmodule() + + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + x = inputs["x"] + return {"y": self.linear(x)} + + +def _patch_distributed_ops(monkeypatch): + from modalities.optimizers import ep_adamw as ep_adamw_module + + monkeypatch.setattr(ep_adamw_module.dist, "get_rank", lambda group=None: 0) + monkeypatch.setattr(ep_adamw_module.dist, "get_world_size", lambda group=None: 1) + monkeypatch.setattr(ep_adamw_module.dist, "all_reduce", lambda tensor, op=None, group=None: tensor) + monkeypatch.setattr(ep_adamw_module.dist, "broadcast", lambda tensor, src=0, group=None: tensor) + monkeypatch.setattr(ep_adamw_module.dist, "get_global_rank", lambda group, group_rank: group_rank) + + +def test_ep_adamw_state_dict_and_load_state_dict(monkeypatch): + _patch_distributed_ops(monkeypatch) + + model = TinyModel() + optimizer = EPAdamW( + model=model, + device_mesh={"dp_shard": DummyDPShardMesh()}, + lr=1e-2, + betas=(0.9, 0.95), + eps=1e-8, + weight_decay=0.1, + weight_decay_groups_excluded=["layernorm"], + ) + + state = optimizer.state_dict() + assert "ep_adamw" in state + assert "dense_adamw" in state + + optimizer.load_state_dict(state) + + +def test_ep_adamw_step_updates_parameters_and_zero_grad(monkeypatch): + _patch_distributed_ops(monkeypatch) + + model = TinyModel() + optimizer = EPAdamW( + model=model, + device_mesh={"dp_shard": DummyDPShardMesh()}, + lr=1e-2, + betas=(0.9, 0.95), + eps=1e-8, + weight_decay=0.1, + weight_decay_groups_excluded=["layernorm"], + ) + + before = [p.detach().clone() for p in model.parameters()] + for p in model.parameters(): + p.grad = torch.ones_like(p) + + optimizer.step() + after = list(model.parameters()) + + for p_before, p_after in zip(before, after): + assert not torch.allclose(p_before, p_after) + + optimizer.zero_grad(set_to_none=True) + for p in model.parameters(): + assert p.grad is None diff --git a/tests/training/gradient_clipping/test_ep_gradient_clipper.py b/tests/training/gradient_clipping/test_ep_gradient_clipper.py new file mode 100644 index 000000000..322ece24f --- /dev/null +++ b/tests/training/gradient_clipping/test_ep_gradient_clipper.py @@ -0,0 +1,50 @@ +import pytest +import torch +import torch.nn as nn + +from modalities.training.gradient_clipping.ep_gradient_clipper import EPGradientClipper +from modalities.training.gradient_clipping.fsdp_gradient_clipper import GradientClippingMode + + +class MockModel(nn.Module): + def __init__(self): + super().__init__() + self.param1 = nn.Parameter(torch.tensor([1.0, 2.0])) + self.param2 = nn.Parameter(torch.tensor([3.0, 4.0])) + + +def test_ep_gradient_clipper_clips_gradients(): + model = MockModel() + model.param1.grad = torch.tensor([1.0, 1.0]) + model.param2.grad = torch.tensor([1.0, 1.0]) + + clipper = EPGradientClipper(model_parts=model, max_norm=1.0, norm_type=GradientClippingMode.P2_NORM) + total_norm = clipper.clip_gradients() + + assert torch.allclose(total_norm, torch.tensor(2.0)) + assert torch.allclose(model.param1.grad, torch.tensor([0.5, 0.5]), atol=1e-6) + assert torch.allclose(model.param2.grad, torch.tensor([0.5, 0.5]), atol=1e-6) + + +def test_ep_gradient_clipper_returns_zero_for_no_gradients(): + model = MockModel() + + clipper = EPGradientClipper(model_parts=model, max_norm=1.0, norm_type=GradientClippingMode.P2_NORM) + total_norm = clipper.clip_gradients() + + assert torch.allclose(total_norm.cpu(), torch.tensor(0.0)) + + +def test_ep_gradient_clipper_raises_for_nonfinite_norm(): + model = MockModel() + model.param1.grad = torch.tensor([float("nan"), 1.0]) + + clipper = EPGradientClipper( + model_parts=model, + max_norm=1.0, + norm_type=GradientClippingMode.P2_NORM, + error_if_nonfinite=True, + ) + + with pytest.raises(RuntimeError, match="non-finite"): + clipper.clip_gradients() From 2a3b81a36378f6e84eab07edaa8fbfe8f505739a Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 11 Jun 2026 09:58:49 +0000 Subject: [PATCH 6/7] test: Add e2e moe test --- tests/end2end_tests/test_moe_ep_fsdp2_e2e.py | 142 +++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 tests/end2end_tests/test_moe_ep_fsdp2_e2e.py diff --git a/tests/end2end_tests/test_moe_ep_fsdp2_e2e.py b/tests/end2end_tests/test_moe_ep_fsdp2_e2e.py new file mode 100644 index 000000000..70e7203e9 --- /dev/null +++ b/tests/end2end_tests/test_moe_ep_fsdp2_e2e.py @@ -0,0 +1,142 @@ +import logging +import multiprocessing as py_mp +import os +import traceback +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.multiprocessing as mp + +from modalities.__main__ import Main, load_app_config_dict +from modalities.batch import EvaluationResultBatch +from modalities.config.config import ProcessGroupBackendType +from modalities.config.instantiation_models import TrainingComponentsInstantiationModel +from modalities.logging_broker.messages import Message +from tests.end2end_tests.custom_components import ( + MultiProcessingCudaEnv, + SaveAllResultSubscriber, + SaveAllResultSubscriberConfig, +) +from tests.utility import find_free_port, monitor_child_processes + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="This E2E test requires 4 CUDA devices.") +class TestMoEEPFSDP2E2E: + @staticmethod + def _patch_for_short_test_run(config_dict: dict[str, Any], checkpoint_root_path: Path) -> None: + # Keep runtime short while preserving EP + FSDP2 wiring. + config_dict["settings"]["intervals"]["training_log_interval_in_steps"] = 1 + config_dict["settings"]["intervals"]["checkpointing_interval_in_steps"] = 1 + config_dict["settings"]["intervals"]["evaluation_interval_in_steps"] = 1000 + + config_dict["settings"]["step_profile"]["sequence_length"] = 64 + config_dict["settings"]["step_profile"]["local_train_micro_batch_size"] = 1 + config_dict["settings"]["step_profile"]["gradient_accumulation_steps"] = 1 + + config_dict["settings"]["training_target"]["num_target_tokens"] = 512 + config_dict["settings"]["training_target"]["num_target_steps"] = 2 + config_dict["lr_scheduler"]["config"]["total_steps"] = 2 + + config_dict["train_dataset"]["config"]["sequence_length"] = 64 + config_dict["test_dataset"]["config"]["sequence_length"] = 64 + config_dict["train_dataloader"]["config"]["num_workers"] = 0 + config_dict["test_dataloader"]["config"]["num_workers"] = 0 + config_dict["train_dataloader"]["config"]["pin_memory"] = False + config_dict["test_dataloader"]["config"]["pin_memory"] = False + + config_dict["settings"]["paths"]["checkpoint_saving_path"] = checkpoint_root_path + config_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][ + "checkpoint_path" + ] = checkpoint_root_path + + @staticmethod + def _worker_wrapper( + process_id: int, + world_size: int, + rdvz_port: int, + config_file_path: Path, + tmp_path: Path, + error_queue: Any, + ) -> None: + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=rdvz_port, + ): + try: + TestMoEEPFSDP2E2E._worker_impl( + process_id=process_id, + config_file_path=config_file_path, + tmp_path=tmp_path, + ) + except Exception as exc: + tb = traceback.format_exc() + logging.error(f"Process {process_id} failed: {exc}\n{tb}") + try: + error_queue.put((process_id, tb)) + except Exception: + logging.error("Failed to write child exception to queue.") + os._exit(1) + + @staticmethod + def _worker_impl(process_id: int, config_file_path: Path, tmp_path: Path) -> None: + experiment_id = "moe-ep-fsdp2-e2e" + checkpoint_root_path = tmp_path / experiment_id / "checkpoints" + cfg = load_app_config_dict( + config_file_path=config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id + ) + TestMoEEPFSDP2E2E._patch_for_short_test_run(cfg, checkpoint_root_path) + + main_obj = Main(config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id) + main_obj.config_dict = cfg + main_obj.add_custom_component( + component_key="results_subscriber", + variant_key="save_all", + custom_component=SaveAllResultSubscriber, + custom_config=SaveAllResultSubscriberConfig, + ) + main_obj.config_dict["evaluation_subscriber"]["variant_key"] = "save_all" + main_obj.config_dict["evaluation_subscriber"]["config"] = {} + + components: TrainingComponentsInstantiationModel = main_obj.build_components( + components_model_type=TrainingComponentsInstantiationModel + ) + + assert getattr(components.model_raw, "_ep_wrapped", False), "Expected EP wrapping marker on raw model." + first_layer = next(iter(components.model_raw.layers.values())) + assert getattr(first_layer.ffn.experts, "_ep_enabled", False), "Expected experts to be EP-enabled." + + main_obj.run(components) + + result_messages: list[Message[EvaluationResultBatch]] = components.evaluation_subscriber.message_list + assert len(result_messages) > 0, "Expected training messages in evaluation subscriber." + for message in result_messages: + loss_value = message.payload.losses["train loss avg"].value + assert torch.isfinite(loss_value), f"Found non-finite train loss: {loss_value}" + + if process_id == 0: + checkpoint_info_file_path = checkpoint_root_path / "last_checkpoint_info.json" + assert checkpoint_info_file_path.exists(), "Expected checkpoint info file from DCP save." + + @staticmethod + def test_moe_ep_fsdp2_training_and_checkpointing(tmp_path: Path) -> None: + repo_root = Path(__file__).resolve().parents[2] + config_file_path = repo_root / "config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml" + + world_size = 4 + rdvz_port = find_free_port() + + manager = py_mp.Manager() + error_queue = manager.Queue() + proc_ctx = mp.spawn( + TestMoEEPFSDP2E2E._worker_wrapper, + args=(world_size, rdvz_port, config_file_path, tmp_path, error_queue), + nprocs=world_size, + join=False, + ) + + monitor_child_processes(manager, error_queue, proc_ctx) From f7f87b1ab07735079efad5bc105878fad84b8fb5 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 11 Jun 2026 11:53:34 +0000 Subject: [PATCH 7/7] refactor: Merge shared rotary embedding logic --- .../models/components/rotary_embedding.py | 126 +++++++++++++++ src/modalities/models/gpt2/gpt2_model.py | 149 ++++-------------- src/modalities/models/moe/qwen_model.py | 38 +++-- 3 files changed, 184 insertions(+), 129 deletions(-) create mode 100644 src/modalities/models/components/rotary_embedding.py diff --git a/src/modalities/models/components/rotary_embedding.py b/src/modalities/models/components/rotary_embedding.py new file mode 100644 index 000000000..c569a787e --- /dev/null +++ b/src/modalities/models/components/rotary_embedding.py @@ -0,0 +1,126 @@ +import math +from typing import Optional + +import torch + + +def compute_default_inv_freq(dim_model: int, base_freq: float, device: Optional[torch.device] = None) -> torch.Tensor: + return 1.0 / (base_freq ** (torch.arange(0, dim_model, 2, device=device).float() / dim_model)) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seq_length_dim: int) -> torch.Tensor: + cos = cos[:, :, : x.shape[seq_length_dim], :] + sin = sin[:, :, : x.shape[seq_length_dim], :] + return (x * cos) + (rotate_half(x) * sin) + + +def update_cos_sin_tables( + x: torch.Tensor, + inv_freq: torch.Tensor, + attention_scaling: float, + seq_length_dim: int, + seq_len_cached: Optional[int], + cos_cached: Optional[torch.Tensor], + sin_cached: Optional[torch.Tensor], +) -> tuple[int, torch.Tensor, torch.Tensor]: + seq_len = x.shape[seq_length_dim] + + if ( + seq_len != seq_len_cached + or cos_cached is None + or sin_cached is None + or cos_cached.device != x.device + or cos_cached.dtype != x.dtype + ): + t = torch.arange(seq_len, device=x.device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + cos_cached = (emb.cos() * attention_scaling)[None, None, :, :].to(x.dtype) + sin_cached = (emb.sin() * attention_scaling)[None, None, :, :].to(x.dtype) + seq_len_cached = seq_len + + return seq_len_cached, cos_cached, sin_cached + + +def compute_yarn_inv_freq_and_attention_scaling( + dim_model: int, + base_freq: float, + max_position_embeddings: int, + original_max_position_embeddings: int, + factor: Optional[float], + attention_factor: Optional[float], + mscale: Optional[float], + mscale_all_dim: Optional[float], + beta_fast: float, + beta_slow: float, + truncate: bool, + device: Optional[torch.device] = None, +) -> tuple[torch.Tensor, float]: + factor_float = ( + float(factor) if factor is not None else float(max_position_embeddings / original_max_position_embeddings) + ) + + def get_mscale(scale: float, mscale_value: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale_value * math.log(scale) + 1.0 + + if attention_factor is None: + if mscale is not None and mscale_all_dim is not None: + attention_factor = float( + get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) + ) + else: + attention_factor = get_mscale(factor_float) + + def find_correction_dim(num_rotations: float, dim: int, base: float, max_pos_emb: int) -> float: + return (dim * math.log(max_pos_emb / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range( + low_rot: float, + high_rot: float, + dim: int, + base: float, + max_pos_emb: int, + do_truncate: bool, + ) -> tuple[float, float]: + low = find_correction_dim(low_rot, dim, base, max_pos_emb) + high = find_correction_dim(high_rot, dim, base, max_pos_emb) + if do_truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: + if min_value == max_value: + max_value += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) + return torch.clamp(linear_func, 0, 1) + + pos_freqs = base_freq ** (torch.arange(0, dim_model, 2, device=device, dtype=torch.float) / dim_model) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor_float * pos_freqs) + + low, high = find_correction_range( + beta_fast, + beta_slow, + dim_model, + base_freq, + original_max_position_embeddings, + bool(truncate), + ) + + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim_model // 2).to( + device=device, dtype=torch.float + ) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, float(attention_factor) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index f43e6e87b..2e93b0be1 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -3,7 +3,7 @@ from abc import abstractmethod from enum import Enum from numbers import Real -from typing import Annotated, Literal, Optional, overload +from typing import Annotated, Literal, Optional, cast, overload import torch import torch.nn as nn @@ -17,6 +17,13 @@ RMSLayerNorm, RMSLayerNormConfig, ) +from modalities.models.components.rotary_embedding import ( + apply_rotary_pos_emb, + compute_default_inv_freq, + compute_yarn_inv_freq_and_attention_scaling, + rotate_half, + update_cos_sin_tables, +) from modalities.models.model import ActivationType, NNModel, SwiGLU from modalities.util import parse_enum_by_name @@ -221,9 +228,7 @@ def reset_parameters(self): if rope_type == "yarn": inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device) else: - inv_freq = 1.0 / ( - self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) - ) + inv_freq = compute_default_inv_freq(dim_model=self.dim_model, base_freq=self.base_freq, device=device) self.attention_scaling = 1.0 self.register_buffer("inv_freq", inv_freq) @@ -243,8 +248,7 @@ def rotate_half(self, x: torch.Tensor): torch.Tensor: The output tensor. """ - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) + return rotate_half(x) def apply_rotary_pos_emb(self, x, cos, sin): """ @@ -258,16 +262,7 @@ def apply_rotary_pos_emb(self, x, cos, sin): Returns: torch.Tensor: Tensor after applying rotary positional embedding. """ - # NOTE: This could probably be moved to Triton - - # Handle a possible sequence length mismatch in between q and k - cos = cos[:, :, : x.shape[self.seq_length_dim], :] - sin = sin[:, :, : x.shape[self.seq_length_dim], :] - - # the rotation is not really a rotation in higher dimensions, - # It merely swaps and negates certain dimensions to make - # the rotation below work - return (x * cos) + (self.rotate_half(x) * sin) + return apply_rotary_pos_emb(x=x, cos=cos, sin=sin, seq_length_dim=self.seq_length_dim) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor @@ -297,109 +292,31 @@ def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.T if self.max_position_embeddings is None: raise ValueError("YaRN requires max_position_embeddings to be set.") - original_max_position_embeddings = self.rope_scaling.original_max_position_embeddings - factor = self.rope_scaling.factor - if factor is None: - factor = self.max_position_embeddings / original_max_position_embeddings - factor_float = float(factor) - - attention_factor = self.rope_scaling.attention_factor - mscale_pair = None - if self.rope_scaling.mscale is not None and self.rope_scaling.mscale_all_dim is not None: - mscale_pair = (self.rope_scaling.mscale, self.rope_scaling.mscale_all_dim) - - beta_fast = self.rope_scaling.beta_fast - beta_slow = self.rope_scaling.beta_slow - truncate = self.rope_scaling.truncate - - def get_mscale(scale: float, mscale: float = 1.0) -> float: - """Return the YaRN mscale coefficient for a given scaling factor.""" - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - if attention_factor is None: - if mscale_pair is not None: - mscale, mscale_all_dim = mscale_pair - attention_factor = float( - get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) - ) - else: - attention_factor = get_mscale(factor_float) - - def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float: - """Map a target number of rotations to a rotary dimension index.""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range( - low_rot: float, - high_rot: float, - dim: int, - base: int, - max_position_embeddings: int, - truncate: bool, - ) -> tuple[float, float]: - """Compute the lower and upper rotary-dimension correction bounds for YaRN.""" - low = find_correction_dim(low_rot, dim, base, max_position_embeddings) - high = find_correction_dim(high_rot, dim, base, max_position_embeddings) - if truncate: - low = math.floor(low) - high = math.ceil(high) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: - """Create a clamped linear ramp used to blend interpolation and extrapolation.""" - if min_value == max_value: - max_value += 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - dim = self.dim_model - base = self.base_freq - - pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (factor_float * pos_freqs) - - low, high = find_correction_range( - beta_fast, - beta_slow, - dim, - base, - original_max_position_embeddings, - bool(truncate), - ) - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation * inv_freq_extrapolation_factor + return compute_yarn_inv_freq_and_attention_scaling( + dim_model=self.dim_model, + base_freq=self.base_freq, + max_position_embeddings=self.max_position_embeddings, + original_max_position_embeddings=self.rope_scaling.original_max_position_embeddings, + factor=self.rope_scaling.factor, + attention_factor=self.rope_scaling.attention_factor, + mscale=self.rope_scaling.mscale, + mscale_all_dim=self.rope_scaling.mscale_all_dim, + beta_fast=self.rope_scaling.beta_fast, + beta_slow=self.rope_scaling.beta_slow, + truncate=self.rope_scaling.truncate, + device=device, ) - return inv_freq, float(attention_factor) - def _update_cos_sin_tables(self, x): - # Update the cosine and sine tables. - seq_len = x.shape[self.seq_length_dim] - - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if ( - seq_len != self._seq_len_cached - or self._cos_cached is None - or self._sin_cached is None - or self._cos_cached.device != x.device - or self._cos_cached.dtype != x.dtype - ): - self._seq_len_cached = seq_len - t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) - emb = torch.cat((freqs, freqs), dim=-1).to( - x.device - ) # here, we combine the two matrices (not zipping them). - self._cos_cached = (emb.cos() * self.attention_scaling)[None, None, :, :].to(x.dtype) - self._sin_cached = (emb.sin() * self.attention_scaling)[None, None, :, :].to(x.dtype) - + self._seq_len_cached, self._cos_cached, self._sin_cached = update_cos_sin_tables( + x=x, + inv_freq=cast(torch.Tensor, self.inv_freq), + attention_scaling=self.attention_scaling, + seq_length_dim=self.seq_length_dim, + seq_len_cached=self._seq_len_cached, + cos_cached=self._cos_cached, + sin_cached=self._sin_cached, + ) return self._cos_cached, self._sin_cached diff --git a/src/modalities/models/moe/qwen_model.py b/src/modalities/models/moe/qwen_model.py index ac4cab752..b20dd05ef 100644 --- a/src/modalities/models/moe/qwen_model.py +++ b/src/modalities/models/moe/qwen_model.py @@ -6,6 +6,11 @@ import torch.nn.functional as F from pydantic import BaseModel +from modalities.models.components.rotary_embedding import ( + apply_rotary_pos_emb, + compute_default_inv_freq, + update_cos_sin_tables, +) from modalities.models.model import NNModel try: @@ -62,33 +67,40 @@ def __init__(self, head_dim: int, max_seq_len: int, base: float = 1000000.0): self.head_dim = head_dim self.max_seq_len = max_seq_len self.base = base + self.register_buffer("inv_freq", None, persistent=False) self.register_buffer("cos_cached", None, persistent=False) self.register_buffer("sin_cached", None, persistent=False) + self._seq_len_cached: Optional[int] = None def _compute_cache(self, device): - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim)) - t = torch.arange(self.max_seq_len, device=device).float() - freqs = torch.outer(t, inv_freq) - emb = torch.cat([freqs, freqs], dim=-1) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self.inv_freq = compute_default_inv_freq(dim_model=self.head_dim, base_freq=self.base, device=device) + self._seq_len_cached = None + self.cos_cached = None + self.sin_cached = None def forward(self, x: torch.Tensor, seq_len: int): - if self.cos_cached is None: + if self.inv_freq is None: self._compute_cache(x.device) + self._seq_len_cached, self.cos_cached, self.sin_cached = update_cos_sin_tables( + x=x, + inv_freq=self.inv_freq, + attention_scaling=1.0, + seq_length_dim=-2, + seq_len_cached=self._seq_len_cached, + cos_cached=self.cos_cached, + sin_cached=self.sin_cached, + ) return ( self.cos_cached[:, :, :seq_len, :].to(x.dtype), self.sin_cached[:, :, :seq_len, :].to(x.dtype), ) -def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - return torch.cat([-x2, x1], dim=-1) - - def apply_rotary_emb(q, k, cos, sin): - return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return ( + apply_rotary_pos_emb(x=q, cos=cos, sin=sin, seq_length_dim=-2), + apply_rotary_pos_emb(x=k, cos=cos, sin=sin, seq_length_dim=-2), + ) class GroupedQueryAttention(nn.Module):