diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7691582f97..a68d2a19b0 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -22,6 +22,7 @@ import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import ( _cudnn_frontend_version_supported, + is_glu_activation, ) from transformer_engine.pytorch.ops.fused import ( @@ -2480,6 +2481,59 @@ def test_scaled_swiglu( assert_close_grads(x_test, x_ref, **tols) assert_close_grads(scales_test, scales_ref, **tols) + @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("scales_requires_grad", (False, True)) + def test_scaled_srelu( + self, + *, + in_shape: Iterable[int], + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + input_requires_grad: bool, + scales_requires_grad: bool, + ) -> None: + """SReLU with post-scale""" + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + scales_ref, scales_test = make_reference_and_test_tensors( + in_shape[:-1], + test_dtype=dtype, + test_device=device, + requires_grad=scales_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y = torch.nn.functional.relu(x_ref).square() + y_ref = scales_ref.unsqueeze(-1) * y + if input_requires_grad or scales_requires_grad: + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.ScaledSReLU() + y_test = op(x_test, scales_test) + if input_requires_grad or scales_requires_grad: + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(scales_test, scales_ref, **tols) + def test_interleaved_scaled_swiglu(self): """SwiGLU with post-scale and block interleaved input format""" self.test_scaled_swiglu( @@ -3570,7 +3624,9 @@ def test_layernorm_mlp( @pytest.mark.parametrize("glu_interleave_size", (None, 32)) @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) @pytest.mark.parametrize("hidden_size", (128, 256)) - @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) + @pytest.mark.parametrize( + "activation", ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_srelu") + ) def test_grouped_mlp( self, *, @@ -3588,7 +3644,7 @@ def test_grouped_mlp( delay_wgrad_compute: bool, activation: str, ) -> None: - """GroupedLinear + ScaledSwiGLU / ScaledClampedQGeGLU + GroupedLinear""" + """GroupedLinear + scaled activation + GroupedLinear""" # Split sizes split_sizes = [split_alignment * (i) for i in range(group_size)] @@ -3601,6 +3657,15 @@ def test_grouped_mlp( # Skip invalid configurations with_quantization = quantization is not None + if activation == "scaled_swiglu": + scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + elif activation == "scaled_clamped_qgeglu": + scaled_act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + elif activation == "scaled_srelu": + scaled_act = te_ops.ScaledSReLU() + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") + activation_is_glu = is_glu_activation(scaled_act) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) if single_grouped_weight and quantization != "mxfp8": pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") @@ -3608,9 +3673,14 @@ def test_grouped_mlp( pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if not activation_is_glu and quantization != "mxfp8": + pytest.skip("Scaled unary grouped MLP fusion is only supported with MXFP8") + if not activation_is_glu and glu_interleave_size is not None: + pytest.skip("Unary activations do not use GLU interleaving") if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias: # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") + fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -3641,7 +3711,7 @@ def test_grouped_mlp( fc2_bs_ref, fc2_bs_test = [], [] for _ in range(group_size): fc1_w_ref, fc1_w_test = make_reference_and_test_tensors( - (2 * hidden_size, hidden_size), + (fc1_out_features, hidden_size), min=-0.25, max=0.25, quantization=quantization, @@ -3660,7 +3730,7 @@ def test_grouped_mlp( fc2_b_ref, fc2_b_test = None, None if bias: fc1_b_ref, fc1_b_test = make_reference_and_test_tensors( - (2 * hidden_size,), + (fc1_out_features,), min=-0.5, max=0.5, test_dtype=dtype, @@ -3689,7 +3759,7 @@ def test_grouped_mlp( for group_idx in range(group_size): x = xs[group_idx] x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) - if glu_interleave_size is not None: + if activation_is_glu and glu_interleave_size is not None: x = x.reshape( -1, 2 * hidden_size // (2 * glu_interleave_size), @@ -3698,15 +3768,20 @@ def test_grouped_mlp( ) x = x.transpose(1, 2) x = x.reshape(-1, 2 * hidden_size) - x1, x2 = x.chunk(2, dim=-1) if activation == "scaled_swiglu": + x1, x2 = x.chunk(2, dim=-1) x = torch.nn.functional.silu(x1) * x2 - else: + elif activation == "scaled_clamped_qgeglu": + x1, x2 = x.chunk(2, dim=-1) lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype) geglu_alpha = 1.702 x1c = torch.minimum(x1, lim) x2c = torch.clamp(x2, -lim, lim) x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c)) + elif activation == "scaled_srelu": + x = torch.nn.functional.relu(x).square() + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") x = x * probs[group_idx].unsqueeze(-1) x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx]) if bias: @@ -3717,16 +3792,11 @@ def test_grouped_mlp( # Construct operations recipe = make_recipe(quantization) - scaled_act = ( - te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_swiglu" - else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - ) with te.quantized_model_init(enabled=with_quantization, recipe=recipe): fc1 = te_ops.GroupedLinear( group_size, hidden_size, - 2 * hidden_size, + fc1_out_features, bias=bias, device=device, dtype=dtype, @@ -3810,22 +3880,31 @@ def test_grouped_mlp( if ( quantization == "mxfp8" and dtype in (torch.bfloat16, torch.float16) - and glu_interleave_size == 32 + and ( + (not activation_is_glu and glu_interleave_size is None) + or (activation_is_glu and glu_interleave_size == 32) + ) and _cudnn_frontend_version_supported() ): - if te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + if activation_is_glu: + forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8 + backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8 + else: + forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary_MXFP8 + backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8 + if forward_cls.is_supported(): forward_ops = module._module_groups[0]._forward_ops assert len(forward_ops) == 1 assert isinstance( forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + forward_cls, ) - if te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): + if backward_cls is not None and backward_cls.is_supported(): backward_ops = module._module_groups[0]._backward_ops assert len(backward_ops) == 1 assert isinstance( backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + backward_cls, ) # Loose tols for sanity checking @@ -3910,9 +3989,9 @@ def test_grouped_mlp_single_weight_numerics( ) -> None: """single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP.""" - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): + if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported(): pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") split_sizes = [split_alignment * (i + 1) for i in range(group_size)] @@ -4014,12 +4093,12 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: assert len(forward_ops) == 1 assert isinstance( forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, ) assert len(backward_ops) == 1 assert isinstance( backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, ) if single_grouped_weight: @@ -4132,9 +4211,9 @@ def test_grouped_mlp_overwrite_main_grad( that read ``.grad`` don't see stale bytes from the cached dummy). """ - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): + if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported(): pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") recipe = make_recipe("mxfp8") @@ -4266,7 +4345,7 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( ) -> None: """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): pytest.skip("MXFP8 fused grouped MLP is not supported on this system") if dtype not in (torch.bfloat16, torch.float16): pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") @@ -4408,12 +4487,12 @@ def train_step( assert len(forward_ops) == 1 assert isinstance( forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, ) assert len(backward_ops) == 1 assert isinstance( backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, ) fresh_x = torch.randn_like(static_x) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 9325d87ae7..f40beb4f9b 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -33,6 +33,11 @@ def _cudnn_frontend_version_supported() -> bool: return False +def _nvidia_cudnn_frontend_supports_wgrad() -> bool: + """Check cuDNN FE min version for grouped GEMM wgrad kernel.""" + return _cudnn_frontend_version_supported() + + def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: """Check if tensor is a quantized tensor""" return isinstance(tensor, QuantizedTensorStorage) @@ -182,8 +187,21 @@ def get_dummy_wgrads_for_params( return out -def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None: - """Validate FC1 / scaled GLU / FC2 dimensions for fused grouped MLP.""" +def is_glu_activation(activation_op) -> bool: + """Whether an activation consumes a GLU-style doubled input.""" + from .basic import ( # pylint: disable=import-outside-toplevel + ScaledClampedQGeGLU, + ScaledSwiGLU, + ) + + return isinstance(activation_op, (ScaledSwiGLU, ScaledClampedQGeGLU)) + + +def validate_grouped_mlp_dims(fc1, activation_op, fc2) -> None: + """Validate FC1 / activation / FC2 dimensions for fused grouped MLP.""" + from .basic import ( # pylint: disable=import-outside-toplevel + ScaledSReLU, + ) if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0: raise ValueError( @@ -195,17 +213,24 @@ def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None: f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " f"in_features={fc2.in_features}, out_features={fc2.out_features})." ) - if fc1.out_features != 2 * fc2.in_features or fc1.num_groups != fc2.num_groups: + if is_glu_activation(activation_op): + expected_fc1_out_features = 2 * fc2.in_features + elif isinstance(activation_op, ScaledSReLU): + expected_fc1_out_features = fc2.in_features + else: + raise TypeError(f"Unsupported grouped MLP activation ({activation_op.__class__.__name__}).") + + if fc1.out_features != expected_fc1_out_features or fc1.num_groups != fc2.num_groups: raise ValueError( f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " f"out_features={fc1.out_features}) " f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " f"out_features={fc2.out_features}) do not match." ) - if glu_op.glu_interleave_size != 32: + if is_glu_activation(activation_op) and activation_op.glu_interleave_size != 32: raise ValueError( "Fused kernel requires 32-wide GLU interleaving, " - f"but got glu_interleave_size={glu_op.glu_interleave_size}." + f"but got glu_interleave_size={activation_op.glu_interleave_size}." ) @@ -214,8 +239,9 @@ def fuse_grouped_mlp_ops( *, recipe, fused_op_cls, + activation_op_types=None, ): - """Sliding-window fusion for GroupedLinear + scaled GLU + GroupedLinear. + """Sliding-window fusion for GroupedLinear + activation + GroupedLinear. Parameters ---------- @@ -225,9 +251,7 @@ def fuse_grouped_mlp_ops( Quantization recipe. fused_op_cls : type Fused operation class with ``is_supported()`` classmethod and - constructor accepting ``fc1``, ``glu_op``, ``fc2`` keyword args. The - ``glu_op`` must be :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledSwiGLU` - or :class:`~transformer_engine.pytorch.ops.basic.swiglu.ScaledClampedQGeGLU`. + constructor accepting ``fc1``, ``activation``, and ``fc2`` keyword args. Returns ------- @@ -244,6 +268,8 @@ def fuse_grouped_mlp_ops( return ops if recipe is None or not recipe.mxfp8(): return ops + if activation_op_types is None: + activation_op_types = (ScaledSwiGLU, ScaledClampedQGeGLU) out = [] window, ops = ops[:3], ops[3:] @@ -252,7 +278,7 @@ def fuse_grouped_mlp_ops( matches_pattern = True if not ( isinstance(window[0], GroupedLinear) - and isinstance(window[1], (ScaledSwiGLU, ScaledClampedQGeGLU)) + and isinstance(window[1], activation_op_types) and isinstance(window[2], GroupedLinear) ): matches_pattern = False @@ -260,22 +286,16 @@ def fuse_grouped_mlp_ops( abs(window[1]._clamped.alpha - 1.702) > 0.001 ): matches_pattern = False - elif window[0].num_groups != window[2].num_groups: - matches_pattern = False - elif ( - window[0].in_features % 64 != 0 - or window[0].out_features % 64 != 0 - or window[2].in_features % 64 != 0 - or window[2].out_features % 64 != 0 - ): - matches_pattern = False - elif window[1].glu_interleave_size != 32: - matches_pattern = False + else: + try: + validate_grouped_mlp_dims(window[0], window[1], window[2]) + except (TypeError, ValueError): + matches_pattern = False if matches_pattern: op = fused_op_cls( fc1=window[0], - swiglu=window[1], + activation=window[1], fc2=window[2], ) window = [op] diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 45c938ede8..6def36ffc7 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -13,6 +13,7 @@ ReLU, ReGLU, SReLU, + ScaledSReLU, SReGLU, SiLU, ) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 13cb519c19..6c6cad3824 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -6,7 +6,8 @@ from __future__ import annotations import abc -from typing import Optional +from collections.abc import Iterable +from typing import Any, Optional import torch @@ -26,6 +27,7 @@ "ReLU", "ReGLU", "SReLU", + "ScaledSReLU", "SReGLU", "SiLU", ] @@ -345,6 +347,99 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dsrelu(*args, **kwargs) +class ScaledSReLU(BasicOperation): + r"""Squared ReLU with per-row post-scaling. + + If the SReLU output has shape ``(d_1, ..., d_n)``, it is multiplied + with an extra input tensor of shape ``(d_1, ..., d_{n-1})``. + """ + + num_extra_inputs: int = 1 + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], # pylint: disable=unused-argument + next_op_input_quantizer: Optional[Quantizer], # pylint: disable=unused-argument + basic_op_kwargs: list[dict[str, Any]], # pylint: disable=unused-argument + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + extra_input = basic_op_extra_inputs[0][0] + + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + dtype = input_.dtype + else: + dtype = extra_input.dtype + + x = maybe_dequantize(input_.contiguous(), dtype) + scales = maybe_dequantize(extra_input, dtype) + y = tex.srelu(x, None) * scales.unsqueeze(-1) + + ctx = basic_op_ctxs[0] + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.input_requires_grad = True + ctx.extra_input_requires_grad = extra_input.requires_grad + ctx.dtype = dtype + ctx.save_for_backward(x, scales) + + return y, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + del basic_op_grad_extra_outputs + + ctx = basic_op_ctxs[0] + x, scales = ctx.saved_tensors + x = maybe_dequantize(x.contiguous(), ctx.dtype) + scales = maybe_dequantize(scales, ctx.dtype) + grad_output = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + grad_input = None + if ctx.input_requires_grad: + grad_srelu_out = grad_output * scales.unsqueeze(-1) + grad_input = tex.dsrelu(grad_srelu_out, x, None) + + grad_extra_input = None + if ctx.extra_input_requires_grad: + srelu_out = tex.srelu(x, None) + grad_extra_input = torch.linalg.vecdot(srelu_out, grad_output) + + clear_tensor_data(ctx.saved_tensors[0]) + + return grad_input, [()], [(grad_extra_input,)] + + class SReGLU(_ActivationOperation): r"""Squared Rectified Gated Linear Unit diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 19a090f121..b29e35814d 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -32,8 +32,10 @@ # Import experimental fusions # Note: Registration logic is non-trivial, so submodule handles it internally. from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position - ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, + ForwardGroupedMLP_CuTeGEMMUnary_MXFP8, ) from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position - BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8, ) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index a11d0505c1..6700da5136 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -7,6 +7,7 @@ from __future__ import annotations from collections.abc import Callable import functools +import inspect import os from typing import Optional @@ -18,7 +19,7 @@ from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...utils import clear_tensor_data, get_cached_ones_tensor, get_device_compute_capability from ...constants import MXFP8_BLOCK_SCALING_SIZE -from ..basic import GroupedLinear, ScaledClampedQGeGLU, ScaledSwiGLU +from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU from ..fuser import register_backward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( @@ -27,6 +28,7 @@ get_accumulate_flag_in_param, get_dummy_wgrads_for_params, get_main_grad_from_param, + is_glu_activation, maybe_dequantize, view_main_grad_as_grouped_buffer, validate_grouped_mlp_dims, @@ -132,6 +134,21 @@ def _cudnn_compute_wgrad( ) +@functools.lru_cache(maxsize=None) +def _dsrelu_wrapper_has_reuse_arg() -> bool: + """True if cuDNN FE SM100 dSReLU wrapper accepts ``use_dsrelu_reuse``.""" + try: + import cudnn # pylint: disable=import-outside-toplevel + except ImportError: + return False + try: + wrapper = getattr(cudnn, "grouped_gemm_dsrelu_wrapper_sm100") + params = inspect.signature(wrapper).parameters + except (AttributeError, TypeError, ValueError): + return False + return "use_dsrelu_reuse" in params + + def _compute_grad_params( fc_op, ctx, @@ -248,20 +265,17 @@ def _compute_grad_params( return w_list + bias_list -class BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8(FusedOperation): - """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU or ScaledClampedQGeGLU + GroupedLinear +class _BackwardGroupedMLP_CuTeGEMMDBase_MXFP8(FusedOperation): + """Base fused backward op for MXFP8 GroupedLinear + activation + GroupedLinear. Uses experimental CuTe DSL kernel from cuDNN front-end. """ @classmethod - @functools.lru_cache(maxsize=None) - def grouped_gemm_dglu_kernel(cls) -> Callable: - """Fused kernel for grouped GEMM, GLU activation backward, and scale grad.""" - from cudnn import grouped_gemm_dglu_wrapper_sm100 # pylint: disable=no-name-in-module - - return grouped_gemm_dglu_wrapper_sm100 + def grouped_gemm_dactivation_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, activation backward, and scale grad.""" + raise NotImplementedError @classmethod @functools.lru_cache(maxsize=None) @@ -296,7 +310,7 @@ def is_supported(cls) -> bool: if not _cudnn_frontend_version_supported(): return False try: - cls.grouped_gemm_dglu_kernel() + cls.grouped_gemm_dactivation_kernel() cls.grouped_gemm_quant_kernel() except ImportError: return False @@ -306,19 +320,26 @@ def __init__( self, *, fc1: GroupedLinear, - swiglu: ScaledSwiGLU | ScaledClampedQGeGLU, + activation: Optional[FusibleOperation], fc2: GroupedLinear, ) -> None: - super().__init__((fc1, swiglu, fc2)) + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) if not self.is_supported(): - self.grouped_gemm_dglu_kernel() # Try triggering import error + self.grouped_gemm_dactivation_kernel() # Try triggering import error raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") - validate_grouped_mlp_dims(fc1, swiglu, fc2) - # The cuDNN dgeglu implementation corresponds to ScaledClampedQGeGLU. - # The act_func string should be fixed on the cuDNN FE side. - self._cudnn_dact_func: str = ( - "dgeglu" if isinstance(swiglu, ScaledClampedQGeGLU) else "dswiglu" - ) + validate_grouped_mlp_dims(fc1, activation, fc2) + if not is_glu_activation(activation): + # grouped_gemm_dsrelu_wrapper_sm100 is dSReLU-specific and does not + # take the GLU ``act_func`` selector. + self._cudnn_dact_func: Optional[str] = None + else: + # The cuDNN dgeglu implementation corresponds to ScaledClampedQGeGLU. + # The act_func string should be fixed on the cuDNN FE side. + self._cudnn_dact_func = ( + "dgeglu" if isinstance(activation, ScaledClampedQGeGLU) else "dswiglu" + ) def fuser_backward( self, @@ -333,7 +354,7 @@ def fuser_backward( # Get basic operations fc1_op, _, fc2_op = self.basic_ops - fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs # Tensor properties fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) @@ -358,8 +379,11 @@ def fuser_backward( saved_tensors[num_groups:], ) - # Saved tensors from scaled SwiGLU forward - swiglu_in, scales = swiglu_ctx.saved_tensors + # Saved tensors from activation forward + activation_in, scales = activation_ctx.saved_tensors + recompute_fc2_x_from_dsrelu = bool( + getattr(fc2_ctx, "recompute_input_from_dsrelu", False) + ) and bool(fc2_ctx.weight_requires_grad) # Saved tensors from FC2 forward. # Layout: [split_sizes, base_split_offsets, split_points, @@ -446,20 +470,19 @@ def fuser_backward( # Kernel scaling factors alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) - norm_const_tensor = get_cached_ones_tensor(1, dtype, device) + norm_const_tensor = get_cached_ones_tensor(1, torch.float32, device) current_stream = torch.cuda.current_stream().cuda_stream scales_f32 = scales.detach().to(dtype=torch.float32) scales_tensor = scales_f32.reshape(-1, 1, 1) dscales_tensor = torch.zeros_like(scales_tensor) - fc2_dglu_kwargs = { + fc2_dactivation_kwargs = { "a_tensor": fc2_dy_data, - "c_tensor": swiglu_in.unsqueeze(0).permute(1, 2, 0), + "c_tensor": activation_in.unsqueeze(0).permute(1, 2, 0), "sfa_tensor": fc2_dy_scales, "padded_offsets": split_points, "alpha_tensor": alpha_tensor, - "beta_tensor": alpha_tensor, "prob_tensor": scales_tensor, "dprob_tensor": dscales_tensor, "generate_dbias": fc1_op.has_bias, @@ -469,9 +492,15 @@ def fuser_backward( "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, "current_stream": current_stream, "discrete_col_sfd": True, - "act_func": self._cudnn_dact_func, "use_dynamic_sched": True, } + if self._cudnn_dact_func is not None: + fc2_dactivation_kwargs["beta_tensor"] = alpha_tensor + fc2_dactivation_kwargs["act_func"] = self._cudnn_dact_func + elif _dsrelu_wrapper_has_reuse_arg(): + fc2_dactivation_kwargs["use_dsrelu_reuse"] = ( + os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP_DSRELU_REUSE", "0") == "1" + ) if fc2_op.single_grouped_weight: # Clone and swizzle scales for GEMM @@ -495,8 +524,8 @@ def fuser_backward( ) fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) - fc2_dglu_kwargs["b_tensor"] = fc2_w_data - fc2_dglu_kwargs["sfb_tensor"] = fc2_w_scales + fc2_dactivation_kwargs["b_tensor"] = fc2_w_data + fc2_dactivation_kwargs["sfb_tensor"] = fc2_w_scales else: fc2_b_ptrs, fc2_sfb_ptrs, _fc2_sw = tex.get_device_pointer_for_data_and_scales( [w._columnwise_data for w in grouped_fc2_weight], @@ -505,13 +534,13 @@ def fuser_backward( rowwise=False, data_dtype=grouped_fc2_weight[0]._fp8_dtype, ) - fc2_dglu_kwargs["b_ptrs"] = fc2_b_ptrs - fc2_dglu_kwargs["sfb_ptrs"] = fc2_sfb_ptrs - fc2_dglu_kwargs["n"] = fc2_weight_shape[1] - fc2_dglu_kwargs["b_dtype"] = torch.float8_e4m3fn - fc2_dglu_kwargs["b_major"] = "n" + fc2_dactivation_kwargs["b_ptrs"] = fc2_b_ptrs + fc2_dactivation_kwargs["sfb_ptrs"] = fc2_sfb_ptrs + fc2_dactivation_kwargs["n"] = fc2_weight_shape[1] + fc2_dactivation_kwargs["b_dtype"] = torch.float8_e4m3fn + fc2_dactivation_kwargs["b_major"] = "n" - fc2_dgrad_kernel_out = self.grouped_gemm_dglu_kernel()(**fc2_dglu_kwargs) + fc2_dgrad_kernel_out = self.grouped_gemm_dactivation_kernel()(**fc2_dactivation_kwargs) fc1_dy_row_data = fc2_dgrad_kernel_out["d_row_tensor"] fc1_dy_row_data = fc1_dy_row_data.view(out_shape[0], fc1_weight_shape[0]) @@ -523,6 +552,37 @@ def fuser_backward( fc1_dy_col_scale = fc2_dgrad_kernel_out["sfd_col_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) grad_scales = fc2_dgrad_kernel_out["dprob_tensor"].view(-1) + if recompute_fc2_x_from_dsrelu: + d_srelu_tensor = fc2_dgrad_kernel_out.get("d_srelu_tensor") + if d_srelu_tensor is None: + raise RuntimeError( + "SReLU recompute is enabled, but the DSReLU kernel did not return " + "the recomputed FC2 input tensor." + ) + + sfd_col_d_srelu_tensor = fc2_dgrad_kernel_out.get("sfd_col_d_srelu_tensor") + if sfd_col_d_srelu_tensor is None: + raise RuntimeError( + "SReLU recompute is enabled, but the DSReLU kernel did not return " + "the recomputed FC2 input column scale tensor." + ) + + fc2_x_col_data = d_srelu_tensor.view(out_shape[0], fc2_weight_shape[1]) + fc2_x_col_scale = sfd_col_d_srelu_tensor.permute(5, 2, 4, 0, 1, 3) + grouped_fc2_x = GroupedTensor( + shape=(out_shape[0], fc2_weight_shape[1]), + dtype=dtype, + num_tensors=num_groups, + quantizer=fc2_ctx.input_quantizers[0], + data=None, + columnwise_data=fc2_x_col_data.reshape(-1), + scale_inv=None, + columnwise_scale_inv=fc2_x_col_scale.reshape(-1), + first_dims=split_sizes, + tensor_offsets=base_split_offsets * fc2_weight_shape[1], + with_gemm_swizzled_scales=True, + ) + fc2_bias_grads: Optional[list[Optional[torch.Tensor]]] = None fc2_bias_grad_packed: Optional[torch.Tensor] = None if scale_bias: @@ -547,7 +607,8 @@ def fuser_backward( else: fc2_bias_grads = [fc2_dbias_packed[idx] for idx in range(num_groups)] - grad_scales = grad_scales.to(dtype=dtype) + if grad_scales is not None: + grad_scales = grad_scales.to(dtype=dtype) fc1_bias_grads: Optional[list[Optional[torch.Tensor]]] = None fc1_bias_grad_packed: Optional[torch.Tensor] = None @@ -618,7 +679,7 @@ def fuser_backward( "a_tensor": fc1_dgrad_a_data, "sfa_tensor": fc1_dgrad_a_scales, "padded_offsets": split_points, - "alpha_tensor": alpha_tensor.float(), + "alpha_tensor": alpha_tensor, "norm_const_tensor": None, "prob_tensor": torch.ones((out_shape[0], 1, 1), dtype=torch.float32, device=device), "acc_dtype": torch.float32, @@ -703,13 +764,38 @@ def fuser_backward( ) fc2_grad_extra = (None, None) if fc2_op._scale_bias else (None,) + activation_grad_extra = (grad_scales,) if grad_scales is not None else () return ( grad_input, [fc1_grad_params, (), fc2_grad_params], - [(None,), (grad_scales,), fc2_grad_extra], + [(None,), activation_grad_extra, fc2_grad_extra], ) +class BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase_MXFP8): + """Fused backward op for GroupedLinear + scaled GLU + GroupedLinear.""" + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_dactivation_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, GLU activation backward, and scale grad.""" + from cudnn import grouped_gemm_dglu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_dglu_wrapper_sm100 + + +class BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase_MXFP8): + """Fused backward op for GroupedLinear + scaled unary activation + GroupedLinear.""" + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_dactivation_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM and dSReLU activation backward.""" + from cudnn import grouped_gemm_dsrelu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_dsrelu_wrapper_sm100 + + def fuse_backward_ops( ops: list[FusibleOperation], *, @@ -735,10 +821,28 @@ def fuse_backward_ops( return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + fused_op_cls=BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + ) + + +def fuse_backward_srelu_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply GroupedLinear + ScaledSReLU + GroupedLinear fusion for backward pass.""" + + return fuse_grouped_mlp_ops( + ops, + recipe=recipe, + fused_op_cls=BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8, + activation_op_types=(ScaledSReLU,), ) # Register fusion if available -if BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): +if BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported(): register_backward_fusion(fuse_backward_ops, prepend=True) +if BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8.is_supported(): + register_backward_fusion(fuse_backward_srelu_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 91db2ff9b7..b003ea4b70 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -19,12 +19,14 @@ from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...constants import MXFP8_BLOCK_SCALING_SIZE -from ..basic import GroupedLinear, ScaledClampedQGeGLU, ScaledSwiGLU +from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( _cudnn_frontend_version_supported, + _nvidia_cudnn_frontend_supports_wgrad, fuse_grouped_mlp_ops, + is_glu_activation, is_quantized_tensor, maybe_dequantize, validate_grouped_mlp_dims, @@ -45,20 +47,38 @@ def _pack_grouped_linear_bias_for_cudnn(linear_op: GroupedLinear) -> Optional[to return torch.stack(rows, dim=0).transpose(0, 1) -class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): - """Fused op for MXFP8 GroupedLinear + scaled GLU + GroupedLinear +@functools.lru_cache(maxsize=1) +def _grouped_gemm_dsrelu_backward_supported() -> bool: + """Whether the cuDNN FE grouped GEMM dSReLU backward wrapper is available.""" + if int(os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "0")) <= 0: + return False + if get_device_compute_capability()[0] != 10: + return False + try: + from cudnn import ( + grouped_gemm_dsrelu_wrapper_sm100, + ) # pylint: disable=import-outside-toplevel + except ImportError: + return False + return grouped_gemm_dsrelu_wrapper_sm100 is not None + + +def _srelu_fc2_input_recompute_enabled() -> bool: + """Whether SReLU backward should regenerate the FC2 input instead of saving it.""" + return int(os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP_SRELU_RECOMPUTE", "1")) > 0 + + +class _ForwardGroupedMLP_CuTeGEMMBase_MXFP8(FusedOperation): + """Base fused op for MXFP8 GroupedLinear + activation + GroupedLinear. Uses experimental CuTe DSL kernel from cuDNN front-end. """ @classmethod - @functools.lru_cache(maxsize=None) - def grouped_gemm_glu_kernel(cls) -> Callable: - """Fused kernel for grouped GEMM, GLU activation, and post-multiplication.""" - from cudnn import grouped_gemm_glu_wrapper_sm100 # pylint: disable=no-name-in-module - - return grouped_gemm_glu_wrapper_sm100 + def grouped_gemm_activation_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, activation, and post-multiplication.""" + raise NotImplementedError @classmethod @functools.lru_cache(maxsize=None) @@ -79,7 +99,7 @@ def is_supported(cls) -> bool: if not _cudnn_frontend_version_supported(): return False try: - cls.grouped_gemm_glu_kernel() + cls.grouped_gemm_activation_kernel() cls.grouped_gemm_quant_kernel() except ImportError: return False @@ -89,17 +109,26 @@ def __init__( self, *, fc1: GroupedLinear, - swiglu: ScaledSwiGLU | ScaledClampedQGeGLU, + activation: Optional[FusibleOperation], fc2: GroupedLinear, ) -> None: - super().__init__((fc1, swiglu, fc2)) + if activation is None: + raise TypeError("Expected a grouped MLP activation op.") + super().__init__((fc1, activation, fc2)) if not self.is_supported(): - self.grouped_gemm_glu_kernel() # Try triggering import error + self.grouped_gemm_activation_kernel() # Try triggering import error raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") - validate_grouped_mlp_dims(fc1, swiglu, fc2) - # The cuDNN geglu implementation corresponds to ScaledClampedQGeGLU. - # The act_func string should be fixed on the cuDNN FE side. - self._cudnn_act_func: str = "geglu" if isinstance(swiglu, ScaledClampedQGeGLU) else "swiglu" + validate_grouped_mlp_dims(fc1, activation, fc2) + if not is_glu_activation(activation): + # grouped_gemm_srelu_wrapper_sm100 is SReLU-specific and does not + # take the GLU ``act_func`` selector. + self._cudnn_act_func: Optional[str] = None + else: + # The cuDNN geglu implementation corresponds to ScaledClampedQGeGLU. + # The act_func string should be fixed on the cuDNN FE side. + self._cudnn_act_func = ( + "geglu" if isinstance(activation, ScaledClampedQGeGLU) else "swiglu" + ) def fuser_forward( self, @@ -113,7 +142,7 @@ def fuser_forward( ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations fc1_op, _, fc2_op = self.basic_ops - fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs # Tensor properties fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) @@ -164,7 +193,7 @@ def fuser_forward( split_points = base_split_offsets[1:].to(dtype=torch.int) fc2_x_tensor_offsets = base_split_offsets * fc2_weight_shape[1] - # Extract post-scales from extra input + # Extract per-row activation probabilities from the middle op. scales = basic_op_extra_inputs[1][0] # Prepare FC1 grouped weight tensor for fused kernels. @@ -281,20 +310,24 @@ def fuser_forward( fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) - norm_const_tensor = get_cached_ones_tensor(1, dtype, device) + norm_const_tensor = get_cached_ones_tensor(1, torch.float32, device) current_stream = torch.cuda.current_stream().cuda_stream fc1_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc1_op) fc2_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc2_op) - fc1_glu_kwargs = { + fc1_activation_kwargs = { "a_tensor": fc1_x_data, "sfa_tensor": fc1_x_scales, "padded_offsets": split_points, "alpha_tensor": alpha_tensor, "bias_tensor": fc1_bias_packed, "norm_const_tensor": norm_const_tensor, - "prob_tensor": scales.detach().to(dtype=dtype).reshape(-1, 1, 1), + "prob_tensor": ( + scales.detach().to(dtype=dtype).reshape(-1, 1, 1) + if scales is not None + else torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device) + ), "acc_dtype": torch.float32, "c_dtype": torch.bfloat16, "d_dtype": torch.float8_e4m3fn, @@ -302,9 +335,10 @@ def fuser_forward( "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, "current_stream": current_stream, "discrete_col_sfd": True, - "act_func": self._cudnn_act_func, "use_dynamic_sched": True, } + if self._cudnn_act_func is not None: + fc1_activation_kwargs["act_func"] = self._cudnn_act_func if fc1_op.single_grouped_weight: # Clone and swizzle scales for GEMM. @@ -329,8 +363,8 @@ def fuser_forward( ) fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) - fc1_glu_kwargs["b_tensor"] = fc1_w_data - fc1_glu_kwargs["sfb_tensor"] = fc1_w_scales + fc1_activation_kwargs["b_tensor"] = fc1_w_data + fc1_activation_kwargs["sfb_tensor"] = fc1_w_scales else: # Discrete-weight kernel: per-expert data/scale pointers fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sw = tex.get_device_pointer_for_data_and_scales( @@ -340,13 +374,13 @@ def fuser_forward( rowwise=True, data_dtype=grouped_fc1_weight[0]._fp8_dtype, ) - fc1_glu_kwargs["b_ptrs"] = fc1_b_ptrs - fc1_glu_kwargs["sfb_ptrs"] = fc1_sfb_ptrs - fc1_glu_kwargs["n"] = fc1_weight_shape[0] - fc1_glu_kwargs["b_dtype"] = torch.float8_e4m3fn - fc1_glu_kwargs["b_major"] = "k" + fc1_activation_kwargs["b_ptrs"] = fc1_b_ptrs + fc1_activation_kwargs["sfb_ptrs"] = fc1_sfb_ptrs + fc1_activation_kwargs["n"] = fc1_weight_shape[0] + fc1_activation_kwargs["b_dtype"] = torch.float8_e4m3fn + fc1_activation_kwargs["b_major"] = "k" - fc1_kernel_out = self.grouped_gemm_glu_kernel()(**fc1_glu_kwargs) + fc1_kernel_out = self.grouped_gemm_activation_kernel()(**fc1_activation_kwargs) # Unpack kernel outputs # Note: Fused kernel outputs tensors with non-contiguous @@ -357,8 +391,8 @@ def fuser_forward( # Column-wise data logical shape: (sum(m_splits), k, 1) # Column-wise scale logical shape: (32 (block col), 4 (block col), # k/128, 4 (block row), sum(m_splits)/128, 1) - swiglu_in = fc1_kernel_out["c_tensor"] - swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0]) + activation_in = fc1_kernel_out["c_tensor"] + activation_in = activation_in.view(in_shape[0], fc1_weight_shape[0]) fc2_in_row_data = fc1_kernel_out["d_tensor"] fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]) fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] @@ -397,7 +431,7 @@ def fuser_forward( "a_tensor": fc1_kernel_out["d_tensor"], "sfa_tensor": fc1_kernel_out["sfd_row_tensor"], "padded_offsets": split_points, - "alpha_tensor": alpha_tensor.float(), + "alpha_tensor": alpha_tensor, "bias_tensor": fc2_bias_packed, "norm_const_tensor": None, "prob_tensor": fc2_scales_tensor, @@ -450,10 +484,20 @@ def fuser_forward( # Save state for backward pass if requires_grad: - mark_grouped_tensor(grouped_fc1_x, swiglu_in, scales, grouped_fc2_x) + mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) + activation_op = self.basic_ops[1] + activation_is_srelu = isinstance(activation_op, ScaledSReLU) + recompute_srelu_fc2_x = ( + activation_is_srelu + and weight_requires_grad + and _srelu_fc2_input_recompute_enabled() + and _grouped_gemm_dsrelu_backward_supported() + and _nvidia_cudnn_frontend_supports_wgrad() + ) + saved_grouped_fc2_x = None if recompute_srelu_fc2_x else grouped_fc2_x # Save the input ``GroupedTensor``s themselves for the activations. - for grouped_fc_x in (grouped_fc1_x, grouped_fc2_x): + for grouped_fc_x in (grouped_fc1_x, saved_grouped_fc2_x): if grouped_fc_x is not None: grouped_fc_x.rowwise_data = None grouped_fc_x.scale_inv = None @@ -481,11 +525,11 @@ def fuser_forward( fc1_ctx.input_requires_grad = input_requires_grad fc1_ctx.weight_requires_grad = weight_requires_grad - # Scaled SwiGLU - swiglu_ctx.save_for_backward(swiglu_in, scales) - swiglu_ctx.input_requires_grad = True - swiglu_ctx.extra_input_requires_grad = True - swiglu_ctx.dtype = dtype + # Activation + activation_ctx.save_for_backward(activation_in, scales) + activation_ctx.extra_input_requires_grad = True + activation_ctx.input_requires_grad = True + activation_ctx.dtype = dtype # FC2 saved-tensor layout. Matches the unfused # ``GroupedLinear._fuser_forward_grouped_tensor`` layout so the @@ -504,7 +548,7 @@ def fuser_forward( ] if fc2_op._scale_bias: fc2_saved.append(fc2_scales) - fc2_saved.append(grouped_fc2_x) + fc2_saved.append(saved_grouped_fc2_x) fc2_saved.extend(fc2_weight_tensors) fc2_ctx.save_for_backward(*fc2_saved) fc2_ctx.use_grouped_tensor_path = True @@ -516,10 +560,35 @@ def fuser_forward( fc2_ctx.dtype = dtype fc2_ctx.input_requires_grad = input_requires_grad fc2_ctx.weight_requires_grad = weight_requires_grad + fc2_ctx.recompute_input_from_dsrelu = recompute_srelu_fc2_x return fc2_out, [(), (), ()] +class ForwardGroupedMLP_CuTeGEMMGLU_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase_MXFP8): + """Fused op for MXFP8 GroupedLinear + scaled GLU + GroupedLinear.""" + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_activation_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, GLU activation, and post-multiplication.""" + from cudnn import grouped_gemm_glu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_glu_wrapper_sm100 + + +class ForwardGroupedMLP_CuTeGEMMUnary_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase_MXFP8): + """Fused op for MXFP8 GroupedLinear + scaled unary activation + GroupedLinear.""" + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_activation_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, SReLU activation, and post-multiplication.""" + from cudnn import grouped_gemm_srelu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_srelu_wrapper_sm100 + + def fuse_forward_ops( ops: list[FusibleOperation], *, @@ -545,10 +614,28 @@ def fuse_forward_ops( return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, + ) + + +def fuse_forward_srelu_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply GroupedLinear + ScaledSReLU + GroupedLinear fusion for forward pass.""" + + return fuse_grouped_mlp_ops( + ops, + recipe=recipe, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMUnary_MXFP8, + activation_op_types=(ScaledSReLU,), ) # Register fusion if available -if ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): +if ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): register_forward_fusion(fuse_forward_ops, prepend=True) +if ForwardGroupedMLP_CuTeGEMMUnary_MXFP8.is_supported(): + register_forward_fusion(fuse_forward_srelu_ops, prepend=True)