import argparse import os import torch from torch import nn from voxcpm.model.voxcpm import VoxCPMModel def remove_weight_norm(module: nn.Module): """Strip weight_norm wrappers for cleaner ONNX graphs.""" for name, child in module.named_children(): remove_weight_norm(child) if isinstance(child, (nn.Conv1d, nn.ConvTranspose1d)): try: torch.nn.utils.remove_weight_norm(child) except ValueError: # not wrapped, skip pass class VAEEncodeWrapper(nn.Module): def __init__(self, audio_vae: nn.Module): super().__init__() self.audio_vae = audio_vae def forward(self, audio_wave: torch.Tensor): return self.audio_vae.encode(audio_wave, self.audio_vae.sample_rate) class VAEDecodeWrapper(nn.Module): def __init__(self, audio_vae: nn.Module): super().__init__() self.audio_vae = audio_vae def forward(self, latent: torch.Tensor): return self.audio_vae.decode(latent) class LocEncWrapper(nn.Module): def __init__(self, locenc: nn.Module): super().__init__() self.locenc = locenc def forward(self, x: torch.Tensor): # x: [B, T, P, D] return self.locenc(x) class LocEncLmWrapper(nn.Module): """LocEnc with enc_to_lm projection fused in a single graph.""" def __init__(self, locenc: nn.Module, proj: nn.Module): super().__init__() self.locenc = locenc self.proj = proj def forward(self, x: torch.Tensor): # x: [B, T, P, D] hidden = self.locenc(x) return self.proj(hidden) class FSQWrapper(nn.Module): def __init__(self, fsq: nn.Module): super().__init__() self.fsq = fsq def forward(self, hidden: torch.Tensor): return self.fsq(hidden) class StopHeadWrapper(nn.Module): def __init__(self, stop_proj: nn.Linear, stop_actn: nn.Module, stop_head: nn.Linear): super().__init__() self.stop_proj = stop_proj self.stop_actn = stop_actn self.stop_head = stop_head def forward(self, hidden: torch.Tensor): hidden = self.stop_proj(hidden) hidden = self.stop_actn(hidden) return self.stop_head(hidden) class CFMWrapper(nn.Module): """ Wrapper for one diffusion step block. Note: the number of diffusion steps (n_timesteps) is fixed at export time. """ def __init__(self, cfm: nn.Module, patch_size: int, n_timesteps: int, cfg_value: float): super().__init__() self.cfm = cfm self.patch_size = patch_size self.n_timesteps = n_timesteps self.cfg_value = cfg_value def forward(self, mu: torch.Tensor, cond: torch.Tensor): # mu: [B, H_dit], cond: [B, D_feat, P] return self.cfm( mu=mu, n_timesteps=self.n_timesteps, patch_size=self.patch_size, cond=cond, cfg_value=self.cfg_value, ) class DiTStepWrapper(nn.Module): """ Wrapper for a single VoxCPMLocDiT forward (one diffusion score estimation step). Inputs match VoxCPMLocDiT.forward: x, mu, t, cond, dt. """ def __init__(self, dit: nn.Module): super().__init__() self.dit = dit def forward(self, x: torch.Tensor, mu: torch.Tensor, t: torch.Tensor, cond: torch.Tensor, dt: torch.Tensor): return self.dit(x, mu, t, cond, dt) def export(model: nn.Module, inputs, path: str, dynamic_axes: dict, opset: int): os.makedirs(os.path.dirname(path), exist_ok=True) torch.onnx.export( model, inputs, path, opset_version=opset, do_constant_folding=True, input_names=list(dynamic_axes.keys()), output_names=["output"], dynamic_axes=dynamic_axes, ) print(f"Saved: {path}") def main(): parser = argparse.ArgumentParser(description="Export VoxCPM submodules to ONNX (LLM excluded).") parser.add_argument("--model-dir", required=True, help="Path to VoxCPM model directory (config/weights).") parser.add_argument("--out-dir", default="onnx_exports", help="Output directory for ONNX files.") parser.add_argument("--opset", type=int, default=18, help="ONNX opset version.") parser.add_argument("--audio-samples", type=int, default=1280, help="Dummy audio length for encoder export.") parser.add_argument("--latent-steps", type=int, default=6, help="Dummy latent steps for decoder export.") parser.add_argument("--seq-len", type=int, default=4, help="Dummy sequence length for LocEnc/FSQ export.") parser.add_argument("--dit-step-t", type=float, default=0.5, help="Dummy diffusion time for DiT step export.") parser.add_argument("--force-fp32", action="store_true", help="Force submodules to float32 for ONNX export.") parser.add_argument("--dump-embeddings", action="store_true", help="Dump base_lm.embed_tokens weights to npy.") args = parser.parse_args() device = torch.device("cpu") # Load full model once, then peel submodules; keep optimize disabled. full_model = VoxCPMModel.from_local(args.model_dir, optimize=False).to(device).eval() if args.force_fp32 or full_model.config.dtype != "float32": full_model.config.dtype = "float32" full_model = full_model.to(torch.float32) full_model.audio_vae = full_model.audio_vae.to(torch.float32) remove_weight_norm(full_model) # Audio VAE encode vae_enc = VAEEncodeWrapper(full_model.audio_vae).to(device).eval() dummy_audio = torch.randn(1, 1, args.audio_samples, device=device) export( vae_enc, dummy_audio, os.path.join(args.out_dir, "audio_vae_encode.onnx"), dynamic_axes={"audio_wave": {0: "batch", 2: "samples"}}, opset=args.opset, ) # Audio VAE decode vae_dec = VAEDecodeWrapper(full_model.audio_vae).to(device).eval() dummy_latent = torch.randn(1, full_model.audio_vae.latent_dim, args.latent_steps, device=device) export( vae_dec, dummy_latent, os.path.join(args.out_dir, "audio_vae_decode.onnx"), dynamic_axes={"latent": {0: "batch", 2: "latent_steps"}}, opset=args.opset, ) # LocEnc with enc_to_lm projection fused locenc = LocEncLmWrapper(full_model.feat_encoder, full_model.enc_to_lm_proj).to(device).eval() dummy_seq = torch.randn(1, args.seq_len, full_model.patch_size, full_model.feat_dim, device=device) export( locenc, dummy_seq, os.path.join(args.out_dir, "locenc.onnx"), dynamic_axes={"x": {0: "batch", 1: "seq_len"}}, opset=args.opset, ) # FSQ layer fsq = FSQWrapper(full_model.fsq_layer).to(device).eval() hidden_size = full_model.config.lm_config.hidden_size dummy_hidden = torch.randn(1, args.seq_len, hidden_size, device=device) export( fsq, dummy_hidden, os.path.join(args.out_dir, "fsq_layer.onnx"), dynamic_axes={"hidden": {0: "batch", 1: "seq_len"}}, opset=args.opset, ) # Stop head stop = StopHeadWrapper(full_model.stop_proj, full_model.stop_actn, full_model.stop_head).to(device).eval() dummy_stop_inp = torch.randn(1, hidden_size, device=device) export( stop, dummy_stop_inp, os.path.join(args.out_dir, "stop_head.onnx"), dynamic_axes={"hidden": {0: "batch"}}, opset=args.opset, ) # Projection layers # export( # full_model.enc_to_lm_proj, # dummy_hidden, # os.path.join(args.out_dir, "enc_to_lm_proj.onnx"), # dynamic_axes={"input": {0: "batch", 1: "seq_len"}}, # opset=args.opset, # ) lm_hidden = torch.randn(1, full_model.config.lm_config.hidden_size, device=device) export( full_model.lm_to_dit_proj, lm_hidden, os.path.join(args.out_dir, "lm_to_dit_proj.onnx"), dynamic_axes={"input": {0: "batch"}}, opset=args.opset, ) export( full_model.res_to_dit_proj, lm_hidden, os.path.join(args.out_dir, "res_to_dit_proj.onnx"), dynamic_axes={"input": {0: "batch"}}, opset=args.opset, ) # VoxCPMLocDiT single step (score function) dit_step = DiTStepWrapper(full_model.feat_decoder.estimator).to(device).eval() dummy_x = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device) dummy_mu = torch.randn(1, full_model.config.dit_config.hidden_dim, device=device) dummy_t = torch.full((1,), args.dit_step_t, device=device) dummy_dt = torch.full((1,), 0.0, device=device) dummy_cond = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device) export( dit_step, (dummy_x, dummy_mu, dummy_t, dummy_cond, dummy_dt), os.path.join(args.out_dir, "dit_step.onnx"), dynamic_axes={ "x": {0: "batch"}, "mu": {0: "batch"}, "t": {0: "batch"}, "cond": {0: "batch"}, "dt": {0: "batch"}, }, opset=args.opset, ) # # UnifiedCFM + VoxCPMLocDiT (single-step sampler unrolled with fixed n_timesteps) # cfm = CFMWrapper( # full_model.feat_decoder, # patch_size=full_model.patch_size, # n_timesteps=args.cfm_steps, # cfg_value=args.cfg_value, # ).to(device).eval() # dummy_mu = torch.randn(1, full_model.config.dit_config.hidden_dim, device=device) # dummy_cond = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device) # export( # cfm, # (dummy_mu, dummy_cond), # os.path.join(args.out_dir, "cfm_step.onnx"), # dynamic_axes={"mu": {0: "batch"}, "cond": {0: "batch"}}, # opset=args.opset, # ) if args.dump_embeddings and hasattr(full_model.base_lm, "embed_tokens"): import numpy as np emb = full_model.base_lm.embed_tokens.weight.detach().cpu().numpy() os.makedirs(args.out_dir, exist_ok=True) np.save(os.path.join(args.out_dir, "embed_tokens.npy"), emb) print(f"Saved: {os.path.join(args.out_dir, 'embed_tokens.npy')}") print("Done.") if __name__ == "__main__": main()