[feat]: support wan2.2 s2v, support dist infer, pose-audio#1113
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Wan2.2 S2V (Speech-to-Video) model, adding configuration files, a dedicated runner, networks, and utilities for audio encoding, audio injection, and frame packing. The review feedback highlights several key improvements: fixing an AttributeError in weight synchronization within casual_audio.py, replacing a float comparison with an integer frame count check in audio_encoder.py to avoid precision issues, dynamically determining latent channels in framepack.py, ensuring multi-device portability in pre_infer.py by avoiding hardcoded device strings, removing an unused variable in transformer_infer.py, and correcting a type annotation syntax error in wan_causal_audio_module.py.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| dst.final_linear.weight.data.copy_(src.final_linear.weight.t()) | ||
| dst.final_linear.bias.data.copy_(src.final_linear.bias) |
There was a problem hiding this comment.
src.final_linear is registered as an MMWeight module, which wraps the underlying PyTorch tensors in WeightTensor objects. Accessing .weight and .bias directly on src.final_linear returns these wrapper objects rather than raw PyTorch tensors, which will cause an AttributeError when calling .t() or copying data. You should access .tensor on them, matching how other weights are copied in this function.
| dst.final_linear.weight.data.copy_(src.final_linear.weight.t()) | |
| dst.final_linear.bias.data.copy_(src.final_linear.bias) | |
| dst.final_linear.weight.data.copy_(src.final_linear.weight.tensor.t()) | |
| dst.final_linear.bias.data.copy_(src.final_linear.bias.tensor) |
| if required_duration > total_frames / original_fps: | ||
| raise ValueError("required_duration must be less than video length") |
There was a problem hiding this comment.
Comparing float durations (required_duration > total_frames / original_fps) can lead to precision issues (e.g., when they are mathematically equal but float representation makes one slightly larger), causing unexpected ValueError exceptions. Comparing the integer frame counts (required_origin_frames > total_frames) is much more robust.
| if required_duration > total_frames / original_fps: | |
| raise ValueError("required_duration must be less than video length") | |
| if required_origin_frames > total_frames: | |
| raise ValueError("required_duration must be less than video length") |
|
|
||
| for m in motion_latents: | ||
| lat_height, lat_width = m.shape[2], m.shape[3] | ||
| padd_lat = torch.zeros(16, zip_frame_buckets.sum(), lat_height, lat_width, device=m.device, dtype=m.dtype) |
There was a problem hiding this comment.
The number of latent channels is hardcoded to 16. To make the frame packing logic robust to different VAE architectures or configurations, it is better to dynamically use m.shape[0] instead of a magic number.
| padd_lat = torch.zeros(16, zip_frame_buckets.sum(), lat_height, lat_width, device=m.device, dtype=m.dtype) | |
| padd_lat = torch.zeros(m.shape[0], zip_frame_buckets.sum(), lat_height, lat_width, device=m.device, dtype=m.dtype) |
| from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer | ||
| from lightx2v_platform.base.global_var import AI_DEVICE | ||
|
|
||
| torch_device_module = getattr(torch, AI_DEVICE) |
|
|
||
|
|
||
| class MotionEncoder_tc(nn.Module): | ||
| def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None): |
There was a problem hiding this comment.
In the method signature, num_heads=int uses the assignment operator = instead of a colon : for the type annotation. This makes the default value of num_heads the int class itself, which will cause runtime errors (e.g., in rearrange) if the argument is ever omitted. It should be written as num_heads: int.
| def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None): | |
| def __init__(self, in_dim: int, hidden_dim: int, num_heads: int, need_global=True, dtype=None, device=None): |
24d4f7e to
cd75cec
Compare
No description provided.