VoxCPM-0.5B-RKNN2 / export_onnx.py
happyme531's picture
Upload 34 files
621e4aa verified
import argparse
import os
import torch
from torch import nn
from voxcpm.model.voxcpm import VoxCPMModel
def remove_weight_norm(module: nn.Module):
"""Strip weight_norm wrappers for cleaner ONNX graphs."""
for name, child in module.named_children():
remove_weight_norm(child)
if isinstance(child, (nn.Conv1d, nn.ConvTranspose1d)):
try:
torch.nn.utils.remove_weight_norm(child)
except ValueError:
# not wrapped, skip
pass
class VAEEncodeWrapper(nn.Module):
def __init__(self, audio_vae: nn.Module):
super().__init__()
self.audio_vae = audio_vae
def forward(self, audio_wave: torch.Tensor):
return self.audio_vae.encode(audio_wave, self.audio_vae.sample_rate)
class VAEDecodeWrapper(nn.Module):
def __init__(self, audio_vae: nn.Module):
super().__init__()
self.audio_vae = audio_vae
def forward(self, latent: torch.Tensor):
return self.audio_vae.decode(latent)
class LocEncWrapper(nn.Module):
def __init__(self, locenc: nn.Module):
super().__init__()
self.locenc = locenc
def forward(self, x: torch.Tensor):
# x: [B, T, P, D]
return self.locenc(x)
class LocEncLmWrapper(nn.Module):
"""LocEnc with enc_to_lm projection fused in a single graph."""
def __init__(self, locenc: nn.Module, proj: nn.Module):
super().__init__()
self.locenc = locenc
self.proj = proj
def forward(self, x: torch.Tensor):
# x: [B, T, P, D]
hidden = self.locenc(x)
return self.proj(hidden)
class FSQWrapper(nn.Module):
def __init__(self, fsq: nn.Module):
super().__init__()
self.fsq = fsq
def forward(self, hidden: torch.Tensor):
return self.fsq(hidden)
class StopHeadWrapper(nn.Module):
def __init__(self, stop_proj: nn.Linear, stop_actn: nn.Module, stop_head: nn.Linear):
super().__init__()
self.stop_proj = stop_proj
self.stop_actn = stop_actn
self.stop_head = stop_head
def forward(self, hidden: torch.Tensor):
hidden = self.stop_proj(hidden)
hidden = self.stop_actn(hidden)
return self.stop_head(hidden)
class CFMWrapper(nn.Module):
"""
Wrapper for one diffusion step block.
Note: the number of diffusion steps (n_timesteps) is fixed at export time.
"""
def __init__(self, cfm: nn.Module, patch_size: int, n_timesteps: int, cfg_value: float):
super().__init__()
self.cfm = cfm
self.patch_size = patch_size
self.n_timesteps = n_timesteps
self.cfg_value = cfg_value
def forward(self, mu: torch.Tensor, cond: torch.Tensor):
# mu: [B, H_dit], cond: [B, D_feat, P]
return self.cfm(
mu=mu,
n_timesteps=self.n_timesteps,
patch_size=self.patch_size,
cond=cond,
cfg_value=self.cfg_value,
)
class DiTStepWrapper(nn.Module):
"""
Wrapper for a single VoxCPMLocDiT forward (one diffusion score estimation step).
Inputs match VoxCPMLocDiT.forward: x, mu, t, cond, dt.
"""
def __init__(self, dit: nn.Module):
super().__init__()
self.dit = dit
def forward(self, x: torch.Tensor, mu: torch.Tensor, t: torch.Tensor, cond: torch.Tensor, dt: torch.Tensor):
return self.dit(x, mu, t, cond, dt)
def export(model: nn.Module, inputs, path: str, dynamic_axes: dict, opset: int):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.onnx.export(
model,
inputs,
path,
opset_version=opset,
do_constant_folding=True,
input_names=list(dynamic_axes.keys()),
output_names=["output"],
dynamic_axes=dynamic_axes,
)
print(f"Saved: {path}")
def main():
parser = argparse.ArgumentParser(description="Export VoxCPM submodules to ONNX (LLM excluded).")
parser.add_argument("--model-dir", required=True, help="Path to VoxCPM model directory (config/weights).")
parser.add_argument("--out-dir", default="onnx_exports", help="Output directory for ONNX files.")
parser.add_argument("--opset", type=int, default=18, help="ONNX opset version.")
parser.add_argument("--audio-samples", type=int, default=1280, help="Dummy audio length for encoder export.")
parser.add_argument("--latent-steps", type=int, default=6, help="Dummy latent steps for decoder export.")
parser.add_argument("--seq-len", type=int, default=4, help="Dummy sequence length for LocEnc/FSQ export.")
parser.add_argument("--dit-step-t", type=float, default=0.5, help="Dummy diffusion time for DiT step export.")
parser.add_argument("--force-fp32", action="store_true", help="Force submodules to float32 for ONNX export.")
parser.add_argument("--dump-embeddings", action="store_true", help="Dump base_lm.embed_tokens weights to npy.")
args = parser.parse_args()
device = torch.device("cpu")
# Load full model once, then peel submodules; keep optimize disabled.
full_model = VoxCPMModel.from_local(args.model_dir, optimize=False).to(device).eval()
if args.force_fp32 or full_model.config.dtype != "float32":
full_model.config.dtype = "float32"
full_model = full_model.to(torch.float32)
full_model.audio_vae = full_model.audio_vae.to(torch.float32)
remove_weight_norm(full_model)
# Audio VAE encode
vae_enc = VAEEncodeWrapper(full_model.audio_vae).to(device).eval()
dummy_audio = torch.randn(1, 1, args.audio_samples, device=device)
export(
vae_enc,
dummy_audio,
os.path.join(args.out_dir, "audio_vae_encode.onnx"),
dynamic_axes={"audio_wave": {0: "batch", 2: "samples"}},
opset=args.opset,
)
# Audio VAE decode
vae_dec = VAEDecodeWrapper(full_model.audio_vae).to(device).eval()
dummy_latent = torch.randn(1, full_model.audio_vae.latent_dim, args.latent_steps, device=device)
export(
vae_dec,
dummy_latent,
os.path.join(args.out_dir, "audio_vae_decode.onnx"),
dynamic_axes={"latent": {0: "batch", 2: "latent_steps"}},
opset=args.opset,
)
# LocEnc with enc_to_lm projection fused
locenc = LocEncLmWrapper(full_model.feat_encoder, full_model.enc_to_lm_proj).to(device).eval()
dummy_seq = torch.randn(1, args.seq_len, full_model.patch_size, full_model.feat_dim, device=device)
export(
locenc,
dummy_seq,
os.path.join(args.out_dir, "locenc.onnx"),
dynamic_axes={"x": {0: "batch", 1: "seq_len"}},
opset=args.opset,
)
# FSQ layer
fsq = FSQWrapper(full_model.fsq_layer).to(device).eval()
hidden_size = full_model.config.lm_config.hidden_size
dummy_hidden = torch.randn(1, args.seq_len, hidden_size, device=device)
export(
fsq,
dummy_hidden,
os.path.join(args.out_dir, "fsq_layer.onnx"),
dynamic_axes={"hidden": {0: "batch", 1: "seq_len"}},
opset=args.opset,
)
# Stop head
stop = StopHeadWrapper(full_model.stop_proj, full_model.stop_actn, full_model.stop_head).to(device).eval()
dummy_stop_inp = torch.randn(1, hidden_size, device=device)
export(
stop,
dummy_stop_inp,
os.path.join(args.out_dir, "stop_head.onnx"),
dynamic_axes={"hidden": {0: "batch"}},
opset=args.opset,
)
# Projection layers
# export(
# full_model.enc_to_lm_proj,
# dummy_hidden,
# os.path.join(args.out_dir, "enc_to_lm_proj.onnx"),
# dynamic_axes={"input": {0: "batch", 1: "seq_len"}},
# opset=args.opset,
# )
lm_hidden = torch.randn(1, full_model.config.lm_config.hidden_size, device=device)
export(
full_model.lm_to_dit_proj,
lm_hidden,
os.path.join(args.out_dir, "lm_to_dit_proj.onnx"),
dynamic_axes={"input": {0: "batch"}},
opset=args.opset,
)
export(
full_model.res_to_dit_proj,
lm_hidden,
os.path.join(args.out_dir, "res_to_dit_proj.onnx"),
dynamic_axes={"input": {0: "batch"}},
opset=args.opset,
)
# VoxCPMLocDiT single step (score function)
dit_step = DiTStepWrapper(full_model.feat_decoder.estimator).to(device).eval()
dummy_x = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device)
dummy_mu = torch.randn(1, full_model.config.dit_config.hidden_dim, device=device)
dummy_t = torch.full((1,), args.dit_step_t, device=device)
dummy_dt = torch.full((1,), 0.0, device=device)
dummy_cond = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device)
export(
dit_step,
(dummy_x, dummy_mu, dummy_t, dummy_cond, dummy_dt),
os.path.join(args.out_dir, "dit_step.onnx"),
dynamic_axes={
"x": {0: "batch"},
"mu": {0: "batch"},
"t": {0: "batch"},
"cond": {0: "batch"},
"dt": {0: "batch"},
},
opset=args.opset,
)
# # UnifiedCFM + VoxCPMLocDiT (single-step sampler unrolled with fixed n_timesteps)
# cfm = CFMWrapper(
# full_model.feat_decoder,
# patch_size=full_model.patch_size,
# n_timesteps=args.cfm_steps,
# cfg_value=args.cfg_value,
# ).to(device).eval()
# dummy_mu = torch.randn(1, full_model.config.dit_config.hidden_dim, device=device)
# dummy_cond = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device)
# export(
# cfm,
# (dummy_mu, dummy_cond),
# os.path.join(args.out_dir, "cfm_step.onnx"),
# dynamic_axes={"mu": {0: "batch"}, "cond": {0: "batch"}},
# opset=args.opset,
# )
if args.dump_embeddings and hasattr(full_model.base_lm, "embed_tokens"):
import numpy as np
emb = full_model.base_lm.embed_tokens.weight.detach().cpu().numpy()
os.makedirs(args.out_dir, exist_ok=True)
np.save(os.path.join(args.out_dir, "embed_tokens.npy"), emb)
print(f"Saved: {os.path.join(args.out_dir, 'embed_tokens.npy')}")
print("Done.")
if __name__ == "__main__":
main()