import importlib
from unicodedata import name
import torch.nn as nn
import transformers
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer, AutoModel, AutoConfig
from transformers import RobertaModel, RobertaPreTrainedModel
from transformers import XLMRobertaModel, XLMRobertaConfig
from transformers import ElectraModel, ElectraPreTrainedModel
from transformers import DebertaV2Model, DebertaV2PreTrainedModel
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from colbert.utils.utils import torch_load_dnn
class XLMRobertaPreTrainedModel(RobertaPreTrainedModel):
"""
This class overrides [`RobertaModel`]. Please check the superclass for the appropriate documentation alongside
usage examples.
"""
config_class = XLMRobertaConfig
base_class_mapping={
"roberta-base": RobertaPreTrainedModel,
"google/electra-base-discriminator": ElectraPreTrainedModel,
"xlm-roberta-base": XLMRobertaPreTrainedModel,
"xlm-roberta-large": XLMRobertaPreTrainedModel,
"bert-base-uncased": BertPreTrainedModel,
"bert-large-uncased": BertPreTrainedModel,
"microsoft/mdeberta-v3-base": DebertaV2PreTrainedModel,
"bert-base-multilingual-uncased": BertPreTrainedModel
}
model_object_mapping = {
"roberta-base": RobertaModel,
"google/electra-base-discriminator": ElectraModel,
"xlm-roberta-base": XLMRobertaModel,
"xlm-roberta-large": XLMRobertaModel,
"bert-base-uncased": BertModel,
"bert-large-uncased": BertModel,
"microsoft/mdeberta-v3-base": DebertaV2Model,
"bert-base-multilingual-uncased": BertModel
}
transformers_module = dir(transformers)
def find_class_names(model_type, class_type):
model_type = model_type.replace("-", "").lower()
for item in transformers_module:
if model_type + class_type == item.lower():
return item
return None
def class_factory(name_or_path):
loadedConfig = AutoConfig.from_pretrained(name_or_path, trust_remote_code=True)
if getattr(loadedConfig, "auto_map", None) is None:
model_type = loadedConfig.model_type
pretrained_class = find_class_names(model_type, 'pretrainedmodel')
model_class = find_class_names(model_type, 'model')
if pretrained_class is not None:
pretrained_class_object = getattr(transformers, pretrained_class)
elif model_type == 'xlm-roberta':
pretrained_class_object = XLMRobertaPreTrainedModel
elif base_class_mapping.get(name_or_path) is not None:
pretrained_class_object = base_class_mapping.get(name_or_path)
else:
raise ValueError("Could not find correct pretrained class for the model type {model_type} in transformers library")
if model_class != None:
model_class_object = getattr(transformers, model_class)
elif model_object_mapping.get(name_or_path) is not None:
model_class_object = model_object_mapping.get(name_or_path)
else:
raise ValueError("Could not find correct model class for the model type {model_type} in transformers library")
else:
assert "AutoModel" in loadedConfig.auto_map, "The custom model should have AutoModel class in the config.automap"
model_class = loadedConfig.auto_map["AutoModel"]
assert model_class.endswith("Model")
pretrained_class = model_class.replace("Model", "PreTrainedModel")
model_class_object = get_class_from_dynamic_module(model_class, name_or_path)
pretrained_class_object = get_class_from_dynamic_module(pretrained_class, name_or_path)
class HF_ColBERT(pretrained_class_object):
"""
Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level.
This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly.
"""
_keys_to_ignore_on_load_unexpected = [r"cls"]
def __init__(self, config, colbert_config):
super().__init__(config)
self.config = config
self.dim = colbert_config.dim
self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False)
setattr(self,self.base_model_prefix, model_class_object(config))
# if colbert_config.relu:
# self.score_scaler = nn.Linear(1, 1)
self.init_weights()
# if colbert_config.relu:
# self.score_scaler.weight.data.fill_(1.0)
# self.score_scaler.bias.data.fill_(-8.0)
@property
def LM(self):
base_model_prefix = getattr(self, "base_model_prefix")
return getattr(self, base_model_prefix)
@classmethod
def from_pretrained(cls, name_or_path, colbert_config):
if name_or_path.endswith('.dnn'):
dnn = torch_load_dnn(name_or_path)
base = dnn.get('arguments', {}).get('model', 'bert-base-uncased')
obj = super().from_pretrained(base, state_dict=dnn['model_state_dict'], colbert_config=colbert_config)
obj.base = base
return obj
obj = super().from_pretrained(name_or_path, colbert_config=colbert_config)
obj.base = name_or_path
return obj
@staticmethod
def raw_tokenizer_from_pretrained(name_or_path):
if name_or_path.endswith('.dnn'):
dnn = torch_load_dnn(name_or_path)
base = dnn.get('arguments', {}).get('model', 'bert-base-uncased')
obj = AutoTokenizer.from_pretrained(base)
obj.base = base
return obj
obj = AutoTokenizer.from_pretrained(name_or_path)
obj.base = name_or_path
return obj
return HF_ColBERT