Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions docs/source/en/optimization/tpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
<!--Copyright 2026 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# TorchTPU

[TorchTPU](https://github.com/google-pytorch/torch_tpu/) provides a PyTorch backend for Google's Tensor Processing Units (TPUs), enabling you to run diffusers pipelines on Google Cloud TPUs (v6e, v5p, …) with minimal code changes.

Four execution modes are available:

| Mode | Constant | How to activate | Notes |
|---|---|---|---|
| **Strict Eager** (default) | `EagerMode.DEFER_NEVER` | just `import torch_tpu` | Operations dispatched one at a time, asynchronous |
| **Debug Eager** | `EagerMode.DEFER_NEVER_AND_LAUNCH_BLOCKING` | `set_eager_mode(EagerMode.DEFER_NEVER_AND_LAUNCH_BLOCKING)` or `TPU_LAUNCH_BLOCKING=1` | Synchronous execution; useful for pinpointing errors |
| **Fused Eager** | `EagerMode.DEFER_AND_FUSE` | `set_eager_mode(EagerMode.DEFER_AND_FUSE)` or `TPU_DEFER_AND_FUSE=1` | Groups multiple ops for XLA fusion; best throughput in eager mode |
| **Compile** | — | `pipe.enable_tpu_compile()` | AOT compilation with `TpuBackend` |

## Installation

Follow the [TorchTPU installation guide](https://github.com/google-pytorch/torch_tpu/). After installation,
`import torch_tpu` registers the `"tpu"` device automatically.

## Basic usage (strict eager mode)

```python
import torch
import torch_tpu # noqa: F401 — registers torch.tpu

from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
)

# Move only the denoising components to TPU; text encoders stay on CPU.
pipe.transformer.to("tpu")
pipe.vae.to("tpu")

# _execution_device is now "tpu" automatically.
image = pipe(
prompt="a golden retriever surfing a wave, photorealistic",
height=1024,
width=1024,
num_inference_steps=4,
guidance_scale=0.0,
).images[0]

image.save("output.png")
```

## Compiled mode (recommended for production)

`torch.compile` with `TpuBackend` traces the transformer statically. The first call (warmup)
is slow because it triggers compilation; subsequent calls reuse the compiled graph.

> [!IMPORTANT]
> TorchTPU requires **static shapes** — `torch.compile` is called with `dynamic=False`
> internally. Every time `height`, `width`, or `num_inference_steps` changes, the graph is
> recompiled from scratch. Keep these values constant across all calls after warmup, or call
> `tpu_warmup` again before changing them.

```python
import torch
import torch_tpu # noqa: F401

from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
)
pipe.transformer.to("tpu")
pipe.vae.to("tpu")

# Compile TPU components with TpuBackend.
# Also applies AttnProcessor to replace SDP-based attention (required for XLA).
pipe.enable_tpu_compile()

# Warmup — triggers static graph compilation.
pipe.tpu_warmup(
prompt="warmup",
height=1024,
width=1024,
num_inference_steps=4,
guidance_scale=0.0,
)

# Timed inference reuses the compiled graph.
image = pipe(
prompt="a golden retriever surfing a wave, photorealistic",
height=1024,
width=1024,
num_inference_steps=4,
guidance_scale=0.0,
).images[0]

image.save("output.png")
```

## Eager mode

TorchTPU defaults to **Strict Eager** (`EagerMode.DEFER_NEVER`): operations are dispatched one
at a time asynchronously, matching standard PyTorch GPU behaviour. Two alternative eager modes
are available:

**Debug Eager** — synchronous execution; every op blocks until the TPU finishes. Useful for
pinpointing the exact line that raises an error. Equivalent to `CUDA_LAUNCH_BLOCKING=1` on GPU.

```python
from torch_tpu._internal import execution_mode as em

# Globally for the session
em.eager_mode = em.EagerMode.DEFER_NEVER_AND_LAUNCH_BLOCKING

# Or via environment variable (before importing torch_tpu):
# TPU_LAUNCH_BLOCKING=1
```

**Fused Eager** — defers ops and lets the XLA compiler fuse across operation boundaries,
reducing memory traffic and dispatch overhead without full AOT compilation.

```python
from torch_tpu._internal import execution_mode as em

# Globally for the session
em.eager_mode = em.EagerMode.DEFER_AND_FUSE

# Or via environment variable (before importing torch_tpu):
# TPU_DEFER_AND_FUSE=1
```

Use `set_eager_mode` as a context manager to switch modes for a single block:

```python
from torch_tpu._internal import execution_mode as em

with em.set_eager_mode(em.EagerMode.DEFER_NEVER_AND_LAUNCH_BLOCKING):
# synchronous — pinpoints the exact failing line
output = model(input_data)
```

> [!TIP]
> For the best production throughput, prefer `torch.compile` via `pipe.enable_tpu_compile()`,
> which uses an Ahead-of-Time (AOT) strategy more aggressive than Fused Eager.

## API reference

### `enable_tpu_compile`

[[autodoc]] diffusers.DiffusionPipeline.enable_tpu_compile

### `tpu_warmup`

[[autodoc]] diffusers.DiffusionPipeline.tpu_warmup
15 changes: 12 additions & 3 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
deprecate,
logging,
)
from ...utils.torch_utils import is_compiled_module
from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import (
Expand Down Expand Up @@ -855,18 +856,26 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float |
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
is_tpu = sample.device.type == "tpu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
dtype = torch.float32 if (is_mps or is_npu or is_tpu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
dtype = torch.int32 if (is_mps or is_npu or is_tpu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])

t_emb = self.time_proj(timesteps)
# On TPU in eager/lazy mode, torch.cat([sin, cos], dim=-1) inside time_proj
# lands at an unaligned offset in the XLA DUS fusion emitter → crash.
# torch.compile with TpuBackend handles this internally, so skip the CPU
# workaround when we're inside a compiled graph.
if sample.device.type == "tpu" and not torch.compiler.is_compiling():
t_emb = self.time_proj(timesteps.cpu()).to(sample.device)
else:
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _enhance_prompt_with_pe(
tokenize=False,
add_generation_prompt=False, # "Output:" is already in the user block
)
inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device)
inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(self.pe.device)
output_ids = self.pe.generate(
**inputs,
max_new_tokens=self.pe_tokenizer.model_max_length,
Expand Down Expand Up @@ -155,7 +155,7 @@ def encode_prompt(
else:
ids = [0]

input_ids = torch.tensor([ids], device=device)
input_ids = torch.tensor([ids], device=self.text_encoder.device)
with torch.no_grad():
outputs = self.text_encoder(
input_ids=input_ids,
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -296,7 +297,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -327,7 +328,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -321,7 +322,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_kontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -343,7 +344,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
model_device = self.text_encoder_2.device
prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
Expand Down Expand Up @@ -376,7 +377,8 @@ def _get_clip_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
model_device = self.text_encoder.device
prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False)

# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/pipelines/flux2/pipeline_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ def _get_qwen3_prompt_embeds(
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])

input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
model_device = text_encoder.device
input_ids = torch.cat(all_input_ids, dim=0).to(model_device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(model_device)

# Forward pass through the model
output = text_encoder(
Expand Down
Loading
Loading