# coding=utf-8 # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformers DAC model.""" import math from dataclasses import dataclass from typing import Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings, ) from .configuration_dac import DacConfig # General docstring _CONFIG_FOR_DOC = "DacConfig" @dataclass class DacOutput(ModelOutput): """ Args: loss (`torch.Tensor`): Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses. audio_values (`torch.Tensor` of shape `(batch_size, input_length)`): Reconstructed audio data. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`): Quantized continuous representation of input. audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`): Codebook indices for each codebook (quantized discrete representation of input). projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`): Projected latents (continuous representation of input before quantization). """ loss: Optional[torch.FloatTensor] = None audio_values: Optional[torch.FloatTensor] = None quantized_representation: Optional[torch.FloatTensor] = None audio_codes: Optional[torch.LongTensor] = None projected_latents: Optional[torch.FloatTensor] = None @dataclass class DacEncoderOutput(ModelOutput): """ Args: loss (`torch.Tensor`): Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*): Quantized continuous representation of input. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*): Codebook indices for each codebook (quantized discrete representation of input). projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*): Projected latents (continuous representation of input before quantization). """ loss: Optional[torch.FloatTensor] = None quantized_representation: Optional[torch.FloatTensor] = None audio_codes: Optional[torch.FloatTensor] = None projected_latents: Optional[torch.FloatTensor] = None @dataclass # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length class DacDecoderOutput(ModelOutput): """ Args: audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*): Decoded audio values, obtained using the decoder part of Dac. """ audio_values: Optional[torch.FloatTensor] = None class Snake1d(nn.Module): """ A 1-dimensional Snake activation function module. """ def __init__(self, hidden_dim): super().__init__() self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1)) def forward(self, hidden_states): shape = hidden_states.shape hidden_states = hidden_states.reshape(shape[0], shape[1], -1) hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2) hidden_states = hidden_states.reshape(shape) return hidden_states class DacVectorQuantize(nn.Module): """ Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization) Additionally uses following tricks from improved VQGAN (https://arxiv.org/pdf/2110.04627.pdf): 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space for improved codebook usage 2. l2-normalized codes: Converts euclidean distance to cosine similarity which improves training stability """ def __init__(self, config: DacConfig): super().__init__() self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1) self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1) self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim) def forward(self, hidden_state): """ Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors. Args: hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`): Input tensor. Returns: quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`): Quantized continuous representation of input. commitment_loss (`torch.FloatTensor`of shape `(1)`): Commitment loss to train encoder to predict vectors closer to codebook entries. codebook_loss (`torch.FloatTensor`of shape `(1)`): Codebook loss to update the codebook. audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`): Codebook indices for each codebook, quantized discrete representation of input. projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`): Projected latents (continuous representation of input before quantization). """ projected_latents = self.in_proj(hidden_state) quantized_representation, audio_codes = self.decode_latents(projected_latents) commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean") codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean") # noop in forward pass, straight-through gradient estimator in backward pass quantized_representation = projected_latents + (quantized_representation - projected_latents).detach() quantized_representation = self.out_proj(quantized_representation) return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents def decode_latents(self, hidden_states): batch_size, hidden_dim, sequence_length = hidden_states.shape encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim) codebook = self.codebook.weight # codebook: (N x D) # L2 normalize encodings and codebook (ViT-VQGAN) encodings = F.normalize(encodings) codebook = F.normalize(codebook) # Compute euclidean distance with codebook l2_norm = encodings.pow(2).sum(1, keepdim=True) dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t() indices = dist.max(1)[1] indices = indices.reshape(hidden_states.size(0), -1) quantized_representation = self.codebook(indices).transpose(1, 2) return quantized_representation, indices class DacResidualUnit(nn.Module): """ A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations. """ def __init__(self, dimension: int = 16, dilation: int = 1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.snake1 = Snake1d(dimension) self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad) self.snake2 = Snake1d(dimension) self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1) def forward(self, hidden_state): """ Forward pass through the residual unit. Args: hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`): Input tensor . Returns: output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`): Input tensor after passing through the residual unit. """ output_tensor = hidden_state output_tensor = self.conv1(self.snake1(output_tensor)) output_tensor = self.conv2(self.snake2(output_tensor)) padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2 if padding > 0: hidden_state = hidden_state[..., padding:-padding] output_tensor = hidden_state + output_tensor return output_tensor class DacEncoderBlock(nn.Module): """Encoder block used in DAC encoder.""" def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1): super().__init__() dimension = config.encoder_hidden_size * 2**stride_index self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1) self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3) self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9) self.snake1 = Snake1d(dimension // 2) self.conv1 = nn.Conv1d( dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2) ) def forward(self, hidden_state): hidden_state = self.res_unit1(hidden_state) hidden_state = self.res_unit2(hidden_state) hidden_state = self.snake1(self.res_unit3(hidden_state)) hidden_state = self.conv1(hidden_state) return hidden_state class DacDecoderBlock(nn.Module): """Decoder block used in DAC decoder.""" def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1): super().__init__() input_dim = config.decoder_hidden_size // 2**stride_index output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1) self.snake1 = Snake1d(input_dim) self.conv_t1 = nn.ConvTranspose1d( input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), ) self.res_unit1 = DacResidualUnit(output_dim, dilation=1) self.res_unit2 = DacResidualUnit(output_dim, dilation=3) self.res_unit3 = DacResidualUnit(output_dim, dilation=9) def forward(self, hidden_state): hidden_state = self.snake1(hidden_state) hidden_state = self.conv_t1(hidden_state) hidden_state = self.res_unit1(hidden_state) hidden_state = self.res_unit2(hidden_state) hidden_state = self.res_unit3(hidden_state) return hidden_state class DacResidualVectorQuantize(nn.Module): """ ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://arxiv.org/abs/2107.03312) """ def __init__(self, config: DacConfig): super().__init__() n_codebooks = config.n_codebooks quantizer_dropout = config.quantizer_dropout self.n_codebooks = n_codebooks self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)]) self.quantizer_dropout = quantizer_dropout def forward(self, hidden_state, n_quantizers: Optional[int] = None): """ Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors. Args: hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`): Input tensor to be quantized. n_quantizers (`int`, *optional*): Number of quantizers to use. If specified and `self.quantizer_dropout` is True, this argument is ignored during training, and a random number of quantizers is used. Returns: quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`): Quantized continuous representation of input. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`): Codebook indices for each codebook (quantized discrete representation of input). projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`): Projected latents (continuous representation of input before quantization). commitment_loss (`torch.Tensor` of shape `(1)`): Commitment loss to train the encoder to predict vectors closer to codebook entries. codebook_loss (`torch.Tensor` of shape `(1)`): Codebook loss to update the codebook. """ quantized_representation = 0 residual = hidden_state commitment_loss = 0 codebook_loss = 0 audio_codes = [] projected_latents = [] n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks if self.training: n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1 dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],)) n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout) n_quantizers[:n_dropout] = dropout[:n_dropout] n_quantizers = n_quantizers.to(hidden_state.device) for i, quantizer in enumerate(self.quantizers): if self.training is False and i >= n_quantizers: break quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer( residual ) # Create mask to apply quantizer dropout mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None] residual = residual - quantized_representation_i # Sum losses commitment_loss += commitment_loss_i * mask codebook_loss += codebook_loss_i * mask audio_codes.append(indices_i) projected_latents.append(projected_latents_i) audio_codes = torch.stack(audio_codes, dim=1) projected_latents = torch.cat(projected_latents, dim=1) return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss def from_codes(self, audio_codes: torch.Tensor): """ Reconstructs the continuous representation from quantized codes. Args: audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`): Quantized discrete representation of input. Returns: quantized_representation (`torch.Tensor`): Quantized continuous representation of input. projected_latents (`torch.Tensor`): List of projected latents (continuous representations of input before quantization) for each codebook. audio_codes (`torch.Tensor`): Codebook indices for each codebook. """ quantized_representation = 0.0 projected_latents = [] n_codebooks = audio_codes.shape[1] for i in range(n_codebooks): projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2) projected_latents.append(projected_latents_i) quantized_representation += self.quantizers[i].out_proj(projected_latents_i) return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes def from_latents(self, latents: torch.Tensor): """Reconstructs the quantized representation from unquantized latents. Args: latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`): Continuous representation of input after projection. Returns: quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`): Quantized representation of the full-projected space. quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`): Quantized representation of the latent space (continuous representation before quantization). """ quantized_representation = 0 quantized_latents = [] codes = [] codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers]) dims = torch.cumsum(codebook_dims_tensor, dim=0) n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0] for i in range(n_codebooks): hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1] quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :]) quantized_latents.append(quantized_latents_i) codes.append(codes_i) quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i) quantized_representation = quantized_representation + quantized_representation_i return quantized_representation, torch.cat(quantized_latents, dim=1) class DacDecoder(nn.Module): """DAC Decoder""" def __init__(self, config: DacConfig): super().__init__() input_channel = config.hidden_size channels = config.decoder_hidden_size strides = config.upsampling_ratios # Add first conv layer self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3) # Add upsampling + MRF blocks block = [] for stride_index, stride in enumerate(strides): block += [DacDecoderBlock(config, stride, stride_index)] self.block = nn.ModuleList(block) output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1) self.snake1 = Snake1d(output_dim) self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3) self.tanh = nn.Tanh() def forward(self, hidden_state): hidden_state = self.conv1(hidden_state) for layer in self.block: hidden_state = layer(hidden_state) hidden_state = self.snake1(hidden_state) hidden_state = self.conv2(hidden_state) hidden_state = self.tanh(hidden_state) return hidden_state class DacEncoder(nn.Module): """DAC Encoder""" def __init__(self, config: DacConfig): super().__init__() strides = config.downsampling_ratios # Create first convolution self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3) self.block = [] # Create EncoderBlocks that double channels as they downsample by `stride` for stride_index, stride in enumerate(strides): stride_index = stride_index + 1 self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)] self.block = nn.ModuleList(self.block) d_model = config.encoder_hidden_size * 2**stride_index self.snake1 = Snake1d(d_model) self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1) def forward(self, hidden_state): hidden_state = self.conv1(hidden_state) for module in self.block: hidden_state = module(hidden_state) hidden_state = self.snake1(hidden_state) hidden_state = self.conv2(hidden_state) return hidden_state class DacPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = DacConfig base_model_prefix = "dac" main_input_name = "input_values" def _init_weights(self, module): if isinstance(module, nn.Conv1d): nn.init.trunc_normal_(module.weight, std=0.02) nn.init.constant_(module.bias, 0) def apply_weight_norm(self): weight_norm = nn.utils.weight_norm if hasattr(nn.utils.parametrizations, "weight_norm"): weight_norm = nn.utils.parametrizations.weight_norm for layer in self.quantizer.quantizers: weight_norm(layer.in_proj) weight_norm(layer.out_proj) weight_norm(self.encoder.conv1) weight_norm(self.encoder.conv2) for layer in self.encoder.block: weight_norm(layer.conv1) weight_norm(layer.res_unit1.conv1) weight_norm(layer.res_unit1.conv2) weight_norm(layer.res_unit2.conv1) weight_norm(layer.res_unit2.conv2) weight_norm(layer.res_unit3.conv1) weight_norm(layer.res_unit3.conv2) weight_norm(self.decoder.conv1) weight_norm(self.decoder.conv2) for layer in self.decoder.block: weight_norm(layer.conv_t1) weight_norm(layer.res_unit1.conv1) weight_norm(layer.res_unit1.conv2) weight_norm(layer.res_unit2.conv1) weight_norm(layer.res_unit2.conv2) weight_norm(layer.res_unit3.conv1) weight_norm(layer.res_unit3.conv2) def remove_weight_norm(self): for layer in self.quantizer.quantizers: nn.utils.remove_weight_norm(layer.in_proj) nn.utils.remove_weight_norm(layer.out_proj) nn.utils.remove_weight_norm(self.encoder.conv1) nn.utils.remove_weight_norm(self.encoder.conv2) for layer in self.encoder.block: nn.utils.remove_weight_norm(layer.conv1) nn.utils.remove_weight_norm(layer.res_unit1.conv1) nn.utils.remove_weight_norm(layer.res_unit1.conv2) nn.utils.remove_weight_norm(layer.res_unit2.conv1) nn.utils.remove_weight_norm(layer.res_unit2.conv2) nn.utils.remove_weight_norm(layer.res_unit3.conv1) nn.utils.remove_weight_norm(layer.res_unit3.conv2) nn.utils.remove_weight_norm(self.decoder.conv1) nn.utils.remove_weight_norm(self.decoder.conv2) for layer in self.decoder.block: nn.utils.remove_weight_norm(layer.conv_t1) nn.utils.remove_weight_norm(layer.res_unit1.conv1) nn.utils.remove_weight_norm(layer.res_unit1.conv2) nn.utils.remove_weight_norm(layer.res_unit2.conv1) nn.utils.remove_weight_norm(layer.res_unit2.conv2) nn.utils.remove_weight_norm(layer.res_unit3.conv1) nn.utils.remove_weight_norm(layer.res_unit3.conv2) DAC_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`DacConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ DAC_INPUTS_DOCSTRING = r""" Args: input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`). Audio data to encode, n_quantizers (`int`, *optional*): Number of quantizers to use. If `None`, all quantizers are used. Default is `None`. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The DAC (Descript Audio Codec) model.", DAC_START_DOCSTRING, ) class DacModel(DacPreTrainedModel): def __init__(self, config: DacConfig): super().__init__(config) self.config = config self.encoder = DacEncoder(config) self.decoder = DacDecoder(config) self.quantizer = DacResidualVectorQuantize(config) self.bits_per_codebook = int(math.log2(self.config.codebook_size)) if 2**self.bits_per_codebook != self.config.codebook_size: raise ValueError("The codebook_size must be a power of 2.") # Initialize weights and apply final processing self.post_init() @replace_return_docstrings(output_type=DacEncoderOutput, config_class=_CONFIG_FOR_DOC) def encode( self, input_values: torch.Tensor, n_quantizers: Optional[int] = None, return_dict: Optional[bool] = None, ): """ Encode given audio data and return quantized latent codes Args: input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`): Input audio data to encode, n_quantizers (int, *optional*): Number of quantizers to use. If None, all quantizers are used. Default is None. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Returns: """ return_dict = return_dict if return_dict is not None else self.config.return_dict quantized_representation = self.encoder(input_values) quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer( quantized_representation, n_quantizers ) loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss if not return_dict: return (loss, quantized_representation, audio_codes, projected_latents) return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents) @replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC) def decode( self, quantized_representation: Optional[torch.Tensor] = None, audio_codes: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, ): """Decode given latent codes and return audio data Args: quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*): Quantized continuous representation of input. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*): The codebook indices for each codebook, representing the quantized discrete representation of the input. This parameter should be provided if you want to decode directly from the audio codes (it will overwrite quantized_representation). return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Returns: """ if quantized_representation is None and audio_codes is None: raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.") return_dict = return_dict if return_dict is not None else self.config.return_dict if audio_codes is not None: quantized_representation = self.quantizer.from_codes(audio_codes)[0] audio_values = self.decoder(quantized_representation).squeeze(1) if not return_dict: return (audio_values,) return DacDecoderOutput(audio_values) @add_start_docstrings_to_model_forward(DAC_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=DacOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_values: torch.Tensor, n_quantizers: Optional[int] = None, return_dict: Optional[bool] = None, ): """ Returns: Examples: ```python >>> from datasets import load_dataset, Audio >>> from transformers import DacModel, AutoProcessor >>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> model = DacModel.from_pretrained("descript/dac_16khz") >>> processor = AutoProcessor.from_pretrained("descript/dac_16khz") >>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate)) >>> audio_sample = librispeech_dummy[-1]["audio"]["array"] >>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt") >>> encoder_outputs = model.encode(inputs["input_values"]) >>> # Get the intermediate audio codes >>> audio_codes = encoder_outputs.audio_codes >>> # Reconstruct the audio from its quantized representation >>> audio_values = model.decode(encoder_outputs.quantized_representation) >>> # or the equivalent with a forward pass >>> audio_values = model(inputs["input_values"]).audio_values ```""" return_dict = return_dict if return_dict is not None else self.config.return_dict length = input_values.shape[-1] loss, quantized_representation, audio_codes, projected_latents = self.encode( input_values, n_quantizers, return_dict=False ) audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length] if not return_dict: return (loss, audio_values, quantized_representation, audio_codes, projected_latents) return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents) __all__ = ["DacModel", "DacPreTrainedModel"]
Memory