# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
from ..glm.modeling_glm import GlmAttention, GlmRotaryEmbedding, apply_rotary_pos_emb
from ..llama.modeling_llama import LlamaDecoderLayer, LlamaModel, eager_attention_forward
from ..whisper.modeling_whisper import WhisperModel, shift_tokens_right
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "UsefulSensors/moonshine-tiny"
_CONFIG_FOR_DOC = "MoonshineConfig"
class MoonshineConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Moonshine
[UsefulSensors/moonshine-tiny](https://huggingface.co/UsefulSensors/moonshine-tiny).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32768):
Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MoonshineModel`].
hidden_size (`int`, *optional*, defaults to 288):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 1152):
Dimension of the MLP representations.
encoder_num_hidden_layers (`int`, *optional*, defaults to 6):
Number of hidden layers in the Transformer encoder.
decoder_num_hidden_layers (`int`, *optional*, defaults to 6):
Number of hidden layers in the Transformer decoder.
encoder_num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
encoder_num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`encoder_num_key_value_heads=encoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
`encoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
decoder_num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`decoder_num_key_value_heads=decoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
`decoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`decoder_num_attention_heads`.
pad_head_dim_to_multiple_of (`int`, *optional*):
Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
optimized attention implementations.
encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder.
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 512):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
decoder_start_token_id (`int`, *optional*, defaults to 1):
Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids`
are provided to the `generate` function. It is used to guide the model`s generation process depending on
the task.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
partial_rotary_factor (`float`, *optional*, defaults to 0.9):
Percentage of the query and keys which will have rotary embedding.
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
Whether the model is used as an encoder/decoder or not.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
bos_token_id (`int`, *optional*, defaults to 1):
Denotes beginning of sequences token id.
eos_token_id (`int`, *optional*, defaults to 2):
Denotes end of sequences token id.
Example:
```python
>>> from transformers import MoonshineModel, MoonshineConfig
>>> # Initializing a Moonshine style configuration
>>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine-tiny")
>>> # Initializing a model from the configuration
>>> model = MoonshineModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "moonshine"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_key_value_heads": "encoder_num_key_value_heads",
"num_attention_heads": "encoder_num_attention_heads",
"num_hidden_layers": "encoder_num_hidden_layers",
}
def __init__(
self,
vocab_size=32768,
hidden_size=288,
intermediate_size=1152,
encoder_num_hidden_layers=6,
decoder_num_hidden_layers=6,
encoder_num_attention_heads=8,
decoder_num_attention_heads=8,
encoder_num_key_value_heads=None,
decoder_num_key_value_heads=None,
pad_head_dim_to_multiple_of=None,
encoder_hidden_act="gelu",
decoder_hidden_act="silu",
max_position_embeddings=512,
initializer_range=0.02,
decoder_start_token_id=1,
use_cache=True,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=0.9,
is_encoder_decoder=True,
attention_bias=False,
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.encoder_num_hidden_layers = encoder_num_hidden_layers
self.decoder_num_hidden_layers = decoder_num_hidden_layers
self.encoder_num_attention_heads = encoder_num_attention_heads
self.decoder_num_attention_heads = decoder_num_attention_heads
if encoder_num_key_value_heads is None:
encoder_num_key_value_heads = encoder_num_attention_heads
self.encoder_num_key_value_heads = encoder_num_key_value_heads
if decoder_num_key_value_heads is None:
decoder_num_key_value_heads = decoder_num_attention_heads
self.decoder_num_key_value_heads = decoder_num_key_value_heads
self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
self.encoder_hidden_act = encoder_hidden_act
self.decoder_hidden_act = decoder_hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self.is_encoder_decoder = is_encoder_decoder
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
rope_config_validation(self)
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
class MoonshineEncoderMLP(nn.Module):
def __init__(self, config, hidden_act):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MoonshineDecoderMLP(nn.Module):
def __init__(self, config, hidden_act):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size * 2)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states, gate = hidden_states.chunk(2, dim=-1)
hidden_states = self.activation_fn(gate) * hidden_states
hidden_states = self.fc2(hidden_states)
return hidden_states
class MoonshineAttention(GlmAttention):
def __init__(
self,
config: MoonshineConfig,
layer_idx: int,
is_causal: bool,
num_attention_heads: int,
num_key_value_heads: int,
):
config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads})
super().__init__(config, layer_idx)
self.is_causal = is_causal
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
# Pad head dimension to the next specified multiple.
if self.config.pad_head_dim_to_multiple_of is not None:
target_multiple = self.config.pad_head_dim_to_multiple_of
target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
self.head_dim_padding = target_head_dim - self.head_dim
else:
self.head_dim_padding = 0
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
key_value_states: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len = hidden_states.shape[:-1]
query_states = (
self.q_proj(hidden_states).view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
)
is_cross_attention = key_value_states is not None
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = (
self.k_proj(current_states)
.view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(current_states)
.view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
if is_cross_attention and past_key_value is not None:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
if not is_cross_attention:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
if self.head_dim_padding > 0:
query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
is_causal=is_causal,
**kwargs,
)
if self.head_dim_padding > 0:
attn_output = attn_output[..., : -self.head_dim_padding]
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class MoonshineRotaryEmbedding(GlmRotaryEmbedding):
pass
class MoonshineEncoderLayer(LlamaDecoderLayer):
def __init__(self, config: MoonshineConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = MoonshineAttention(
config=config,
layer_idx=layer_idx,
is_causal=False,
num_attention_heads=config.encoder_num_attention_heads,
num_key_value_heads=config.encoder_num_key_value_heads,
)
self.mlp = MoonshineEncoderMLP(config, config.encoder_hidden_act)
self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
class MoonshineDecoderLayer(nn.Module):
def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MoonshineAttention(
config=config,
layer_idx=layer_idx,
is_causal=True,
num_attention_heads=config.decoder_num_attention_heads,
num_key_value_heads=config.decoder_num_key_value_heads,
)
self.encoder_attn = MoonshineAttention(
config=config,
layer_idx=layer_idx,
is_causal=False,
num_attention_heads=config.decoder_num_attention_heads,
num_key_value_heads=config.decoder_num_key_value_heads,
)
self.mlp = MoonshineDecoderMLP(config, config.decoder_hidden_act)
self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
encoder_position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
encoder_position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
MOONSHINE_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MoonshineConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Moonshine Model outputting raw hidden-states without any specific head on top.",
MOONSHINE_START_DOCSTRING,
)
class MoonshinePreTrainedModel(PreTrainedModel):
config_class = MoonshineConfig
base_model_prefix = "model"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
output_conv1_length = int((input_lengths - 127) / 64 + 1)
output_conv2_length = int((output_conv1_length - 7) / 3 + 1)
output_conv3_length = int((output_conv2_length - 3) / 2 + 1)
return output_conv3_length
class MoonshineEncoder(MoonshinePreTrainedModel):
"""
Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`]
Args:
config: MoonshineConfig
"""
main_input_name = "input_values"
def __init__(self, config: MoonshineConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False)
self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3)
self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2)
self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5)
self.rotary_emb = MoonshineRotaryEmbedding(config=config)
self.layers = nn.ModuleList(
[MoonshineEncoderLayer(config, idx) for idx in range(config.encoder_num_hidden_layers)]
)
self.layer_norm = nn.LayerNorm(embed_dim, bias=False)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
@can_return_tuple
def forward(
self,
input_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
Float values of the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_values is None:
raise ValueError("You must specify input_values.")
# conv downsampling
input_values = input_values.unsqueeze(1)
hidden_states = nn.functional.tanh(self.conv1(input_values))
hidden_states = self.groupnorm(hidden_states)
hidden_states = nn.functional.gelu(self.conv2(hidden_states))
hidden_states = nn.functional.gelu(self.conv3(hidden_states))
hidden_states = hidden_states.permute(0, 2, 1)
# attention mask downsampling
if attention_mask is not None:
mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
downsample_stride = 64 * 3 * 2 # conv strides
attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask == 0.0).any() else None
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# encoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
None,
output_attentions,
False,
None,
position_embeddings,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last encoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class MoonshineDecoder(LlamaModel):
main_input_name = "input_ids"
def __init__(self, config: MoonshineConfig):
super().__init__(config)
self.norm = nn.LayerNorm(config.hidden_size, bias=False)
self.layers = nn.ModuleList(
[MoonshineDecoderLayer(config, idx) for idx in range(config.decoder_num_hidden_layers)]
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache()
past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# attention mask downsampling
if encoder_attention_mask is not None:
mask_len = encoder_hidden_states.shape[-2]
downsample_stride = 64 * 3 * 2 # conv strides
encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
encoder_hidden_states,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
MOONSHINE_MODEL_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
Float values of the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids`
of shape `(batch_size, sequence_length)`.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `decoder_position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare Moonshine Model outputting raw hidden-states without any specific head on top.",
MOONSHINE_START_DOCSTRING,
)
class MoonshineModel(WhisperModel):
@can_return_tuple
@add_start_docstrings_to_model_forward(MOONSHINE_MODEL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Seq2SeqModelOutput:
r"""
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, MoonshineModel
>>> from datasets import load_dataset
>>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_values = inputs.input_values
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 288]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if encoder_outputs is None:
encoder_outputs: BaseModelOutput = self.encoder(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_attention_mask=attention_mask,
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
)
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"The Moonshine Model with a language modeling head. Can be used for automatic speech recognition.",
MOONSHINE_START_DOCSTRING,
)
class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["proj_out.weight"]
def __init__(self, config: MoonshineConfig):
super().__init__(config)
self.model = MoonshineModel(config)
self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def get_output_embeddings(self):
return self.proj_out
def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
@can_return_tuple
@add_start_docstrings_to_model_forward(MOONSHINE_MODEL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
) -> Seq2SeqLMOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoProcessor, MoonshineForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
>>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_values = inputs.input_values
>>> generated_ids = model.generate(input_values, max_new_tokens=100)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```"""
if labels is not None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs: Seq2SeqModelOutput = self.model(
input_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
)
logits = self.proj_out(outputs.last_hidden_state)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
return Seq2SeqLMOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
__all__ = [
"MoonshineConfig",
"MoonshineModel",
"MoonshinePreTrainedModel",
"MoonshineForConditionalGeneration",
]