-
Notifications
You must be signed in to change notification settings - Fork 725
[JAX] Improve JAX tutorial documentation #2976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9fceea0
2cb3cb2
143391d
4f543b5
5432ec6
7c74aaf
89d1cef
aa7d624
74e9c58
15f10b8
168cc63
4c1fec9
73ab760
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| .. | ||
| Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
|
||
| See LICENSE for license information. | ||
|
|
||
| JAX: Attention with TransformerEngine | ||
| ===================================== | ||
|
|
||
| **TODO — Coming soon.** | ||
|
|
||
| `← Back to the JAX integration overview <../te_jax_integration.html>`_ | ||
|
Comment on lines
+1
to
+11
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated to attention but looks like you are renaming the dir to
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, updated to |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| .. | ||
| Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
|
||
| See LICENSE for license information. | ||
|
|
||
| JAX: Collective GEMMs with TransformerEngine | ||
| ============================================= | ||
|
|
||
| **TODO — Coming soon.** | ||
|
|
||
| `← Back to the JAX integration overview <../te_jax_integration.html>`_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # Numbers below are illustrative (captured on a GB200). Regenerate with: | ||
| # python3 docs/examples/jax/dense.py > dense.out | ||
|
|
||
| # SINGLE_GPU_OUTPUT_START | ||
| Variable collections: ['params'] | ||
| {'params': {'Dense_0': {'kernel': ((8192, 32768), dtype('float32'))}}} | ||
|
|
||
| bf16 baseline: | ||
| Mean time: 18.056 ms | ||
|
|
||
| TE MXFP8BlockScaling: | ||
| Mean time: 11.260 ms | ||
| # SINGLE_GPU_OUTPUT_END | ||
|
|
||
| # MULTI_GPU_OUTPUT_START | ||
| bf16 DP=2/TP=2: | ||
| Mean time: 5.516 ms | ||
|
|
||
| TE MXFP8BlockScaling DP=2/TP=2: | ||
| Mean time: 3.712 ms | ||
| # MULTI_GPU_OUTPUT_END |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """JAX: Dense GEMMs with TransformerEngine. | ||
|
|
||
| Companion source for ``dense.rst``. Code blocks between ``# DENSE_*_START`` / | ||
| ``# DENSE_*_END`` markers are pulled into the RST via ``literalinclude``. | ||
|
|
||
| Run as a script to exercise the example end-to-end: | ||
|
|
||
| python docs/examples/jax/dense.py | ||
|
|
||
| Pytest tests live in ``test_dense.py``; the multi-GPU section auto-skips when | ||
| fewer than 4 GPUs are visible. | ||
| """ | ||
|
|
||
| # DENSE_IMPORTS_START | ||
| import jax | ||
| import jax.numpy as jnp | ||
| from flax import linen as nn | ||
|
|
||
| import quickstart_jax_utils as utils | ||
|
|
||
| # DENSE_IMPORTS_END | ||
|
|
||
|
|
||
| # DENSE_BASELINE_MODEL_START | ||
| class FlaxDenseBlock(nn.Module): | ||
| """One linear layer. ``dot_general_cls`` lets us swap the GEMM impl.""" | ||
|
|
||
| features: int | ||
| dtype: jnp.dtype = jnp.bfloat16 | ||
| dot_general_cls: callable = lambda: None | ||
|
|
||
| @nn.compact | ||
| def __call__(self, x): | ||
| return nn.Dense( | ||
| features=self.features, | ||
| use_bias=False, | ||
| dtype=self.dtype, | ||
| dot_general=self.dot_general_cls(), | ||
| )(x) | ||
|
|
||
|
|
||
| # DENSE_BASELINE_MODEL_END | ||
|
|
||
|
|
||
| # DENSE_INPUTS_SETUP_START | ||
| batch, seq, hidden, out_features = 8, 2048, 8192, 32768 | ||
| dtype = jnp.bfloat16 | ||
|
|
||
| key = jax.random.PRNGKey(0) | ||
| k_init, k_x, k_dy = jax.random.split(key, 3) | ||
| x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype) | ||
| dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype) | ||
|
|
||
| baseline = FlaxDenseBlock(features=out_features) | ||
| baseline_vars = baseline.init(k_init, x) | ||
| # DENSE_INPUTS_SETUP_END | ||
|
|
||
|
|
||
| # DENSE_TE_SETUP_START | ||
| from transformer_engine.jax import flax as te_flax | ||
| from transformer_engine.common.recipe import MXFP8BlockScaling | ||
|
|
||
| recipe = MXFP8BlockScaling() | ||
| te_dot_general_cls = te_flax.make_dot_general_cls(recipe) | ||
|
|
||
| te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls) | ||
| te_vars = te_model.init(k_init, x) | ||
|
|
||
| print("Variable collections:", list(te_vars.keys())) | ||
| print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars)) | ||
| # DENSE_TE_SETUP_END | ||
|
|
||
|
|
||
| # DENSE_SINGLE_GPU_BENCH_START | ||
| def run_single_gpu_bench(): | ||
| print("bf16 baseline:") | ||
| utils.speedometer( | ||
| model_apply_fn=baseline.apply, | ||
| variables=baseline_vars, | ||
| input=x, | ||
| output_grad=dy, | ||
| ) | ||
|
|
||
| print(f"\nTE {type(recipe).__name__}:") | ||
| utils.speedometer( | ||
| model_apply_fn=te_model.apply, | ||
| variables=te_vars, | ||
| input=x, | ||
| output_grad=dy, | ||
|
jberchtold-nvidia marked this conversation as resolved.
|
||
| ) | ||
|
|
||
|
|
||
| # DENSE_SINGLE_GPU_BENCH_END | ||
|
|
||
|
|
||
| # DENSE_MULTI_GPU_MESH_SETUP_START | ||
| from jax.sharding import Mesh, NamedSharding, PartitionSpec as P | ||
| from jax.experimental import mesh_utils | ||
| from transformer_engine.jax.sharding import MeshResource, global_shard_guard | ||
|
|
||
|
|
||
| def build_dp_tp_mesh(): | ||
| # 2x2 mesh: DP on one axis, TP on the other. | ||
| devices = mesh_utils.create_device_mesh((2, 2)) | ||
| mesh = Mesh(devices, axis_names=("dp", "tp")) | ||
|
|
||
| # Tell TE which mesh axis is which. This is a *global* setting, established | ||
| # outside JIT, so TE's GEMM primitives can plan comms accordingly. | ||
| mesh_resource = MeshResource(dp_resource="dp", tp_resource="tp") | ||
| return mesh, mesh_resource | ||
|
|
||
|
|
||
| # DENSE_MULTI_GPU_MESH_SETUP_END | ||
|
|
||
|
|
||
| # DENSE_MULTI_GPU_SHARD_SETUP_START | ||
| def shard_variables(mesh, variables_dict): | ||
| kernel_sharding = NamedSharding(mesh, P(None, "tp")) | ||
|
|
||
| def _shard(variables): | ||
| params = variables["params"] | ||
| sharded = jax.device_put(params["Dense_0"]["kernel"], kernel_sharding) | ||
| return { | ||
| **variables, | ||
| "params": { | ||
| **params, | ||
| "Dense_0": {**params["Dense_0"], "kernel": sharded}, | ||
| }, | ||
| } | ||
|
|
||
| input_sharding = NamedSharding(mesh, P("dp", None, None)) | ||
| output_grad_sharding = NamedSharding(mesh, P("dp", None, "tp")) | ||
|
|
||
| return { | ||
| "x": jax.device_put(x, input_sharding), | ||
| "dy": jax.device_put(dy, output_grad_sharding), | ||
| **{name: _shard(vars_) for name, vars_ in variables_dict.items()}, | ||
| } | ||
|
|
||
|
|
||
| # DENSE_MULTI_GPU_SHARD_SETUP_END | ||
|
|
||
|
|
||
| # DENSE_MULTI_GPU_BENCH_START | ||
| def run_multi_gpu_bench(): | ||
| mesh, mesh_resource = build_dp_tp_mesh() | ||
| sharded = shard_variables(mesh, {"baseline": baseline_vars, "te": te_vars}) | ||
|
|
||
| with jax.set_mesh(mesh), global_shard_guard(mesh_resource): | ||
| print("bf16 DP=2/TP=2:") | ||
| utils.speedometer( | ||
| model_apply_fn=baseline.apply, | ||
| variables=sharded["baseline"], | ||
| input=sharded["x"], | ||
| output_grad=sharded["dy"], | ||
| ) | ||
|
|
||
| print(f"\nTE {type(recipe).__name__} DP=2/TP=2:") | ||
| utils.speedometer( | ||
| model_apply_fn=te_model.apply, | ||
| variables=sharded["te"], | ||
| input=sharded["x"], | ||
| output_grad=sharded["dy"], | ||
| ) | ||
|
|
||
|
|
||
| # DENSE_MULTI_GPU_BENCH_END | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_single_gpu_bench() | ||
| if len(jax.devices()) >= 4: | ||
| print() | ||
| run_multi_gpu_bench() | ||
| else: | ||
| print("\n[skipped multi-GPU section: <4 devices visible]") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| .. | ||
| Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
|
||
| See LICENSE for license information. | ||
|
|
||
| JAX: Dense GEMMs with TransformerEngine | ||
| ======================================= | ||
|
|
||
| This document walks through replacing a plain ``flax.linen.Dense``'s GEMM with | ||
| TransformerEngine's quantized GEMM. | ||
|
|
||
| **Recipe.** We use ``MXFP8BlockScaling`` in this tutorial. ``MXFP8BlockScaling`` and | ||
| ``NVFP4BlockScaling`` require a Blackwell-class GPU; on Hopper, swap in | ||
| ``DelayedScaling`` or ``Float8CurrentScaling``. | ||
|
|
||
| `← Back to the JAX integration overview <../te_jax_integration.html>`_ | ||
|
|
||
| 1. Baseline: a plain Flax Dense block | ||
| ------------------------------------- | ||
|
|
||
| We isolate the optimization to a single linear layer so it's clear what's | ||
| changing. ``dot_general_cls`` is exposed as a constructor argument so we can swap | ||
| in TE later without touching the model definition. | ||
|
|
||
| .. literalinclude:: dense.py | ||
| :language: python | ||
| :start-after: # DENSE_BASELINE_MODEL_START | ||
| :end-before: # DENSE_BASELINE_MODEL_END | ||
|
|
||
| .. literalinclude:: dense.py | ||
| :language: python | ||
| :start-after: # DENSE_INPUTS_SETUP_START | ||
| :end-before: # DENSE_INPUTS_SETUP_END | ||
|
|
||
|
|
||
| 2. Quantized Dense via ``make_dot_general_cls`` | ||
| ----------------------------------------------- | ||
|
|
||
| TE exposes a helper, ``te_flax.make_dot_general_cls(recipe)``, that returns a Flax | ||
| module class you pass directly to ``nn.Dense(..., dot_general=...)``. | ||
|
|
||
| With this API, TE doesn't create the ``kernel`` params; it only wraps the GEMM. | ||
| All your initialization, sharding annotations, and optimizer state stay where | ||
| they were. | ||
|
|
||
| .. literalinclude:: dense.py | ||
| :language: python | ||
| :start-after: # DENSE_TE_SETUP_START | ||
| :end-before: # DENSE_TE_SETUP_END | ||
|
|
||
| .. note:: | ||
|
|
||
| **What about DelayedScaling state?** | ||
|
|
||
| Most recipes are stateless — scaling factors are computed from each tensor | ||
| as it flows through the GEMM, so there is nothing to persist across steps. | ||
| However, if you swap in ``DelayedScaling`` instead, ``init`` will produce a | ||
| second variable collection, ``_overwrite_with_gradient``, holding | ||
| ``kernel_amax_history``, ``kernel_scale``, ``x_amax_history``, ``x_scale``, | ||
| etc. These are **not** model parameters — they are Flax variables that TE | ||
| updates each step to compute per-tensor scales from a rolling amax window. | ||
|
|
||
| If you use ``DelayedScaling``, you must thread the *entire* ``var_collect`` | ||
| through your training loop (not just ``params``) so the history persists | ||
| across steps. ``MXFP8BlockScaling``, ``NVFP4BlockScaling``, and | ||
| ``Float8CurrentScaling`` do not require this. | ||
|
|
||
|
|
||
| 3. Single-GPU performance | ||
| ------------------------- | ||
|
|
||
| ``speedometer`` runs a JIT-compiled forward+backward loop with warmup, on the | ||
| same input for both models. | ||
|
|
||
| .. literalinclude:: dense.py | ||
| :language: python | ||
| :start-after: # DENSE_SINGLE_GPU_BENCH_START | ||
| :end-before: # DENSE_SINGLE_GPU_BENCH_END | ||
|
|
||
| .. raw:: html | ||
|
|
||
| <div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;"> | ||
| Output: | ||
| </div> | ||
|
|
||
| .. container:: program-output | ||
|
|
||
| .. literalinclude:: dense.out | ||
| :language: text | ||
| :start-after: # SINGLE_GPU_OUTPUT_START | ||
| :end-before: # SINGLE_GPU_OUTPUT_END | ||
|
|
||
| On a single GB200, that's roughly **2.5× faster** for the fwd+bwd of one large | ||
| Dense — and the only code change was passing ``dot_general=te_dot_general_cls()`` | ||
| into ``nn.Dense``. | ||
|
|
||
| The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may | ||
| not benefit at all because the cast + scale overhead can dominate. | ||
|
|
||
| .. warning:: | ||
|
|
||
| **Remat / activation checkpointing.** If your training loop uses | ||
| ``jax.checkpoint_policies.checkpoint_dots`` (or any policy that matches | ||
| ``jax.lax.dot_general``), swap it for | ||
| ``transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms``. | ||
| Otherwise TE's quantized GEMM primitives won't be checkpointed correctly | ||
| and your performance comparison will not be accurate. | ||
|
|
||
|
|
||
| 4. Multi-GPU: DP=2 / TP=2 on a single Dense | ||
| ------------------------------------------- | ||
|
|
||
| **Prerequisite:** this section requires four GPUs. | ||
|
|
||
| Keeping the same ``FlaxDenseBlock`` from the rest of the document, we run it on | ||
| a 2×2 mesh with **data parallelism** on one axis and **tensor parallelism** | ||
| (column-parallel: shard the kernel's output dim) on the other. | ||
|
|
||
| Two pieces wire this up: | ||
|
|
||
| 1. A ``jax.sharding.Mesh`` you build once at module scope (outside JIT). | ||
| 2. TE's ``MeshResource``, set globally via ``global_shard_guard``, which tells | ||
| TE which mesh axes are DP and TP. | ||
|
|
||
| .. literalinclude:: dense.py | ||
| :language: python | ||
| :start-after: # DENSE_MULTI_GPU_MESH_SETUP_START | ||
| :end-before: # DENSE_MULTI_GPU_MESH_SETUP_END | ||
|
|
||
| **Sharding plan:** | ||
|
|
||
| .. csv-table:: | ||
| :header: "Tensor", "Shape", "PartitionSpec" | ||
| :widths: 30, 40, 30 | ||
|
|
||
| "Kernel (column-parallel)", "``(hidden, out_features)``", "``P(None, 'tp')``" | ||
| "Input activations", "``(batch, seq, hidden)``", "``P('dp', None, None)``" | ||
| "Gradient on output", "``(batch, seq, out_features)``", "``P('dp', None, 'tp')``" | ||
|
|
||
| .. literalinclude:: dense.py | ||
| :language: python | ||
| :start-after: # DENSE_MULTI_GPU_SHARD_SETUP_START | ||
| :end-before: # DENSE_MULTI_GPU_SHARD_SETUP_END | ||
|
|
||
| .. literalinclude:: dense.py | ||
| :language: python | ||
| :start-after: # DENSE_MULTI_GPU_BENCH_START | ||
| :end-before: # DENSE_MULTI_GPU_BENCH_END | ||
|
|
||
| .. raw:: html | ||
|
|
||
| <div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;"> | ||
| Output: | ||
| </div> | ||
|
|
||
| .. container:: program-output | ||
|
|
||
| .. literalinclude:: dense.out | ||
| :language: text | ||
| :start-after: # MULTI_GPU_OUTPUT_START | ||
| :end-before: # MULTI_GPU_OUTPUT_END | ||
|
|
||
|
|
||
| Next steps | ||
| ---------- | ||
|
|
||
| * `Collective GEMM <collective_gemm.html>`_: further speedups by communicating between devices inside the GEMM. | ||
| * `← Hub <../te_jax_integration.html>`_ |
Uh oh!
There was an error while loading. Please reload this page.