From 9fceea0391993a45ed3a5de0efe42d85ed5287d6 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 11 May 2026 14:00:24 -0700 Subject: [PATCH 01/12] wip Signed-off-by: Jeremy Berchtold --- docs/examples/jax_examples/attention.ipynb | 31 ++ docs/examples/jax_examples/dense.ipynb | 535 +++++++++++++++++++++ docs/examples/jax_examples/moe.ipynb | 31 ++ docs/examples/quickstart_jax_utils.py | 53 ++ docs/examples/te_jax_integration.ipynb | 445 ++--------------- 5 files changed, 681 insertions(+), 414 deletions(-) create mode 100644 docs/examples/jax_examples/attention.ipynb create mode 100644 docs/examples/jax_examples/dense.ipynb create mode 100644 docs/examples/jax_examples/moe.ipynb diff --git a/docs/examples/jax_examples/attention.ipynb b/docs/examples/jax_examples/attention.ipynb new file mode 100644 index 0000000000..122595ec31 --- /dev/null +++ b/docs/examples/jax_examples/attention.ipynb @@ -0,0 +1,31 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "title", + "metadata": {}, + "source": [ + "# JAX: Attention with TransformerEngine\n", + "\n", + "**TODO — Coming soon.**\n", + "\n", + "This notebook will cover replacing `flax.linen.dot_product_attention` with TE's fused attention path (`transformer_engine.jax.flax.DotProductAttention`), including:\n", + "\n", + "- FP8/MXFP8 attention numerics and when it's safe to enable,\n", + "- causal, sliding-window, and packed/THD layouts,\n", + "- context parallelism (all-gather and ring variants).\n", + "\n", + "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/jax_examples/dense.ipynb b/docs/examples/jax_examples/dense.ipynb new file mode 100644 index 0000000000..ff574a3cb2 --- /dev/null +++ b/docs/examples/jax_examples/dense.ipynb @@ -0,0 +1,535 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro-md", + "metadata": {}, + "source": [ + "# JAX: Dense GEMMs with TransformerEngine\n", + "\n", + "This notebook walks through replacing a plain `flax.linen.Dense`'s GEMM with TransformerEngine's quantized GEMM. We start with the headline single-GPU speedup on a single Dense layer, look at the numerical cost, then scale the same Dense layer to a DP=2/TP=2 mesh.\n", + "\n", + "**Audience.** You're comfortable with Flax Linen: `nn.Module`, `init`/`apply`, variable collections, and RNG handling. We won't re-introduce those.\n", + "\n", + "**Recipe.** We use `MXFP8BlockScaling` as the default — it has no extra runtime state to manage. MXFP8 requires a Blackwell-class GPU; on Hopper, swap in `DelayedScaling` or `Float8CurrentScaling`.\n", + "\n", + "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "s1-md", + "metadata": {}, + "source": [ + "## 1. Baseline: a plain Flax Dense block\n", + "\n", + "We isolate the optimization to a single linear layer so it's clear what's changing. `dot_general_cls` is exposed as a constructor argument so we can swap in TE later without touching the model definition." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "s1-imports", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-11T20:59:17.120315Z", + "iopub.status.busy": "2026-05-11T20:59:17.120200Z", + "iopub.status.idle": "2026-05-11T20:59:17.920419Z", + "shell.execute_reply": "2026-05-11T20:59:17.919880Z" + } + }, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"..\") # so we can import quickstart_jax_utils from docs/examples/\n", + "\n", + "import warnings\n", + "# Silence a benign TE warning about tpsp_resource; tp_resource and tpsp_resource\n", + "# are mutually exclusive in TE's MeshResource, and tp_resource is what we want.\n", + "warnings.filterwarnings(\"ignore\", message=\"Tensor sequence parallelism is detected\")\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from flax import linen as nn\n", + "import quickstart_jax_utils as utils" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "s1-model", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-11T20:59:17.921750Z", + "iopub.status.busy": "2026-05-11T20:59:17.921564Z", + "iopub.status.idle": "2026-05-11T20:59:23.193250Z", + "shell.execute_reply": "2026-05-11T20:59:23.192560Z" + } + }, + "outputs": [], + "source": [ + "class FlaxDenseBlock(nn.Module):\n", + " \"\"\"One linear layer. `dot_general_cls` lets us swap the GEMM impl.\"\"\"\n", + " features: int\n", + " dot_general_cls: callable = lambda: None\n", + "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " return nn.Dense(\n", + " features=self.features,\n", + " use_bias=False,\n", + " dot_general=self.dot_general_cls(),\n", + " )(x)\n", + "\n", + "# Shapes chosen to be large enough that quantization actually pays off.\n", + "batch, seq, hidden, out_features = 4, 2048, 4096, 16384\n", + "dtype = jnp.bfloat16\n", + "\n", + "key = jax.random.PRNGKey(0)\n", + "k_init, k_x, k_dy = jax.random.split(key, 3)\n", + "x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype)\n", + "dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype)\n", + "\n", + "baseline = FlaxDenseBlock(features=out_features)\n", + "baseline_vars = baseline.init(k_init, x)" + ] + }, + { + "cell_type": "markdown", + "id": "s2-md", + "metadata": {}, + "source": [ + "## 2. Quantized Dense via `make_dot_general_cls`\n", + "\n", + "TE exposes a single helper, `te_flax.make_dot_general_cls(recipe)`, that returns a Flax module class you pass directly to `nn.Dense(..., dot_general=...)`. Internally it routes the GEMM through `transformer_engine.jax.dense.dense`, which handles the cast, the quantized matmul, and the VJP.\n", + "\n", + "Crucially, **your model parameters are still yours**. TE doesn't create the `kernel`; it only wraps the multiply. Initialization, sharding annotations, and optimizer state stay where they were." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "s2-wire", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-11T20:59:23.194639Z", + "iopub.status.busy": "2026-05-11T20:59:23.194505Z", + "iopub.status.idle": "2026-05-11T20:59:23.715045Z", + "shell.execute_reply": "2026-05-11T20:59:23.714507Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Variable collections: ['params']\n", + "{'params': {'Dense_0': {'kernel': ((4096, 16384), dtype('float32'))}}}\n" + ] + } + ], + "source": [ + "from transformer_engine.jax import flax as te_flax\n", + "from transformer_engine.common.recipe import MXFP8BlockScaling\n", + "\n", + "recipe = MXFP8BlockScaling()\n", + "te_dot_general_cls = te_flax.make_dot_general_cls(recipe)\n", + "\n", + "te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls)\n", + "te_vars = te_model.init(k_init, x)\n", + "\n", + "print(\"Variable collections:\", list(te_vars.keys()))\n", + "print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars))" + ] + }, + { + "cell_type": "markdown", + "id": "s3-md", + "metadata": {}, + "source": [ + "## 3. Does it actually help? — single-GPU performance\n", + "\n", + "Let's measure first, ask questions later. `speedometer` runs a JIT-compiled forward+backward loop with warmup, on the same input for both models." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "s3-bench", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-11T20:59:23.716295Z", + "iopub.status.busy": "2026-05-11T20:59:23.716158Z", + "iopub.status.idle": "2026-05-11T20:59:25.886104Z", + "shell.execute_reply": "2026-05-11T20:59:25.885457Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bf16 baseline:\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean time: 4.13301944732666 ms\n", + "\n", + "TE MXFP8BlockScaling:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "E0511 13:59:25.678780 152730 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean time: 1.6977357864379883 ms\n" + ] + } + ], + "source": [ + "print(\"bf16 baseline:\")\n", + "utils.speedometer(\n", + " model_apply_fn=baseline.apply,\n", + " variables=baseline_vars,\n", + " input=x, output_grad=dy,\n", + ")\n", + "\n", + "print(f\"\\nTE {type(recipe).__name__}:\")\n", + "utils.speedometer(\n", + " model_apply_fn=te_model.apply,\n", + " variables=te_vars,\n", + " input=x, output_grad=dy,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "s3-note-md", + "metadata": {}, + "source": [ + "On a single GB200, that's roughly **2.5× faster** for the fwd+bwd of one large Dense — and the only code change was passing `dot_general=te_dot_general_cls()` into `nn.Dense`.\n", + "\n", + "The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may not benefit at all because the cast + scale overhead can dominate.\n", + "\n", + "
\n", + "Remat / activation checkpointing. If your training loop uses jax.checkpoint_policies.checkpoint_dots (or any policy that matches jax.lax.dot_general), swap it for transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms. Otherwise TE's quantized GEMM primitives won't be checkpointed correctly and your performance comparison will be unfair.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "s4-md", + "metadata": {}, + "source": [ + "## 4. What does it cost? — numerical comparison\n", + "\n", + "Quantization is lossy by construction. `compare_fwd_bwd` in `quickstart_jax_utils` runs both apply functions through `jax.value_and_grad` on the same input and reports max abs/rel diff on the output, the input gradient, and the kernel gradient." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "s4-compare", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-11T20:59:25.887578Z", + "iopub.status.busy": "2026-05-11T20:59:25.887438Z", + "iopub.status.idle": "2026-05-11T20:59:28.782009Z", + "shell.execute_reply": "2026-05-11T20:59:28.781350Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y: max_abs=2.245e-01 max_rel=3.881e-02\n", + "dx: max_abs=4.766e-01 max_rel=4.144e-02\n", + "dW: max_abs=1.874e+01 max_rel=3.676e-02\n" + ] + } + ], + "source": [ + "# Make the TE model's kernel identical to the baseline's so we measure quantization\n", + "# error specifically, not init differences.\n", + "te_vars_matched = {**te_vars, \"params\": baseline_vars[\"params\"]}\n", + "\n", + "diffs = utils.compare_fwd_bwd(\n", + " apply_a=baseline.apply, variables_a=baseline_vars,\n", + " apply_b=te_model.apply, variables_b=te_vars_matched,\n", + " input=x, output_grad=dy,\n", + ")\n", + "for name, d in diffs.items():\n", + " print(f\"{name}: max_abs={d['max_abs']:.3e} max_rel={d['max_rel']:.3e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "s4-note-md", + "metadata": {}, + "source": [ + "Max relative error around a few percent is the expected MXFP8 fidelity envelope. In real training, gradient noise dominates this error." + ] + }, + { + "cell_type": "markdown", + "id": "s-state-md", + "metadata": {}, + "source": [ + "## 5. A note on recipe state\n", + "\n", + "MXFP8 is stateless — scaling factors are computed from each tensor as it flows through the GEMM, so there is nothing to persist across steps. If you swap in `DelayedScaling` instead, `init` will produce a second variable collection, `_overwrite_with_gradient`, holding `kernel_amax_history`, `kernel_scale`, `x_amax_history`, `x_scale`, etc. These are **not** model parameters — they are Flax variables that TE updates each step to compute per-tensor scales from a rolling amax window.\n", + "\n", + "If you use a stateful recipe, you must thread the *entire* `var_collect` through your training loop (not just `params`) so the history persists across steps. `MXFP8BlockScaling` avoids this entirely by scaling locally per 32-element block." + ] + }, + { + "cell_type": "markdown", + "id": "s6-md", + "metadata": {}, + "source": [ + "## 6. Multi-GPU: DP=2 / TP=2 on a single Dense\n", + "\n", + "**Prerequisite:** this section requires four GPUs.\n", + "\n", + "Keeping the same `FlaxDenseBlock` from the rest of the notebook, we run it on a 2×2 mesh with **data parallelism** on one axis and **tensor parallelism** (column-parallel: shard the kernel's output dim) on the other.\n", + "\n", + "Two pieces wire this up:\n", + "\n", + "1. A `jax.sharding.Mesh` you build once at module scope (outside JIT).\n", + "2. TE's `MeshResource`, set globally via `global_shard_guard`, which tells TE which mesh axes are DP and TP." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "s6-check", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-11T20:59:28.783268Z", + "iopub.status.busy": "2026-05-11T20:59:28.783059Z", + "iopub.status.idle": "2026-05-11T20:59:28.785499Z", + "shell.execute_reply": "2026-05-11T20:59:28.785006Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Visible devices: 4\n" + ] + } + ], + "source": [ + "n_devices = len(jax.devices())\n", + "print(f\"Visible devices: {n_devices}\")\n", + "assert n_devices >= 4, \"This section requires 4 GPUs for DP=2/TP=2.\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "s6-mesh", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-11T20:59:28.786463Z", + "iopub.status.busy": "2026-05-11T20:59:28.786330Z", + "iopub.status.idle": "2026-05-11T20:59:28.789103Z", + "shell.execute_reply": "2026-05-11T20:59:28.788745Z" + } + }, + "outputs": [], + "source": [ + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", + "from jax.experimental import mesh_utils\n", + "from transformer_engine.jax.sharding import MeshResource, global_shard_guard\n", + "\n", + "# 2x2 mesh: DP on one axis, TP on the other.\n", + "devices = mesh_utils.create_device_mesh((2, 2))\n", + "mesh = Mesh(devices, axis_names=(\"dp\", \"tp\"))\n", + "\n", + "# Tell TE which mesh axis is which. This is a *global* setting, established\n", + "# outside JIT, so TE's GEMM primitives can plan comms accordingly.\n", + "mesh_resource = MeshResource(dp_resource=\"dp\", tp_resource=\"tp\")" + ] + }, + { + "cell_type": "markdown", + "id": "s6-shard-md", + "metadata": {}, + "source": [ + "**Sharding plan:**\n", + "\n", + "| Tensor | Shape | PartitionSpec |\n", + "|---|---|---|\n", + "| Kernel (column-parallel) | `(hidden, out_features)` | `P(None, \"tp\")` |\n", + "| Input activations | `(batch, seq, hidden)` | `P(\"dp\", None, None)` |\n", + "| Gradient on output | `(batch, seq, out_features)` | `P(\"dp\", None, \"tp\")` |" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "s6-shard", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-11T20:59:28.790047Z", + "iopub.status.busy": "2026-05-11T20:59:28.789928Z", + "iopub.status.idle": "2026-05-11T20:59:30.177416Z", + "shell.execute_reply": "2026-05-11T20:59:30.176795Z" + } + }, + "outputs": [], + "source": [ + "kernel_sharding = NamedSharding(mesh, P(None, \"tp\"))\n", + "input_sharding = NamedSharding(mesh, P(\"dp\", None, None))\n", + "output_grad_sharding = NamedSharding(mesh, P(\"dp\", None, \"tp\"))\n", + "\n", + "def shard_kernel(variables):\n", + " params = variables[\"params\"]\n", + " sharded = jax.device_put(params[\"Dense_0\"][\"kernel\"], kernel_sharding)\n", + " return {**variables, \"params\": {**params,\n", + " \"Dense_0\": {**params[\"Dense_0\"], \"kernel\": sharded}}}\n", + "\n", + "x_mp_s = jax.device_put(x, input_sharding)\n", + "dy_mp_s = jax.device_put(dy, output_grad_sharding)\n", + "baseline_vars_s = shard_kernel(baseline_vars)\n", + "te_vars_s = shard_kernel(te_vars)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "s6-bench", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-11T20:59:30.178701Z", + "iopub.status.busy": "2026-05-11T20:59:30.178579Z", + "iopub.status.idle": "2026-05-11T20:59:32.572397Z", + "shell.execute_reply": "2026-05-11T20:59:32.571759Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bf16 DP=2/TP=2:\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean time: 1.7351531982421875 ms\n", + "\n", + "TE MXFP8BlockScaling DP=2/TP=2:\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean time: 0.9787273406982422 ms\n" + ] + } + ], + "source": [ + "with jax.set_mesh(mesh), global_shard_guard(mesh_resource):\n", + " print(\"bf16 DP=2/TP=2:\")\n", + " utils.speedometer(\n", + " model_apply_fn=baseline.apply,\n", + " variables=baseline_vars_s,\n", + " input=x_mp_s, output_grad=dy_mp_s,\n", + " )\n", + "\n", + " print(f\"\\nTE {type(recipe).__name__} DP=2/TP=2:\")\n", + " utils.speedometer(\n", + " model_apply_fn=te_model.apply,\n", + " variables=te_vars_s,\n", + " input=x_mp_s, output_grad=dy_mp_s,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "s6-why-md", + "metadata": {}, + "source": [ + "**Why this works.** Column-parallel TP shards the kernel along its *output* dim. The input activations are replicated along the TP axis, so the GEMM needs no input all-gather; it just runs locally on each TP shard and produces an output that's already sharded along the TP axis. The forward has no TP comms. The backward needs an all-reduce on the input-gradient (`dx`), which is partly hidden under the next layer's compute in a real model.\n", + "\n", + "With TE, the bf16→FP8 cast and the GEMM both run on the local shard. Since the column-parallel pattern keeps comms out of the forward critical path, the GEMM-compute speedup translates directly to wall-clock.\n", + "\n", + "The DP axis just shards the batch — it adds no comms in the forward; the gradient all-reduce on the backward is what gets DP-ified.\n", + "\n", + "
\n", + "Flax logical axes. In larger models, prefer Flax's logical axis rules (set via flax.linen.logical_axis_rules) and annotate kernels with kernel_axes=(...) on TE modules. The manual approach above keeps this notebook small; the logical-axes approach scales to many layers without per-layer plumbing.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "s7-md", + "metadata": {}, + "source": [ + "## 7. Collective GEMM (placeholder)\n", + "\n", + "*Coming soon.*\n", + "\n", + "A further optimization is to *overlap* the TP all-gather / all-reduce with the GEMM itself — a \"collective GEMM\". TE's JAX dense path exposes this via the `collective_op_set` argument on `transformer_engine.jax.dense.dense`. A future revision of this notebook will walk through:\n", + "\n", + "- enabling overlapped AG+GEMM and reduce-scatter+GEMM,\n", + "- the throughput delta at large TP world sizes,\n", + "- when the overlap pays off and when it doesn't (small layers, slow interconnects)." + ] + }, + { + "cell_type": "markdown", + "id": "recap-md", + "metadata": {}, + "source": [ + "## Recap\n", + "\n", + "- One line of code (`dot_general=te_flax.make_dot_general_cls(recipe)()`) swaps in a quantized GEMM.\n", + "- Your `params` tree is unchanged; stateful recipes add a `_overwrite_with_gradient` collection that must be threaded through training.\n", + "- Numerical error from MXFP8 is small enough to be dominated by gradient noise in real training.\n", + "- On a single GB200, one large Dense fwd+bwd is **~2.5× faster** with MXFP8.\n", + "- The same Dense under DP=2/TP=2 stays faster with FP8 — column-parallel TP keeps comms out of the forward critical path, so the GEMM-compute win carries through.\n", + "\n", + "**Next:** [Attention](./attention.ipynb) · [Mixture of Experts](./moe.ipynb) · [← Hub](../te_jax_integration.ipynb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/jax_examples/moe.ipynb b/docs/examples/jax_examples/moe.ipynb new file mode 100644 index 0000000000..2cf421509a --- /dev/null +++ b/docs/examples/jax_examples/moe.ipynb @@ -0,0 +1,31 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "title", + "metadata": {}, + "source": [ + "# JAX: Mixture of Experts with TransformerEngine\n", + "\n", + "**TODO — Coming soon.**\n", + "\n", + "This notebook will cover MoE layers using TE's grouped GEMM (`te_flax.make_grouped_dense_cls`, MXFP8 only), including:\n", + "\n", + "- routing + token dispatch,\n", + "- grouped quantized GEMM for the expert FFN,\n", + "- expert-parallel sharding considerations.\n", + "\n", + "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/quickstart_jax_utils.py b/docs/examples/quickstart_jax_utils.py index 0c5ec5295e..5253fb14e4 100644 --- a/docs/examples/quickstart_jax_utils.py +++ b/docs/examples/quickstart_jax_utils.py @@ -88,6 +88,59 @@ def fwd_bwd_fn(*args, **kwargs): return jax.jit(fwd_bwd_fn) +def compare_fwd_bwd( + apply_a: Callable, + variables_a: Any, + apply_b: Callable, + variables_b: Any, + input: jnp.ndarray, + output_grad: jnp.ndarray, + forward_kwargs_a: Dict[str, Any] = None, + forward_kwargs_b: Dict[str, Any] = None, + rngs_a: Dict[str, jax.random.PRNGKey] = None, + rngs_b: Dict[str, jax.random.PRNGKey] = None, +) -> Dict[str, Dict[str, float]]: + """Run fwd+bwd on two apply functions and report max abs/rel diff on y, dx, and dW.""" + forward_kwargs_a = forward_kwargs_a or {} + forward_kwargs_b = forward_kwargs_b or {} + rngs_a = rngs_a or {} + rngs_b = rngs_b or {} + + def run(apply_fn, variables, forward_kwargs, rngs): + def loss_fn(variables, inp): + out = apply_fn(variables, inp, rngs=rngs, **forward_kwargs) + return jnp.vdot(out, output_grad), out + + (_, out), (param_grads, dx) = jax.value_and_grad(loss_fn, argnums=(0, 1), has_aux=True)( + variables, input + ) + return out, dx, param_grads + + y_a, dx_a, gp_a = run(apply_a, variables_a, forward_kwargs_a, rngs_a) + y_b, dx_b, gp_b = run(apply_b, variables_b, forward_kwargs_b, rngs_b) + + kernel_leaves_a = [ + leaf for path, leaf in jax.tree_util.tree_leaves_with_path(gp_a) if "kernel" in jax.tree_util.keystr(path) + ] + kernel_leaves_b = [ + leaf for path, leaf in jax.tree_util.tree_leaves_with_path(gp_b) if "kernel" in jax.tree_util.keystr(path) + ] + dW_a = kernel_leaves_a[0] if kernel_leaves_a else None + dW_b = kernel_leaves_b[0] if kernel_leaves_b else None + + def diffs(a, b): + a = a.astype(jnp.float32) + b = b.astype(jnp.float32) + abs_diff = float(jnp.max(jnp.abs(a - b))) + denom = float(jnp.max(jnp.abs(a))) + 1e-12 + return {"max_abs": abs_diff, "max_rel": abs_diff / denom} + + result = {"y": diffs(y_a, y_b), "dx": diffs(dx_a, dx_b)} + if dW_a is not None and dW_b is not None: + result["dW"] = diffs(dW_a, dW_b) + return result + + def _split_step_rngs( rngs: Dict[str, jax.random.PRNGKey], ) -> Tuple[Dict[str, jax.random.PRNGKey], Dict[str, jax.random.PRNGKey]]: diff --git a/docs/examples/te_jax_integration.ipynb b/docs/examples/te_jax_integration.ipynb index 66d16ed52f..76e56161e7 100644 --- a/docs/examples/te_jax_integration.ipynb +++ b/docs/examples/te_jax_integration.ipynb @@ -2,452 +2,69 @@ "cells": [ { "cell_type": "markdown", - "id": "962d87bb", + "id": "title", "metadata": {}, "source": [ + "# JAX: Integrating TransformerEngine into an existing framework\n", "\n", + "This is the landing page for a series of focused notebooks on bringing TransformerEngine into a JAX+Flax codebase one optimization at a time. Each linked notebook isolates a single feature so you can see exactly what changes, what state it introduces, what it costs in numerical precision, and what it buys in speed.\n", "\n", - "# JAX: Integrating TE into an existing framework\n", + "**Audience.** Intermediate Flax users. You're already comfortable with `nn.Module`, `init`/`apply`, variable collections, and RNG handling \u2014 we won't re-introduce those.\n", "\n", - "This tutorial will cover how to integrate TransformerEngine into an existing JAX model framework, such as [MaxText's TE integration](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/src/MaxText/layers/quantizations.py#L753) or your own model framework. \n" + "**Philosophy.** TE's JAX integration is designed to be **non-invasive**: your model parameters, initialization, optimizer, and sharding annotations stay yours. The only surface area you touch is the GEMM (and, for attention, the attention kernel). Everything else is unchanged. See [MaxText's TE integration](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/src/MaxText/layers/quantizations.py#L753) for a production example." ] }, { "cell_type": "markdown", - "id": "b36876bb", + "id": "topics", "metadata": {}, "source": [ - "Let's start with a standard JAX+Flax Transformer layer" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "d5284a38", - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from flax import linen as nn\n", - "import quickstart_jax_utils as utils\n", - "from typing import Optional" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "a4d1cfdc", - "metadata": {}, - "outputs": [], - "source": [ - "class FlaxMLP(nn.Module):\n", - " \"\"\"Feed-forward network in Transformer layer\n", - " Built with plain Flax modules.\n", - " \"\"\"\n", - " hidden_size: int\n", - " ffn_hidden_size: int\n", - " dot_general_cls: callable = lambda: None\n", - "\n", - " @nn.compact\n", - " def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n", - " x = nn.Dense(features=self.ffn_hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n", - " x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n", - " x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n", - " return x\n", + "## Pick a topic\n", "\n", - "class FlaxTransformerLayer(nn.Module):\n", - " \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n", - " hidden_size: int\n", - " ffn_hidden_size: int\n", - " num_attention_heads: int\n", - " layernorm_eps: float = 1e-5\n", - " attention_dropout: float = 0.1\n", - " dot_general_cls: callable = lambda: None\n", - " \n", - " def setup(self):\n", - " self.kv_channels = self.hidden_size // self.num_attention_heads\n", + "| Notebook | Status | Covers |\n", + "|---|---|---|\n", + "| [Dense GEMMs](./jax_examples/dense.ipynb) | **Available** | `nn.Dense` \u2192 quantized GEMM via `dot_general_cls`; fwd/bwd numerics; single-GPU speedup; FSDP quantize-before-all-gather; Collective GEMM (placeholder) |\n", + "| [Attention](./jax_examples/attention.ipynb) | *Coming soon* | Fused attention via `te.flax.DotProductAttention`; FP8 attention; context parallelism |\n", + "| [Mixture of Experts](./jax_examples/moe.ipynb) | *Coming soon* | Grouped GEMM for expert FFNs; routing; expert-parallel sharding |\n", "\n", - " @nn.compact\n", - " def __call__(\n", - " self, \n", - " x: jnp.ndarray, \n", - " attention_mask: Optional[jnp.ndarray] = None,\n", - " deterministic: bool = False\n", - " ) -> jnp.ndarray:\n", - " # Create causal mask if not provided\n", - " if attention_mask is None:\n", - " attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n", - " \n", - " res = x\n", - " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - " \n", - " # Fused QKV projection\n", - " qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n", - " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = jnp.split(qkv, 3, axis=3)\n", - " \n", - " # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n", - " # which is the correct format for dot_product_attention\n", - " \n", - " # Apply dot product attention\n", - " # Note: dot_product_attention expects mask to be broadcastable to \n", - " # [batch, num_heads, q_length, kv_length], but attention_mask from \n", - " # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n", - " \n", - " # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n", - " dropout_rng = None\n", - " if not deterministic and self.attention_dropout > 0:\n", - " dropout_rng = self.make_rng('dropout')\n", - " \n", - " # See quickstart_jax.ipynb for details on using TE's faster fused attention\n", - " x = nn.dot_product_attention(\n", - " query=q,\n", - " key=k,\n", - " value=v,\n", - " mask=attention_mask,\n", - " dropout_rng=dropout_rng,\n", - " dropout_rate=self.attention_dropout,\n", - " deterministic=deterministic,\n", - " broadcast_dropout=True,\n", - " )\n", - " \n", - " # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n", - " x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n", - "\n", - " # Output projection\n", - " x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n", - " \n", - " x = res + x\n", - " \n", - " # Second residual connection\n", - " res = x\n", - " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - " \n", - " # MLP\n", - " mlp = FlaxMLP(\n", - " hidden_size=self.hidden_size,\n", - " ffn_hidden_size=self.ffn_hidden_size,\n", - " dot_general_cls=self.dot_general_cls,\n", - " )\n", - " x = mlp(x)\n", - " \n", - " return x + res\n" - ] - }, - { - "cell_type": "markdown", - "id": "db16bf70", - "metadata": {}, - "source": [ - "We've exposed `dot_general_cls` here so we can test out different GEMM implementations later. By default, Flax's `nn.Dense` will use JAX's GEMM `jax.lax.dot_general` when `dot_general` is `None`." + "Start with **Dense GEMMs** \u2014 it introduces the `dot_general_cls` pattern and recipe mechanics that the other notebooks build on." ] }, { "cell_type": "markdown", - "id": "fbc3510b", + "id": "recipes", "metadata": {}, "source": [ - "## Testing Performance\n", + "## Quantization recipes at a glance\n", "\n", - "Now let's test the performance of our FlaxTransformerLayer:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8b44649d", - "metadata": {}, - "outputs": [], - "source": [ - "# Layer configuration\n", - "hidden_size = 4096\n", - "sequence_length = 2048\n", - "batch_size = 4\n", - "ffn_hidden_size = 16384\n", - "num_attention_heads = 32\n", - "dtype = jnp.bfloat16\n", + "TE exposes its quantization choices as **recipes**. You pick one and pass it to `te_flax.make_dot_general_cls(recipe)` (or to the equivalent helper for attention / grouped GEMM).\n", "\n", - "# Synthetic data\n", - "key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n", - "x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n", - "dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "e44ed26d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Pure Flax FlaxTransformerLayer initialized successfully!\n", - "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n" - ] - } - ], - "source": [ - "# Initialize the FlaxTransformerLayer\n", - "flax_transformer = FlaxTransformerLayer(\n", - " hidden_size=hidden_size,\n", - " ffn_hidden_size=ffn_hidden_size,\n", - " num_attention_heads=num_attention_heads,\n", - ")\n", + "| Recipe | Hardware | State | When to use |\n", + "|---|---|---|---|\n", + "| `DelayedScaling` | Hopper+ | amax history (Flax variables) | Stable per-tensor FP8 training; the most battle-tested recipe |\n", + "| `Float8CurrentScaling` | Hopper+ | none | Per-tensor FP8 without an amax history; simpler bookkeeping |\n", + "| `MXFP8BlockScaling` | Blackwell+ | none | Block-scaled FP8 (32-element blocks); no state to thread through training |\n", + "| `NVFP4BlockScaling` | Blackwell+ | requires `sr_rng` | FP4 with 2D block scaling and stochastic rounding |\n", "\n", - "# Initialize parameters\n", - "params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n", - "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "de91af7a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input shape: (4, 2048, 4096)\n", - "Output shape: (4, 2048, 4096)\n", - "Output dtype: float32\n", - "Forward pass completed successfully!\n" - ] - } - ], - "source": [ - "# Example usage of forward pass\n", - "y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n", - "print(f\"Input shape: {x.shape}\")\n", - "print(f\"Output shape: {y.shape}\")\n", - "print(f\"Output dtype: {y.dtype}\")\n", - "print(\"Forward pass completed successfully!\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "037bc8d9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 18.83516788482666 ms\n" - ] - } - ], - "source": [ - "import importlib\n", - "import quickstart_jax_utils\n", - "importlib.reload(quickstart_jax_utils)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=flax_transformer.apply,\n", - " variables=params,\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " rngs={\"dropout\": dropout_key},\n", - ")" + "Import them from `transformer_engine.common.recipe`." ] }, { "cell_type": "markdown", - "id": "5e9310c9", - "metadata": {}, - "source": [ - "## Transformer Engine" - ] - }, - { - "cell_type": "markdown", - "id": "1f8e213e", - "metadata": {}, - "source": [ - "TransformerEngine/JAX is currently using Flax Linen. However, it is easily compatible with Flax NNX or Haiku.\n", - "* [Use Flax NNX and Linen together](https://flax.readthedocs.io/en/latest/guides/bridge_guide.html)\n", - "* [Haiku and Flax interop](https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html)\n", - "\n", - "Additionally, with the tutorial below, no model parameters need to be managed by TransformerEngine. You can keep all your existing model parameters, initialization, and sharding the same. The only change required is to call TE's dot_general_cls instead of the default Dense dot_general implementation. TE's dot_general_cls is a small module that performs a quantized dense VJP and stores some small recipe-specific state." - ] - }, - { - "cell_type": "markdown", - "id": "4477d4e9", - "metadata": {}, - "source": [ - "Now we'll select a recipe. `DelayedScaling` and `CurrentScaling` use per-tensor scaling and are supported on Hopper and Blackwell. `MXFP8BlockScaling` and `NVFP4BlockScaling` use block scaling or a combination of both per-tensor and block scaling and are supported on Blackwell.\n", - "\n", - "If you would like to customize the recipe further, various options can be changed by passing args to the recipe's constructor." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5ddf41e7", + "id": "conventions", "metadata": {}, - "outputs": [], "source": [ - "from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, NVFP4BlockScaling\n", - "from transformer_engine.jax import flax as te_flax \n", + "## Conventions used across these notebooks\n", "\n", - "# Choose a quantization recipe. This can be modified to any of the recipes imported above.\n", - "quantization_recipe = DelayedScaling()\n", - "\n", - "te_dot_general_cls = te_flax.make_dot_general_cls(quantization_recipe)\n", - "\n", - "rngs = {'dropout': dropout_key}\n", - "if isinstance(quantization_recipe, NVFP4BlockScaling):\n", - " # The NVFP4 recipe requires a Flax RNG for stochastic rounding\n", - " rngs['sr_rng'] = jax.random.PRNGKey(0)\n" - ] - }, - { - "cell_type": "markdown", - "id": "c8769655", - "metadata": {}, - "source": [ - "Now using this quantized dense in our model is as simple as passing in `dot_general_fn=te_dot_general`. Let's try it out!\n", + "- **Framework.** Flax Linen. (TE/JAX uses Linen; see [Flax NNX/Linen interop](https://flax.readthedocs.io/en/latest/guides/bridge_guide.html) and [Haiku/Flax interop](https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html) if you're on a different stack.)\n", + "- **Baseline dtype.** bf16 for inputs and parameters.\n", + "- **Benchmarking.** `quickstart_jax_utils.speedometer` runs a JIT-compiled fwd+bwd loop with warmup.\n", + "- **Numerical check.** `quickstart_jax_utils.compare_fwd_bwd` reports max abs/rel diff on output, input grad, and weight grad.\n", "\n", "
\n", - "\n", - "Important: Remat Policy\n", - "\n", - "TE's quantization uses specialized TE quantized GEMM primitives. If you are using any built-in JAX checkpoint policies that look for JAX GEMMs (dots), such as `jax.checkpoint_policies.checkpoint_dots`, please replace the policy with `transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms` or similar policies to ensure TE's quantized GEMM primitives are checkpointed correctly.\n", - "\n", - "If this is not performed, TE GEMMs will be rematerialized introducing an incorrect performance comparison.\n", - "\n", + "Remat / activation checkpointing. If your training loop uses jax.checkpoint_policies.checkpoint_dots (or any policy that matches jax.lax.dot_general), swap it for transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms. Otherwise TE's quantized GEMM primitives won't be checkpointed correctly and your performance comparison will be misleading.\n", "
" ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "8407d2ea", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Pure Flax FlaxTransformerLayer initialized successfully!\n", - "Parameter shapes: {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n", - "Additional state: {'_overwrite_with_gradient': {'FlaxMLP_0': {'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}, 'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}}\n" - ] - } - ], - "source": [ - "# Initialize the FlaxTransformerLayer\n", - "flax_transformer = FlaxTransformerLayer(\n", - " hidden_size=hidden_size,\n", - " ffn_hidden_size=ffn_hidden_size,\n", - " num_attention_heads=num_attention_heads,\n", - " dot_general_cls=te_dot_general_cls,\n", - ")\n", - "\n", - "# Initialize parameters\n", - "var_collect = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n", - "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, var_collect['params'])}\")\n", - "print(f\"Additional state: {jax.tree_util.tree_map(lambda x: x.shape, {k: v for k, v in var_collect.items() if k != 'params'})}\")" - ] - }, - { - "cell_type": "markdown", - "id": "abe27237", - "metadata": {}, - "source": [ - "If using a recipe that stores additional state, such as `DelayedScaling`, you'll see this additional state stored as Flax variables. It is important to maintain and pass the whole state of Flax variables `var_collect` across training steps, not just the model params, for proper usage of stateful recipes like `DelayedScaling`.\n", - "\n", - "For example, above inside `Additional state: ` you'll see the `amax_history` of each quantization which is used to compute the per-tensor scale in the `DelayedScaling` recipe." - ] - }, - { - "cell_type": "markdown", - "id": "5ab72935", - "metadata": {}, - "source": [ - "The reason we need `te_dot_general_cls` as a Flax module instead of a module-less function like `jax.lax.dot_general` is for some quantization recipes to track internal state separate from model parameters.\n", - "\n", - "Flax modules can manage 3 things:\n", - "1. Model parameters/weights, e.g. your Dense \"kernel\", \"bias\", etc.\n", - "2. RNGs for dropout, stochastic rounding, etc.\n", - "3. Flax variables. These are additional state variables that are used across training steps but are distinct from model params in that you don't take gradients or optimize them. Currently, we only use this for DelayedScaling's amax_history state\n", - "\n", - "With the simplest quantization integration shown in this tutorial, we want users to keep their existing model param setup so they don't need to worry about preserving the sharding, init distribution, etc.. So we don't need point 1 since we don't do model param creation in this codepath with dot_general_cls, but we still do need `te_dot_general_cls()` to produce a Flax module since we potentially need to do points 2 or 3 which need to be in a Flax module." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "3b6b344b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input shape: (4, 2048, 4096)\n", - "Output shape: (4, 2048, 4096)\n", - "Output dtype: float32\n", - "Forward pass completed successfully!\n" - ] - } - ], - "source": [ - "# Example usage of forward pass\n", - "y = flax_transformer.apply(var_collect, x, attention_mask=None, deterministic=True, rngs=rngs)\n", - "print(f\"Input shape: {x.shape}\")\n", - "print(f\"Output shape: {y.shape}\")\n", - "print(f\"Output dtype: {y.dtype}\")\n", - "print(\"Forward pass completed successfully!\")\n" - ] - }, - { - "cell_type": "markdown", - "id": "d178f247", - "metadata": {}, - "source": [ - "Now let's measure the performance!" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "5cc6c2a7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 10.553865432739258 ms\n" - ] - } - ], - "source": [ - "import importlib\n", - "import quickstart_jax_utils\n", - "importlib.reload(quickstart_jax_utils)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=flax_transformer.apply,\n", - " variables=var_collect,\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " rngs=rngs,\n", - ")" - ] } ], "metadata": { @@ -459,4 +76,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file From 2cb3cb25bba0c4c6acdb64e37af79f914b445981 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 11 May 2026 14:46:08 -0700 Subject: [PATCH 02/12] wip Signed-off-by: Jeremy Berchtold --- docs/examples/jax_examples/attention.ipynb | 6 - docs/examples/jax_examples/dense.ipynb | 211 ++++++--------------- docs/examples/jax_examples/moe.ipynb | 5 +- docs/examples/te_jax_integration.ipynb | 33 ++-- 4 files changed, 74 insertions(+), 181 deletions(-) diff --git a/docs/examples/jax_examples/attention.ipynb b/docs/examples/jax_examples/attention.ipynb index 122595ec31..6d8e8c3f46 100644 --- a/docs/examples/jax_examples/attention.ipynb +++ b/docs/examples/jax_examples/attention.ipynb @@ -9,12 +9,6 @@ "\n", "**TODO — Coming soon.**\n", "\n", - "This notebook will cover replacing `flax.linen.dot_product_attention` with TE's fused attention path (`transformer_engine.jax.flax.DotProductAttention`), including:\n", - "\n", - "- FP8/MXFP8 attention numerics and when it's safe to enable,\n", - "- causal, sliding-window, and packed/THD layouts,\n", - "- context parallelism (all-gather and ring variants).\n", - "\n", "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" ] } diff --git a/docs/examples/jax_examples/dense.ipynb b/docs/examples/jax_examples/dense.ipynb index ff574a3cb2..e03eb48389 100644 --- a/docs/examples/jax_examples/dense.ipynb +++ b/docs/examples/jax_examples/dense.ipynb @@ -7,11 +7,9 @@ "source": [ "# JAX: Dense GEMMs with TransformerEngine\n", "\n", - "This notebook walks through replacing a plain `flax.linen.Dense`'s GEMM with TransformerEngine's quantized GEMM. We start with the headline single-GPU speedup on a single Dense layer, look at the numerical cost, then scale the same Dense layer to a DP=2/TP=2 mesh.\n", + "This notebook walks through replacing a plain `flax.linen.Dense`'s GEMM with TransformerEngine's quantized GEMM.\n", "\n", - "**Audience.** You're comfortable with Flax Linen: `nn.Module`, `init`/`apply`, variable collections, and RNG handling. We won't re-introduce those.\n", - "\n", - "**Recipe.** We use `MXFP8BlockScaling` as the default — it has no extra runtime state to manage. MXFP8 requires a Blackwell-class GPU; on Hopper, swap in `DelayedScaling` or `Float8CurrentScaling`.\n", + "**Recipe.** We use `MXFP8BlockScaling` in this tutorial. `MXFP8BlockScaling` and `NVFP4BlockScaling` require a Blackwell-class GPU; on Hopper, swap in `DelayedScaling` or `Float8CurrentScaling`.\n", "\n", "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" ] @@ -28,14 +26,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "s1-imports", "metadata": { "execution": { - "iopub.execute_input": "2026-05-11T20:59:17.120315Z", - "iopub.status.busy": "2026-05-11T20:59:17.120200Z", - "iopub.status.idle": "2026-05-11T20:59:17.920419Z", - "shell.execute_reply": "2026-05-11T20:59:17.919880Z" + "iopub.execute_input": "2026-05-11T21:02:32.861403Z", + "iopub.status.busy": "2026-05-11T21:02:32.861276Z", + "iopub.status.idle": "2026-05-11T21:02:33.571664Z", + "shell.execute_reply": "2026-05-11T21:02:33.571149Z" } }, "outputs": [], @@ -44,9 +42,6 @@ "sys.path.append(\"..\") # so we can import quickstart_jax_utils from docs/examples/\n", "\n", "import warnings\n", - "# Silence a benign TE warning about tpsp_resource; tp_resource and tpsp_resource\n", - "# are mutually exclusive in TE's MeshResource, and tp_resource is what we want.\n", - "warnings.filterwarnings(\"ignore\", message=\"Tensor sequence parallelism is detected\")\n", "\n", "import jax\n", "import jax.numpy as jnp\n", @@ -56,14 +51,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "s1-model", "metadata": { "execution": { - "iopub.execute_input": "2026-05-11T20:59:17.921750Z", - "iopub.status.busy": "2026-05-11T20:59:17.921564Z", - "iopub.status.idle": "2026-05-11T20:59:23.193250Z", - "shell.execute_reply": "2026-05-11T20:59:23.192560Z" + "iopub.execute_input": "2026-05-11T21:02:33.573025Z", + "iopub.status.busy": "2026-05-11T21:02:33.572835Z", + "iopub.status.idle": "2026-05-11T21:02:38.767025Z", + "shell.execute_reply": "2026-05-11T21:02:38.766312Z" } }, "outputs": [], @@ -81,7 +76,6 @@ " dot_general=self.dot_general_cls(),\n", " )(x)\n", "\n", - "# Shapes chosen to be large enough that quantization actually pays off.\n", "batch, seq, hidden, out_features = 4, 2048, 4096, 16384\n", "dtype = jnp.bfloat16\n", "\n", @@ -101,9 +95,9 @@ "source": [ "## 2. Quantized Dense via `make_dot_general_cls`\n", "\n", - "TE exposes a single helper, `te_flax.make_dot_general_cls(recipe)`, that returns a Flax module class you pass directly to `nn.Dense(..., dot_general=...)`. Internally it routes the GEMM through `transformer_engine.jax.dense.dense`, which handles the cast, the quantized matmul, and the VJP.\n", + "TE exposes a helper, `te_flax.make_dot_general_cls(recipe)`, that returns a Flax module class you pass directly to `nn.Dense(..., dot_general=...)`.\n", "\n", - "Crucially, **your model parameters are still yours**. TE doesn't create the `kernel`; it only wraps the multiply. Initialization, sharding annotations, and optimizer state stay where they were." + "With this API, TE doesn't create the `kernel` params; it only wraps the GEMM. All your initialization, sharding annotations, and optimizer state stay where they were." ] }, { @@ -112,10 +106,10 @@ "id": "s2-wire", "metadata": { "execution": { - "iopub.execute_input": "2026-05-11T20:59:23.194639Z", - "iopub.status.busy": "2026-05-11T20:59:23.194505Z", - "iopub.status.idle": "2026-05-11T20:59:23.715045Z", - "shell.execute_reply": "2026-05-11T20:59:23.714507Z" + "iopub.execute_input": "2026-05-11T21:02:38.768455Z", + "iopub.status.busy": "2026-05-11T21:02:38.768315Z", + "iopub.status.idle": "2026-05-11T21:02:39.271141Z", + "shell.execute_reply": "2026-05-11T21:02:39.270566Z" } }, "outputs": [ @@ -142,14 +136,28 @@ "print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars))" ] }, + { + "cell_type": "markdown", + "id": "670f6e90", + "metadata": {}, + "source": [ + "
\n", + "

