Den4ikAI commited on
Commit
f0b5414
·
verified ·
1 Parent(s): 6f5585b

Upload 3 files

Browse files
Files changed (3) hide show
  1. demo.py +59 -0
  2. ebanyvae.pt +3 -0
  3. ebanyvae.py +269 -0
demo.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import os
5
+ import numpy as np
6
+
7
+ from ebanyvae import EbanyCodec, CodecConfig
8
+
9
+ WEIGHTS_FILE = "ebanyvae.pt"
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ def init_engine():
13
+ codec = EbanyCodec()
14
+ if os.path.exists(WEIGHTS_FILE):
15
+ try:
16
+ params = torch.load(WEIGHTS_FILE, map_location="cpu")
17
+ codec.load_state_dict(params, strict=True)
18
+ except Exception:
19
+ pass
20
+ codec.to(DEVICE)
21
+ codec.eval()
22
+ return codec
23
+
24
+ processor = init_engine()
25
+
26
+ def process_signal(input_file):
27
+ if input_file is None:
28
+ return None
29
+ try:
30
+ signal, fs = torchaudio.load(input_file)
31
+ internal_sr = processor.cfg.sr
32
+ if fs != internal_sr:
33
+ resampler = torchaudio.transforms.Resample(orig_freq=fs, new_freq=internal_sr)
34
+ signal = resampler(signal)
35
+ if signal.shape[0] > 1:
36
+ signal = signal.mean(dim=0, keepdim=True)
37
+ input_tensor = signal.unsqueeze(0).to(DEVICE)
38
+ with torch.no_grad():
39
+ z = processor.encode(input_tensor, internal_sr)
40
+ out_tensor = processor.decode(z)
41
+ audio_out = out_tensor.squeeze().cpu().float().numpy()
42
+ return (internal_sr, audio_out)
43
+ except Exception:
44
+ return None
45
+
46
+ theme = gr.themes.Soft()
47
+
48
+ with gr.Blocks(theme=theme, title="Neural Audio Processor") as interface:
49
+ gr.Markdown("### Neural Codec Reconstruction Test")
50
+ with gr.Row():
51
+ with gr.Column():
52
+ audio_in = gr.Audio(type="filepath", label="Source Signal")
53
+ run_btn = gr.Button("Process Signal", variant="primary")
54
+ with gr.Column():
55
+ audio_out = gr.Audio(label="Synthesized Output")
56
+ run_btn.click(fn=process_signal, inputs=audio_in, outputs=audio_out)
57
+
58
+ if __name__ == "__main__":
59
+ interface.launch(server_name="0.0.0.0", share=True)
ebanyvae.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75640ec86cde3e0ccf2109e49d4b919d6682c5e3458d042311abb432b907c77e
3
+ size 346029598
ebanyvae.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from typing import List, Optional, Tuple, Dict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils import weight_norm
9
+ from pydantic import BaseModel
10
+
11
+ class WeightNormWrapper(nn.Module):
12
+ @staticmethod
13
+ def wrap(module):
14
+ return weight_norm(module)
15
+
16
+ class SineAct(nn.Module):
17
+ def __init__(self, channels: int):
18
+ super().__init__()
19
+ self.freq_param = nn.Parameter(torch.ones(1, channels, 1))
20
+
21
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
22
+ b, c, t = input_tensor.shape
23
+ flat_x = input_tensor.reshape(b, c, -1)
24
+ recip_alpha = 1.0 / (self.freq_param + 1e-9)
25
+ sine_part = torch.square(torch.sin(self.freq_param * flat_x))
26
+ out = flat_x + recip_alpha * sine_part
27
+ return out.reshape(b, c, t)
28
+
29
+ class TemporalConv(nn.Conv1d):
30
+ def __init__(self, *args, pad_val: int = 0, **kwargs):
31
+ if 'padding' in kwargs:
32
+ kwargs['padding'] = 0
33
+ super().__init__(*args, **kwargs)
34
+ self.pad_val = pad_val
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ if self.pad_val > 0:
38
+ x = F.pad(x, (self.pad_val * 2, 0))
39
+ return super().forward(x)
40
+
41
+ class TemporalTransposeConv(nn.ConvTranspose1d):
42
+ def __init__(self, *args, pad_val: int = 0, out_pad: int = 0, **kwargs):
43
+ if 'padding' in kwargs:
44
+ kwargs['padding'] = 0
45
+ if 'output_padding' in kwargs:
46
+ kwargs['output_padding'] = 0
47
+ super().__init__(*args, **kwargs)
48
+ self.pad_val = pad_val
49
+ self.out_pad = out_pad
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ out = super().forward(x)
53
+ trim = self.pad_val * 2 - self.out_pad
54
+ if trim > 0:
55
+ return out[..., :-trim]
56
+ return out
57
+
58
+ def get_normed_conv(in_c, out_c, k, d=1, p=0, g=1, s=1, bias=True):
59
+ return weight_norm(
60
+ TemporalConv(
61
+ in_c, out_c,
62
+ kernel_size=k,
63
+ stride=s,
64
+ padding=p,
65
+ dilation=d,
66
+ groups=g,
67
+ bias=bias,
68
+ pad_val=p
69
+ )
70
+ )
71
+
72
+ def get_normed_transpose(in_c, out_c, k, s, p, op):
73
+ return weight_norm(
74
+ TemporalTransposeConv(
75
+ in_c, out_c,
76
+ kernel_size=k,
77
+ stride=s,
78
+ padding=p,
79
+ output_padding=op,
80
+ pad_val=p,
81
+ out_pad=op
82
+ )
83
+ )
84
+
85
+ class ResidualUnit(nn.Module):
86
+ def __init__(self, channels: int, dilation_rate: int, kernel: int = 7, groups: int = 1):
87
+ super().__init__()
88
+ effective_padding = ((kernel - 1) * dilation_rate) // 2
89
+ self.ops = nn.Sequential(
90
+ SineAct(channels),
91
+ get_normed_conv(channels, channels, k=kernel, d=dilation_rate, p=effective_padding, g=groups),
92
+ SineAct(channels),
93
+ get_normed_conv(channels, channels, k=1)
94
+ )
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ identity = x
98
+ out = self.ops(x)
99
+ diff = identity.shape[-1] - out.shape[-1]
100
+ if diff > 0:
101
+ pad_trim = diff // 2
102
+ identity = identity[..., pad_trim:-pad_trim]
103
+ return identity + out
104
+
105
+ class EncoderStep(nn.Module):
106
+ def __init__(self, out_ch: int, in_ch: Optional[int] = None, factor: int = 1, groups: int = 1):
107
+ super().__init__()
108
+ in_ch = in_ch or out_ch // 2
109
+ res_stack = [
110
+ ResidualUnit(in_ch, dilation_rate=d, groups=groups)
111
+ for d in [1, 3, 9]
112
+ ]
113
+ downsampler = [
114
+ SineAct(in_ch),
115
+ get_normed_conv(
116
+ in_ch,
117
+ out_ch,
118
+ k=2 * factor,
119
+ s=factor,
120
+ p=math.ceil(factor / 2)
121
+ )
122
+ ]
123
+ self.ops = nn.Sequential(*res_stack, *downsampler)
124
+
125
+ def forward(self, x):
126
+ return self.ops(x)
127
+
128
+ class LatentEncoder(nn.Module):
129
+ def __init__(self, base_ch: int = 64, z_dim: int = 32, ratios: list = [2, 4, 8, 8], is_depthwise: bool = False):
130
+ super().__init__()
131
+ self.layers = nn.ModuleList()
132
+ self.layers.append(get_normed_conv(1, base_ch, k=7, p=3))
133
+ current_ch = base_ch
134
+ for r in ratios:
135
+ current_ch *= 2
136
+ grp = current_ch // 2 if is_depthwise else 1
137
+ self.layers.append(EncoderStep(out_ch=current_ch, factor=r, groups=grp))
138
+ self.calc_mu = get_normed_conv(current_ch, z_dim, k=3, p=1)
139
+ self.calc_logvar = get_normed_conv(current_ch, z_dim, k=3, p=1)
140
+ self.layers = nn.Sequential(*self.layers)
141
+
142
+ def forward(self, x):
143
+ h = self.layers(x)
144
+ return {
145
+ "h": h,
146
+ "mean": self.calc_mu(h),
147
+ "logvar": self.calc_logvar(h)
148
+ }
149
+
150
+ class StochasticInjector(nn.Module):
151
+ def __init__(self, dim):
152
+ super().__init__()
153
+ self.proj = weight_norm(
154
+ TemporalConv(dim, dim, kernel_size=1, bias=False, pad_val=0)
155
+ )
156
+
157
+ def forward(self, x):
158
+ noise = torch.randn_like(x[:, :1, :])
159
+ modulator = self.proj(x)
160
+ return x + (noise * modulator)
161
+
162
+ class DecoderStep(nn.Module):
163
+ def __init__(self, in_ch: int, out_ch: int, factor: int, groups: int = 1, noise: bool = False):
164
+ super().__init__()
165
+ stack = [
166
+ SineAct(in_ch),
167
+ get_normed_transpose(
168
+ in_ch,
169
+ out_ch,
170
+ k=2 * factor,
171
+ s=factor,
172
+ p=math.ceil(factor / 2),
173
+ op=factor % 2
174
+ )
175
+ ]
176
+ if noise:
177
+ stack.append(StochasticInjector(out_ch))
178
+ for d in [1, 3, 9]:
179
+ stack.append(ResidualUnit(out_ch, dilation_rate=d, groups=groups))
180
+ self.ops = nn.Sequential(*stack)
181
+
182
+ def forward(self, x):
183
+ return self.ops(x)
184
+
185
+ class LatentDecoder(nn.Module):
186
+ def __init__(self, z_dim, start_ch, ratios, is_depthwise=False, out_channels=1, use_noise=False):
187
+ super().__init__()
188
+ sequence = []
189
+ if is_depthwise:
190
+ sequence.extend([
191
+ get_normed_conv(z_dim, z_dim, k=7, p=3, g=z_dim),
192
+ get_normed_conv(z_dim, start_ch, k=1)
193
+ ])
194
+ else:
195
+ sequence.append(get_normed_conv(z_dim, start_ch, k=7, p=3))
196
+ for i, r in enumerate(ratios):
197
+ dim_in = start_ch // (2 ** i)
198
+ dim_out = start_ch // (2 ** (i + 1))
199
+ grp = dim_out if is_depthwise else 1
200
+ sequence.append(
201
+ DecoderStep(dim_in, dim_out, factor=r, groups=grp, noise=use_noise)
202
+ )
203
+ final_dim = dim_out
204
+ sequence.extend([
205
+ SineAct(final_dim),
206
+ get_normed_conv(final_dim, out_channels, k=7, p=3),
207
+ nn.Tanh()
208
+ ])
209
+ self.sequence = nn.Sequential(*sequence)
210
+
211
+ def forward(self, x):
212
+ return self.sequence(x)
213
+
214
+ class CodecConfig(BaseModel):
215
+ enc_dim: int = 64
216
+ enc_ratios: List[int] = [2, 3, 6, 7, 7]
217
+ z_dim: int = 64
218
+ dec_dim: int = 2048
219
+ dec_ratios: List[int] = [7, 7, 6, 3, 2]
220
+ depthwise_conv: bool = True
221
+ sr: int = 44100
222
+ noise_injection: bool = False
223
+
224
+ class EbanyCodec(nn.Module):
225
+ def __init__(self, cfg: Optional[CodecConfig] = None):
226
+ if cfg is None:
227
+ cfg = CodecConfig()
228
+ super().__init__()
229
+ self.cfg = cfg
230
+ if self.cfg.z_dim is None:
231
+ calc_dim = self.cfg.enc_dim * (2 ** len(self.cfg.enc_ratios))
232
+ else:
233
+ calc_dim = self.cfg.z_dim
234
+ self.encoder = LatentEncoder(
235
+ base_ch=self.cfg.enc_dim,
236
+ z_dim=calc_dim,
237
+ ratios=self.cfg.enc_ratios,
238
+ is_depthwise=self.cfg.depthwise_conv
239
+ )
240
+ self.decoder = LatentDecoder(
241
+ z_dim=calc_dim,
242
+ start_ch=self.cfg.dec_dim,
243
+ ratios=self.cfg.dec_ratios,
244
+ is_depthwise=self.cfg.depthwise_conv,
245
+ use_noise=self.cfg.noise_injection
246
+ )
247
+ self.hop = math.prod(self.cfg.enc_ratios)
248
+
249
+ def _pad_audio(self, wav):
250
+ total = wav.shape[-1]
251
+ remainder = total % self.hop
252
+ if remainder != 0:
253
+ missing = self.hop - remainder
254
+ wav = F.pad(wav, (0, missing))
255
+ return wav
256
+
257
+ def encode(self, wav: torch.Tensor, sr: int = None):
258
+ if wav.ndim == 2:
259
+ wav = wav.unsqueeze(1)
260
+ wav = self._pad_audio(wav)
261
+ res = self.encoder(wav)
262
+ return res["mean"]
263
+
264
+ def decode(self, latents: torch.Tensor):
265
+ return self.decoder(latents)
266
+
267
+ def forward(self, x, sr=None):
268
+ z = self.encode(x, sr)
269
+ return self.decode(z)