Source code for activation_extractor.model_functions.tokenize_funs

"""
Defines a tokenizer wrapper function for the models included by default.
"""
import re

[docs] def define_tokenize_function(model_type, tokenizer, device=None): """ Define the right function to tokenize the inputs based on the model type. This function is called inside inferencer.tokenizer(). :param model_type: the model type (from the list in activation_extractor.model_functions.model_types) :type model_type: str :param tokenizer: the loaded tokenizer object :return: the function used to tokenize the inputs """ match model_type: #🥩,🧬,🧬,🧬 case ("esm" | "nucleotide-transformer" | "hyenadna" | "caduceus" | "dnabert" ): #### start function definition def tokenize_fun(sequence_inputs, **kwargs): tokenized = tokenizer.batch_encode_plus(sequence_inputs, return_tensors="pt", padding="longest", **kwargs ) return tokenized #### end function definition #🧬 case "evo": from evo.scoring import prepare_batch #### start function definition def tokenize_fun(sequence_inputs, **kwargs): tokens_ids, seq_lengths = prepare_batch( sequence_inputs, tokenizer, prepend_bos=False, device=device, ) tokenized=dict() tokenized["input_ids"]=tokens_ids return tokenized #### end function definition #🥩 case "prot_t5" | "prot_bert" | "prot_xlnet" | "prot_electra" : #### start function definition def tokenize_fun(sequence_inputs, **kwargs): # replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids sequence_inputs = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_inputs] tokenized = tokenizer.batch_encode_plus(sequence_inputs, add_special_tokens=True, padding="longest", return_tensors='pt') return tokenized #### end function definition #🥩 protein sequence, ⛓️🥩 3di structural sequence case "prostt5" : #### start function definition def tokenize_fun(sequence_inputs, sequence_type=None, **kwargs): # replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids sequence_inputs = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_inputs] if sequence_type is None: # add pre-fixes accordingly (this already expects 3Di-sequences to be lower-case) # if you go from AAs to 3Di (or if you want to embed AAs), you need to prepend "<AA2fold>" # if you go from 3Di to AAs (or if you want to embed 3Di), you need to prepend "<fold2AA>" sequence_inputs = [ "<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in sequence_inputs ] else: #if sequence type is given, overwrite the above sequence_inputs = [ "<AA2fold>" + " " + s if sequence_type=="protein" else "<fold2AA>" + " " + s for s in sequence_inputs ] tokenized = tokenizer.batch_encode_plus(sequence_inputs, add_special_tokens=True, padding="longest", return_tensors='pt') return tokenized #### end function definition #🥩 case "ankh": #### start function definition def tokenize_fun(sequence_inputs, **kwargs): sequence_inputs = [list(seq) for seq in sequence_inputs] tokenized = tokenizer.batch_encode_plus(sequence_inputs, add_special_tokens=True, padding="longest", is_split_into_words=True, return_tensors="pt") return tokenized #### end function definition #🥩/🧱/🌟 case "esm3": #### start function definition def tokenize_fun(inputs): """ ESM3 """ #sequence tokenization from esm.utils.encoding import tokenize_sequence #structure tokenization from esm.sdk.api import ESMProtein from esm.utils.encoding import tokenize_structure #sequence if "sequence" in inputs.keys(): tokenized_sequence = tokenize_sequence(sequence=inputs['sequence'][0], #input is a string not list sequence_tokenizer=tokenizer["sequence"], add_special_tokens=True ).unsqueeze(0) #shape to (1,L) tokenized_structure = None #structure elif "pdb" in inputs.keys(): protein = ESMProtein.from_pdb(inputs['pdb'][0]) tokenized_sequence = tokenize_sequence(sequence=protein.sequence, #input is a string not list sequence_tokenizer=tokenizer["sequence"], add_special_tokens=True ).unsqueeze(0) #shape to (1,L) tokenized_structure = tokenize_structure( coordinates = protein.coordinates, structure_encoder = tokenizer["structure_encoder"], structure_tokenizer= tokenizer["structure_tokenizer"], )[2].unsqueeze(0) tokenized = { "sequence": tokenized_sequence, "structure": tokenized_structure, } return tokenized #### end function definition #text 📚 case "pythia" | "mamba" | "llama" | "striped-hyena": #### start function definition def tokenize_fun(sequence_inputs): tokenized = tokenizer(sequence_inputs, return_tensors="pt") return tokenized #### end function definition #return the rightly defined tokenizer function return tokenize_fun