"""
This file contains functions to load models, tokenizers, etc.
Because not all models are from huggingface and they might not all be installed,
the right imports are directly inside the corresponding loading functions.
"""
import os
#⏳ load tokenizers
[docs]
def load_tokenizer(model_name, tokenizer_type, **kwargs):
"""
Load a tokenizer type for a model.
This function is called inside load_model() for sequence type models.
:param model_name: model name (for huggingface models it should be the same as the loaded model)
:type model_name: str
:param tokenizer_type: the type of tokenizer (valid types - AutoTokenizer and T5Tokenizer)
:type tokenizer_type: str
:return: the tokenizer object
"""
match tokenizer_type:
case "AutoTokenizer":
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, **kwargs)
case "T5Tokenizer":
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained(model_name, trust_remote_code=True, do_lower_case=False, **kwargs)
case "BertTokenizer":
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(model_name, **kwargs)
case _:
raise ValueError(f"tokenizer_type not valid")
return tokenizer
#⏳ load processors
def load_processor(model_name, processor_type, **kwargs):
match processor_type:
#images 🖼️
case "AutoProcessor":
from transformers import AutoImageProcessor
processor = AutoImageProcessor.from_pretrained(model_name)
case "igptProcessor":
from transformers import ImageGPTImageProcessor, ImageGPTModel
processor = ImageGPTImageProcessor.from_pretrained(model_name)
case "convnextProcessor":
from transformers import ConvNextImageProcessor
processor = ConvNextImageProcessor.from_pretrained(model_name)
#multimodal 🖼️/📚
case "CLIP":
from transformers import CLIPProcessor
processor = CLIPProcessor.from_pretrained(model_name)
return processor
#⏳ load models
[docs]
def load_model(model_name, model_type, **kwargs):
"""
Loads a Pytorch model according to the passed model name.
For sequence models, it loads the corresponding tokenizer.
For image models, it loads the image processor.
:param model: the Pytorch model object
:param model_type: A model type (see list of included models).
:type model_type: str
:return: tuple with (model, tokenizer) or (model, processor).
"""
#START OF MATCH#
match model_type:
#DNA models 🧬 ---
case "nucleotide-transformer":
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True, **kwargs)
tokenizer = load_tokenizer(model_name, tokenizer_type='AutoTokenizer', **kwargs)
return model, tokenizer
case 'hyenadna':
from transformers import AutoModel
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, **kwargs)
#model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True, **kwargs)
tokenizer = load_tokenizer(model_name, tokenizer_type='AutoTokenizer', **kwargs)
return model, tokenizer
case 'evo':
from evo import Evo
evo_model = Evo(model_name.split('/')[1])
model, tokenizer = evo_model.model, evo_model.tokenizer
return model, tokenizer
case 'caduceus':
from transformers import AutoModelForMaskedLM
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True, **kwargs)
tokenizer = load_tokenizer(model_name, tokenizer_type='AutoTokenizer', **kwargs)
return model, tokenizer
#protein models 🥩 ---
case 'esm':
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True, **kwargs)
tokenizer = load_tokenizer(model_name, tokenizer_type='AutoTokenizer', **kwargs)
return model, tokenizer
case 'prot_t5' | "prostt5" :
from transformers import T5EncoderModel
model = T5EncoderModel.from_pretrained(model_name, **kwargs)
tokenizer = load_tokenizer(model_name, tokenizer_type='T5Tokenizer', **kwargs)
return model, tokenizer
case "prot_bert":
from transformers import BertForMaskedLM
model = BertForMaskedLM.from_pretrained(model_name, **kwargs)
tokenizer = load_tokenizer(model_name, tokenizer_type="BertTokenizer",
do_lower_case=False, **kwargs)
return model, tokenizer
case "prot_xlnet":
from transformers import XLNetTokenizer, XLNetModel
model = XLNetModel.from_pretrained(model_name,
output_attentions=False)
tokenizer = XLNetTokenizer.from_pretrained(model_name,
do_lower_case=False)
return model, tokenizer
case "prot_electra":
from transformers import (ElectraTokenizer,
ElectraForMaskedLM, ElectraModel,
AutoModel)
from activation_extractor.utils.download import download_file
#crate folder to download models
FolderPath = f"{os.environ['TRANSFORMERS_CACHE']}/electra/{model_name}"
os.makedirs(FolderPath, exist_ok=True)
# # #corresponding urls for each model
# if model_name=="Rostlab/prot_electra_generator_bfd":
# ModelUrl = 'https://www.dropbox.com/s/5x5et5q84y3r01m/pytorch_model.bin?dl=1'
# ConfigUrl = 'https://www.dropbox.com/s/9059fvix18i6why/config.json?dl=1'
# if model_name=="Rostlab/prot_electra_discriminator_bfd":
# ModelUrl = 'https://www.dropbox.com/s/9ptrgtc8ranf0pa/pytorch_model.bin?dl=1'
# ConfigUrl = 'https://www.dropbox.com/s/jq568evzexyla0p/config.json?dl=1'
# #download files
# ModelFilePath = os.path.join(FolderPath, 'pytorch_model.bin')
# ConfigFilePath = os.path.join(FolderPath, 'config.json')
# download_file(ModelUrl, ModelFilePath)
# download_file(ConfigUrl, ConfigFilePath)
# #create model
# #
# if model_name=="Rostlab/prot_electra_generator_bfd":
# model = ElectraForMaskedLM.from_pretrained(FolderPath, output_attentions=False)
# if model_name=="Rostlab/prot_electra_discriminator_bfd":
# model = ElectraModel.from_pretrained(FolderPath, output_attentions=False)
model = AutoModel.from_pretrained(model_name)
#tokenizer
vocabUrl = 'https://www.dropbox.com/s/wck3w1q15bc53s0/vocab.txt?dl=1'
vocabFilePath = f"{FolderPath}/vocab.txt"
download_file(vocabUrl, vocabFilePath)
tokenizer = ElectraTokenizer(vocabFilePath, do_lower_case=False)
#using Prot_Bert tokenizer ❗
# tokenizer = load_tokenizer("Rostlab/prot_bert",
# tokenizer_type='AutoTokenizer',
# **kwargs)
return model, tokenizer
case 'ankh':
from transformers import T5EncoderModel
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
model = T5EncoderModel.from_pretrained(model_name, **kwargs) #output_attentions
tokenizer = load_tokenizer(model_name, tokenizer_type='AutoTokenizer', **kwargs)
return model, tokenizer
case "esm3": #🥩/🧱/🌟
from esm.models.esm3 import ESM3
#sequence
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
#structure
from esm.tokenization.structure_tokenizer import StructureTokenizer
from esm.utils.constants.esm3 import VQVAE_SPECIAL_TOKENS
from esm.pretrained import ESM3_structure_encoder_v0
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1")
#sequence
seq_tokenizer = EsmSequenceTokenizer()
#structure
structure_tokenizer = StructureTokenizer(vq_vae_special_tokens=VQVAE_SPECIAL_TOKENS)
structure_encoder = ESM3_structure_encoder_v0()
tokenizer = {
"sequence":seq_tokenizer,
"structure_tokenizer":structure_tokenizer,
"structure_encoder":structure_encoder,
}
return model, tokenizer
#image 🖼️ ---
case "vit":
from transformers import ViTForMaskedImageModeling
processor = load_processor(model_name, processor_type="AutoProcessor")
model = ViTForMaskedImageModeling.from_pretrained(model_name)
return model, processor
case "igpt":
from transformers import ImageGPTModel
processor = load_processor(model_name, processor_type="igptProcessor")
model = ImageGPTModel.from_pretrained(model_name)
return model, processor
case "swin":
from transformers import SwinForMaskedImageModeling
processor = load_processor(model_name, processor_type="AutoProcessor")
model = SwinForMaskedImageModeling.from_pretrained(model_name)
return model, processor
case "convnext":
from transformers import ConvNextForImageClassification
processor = load_processor(model_name, processor_type="convnextProcessor")
model = ConvNextForImageClassification.from_pretrained(model_name)
return model, processor
case "resnet":
from transformers import ResNetForImageClassification
processor = load_processor(model_name, processor_type="AutoProcessor")
model = ResNetForImageClassification.from_pretrained(model_name)
return model, processor
case "timm":
import timm
#load model
model_name = model_name.split("/")[1]
model = timm.create_model(model_name, pretrained=True)
# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
#return
return model, transforms
#text 📚
case "pythia":
from transformers import AutoModelForCausalLM
tokenizer = load_tokenizer(model_name, tokenizer_type='AutoTokenizer', **kwargs)
model = AutoModelForCausalLM.from_pretrained(model_name)
return model, tokenizer
case "mamba":
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = MambaForCausalLM.from_pretrained(model_name)
return model, tokenizer
case "llama":
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
device_map="auto" #multiple GPUs
)
return model, tokenizer
case "striped-hyena":
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
return model, tokenizer
#multimodal 🖼️/📚 ---
case "clip":
from transformers import CLIPModel
model = CLIPModel.from_pretrained(model_name)
processor = load_processor(model_name, processor_type="CLIP")
return model, processor
case _:
raise ValueError(f"model_type not valid ")
#END OF MATCH#