From 3d14964f018b0c3d5e10c2fc62a0054950322ca9 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 23 Jun 2026 16:45:00 +0000 Subject: [PATCH] Init commit for transfer capability to Cosmos3 pipeline --- docs/source/en/api/pipelines/cosmos3.md | 105 +++- examples/cosmos3/inference_cosmos3.py | 116 +++- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 515 ++++++++++++++++-- 3 files changed, 698 insertions(+), 38 deletions(-) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index 1ac8f36457a4..221922fcd4f0 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -32,7 +32,7 @@ From one model you can: - Generate physically plausible video worlds from text, images, or action inputs (image-to-video, text-to-video, action-conditioned video generation). - Reason about physical properties like motion, causality, and spatial relationships. - Predict future video and action sequences from the current state. -- Transfer scenes across viewpoints and conditions with structural control *(coming soon)*. +- Transfer scenes across viewpoints and conditions with structural control (edge, blur, depth, segmentation, world-scenario maps). Under the hood, a single `Cosmos3OmniTransformer` runs a Qwen-style language model in parallel with a diffusion generation pathway: text tokens flow through a causal "understanding" stream while video and sound latents flow through a bi-directionally-attended "generation" stream, joined by a 3D multimodal RoPE. See the [Cosmos World Foundation Model Platform paper](https://huggingface.co/papers/2501.03575) for the architectural background. @@ -371,6 +371,109 @@ export_to_video(result.video, "cosmos3_v2v.mp4", fps=24, macro_block_size=1) +## Transfer (structural control) + +Transfer generates a target clip that follows a **precomputed control video** (a spatial control signal): edge (Canny), blur, depth, segmentation, or a world-scenario map (WSM). Pass it through `control_videos=` as a mapping from hint name to a loaded video. The control map is resized, temporally padded, normalized, and VAE-encoded into a clean conditioning item placed before the noisy target; the model then generates the target to match it. Transfer is video-only (no `image`, `video`, `action`, or `enable_sound`), and the prompt is a pre-upsampled JSON caption (see [Prompt upsampling](#prompt-upsampling)). + +Diffusers does not ship the control assets. Ready-made ones (a control video + matching `prompt.json` per hint, plus a shared `negative_prompt.json`) live in the [Cosmos cookbook](https://github.com/NVIDIA/cosmos/tree/main/cookbooks/cosmos3/generator/transfer/assets). For the edge example below, download them into a local `assets/` folder: + +```bash +base=https://github.com/NVIDIA/cosmos/raw/refs/heads/main/cookbooks/cosmos3/generator/transfer/assets +mkdir -p assets/edge +curl -sL "$base/edge/control_edge.mp4" -o assets/edge/control_edge.mp4 +curl -sL "$base/edge/prompt.json" -o assets/edge/prompt.json +curl -sL "$base/negative_prompt.json" -o assets/negative_prompt.json +``` + +Guidance uses a nested control/text classifier-free-guidance blend. `guidance_scale` is the usual text CFG; `control_guidance` (`!= 1.0`) additionally amplifies the control signal. Recommended starting values per hint (matching the Cosmos Framework defaults): + +| Hint | `guidance_scale` | `control_guidance` | `flow_shift` | Geometry | +| --- | --- | --- | --- | --- | +| Edge / Blur / Depth | 3.0 | 1.5 | 10.0 | 121 frames @ 30 FPS | +| Segmentation | 3.0 | 2.0 | 10.0 | 121 frames @ 30 FPS | +| World scenario (WSM) | 1.0 | 3.0 | 10.0 | 101 frames @ 10 FPS | + +Depth, segmentation, and WSM control maps must be precomputed by external models; edge/blur maps can be produced offline with any Canny/blur tool. The shipped cookbook configs use a single hint each; passing several entries in `control_videos` to combine hints is supported by the pipeline but is not a tuned/validated cookbook path (set `guidance_scale` / `control_guidance` explicitly, since the per-hint defaults above assume a single hint). Long clips are generated autoregressively in chunks of `num_video_frames_per_chunk` and stitched automatically. + + + + +```python +import json +import torch +from diffusers import Cosmos3OmniPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video, load_video + +# Downloaded into assets/ from the Cosmos cookbook (see the curl snippet above). +json_prompt = json.load(open("assets/edge/prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) +control_edge = load_video("assets/edge/control_edge.mp4") + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" +) +pipe.scheduler = UniPCMultistepScheduler.from_config( + pipe.scheduler.config, flow_shift=10.0, use_karras_sigmas=False +) + +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + control_videos={"edge": control_edge}, + num_frames=121, + height=720, + width=1280, + fps=30.0, + num_inference_steps=35, + guidance_scale=3.0, + control_guidance=1.5, +) +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "cosmos3_transfer_edge.mp4", fps=30, macro_block_size=1) +``` + + + + +```python +import json +import torch +from diffusers import Cosmos3OmniPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video, load_video + +# Downloaded into assets/ from the Cosmos cookbook (see the curl snippet above). +json_prompt = json.load(open("assets/edge/prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) +control_edge = load_video("assets/edge/control_edge.mp4") + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" +) +pipe.scheduler = UniPCMultistepScheduler.from_config( + pipe.scheduler.config, flow_shift=10.0, use_karras_sigmas=False +) + +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + control_videos={"edge": control_edge}, + num_frames=121, + height=720, + width=1280, + fps=30.0, + num_inference_steps=35, + guidance_scale=3.0, + control_guidance=1.5, +) +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "cosmos3_transfer_edge.mp4", fps=30, macro_block_size=1) +``` + + + + ## Video-to-video with sound When the checkpoint carries a `sound_tokenizer`, add `enable_sound=True` to the video-to-video call to jointly generate a synchronized audio track. The waveform is returned alongside the video and can be muxed into the MP4 with [`~utils.encode_video`]. diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index 62388c8d1288..16014dabaaec 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -21,6 +21,13 @@ Video-to-video: python inference_cosmos3.py --prompt "..." --video-path /path/to/video.mp4 +Transfer (ready-made control_*.mp4 + prompt.json are hosted in the Cosmos cookbook; --control-path / --prompt +accept URLs or local paths: https://github.com/NVIDIA/cosmos/tree/main/cookbooks/cosmos3/generator/transfer/assets): + base=https://github.com/NVIDIA/cosmos/raw/refs/heads/main/cookbooks/cosmos3/generator/transfer/assets + python inference_cosmos3.py --prompt "$(curl -sL $base/edge/prompt.json)" \ + --transfer-hint edge --control-path $base/edge/control_edge.mp4 \ + --guidance-scale 3.0 --control-guidance 1.5 --flow-shift 10.0 --num-frames 121 --fps 30 + Text-to-video-with-sound (requires a sound-capable checkpoint): python inference_cosmos3.py --prompt "..." --enable-sound """ @@ -62,6 +69,11 @@ def _load_action(path: str | None): def main(): parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("--prompt", required=True, help="Text prompt.") + parser.add_argument( + "--negative-prompt", + default=None, + help="Optional negative prompt text.", + ) parser.add_argument( "--model", choices=sorted(HF_REPOS), @@ -89,6 +101,60 @@ def main(): default="first", help="Take the video-to-video conditioning frames from the first or last of the source clip (default: first).", ) + parser.add_argument( + "--transfer-hint", + action="append", + choices=["edge", "blur", "depth", "seg", "wsm"], + default=None, + help="Enable transfer with a control hint. Repeat (paired with --control-path) to combine multiple hints.", + ) + parser.add_argument( + "--control-path", + action="append", + default=None, + help="URL or local path to a precomputed control video, paired in order with each --transfer-hint.", + ) + parser.add_argument( + "--control-guidance", + type=float, + default=1.0, + help="Transfer control-CFG scale (recommended 1.5 for edge/blur/depth, 2.0 for seg, 3.0 for wsm).", + ) + parser.add_argument( + "--control-guidance-interval", + default=None, + help="Comma-separated [lo,hi] timestep window for control guidance (default: applied at every step).", + ) + parser.add_argument( + "--guidance-interval", + default=None, + help="Comma-separated [lo,hi] timestep window for text guidance in transfer (default: every step).", + ) + parser.add_argument( + "--num-conditional-frames", + type=int, + default=1, + help="Frames carried over from the previous chunk as conditioning (transfer multi-chunk).", + ) + parser.add_argument( + "--num-first-chunk-conditional-frames", + type=int, + default=0, + help="Leading frames of --video-path used to condition the first transfer chunk (requires --video-path).", + ) + parser.add_argument( + "--num-video-frames-per-chunk", + type=int, + default=None, + help="Max frames generated per autoregressive transfer chunk (default: whole clip in one chunk).", + ) + parser.add_argument( + "--no-share-vision-temporal-positions", + dest="share_vision_temporal_positions", + action="store_false", + default=True, + help="Give control maps and the target distinct temporal mRoPE positions instead of sharing them (transfer).", + ) parser.add_argument("--output", default=".", help="Directory to save generated video/image/audio files.") parser.add_argument( "--height", @@ -198,7 +264,52 @@ def main(): output_dir.mkdir(parents=True, exist_ok=True) generator = torch.Generator().manual_seed(args.seed) if args.seed is not None else None - if args.action_mode is not None: + def _parse_interval(value): + if value is None: + return None + parts = [float(v) for v in value.split(",") if v.strip()] + if len(parts) != 2: + raise ValueError(f"Expected a comma-separated [lo,hi] interval, got {value!r}.") + return (parts[0], parts[1]) + + if args.transfer_hint is not None: + control_paths = args.control_path or [] + if len(control_paths) != len(args.transfer_hint): + raise ValueError("Pass one --control-path per --transfer-hint, in matching order.") + control_videos = {hint: load_video(path) for hint, path in zip(args.transfer_hint, control_paths)} + # `--video-path` is an OPTIONAL RGB prefix that only seeds the first chunk, and is consulted solely when + # --num-first-chunk-conditional-frames > 0. It is unrelated to the control hints (which always drive transfer). + conditioning_video = None + if args.num_first_chunk_conditional_frames > 0: + if args.video_path is None: + raise ValueError( + "--num-first-chunk-conditional-frames > 0 requires --video-path (an RGB prefix clip)." + ) + conditioning_video = load_video(args.video_path) + elif args.video_path is not None: + print("Ignoring --video-path: it only applies when --num-first-chunk-conditional-frames > 0.") + result = pipeline( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + control_videos=control_videos, + video=conditioning_video, + num_frames=args.num_frames if args.num_frames != 189 else None, + height=args.height, + width=args.width, + fps=args.fps, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + control_guidance=args.control_guidance, + control_guidance_interval=_parse_interval(args.control_guidance_interval), + guidance_interval=_parse_interval(args.guidance_interval), + num_conditional_frames=args.num_conditional_frames, + num_first_chunk_conditional_frames=args.num_first_chunk_conditional_frames, + num_video_frames_per_chunk=args.num_video_frames_per_chunk, + share_vision_temporal_positions=args.share_vision_temporal_positions, + generator=generator, + enable_safety_check=not args.no_safety_check, + ) + elif args.action_mode is not None: if args.vision_path is None: raise ValueError("--vision-path must point to a conditioning video for action modes.") if args.action_chunk_size is None: @@ -207,6 +318,7 @@ def main(): raw_actions = _load_action(args.action_path) if args.action_mode == "forward_dynamics" else None result = pipeline( prompt=args.prompt, + negative_prompt=args.negative_prompt, action=CosmosActionCondition( mode=args.action_mode, chunk_size=args.action_chunk_size, @@ -234,6 +346,7 @@ def main(): ) result = pipeline( prompt=args.prompt, + negative_prompt=args.negative_prompt, video=video, condition_frame_indexes_vision=condition_frame_indexes_vision, condition_video_keep=args.condition_video_keep, @@ -253,6 +366,7 @@ def main(): image = load_image(args.vision_path) if args.vision_path is not None else None result = pipeline( prompt=args.prompt, + negative_prompt=args.negative_prompt, image=image, num_frames=args.num_frames, height=args.height, diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 538b553d478d..bc6a0456eddb 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -135,6 +135,10 @@ def get_3d_mrope_ids_vae_tokens( _SYSTEM_PROMPT_IMAGE = "You are a helpful assistant who will generate images from a give prompt." _SYSTEM_PROMPT_VIDEO = "You are a helpful assistant who will generate videos from a give prompt." +_SYSTEM_PROMPT_TRANSFER = ( + "You are a helpful assistant that generates images or videos following the user's instructions" + " and control signals (edge maps, blur, depth, or segmentation)." +) _ACTION_RESOLUTION_BINS = { "256": { @@ -502,62 +506,109 @@ def _prepare_text_segment( def _prepare_vision_segment( self, - input_vision_tokens: torch.Tensor, + input_vision_tokens: torch.Tensor | list[torch.Tensor], has_image_condition: bool, mrope_offset: int | float, vision_fps: float | None, curr: int, device: torch.device | str, - condition_frame_indexes: list[int] | None = None, + condition_frame_indexes: list[int] | list[list[int] | None] | None = None, + clean_item_flags: list[bool] | None = None, + share_vision_temporal_positions: bool = False, ) -> dict[str, Any]: """Build the static portion of the vision segment of the joint sequence. Step-varying fields (``vision_tokens`` and ``vision_timesteps``) are NOT included here — the caller splices them in inside the denoising loop. The method is called once per (cond/uncond) prompt before the loop, since everything else only depends on the prompt length and the vision shape. + + For transfer, multiple vision items are packed in order ``[ctrl_1, ..., ctrl_N, target]``: control items are + marked clean via ``clean_item_flags`` (all frames conditioned, no noisy positions, no MSE-loss positions), so + the transformer treats them as fixed context and only predicts the (noisy) target frames. When + ``share_vision_temporal_positions`` is ``True`` every item reuses the same temporal mRoPE offset (the control + maps and the target are temporally aligned) instead of advancing the offset per item. """ config = self.transformer.config latent_patch_size = config.latent_patch_size - _, _, latent_t, latent_h, latent_w = input_vision_tokens.shape - patch_h = math.ceil(latent_h / latent_patch_size) - patch_w = math.ceil(latent_w / latent_patch_size) - num_vision_tokens = latent_t * patch_h * patch_w - - if condition_frame_indexes is None: - condition_frame_indexes = [0] if has_image_condition else [] - cond_frames = {idx for idx in condition_frame_indexes if 0 <= idx < latent_t} - noisy_frame_indexes = torch.tensor( - [idx for idx in range(latent_t) if idx not in cond_frames], device=device, dtype=torch.long - ) - frame_token_stride = patch_h * patch_w - mse_loss_indexes: list[int] = [] - for frame_idx in noisy_frame_indexes.tolist(): - frame_start = curr + frame_idx * frame_token_stride - mse_loss_indexes.extend(range(frame_start, frame_start + frame_token_stride)) + # Normalize to per-item lists so the single-item (non-transfer) path and the multi-item transfer path share + # one implementation. A single tensor with a flat condition_frame_indexes list reproduces the old behavior. + if isinstance(input_vision_tokens, torch.Tensor): + items = [input_vision_tokens] + per_item_condition: list[list[int] | None] = [condition_frame_indexes] # type: ignore[list-item] + else: + items = list(input_vision_tokens) + if condition_frame_indexes is None: + per_item_condition = [None] * len(items) + else: + per_item_condition = list(condition_frame_indexes) # type: ignore[arg-type] + if clean_item_flags is None: + clean_item_flags = [False] * len(items) effective_fps = vision_fps if config.enable_fps_modulation else None - vision_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( - grid_t=latent_t, - grid_h=patch_h, - grid_w=patch_w, - temporal_offset=mrope_offset, - reset_spatial_indices=config.unified_3d_mrope_reset_spatial_ids, - fps=effective_fps, - base_fps=float(config.base_fps), - temporal_compression_factor=self.vae.config.scale_factor_temporal, - ) + token_shapes: list[tuple[int, int, int]] = [] + sequence_index_parts: list[torch.Tensor] = [] + mse_loss_indexes: list[int] = [] + noisy_frame_indexes_per_item: list[torch.Tensor] = [] + mrope_id_parts: list[torch.Tensor] = [] + num_vision_tokens = 0 + num_noisy_vision_tokens = 0 + item_curr = curr + item_mrope_offset: int | float = mrope_offset + + for item, item_condition, is_clean in zip(items, per_item_condition, clean_item_flags): + _, _, latent_t, latent_h, latent_w = item.shape + patch_h = math.ceil(latent_h / latent_patch_size) + patch_w = math.ceil(latent_w / latent_patch_size) + item_num_tokens = latent_t * patch_h * patch_w + frame_token_stride = patch_h * patch_w + + if is_clean: + cond_frames = set(range(latent_t)) + else: + item_condition = item_condition if item_condition is not None else ([0] if has_image_condition else []) + cond_frames = {idx for idx in item_condition if 0 <= idx < latent_t} + noisy_frame_indexes = torch.tensor( + [idx for idx in range(latent_t) if idx not in cond_frames], device=device, dtype=torch.long + ) + + for frame_idx in noisy_frame_indexes.tolist(): + frame_start = item_curr + frame_idx * frame_token_stride + mse_loss_indexes.extend(range(frame_start, frame_start + frame_token_stride)) + + item_mrope_ids, next_mrope_offset = get_3d_mrope_ids_vae_tokens( + grid_t=latent_t, + grid_h=patch_h, + grid_w=patch_w, + temporal_offset=item_mrope_offset, + reset_spatial_indices=config.unified_3d_mrope_reset_spatial_ids, + fps=effective_fps, + base_fps=float(config.base_fps), + temporal_compression_factor=self.vae.config.scale_factor_temporal, + ) + + token_shapes.append((latent_t, patch_h, patch_w)) + sequence_index_parts.append( + torch.arange(item_curr, item_curr + item_num_tokens, dtype=torch.long, device=device) + ) + noisy_frame_indexes_per_item.append(noisy_frame_indexes) + mrope_id_parts.append(item_mrope_ids.to(device)) + num_vision_tokens += item_num_tokens + num_noisy_vision_tokens += len(noisy_frame_indexes) * frame_token_stride + item_curr += item_num_tokens + if not share_vision_temporal_positions: + item_mrope_offset = next_mrope_offset return { # Transformer-facing fields (vision_tokens and vision_timesteps spliced per step). - "vision_token_shapes": [(latent_t, patch_h, patch_w)], - "vision_sequence_indexes": torch.arange(curr, curr + num_vision_tokens, dtype=torch.long, device=device), + "vision_token_shapes": token_shapes, + "vision_sequence_indexes": torch.cat(sequence_index_parts, dim=0), "vision_mse_loss_indexes": torch.tensor(mse_loss_indexes, dtype=torch.long, device=device), - "vision_noisy_frame_indexes": [noisy_frame_indexes], + "vision_noisy_frame_indexes": noisy_frame_indexes_per_item, # Assembly helpers (consumed inline before the transformer call). - "vision_mrope_ids": vision_mrope_ids.to(device), + "vision_mrope_ids": torch.cat(mrope_id_parts, dim=1), "num_vision_tokens": num_vision_tokens, - "num_noisy_vision_tokens": len(noisy_frame_indexes) * frame_token_stride, + "num_noisy_vision_tokens": num_noisy_vision_tokens, } def _prepare_sound_segment( @@ -959,11 +1010,42 @@ def check_inputs( action: "CosmosActionCondition | None" = None, video: list[Image.Image] | torch.Tensor | np.ndarray | None = None, condition_frame_indexes_vision: Iterable[int] = (0, 1), + control_videos: dict[str, Any] | None = None, + num_first_chunk_conditional_frames: int = 0, ) -> None: if not isinstance(prompt, (str, list)) or ( isinstance(prompt, list) and not all(isinstance(p, str) for p in prompt) ): raise ValueError(f"`prompt` must be a str or list of str, got {type(prompt).__name__}.") + + if control_videos is not None: + # Transfer mode: validate the hint mapping and reject combinations the model does not support. + # The supported hints (edge, blur, depth, seg, wsm) are listed in canonical packing order. + supported_hints = ["edge", "blur", "depth", "seg", "wsm"] + if not isinstance(control_videos, dict) or not control_videos: + raise ValueError("`control_videos` must be a non-empty dict mapping hint name -> control video.") + unknown = [k for k in control_videos if k not in supported_hints] + if unknown: + raise ValueError( + f"`control_videos` has unknown hint(s) {unknown}; expected keys from {supported_hints}." + ) + if any(v is None for v in control_videos.values()): + raise ValueError("`control_videos` entries must be loaded videos, not None.") + if action is not None: + raise ValueError("Transfer (`control_videos`) cannot be combined with `action`.") + if image is not None: + raise ValueError("Transfer (`control_videos`) cannot be combined with `image`.") + if enable_sound: + raise ValueError( + "Transfer (`control_videos`) is video-only and cannot be combined with `enable_sound`." + ) + if num_first_chunk_conditional_frames > 0 and video is None: + raise ValueError( + "`num_first_chunk_conditional_frames` > 0 requires a `video` for first-chunk conditioning." + ) + if num_frames is not None and num_frames < 1: + raise ValueError(f"`num_frames` must be >= 1, got {num_frames}.") + return if negative_prompt is not None and not isinstance(negative_prompt, (str, list)): raise ValueError( f"`negative_prompt` must be a str, list of str, or None, got {type(negative_prompt).__name__}." @@ -1085,6 +1167,7 @@ def tokenize_prompt( add_duration_template: bool = True, action_mode: str | None = None, action_view_point: str | None = None, + transfer_mode: bool = False, ) -> tuple[list[int], list[int]]: """Apply prompt-augmentation templates and tokenize cond/uncond prompts via the Qwen2 chat template. @@ -1099,6 +1182,9 @@ def tokenize_prompt( was trained on (see :meth:`_build_action_json_prompt`), using ``action_view_point`` for the framing field; the flat metadata templates are skipped because the JSON already carries duration/fps/resolution/aspect_ratio. + When ``transfer_mode`` is set, the transfer system prompt is used and the prompt / negative prompt are passed + through verbatim (they are pre-upsampled JSON captions), again skipping the flat metadata templates. + Returns: ``(cond_input_ids, uncond_input_ids)`` — token-id lists for this sample. """ @@ -1128,7 +1214,10 @@ def _apply_templates(text: str, is_negative: bool = False) -> str: def _tokenize(text: str) -> BatchEncoding: conversations = [] if use_system_prompt: - system_prompt = _SYSTEM_PROMPT_IMAGE if is_image else _SYSTEM_PROMPT_VIDEO + if transfer_mode: + system_prompt = _SYSTEM_PROMPT_TRANSFER + else: + system_prompt = _SYSTEM_PROMPT_IMAGE if is_image else _SYSTEM_PROMPT_VIDEO conversations.append({"role": "system", "content": system_prompt}) conversations.append({"role": "user", "content": text}) return self.text_tokenizer.apply_chat_template( @@ -1150,6 +1239,11 @@ def _add_special_tokens(input_ids: list[int]) -> list[int]: prompt, view_point=action_view_point, num_frames=num_frames, fps=fps, height=height, width=width ) uncond_text = negative_prompt + elif transfer_mode: + # Transfer prompts are pre-upsampled JSON captions that already carry duration/fps/resolution; pass them + # through verbatim (the metadata templates would corrupt the JSON), mirroring the action-mode branch. + cond_text = prompt + uncond_text = negative_prompt else: cond_text = _apply_templates(prompt) uncond_text = _apply_templates(negative_prompt, is_negative=True) @@ -1234,6 +1328,286 @@ def _apply_video_safety_check(self, video: Any, output_type: str, device: torch. # output_type == "pt" return torch.from_numpy(checked.astype(np.float32) / 255.0).permute(0, 3, 1, 2) + @torch.no_grad() + def _generate_transfer( + self, + *, + prompt: str, + negative_prompt: str | None, + control_videos: dict[str, Any], + video: Any, + num_frames: int | None, + height: int, + width: int, + fps: float, + num_inference_steps: int, + guidance_scale: float, + control_guidance: float, + control_guidance_interval: tuple[float, float] | None, + guidance_interval: tuple[float, float] | None, + num_conditional_frames: int, + num_first_chunk_conditional_frames: int, + num_video_frames_per_chunk: int | None, + share_vision_temporal_positions: bool, + generator: torch.Generator | None, + output_type: str, + return_dict: bool, + use_system_prompt: bool, + enable_safety_check: bool, + device: torch.device, + dtype: torch.dtype, + ) -> "Cosmos3OmniPipelineOutput | tuple": + """Run video transfer: generate a target clip that follows one or more precomputed control hints. + + Control maps are packed as clean (fully conditioned) vision items before the noisy target — sequence layout + ``[ctrl_1, ..., ctrl_N, target]`` — and guidance uses a nested control/text classifier-free-guidance blend + (see the per-step branch selection below). Long clips are produced autoregressively chunk-by-chunk, with each + chunk conditioned on the tail of the previous one, then stitched back together. + """ + if output_type == "latent": + raise ValueError( + "Transfer decodes and stitches chunks in pixel space; `output_type='latent'` is unsupported." + ) + + tcf = int(self.vae.config.scale_factor_temporal) + sf = int(self.vae.config.scale_factor_spatial) + if height % sf != 0 or width % sf != 0: + raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") + + def _pad_temporal(frames: torch.Tensor, target_t: int) -> torch.Tensor: + # frames: [1, 3, T, H, W]. Reflect-pad along time up to target_t, falling back to repeating the last + # frame, mirroring the native Cosmos Framework `pad_temporal_frames`. No truncation (callers slice first). + if frames.shape[2] >= target_t: + return frames + while frames.shape[2] < target_t: + pad_len = min(frames.shape[2] - 1, target_t - frames.shape[2]) + if pad_len <= 0: + pad_frame = frames[:, :, -1:].repeat(1, 1, target_t - frames.shape[2], 1, 1) + frames = torch.cat([frames, pad_frame], dim=2) + break + frames = torch.cat([frames, frames.flip(dims=[2])[:, :, :pad_len]], dim=2) + return frames + + def _decode_to_pixel(latent: torch.Tensor) -> torch.Tensor: + vae_dtype = self.vae.dtype + mean = self._vae_latents_mean.to(device=latent.device, dtype=vae_dtype) + inv_std = self._vae_latents_inv_std.to(device=latent.device, dtype=vae_dtype) + z_raw = latent.to(vae_dtype) / inv_std.view(1, -1, 1, 1, 1) + mean.view(1, -1, 1, 1, 1) + return self.vae.decode(z_raw).sample.to(torch.float32).clamp(-1, 1) + + def _active_at(t: torch.Tensor, interval: tuple[float, float] | None) -> bool: + if interval is None: + return True + lo, hi = float(interval[0]), float(interval[1]) + return lo <= float(t.item()) <= hi + + # Canonical hint order, then preprocess every control map to [1, 3, T, H, W] in [-1, 1] at the target geometry. + hint_keys = [k for k in ["edge", "blur", "depth", "seg", "wsm"] if k in control_videos] + control_frames = { + key: self.video_processor.preprocess_video(control_videos[key], height=height, width=width).to( + device=device, dtype=dtype + ) + for key in hint_keys + } + input_frames = None + if video is not None: + input_frames = self.video_processor.preprocess_video(video, height=height, width=width).to( + device=device, dtype=dtype + ) + + # Output frame count / chunking come from the (first) control video, optionally capped by num_frames. + total_frames = next(iter(control_frames.values())).shape[2] + if num_frames is not None: + total_frames = min(total_frames, num_frames) + total_frames = max(1, total_frames) + + per_chunk = num_video_frames_per_chunk if num_video_frames_per_chunk is not None else total_frames + chunk_frames = 1 if total_frames == 1 else per_chunk + chunk_frames = math.ceil((chunk_frames - 1) / tcf) * tcf + 1 + + if total_frames <= chunk_frames: + num_chunks, stride = 1, chunk_frames + else: + stride = chunk_frames - num_conditional_frames + if stride <= 0: + raise ValueError("`num_conditional_frames` must be smaller than `num_video_frames_per_chunk`.") + remaining = total_frames - chunk_frames + num_chunks = 1 + (remaining // stride + (1 if remaining % stride else 0)) + + padded = max(total_frames, chunk_frames) + control_frames = {key: _pad_temporal(frames, padded) for key, frames in control_frames.items()} + if input_frames is not None: + input_frames = _pad_temporal(input_frames, padded) + + # Text packing is invariant across chunks and denoising steps; build it once. Transfer prompts are passed + # through verbatim (pre-upsampled JSON) under the transfer system prompt. + cond_input_ids, uncond_input_ids = self.tokenize_prompt( + prompt, + negative_prompt, + num_frames=chunk_frames, + height=height, + width=width, + fps=fps, + use_system_prompt=use_system_prompt, + transfer_mode=True, + ) + cond_text_segment = self._prepare_text_segment(cond_input_ids, device=device) + uncond_text_segment = self._prepare_text_segment(uncond_input_ids, device=device) + num_hints = len(hint_keys) + + output_chunks: list[torch.Tensor] = [] + previous_output: torch.Tensor | None = None + + for chunk_id in range(num_chunks): + start_frame = chunk_id * stride + end_frame = min(start_frame + chunk_frames, total_frames) + chunk_controls = [ + _pad_temporal(control_frames[key][:, :, start_frame:end_frame], chunk_frames) for key in hint_keys + ] + + # Seed the target with conditioning frames (first chunk from the input video, later chunks from the + # previous chunk's tail), repeat-padding the remaining frames so the whole clip is well-defined. + target = torch.zeros(1, 3, chunk_frames, height, width, device=device, dtype=dtype) + current_conditional_frames = 0 + if chunk_id == 0 and num_first_chunk_conditional_frames > 0 and input_frames is not None: + current_conditional_frames = min( + num_first_chunk_conditional_frames, input_frames.shape[2], chunk_frames + ) + if current_conditional_frames > 0: + target[:, :, :current_conditional_frames] = input_frames[:, :, :current_conditional_frames] + elif chunk_id > 0 and previous_output is not None: + current_conditional_frames = min(num_conditional_frames, previous_output.shape[2], chunk_frames) + if current_conditional_frames > 0: + target[:, :, :current_conditional_frames] = previous_output[:, :, -current_conditional_frames:].to( + device=device, dtype=dtype + ) + if 0 < current_conditional_frames < chunk_frames: + fill = target[:, :, current_conditional_frames - 1 : current_conditional_frames] + target[:, :, current_conditional_frames:] = fill.expand( + -1, -1, chunk_frames - current_conditional_frames, -1, -1 + ) + + # Encode controls as clean latents and build the noisy target latents + conditioning mask. + control_latents = [self._encode_video(ctrl).contiguous().float() for ctrl in chunk_controls] + target_x0 = self._encode_video(target).contiguous().float() + latent_t = target_x0.shape[2] + condition_mask = torch.zeros((latent_t, 1, 1), device=device, dtype=dtype) + latent_condition_frames = 0 + if current_conditional_frames > 0: + latent_condition_frames = (current_conditional_frames - 1) // tcf + 1 + condition_mask[:latent_condition_frames] = 1.0 + noise = randn_tensor(tuple(target_x0.shape), generator=generator, device=device, dtype=dtype) + latents = condition_mask * target_x0 + (1.0 - condition_mask) * noise + velocity_mask = 1.0 - condition_mask + condition_latents = condition_mask * target_x0 + + target_condition_indexes = list(range(latent_condition_frames)) + + # Pre-pack the three CFG sequence variants. cond_full / uncond_full carry every control item; the + # no-control branch drops them (only [text, target]) so the control axis can be amplified. + def _vision_pack(text_segment: dict[str, Any], include_controls: bool) -> dict[str, Any]: + if include_controls: + vision_items = [*control_latents, latents] + condition_indexes = [None] * num_hints + [target_condition_indexes] + clean_flags = [True] * num_hints + [False] + else: + vision_items = [latents] + condition_indexes = [target_condition_indexes] + clean_flags = [False] + vision_segment = self._prepare_vision_segment( + input_vision_tokens=vision_items, + has_image_condition=False, + mrope_offset=text_segment["vision_start_temporal_offset"], + vision_fps=fps, + curr=text_segment["und_len"], + device=device, + condition_frame_indexes=condition_indexes, + clean_item_flags=clean_flags, + share_vision_temporal_positions=share_vision_temporal_positions, + ) + return { + **text_segment, + **vision_segment, + "position_ids": torch.cat( + [text_segment["text_mrope_ids"], vision_segment["vision_mrope_ids"]], dim=1 + ), + "sequence_length": text_segment["und_len"] + vision_segment["num_vision_tokens"], + } + + cond_full_static = _vision_pack(cond_text_segment, include_controls=True) + cond_no_control_static = _vision_pack(cond_text_segment, include_controls=False) + uncond_full_static = _vision_pack(uncond_text_segment, include_controls=True) + num_noisy_vision_tokens = cond_full_static["num_noisy_vision_tokens"] + + def _run(static: dict[str, Any], vision_tokens: list[torch.Tensor], vision_timesteps: torch.Tensor): + preds_vision, _, _ = self.transformer( + input_ids=static["input_ids"], + text_indexes=static["text_indexes"], + position_ids=static["position_ids"], + und_len=static["und_len"], + sequence_length=static["sequence_length"], + vision_tokens=vision_tokens, + vision_token_shapes=static["vision_token_shapes"], + vision_sequence_indexes=static["vision_sequence_indexes"], + vision_mse_loss_indexes=static["vision_mse_loss_indexes"], + vision_timesteps=vision_timesteps, + vision_noisy_frame_indexes=static["vision_noisy_frame_indexes"], + ) + # The target is the last vision item; control items return zeros (no MSE positions). + return preds_vision[-1] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.progress_bar(self.scheduler.timesteps): + self._current_timestep = t + timestep = t.item() + vision_tokens_full = [c.to(device=device, dtype=dtype) for c in control_latents] + [ + latents.to(device=device, dtype=dtype) + ] + vision_tokens_target = [latents.to(device=device, dtype=dtype)] + vision_timesteps = torch.full((num_noisy_vision_tokens,), timestep, device=device) + + step_guidance = guidance_scale if _active_at(t, guidance_interval) else 1.0 + step_control = control_guidance if _active_at(t, control_guidance_interval) else 1.0 + needs_text_cfg = step_guidance > 1.0 + needs_control_cfg = step_control != 1.0 + + cond_full = _run(cond_full_static, vision_tokens_full, vision_timesteps) + if needs_control_cfg and needs_text_cfg: + cond_no_control = _run(cond_no_control_static, vision_tokens_target, vision_timesteps) + uncond_full = _run(uncond_full_static, vision_tokens_full, vision_timesteps) + control_cond = cond_no_control + step_control * (cond_full - cond_no_control) + velocity = uncond_full + step_guidance * (control_cond - uncond_full) + elif needs_control_cfg: + cond_no_control = _run(cond_no_control_static, vision_tokens_target, vision_timesteps) + velocity = cond_no_control + step_control * (cond_full - cond_no_control) + elif needs_text_cfg: + uncond_full = _run(uncond_full_static, vision_tokens_full, vision_timesteps) + velocity = uncond_full + step_guidance * (cond_full - uncond_full) + else: + velocity = cond_full + + velocity = velocity * velocity_mask + latents = self.scheduler.step(velocity.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False)[ + 0 + ].squeeze(0) + latents = velocity_mask * latents + (1.0 - velocity_mask) * condition_latents + + output_video = _decode_to_pixel(latents) + previous_output = output_video + # Chunks after the first overlap the previous chunk by the conditioning frames; drop them when stitching. + output_chunks.append(output_video if chunk_id == 0 else output_video[:, :, current_conditional_frames:]) + + self._current_timestep = None + decoded = torch.cat(output_chunks, dim=2)[:, :, :total_frames] + video_out = self.video_processor.postprocess_video(decoded, output_type=output_type)[0] + if enable_safety_check and isinstance(self.safety_checker, CosmosSafetyChecker): + video_out = self._apply_video_safety_check(video_out, output_type=output_type, device=device) + + self.maybe_free_model_hooks() + if not return_dict: + return (video_out, None) + return Cosmos3OmniPipelineOutput(video=video_out, sound=None, action=None) + @property def current_timestep(self): return self._current_timestep @@ -1267,6 +1641,14 @@ def __call__( sound_latents: torch.Tensor | None = None, action_latents: torch.Tensor | None = None, action: CosmosActionCondition | None = None, + control_videos: dict[str, list[Image.Image] | torch.Tensor | np.ndarray] | None = None, + control_guidance: float = 1.0, + control_guidance_interval: tuple[float, float] | None = None, + guidance_interval: tuple[float, float] | None = None, + num_conditional_frames: int = 1, + num_first_chunk_conditional_frames: int = 0, + num_video_frames_per_chunk: int | None = None, + share_vision_temporal_positions: bool = True, output_type: str = "pil", return_dict: bool = True, use_system_prompt: bool = True, @@ -1347,6 +1729,33 @@ def __call__( `action_gen=True`. When set, passing the top-level `image` argument raises; `height` / `width` / `num_frames` must be `None`, since resolution comes from `action.resolution_tier` and frame count from `action.chunk_size`. See [`CosmosActionCondition`]. + control_videos (`dict[str, video]`, *optional*): + Enables video transfer. A mapping from control-hint name (`"edge"`, `"blur"`, `"depth"`, `"seg"`, + `"wsm"`) to a precomputed control video (anything accepted by `video=`). Each control map is resized, + temporally padded, normalized and VAE-encoded into a clean conditioning item; the target clip is then + generated to follow them. Multiple hints can be combined. Transfer is video-only and cannot be combined + with `image`, `video`, `action`, or `enable_sound`. The prompt should be a pre-upsampled JSON caption. + control_guidance (`float`, *optional*, defaults to `1.0`): + Control classifier-free guidance scale for transfer. Values `!= 1.0` amplify the control signal by + blending a "with-control" prediction against a "without-control" prediction (nested with the text + `guidance_scale`). `1.0` disables the control axis (control maps still condition both text branches). + control_guidance_interval (`tuple[float, float]`, *optional*): + Optional `[lo, hi]` timestep window (in scheduler timestep units) outside which control guidance is + skipped. When `None`, control guidance is applied at every step. + guidance_interval (`tuple[float, float]`, *optional*): + Optional `[lo, hi]` timestep window outside which text guidance is skipped (transfer only). + num_conditional_frames (`int`, *optional*, defaults to `1`): + Number of frames carried over from the previous chunk as conditioning at the start of each subsequent + autoregressive chunk (transfer only). + num_first_chunk_conditional_frames (`int`, *optional*, defaults to `0`): + Number of leading frames of `video` used to condition the first transfer chunk. Requires `video` to be + passed alongside `control_videos`; `0` means the first chunk is fully generated. + num_video_frames_per_chunk (`int`, *optional*): + Maximum number of frames generated per autoregressive chunk (transfer only). When `None`, the whole + clip is generated in a single chunk. Longer clips are produced chunk-by-chunk and stitched. + share_vision_temporal_positions (`bool`, *optional*, defaults to `True`): + When `True`, the control maps and the target share the same temporal mRoPE positions (they are + temporally aligned). When `False`, each vision item advances the temporal offset (transfer only). output_type (`str`, *optional*, defaults to `"pil"`): Output format for the video. One of `"pil"` (list of `PIL.Image.Image`), `"np"` (`np.ndarray`, `[T, H, W, C]`), `"pt"` (`torch.Tensor`, `[T, C, H, W]`), or `"latent"` (raw vision latents). @@ -1383,12 +1792,14 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs if action is None: - if num_frames is None: - num_frames = 189 if height is None: height = 720 if width is None: width = 1280 + # For transfer, num_frames defaults to the control video length (resolved in _generate_transfer); for the + # other modes it falls back to the standard ~7.9s clip. + if num_frames is None and control_videos is None: + num_frames = 189 # 1. Check inputs self.check_inputs( @@ -1404,6 +1815,8 @@ def __call__( action, video=video, condition_frame_indexes_vision=condition_frame_indexes_vision, + control_videos=control_videos, + num_first_chunk_conditional_frames=num_first_chunk_conditional_frames, ) # `action_mode` is the only action field consumed directly in __call__ (prompt template + output slicing); @@ -1446,6 +1859,36 @@ def __call__( finally: self.safety_checker.to("cpu") + # Transfer is a distinct mode (autoregressive multi-chunk + nested control/text CFG over multiple vision + # items), so it runs through its own self-contained routine rather than the shared single-clip path below. + if control_videos is not None: + return self._generate_transfer( + prompt=prompt, + negative_prompt=negative_prompt, + control_videos=control_videos, + video=video, + num_frames=num_frames, + height=height, + width=width, + fps=fps, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + control_guidance=control_guidance, + control_guidance_interval=control_guidance_interval, + guidance_interval=guidance_interval, + num_conditional_frames=num_conditional_frames, + num_first_chunk_conditional_frames=num_first_chunk_conditional_frames, + num_video_frames_per_chunk=num_video_frames_per_chunk, + share_vision_temporal_positions=share_vision_temporal_positions, + generator=generator, + output_type=output_type, + return_dict=return_dict, + use_system_prompt=use_system_prompt, + enable_safety_check=enable_safety_check, + device=device, + dtype=dtype, + ) + # 2. Tokenize prompt (applies metadata templates and selects mode-specific default negative prompt) cond_input_ids, uncond_input_ids = self.tokenize_prompt( prompt,