Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 106 additions & 27 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import (
_cudnn_frontend_version_supported,
is_glu_activation,
)

from transformer_engine.pytorch.ops.fused import (
Expand Down Expand Up @@ -2480,6 +2481,59 @@ def test_scaled_swiglu(
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)

@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
def test_scaled_srelu(
self,
*,
in_shape: Iterable[int],
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
) -> None:
"""SReLU with post-scale"""

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
scales_ref, scales_test = make_reference_and_test_tensors(
in_shape[:-1],
test_dtype=dtype,
test_device=device,
requires_grad=scales_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y = torch.nn.functional.relu(x_ref).square()
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)

# Implementation with fusible operation
op = te_ops.ScaledSReLU()
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
y_test.backward(dy_test)

# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)

def test_interleaved_scaled_swiglu(self):
"""SwiGLU with post-scale and block interleaved input format"""
self.test_scaled_swiglu(
Expand Down Expand Up @@ -3570,7 +3624,9 @@ def test_layernorm_mlp(
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
@pytest.mark.parametrize("delay_wgrad_compute", (False, True))
@pytest.mark.parametrize("hidden_size", (128, 256))
@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu"))
@pytest.mark.parametrize(
"activation", ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_srelu")
)
def test_grouped_mlp(
self,
*,
Expand All @@ -3588,7 +3644,7 @@ def test_grouped_mlp(
delay_wgrad_compute: bool,
activation: str,
) -> None:
"""GroupedLinear + ScaledSwiGLU / ScaledClampedQGeGLU + GroupedLinear"""
"""GroupedLinear + scaled activation + GroupedLinear"""

# Split sizes
split_sizes = [split_alignment * (i) for i in range(group_size)]
Expand All @@ -3601,16 +3657,30 @@ def test_grouped_mlp(

# Skip invalid configurations
with_quantization = quantization is not None
if activation == "scaled_swiglu":
scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
elif activation == "scaled_clamped_qgeglu":
scaled_act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
elif activation == "scaled_srelu":
scaled_act = te_ops.ScaledSReLU()
else:
raise ValueError(f"Unexpected grouped MLP activation ({activation})")
activation_is_glu = is_glu_activation(scaled_act)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if single_grouped_weight and quantization != "mxfp8":
pytest.skip("single_grouped_weight is only supported for MXFP8 quantization")
if single_grouped_bias and not bias:
pytest.skip("single_grouped_bias requires bias=True")
if with_quantization and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
if not activation_is_glu and quantization != "mxfp8":
pytest.skip("Scaled unary grouped MLP fusion is only supported with MXFP8")
if not activation_is_glu and glu_interleave_size is not None:
pytest.skip("Unary activations do not use GLU interleaving")
if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias:
# TODO: ksivaman: Need to debug numerics for this case.
pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU")
fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size

# Random data
x_ref, x_test = make_reference_and_test_tensors(
Expand Down Expand Up @@ -3641,7 +3711,7 @@ def test_grouped_mlp(
fc2_bs_ref, fc2_bs_test = [], []
for _ in range(group_size):
fc1_w_ref, fc1_w_test = make_reference_and_test_tensors(
(2 * hidden_size, hidden_size),
(fc1_out_features, hidden_size),
min=-0.25,
max=0.25,
quantization=quantization,
Expand All @@ -3660,7 +3730,7 @@ def test_grouped_mlp(
fc2_b_ref, fc2_b_test = None, None
if bias:
fc1_b_ref, fc1_b_test = make_reference_and_test_tensors(
(2 * hidden_size,),
(fc1_out_features,),
min=-0.5,
max=0.5,
test_dtype=dtype,
Expand Down Expand Up @@ -3689,7 +3759,7 @@ def test_grouped_mlp(
for group_idx in range(group_size):
x = xs[group_idx]
x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx])
if glu_interleave_size is not None:
if activation_is_glu and glu_interleave_size is not None:
x = x.reshape(
-1,
2 * hidden_size // (2 * glu_interleave_size),
Expand All @@ -3698,15 +3768,20 @@ def test_grouped_mlp(
)
x = x.transpose(1, 2)
x = x.reshape(-1, 2 * hidden_size)
x1, x2 = x.chunk(2, dim=-1)
if activation == "scaled_swiglu":
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
else:
elif activation == "scaled_clamped_qgeglu":
x1, x2 = x.chunk(2, dim=-1)
lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype)
geglu_alpha = 1.702
x1c = torch.minimum(x1, lim)
x2c = torch.clamp(x2, -lim, lim)
x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c))
elif activation == "scaled_srelu":
x = torch.nn.functional.relu(x).square()
else:
raise ValueError(f"Unexpected grouped MLP activation ({activation})")
x = x * probs[group_idx].unsqueeze(-1)
x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx])
if bias:
Expand All @@ -3717,16 +3792,11 @@ def test_grouped_mlp(

# Construct operations
recipe = make_recipe(quantization)
scaled_act = (
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
if activation == "scaled_swiglu"
else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
)
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
fc1 = te_ops.GroupedLinear(
group_size,
hidden_size,
2 * hidden_size,
fc1_out_features,
bias=bias,
device=device,
dtype=dtype,
Expand Down Expand Up @@ -3810,22 +3880,31 @@ def test_grouped_mlp(
if (
quantization == "mxfp8"
and dtype in (torch.bfloat16, torch.float16)
and glu_interleave_size == 32
and (
(not activation_is_glu and glu_interleave_size is None)
or (activation_is_glu and glu_interleave_size == 32)
)
and _cudnn_frontend_version_supported()
):
if te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
if activation_is_glu:
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8
else:
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary_MXFP8
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8
if forward_cls.is_supported():
forward_ops = module._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8,
forward_cls,
)
if te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
if backward_cls is not None and backward_cls.is_supported():
backward_ops = module._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8,
backward_cls,
)

# Loose tols for sanity checking
Expand Down Expand Up @@ -3910,9 +3989,9 @@ def test_grouped_mlp_single_weight_numerics(
) -> None:
"""single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP."""

if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system")
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system")

split_sizes = [split_alignment * (i + 1) for i in range(group_size)]
Expand Down Expand Up @@ -4014,12 +4093,12 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]:
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8,
te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8,
)
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8,
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8,
)

if single_grouped_weight:
Expand Down Expand Up @@ -4132,9 +4211,9 @@ def test_grouped_mlp_overwrite_main_grad(
that read ``.grad`` don't see stale bytes from the cached dummy).
"""

if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system")
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system")

recipe = make_recipe("mxfp8")
Expand Down Expand Up @@ -4266,7 +4345,7 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8(
) -> None:
"""Grouped MLP forward+backward should be CUDA graph capturable (MXFP8)."""

if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP is not supported on this system")
if dtype not in (torch.bfloat16, torch.float16):
pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16")
Expand Down Expand Up @@ -4408,12 +4487,12 @@ def train_step(
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8,
te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8,
)
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8,
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8,
)

fresh_x = torch.randn_like(static_x)
Expand Down
Loading
Loading