import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput, Wav2Vec2BaseModelOutput, XVectorOutput, ) from ...modeling_utils import PreTrainedModel from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ..wav2vec2.modeling_wav2vec2 import ( Wav2Vec2FeatureProjection, Wav2Vec2FeedForward, Wav2Vec2ForAudioFrameClassification, Wav2Vec2ForCTC, Wav2Vec2ForSequenceClassification, Wav2Vec2ForXVector, Wav2Vec2Model, Wav2Vec2PositionalConvEmbedding, Wav2Vec2PreTrainedModel, ) from .configuration_wavlm import WavLMConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "WavLMConfig" _CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus" _EXPECTED_OUTPUT_SHAPE = [1, 292, 768] _CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'" _CTC_EXPECTED_LOSS = 12.51 _FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd" _FRAME_EXPECTED_OUTPUT = [0, 0] _XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv" _XVECTOR_EXPECTED_OUTPUT = 0.97 class WavLMPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding): pass class WavLMFeatureProjection(Wav2Vec2FeatureProjection): pass class WavLMAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, num_buckets: int = 320, max_distance: int = 800, has_relative_position_bias: bool = True, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.num_buckets = num_buckets self.max_distance = max_distance self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1)) self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) if has_relative_position_bias: self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, index=0, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Attention layer with relative attention""" bsz, tgt_len, _ = hidden_states.size() # first pass of attention layer creates position bias if position_bias is None: position_bias = self.compute_bias(tgt_len, tgt_len) position_bias = ( position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len) ) # Compute relative position bias: # 1) get reshape hidden_states gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1)) gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3) # 2) project hidden states relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states) relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1) # 3) compute gate for position bias from projected hidden states gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1) gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 # 4) apply gate to position bias to compute gated position_bias gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len)) attn_output, attn_weights = self.torch_multi_head_self_attention( hidden_states, attention_mask, gated_position_bias, output_attentions ) return attn_output, attn_weights, position_bias def torch_multi_head_self_attention( self, hidden_states: torch.FloatTensor, attention_mask: Union[torch.LongTensor, torch.BoolTensor], gated_position_bias: torch.FloatTensor, output_attentions: bool, ) -> (torch.FloatTensor, torch.FloatTensor): """simple wrapper around torch's multi_head_attention_forward function""" # self-attention assumes q = k = v query = key = value = hidden_states.transpose(0, 1) key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None # disable bias and add_zero_attn bias_k = bias_v = None add_zero_attn = False # PyTorch 1.3.0 has F.multi_head_attention_forward defined # so no problem with backwards compatibility attn_output, attn_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), bias_k, bias_v, add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, self.training, key_padding_mask, output_attentions, gated_position_bias, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, ) # [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...] attn_output = attn_output.transpose(0, 1) if attn_weights is not None: # IMPORTANT: Attention weights are averaged weights # here which should not be the case. This is an open issue # on PyTorch: https://github.com/pytorch/pytorch/issues/32590 attn_weights = attn_weights[:, None].broadcast_to( attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:] ) return attn_output, attn_weights def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor: context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position relative_position_bucket = self._relative_positions_bucket(relative_position) relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) values = self.rel_attn_embed(relative_position_bucket) values = values.permute([2, 0, 1]) return values def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor: num_buckets = self.num_buckets // 2 relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets relative_positions = torch.abs(relative_positions) max_exact = num_buckets // 2 is_small = relative_positions < max_exact relative_positions_if_large = torch.log(relative_positions.float() / max_exact) relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact) relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact) relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) ) relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large) return relative_buckets class WavLMFeedForward(Wav2Vec2FeedForward): pass class WavLMEncoderLayer(nn.Module): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, num_buckets=config.num_buckets, max_distance=config.max_bucket_distance, has_relative_position_bias=has_relative_position_bias, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = WavLMFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0): attn_residual = hidden_states hidden_states, attn_weights, position_bias = self.attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, index=index, ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states, position_bias) if output_attentions: outputs += (attn_weights,) return outputs class WavLMEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, num_buckets=config.num_buckets, max_distance=config.max_bucket_distance, has_relative_position_bias=has_relative_position_bias, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = WavLMFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False): attn_residual = hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states, attn_weights, position_bias = self.attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) outputs = (hidden_states, position_bias) if output_attentions: outputs += (attn_weights,) return outputs class WavLMEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.pos_conv_embed = WavLMPositionalConvEmbedding(config) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList( [WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) position_bias = None for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = torch.rand([]) skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, position_bias, output_attentions, ) else: layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, index=i, ) hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: layer_outputs = (None, None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class WavLMEncoderStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() self.config = config self.pos_conv_embed = WavLMPositionalConvEmbedding(config) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList( [ WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers) ] ) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) position_bias = None for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = torch.rand([]) skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, position_bias, output_attentions, ) else: layer_outputs = layer( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, position_bias=position_bias, ) hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: layer_outputs = (None, None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions ) class WavLMGumbelVectorQuantizer(nn.Module): """ Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. """ def __init__(self, config): super().__init__() self.num_groups = config.num_codevector_groups self.num_vars = config.num_codevectors_per_group if config.codevector_dim % self.num_groups != 0: raise ValueError( f"`config.codevector_dim {config.codevector_dim} must be divisible" f" by `config.num_codevector_groups` {self.num_groups} " "for concatenation." ) # storage for codebook variables (codewords) self.codevectors = nn.Parameter( torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) ) self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 @staticmethod def _compute_perplexity(probs): marginal_probs = probs.mean(dim=0) perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() return perplexity def forward(self, hidden_states): batch_size, sequence_length, hidden_size = hidden_states.shape # project to codevector dim hidden_states = self.weight_proj(hidden_states) hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) if self.training: # sample code vector probs via gumbel in differentiateable way codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True) codevector_probs = codevector_probs.type_as(hidden_states) # compute perplexity codevector_soft_dist = torch.softmax( hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 ) perplexity = self._compute_perplexity(codevector_soft_dist) else: # take argmax in non-differentiable way # comptute hard codevector distribution (one hot) codevector_idx = hidden_states.argmax(dim=-1) codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( -1, codevector_idx.view(-1, 1), 1.0 ) codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) perplexity = self._compute_perplexity(codevector_probs) codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) # use probs to retrieve codevectors codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) return codevectors, perplexity class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = WavLMConfig base_model_prefix = "wavlm" main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = False _supports_sdpa = False def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, WavLMGumbelVectorQuantizer): module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.bias.data.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, WavLMPositionalConvEmbedding): nn.init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, WavLMFeatureProjection): k = math.sqrt(1 / module.projection.in_features) nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) def _get_adapters(self): raise AttributeError("Not needed for WavLM") def init_adapter_layers(self): raise AttributeError("Not needed for WavLM") def load_adapter(self): raise AttributeError("Not needed for WavLM") WAVLM_START_DOCSTRING = r""" WavLM was proposed in [WavLM: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Xiangzhan Yu, Furu Wei. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving etc.). This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`WavLMConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ WAVLM_INPUTS_DOCSTRING = r""" Args: input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) <Tip warning={true}> `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly different results depending on whether `input_values` is padded or not. </Tip> output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ WavLMBaseModelOutput = Wav2Vec2BaseModelOutput @add_start_docstrings( "The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.", WAVLM_START_DOCSTRING, ) class WavLMModel(Wav2Vec2Model): @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=WavLMBaseModelOutput, config_class=_CONFIG_FOR_DOC, modality="audio", expected_output=_EXPECTED_OUTPUT_SHAPE, ) def forward(self, **super_kwargs): return super().forward(**super_kwargs) @add_start_docstrings( """WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", WAVLM_START_DOCSTRING, ) class WavLMForCTC(Wav2Vec2ForCTC): @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC, expected_output=_CTC_EXPECTED_OUTPUT, expected_loss=_CTC_EXPECTED_LOSS, ) def forward(self, **super_kwargs): super().forward(**super_kwargs) @add_start_docstrings( """ WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB Keyword Spotting. """, WAVLM_START_DOCSTRING, ) class WavLMForSequenceClassification(Wav2Vec2ForSequenceClassification): @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, modality="audio", ) def forward(self, **super_kwargs): super().forward(**super_kwargs) @add_start_docstrings( """ WavLM Model with a frame classification head on top for tasks like Speaker Diarization. """, WAVLM_START_DOCSTRING, ) class WavLMForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification): @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_FRAME_CLASS_CHECKPOINT, output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, modality="audio", expected_output=_FRAME_EXPECTED_OUTPUT, ) def forward(self, **super_kwargs): super().forward(**super_kwargs) @add_start_docstrings( """ WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification. """, WAVLM_START_DOCSTRING, ) class WavLMForXVector(Wav2Vec2ForXVector): pass @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_XVECTOR_CHECKPOINT, output_type=XVectorOutput, config_class=_CONFIG_FOR_DOC, modality="audio", expected_output=_XVECTOR_EXPECTED_OUTPUT, ) def forward(self, **super_kwargs): super().forward(**super_kwargs) __all__ = [ "WavLMForAudioFrameClassification", "WavLMForCTC", "WavLMForSequenceClassification", "WavLMForXVector", "WavLMModel", "WavLMPreTrainedModel", ]
Memory