| | import argparse |
| | import os |
| | import torch |
| | from torch import nn |
| |
|
| | from voxcpm.model.voxcpm import VoxCPMModel |
| |
|
| |
|
| | def remove_weight_norm(module: nn.Module): |
| | """Strip weight_norm wrappers for cleaner ONNX graphs.""" |
| | for name, child in module.named_children(): |
| | remove_weight_norm(child) |
| | if isinstance(child, (nn.Conv1d, nn.ConvTranspose1d)): |
| | try: |
| | torch.nn.utils.remove_weight_norm(child) |
| | except ValueError: |
| | |
| | pass |
| |
|
| |
|
| | class VAEEncodeWrapper(nn.Module): |
| | def __init__(self, audio_vae: nn.Module): |
| | super().__init__() |
| | self.audio_vae = audio_vae |
| |
|
| | def forward(self, audio_wave: torch.Tensor): |
| | return self.audio_vae.encode(audio_wave, self.audio_vae.sample_rate) |
| |
|
| |
|
| | class VAEDecodeWrapper(nn.Module): |
| | def __init__(self, audio_vae: nn.Module): |
| | super().__init__() |
| | self.audio_vae = audio_vae |
| |
|
| | def forward(self, latent: torch.Tensor): |
| | return self.audio_vae.decode(latent) |
| |
|
| |
|
| | class LocEncWrapper(nn.Module): |
| | def __init__(self, locenc: nn.Module): |
| | super().__init__() |
| | self.locenc = locenc |
| |
|
| | def forward(self, x: torch.Tensor): |
| | |
| | return self.locenc(x) |
| |
|
| |
|
| | class LocEncLmWrapper(nn.Module): |
| | """LocEnc with enc_to_lm projection fused in a single graph.""" |
| |
|
| | def __init__(self, locenc: nn.Module, proj: nn.Module): |
| | super().__init__() |
| | self.locenc = locenc |
| | self.proj = proj |
| |
|
| | def forward(self, x: torch.Tensor): |
| | |
| | hidden = self.locenc(x) |
| | return self.proj(hidden) |
| |
|
| |
|
| | class FSQWrapper(nn.Module): |
| | def __init__(self, fsq: nn.Module): |
| | super().__init__() |
| | self.fsq = fsq |
| |
|
| | def forward(self, hidden: torch.Tensor): |
| | return self.fsq(hidden) |
| |
|
| |
|
| | class StopHeadWrapper(nn.Module): |
| | def __init__(self, stop_proj: nn.Linear, stop_actn: nn.Module, stop_head: nn.Linear): |
| | super().__init__() |
| | self.stop_proj = stop_proj |
| | self.stop_actn = stop_actn |
| | self.stop_head = stop_head |
| |
|
| | def forward(self, hidden: torch.Tensor): |
| | hidden = self.stop_proj(hidden) |
| | hidden = self.stop_actn(hidden) |
| | return self.stop_head(hidden) |
| |
|
| |
|
| | class CFMWrapper(nn.Module): |
| | """ |
| | Wrapper for one diffusion step block. |
| | |
| | Note: the number of diffusion steps (n_timesteps) is fixed at export time. |
| | """ |
| |
|
| | def __init__(self, cfm: nn.Module, patch_size: int, n_timesteps: int, cfg_value: float): |
| | super().__init__() |
| | self.cfm = cfm |
| | self.patch_size = patch_size |
| | self.n_timesteps = n_timesteps |
| | self.cfg_value = cfg_value |
| |
|
| | def forward(self, mu: torch.Tensor, cond: torch.Tensor): |
| | |
| | return self.cfm( |
| | mu=mu, |
| | n_timesteps=self.n_timesteps, |
| | patch_size=self.patch_size, |
| | cond=cond, |
| | cfg_value=self.cfg_value, |
| | ) |
| |
|
| |
|
| | class DiTStepWrapper(nn.Module): |
| | """ |
| | Wrapper for a single VoxCPMLocDiT forward (one diffusion score estimation step). |
| | Inputs match VoxCPMLocDiT.forward: x, mu, t, cond, dt. |
| | """ |
| |
|
| | def __init__(self, dit: nn.Module): |
| | super().__init__() |
| | self.dit = dit |
| |
|
| | def forward(self, x: torch.Tensor, mu: torch.Tensor, t: torch.Tensor, cond: torch.Tensor, dt: torch.Tensor): |
| | return self.dit(x, mu, t, cond, dt) |
| |
|
| |
|
| | def export(model: nn.Module, inputs, path: str, dynamic_axes: dict, opset: int): |
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | torch.onnx.export( |
| | model, |
| | inputs, |
| | path, |
| | opset_version=opset, |
| | do_constant_folding=True, |
| | input_names=list(dynamic_axes.keys()), |
| | output_names=["output"], |
| | dynamic_axes=dynamic_axes, |
| | ) |
| | print(f"Saved: {path}") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Export VoxCPM submodules to ONNX (LLM excluded).") |
| | parser.add_argument("--model-dir", required=True, help="Path to VoxCPM model directory (config/weights).") |
| | parser.add_argument("--out-dir", default="onnx_exports", help="Output directory for ONNX files.") |
| | parser.add_argument("--opset", type=int, default=18, help="ONNX opset version.") |
| | parser.add_argument("--audio-samples", type=int, default=1280, help="Dummy audio length for encoder export.") |
| | parser.add_argument("--latent-steps", type=int, default=6, help="Dummy latent steps for decoder export.") |
| | parser.add_argument("--seq-len", type=int, default=4, help="Dummy sequence length for LocEnc/FSQ export.") |
| | parser.add_argument("--dit-step-t", type=float, default=0.5, help="Dummy diffusion time for DiT step export.") |
| | parser.add_argument("--force-fp32", action="store_true", help="Force submodules to float32 for ONNX export.") |
| | parser.add_argument("--dump-embeddings", action="store_true", help="Dump base_lm.embed_tokens weights to npy.") |
| | args = parser.parse_args() |
| |
|
| | device = torch.device("cpu") |
| | |
| | full_model = VoxCPMModel.from_local(args.model_dir, optimize=False).to(device).eval() |
| | if args.force_fp32 or full_model.config.dtype != "float32": |
| | full_model.config.dtype = "float32" |
| | full_model = full_model.to(torch.float32) |
| | full_model.audio_vae = full_model.audio_vae.to(torch.float32) |
| | remove_weight_norm(full_model) |
| |
|
| | |
| | vae_enc = VAEEncodeWrapper(full_model.audio_vae).to(device).eval() |
| | dummy_audio = torch.randn(1, 1, args.audio_samples, device=device) |
| | export( |
| | vae_enc, |
| | dummy_audio, |
| | os.path.join(args.out_dir, "audio_vae_encode.onnx"), |
| | dynamic_axes={"audio_wave": {0: "batch", 2: "samples"}}, |
| | opset=args.opset, |
| | ) |
| |
|
| | |
| | vae_dec = VAEDecodeWrapper(full_model.audio_vae).to(device).eval() |
| | dummy_latent = torch.randn(1, full_model.audio_vae.latent_dim, args.latent_steps, device=device) |
| | export( |
| | vae_dec, |
| | dummy_latent, |
| | os.path.join(args.out_dir, "audio_vae_decode.onnx"), |
| | dynamic_axes={"latent": {0: "batch", 2: "latent_steps"}}, |
| | opset=args.opset, |
| | ) |
| |
|
| | |
| | locenc = LocEncLmWrapper(full_model.feat_encoder, full_model.enc_to_lm_proj).to(device).eval() |
| | dummy_seq = torch.randn(1, args.seq_len, full_model.patch_size, full_model.feat_dim, device=device) |
| | export( |
| | locenc, |
| | dummy_seq, |
| | os.path.join(args.out_dir, "locenc.onnx"), |
| | dynamic_axes={"x": {0: "batch", 1: "seq_len"}}, |
| | opset=args.opset, |
| | ) |
| |
|
| | |
| | fsq = FSQWrapper(full_model.fsq_layer).to(device).eval() |
| | hidden_size = full_model.config.lm_config.hidden_size |
| | dummy_hidden = torch.randn(1, args.seq_len, hidden_size, device=device) |
| | export( |
| | fsq, |
| | dummy_hidden, |
| | os.path.join(args.out_dir, "fsq_layer.onnx"), |
| | dynamic_axes={"hidden": {0: "batch", 1: "seq_len"}}, |
| | opset=args.opset, |
| | ) |
| |
|
| | |
| | stop = StopHeadWrapper(full_model.stop_proj, full_model.stop_actn, full_model.stop_head).to(device).eval() |
| | dummy_stop_inp = torch.randn(1, hidden_size, device=device) |
| | export( |
| | stop, |
| | dummy_stop_inp, |
| | os.path.join(args.out_dir, "stop_head.onnx"), |
| | dynamic_axes={"hidden": {0: "batch"}}, |
| | opset=args.opset, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | lm_hidden = torch.randn(1, full_model.config.lm_config.hidden_size, device=device) |
| | export( |
| | full_model.lm_to_dit_proj, |
| | lm_hidden, |
| | os.path.join(args.out_dir, "lm_to_dit_proj.onnx"), |
| | dynamic_axes={"input": {0: "batch"}}, |
| | opset=args.opset, |
| | ) |
| | export( |
| | full_model.res_to_dit_proj, |
| | lm_hidden, |
| | os.path.join(args.out_dir, "res_to_dit_proj.onnx"), |
| | dynamic_axes={"input": {0: "batch"}}, |
| | opset=args.opset, |
| | ) |
| |
|
| | |
| | dit_step = DiTStepWrapper(full_model.feat_decoder.estimator).to(device).eval() |
| | dummy_x = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device) |
| | dummy_mu = torch.randn(1, full_model.config.dit_config.hidden_dim, device=device) |
| | dummy_t = torch.full((1,), args.dit_step_t, device=device) |
| | dummy_dt = torch.full((1,), 0.0, device=device) |
| | dummy_cond = torch.randn(1, full_model.feat_dim, full_model.patch_size, device=device) |
| | export( |
| | dit_step, |
| | (dummy_x, dummy_mu, dummy_t, dummy_cond, dummy_dt), |
| | os.path.join(args.out_dir, "dit_step.onnx"), |
| | dynamic_axes={ |
| | "x": {0: "batch"}, |
| | "mu": {0: "batch"}, |
| | "t": {0: "batch"}, |
| | "cond": {0: "batch"}, |
| | "dt": {0: "batch"}, |
| | }, |
| | opset=args.opset, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if args.dump_embeddings and hasattr(full_model.base_lm, "embed_tokens"): |
| | import numpy as np |
| | emb = full_model.base_lm.embed_tokens.weight.detach().cpu().numpy() |
| | os.makedirs(args.out_dir, exist_ok=True) |
| | np.save(os.path.join(args.out_dir, "embed_tokens.npy"), emb) |
| | print(f"Saved: {os.path.join(args.out_dir, 'embed_tokens.npy')}") |
| |
|
| | print("Done.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|