diff --git a/docs/source/en/optimization/tpu.md b/docs/source/en/optimization/tpu.md new file mode 100644 index 000000000000..8018fb9a8bbf --- /dev/null +++ b/docs/source/en/optimization/tpu.md @@ -0,0 +1,163 @@ + + +# TorchTPU + +[TorchTPU](https://github.com/google-pytorch/torch_tpu/) provides a PyTorch backend for Google's Tensor Processing Units (TPUs), enabling you to run diffusers pipelines on Google Cloud TPUs (v6e, v5p, …) with minimal code changes. + +Four execution modes are available: + +| Mode | Constant | How to activate | Notes | +|---|---|---|---| +| **Strict Eager** (default) | `EagerMode.DEFER_NEVER` | just `import torch_tpu` | Operations dispatched one at a time, asynchronous | +| **Debug Eager** | `EagerMode.DEFER_NEVER_AND_LAUNCH_BLOCKING` | `set_eager_mode(EagerMode.DEFER_NEVER_AND_LAUNCH_BLOCKING)` or `TPU_LAUNCH_BLOCKING=1` | Synchronous execution; useful for pinpointing errors | +| **Fused Eager** | `EagerMode.DEFER_AND_FUSE` | `set_eager_mode(EagerMode.DEFER_AND_FUSE)` or `TPU_DEFER_AND_FUSE=1` | Groups multiple ops for XLA fusion; best throughput in eager mode | +| **Compile** | — | `pipe.enable_tpu_compile()` | AOT compilation with `TpuBackend` | + +## Installation + +Follow the [TorchTPU installation guide](https://github.com/google-pytorch/torch_tpu/). After installation, +`import torch_tpu` registers the `"tpu"` device automatically. + +## Basic usage (strict eager mode) + +```python +import torch +import torch_tpu # noqa: F401 — registers torch.tpu + +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16, +) + +# Move only the denoising components to TPU; text encoders stay on CPU. +pipe.transformer.to("tpu") +pipe.vae.to("tpu") + +# _execution_device is now "tpu" automatically. +image = pipe( + prompt="a golden retriever surfing a wave, photorealistic", + height=1024, + width=1024, + num_inference_steps=4, + guidance_scale=0.0, +).images[0] + +image.save("output.png") +``` + +## Compiled mode (recommended for production) + +`torch.compile` with `TpuBackend` traces the transformer statically. The first call (warmup) +is slow because it triggers compilation; subsequent calls reuse the compiled graph. + +> [!IMPORTANT] +> TorchTPU requires **static shapes** — `torch.compile` is called with `dynamic=False` +> internally. Every time `height`, `width`, or `num_inference_steps` changes, the graph is +> recompiled from scratch. Keep these values constant across all calls after warmup, or call +> `tpu_warmup` again before changing them. + +```python +import torch +import torch_tpu # noqa: F401 + +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16, +) +pipe.transformer.to("tpu") +pipe.vae.to("tpu") + +# Compile TPU components with TpuBackend. +# Also applies AttnProcessor to replace SDP-based attention (required for XLA). +pipe.enable_tpu_compile() + +# Warmup — triggers static graph compilation. +pipe.tpu_warmup( + prompt="warmup", + height=1024, + width=1024, + num_inference_steps=4, + guidance_scale=0.0, +) + +# Timed inference reuses the compiled graph. +image = pipe( + prompt="a golden retriever surfing a wave, photorealistic", + height=1024, + width=1024, + num_inference_steps=4, + guidance_scale=0.0, +).images[0] + +image.save("output.png") +``` + +## Eager mode + +TorchTPU defaults to **Strict Eager** (`EagerMode.DEFER_NEVER`): operations are dispatched one +at a time asynchronously, matching standard PyTorch GPU behaviour. Two alternative eager modes +are available: + +**Debug Eager** — synchronous execution; every op blocks until the TPU finishes. Useful for +pinpointing the exact line that raises an error. Equivalent to `CUDA_LAUNCH_BLOCKING=1` on GPU. + +```python +from torch_tpu._internal import execution_mode as em + +# Globally for the session +em.eager_mode = em.EagerMode.DEFER_NEVER_AND_LAUNCH_BLOCKING + +# Or via environment variable (before importing torch_tpu): +# TPU_LAUNCH_BLOCKING=1 +``` + +**Fused Eager** — defers ops and lets the XLA compiler fuse across operation boundaries, +reducing memory traffic and dispatch overhead without full AOT compilation. + +```python +from torch_tpu._internal import execution_mode as em + +# Globally for the session +em.eager_mode = em.EagerMode.DEFER_AND_FUSE + +# Or via environment variable (before importing torch_tpu): +# TPU_DEFER_AND_FUSE=1 +``` + +Use `set_eager_mode` as a context manager to switch modes for a single block: + +```python +from torch_tpu._internal import execution_mode as em + +with em.set_eager_mode(em.EagerMode.DEFER_NEVER_AND_LAUNCH_BLOCKING): + # synchronous — pinpoints the exact failing line + output = model(input_data) +``` + +> [!TIP] +> For the best production throughput, prefer `torch.compile` via `pipe.enable_tpu_compile()`, +> which uses an Ahead-of-Time (AOT) strategy more aggressive than Fused Eager. + +## API reference + +### `enable_tpu_compile` + +[[autodoc]] diffusers.DiffusionPipeline.enable_tpu_compile + +### `tpu_warmup` + +[[autodoc]] diffusers.DiffusionPipeline.tpu_warmup diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 38a41a3dc93f..31b69f23f616 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -26,7 +26,7 @@ deprecate, logging, ) -from ...utils.torch_utils import maybe_adjust_dtype_for_device +from ...utils.torch_utils import is_compiled_module, maybe_adjust_dtype_for_device from ..activations import get_activation from ..attention import AttentionMixin from ..attention_processor import ( @@ -864,7 +864,14 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - t_emb = self.time_proj(timesteps) + # On TPU in eager/lazy mode, torch.cat([sin, cos], dim=-1) inside time_proj + # lands at an unaligned offset in the XLA DUS fusion emitter → crash. + # torch.compile with TpuBackend handles this internally, so skip the CPU + # workaround when we're inside a compiled graph. + if sample.device.type == "tpu" and not torch.compiler.is_compiling(): + t_emb = self.time_proj(timesteps.cpu()).to(sample.device) + else: + t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 11fce6a204bf..df5b27ace653 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -114,7 +114,7 @@ def _enhance_prompt_with_pe( tokenize=False, add_generation_prompt=False, # "Output:" is already in the user block ) - inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device) + inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(self.pe.device) output_ids = self.pe.generate( **inputs, max_new_tokens=self.pe_tokenizer.model_max_length, @@ -155,7 +155,7 @@ def encode_prompt( else: ids = [0] - input_ids = torch.tensor([ids], device=device) + input_ids = torch.tensor([ids], device=self.text_encoder.device) with torch.no_grad(): outputs = self.text_encoder( input_ids=input_ids, diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 34cbf0faa667..b79522252249 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -251,7 +251,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -296,7 +297,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index e7792d667f16..7bedf12a7c7c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -282,7 +282,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -327,7 +328,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 4c35ffefe088..8935ae114c78 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -275,7 +275,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -321,7 +322,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index e32bfecfcdad..3f6e95b0a6b8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -297,7 +297,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -343,7 +344,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py index c85299eedcd3..8bf645cab04c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py @@ -330,7 +330,8 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + model_device = self.text_encoder_2.device + prompt_embeds = self.text_encoder_2(text_input_ids.to(model_device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -376,7 +377,8 @@ def _get_clip_prompt_embeds( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py index d768e6127f26..98a87eacbbeb 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -241,8 +241,9 @@ def _get_qwen3_prompt_embeds( all_input_ids.append(inputs["input_ids"]) all_attention_masks.append(inputs["attention_mask"]) - input_ids = torch.cat(all_input_ids, dim=0).to(device) - attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + model_device = text_encoder.device + input_ids = torch.cat(all_input_ids, dim=0).to(model_device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(model_device) # Forward pass through the model output = text_encoder( diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1fa4db90d995..b55d3fed29ad 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -64,10 +64,12 @@ is_bitsandbytes_version, is_hpu_available, is_torch_npu_available, + is_torch_tpu_available, is_torch_version, is_transformers_version, logging, numpy_to_pil, + requires_backends, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card @@ -1162,6 +1164,15 @@ def _execution_device(self): except ValueError: pass + # When text encoders are offloaded to CPU while the denoising backbone + # (unet, transformer, vae) runs on an accelerator, self.device returns CPU + # (first component). Prefer any non-CPU, non-meta component device so that + # scheduler and latent tensors land on the accelerator. This covers TPU, + # NPU (npu), Intel GPU (xpu), Habana (hpu), and any other backend. + for name, model in self.components.items(): + if isinstance(model, torch.nn.Module) and model.device.type not in ("cpu", "meta"): + return model.device + for name, model in self.components.items(): if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: continue @@ -2387,3 +2398,89 @@ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): else: self.vae.unfuse_qkv_projections() self.fusing_vae = False + + def enable_tpu_compile( + self, + model_names: Optional[List[str]] = None, + **compile_kwargs, + ) -> None: + """Compile pipeline components that are on TPU using ``torch.compile`` with the ``TpuBackend``. + + Before compiling, each component that exposes ``set_attn_processor`` has ``AttnProcessor`` + applied. This replaces ``AttnProcessor2_0`` (SDP-based) which triggers XLA fusion-emitter + crashes in eager/lazy mode. ``TpuBackend`` handles the resulting ``torch.cat`` layout + internally during static tracing, so no additional wrapper is needed at compile time. + + Args: + model_names (`list[str]`, *optional*): + Names of pipeline components to compile. Defaults to all ``torch.nn.Module`` + components currently resident on a TPU device. + **compile_kwargs: + Extra keyword arguments forwarded to ``torch.compile``. ``backend`` defaults to + ``TpuBackend()`` and ``dynamic`` defaults to ``False`` (required for static tracing). + + Example: + ```python + import torch + import torch_tpu # noqa: F401 + + pipe.transformer.to("tpu") + pipe.vae.to("tpu") + pipe.enable_tpu_compile() + ``` + """ + requires_backends(self, "torch_tpu") + from torch_tpu._internal.compile import TpuBackend + + from ..models.attention_processor import AttnProcessor + + if model_names is None: + model_names = [ + name + for name, comp in self.components.items() + if isinstance(comp, torch.nn.Module) and comp.device.type == "tpu" + ] + + for name in model_names: + component = getattr(self, name, None) + if not isinstance(component, torch.nn.Module): + logger.warning(f"`enable_tpu_compile`: component '{name}' is not a nn.Module, skipping.") + continue + if is_compiled_module(component): + logger.warning(f"`enable_tpu_compile`: component '{name}' is already compiled, skipping.") + continue + if hasattr(component, "set_attn_processor"): + component.set_attn_processor(AttnProcessor()) + compile_kwargs.setdefault("backend", TpuBackend()) + compile_kwargs.setdefault("dynamic", False) + logger.info(f"Compiling '{name}' with TpuBackend.") + setattr(self, name, torch.compile(component, **compile_kwargs)) + + def tpu_warmup(self, *args, **kwargs) -> None: + """Run a single forward pass to trigger XLA / ``TpuBackend`` compilation. + + Call this after ``enable_tpu_compile`` and before timed inference. The warmup + pass compiles the static computation graphs; subsequent calls reuse the compiled + graphs and run at full speed. + + Args: + *args: Positional arguments forwarded to the pipeline ``__call__``. + **kwargs: Keyword arguments forwarded to the pipeline ``__call__``. + + Example: + ```python + pipe.tpu_warmup( + prompt="warmup", + height=1024, + width=1024, + num_inference_steps=4, + guidance_scale=0.0, + ) + ``` + """ + logger.info("Running TPU warmup pass to trigger XLA compilation...") + with torch.no_grad(): + self(*args, **kwargs) + if hasattr(torch, "tpu") and hasattr(torch.tpu, "synchronize"): + torch.tpu.synchronize() + logger.info("TPU warmup complete.") diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index d08b6c5a5973..978124b9c87d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -402,7 +402,7 @@ def encode_prompt( f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: @@ -463,7 +463,7 @@ def encode_prompt( ) negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), + uncond_input.input_ids.to(text_encoder.device), output_hidden_states=True, ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 19ccfab3de0a..eadaf543b9d0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -420,7 +420,7 @@ def encode_prompt( f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: @@ -481,7 +481,7 @@ def encode_prompt( ) negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), + uncond_input.input_ids.to(text_encoder.device), output_hidden_states=True, ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 7382d597102c..f03d01c90ec1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -524,7 +524,7 @@ def encode_prompt( f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: @@ -585,7 +585,7 @@ def encode_prompt( ) negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), + uncond_input.input_ids.to(text_encoder.device), output_hidden_states=True, ) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index be2d53f17932..b7e8a71cd50c 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -182,7 +182,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( @@ -654,7 +655,7 @@ def __call__( self._current_timestep = None if not output_type == "latent": - latents = latents.to(self.vae.dtype) + latents = latents.to(self.vae.device, dtype=self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_animate.py b/src/diffusers/pipelines/wan/pipeline_wan_animate.py index 5806032c0142..91960fad7d36 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_animate.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_animate.py @@ -259,7 +259,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 8061f67ab6b9..59a844e088ae 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -223,7 +223,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index b0896d382d67..8c72adf09d6a 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -228,7 +228,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index 8993475a2851..9064000ab35b 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -246,7 +246,8 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + model_device = self.text_encoder.device + prompt_embeds = self.text_encoder(text_input_ids.to(model_device), mask.to(model_device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 5cd6885e0364..10508e39d599 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -117,6 +117,7 @@ is_torch_mlu_available, is_torch_neuronx_available, is_torch_npu_available, + is_torch_tpu_available, is_torch_version, is_torch_xla_available, is_torch_xla_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index a0fa882d2705..4c390496fc31 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu") +_torch_tpu_available, _torch_tpu_version = _is_package_available("torch_tpu") _torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") @@ -253,6 +254,10 @@ def is_torch_mlu_available(): return _torch_mlu_available +def is_torch_tpu_available(): + return _torch_tpu_available + + def is_torch_neuronx_available(): return _torch_neuronx_available @@ -594,6 +599,11 @@ def is_av_available(): torchao` """ +TORCH_TPU_IMPORT_ERROR = """ +{0} requires the torch_tpu library but it was not found in your environment. Please follow the installation +instructions at https://github.com/pytorch/tpu +""" + QUANTO_IMPORT_ERROR = """ {0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip install optimum-quanto` @@ -650,6 +660,7 @@ def is_av_available(): ("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)), ("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)), ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), + ("torch_tpu", (is_torch_tpu_available, TORCH_TPU_IMPORT_ERROR)), ("torch_neuronx", (is_torch_neuronx_available, TORCH_NEURONX_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 263334dce8cd..4b641eaf8abc 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -45,6 +45,7 @@ "cpu": True, "mps": False, "neuron": False, + "tpu": False, "default": True, } BACKEND_EMPTY_CACHE = { @@ -52,6 +53,7 @@ "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, + "tpu": getattr(getattr(torch, "tpu", None), "empty_cache", None), "neuron": None, "default": None, } @@ -60,6 +62,7 @@ "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, + "tpu": lambda: getattr(getattr(torch, "tpu", None), "device_count", lambda: 0)(), "neuron": lambda: getattr(getattr(torch, "neuron", None), "device_count", lambda: 0)(), "default": 0, } @@ -68,6 +71,9 @@ "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, + # TPU latents are always generated on CPU (TPU RNG has unaligned DUS bug), + # so CPU seeding is the correct behaviour here. + "tpu": torch.manual_seed, "neuron": torch.manual_seed, "default": torch.manual_seed, } @@ -76,6 +82,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "tpu": None, "neuron": None, "default": None, } @@ -84,6 +91,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "tpu": None, "neuron": None, "default": None, } @@ -92,6 +100,7 @@ "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, + "tpu": 0, "neuron": 0, "default": 0, } @@ -100,6 +109,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, + "tpu": getattr(getattr(torch, "tpu", None), "synchronize", None), "neuron": getattr(getattr(torch, "neuron", None), "synchronize", None), "default": None, } @@ -197,6 +207,11 @@ def randn_tensor( rand_device = device batch_size = shape[0] + # TPU RNG has an unaligned DUS (dynamic-update-slice) bug — generate on CPU + # and move to TPU via the existing .to(device) call at the end. + if device is not None and device.type == "tpu": + rand_device = torch.device("cpu") + layout = layout or torch.strided device = device or torch.device("cpu")