LTX-2 to Wan 2.1 Latent Adapter
This model translates latents from the LTX-Video (LTX-2) latent space directly into the Wan 2.1 latent space. This allows you to generate videos using LTX-Video, and then natively decode those generation latents using the high-quality Wan 2.1 VAE without leaving latent space or requiring full pixel-space VAE decoding and re-encoding. This is somewhat outdated due to LTX-2.3 release which seems to adress the issue. Overall translating the latent with this model provides worse result than re-encoding through pixel space, but works much faster. Possible applications are decoding LTX-2 Latents ore using Wan 2.1 as post-processing for LTX-2 Videos.
Architecture Overview
- Architecture: LatentAdapter
- Parameters: ~112M
- Features: Causal 3D convolutions (no future leakage) and temporal upsampling.
- Input: LTX-2 Latents
(B, 128, T_ltx, H_ltx, W_ltx) - Output: Wan 2.1 Latents
(B, 16, T_wan, H_wan, W_wan)
The model automatically handles the spatial (LTX 32x downscale -> Wan 8x downscale) and temporal transformation differences between the two VAEs.
Example Video Translation
Below is a demonstration of the adapter in action:
- Left: Real Video
- Center: Standard Pixel-Space Translation (Encode LTX -> Decode LTX โ Encode Wan โ Decode Wan)
- Right: Native Latent Adapter (Decode directly from Adapter Latents)
Standalone Inference Example (Non-ComfyUI, Not tested)
This example demonstrates how to load the adapter and translate a tensor of LTX-2 latents into Wan 2.1 latents directly in PyTorch, independent of the ComfyUI framework.
You will need the adapter_model.py file included in this repository to define the neural network architecture.
import torch
import torch.nn.functional as F
# Ensure adapter_model.py is in exactly the same folder or in your python path
from adapter_model import LatentAdapter
def run_adapter():
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Initialize and Load the Model
print("Loading Latent Adapter...")
adapter = LatentAdapter()
# Load weights (ensure the path points to your downloaded .pt file)
ckpt_path = "latent_adapter_final.pt"
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
# Handle if the dict is wrapped in a "model" key (which is common for our training saves)
if "model" in state_dict:
state_dict = state_dict["model"]
adapter.load_state_dict(state_dict)
adapter.to(device)
adapter.eval()
# 2. Prepare Dummy LTX-2 Latents
# LTX-2 shape: (Batch, Channels=128, Frames, Height, Width)
# E.g., for a 25-frame video at 480x704:
# Temporal dimension: (25 - 1) / 8 + 1 = 4
# Spatial dimensions: 480 / 32 = 15 | 704 / 32 = 22
b, c, t_ltx, h_ltx, w_ltx = 1, 128, 4, 15, 22
z_ltx = torch.randn(b, c, t_ltx, h_ltx, w_ltx, device=device)
print(f"Input LTX-2 Latent Shape: {z_ltx.shape}")
# 3. Calculate Exact Wan 2.1 Target Shape
# temporal: LTX is ~8x downscaled, Wan is 4x -> (t_ltx - 1) * 2 + 1
# spatial: LTX is 32x downscaled, Wan is 8x -> LTX spatial * 4
t_wan = (t_ltx - 1) * 2 + 1 # 4 -> 7
h_wan = h_ltx * 4 # 15 -> 60
w_wan = w_ltx * 4 # 22 -> 88
target_shape = (t_wan, h_wan, w_wan)
# 4. Run the Architecture
with torch.no_grad():
z_wan = adapter(z_ltx, target_shape=target_shape)
print(f"Output Wan 2.1 Latent Shape: {z_wan.shape}")
# Output should be (1, 16, 7, 60, 88)
return z_wan
if __name__ == "__main__":
run_adapter()
Next Steps (Decoding to Video)
Once you have the z_wan latents from the above script, you can feed them directly into the Wan 2.1 VAE decoder to obtain the final pixel-space video frames.