Instructions to use CAMB-AI/MARS5-TTS with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MARS5-TTS
How to use CAMB-AI/MARS5-TTS with MARS5-TTS:
# Install from https://github.com/Camb-ai/MARS5-TTS from inference import Mars5TTS mars5 = Mars5TTS.from_pretrained("CAMB-AI/MARS5-TTS") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Discrete multinomial diffusion code adapted from https://github.com/RF5/transfusion-asr, | |
| which in turn is adapted from https://github.com/ehoogeboom/multinomial_diffusion. | |
| Please see the original repo (https://github.com/ehoogeboom/multinomial_diffusion) and paper for full | |
| details on how multinomial diffusion works -- thanks to the original authors! | |
| """ | |
| import torch | |
| from torch import Tensor | |
| from torch.functional import F | |
| import numpy as np | |
| from dataclasses import dataclass | |
| from typing import Union | |
| # -------------- Multinomial utility functions ----------- | |
| MIN_LOG_ARG = 1e-7 # originally was 1e-40 | |
| def log_1_min_a(a): return torch.log((1 - a.exp()).clamp_(min=1e-30)) | |
| def log_add_exp(a, b): | |
| maximum = torch.max(a, b) | |
| return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) | |
| def extract(a: Tensor, t, x_shape): | |
| """ Given 1D vector of alpha/alpha_cum/betas, get index at `t` of shape (bs,), and then | |
| broadcast it to number of dims in `x_shape`. | |
| """ | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| def index_to_log_onehot(x, num_classes, dim=-1, dtype=torch.float32): | |
| """ Convert indices `x` (bs, ...) to approx one-hot log-probs of shape (bs, ..., num_classes) """ | |
| assert x.max().item() < num_classes, \ | |
| f'Error: {x.max().item()} >= {num_classes}' | |
| x_onehot = F.one_hot(x, num_classes) | |
| if dim == 1: | |
| permute_order = (0, -1) + tuple(range(1, len(x.size()))) | |
| x_onehot = x_onehot.permute(permute_order) | |
| else: | |
| pass | |
| log_x = torch.log(x_onehot.to(dtype).clamp(min=MIN_LOG_ARG)) # so min(log_x) will be -30 | |
| return log_x | |
| def sum_except_batch(x: Tensor, num_dims=1) -> Tensor: | |
| ''' | |
| Sums all dimensions except the first. | |
| Args: | |
| x: Tensor, shape (batch_size, ...) | |
| num_dims: int, number of batch dims (default=1) | |
| Returns: | |
| x_sum: Tensor, shape (batch_size,) | |
| ''' | |
| return x.reshape(*x.shape[:num_dims], -1).sum(-1) | |
| # -------------- Multinomial diffusion class ------------- | |
| class MultinomialDiffusion(): | |
| def __init__(self, num_classes, timesteps=100, diffusion_s=0.008, | |
| loss_type='vb_stochastic', parametrization='x0', | |
| dtype=torch.float32, | |
| device='cpu'): | |
| super(MultinomialDiffusion, self).__init__() | |
| assert loss_type in ('vb_stochastic',) | |
| assert parametrization in ('x0', 'direct') | |
| self.num_classes = num_classes | |
| self.loss_type = loss_type | |
| self.num_timesteps = timesteps | |
| self.parametrization = parametrization | |
| alphas = self.cosine_beta_schedule(timesteps, diffusion_s) | |
| alphas = alphas.to(torch.float64) | |
| log_alpha = alphas.log() | |
| log_cumprod_alpha = torch.cumsum(log_alpha, dim=-1) | |
| log_1_min_alpha = log_1_min_a(log_alpha) # = log(betas) | |
| log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) # = log(1- \bar{a}) | |
| a = log_add_exp(log_alpha, log_1_min_alpha) # log(1-beta + beta) = log(1) = 0 | |
| assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5 | |
| assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 | |
| assert (torch.cumsum(log_alpha, dim=-1) - log_cumprod_alpha).abs().sum().item() < 1.e-5 | |
| # Convert to float32 and register buffers. | |
| self.log_alpha = log_alpha.to(dtype).to(device) | |
| self.log_1_min_alpha = log_1_min_alpha.to(dtype).to(device) | |
| self.log_cumprod_alpha = log_cumprod_alpha.to(dtype).to(device) | |
| self.log_1_min_cumprod_alpha = log_1_min_cumprod_alpha.to(dtype).to(device) | |
| def cosine_beta_schedule(timesteps, s=0.008) -> Tensor: | |
| """ | |
| cosine schedule as proposed in https://arxiv.org/abs/2102.09672 . | |
| Returns alpha parameters, NOT Beta | |
| """ | |
| steps = timesteps + 1 | |
| x = torch.linspace(0, timesteps, steps) | |
| alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 | |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
| alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
| alphas = torch.clamp(alphas, 0.001, 1.0) | |
| return torch.sqrt(alphas) | |
| def multinomial_kl(self, log_prob1: Tensor, log_prob2: Tensor, dim=-1) -> Tensor: | |
| """ Get KL divergence between two categorical distributions specified with `log_prob1` and `log_prob2`. | |
| Assumed probability dim is `dim` (i.e. log_prob1.exp().sum(dim=`dim`) should be tensor of ones) | |
| """ | |
| kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=dim) | |
| return kl | |
| def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor: | |
| """ Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain | |
| given `log_x_t` as log one-hot encoding of x_t. | |
| Recall due to symmetry property we can compute | |
| this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf) | |
| """ | |
| dt = log_x_t.dtype | |
| log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt) | |
| log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt) | |
| # alpha_t * E[xt] + (1 - alpha_t) 1 / K | |
| log_probs = log_add_exp( | |
| log_x_t + log_alpha_t, | |
| log_1_min_alpha_t - np.log(self.num_classes) | |
| ) | |
| return log_probs | |
| def q_pred_one_timestep_scaled(self, log_x_t: Tensor, t: Tensor, c: int, jump_len: int) -> Tensor: | |
| """ Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain | |
| given `log_x_t` as log one-hot encoding of x_t. | |
| Recall due to symmetry property we can compute | |
| this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf) | |
| """ | |
| dt = log_x_t.dtype | |
| log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt) | |
| log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt) | |
| # Magic | |
| xax = torch.arange(0,log_x_t.shape[1],1).to(log_x_t.device) | |
| aa=log_x_t.shape[1]*(c/jump_len) | |
| sig = 1/(1+torch.exp(-(xax-aa+20)/8)) | |
| log_alpha_t = (torch.log(1/sig)[None,:,None] + log_alpha_t).clamp(-torch.inf, 0) | |
| log_1_min_alpha_t = torch.log(sig)[None,:,None] + log_1_min_alpha_t | |
| # alpha_t * E[xt] + (1 - alpha_t) 1 / K | |
| log_probs = log_add_exp( | |
| log_x_t + log_alpha_t, | |
| log_1_min_alpha_t - np.log(self.num_classes) | |
| ) | |
| return log_probs | |
| def q_pred(self, log_x_start: Tensor, t) -> Tensor: | |
| """ Compute q(x_t | x_0) = C(x_t | bar{alpha}_t * x_0 + (1 - bar{alpha}_t)/K ) in log domain, | |
| given `log_x_start` of log probs of x_0. | |
| """ | |
| dt = log_x_start.dtype | |
| log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape).to(dt) | |
| log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape).to(dt) | |
| log_probs = log_add_exp( | |
| log_x_start + log_cumprod_alpha_t, | |
| log_1_min_cumprod_alpha - np.log(self.num_classes) | |
| ) | |
| return log_probs | |
| def q_posterior(self, log_x_start, log_x_t, t): | |
| """ Compute `q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)` | |
| where q(xt | xt-1, x0) = q(xt | xt-1). | |
| """ | |
| # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0) | |
| # where q(xt | xt-1, x0) = q(xt | xt-1). | |
| t_minus_1 = t - 1 | |
| # Remove negative values, will not be used anyway for final decoder | |
| t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) | |
| log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) # log( q(x_{t-1} | x_0) ) | |
| # if t == 0, then log( q(x_0 | x_0) ) = log( one_hot(x_0) ), not even random at that point. | |
| # so, where t == 0 | |
| num_axes = (1,) * (len(log_x_start.size()) - 1) | |
| t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start) # broadcast to non-batch axes | |
| log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0) | |
| # where it is zero, replace | |
| # with log one-hot encoding of x0. | |
| # Note: _NOT_ x_tmin1, which is how the formula is typically used!!! | |
| # Not very easy to see why this is true. But it is :) | |
| # log_EV_qxtmin_x0 ~ q(x_{t-1} | x_0) | |
| # q_pred_one_timestep(log_x_t, t) ~ q(x_t | x_{t-1}) (which due to symmetry can be computed using x_t) | |
| unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) # numerator of bayes | |
| # approximate denominator with just a normalizing sum. | |
| log_EV_xtmin_given_xt_given_xstart = \ | |
| unnormed_logprobs \ | |
| - torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True) | |
| return log_EV_xtmin_given_xt_given_xstart | |
| def p_pred(self, log_x_t, t, log_x0_pred): | |
| """ Predict `p(x_{t-1} | x_t)` using `q(xt-1 | xt, hat{x0})`, where `hat{x0}` is given by | |
| log probabilities from model as `log_x0_pred` (bs, ...., K) and x_t is given by | |
| `log_x_t` of shape `(bs, ..., K)` | |
| """ | |
| # log_x_recon = self.predict_start(log_x, t=t) # model itself predicts x_0 | |
| # log_x0_pred | |
| log_model_pred = self.q_posterior( | |
| log_x_start=log_x0_pred, log_x_t=log_x_t, t=t) | |
| return log_model_pred | |
| def log_sample_categorical(self, logprobs: Tensor, dim=-1) -> Tensor: | |
| """ Sample from categorical `logprobs` (bs, ..., probs), where position of probs is specified | |
| by `dim`. | |
| Returns sampled long indices of shape `(bs, ...)` | |
| """ | |
| uniform = torch.rand_like(logprobs) | |
| gumbel_noise = -torch.log( (-torch.log(uniform.clamp_(min=MIN_LOG_ARG)) ).clamp_(min=MIN_LOG_ARG)) | |
| sample = (gumbel_noise + logprobs).argmax(dim=dim) | |
| return sample | |
| def q_sample(self, log_x_start, t): | |
| """ Draw `x_t` ~ q(x_t | x_0) . `log_x_start` is of shape `(bs, ..., K)`, returns result of same shape """ | |
| log_EV_qxt_x0 = self.q_pred(log_x_start, t) | |
| sample = self.log_sample_categorical(log_EV_qxt_x0) | |
| # log_sample = index_to_log_onehot(sample, self.num_classes) | |
| return sample #log_sample | |
| def compute_Lt(self, log_x_start: Tensor, log_x_t: Tensor, log_x0_pred: Tensor, t, | |
| detach_mean=False, include_kl_prior=True): | |
| """ Get loss given one-hot log x_0, one-hot log x_t, t, and model prediction `log_x0_pred`. | |
| Parameters: | |
| - `log_x_start`: ground-truth input x0, converted to log one-hot (bs, ..., K) | |
| - `log_x_t`: sampled noisy input at `x_t`, converted to log one-hot (bs, ..., K) | |
| - `t`: diffusion timestep (bs,) | |
| - `log_x0_pred`: model prediction of log probabilities of x0, i.e. hat{x0}. | |
| - `include_kl_prior`: add last two terms to model loss (does not change optimization problem). | |
| """ | |
| dtype = log_x_start.dtype | |
| log_true_prob = self.q_posterior( | |
| log_x_start=log_x_start, log_x_t=log_x_t, t=t) | |
| log_model_prob = self.p_pred(log_x_t=log_x_t, t=t, log_x0_pred=log_x0_pred) | |
| if detach_mean: | |
| log_model_prob = log_model_prob.detach() | |
| kl = self.multinomial_kl(log_true_prob, log_model_prob) | |
| kl = sum_except_batch(kl) | |
| # Add L_0, -log(p(x_0 | x_1)) | |
| decoder_nll = - (log_x_start.exp() * log_model_prob).sum(dim=-1) | |
| decoder_nll = sum_except_batch(decoder_nll) | |
| mask = (t == torch.zeros_like(t)).to(dtype) | |
| loss = mask * decoder_nll + (1. - mask) * kl # only add L0 if t == 0. | |
| if include_kl_prior: | |
| pt = torch.ones_like(t, dtype=dtype) | |
| kl_prior = self.kl_prior(log_x_start) | |
| loss = (kl) + kl_prior | |
| return loss | |
| def kl_prior(self, log_x_start: Tensor) -> Tensor: | |
| """ This function computes -H_{q}(x_T | x_0)+H_{p}(x_T), which | |
| by some math (see wiki for KL div relation to conditional entropy). | |
| So KL(q(x_T | x_0) || 1/K) = -H_{q}(x_T | x_0)+H_{p}(x_T) for categorical distribution. | |
| Given `log_x_start` (bs, ..., probs), return KL prior of shape (bs,) | |
| """ | |
| b = log_x_start.size(0) | |
| device = log_x_start.device | |
| ones = torch.ones(b, device=device, dtype=torch.long) | |
| log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) # q(x_T | x_0) | |
| log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob)) # log(1/K), broadcast to q(x_T|x_0) shape | |
| kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) | |
| return sum_except_batch(kl_prior) | |
| def index2logit(x: Tensor, vocab_size: int, dtype=torch.float32): | |
| x = F.one_hot(x, num_classes=vocab_size).to(dtype) | |
| x = x * (vocab_size/(vocab_size - 1)) - 1/(vocab_size - 1) | |
| return x | |
| # ------------------------------ | |
| # Functions adapted from the full | |
| class DSH(): | |
| # Diffusion Sampling Hyperparameters [DSH] (Section 4) | |
| jump_len: int = 1 # j in RePaint paper [default 10] (Section 4.1) | |
| jump_n_sample: int = 1 # r in RePaint paper [default 10] (Section 4.1) | |
| last_greedy: bool = False # whether to not sample at t=0, but take argmax prediction. [default False] | |
| x_0_temp: float = 1.0 # reweight temp for model prediction of x0 | |
| guidance_w: float = 1.0 # classifier free guidance weight [default 1.5] (Section 4.3) | |
| enable_kevin_scaled_inference: bool = True # sequentially progressive diffusion [default True] (Section 4.2) | |
| T_override: Union[None, int] = None # allow variable transcription sizes during inference (Section 4.4) | |
| deep_clone: bool = False # whether to do deep clone. | |
| q0_override_steps: int = 0 # number of steps that we allow overriding the input quant level 0 inputs. | |
| progress: bool = False # whether to show progress bar | |
| def get_schedule(t_T, jump_len=10, jump_n_sample=10): | |
| jumps = {} | |
| for j in range(0, t_T - jump_len, jump_len): | |
| jumps[j] = jump_n_sample - 1 | |
| t = t_T | |
| ts = [] | |
| while t >= 1: | |
| t = t-1 | |
| ts.append(t) | |
| if jumps.get(t, 0) > 0: | |
| jumps[t] = jumps[t] - 1 | |
| for _ in range(jump_len): | |
| t = t + 1 | |
| ts.append(t) | |
| ts.append(-1) | |
| return ts | |
| def forward_diffusion(diff: MultinomialDiffusion, dtype, x, t, c=None, dsh=DSH): | |
| """Simple forward diffusion process p""" | |
| log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=dtype) | |
| if c is not None: x = diff.q_pred_one_timestep_scaled(log_x_t, t, c, dsh.jump_len) | |
| else: x = diff.q_pred_one_timestep(log_x_t, t) | |
| x = diff.log_sample_categorical(x) | |
| return x | |
| def reverse_diffusion(diff: MultinomialDiffusion, model, batch, x_known=None, m=None, | |
| last_greedy=False, temperature=1.0, alphas=None, ensemble_size=1, dsh=DSH): | |
| """Reverse diffusion process q: predict x_{t-1} given x, t, x_known, m. Optionally do not sample model output | |
| for t=0, but rather use the greedy argmax with `last_greedy`. | |
| """ | |
| x = batch[4] | |
| t = batch[-1] | |
| if x_known is None: x_known = torch.zeros_like(x) | |
| if m is None: m = torch.zeros_like(x) | |
| # Equation 8b | |
| # for b in batch: | |
| # print(f"{b.shape}: {b}") | |
| x_0_pred = model(*batch) # (bs, seq_len, logit_dim, n_quant) | |
| x_0_pred = x_0_pred.permute(0, 1, 3, 2) # (bs, seq_len, n_quant, dim) | |
| if dsh.guidance_w != 1: | |
| uncond_x_0_pred = model(*(c.clone() if c is not None else None for c in batch), drop_cond=True) | |
| uncond_x_0_pred = uncond_x_0_pred.permute(0, 1, 3, 2) | |
| x_0_pred = dsh.guidance_w*x_0_pred + (1-dsh.guidance_w)*uncond_x_0_pred | |
| x_0_pred = x_0_pred / temperature | |
| log_x_0_pred = F.log_softmax(x_0_pred, dim=-1) | |
| log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=x_0_pred.dtype) | |
| # print("PRE: ", log_x_t.shape, t.shape, log_x_0_pred.shape) | |
| log_model_pred = diff.p_pred(log_x_t, t, log_x_0_pred) # p(x_{t-1} | x_{t}) | |
| a_t = alphas[t[0]] if alphas is not None else 0 | |
| mat = torch.eye(ensemble_size, device=x.device)*(1-a_t) | |
| mat += 1/ensemble_size * a_t | |
| mat = torch.block_diag(*([mat]*(x.shape[0]//ensemble_size))) | |
| log_model_pred = ( (mat[..., None, None] ).log().to(x.dtype) + log_model_pred[None]) | |
| log_model_pred = torch.logsumexp(log_model_pred, dim=1) | |
| if (t==0).all() and last_greedy: # Do not sample at t=0 | |
| x_tm1_unknown = log_model_pred.argmax(dim=-1) | |
| else: | |
| x_tm1_unknown = diff.log_sample_categorical(log_model_pred) | |
| # Equation 8a | |
| x_known_log = index_to_log_onehot(x_known, diff.num_classes, dtype=x_0_pred.dtype) | |
| if (t==0).all(): # Do not sample at t=0 | |
| x_tm1_known = x_known | |
| else: | |
| x_tm1_known = diff.q_sample(x_known_log, t) | |
| # Equation 8c | |
| x_tm1 = x_tm1_known * m.long() + x_tm1_unknown * (1 - m.long()) | |
| return x_tm1, x_0_pred | |
| def perform_simple_inference(model: torch.nn.Module, batch: tuple, diff: MultinomialDiffusion, T, dtype=torch.float16, | |
| retain_quant0: bool = True, dsh=DSH): | |
| """ If `retain_quant0`, then do not sample quant0 in each forward or reverse diffusion step. """ | |
| # (bs=1, N), (bs, seq_len2, 8), (bs,) | |
| c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask = batch | |
| device = c_text.device | |
| bs = c_text.shape[0] | |
| x_quant0 = x[..., 0].clone() # (bs, seq_len) 0th quant level | |
| x = torch.randint(0, diff.num_classes, x.shape, dtype=x.dtype, device=device) | |
| # CRITICAL LINE: override quantization level 0 with provided quant0 level. | |
| x[..., 0] = x_quant0 | |
| # RePaint paper resample scheduling | |
| times = get_schedule(T, jump_n_sample=dsh.jump_n_sample, jump_len=dsh.jump_len) | |
| x_known = torch.zeros_like(x) | |
| x_known[..., 0] = x[..., 0] # override L0 codes | |
| m = torch.zeros_like(x).bool() | |
| # (bs, seq_len, 8) | |
| m[..., 0] = True | |
| offset = 0 | |
| if dsh.deep_clone: | |
| print(f"Note: using deep clone. Assuming input `c_phones` is concatenated prompt and output phones.", | |
| "Also assuming no padded indices in `c_codes`.") | |
| prompt = c_codes | |
| x = torch.cat((prompt, x), dim=1) # (bs=1, sl1 + sl2, 8) | |
| x_known = torch.cat((prompt, x_known), dim=1) | |
| x_padding_mask = torch.cat(( | |
| torch.zeros(x_padding_mask.shape[0], c_codes_lengths[0], dtype=torch.bool, device=x_padding_mask.device), | |
| x_padding_mask), dim=-1 | |
| ) | |
| # (bs=1, :up to prompt duration, all 8 codebooks) = True/masked. | |
| m = torch.cat((torch.ones_like(prompt), m), dim=1) | |
| x_quant0 = torch.cat((prompt[..., 0], x_quant0), dim=-1) | |
| offset = c_codes_lengths[0] | |
| print(f"New x: {x.shape} | new x_known: {x_known.shape} . Base prompt: {prompt.shape}. New padding mask: {x_padding_mask.shape} | m shape: {m.shape}") | |
| c = 0 # sequentially progressive diffusion offset (Section 4.2) | |
| # ensemble bs (not in paper) | |
| alphas = torch.linspace(1, 0, T).to(device) | |
| pb = zip(times[:-1], times[1:]) | |
| if dsh.progress: | |
| from fastprogress import progress_bar | |
| pb = progress_bar(pb, total=len(times)-1) | |
| # See RePaint paper algorithm | |
| for t_last, t_cur in pb: | |
| t = torch.ones((bs,), dtype=torch.long, device=x.device) * (t_last) | |
| if t_cur < t_last: | |
| if c > dsh.jump_n_sample: | |
| c = 0 | |
| c += 1/dsh.jump_len | |
| # Reverse diffusion: q | |
| cbatch = (c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask, t) | |
| x, x_0_pred = reverse_diffusion(diff, model, cbatch, x_known, m, temperature=dsh.x_0_temp, alphas=alphas, ensemble_size=1, dsh=dsh) | |
| else: | |
| # Forward diffusion: p | |
| if dsh.enable_kevin_scaled_inference: x = forward_diffusion(diff, dtype, x, t, c=c, dsh=dsh) | |
| else: x = forward_diffusion(diff, dtype, x, t, c=None, dsh=dsh) | |
| if retain_quant0 and dsh.q0_override_steps < t_last: | |
| x[..., 0] = x_quant0 | |
| # crop offset: | |
| x = x[:, offset:] | |
| return x | |