import numpy as np
def fuse_linear(spec, layers):
if not layers:
raise ValueError("Cannot fuse linear layers: at least one layer is required")
if isinstance(layers[0].weight, np.ndarray):
concatenate = np.concatenate
zeros = np.zeros
else:
import torch
concatenate = torch.cat
zeros = torch.zeros
spec.weight = concatenate([layer.weight for layer in layers])
bias_dtype = None
for layer in layers:
if layer.has_bias():
bias_dtype = layer.bias.dtype
break
if bias_dtype is not None:
spec.bias = concatenate(
[
(
layer.bias
if layer.has_bias()
else zeros([layer.weight.shape[0]], dtype=bias_dtype)
)
for layer in layers
]
)
def fuse_linear_prequant(spec, layers, axis):
if not layers:
raise ValueError("Cannot fuse linear layers: at least one layer is required")
params = ["weight", "weight_scale", "weight_zero"]
if isinstance(layers[0].weight, np.ndarray):
concatenate = np.concatenate
else:
import torch
concatenate = torch.cat
for param in params:
setattr(
spec,
param,
concatenate([getattr(layer, param) for layer in layers], axis=axis),
)
def permute_for_sliced_rotary(weight, num_heads, rotary_dim=None):
"""Permutes the weight to use the sliced rotary implementation."""
if rotary_dim is not None:
weight = weight.reshape(num_heads, weight.shape[0] // num_heads, -1)
rotary_weight = weight[:, :rotary_dim]
rotary_weight = permute_for_sliced_rotary(
rotary_weight.reshape(num_heads * rotary_dim, -1), num_heads
).reshape(num_heads, rotary_dim, -1)
weight[:, :rotary_dim] = rotary_weight
return weight.reshape(-1, weight.shape[-1])
return (
weight.reshape(num_heads, weight.shape[0] // num_heads // 2, 2, weight.shape[1])
.swapaxes(1, 2)
.reshape(weight.shape[0], weight.shape[1])
)
def smooth_activation(layer_norm, linear, activation_scales):
"""Applies the activation smoothing technique described in
https://github.com/mit-han-lab/smoothquant.
"""
if not isinstance(linear.weight, np.ndarray):
linear_weight = linear.weight.numpy()
activation_scales = activation_scales.numpy()
else:
linear_weight = linear.weight
weight_scales = np.amax(np.absolute(linear_weight), axis=0)
weight_scales = np.maximum(weight_scales, 1e-5)
activation_scales = activation_scales.astype(weight_scales.dtype)
scales = np.sqrt(activation_scales / weight_scales)
scales = np.maximum(scales, 1e-5)
if not isinstance(linear.weight, np.ndarray):
import torch
scales = torch.from_numpy(scales)
layer_norm.gamma /= scales
layer_norm.beta /= scales
linear.weight *= scales.reshape(1, -1)
def raise_unsupported(reasons):
message = (
"The model you are trying to convert is not supported by CTranslate2. "
"We identified the following reasons:\n"
)
for reason in reasons:
message += "\n- " + reason
raise ValueError(message)
class ConfigurationChecker:
def __init__(self):
self._unsupported_reasons = []
def __call__(self, assert_condition, error_message):
if not assert_condition:
self._unsupported_reasons.append(error_message)
def validate(self):
if self._unsupported_reasons:
raise_unsupported(self._unsupported_reasons)