What about DelayedScaling state?

\n", + "\n", + "Most recipes are stateless — scaling factors are computed from each tensor as it flows through the GEMM, so there is nothing to persist across steps. However, if you swap in `DelayedScaling` instead, `init` will produce a second variable collection, `_overwrite_with_gradient`, holding `kernel_amax_history`, `kernel_scale`, `x_amax_history`, `x_scale`, etc. These are **not** model parameters — they are Flax variables that TE updates each step to compute per-tensor scales from a rolling amax window.\n", + "\n", + "If you use `DelayedScaling`, you must thread the *entire* `var_collect` through your training loop (not just `params`) so the history persists across steps. `MXFP8BlockScaling`, `NVFP4BlockScaling`, and `Float8CurrentScaling` do not require this.\n", + "
" + ] + }, { "cell_type": "markdown", "id": "s3-md", "metadata": {}, "source": [ - "## 3. Does it actually help? — single-GPU performance\n", + "## 3. Single-GPU performance\n", "\n", - "Let's measure first, ask questions later. `speedometer` runs a JIT-compiled forward+backward loop with warmup, on the same input for both models." + "`speedometer` runs a JIT-compiled forward+backward loop with warmup, on the same input for both models." ] }, { @@ -158,10 +166,10 @@ "id": "s3-bench", "metadata": { "execution": { - "iopub.execute_input": "2026-05-11T20:59:23.716295Z", - "iopub.status.busy": "2026-05-11T20:59:23.716158Z", - "iopub.status.idle": "2026-05-11T20:59:25.886104Z", - "shell.execute_reply": "2026-05-11T20:59:25.885457Z" + "iopub.execute_input": "2026-05-11T21:02:39.272256Z", + "iopub.status.busy": "2026-05-11T21:02:39.272128Z", + "iopub.status.idle": "2026-05-11T21:02:41.426450Z", + "shell.execute_reply": "2026-05-11T21:02:41.425748Z" } }, "outputs": [ @@ -176,7 +184,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mean time: 4.13301944732666 ms\n", + "Mean time: 4.126272201538086 ms\n", "\n", "TE MXFP8BlockScaling:\n" ] @@ -185,14 +193,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "E0511 13:59:25.678780 152730 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.\n" + "E0511 14:02:41.217310 154511 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Mean time: 1.6977357864379883 ms\n" + "Mean time: 1.689767837524414 ms\n" ] } ], @@ -222,77 +230,10 @@ "The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may not benefit at all because the cast + scale overhead can dominate.\n", "\n", "
\n", - "Remat / activation checkpointing. If your training loop uses jax.checkpoint_policies.checkpoint_dots (or any policy that matches jax.lax.dot_general), swap it for transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms. Otherwise TE's quantized GEMM primitives won't be checkpointed correctly and your performance comparison will be unfair.\n", + "Remat / activation checkpointing. If your training loop uses jax.checkpoint_policies.checkpoint_dots (or any policy that matches jax.lax.dot_general), swap it for transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms. Otherwise TE's quantized GEMM primitives won't be checkpointed correctly and your performance comparison will not be accurate.\n", "
" ] }, - { - "cell_type": "markdown", - "id": "s4-md", - "metadata": {}, - "source": [ - "## 4. What does it cost? — numerical comparison\n", - "\n", - "Quantization is lossy by construction. `compare_fwd_bwd` in `quickstart_jax_utils` runs both apply functions through `jax.value_and_grad` on the same input and reports max abs/rel diff on the output, the input gradient, and the kernel gradient." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "s4-compare", - "metadata": { - "execution": { - "iopub.execute_input": "2026-05-11T20:59:25.887578Z", - "iopub.status.busy": "2026-05-11T20:59:25.887438Z", - "iopub.status.idle": "2026-05-11T20:59:28.782009Z", - "shell.execute_reply": "2026-05-11T20:59:28.781350Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y: max_abs=2.245e-01 max_rel=3.881e-02\n", - "dx: max_abs=4.766e-01 max_rel=4.144e-02\n", - "dW: max_abs=1.874e+01 max_rel=3.676e-02\n" - ] - } - ], - "source": [ - "# Make the TE model's kernel identical to the baseline's so we measure quantization\n", - "# error specifically, not init differences.\n", - "te_vars_matched = {**te_vars, \"params\": baseline_vars[\"params\"]}\n", - "\n", - "diffs = utils.compare_fwd_bwd(\n", - " apply_a=baseline.apply, variables_a=baseline_vars,\n", - " apply_b=te_model.apply, variables_b=te_vars_matched,\n", - " input=x, output_grad=dy,\n", - ")\n", - "for name, d in diffs.items():\n", - " print(f\"{name}: max_abs={d['max_abs']:.3e} max_rel={d['max_rel']:.3e}\")" - ] - }, - { - "cell_type": "markdown", - "id": "s4-note-md", - "metadata": {}, - "source": [ - "Max relative error around a few percent is the expected MXFP8 fidelity envelope. In real training, gradient noise dominates this error." - ] - }, - { - "cell_type": "markdown", - "id": "s-state-md", - "metadata": {}, - "source": [ - "## 5. A note on recipe state\n", - "\n", - "MXFP8 is stateless — scaling factors are computed from each tensor as it flows through the GEMM, so there is nothing to persist across steps. If you swap in `DelayedScaling` instead, `init` will produce a second variable collection, `_overwrite_with_gradient`, holding `kernel_amax_history`, `kernel_scale`, `x_amax_history`, `x_scale`, etc. These are **not** model parameters — they are Flax variables that TE updates each step to compute per-tensor scales from a rolling amax window.\n", - "\n", - "If you use a stateful recipe, you must thread the *entire* `var_collect` through your training loop (not just `params`) so the history persists across steps. `MXFP8BlockScaling` avoids this entirely by scaling locally per 32-element block." - ] - }, { "cell_type": "markdown", "id": "s6-md", @@ -316,10 +257,10 @@ "id": "s6-check", "metadata": { "execution": { - "iopub.execute_input": "2026-05-11T20:59:28.783268Z", - "iopub.status.busy": "2026-05-11T20:59:28.783059Z", - "iopub.status.idle": "2026-05-11T20:59:28.785499Z", - "shell.execute_reply": "2026-05-11T20:59:28.785006Z" + "iopub.execute_input": "2026-05-11T21:02:44.352297Z", + "iopub.status.busy": "2026-05-11T21:02:44.352166Z", + "iopub.status.idle": "2026-05-11T21:02:44.354494Z", + "shell.execute_reply": "2026-05-11T21:02:44.354073Z" } }, "outputs": [ @@ -343,10 +284,10 @@ "id": "s6-mesh", "metadata": { "execution": { - "iopub.execute_input": "2026-05-11T20:59:28.786463Z", - "iopub.status.busy": "2026-05-11T20:59:28.786330Z", - "iopub.status.idle": "2026-05-11T20:59:28.789103Z", - "shell.execute_reply": "2026-05-11T20:59:28.788745Z" + "iopub.execute_input": "2026-05-11T21:02:44.355538Z", + "iopub.status.busy": "2026-05-11T21:02:44.355424Z", + "iopub.status.idle": "2026-05-11T21:02:44.358077Z", + "shell.execute_reply": "2026-05-11T21:02:44.357662Z" } }, "outputs": [], @@ -384,10 +325,10 @@ "id": "s6-shard", "metadata": { "execution": { - "iopub.execute_input": "2026-05-11T20:59:28.790047Z", - "iopub.status.busy": "2026-05-11T20:59:28.789928Z", - "iopub.status.idle": "2026-05-11T20:59:30.177416Z", - "shell.execute_reply": "2026-05-11T20:59:30.176795Z" + "iopub.execute_input": "2026-05-11T21:02:44.359044Z", + "iopub.status.busy": "2026-05-11T21:02:44.358932Z", + "iopub.status.idle": "2026-05-11T21:02:45.715164Z", + "shell.execute_reply": "2026-05-11T21:02:45.714580Z" } }, "outputs": [], @@ -414,10 +355,10 @@ "id": "s6-bench", "metadata": { "execution": { - "iopub.execute_input": "2026-05-11T20:59:30.178701Z", - "iopub.status.busy": "2026-05-11T20:59:30.178579Z", - "iopub.status.idle": "2026-05-11T20:59:32.572397Z", - "shell.execute_reply": "2026-05-11T20:59:32.571759Z" + "iopub.execute_input": "2026-05-11T21:02:45.716696Z", + "iopub.status.busy": "2026-05-11T21:02:45.716566Z", + "iopub.status.idle": "2026-05-11T21:02:48.119943Z", + "shell.execute_reply": "2026-05-11T21:02:48.119290Z" } }, "outputs": [ @@ -432,7 +373,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mean time: 1.7351531982421875 ms\n", + "Mean time: 1.7258834838867188 ms\n", "\n", "TE MXFP8BlockScaling DP=2/TP=2:\n" ] @@ -441,7 +382,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mean time: 0.9787273406982422 ms\n" + "Mean time: 0.9692144393920898 ms\n" ] } ], @@ -462,22 +403,6 @@ " )" ] }, - { - "cell_type": "markdown", - "id": "s6-why-md", - "metadata": {}, - "source": [ - "**Why this works.** Column-parallel TP shards the kernel along its *output* dim. The input activations are replicated along the TP axis, so the GEMM needs no input all-gather; it just runs locally on each TP shard and produces an output that's already sharded along the TP axis. The forward has no TP comms. The backward needs an all-reduce on the input-gradient (`dx`), which is partly hidden under the next layer's compute in a real model.\n", - "\n", - "With TE, the bf16→FP8 cast and the GEMM both run on the local shard. Since the column-parallel pattern keeps comms out of the forward critical path, the GEMM-compute speedup translates directly to wall-clock.\n", - "\n", - "The DP axis just shards the batch — it adds no comms in the forward; the gradient all-reduce on the backward is what gets DP-ified.\n", - "\n", - "
\n", - "Flax logical axes. In larger models, prefer Flax's logical axis rules (set via flax.linen.logical_axis_rules) and annotate kernels with kernel_axes=(...) on TE modules. The manual approach above keeps this notebook small; the logical-axes approach scales to many layers without per-layer plumbing.\n", - "
" - ] - }, { "cell_type": "markdown", "id": "s7-md", @@ -485,13 +410,7 @@ "source": [ "## 7. Collective GEMM (placeholder)\n", "\n", - "*Coming soon.*\n", - "\n", - "A further optimization is to *overlap* the TP all-gather / all-reduce with the GEMM itself — a \"collective GEMM\". TE's JAX dense path exposes this via the `collective_op_set` argument on `transformer_engine.jax.dense.dense`. A future revision of this notebook will walk through:\n", - "\n", - "- enabling overlapped AG+GEMM and reduce-scatter+GEMM,\n", - "- the throughput delta at large TP world sizes,\n", - "- when the overlap pays off and when it doesn't (small layers, slow interconnects)." + "*Coming soon.*" ] }, { @@ -499,14 +418,6 @@ "id": "recap-md", "metadata": {}, "source": [ - "## Recap\n", - "\n", - "- One line of code (`dot_general=te_flax.make_dot_general_cls(recipe)()`) swaps in a quantized GEMM.\n", - "- Your `params` tree is unchanged; stateful recipes add a `_overwrite_with_gradient` collection that must be threaded through training.\n", - "- Numerical error from MXFP8 is small enough to be dominated by gradient noise in real training.\n", - "- On a single GB200, one large Dense fwd+bwd is **~2.5× faster** with MXFP8.\n", - "- The same Dense under DP=2/TP=2 stays faster with FP8 — column-parallel TP keeps comms out of the forward critical path, so the GEMM-compute win carries through.\n", - "\n", "**Next:** [Attention](./attention.ipynb) · [Mixture of Experts](./moe.ipynb) · [← Hub](../te_jax_integration.ipynb)" ] } diff --git a/docs/examples/jax_examples/moe.ipynb b/docs/examples/jax_examples/moe.ipynb index 2cf421509a..11898857c4 100644 --- a/docs/examples/jax_examples/moe.ipynb +++ b/docs/examples/jax_examples/moe.ipynb @@ -9,10 +9,9 @@ "\n", "**TODO — Coming soon.**\n", "\n", - "This notebook will cover MoE layers using TE's grouped GEMM (`te_flax.make_grouped_dense_cls`, MXFP8 only), including:\n", + "This notebook will cover TE's `MoEBlock` layer which utilizes TE's optimized routing, permutation and grouped GEMM\n", "\n", - "- routing + token dispatch,\n", - "- grouped quantized GEMM for the expert FFN,\n", + "- single-GPU MoEBlock usage vs jax.lax.ragged_dot\n", "- expert-parallel sharding considerations.\n", "\n", "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" diff --git a/docs/examples/te_jax_integration.ipynb b/docs/examples/te_jax_integration.ipynb index 76e56161e7..dc9f5acbc0 100644 --- a/docs/examples/te_jax_integration.ipynb +++ b/docs/examples/te_jax_integration.ipynb @@ -7,11 +7,7 @@ "source": [ "# JAX: Integrating TransformerEngine into an existing framework\n", "\n", - "This is the landing page for a series of focused notebooks on bringing TransformerEngine into a JAX+Flax codebase one optimization at a time. Each linked notebook isolates a single feature so you can see exactly what changes, what state it introduces, what it costs in numerical precision, and what it buys in speed.\n", - "\n", - "**Audience.** Intermediate Flax users. You're already comfortable with `nn.Module`, `init`/`apply`, variable collections, and RNG handling \u2014 we won't re-introduce those.\n", - "\n", - "**Philosophy.** TE's JAX integration is designed to be **non-invasive**: your model parameters, initialization, optimizer, and sharding annotations stay yours. The only surface area you touch is the GEMM (and, for attention, the attention kernel). Everything else is unchanged. See [MaxText's TE integration](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/src/MaxText/layers/quantizations.py#L753) for a production example." + "This is the landing page for a series of focused notebooks on bringing TransformerEngine into a JAX+Flax codebase one optimization at a time. Each linked notebook isolates a single feature so you can see exactly what changes are required and what are the performance benefits." ] }, { @@ -23,11 +19,9 @@ "\n", "| Notebook | Status | Covers |\n", "|---|---|---|\n", - "| [Dense GEMMs](./jax_examples/dense.ipynb) | **Available** | `nn.Dense` \u2192 quantized GEMM via `dot_general_cls`; fwd/bwd numerics; single-GPU speedup; FSDP quantize-before-all-gather; Collective GEMM (placeholder) |\n", - "| [Attention](./jax_examples/attention.ipynb) | *Coming soon* | Fused attention via `te.flax.DotProductAttention`; FP8 attention; context parallelism |\n", - "| [Mixture of Experts](./jax_examples/moe.ipynb) | *Coming soon* | Grouped GEMM for expert FFNs; routing; expert-parallel sharding |\n", - "\n", - "Start with **Dense GEMMs** \u2014 it introduces the `dot_general_cls` pattern and recipe mechanics that the other notebooks build on." + "| [Dense GEMMs](./jax_examples/dense.ipynb) | **Available** | `nn.Dense` → quantized GEMM; single-GPU speedup; multi-GPU speedup; Collective GEMM |\n", + "| [Attention](./jax_examples/attention.ipynb) | *Coming soon* \n", + "| [Mixture of Experts](./jax_examples/moe.ipynb) | *Coming soon*" ] }, { @@ -37,14 +31,14 @@ "source": [ "## Quantization recipes at a glance\n", "\n", - "TE exposes its quantization choices as **recipes**. You pick one and pass it to `te_flax.make_dot_general_cls(recipe)` (or to the equivalent helper for attention / grouped GEMM).\n", + "TE exposes its quantization choices as **recipes**. Please see Low-precision Training for a more detailed description of each recipe.\n", "\n", "| Recipe | Hardware | State | When to use |\n", "|---|---|---|---|\n", - "| `DelayedScaling` | Hopper+ | amax history (Flax variables) | Stable per-tensor FP8 training; the most battle-tested recipe |\n", - "| `Float8CurrentScaling` | Hopper+ | none | Per-tensor FP8 without an amax history; simpler bookkeeping |\n", - "| `MXFP8BlockScaling` | Blackwell+ | none | Block-scaled FP8 (32-element blocks); no state to thread through training |\n", - "| `NVFP4BlockScaling` | Blackwell+ | requires `sr_rng` | FP4 with 2D block scaling and stochastic rounding |\n", + "| `DelayedScaling` | Hopper+ | amax history (Flax variables) | Per-tensor FP8 with amax history\n", + "| `Float8CurrentScaling` | Hopper+ | none | Per-tensor FP8 without an amax history |\n", + "| `MXFP8BlockScaling` | Blackwell+ | none | Block-scaled FP8 (32-element blocks) |\n", + "| `NVFP4BlockScaling` | Blackwell+ | requires a Flax RNG `sr_rng` | FP4 with 2D block scaling and stochastic rounding |\n", "\n", "Import them from `transformer_engine.common.recipe`." ] @@ -58,12 +52,7 @@ "\n", "- **Framework.** Flax Linen. (TE/JAX uses Linen; see [Flax NNX/Linen interop](https://flax.readthedocs.io/en/latest/guides/bridge_guide.html) and [Haiku/Flax interop](https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html) if you're on a different stack.)\n", "- **Baseline dtype.** bf16 for inputs and parameters.\n", - "- **Benchmarking.** `quickstart_jax_utils.speedometer` runs a JIT-compiled fwd+bwd loop with warmup.\n", - "- **Numerical check.** `quickstart_jax_utils.compare_fwd_bwd` reports max abs/rel diff on output, input grad, and weight grad.\n", - "\n", - "
\n", - "Remat / activation checkpointing. If your training loop uses jax.checkpoint_policies.checkpoint_dots (or any policy that matches jax.lax.dot_general), swap it for transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms. Otherwise TE's quantized GEMM primitives won't be checkpointed correctly and your performance comparison will be misleading.\n", - "
" + "- **Benchmarking.** `quickstart_jax_utils.speedometer` runs a JIT-compiled fwd+bwd loop with warmup." ] } ], @@ -76,4 +65,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} From 143391d13c50307f6bf14420d826cb9b76f19f3c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 11 May 2026 14:47:44 -0700 Subject: [PATCH 03/12] revert compare util function Signed-off-by: Jeremy Berchtold --- docs/examples/quickstart_jax_utils.py | 53 --------------------------- 1 file changed, 53 deletions(-) diff --git a/docs/examples/quickstart_jax_utils.py b/docs/examples/quickstart_jax_utils.py index 5253fb14e4..0c5ec5295e 100644 --- a/docs/examples/quickstart_jax_utils.py +++ b/docs/examples/quickstart_jax_utils.py @@ -88,59 +88,6 @@ def fwd_bwd_fn(*args, **kwargs): return jax.jit(fwd_bwd_fn) -def compare_fwd_bwd( - apply_a: Callable, - variables_a: Any, - apply_b: Callable, - variables_b: Any, - input: jnp.ndarray, - output_grad: jnp.ndarray, - forward_kwargs_a: Dict[str, Any] = None, - forward_kwargs_b: Dict[str, Any] = None, - rngs_a: Dict[str, jax.random.PRNGKey] = None, - rngs_b: Dict[str, jax.random.PRNGKey] = None, -) -> Dict[str, Dict[str, float]]: - """Run fwd+bwd on two apply functions and report max abs/rel diff on y, dx, and dW.""" - forward_kwargs_a = forward_kwargs_a or {} - forward_kwargs_b = forward_kwargs_b or {} - rngs_a = rngs_a or {} - rngs_b = rngs_b or {} - - def run(apply_fn, variables, forward_kwargs, rngs): - def loss_fn(variables, inp): - out = apply_fn(variables, inp, rngs=rngs, **forward_kwargs) - return jnp.vdot(out, output_grad), out - - (_, out), (param_grads, dx) = jax.value_and_grad(loss_fn, argnums=(0, 1), has_aux=True)( - variables, input - ) - return out, dx, param_grads - - y_a, dx_a, gp_a = run(apply_a, variables_a, forward_kwargs_a, rngs_a) - y_b, dx_b, gp_b = run(apply_b, variables_b, forward_kwargs_b, rngs_b) - - kernel_leaves_a = [ - leaf for path, leaf in jax.tree_util.tree_leaves_with_path(gp_a) if "kernel" in jax.tree_util.keystr(path) - ] - kernel_leaves_b = [ - leaf for path, leaf in jax.tree_util.tree_leaves_with_path(gp_b) if "kernel" in jax.tree_util.keystr(path) - ] - dW_a = kernel_leaves_a[0] if kernel_leaves_a else None - dW_b = kernel_leaves_b[0] if kernel_leaves_b else None - - def diffs(a, b): - a = a.astype(jnp.float32) - b = b.astype(jnp.float32) - abs_diff = float(jnp.max(jnp.abs(a - b))) - denom = float(jnp.max(jnp.abs(a))) + 1e-12 - return {"max_abs": abs_diff, "max_rel": abs_diff / denom} - - result = {"y": diffs(y_a, y_b), "dx": diffs(dx_a, dx_b)} - if dW_a is not None and dW_b is not None: - result["dW"] = diffs(dW_a, dW_b) - return result - - def _split_step_rngs( rngs: Dict[str, jax.random.PRNGKey], ) -> Tuple[Dict[str, jax.random.PRNGKey], Dict[str, jax.random.PRNGKey]]: From 4f543b59f25d7df39350afbcee023dc82c739055 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 12 May 2026 11:37:14 -0700 Subject: [PATCH 04/12] Refactor notebook docs to .rst and independent .py file for testability Signed-off-by: Jeremy Berchtold --- docs/examples/jax_examples/attention.ipynb | 25 -- docs/examples/jax_examples/attention.rst | 11 + docs/examples/jax_examples/conftest.py | 17 + docs/examples/jax_examples/dense.ipynb | 446 --------------------- docs/examples/jax_examples/dense.out | 22 + docs/examples/jax_examples/dense.py | 227 +++++++++++ docs/examples/jax_examples/dense.rst | 175 ++++++++ docs/examples/jax_examples/moe.ipynb | 30 -- docs/examples/jax_examples/moe.rst | 17 + docs/examples/te_jax_integration.ipynb | 68 ---- docs/examples/te_jax_integration.rst | 91 +++++ docs/index.rst | 2 +- qa/L0_jax_unittest/test.sh | 5 + qa/L1_jax_distributed_unittest/test.sh | 4 + 14 files changed, 570 insertions(+), 570 deletions(-) delete mode 100644 docs/examples/jax_examples/attention.ipynb create mode 100644 docs/examples/jax_examples/attention.rst create mode 100644 docs/examples/jax_examples/conftest.py delete mode 100644 docs/examples/jax_examples/dense.ipynb create mode 100644 docs/examples/jax_examples/dense.out create mode 100644 docs/examples/jax_examples/dense.py create mode 100644 docs/examples/jax_examples/dense.rst delete mode 100644 docs/examples/jax_examples/moe.ipynb create mode 100644 docs/examples/jax_examples/moe.rst delete mode 100644 docs/examples/te_jax_integration.ipynb create mode 100644 docs/examples/te_jax_integration.rst diff --git a/docs/examples/jax_examples/attention.ipynb b/docs/examples/jax_examples/attention.ipynb deleted file mode 100644 index 6d8e8c3f46..0000000000 --- a/docs/examples/jax_examples/attention.ipynb +++ /dev/null @@ -1,25 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "title", - "metadata": {}, - "source": [ - "# JAX: Attention with TransformerEngine\n", - "\n", - "**TODO — Coming soon.**\n", - "\n", - "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/examples/jax_examples/attention.rst b/docs/examples/jax_examples/attention.rst new file mode 100644 index 0000000000..c9f84da634 --- /dev/null +++ b/docs/examples/jax_examples/attention.rst @@ -0,0 +1,11 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +JAX: Attention with TransformerEngine +===================================== + +**TODO — Coming soon.** + +`← Back to the JAX integration overview <../te_jax_integration.html>`_ diff --git a/docs/examples/jax_examples/conftest.py b/docs/examples/jax_examples/conftest.py new file mode 100644 index 0000000000..a584e7392e --- /dev/null +++ b/docs/examples/jax_examples/conftest.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Pytest conftest for docs/examples/jax_examples. + +Adds ``docs/examples/`` to ``sys.path`` so the example modules can do +``import quickstart_jax_utils`` regardless of the directory pytest was invoked +from. +""" +import os +import sys + +_HERE = os.path.dirname(os.path.abspath(__file__)) +_EXAMPLES_ROOT = os.path.dirname(_HERE) +if _EXAMPLES_ROOT not in sys.path: + sys.path.insert(0, _EXAMPLES_ROOT) diff --git a/docs/examples/jax_examples/dense.ipynb b/docs/examples/jax_examples/dense.ipynb deleted file mode 100644 index e03eb48389..0000000000 --- a/docs/examples/jax_examples/dense.ipynb +++ /dev/null @@ -1,446 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "intro-md", - "metadata": {}, - "source": [ - "# JAX: Dense GEMMs with TransformerEngine\n", - "\n", - "This notebook walks through replacing a plain `flax.linen.Dense`'s GEMM with TransformerEngine's quantized GEMM.\n", - "\n", - "**Recipe.** We use `MXFP8BlockScaling` in this tutorial. `MXFP8BlockScaling` and `NVFP4BlockScaling` require a Blackwell-class GPU; on Hopper, swap in `DelayedScaling` or `Float8CurrentScaling`.\n", - "\n", - "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" - ] - }, - { - "cell_type": "markdown", - "id": "s1-md", - "metadata": {}, - "source": [ - "## 1. Baseline: a plain Flax Dense block\n", - "\n", - "We isolate the optimization to a single linear layer so it's clear what's changing. `dot_general_cls` is exposed as a constructor argument so we can swap in TE later without touching the model definition." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "s1-imports", - "metadata": { - "execution": { - "iopub.execute_input": "2026-05-11T21:02:32.861403Z", - "iopub.status.busy": "2026-05-11T21:02:32.861276Z", - "iopub.status.idle": "2026-05-11T21:02:33.571664Z", - "shell.execute_reply": "2026-05-11T21:02:33.571149Z" - } - }, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.append(\"..\") # so we can import quickstart_jax_utils from docs/examples/\n", - "\n", - "import warnings\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "from flax import linen as nn\n", - "import quickstart_jax_utils as utils" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "s1-model", - "metadata": { - "execution": { - "iopub.execute_input": "2026-05-11T21:02:33.573025Z", - "iopub.status.busy": "2026-05-11T21:02:33.572835Z", - "iopub.status.idle": "2026-05-11T21:02:38.767025Z", - "shell.execute_reply": "2026-05-11T21:02:38.766312Z" - } - }, - "outputs": [], - "source": [ - "class FlaxDenseBlock(nn.Module):\n", - " \"\"\"One linear layer. `dot_general_cls` lets us swap the GEMM impl.\"\"\"\n", - " features: int\n", - " dot_general_cls: callable = lambda: None\n", - "\n", - " @nn.compact\n", - " def __call__(self, x):\n", - " return nn.Dense(\n", - " features=self.features,\n", - " use_bias=False,\n", - " dot_general=self.dot_general_cls(),\n", - " )(x)\n", - "\n", - "batch, seq, hidden, out_features = 4, 2048, 4096, 16384\n", - "dtype = jnp.bfloat16\n", - "\n", - "key = jax.random.PRNGKey(0)\n", - "k_init, k_x, k_dy = jax.random.split(key, 3)\n", - "x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype)\n", - "dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype)\n", - "\n", - "baseline = FlaxDenseBlock(features=out_features)\n", - "baseline_vars = baseline.init(k_init, x)" - ] - }, - { - "cell_type": "markdown", - "id": "s2-md", - "metadata": {}, - "source": [ - "## 2. Quantized Dense via `make_dot_general_cls`\n", - "\n", - "TE exposes a helper, `te_flax.make_dot_general_cls(recipe)`, that returns a Flax module class you pass directly to `nn.Dense(..., dot_general=...)`.\n", - "\n", - "With this API, TE doesn't create the `kernel` params; it only wraps the GEMM. All your initialization, sharding annotations, and optimizer state stay where they were." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "s2-wire", - "metadata": { - "execution": { - "iopub.execute_input": "2026-05-11T21:02:38.768455Z", - "iopub.status.busy": "2026-05-11T21:02:38.768315Z", - "iopub.status.idle": "2026-05-11T21:02:39.271141Z", - "shell.execute_reply": "2026-05-11T21:02:39.270566Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Variable collections: ['params']\n", - "{'params': {'Dense_0': {'kernel': ((4096, 16384), dtype('float32'))}}}\n" - ] - } - ], - "source": [ - "from transformer_engine.jax import flax as te_flax\n", - "from transformer_engine.common.recipe import MXFP8BlockScaling\n", - "\n", - "recipe = MXFP8BlockScaling()\n", - "te_dot_general_cls = te_flax.make_dot_general_cls(recipe)\n", - "\n", - "te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls)\n", - "te_vars = te_model.init(k_init, x)\n", - "\n", - "print(\"Variable collections:\", list(te_vars.keys()))\n", - "print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars))" - ] - }, - { - "cell_type": "markdown", - "id": "670f6e90", - "metadata": {}, - "source": [ - "
\n", - "

What about DelayedScaling state?

\n", - "\n", - "Most recipes are stateless — scaling factors are computed from each tensor as it flows through the GEMM, so there is nothing to persist across steps. However, if you swap in `DelayedScaling` instead, `init` will produce a second variable collection, `_overwrite_with_gradient`, holding `kernel_amax_history`, `kernel_scale`, `x_amax_history`, `x_scale`, etc. These are **not** model parameters — they are Flax variables that TE updates each step to compute per-tensor scales from a rolling amax window.\n", - "\n", - "If you use `DelayedScaling`, you must thread the *entire* `var_collect` through your training loop (not just `params`) so the history persists across steps. `MXFP8BlockScaling`, `NVFP4BlockScaling`, and `Float8CurrentScaling` do not require this.\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "s3-md", - "metadata": {}, - "source": [ - "## 3. Single-GPU performance\n", - "\n", - "`speedometer` runs a JIT-compiled forward+backward loop with warmup, on the same input for both models." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "s3-bench", - "metadata": { - "execution": { - "iopub.execute_input": "2026-05-11T21:02:39.272256Z", - "iopub.status.busy": "2026-05-11T21:02:39.272128Z", - "iopub.status.idle": "2026-05-11T21:02:41.426450Z", - "shell.execute_reply": "2026-05-11T21:02:41.425748Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bf16 baseline:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 4.126272201538086 ms\n", - "\n", - "TE MXFP8BlockScaling:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "E0511 14:02:41.217310 154511 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 1.689767837524414 ms\n" - ] - } - ], - "source": [ - "print(\"bf16 baseline:\")\n", - "utils.speedometer(\n", - " model_apply_fn=baseline.apply,\n", - " variables=baseline_vars,\n", - " input=x, output_grad=dy,\n", - ")\n", - "\n", - "print(f\"\\nTE {type(recipe).__name__}:\")\n", - "utils.speedometer(\n", - " model_apply_fn=te_model.apply,\n", - " variables=te_vars,\n", - " input=x, output_grad=dy,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "s3-note-md", - "metadata": {}, - "source": [ - "On a single GB200, that's roughly **2.5× faster** for the fwd+bwd of one large Dense — and the only code change was passing `dot_general=te_dot_general_cls()` into `nn.Dense`.\n", - "\n", - "The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may not benefit at all because the cast + scale overhead can dominate.\n", - "\n", - "
\n", - "Remat / activation checkpointing. If your training loop uses jax.checkpoint_policies.checkpoint_dots (or any policy that matches jax.lax.dot_general), swap it for transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms. Otherwise TE's quantized GEMM primitives won't be checkpointed correctly and your performance comparison will not be accurate.\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "s6-md", - "metadata": {}, - "source": [ - "## 6. Multi-GPU: DP=2 / TP=2 on a single Dense\n", - "\n", - "**Prerequisite:** this section requires four GPUs.\n", - "\n", - "Keeping the same `FlaxDenseBlock` from the rest of the notebook, we run it on a 2×2 mesh with **data parallelism** on one axis and **tensor parallelism** (column-parallel: shard the kernel's output dim) on the other.\n", - "\n", - "Two pieces wire this up:\n", - "\n", - "1. A `jax.sharding.Mesh` you build once at module scope (outside JIT).\n", - "2. TE's `MeshResource`, set globally via `global_shard_guard`, which tells TE which mesh axes are DP and TP." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "s6-check", - "metadata": { - "execution": { - "iopub.execute_input": "2026-05-11T21:02:44.352297Z", - "iopub.status.busy": "2026-05-11T21:02:44.352166Z", - "iopub.status.idle": "2026-05-11T21:02:44.354494Z", - "shell.execute_reply": "2026-05-11T21:02:44.354073Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Visible devices: 4\n" - ] - } - ], - "source": [ - "n_devices = len(jax.devices())\n", - "print(f\"Visible devices: {n_devices}\")\n", - "assert n_devices >= 4, \"This section requires 4 GPUs for DP=2/TP=2.\"" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "s6-mesh", - "metadata": { - "execution": { - "iopub.execute_input": "2026-05-11T21:02:44.355538Z", - "iopub.status.busy": "2026-05-11T21:02:44.355424Z", - "iopub.status.idle": "2026-05-11T21:02:44.358077Z", - "shell.execute_reply": "2026-05-11T21:02:44.357662Z" - } - }, - "outputs": [], - "source": [ - "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", - "from jax.experimental import mesh_utils\n", - "from transformer_engine.jax.sharding import MeshResource, global_shard_guard\n", - "\n", - "# 2x2 mesh: DP on one axis, TP on the other.\n", - "devices = mesh_utils.create_device_mesh((2, 2))\n", - "mesh = Mesh(devices, axis_names=(\"dp\", \"tp\"))\n", - "\n", - "# Tell TE which mesh axis is which. This is a *global* setting, established\n", - "# outside JIT, so TE's GEMM primitives can plan comms accordingly.\n", - "mesh_resource = MeshResource(dp_resource=\"dp\", tp_resource=\"tp\")" - ] - }, - { - "cell_type": "markdown", - "id": "s6-shard-md", - "metadata": {}, - "source": [ - "**Sharding plan:**\n", - "\n", - "| Tensor | Shape | PartitionSpec |\n", - "|---|---|---|\n", - "| Kernel (column-parallel) | `(hidden, out_features)` | `P(None, \"tp\")` |\n", - "| Input activations | `(batch, seq, hidden)` | `P(\"dp\", None, None)` |\n", - "| Gradient on output | `(batch, seq, out_features)` | `P(\"dp\", None, \"tp\")` |" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "s6-shard", - "metadata": { - "execution": { - "iopub.execute_input": "2026-05-11T21:02:44.359044Z", - "iopub.status.busy": "2026-05-11T21:02:44.358932Z", - "iopub.status.idle": "2026-05-11T21:02:45.715164Z", - "shell.execute_reply": "2026-05-11T21:02:45.714580Z" - } - }, - "outputs": [], - "source": [ - "kernel_sharding = NamedSharding(mesh, P(None, \"tp\"))\n", - "input_sharding = NamedSharding(mesh, P(\"dp\", None, None))\n", - "output_grad_sharding = NamedSharding(mesh, P(\"dp\", None, \"tp\"))\n", - "\n", - "def shard_kernel(variables):\n", - " params = variables[\"params\"]\n", - " sharded = jax.device_put(params[\"Dense_0\"][\"kernel\"], kernel_sharding)\n", - " return {**variables, \"params\": {**params,\n", - " \"Dense_0\": {**params[\"Dense_0\"], \"kernel\": sharded}}}\n", - "\n", - "x_mp_s = jax.device_put(x, input_sharding)\n", - "dy_mp_s = jax.device_put(dy, output_grad_sharding)\n", - "baseline_vars_s = shard_kernel(baseline_vars)\n", - "te_vars_s = shard_kernel(te_vars)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "s6-bench", - "metadata": { - "execution": { - "iopub.execute_input": "2026-05-11T21:02:45.716696Z", - "iopub.status.busy": "2026-05-11T21:02:45.716566Z", - "iopub.status.idle": "2026-05-11T21:02:48.119943Z", - "shell.execute_reply": "2026-05-11T21:02:48.119290Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bf16 DP=2/TP=2:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 1.7258834838867188 ms\n", - "\n", - "TE MXFP8BlockScaling DP=2/TP=2:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 0.9692144393920898 ms\n" - ] - } - ], - "source": [ - "with jax.set_mesh(mesh), global_shard_guard(mesh_resource):\n", - " print(\"bf16 DP=2/TP=2:\")\n", - " utils.speedometer(\n", - " model_apply_fn=baseline.apply,\n", - " variables=baseline_vars_s,\n", - " input=x_mp_s, output_grad=dy_mp_s,\n", - " )\n", - "\n", - " print(f\"\\nTE {type(recipe).__name__} DP=2/TP=2:\")\n", - " utils.speedometer(\n", - " model_apply_fn=te_model.apply,\n", - " variables=te_vars_s,\n", - " input=x_mp_s, output_grad=dy_mp_s,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "id": "s7-md", - "metadata": {}, - "source": [ - "## 7. Collective GEMM (placeholder)\n", - "\n", - "*Coming soon.*" - ] - }, - { - "cell_type": "markdown", - "id": "recap-md", - "metadata": {}, - "source": [ - "**Next:** [Attention](./attention.ipynb) · [Mixture of Experts](./moe.ipynb) · [← Hub](../te_jax_integration.ipynb)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/examples/jax_examples/dense.out b/docs/examples/jax_examples/dense.out new file mode 100644 index 0000000000..d6628b2f82 --- /dev/null +++ b/docs/examples/jax_examples/dense.out @@ -0,0 +1,22 @@ +# Numbers below are illustrative (captured on a GB200). Regenerate with: +# python3 docs/examples/jax_examples/dense.py > dense.out +# after substantial code changes. + +# SINGLE_GPU_OUTPUT_START +Variable collections: ['params'] +{'params': {'Dense_0': {'kernel': ((4096, 16384), dtype('float32'))}}} + +bf16 baseline: +Mean time: 4.126 ms + +TE MXFP8BlockScaling: +Mean time: 1.690 ms +# SINGLE_GPU_OUTPUT_END + +# MULTI_GPU_OUTPUT_START +bf16 DP=2/TP=2: +Mean time: 1.726 ms + +TE MXFP8BlockScaling DP=2/TP=2: +Mean time: 0.969 ms +# MULTI_GPU_OUTPUT_END diff --git a/docs/examples/jax_examples/dense.py b/docs/examples/jax_examples/dense.py new file mode 100644 index 0000000000..2e63d218a8 --- /dev/null +++ b/docs/examples/jax_examples/dense.py @@ -0,0 +1,227 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""JAX: Dense GEMMs with TransformerEngine. + +Companion source for ``dense.rst``. Code blocks between ``# DENSE_*_START`` / +``# DENSE_*_END`` markers are pulled into the RST via ``literalinclude``. + +Run as a pytest module to exercise the example end-to-end: + + pytest -v docs/examples/jax_examples/dense.py + +The multi-GPU section auto-skips when fewer than 4 GPUs are visible. +""" + +# DENSE_IMPORTS_START +import sys + +sys.path.append("..") # so we can import quickstart_jax_utils from docs/examples/ + +import jax +import jax.numpy as jnp +from flax import linen as nn + +import quickstart_jax_utils as utils +# DENSE_IMPORTS_END + + +# DENSE_BASELINE_MODEL_START +class FlaxDenseBlock(nn.Module): + """One linear layer. ``dot_general_cls`` lets us swap the GEMM impl.""" + + features: int + dot_general_cls: callable = lambda: None + + @nn.compact + def __call__(self, x): + return nn.Dense( + features=self.features, + use_bias=False, + dot_general=self.dot_general_cls(), + )(x) +# DENSE_BASELINE_MODEL_END + + +# DENSE_INPUTS_SETUP_START +batch, seq, hidden, out_features = 4, 2048, 4096, 16384 +dtype = jnp.bfloat16 + +key = jax.random.PRNGKey(0) +k_init, k_x, k_dy = jax.random.split(key, 3) +x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype) +dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype) + +baseline = FlaxDenseBlock(features=out_features) +baseline_vars = baseline.init(k_init, x) +# DENSE_INPUTS_SETUP_END + + +# DENSE_TE_SETUP_START +from transformer_engine.jax import flax as te_flax +from transformer_engine.common.recipe import MXFP8BlockScaling + +recipe = MXFP8BlockScaling() +te_dot_general_cls = te_flax.make_dot_general_cls(recipe) + +te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls) +te_vars = te_model.init(k_init, x) + +print("Variable collections:", list(te_vars.keys())) +print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars)) +# DENSE_TE_SETUP_END + + +# DENSE_SINGLE_GPU_BENCH_START +def run_single_gpu_bench(): + print("bf16 baseline:") + utils.speedometer( + model_apply_fn=baseline.apply, + variables=baseline_vars, + input=x, + output_grad=dy, + ) + + print(f"\nTE {type(recipe).__name__}:") + utils.speedometer( + model_apply_fn=te_model.apply, + variables=te_vars, + input=x, + output_grad=dy, + ) +# DENSE_SINGLE_GPU_BENCH_END + + +# DENSE_MULTI_GPU_MESH_SETUP_START +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from jax.experimental import mesh_utils +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +def build_dp_tp_mesh(): + # 2x2 mesh: DP on one axis, TP on the other. + devices = mesh_utils.create_device_mesh((2, 2)) + mesh = Mesh(devices, axis_names=("dp", "tp")) + + # Tell TE which mesh axis is which. This is a *global* setting, established + # outside JIT, so TE's GEMM primitives can plan comms accordingly. + mesh_resource = MeshResource(dp_resource="dp", tp_resource="tp") + return mesh, mesh_resource +# DENSE_MULTI_GPU_MESH_SETUP_END + + +# DENSE_MULTI_GPU_SHARD_SETUP_START +def shard_variables(mesh, variables_dict): + kernel_sharding = NamedSharding(mesh, P(None, "tp")) + + def _shard(variables): + params = variables["params"] + sharded = jax.device_put(params["Dense_0"]["kernel"], kernel_sharding) + return { + **variables, + "params": { + **params, + "Dense_0": {**params["Dense_0"], "kernel": sharded}, + }, + } + + input_sharding = NamedSharding(mesh, P("dp", None, None)) + output_grad_sharding = NamedSharding(mesh, P("dp", None, "tp")) + + return { + "x": jax.device_put(x, input_sharding), + "dy": jax.device_put(dy, output_grad_sharding), + **{name: _shard(vars_) for name, vars_ in variables_dict.items()}, + } +# DENSE_MULTI_GPU_SHARD_SETUP_END + + +# DENSE_MULTI_GPU_BENCH_START +def run_multi_gpu_bench(): + mesh, mesh_resource = build_dp_tp_mesh() + sharded = shard_variables(mesh, {"baseline": baseline_vars, "te": te_vars}) + + with jax.set_mesh(mesh), global_shard_guard(mesh_resource): + print("bf16 DP=2/TP=2:") + utils.speedometer( + model_apply_fn=baseline.apply, + variables=sharded["baseline"], + input=sharded["x"], + output_grad=sharded["dy"], + ) + + print(f"\nTE {type(recipe).__name__} DP=2/TP=2:") + utils.speedometer( + model_apply_fn=te_model.apply, + variables=sharded["te"], + input=sharded["x"], + output_grad=sharded["dy"], + ) +# DENSE_MULTI_GPU_BENCH_END + + +# ----------------------------------------------------------------------------- +# Pytest entry points (not pulled into docs). +# +# These run the same code shown in the snippets above and add numeric / smoke +# assertions so CI catches regressions. +# ----------------------------------------------------------------------------- + +import pytest +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode + +_mxfp8_supported, _mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) +requires_mxfp8 = pytest.mark.skipif( + not _mxfp8_supported, reason=f"MXFP8 not supported on this device: {_mxfp8_reason}" +) + + +def test_baseline_runs(): + out = baseline.apply(baseline_vars, x) + assert out.shape == (batch, seq, out_features) + assert out.dtype == dtype + + +@requires_mxfp8 +def test_te_dense_runs(): + out = te_model.apply(te_vars, x) + assert out.shape == (batch, seq, out_features) + + +@requires_mxfp8 +def test_te_matches_baseline(): + """TE quantized Dense should match the bf16 baseline within MXFP8 tolerance.""" + diffs = utils.compare_fwd_bwd( + baseline.apply, + baseline_vars, + te_model.apply, + te_vars, + input=x, + output_grad=dy, + ) + # MXFP8 quantizes activations / weights, so we accept noticeable rel diff vs bf16. + # Tune these in follow-ups once we have real CI numbers. + assert diffs["y"]["max_rel"] < 0.20, diffs + assert diffs["dx"]["max_rel"] < 0.20, diffs + assert diffs["dW"]["max_rel"] < 0.30, diffs + + +@requires_mxfp8 +def test_single_gpu_benchmark(): + run_single_gpu_bench() + + +@requires_mxfp8 +@pytest.mark.skipif(len(jax.devices()) < 4, reason="needs 4 GPUs for DP=2/TP=2") +def test_multi_gpu_benchmark(): + run_multi_gpu_bench() + + +if __name__ == "__main__": + run_single_gpu_bench() + if len(jax.devices()) >= 4: + print() + run_multi_gpu_bench() + else: + print("\n[skipped multi-GPU section: <4 devices visible]") diff --git a/docs/examples/jax_examples/dense.rst b/docs/examples/jax_examples/dense.rst new file mode 100644 index 0000000000..e04326c5f5 --- /dev/null +++ b/docs/examples/jax_examples/dense.rst @@ -0,0 +1,175 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +JAX: Dense GEMMs with TransformerEngine +======================================= + +This document walks through replacing a plain ``flax.linen.Dense``'s GEMM with +TransformerEngine's quantized GEMM. + +**Recipe.** We use ``MXFP8BlockScaling`` in this tutorial. ``MXFP8BlockScaling`` and +``NVFP4BlockScaling`` require a Blackwell-class GPU; on Hopper, swap in +``DelayedScaling`` or ``Float8CurrentScaling``. + +`← Back to the JAX integration overview <../te_jax_integration.html>`_ + +1. Baseline: a plain Flax Dense block +------------------------------------- + +We isolate the optimization to a single linear layer so it's clear what's +changing. ``dot_general_cls`` is exposed as a constructor argument so we can swap +in TE later without touching the model definition. + +.. literalinclude:: dense.py + :language: python + :start-after: # DENSE_BASELINE_MODEL_START + :end-before: # DENSE_BASELINE_MODEL_END + +.. literalinclude:: dense.py + :language: python + :start-after: # DENSE_INPUTS_SETUP_START + :end-before: # DENSE_INPUTS_SETUP_END + + +2. Quantized Dense via ``make_dot_general_cls`` +----------------------------------------------- + +TE exposes a helper, ``te_flax.make_dot_general_cls(recipe)``, that returns a Flax +module class you pass directly to ``nn.Dense(..., dot_general=...)``. + +With this API, TE doesn't create the ``kernel`` params; it only wraps the GEMM. +All your initialization, sharding annotations, and optimizer state stay where +they were. + +.. literalinclude:: dense.py + :language: python + :start-after: # DENSE_TE_SETUP_START + :end-before: # DENSE_TE_SETUP_END + +.. note:: + + **What about DelayedScaling state?** + + Most recipes are stateless — scaling factors are computed from each tensor + as it flows through the GEMM, so there is nothing to persist across steps. + However, if you swap in ``DelayedScaling`` instead, ``init`` will produce a + second variable collection, ``_overwrite_with_gradient``, holding + ``kernel_amax_history``, ``kernel_scale``, ``x_amax_history``, ``x_scale``, + etc. These are **not** model parameters — they are Flax variables that TE + updates each step to compute per-tensor scales from a rolling amax window. + + If you use ``DelayedScaling``, you must thread the *entire* ``var_collect`` + through your training loop (not just ``params``) so the history persists + across steps. ``MXFP8BlockScaling``, ``NVFP4BlockScaling``, and + ``Float8CurrentScaling`` do not require this. + + +3. Single-GPU performance +------------------------- + +``speedometer`` runs a JIT-compiled forward+backward loop with warmup, on the +same input for both models. + +.. literalinclude:: dense.py + :language: python + :start-after: # DENSE_SINGLE_GPU_BENCH_START + :end-before: # DENSE_SINGLE_GPU_BENCH_END + +.. raw:: html + +
+ Output: +
+ +.. container:: program-output + + .. literalinclude:: dense.out + :language: text + :start-after: # SINGLE_GPU_OUTPUT_START + :end-before: # SINGLE_GPU_OUTPUT_END + +On a single GB200, that's roughly **2.5× faster** for the fwd+bwd of one large +Dense — and the only code change was passing ``dot_general=te_dot_general_cls()`` +into ``nn.Dense``. + +The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may +not benefit at all because the cast + scale overhead can dominate. + +.. warning:: + + **Remat / activation checkpointing.** If your training loop uses + ``jax.checkpoint_policies.checkpoint_dots`` (or any policy that matches + ``jax.lax.dot_general``), swap it for + ``transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms``. + Otherwise TE's quantized GEMM primitives won't be checkpointed correctly + and your performance comparison will not be accurate. + + +6. Multi-GPU: DP=2 / TP=2 on a single Dense +------------------------------------------- + +**Prerequisite:** this section requires four GPUs. + +Keeping the same ``FlaxDenseBlock`` from the rest of the document, we run it on +a 2×2 mesh with **data parallelism** on one axis and **tensor parallelism** +(column-parallel: shard the kernel's output dim) on the other. + +Two pieces wire this up: + +1. A ``jax.sharding.Mesh`` you build once at module scope (outside JIT). +2. TE's ``MeshResource``, set globally via ``global_shard_guard``, which tells + TE which mesh axes are DP and TP. + +.. literalinclude:: dense.py + :language: python + :start-after: # DENSE_MULTI_GPU_MESH_SETUP_START + :end-before: # DENSE_MULTI_GPU_MESH_SETUP_END + +**Sharding plan:** + +.. csv-table:: + :header: "Tensor", "Shape", "PartitionSpec" + :widths: 30, 40, 30 + + "Kernel (column-parallel)", "``(hidden, out_features)``", "``P(None, 'tp')``" + "Input activations", "``(batch, seq, hidden)``", "``P('dp', None, None)``" + "Gradient on output", "``(batch, seq, out_features)``", "``P('dp', None, 'tp')``" + +.. literalinclude:: dense.py + :language: python + :start-after: # DENSE_MULTI_GPU_SHARD_SETUP_START + :end-before: # DENSE_MULTI_GPU_SHARD_SETUP_END + +.. literalinclude:: dense.py + :language: python + :start-after: # DENSE_MULTI_GPU_BENCH_START + :end-before: # DENSE_MULTI_GPU_BENCH_END + +.. raw:: html + +
+ Output: +
+ +.. container:: program-output + + .. literalinclude:: dense.out + :language: text + :start-after: # MULTI_GPU_OUTPUT_START + :end-before: # MULTI_GPU_OUTPUT_END + + +7. Collective GEMM (placeholder) +-------------------------------- + +*Coming soon.* + + +Next steps +---------- + +* `Attention `_ +* `Mixture of Experts `_ +* `← Hub <../te_jax_integration.html>`_ diff --git a/docs/examples/jax_examples/moe.ipynb b/docs/examples/jax_examples/moe.ipynb deleted file mode 100644 index 11898857c4..0000000000 --- a/docs/examples/jax_examples/moe.ipynb +++ /dev/null @@ -1,30 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "title", - "metadata": {}, - "source": [ - "# JAX: Mixture of Experts with TransformerEngine\n", - "\n", - "**TODO — Coming soon.**\n", - "\n", - "This notebook will cover TE's `MoEBlock` layer which utilizes TE's optimized routing, permutation and grouped GEMM\n", - "\n", - "- single-GPU MoEBlock usage vs jax.lax.ragged_dot\n", - "- expert-parallel sharding considerations.\n", - "\n", - "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/examples/jax_examples/moe.rst b/docs/examples/jax_examples/moe.rst new file mode 100644 index 0000000000..fb1c8496ba --- /dev/null +++ b/docs/examples/jax_examples/moe.rst @@ -0,0 +1,17 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +JAX: Mixture of Experts with TransformerEngine +============================================== + +**TODO — Coming soon.** + +This document will cover TE's ``MoEBlock`` layer which utilizes TE's optimized +routing, permutation and grouped GEMM: + +* single-GPU ``MoEBlock`` usage vs ``jax.lax.ragged_dot`` +* expert-parallel sharding considerations. + +`← Back to the JAX integration overview <../te_jax_integration.html>`_ diff --git a/docs/examples/te_jax_integration.ipynb b/docs/examples/te_jax_integration.ipynb deleted file mode 100644 index dc9f5acbc0..0000000000 --- a/docs/examples/te_jax_integration.ipynb +++ /dev/null @@ -1,68 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "title", - "metadata": {}, - "source": [ - "# JAX: Integrating TransformerEngine into an existing framework\n", - "\n", - "This is the landing page for a series of focused notebooks on bringing TransformerEngine into a JAX+Flax codebase one optimization at a time. Each linked notebook isolates a single feature so you can see exactly what changes are required and what are the performance benefits." - ] - }, - { - "cell_type": "markdown", - "id": "topics", - "metadata": {}, - "source": [ - "## Pick a topic\n", - "\n", - "| Notebook | Status | Covers |\n", - "|---|---|---|\n", - "| [Dense GEMMs](./jax_examples/dense.ipynb) | **Available** | `nn.Dense` → quantized GEMM; single-GPU speedup; multi-GPU speedup; Collective GEMM |\n", - "| [Attention](./jax_examples/attention.ipynb) | *Coming soon* \n", - "| [Mixture of Experts](./jax_examples/moe.ipynb) | *Coming soon*" - ] - }, - { - "cell_type": "markdown", - "id": "recipes", - "metadata": {}, - "source": [ - "## Quantization recipes at a glance\n", - "\n", - "TE exposes its quantization choices as **recipes**. Please see Low-precision Training for a more detailed description of each recipe.\n", - "\n", - "| Recipe | Hardware | State | When to use |\n", - "|---|---|---|---|\n", - "| `DelayedScaling` | Hopper+ | amax history (Flax variables) | Per-tensor FP8 with amax history\n", - "| `Float8CurrentScaling` | Hopper+ | none | Per-tensor FP8 without an amax history |\n", - "| `MXFP8BlockScaling` | Blackwell+ | none | Block-scaled FP8 (32-element blocks) |\n", - "| `NVFP4BlockScaling` | Blackwell+ | requires a Flax RNG `sr_rng` | FP4 with 2D block scaling and stochastic rounding |\n", - "\n", - "Import them from `transformer_engine.common.recipe`." - ] - }, - { - "cell_type": "markdown", - "id": "conventions", - "metadata": {}, - "source": [ - "## Conventions used across these notebooks\n", - "\n", - "- **Framework.** Flax Linen. (TE/JAX uses Linen; see [Flax NNX/Linen interop](https://flax.readthedocs.io/en/latest/guides/bridge_guide.html) and [Haiku/Flax interop](https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html) if you're on a different stack.)\n", - "- **Baseline dtype.** bf16 for inputs and parameters.\n", - "- **Benchmarking.** `quickstart_jax_utils.speedometer` runs a JIT-compiled fwd+bwd loop with warmup." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/examples/te_jax_integration.rst b/docs/examples/te_jax_integration.rst new file mode 100644 index 0000000000..a6dd0d401e --- /dev/null +++ b/docs/examples/te_jax_integration.rst @@ -0,0 +1,91 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +JAX: Integrating TransformerEngine into an existing framework +============================================================= + +This is the landing page for a series of focused documents on bringing +TransformerEngine into a JAX+Flax codebase one optimization at a time. Each +linked page isolates a single feature so you can see exactly what changes are +required and what are the performance benefits. + +Pick a topic +------------ + +.. list-table:: + :header-rows: 1 + :widths: 25, 15, 60 + + * - Document + - Status + - Covers + * - `Dense GEMMs `_ + - **Available** + - ``nn.Dense`` → quantized GEMM; single-GPU speedup; multi-GPU speedup; + Collective GEMM + * - `Attention `_ + - *Coming soon* + - + * - `Mixture of Experts `_ + - *Coming soon* + - + + +Quantization recipes at a glance +-------------------------------- + +TE exposes its quantization choices as **recipes**. Please see +`Low-precision Training +`_ +for a more detailed description of each recipe. + +.. list-table:: + :header-rows: 1 + :widths: 25, 15, 30, 30 + + * - Recipe + - Hardware + - State + - When to use + * - ``DelayedScaling`` + - Hopper+ + - amax history (Flax variables) + - Per-tensor FP8 with amax history + * - ``Float8CurrentScaling`` + - Hopper+ + - none + - Per-tensor FP8 without an amax history + * - ``MXFP8BlockScaling`` + - Blackwell+ + - none + - Block-scaled FP8 (32-element blocks) + * - ``NVFP4BlockScaling`` + - Blackwell+ + - requires a Flax RNG ``sr_rng`` + - FP4 with 2D block scaling and stochastic rounding + +Import them from ``transformer_engine.common.recipe``. + + +Conventions used across these documents +--------------------------------------- + +* **Framework.** Flax Linen. (TE/JAX uses Linen; see + `Flax NNX/Linen interop + `_ and + `Haiku/Flax interop + `_ if you're on + a different stack.) +* **Baseline dtype.** bf16 for inputs and parameters. +* **Benchmarking.** ``quickstart_jax_utils.speedometer`` runs a JIT-compiled + fwd+bwd loop with warmup. + + +.. toctree:: + :hidden: + + jax_examples/dense + jax_examples/attention + jax_examples/moe diff --git a/docs/index.rst b/docs/index.rst index 7389553679..53c4b0e37e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,7 +57,7 @@ Transformer Engine documentation examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb examples/te_gemma/tutorial_generation_gemma_with_te.ipynb examples/onnx/onnx_export.ipynb - examples/te_jax_integration.ipynb + examples/te_jax_integration.rst examples/op_fuser/op_fuser.rst .. toctree:: diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 3453e35d2c..9cd171f896 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -42,6 +42,11 @@ python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/py export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" +# Exercise the docs/examples/jax_examples tutorials. The multi-GPU tests are +# skipped at runtime when fewer than 4 devices are visible, so this is safe on +# single-GPU runners. +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax_examples/ || test_fail "docs/examples/jax_examples" + if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" exit 1 diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 4f92d1c783..ea33828f53 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -37,6 +37,10 @@ XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_nccl_comm_splitting=false" python3 -m pyt python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py" +# Exercise the multi-GPU tutorial in docs/examples/jax_examples (needs >= 4 GPUs; +# auto-skips otherwise). +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax_distributed.xml -k multi_gpu $TE_PATH/docs/examples/jax_examples/ || test_fail "docs/examples/jax_examples (multi-GPU)" + # TODO(Phuong): add this test back after it is verified # SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py" From 5432ec6892d1b7f1b356f3bf5474615b2a832a51 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 15 May 2026 09:42:59 -0700 Subject: [PATCH 05/12] Address comments Signed-off-by: Jeremy Berchtold --- .../{jax_examples => jax}/attention.rst | 0 docs/examples/jax/collective_gemm.rst | 11 +++++++++++ .../{jax_examples => jax}/conftest.py | 0 docs/examples/{jax_examples => jax}/dense.out | 1 - docs/examples/{jax_examples => jax}/dense.py | 0 docs/examples/{jax_examples => jax}/dense.rst | 9 ++------- .../moe.rst => jax/expert_parallelism.rst} | 8 +------- docs/examples/te_jax_integration.rst | 19 +++++++++++-------- 8 files changed, 25 insertions(+), 23 deletions(-) rename docs/examples/{jax_examples => jax}/attention.rst (100%) create mode 100644 docs/examples/jax/collective_gemm.rst rename docs/examples/{jax_examples => jax}/conftest.py (100%) rename docs/examples/{jax_examples => jax}/dense.out (93%) rename docs/examples/{jax_examples => jax}/dense.py (100%) rename docs/examples/{jax_examples => jax}/dense.rst (97%) rename docs/examples/{jax_examples/moe.rst => jax/expert_parallelism.rst} (50%) diff --git a/docs/examples/jax_examples/attention.rst b/docs/examples/jax/attention.rst similarity index 100% rename from docs/examples/jax_examples/attention.rst rename to docs/examples/jax/attention.rst diff --git a/docs/examples/jax/collective_gemm.rst b/docs/examples/jax/collective_gemm.rst new file mode 100644 index 0000000000..05b39ea011 --- /dev/null +++ b/docs/examples/jax/collective_gemm.rst @@ -0,0 +1,11 @@ +.. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +JAX: Collective GEMMs with TransformerEngine +============================================= + +**TODO — Coming soon.** + +`← Back to the JAX integration overview <../te_jax_integration.html>`_ diff --git a/docs/examples/jax_examples/conftest.py b/docs/examples/jax/conftest.py similarity index 100% rename from docs/examples/jax_examples/conftest.py rename to docs/examples/jax/conftest.py diff --git a/docs/examples/jax_examples/dense.out b/docs/examples/jax/dense.out similarity index 93% rename from docs/examples/jax_examples/dense.out rename to docs/examples/jax/dense.out index d6628b2f82..990c4104b6 100644 --- a/docs/examples/jax_examples/dense.out +++ b/docs/examples/jax/dense.out @@ -1,6 +1,5 @@ # Numbers below are illustrative (captured on a GB200). Regenerate with: # python3 docs/examples/jax_examples/dense.py > dense.out -# after substantial code changes. # SINGLE_GPU_OUTPUT_START Variable collections: ['params'] diff --git a/docs/examples/jax_examples/dense.py b/docs/examples/jax/dense.py similarity index 100% rename from docs/examples/jax_examples/dense.py rename to docs/examples/jax/dense.py diff --git a/docs/examples/jax_examples/dense.rst b/docs/examples/jax/dense.rst similarity index 97% rename from docs/examples/jax_examples/dense.rst rename to docs/examples/jax/dense.rst index e04326c5f5..5e64cc4850 100644 --- a/docs/examples/jax_examples/dense.rst +++ b/docs/examples/jax/dense.rst @@ -107,7 +107,7 @@ not benefit at all because the cast + scale overhead can dominate. and your performance comparison will not be accurate. -6. Multi-GPU: DP=2 / TP=2 on a single Dense +4. Multi-GPU: DP=2 / TP=2 on a single Dense ------------------------------------------- **Prerequisite:** this section requires four GPUs. @@ -161,15 +161,10 @@ Two pieces wire this up: :end-before: # MULTI_GPU_OUTPUT_END -7. Collective GEMM (placeholder) --------------------------------- - -*Coming soon.* - - Next steps ---------- +* `Collective GEMM `_: further speedups by communicating between devices inside the GEMM. * `Attention `_ * `Mixture of Experts `_ * `← Hub <../te_jax_integration.html>`_ diff --git a/docs/examples/jax_examples/moe.rst b/docs/examples/jax/expert_parallelism.rst similarity index 50% rename from docs/examples/jax_examples/moe.rst rename to docs/examples/jax/expert_parallelism.rst index fb1c8496ba..5e94e1d298 100644 --- a/docs/examples/jax_examples/moe.rst +++ b/docs/examples/jax/expert_parallelism.rst @@ -3,15 +3,9 @@ See LICENSE for license information. -JAX: Mixture of Experts with TransformerEngine +JAX: Expert Parallelism with TransformerEngine ============================================== **TODO — Coming soon.** -This document will cover TE's ``MoEBlock`` layer which utilizes TE's optimized -routing, permutation and grouped GEMM: - -* single-GPU ``MoEBlock`` usage vs ``jax.lax.ragged_dot`` -* expert-parallel sharding considerations. - `← Back to the JAX integration overview <../te_jax_integration.html>`_ diff --git a/docs/examples/te_jax_integration.rst b/docs/examples/te_jax_integration.rst index a6dd0d401e..2602b3bbf3 100644 --- a/docs/examples/te_jax_integration.rst +++ b/docs/examples/te_jax_integration.rst @@ -21,14 +21,16 @@ Pick a topic * - Document - Status - Covers - * - `Dense GEMMs `_ + * - `Dense GEMMs `_ - **Available** - ``nn.Dense`` → quantized GEMM; single-GPU speedup; multi-GPU speedup; - Collective GEMM - * - `Attention `_ + * - `Collective GEMMs `_ - *Coming soon* - - * - `Mixture of Experts `_ + * - `Attention `_ + - *Coming soon* + - + * - `Expert Parallelism `_ - *Coming soon* - @@ -80,12 +82,13 @@ Conventions used across these documents a different stack.) * **Baseline dtype.** bf16 for inputs and parameters. * **Benchmarking.** ``quickstart_jax_utils.speedometer`` runs a JIT-compiled - fwd+bwd loop with warmup. + fwd+bwd loop with warmup .. toctree:: :hidden: - jax_examples/dense - jax_examples/attention - jax_examples/moe + jax/dense + jax/collective_gemm + jax/attention + jax/expert_parallelism From 89d1cefc95e8ea1dca1c180522ebab760ec4c301 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 16:44:04 +0000 Subject: [PATCH 06/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/jax/dense.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/examples/jax/dense.py b/docs/examples/jax/dense.py index 2e63d218a8..bba341e074 100644 --- a/docs/examples/jax/dense.py +++ b/docs/examples/jax/dense.py @@ -24,6 +24,7 @@ from flax import linen as nn import quickstart_jax_utils as utils + # DENSE_IMPORTS_END @@ -41,6 +42,8 @@ def __call__(self, x): use_bias=False, dot_general=self.dot_general_cls(), )(x) + + # DENSE_BASELINE_MODEL_END @@ -90,6 +93,8 @@ def run_single_gpu_bench(): input=x, output_grad=dy, ) + + # DENSE_SINGLE_GPU_BENCH_END @@ -108,6 +113,8 @@ def build_dp_tp_mesh(): # outside JIT, so TE's GEMM primitives can plan comms accordingly. mesh_resource = MeshResource(dp_resource="dp", tp_resource="tp") return mesh, mesh_resource + + # DENSE_MULTI_GPU_MESH_SETUP_END @@ -134,6 +141,8 @@ def _shard(variables): "dy": jax.device_put(dy, output_grad_sharding), **{name: _shard(vars_) for name, vars_ in variables_dict.items()}, } + + # DENSE_MULTI_GPU_SHARD_SETUP_END @@ -158,6 +167,8 @@ def run_multi_gpu_bench(): input=sharded["x"], output_grad=sharded["dy"], ) + + # DENSE_MULTI_GPU_BENCH_END From aa7d624eeba4aa417c2580cb24a95e9cb3e3cec7 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 15 May 2026 09:48:09 -0700 Subject: [PATCH 07/12] Update qa/L1_jax_distributed_unittest/test.sh Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- qa/L1_jax_distributed_unittest/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index ea33828f53..b3b5762d98 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -39,7 +39,7 @@ python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/py # Exercise the multi-GPU tutorial in docs/examples/jax_examples (needs >= 4 GPUs; # auto-skips otherwise). -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax_distributed.xml -k multi_gpu $TE_PATH/docs/examples/jax_examples/ || test_fail "docs/examples/jax_examples (multi-GPU)" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax_distributed.xml -k multi_gpu $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax (multi-GPU)" # TODO(Phuong): add this test back after it is verified # SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py" From 74e9c5812928e9daf07988d7b9297a424961d2a7 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 15 May 2026 09:48:18 -0700 Subject: [PATCH 08/12] Update qa/L0_jax_unittest/test.sh Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- qa/L0_jax_unittest/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 9cd171f896..ccf10b8843 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -45,7 +45,7 @@ NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini # Exercise the docs/examples/jax_examples tutorials. The multi-GPU tests are # skipped at runtime when fewer than 4 devices are visible, so this is safe on # single-GPU runners. -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax_examples/ || test_fail "docs/examples/jax_examples" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" From 15f10b873938082290959dcbaeb8747d260aa48e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 15 May 2026 10:43:54 -0700 Subject: [PATCH 09/12] Fix tests Signed-off-by: Jeremy Berchtold --- docs/examples/jax/conftest.py | 17 ---- docs/examples/jax/dense.py | 70 ++------------ docs/examples/jax/dense.rst | 2 - .../{ => jax}/quickstart_jax_utils.py | 46 ++++++++++ docs/examples/jax/test_dense.py | 91 +++++++++++++++++++ 5 files changed, 143 insertions(+), 83 deletions(-) delete mode 100644 docs/examples/jax/conftest.py rename docs/examples/{ => jax}/quickstart_jax_utils.py (64%) create mode 100644 docs/examples/jax/test_dense.py diff --git a/docs/examples/jax/conftest.py b/docs/examples/jax/conftest.py deleted file mode 100644 index a584e7392e..0000000000 --- a/docs/examples/jax/conftest.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Pytest conftest for docs/examples/jax_examples. - -Adds ``docs/examples/`` to ``sys.path`` so the example modules can do -``import quickstart_jax_utils`` regardless of the directory pytest was invoked -from. -""" -import os -import sys - -_HERE = os.path.dirname(os.path.abspath(__file__)) -_EXAMPLES_ROOT = os.path.dirname(_HERE) -if _EXAMPLES_ROOT not in sys.path: - sys.path.insert(0, _EXAMPLES_ROOT) diff --git a/docs/examples/jax/dense.py b/docs/examples/jax/dense.py index bba341e074..9bf55b8c46 100644 --- a/docs/examples/jax/dense.py +++ b/docs/examples/jax/dense.py @@ -7,18 +7,15 @@ Companion source for ``dense.rst``. Code blocks between ``# DENSE_*_START`` / ``# DENSE_*_END`` markers are pulled into the RST via ``literalinclude``. -Run as a pytest module to exercise the example end-to-end: +Run as a script to exercise the example end-to-end: - pytest -v docs/examples/jax_examples/dense.py + python docs/examples/jax/dense.py -The multi-GPU section auto-skips when fewer than 4 GPUs are visible. +Pytest tests live in ``test_dense.py``; the multi-GPU section auto-skips when +fewer than 4 GPUs are visible. """ # DENSE_IMPORTS_START -import sys - -sys.path.append("..") # so we can import quickstart_jax_utils from docs/examples/ - import jax import jax.numpy as jnp from flax import linen as nn @@ -33,6 +30,7 @@ class FlaxDenseBlock(nn.Module): """One linear layer. ``dot_general_cls`` lets us swap the GEMM impl.""" features: int + dtype: jnp.dtype = jnp.bfloat16 dot_general_cls: callable = lambda: None @nn.compact @@ -40,6 +38,7 @@ def __call__(self, x): return nn.Dense( features=self.features, use_bias=False, + dtype=self.dtype, dot_general=self.dot_general_cls(), )(x) @@ -172,63 +171,6 @@ def run_multi_gpu_bench(): # DENSE_MULTI_GPU_BENCH_END -# ----------------------------------------------------------------------------- -# Pytest entry points (not pulled into docs). -# -# These run the same code shown in the snippets above and add numeric / smoke -# assertions so CI catches regressions. -# ----------------------------------------------------------------------------- - -import pytest -from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode - -_mxfp8_supported, _mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) -requires_mxfp8 = pytest.mark.skipif( - not _mxfp8_supported, reason=f"MXFP8 not supported on this device: {_mxfp8_reason}" -) - - -def test_baseline_runs(): - out = baseline.apply(baseline_vars, x) - assert out.shape == (batch, seq, out_features) - assert out.dtype == dtype - - -@requires_mxfp8 -def test_te_dense_runs(): - out = te_model.apply(te_vars, x) - assert out.shape == (batch, seq, out_features) - - -@requires_mxfp8 -def test_te_matches_baseline(): - """TE quantized Dense should match the bf16 baseline within MXFP8 tolerance.""" - diffs = utils.compare_fwd_bwd( - baseline.apply, - baseline_vars, - te_model.apply, - te_vars, - input=x, - output_grad=dy, - ) - # MXFP8 quantizes activations / weights, so we accept noticeable rel diff vs bf16. - # Tune these in follow-ups once we have real CI numbers. - assert diffs["y"]["max_rel"] < 0.20, diffs - assert diffs["dx"]["max_rel"] < 0.20, diffs - assert diffs["dW"]["max_rel"] < 0.30, diffs - - -@requires_mxfp8 -def test_single_gpu_benchmark(): - run_single_gpu_bench() - - -@requires_mxfp8 -@pytest.mark.skipif(len(jax.devices()) < 4, reason="needs 4 GPUs for DP=2/TP=2") -def test_multi_gpu_benchmark(): - run_multi_gpu_bench() - - if __name__ == "__main__": run_single_gpu_bench() if len(jax.devices()) >= 4: diff --git a/docs/examples/jax/dense.rst b/docs/examples/jax/dense.rst index 5e64cc4850..93fbb864ad 100644 --- a/docs/examples/jax/dense.rst +++ b/docs/examples/jax/dense.rst @@ -165,6 +165,4 @@ Next steps ---------- * `Collective GEMM `_: further speedups by communicating between devices inside the GEMM. -* `Attention `_ -* `Mixture of Experts `_ * `← Hub <../te_jax_integration.html>`_ diff --git a/docs/examples/quickstart_jax_utils.py b/docs/examples/jax/quickstart_jax_utils.py similarity index 64% rename from docs/examples/quickstart_jax_utils.py rename to docs/examples/jax/quickstart_jax_utils.py index 0c5ec5295e..a8e1202d38 100644 --- a/docs/examples/quickstart_jax_utils.py +++ b/docs/examples/jax/quickstart_jax_utils.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +import numpy as np import time from typing import Callable, Any, Dict, Optional, Tuple @@ -99,3 +100,48 @@ def _split_step_rngs( new_rngs[name] = new_key step_rngs[name] = step_key return new_rngs, step_rngs + + +def compare_fwd_bwd( + ref_apply_fn: Callable, + ref_variables: Any, + test_apply_fn: Callable, + test_variables: Any, + *, + input: jnp.ndarray, + output_grad: jnp.ndarray, + rtol: float = 1e-5, + atol: float = 1e-8, + rtol_dW: Optional[float] = None, + atol_dW: Optional[float] = None, +) -> None: + """Compare forward outputs and VJP gradients between two models. + + Runs ``y, vjp_fn = jax.vjp(apply_fn, variables, input)`` for each model, + then applies ``vjp_fn(output_grad)`` to get gradients wrt both the + parameters (``dW``) and the input (``dx``). Calls + ``numpy.testing.assert_allclose`` on each tensor (``y``, ``dx``, and every + leaf of ``dW``). ``rtol_dW`` / ``atol_dW`` override ``rtol`` / ``atol`` + for the params-grad comparison. + """ + rtol_dW = rtol if rtol_dW is None else rtol_dW + atol_dW = atol if atol_dW is None else atol_dW + + def _run(apply_fn: Callable) -> Callable: + @jax.jit + def go(variables, inp, dy): + y, vjp_fn = jax.vjp(apply_fn, variables, inp) + dvars, dx = vjp_fn(dy.astype(y.dtype)) + return y, dvars["params"], dx + + return go + + y_ref, dW_ref, dx_ref = _run(ref_apply_fn)(ref_variables, input, output_grad) + y_test, dW_test, dx_test = _run(test_apply_fn)(test_variables, input, output_grad) + + np.testing.assert_allclose(y_test, y_ref, rtol=rtol, atol=atol, err_msg="forward output (y) mismatch") + np.testing.assert_allclose(dx_test, dx_ref, rtol=rtol, atol=atol, err_msg="input grad (dx) mismatch") + for ref_leaf, test_leaf in zip(jax.tree_util.tree_leaves(dW_ref), jax.tree_util.tree_leaves(dW_test)): + np.testing.assert_allclose( + test_leaf, ref_leaf, rtol=rtol_dW, atol=atol_dW, err_msg="params grad (dW) mismatch" + ) diff --git a/docs/examples/jax/test_dense.py b/docs/examples/jax/test_dense.py new file mode 100644 index 0000000000..db600f7ed5 --- /dev/null +++ b/docs/examples/jax/test_dense.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Pytest entry points for ``dense.py``. + +These run the same code shown in ``dense.py`` and add numeric / smoke +assertions so CI catches regressions. + +Run with: + + pytest -v docs/examples/jax/test_dense.py + +The multi-GPU section auto-skips when fewer than 4 GPUs are visible. +""" + +import jax +import jax.numpy as jnp +import pytest + +import quickstart_jax_utils as utils +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode + +from dense import ( + baseline, + baseline_vars, + batch, + dtype, + dy, + out_features, + run_multi_gpu_bench, + run_single_gpu_bench, + seq, + te_model, + te_vars, + x, +) + +_mxfp8_supported, _mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) +requires_mxfp8 = pytest.mark.skipif( + not _mxfp8_supported, reason=f"MXFP8 not supported on this device: {_mxfp8_reason}" +) + +# MXFP8 quantization noise is ~FP8 epsilon (~5%) of the per-tensor magnitude. +# ``atol`` covers near-zero ref values where the rtol fraction is too tight. +# y and dx are O(1) under Flax's lecun_normal init (in/out scaling cancels); +# dW has no init scaling and accumulates batch*seq products, so it grows as +# sqrt(batch*seq). +_FP8_REL_NOISE = float(jnp.finfo(jnp.float8_e4m3fn).eps) # 0.125 +_ATOL_FWD = 10.0 * _FP8_REL_NOISE # ~1.25; covers Gaussian-tail |y|, |dx| +_ATOL_DW = _ATOL_FWD * jnp.sqrt(batch * seq).item() # ~113; covers Gaussian-tail |dW| + + +def test_baseline_runs(): + out = baseline.apply(baseline_vars, x) + assert out.shape == (batch, seq, out_features) + assert out.dtype == dtype + + +@requires_mxfp8 +def test_te_dense_runs(): + out = te_model.apply(te_vars, x) + assert out.shape == (batch, seq, out_features) + + +@requires_mxfp8 +def test_te_matches_baseline(): + """TE quantized Dense should match the bf16 baseline within MXFP8 tolerance.""" + utils.compare_fwd_bwd( + baseline.apply, + baseline_vars, + te_model.apply, + te_vars, + input=x, + output_grad=dy, + rtol=_FP8_REL_NOISE, + atol=_ATOL_FWD, + rtol_dW=_FP8_REL_NOISE, + atol_dW=_ATOL_DW, + ) + + +@requires_mxfp8 +def test_single_gpu_benchmark(): + run_single_gpu_bench() + + +@requires_mxfp8 +@pytest.mark.skipif(len(jax.devices()) < 4, reason="needs 4 GPUs for DP=2/TP=2") +def test_multi_gpu_benchmark(): + run_multi_gpu_bench() From 168cc636756d1f0e54fda70e21c71adbbcd53f1b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 15 May 2026 11:36:41 -0700 Subject: [PATCH 10/12] Fixes Signed-off-by: Jeremy Berchtold --- docs/examples/jax/dense.out | 12 ++++++------ docs/examples/jax/dense.py | 2 +- docs/examples/jax/test_dense.py | 11 +++-------- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/docs/examples/jax/dense.out b/docs/examples/jax/dense.out index 990c4104b6..22b93ff04e 100644 --- a/docs/examples/jax/dense.out +++ b/docs/examples/jax/dense.out @@ -1,21 +1,21 @@ # Numbers below are illustrative (captured on a GB200). Regenerate with: -# python3 docs/examples/jax_examples/dense.py > dense.out +# python3 docs/examples/jax/dense.py > dense.out # SINGLE_GPU_OUTPUT_START Variable collections: ['params'] -{'params': {'Dense_0': {'kernel': ((4096, 16384), dtype('float32'))}}} +{'params': {'Dense_0': {'kernel': ((8192, 32768), dtype('float32'))}}} bf16 baseline: -Mean time: 4.126 ms +Mean time: 18.056 ms TE MXFP8BlockScaling: -Mean time: 1.690 ms +Mean time: 11.260 ms # SINGLE_GPU_OUTPUT_END # MULTI_GPU_OUTPUT_START bf16 DP=2/TP=2: -Mean time: 1.726 ms +Mean time: 5.516 ms TE MXFP8BlockScaling DP=2/TP=2: -Mean time: 0.969 ms +Mean time: 3.712 ms # MULTI_GPU_OUTPUT_END diff --git a/docs/examples/jax/dense.py b/docs/examples/jax/dense.py index 9bf55b8c46..9ddc5a9e8e 100644 --- a/docs/examples/jax/dense.py +++ b/docs/examples/jax/dense.py @@ -47,7 +47,7 @@ def __call__(self, x): # DENSE_INPUTS_SETUP_START -batch, seq, hidden, out_features = 4, 2048, 4096, 16384 +batch, seq, hidden, out_features = 8, 2048, 8192, 32768 dtype = jnp.bfloat16 key = jax.random.PRNGKey(0) diff --git a/docs/examples/jax/test_dense.py b/docs/examples/jax/test_dense.py index db600f7ed5..c5363c75ce 100644 --- a/docs/examples/jax/test_dense.py +++ b/docs/examples/jax/test_dense.py @@ -41,14 +41,9 @@ not _mxfp8_supported, reason=f"MXFP8 not supported on this device: {_mxfp8_reason}" ) -# MXFP8 quantization noise is ~FP8 epsilon (~5%) of the per-tensor magnitude. -# ``atol`` covers near-zero ref values where the rtol fraction is too tight. -# y and dx are O(1) under Flax's lecun_normal init (in/out scaling cancels); -# dW has no init scaling and accumulates batch*seq products, so it grows as -# sqrt(batch*seq). -_FP8_REL_NOISE = float(jnp.finfo(jnp.float8_e4m3fn).eps) # 0.125 -_ATOL_FWD = 10.0 * _FP8_REL_NOISE # ~1.25; covers Gaussian-tail |y|, |dx| -_ATOL_DW = _ATOL_FWD * jnp.sqrt(batch * seq).item() # ~113; covers Gaussian-tail |dW| +_FP8_REL_NOISE = float(jnp.finfo(jnp.float8_e4m3fn).eps) +_ATOL_FWD = 10.0 * _FP8_REL_NOISE +_ATOL_DW = _ATOL_FWD * jnp.sqrt(batch * seq).item() def test_baseline_runs(): From 4c1fec958b7354e0da29a5b4974b675794afda7d Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 15 May 2026 12:02:18 -0700 Subject: [PATCH 11/12] Guard tests by arch Signed-off-by: Jeremy Berchtold --- docs/examples/jax/test_dense.py | 44 ++++++++++++++++----------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/docs/examples/jax/test_dense.py b/docs/examples/jax/test_dense.py index c5363c75ce..049a7c9566 100644 --- a/docs/examples/jax/test_dense.py +++ b/docs/examples/jax/test_dense.py @@ -21,32 +21,20 @@ import quickstart_jax_utils as utils from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode -from dense import ( - baseline, - baseline_vars, - batch, - dtype, - dy, - out_features, - run_multi_gpu_bench, - run_single_gpu_bench, - seq, - te_model, - te_vars, - x, -) +# Imports from ``dense`` are intentionally deferred into each test body. dense.py +# runs ``te_vars = te_model.init(k_init, x)`` at module scope, which raises on +# devices without MXFP8 support (Hopper or older). A top-level import would fire +# that before pytest can apply the @requires_mxfp8 skip marks. _mxfp8_supported, _mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) requires_mxfp8 = pytest.mark.skipif( not _mxfp8_supported, reason=f"MXFP8 not supported on this device: {_mxfp8_reason}" ) -_FP8_REL_NOISE = float(jnp.finfo(jnp.float8_e4m3fn).eps) -_ATOL_FWD = 10.0 * _FP8_REL_NOISE -_ATOL_DW = _ATOL_FWD * jnp.sqrt(batch * seq).item() - def test_baseline_runs(): + from dense import baseline, baseline_vars, batch, dtype, out_features, seq, x + out = baseline.apply(baseline_vars, x) assert out.shape == (batch, seq, out_features) assert out.dtype == dtype @@ -54,6 +42,8 @@ def test_baseline_runs(): @requires_mxfp8 def test_te_dense_runs(): + from dense import batch, out_features, seq, te_model, te_vars, x + out = te_model.apply(te_vars, x) assert out.shape == (batch, seq, out_features) @@ -61,6 +51,12 @@ def test_te_dense_runs(): @requires_mxfp8 def test_te_matches_baseline(): """TE quantized Dense should match the bf16 baseline within MXFP8 tolerance.""" + from dense import baseline, baseline_vars, batch, dy, seq, te_model, te_vars, x + + fp8_rel_noise = float(jnp.finfo(jnp.float8_e4m3fn).eps) + atol_fwd = 10.0 * fp8_rel_noise + atol_dw = atol_fwd * jnp.sqrt(batch * seq).item() + utils.compare_fwd_bwd( baseline.apply, baseline_vars, @@ -68,19 +64,23 @@ def test_te_matches_baseline(): te_vars, input=x, output_grad=dy, - rtol=_FP8_REL_NOISE, - atol=_ATOL_FWD, - rtol_dW=_FP8_REL_NOISE, - atol_dW=_ATOL_DW, + rtol=fp8_rel_noise, + atol=atol_fwd, + rtol_dW=fp8_rel_noise, + atol_dW=atol_dw, ) @requires_mxfp8 def test_single_gpu_benchmark(): + from dense import run_single_gpu_bench + run_single_gpu_bench() @requires_mxfp8 @pytest.mark.skipif(len(jax.devices()) < 4, reason="needs 4 GPUs for DP=2/TP=2") def test_multi_gpu_benchmark(): + from dense import run_multi_gpu_bench + run_multi_gpu_bench() From 73ab760c9857db904244b1ff7bc71afa4dd5636d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 19:03:23 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/jax/quickstart_jax_utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/examples/jax/quickstart_jax_utils.py b/docs/examples/jax/quickstart_jax_utils.py index a8e1202d38..6547a5ff1a 100644 --- a/docs/examples/jax/quickstart_jax_utils.py +++ b/docs/examples/jax/quickstart_jax_utils.py @@ -139,9 +139,15 @@ def go(variables, inp, dy): y_ref, dW_ref, dx_ref = _run(ref_apply_fn)(ref_variables, input, output_grad) y_test, dW_test, dx_test = _run(test_apply_fn)(test_variables, input, output_grad) - np.testing.assert_allclose(y_test, y_ref, rtol=rtol, atol=atol, err_msg="forward output (y) mismatch") - np.testing.assert_allclose(dx_test, dx_ref, rtol=rtol, atol=atol, err_msg="input grad (dx) mismatch") - for ref_leaf, test_leaf in zip(jax.tree_util.tree_leaves(dW_ref), jax.tree_util.tree_leaves(dW_test)): + np.testing.assert_allclose( + y_test, y_ref, rtol=rtol, atol=atol, err_msg="forward output (y) mismatch" + ) + np.testing.assert_allclose( + dx_test, dx_ref, rtol=rtol, atol=atol, err_msg="input grad (dx) mismatch" + ) + for ref_leaf, test_leaf in zip( + jax.tree_util.tree_leaves(dW_ref), jax.tree_util.tree_leaves(dW_test) + ): np.testing.assert_allclose( test_leaf, ref_leaf, rtol=rtol_dW, atol=atol_dW, err_msg="params grad (dW) mismatch" )