from typing import List, Optional, Tuple
import numpy as np
from ctranslate2.specs import common_spec, model_spec, transformer_spec
class WhisperConfig(model_spec.ModelConfig):
"""Configuration for the Whisper model."""
def __init__(
self,
suppress_ids: Optional[List[int]] = None,
suppress_ids_begin: Optional[List[int]] = None,
lang_ids: Optional[List[int]] = None,
alignment_heads: Optional[List[Tuple[int, int]]] = None,
):
super().__init__(
suppress_ids=suppress_ids,
suppress_ids_begin=suppress_ids_begin,
lang_ids=lang_ids,
alignment_heads=alignment_heads,
)
class WhisperSpec(model_spec.LanguageModelSpec):
"""Describes a Whisper model."""
def __init__(
self,
num_encoder_layers,
num_encoder_heads,
num_decoder_layers,
num_decoder_heads,
):
"""Initializes the model specification.
Args:
num_encoder_layers: The number of encoder layers.
num_encoder_heads: The number of encoder attention heads.
num_decoder_layers: The number of decoder layers.
num_decoder_heads: The number of decoder attention heads.
"""
super().__init__()
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads)
self.decoder = transformer_spec.TransformerDecoderSpec(
num_decoder_layers,
num_decoder_heads,
activation=common_spec.Activation.GELU,
)
self.decoder.scale_embeddings = False
@property
def name(self):
return "WhisperSpec"
@property
def revision(self):
return 3
def get_default_config(self):
return WhisperConfig()
def get_vocabulary_size(self):
return self.decoder.embeddings.weight.shape[0]
class WhisperEncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers, num_heads):
self.num_heads = np.dtype("int16").type(num_heads)
self.conv1 = common_spec.Conv1DSpec()
self.conv2 = common_spec.Conv1DSpec()
self.position_encodings = transformer_spec.PositionEncoderSpec()
self.layer_norm = common_spec.LayerNormSpec()
self.layer = [
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
]