Source code for activation_extractor.model_functions.inference_funs

"""
This file defines an inferencer wrapper for the included models.
"""
import torch

[docs] def define_inference_function(model_type, model, tokenizer, device): """ Define the right function to do inference based on the model type. The resulting function is called as inferencer.inference(). The functions move the tokenized input to ``device`` before performing inference. :param model_type: the model type (from the list in activation_extractor.model_functions.model_types) :type model_type: str :param model_type: the loaded pytorch model :param tokenizer: the loaded tokenizer object :param device: the device (cpu, cuda...) :type device: str :return: the function used to do the inference """ match model_type: # Biological Sequences 🥩🧬 ============================================================= #🥩,🧬 case "esm" | "nucleotide-transformer" : #### start function definition def inference_fun(tokenized_inputs, device, **kwargs): tokens_ids = tokenized_inputs["input_ids"].to(device) attention_mask = tokens_ids != tokenizer.pad_token_id #set default parameters if 'output_hidden_states' in kwargs: output_hidden_states=kwargs['output_hidden_states'] else: output_hidden_states=False #inference outputs = model( tokens_ids, attention_mask=attention_mask, encoder_attention_mask=attention_mask, output_hidden_states=output_hidden_states ) return outputs #### end function definition #🥩, 🥩 case "prot_t5" | "ankh" : #### start function definition def inference_fun(tokenized_inputs, device, **kwargs): tokens_ids = torch.tensor(tokenized_inputs['input_ids']).to(device) attention_mask = torch.tensor(tokenized_inputs['attention_mask']).to(device) outputs = model(input_ids=tokens_ids, attention_mask=attention_mask) return outputs #### end function definition #🥩 📚 case ( "prot_bert" | "prot_xlnet" | "prot_electra" | "striped-hyena" ): #### start function definition def inference_fun(tokenized_inputs, device, **kwargs): for key in tokenized_inputs.keys(): tokenized_inputs[key]=tokenized_inputs[key].to(device) outputs = model(**tokenized_inputs) return outputs #### end function definition case "prostt5": #### start function definition def inference_fun(tokenized_inputs, device, **kwargs): for key in tokenized_inputs.keys(): tokenized_inputs[key]=tokenized_inputs[key].to(device) outputs = model(tokenized_inputs["input_ids"], attention_mask=tokenized_inputs["attention_mask"]) return outputs #### end function definition #default inference for sequences #🧬, 📚 case ( "hyenadna" | "evo" | "caduceus" | "pythia" | "mamba"): #### start function definition def inference_fun(tokenized_inputs, device, **kwargs): tokens_ids = tokenized_inputs["input_ids"].to(device) outputs = model(tokens_ids) return outputs #### end function definition # text 📚 case "llama": #### start function definition def inference_fun(tokenized_inputs, device, **kwargs): outputs = model(**tokenized_inputs) return outputs #### end function definition case "llama": #### start function definition def inference_fun(tokenized_inputs, device, **kwargs): outputs = model(**tokenized_inputs) return outputs #### end function definition # Images 🖼️ ================================================================ #🖼️ case "vit" | "igpt" | "convnext" | "resnet" | "swin": #### start function definition def inference_fun(processed_image, device, **kwargs): for key in processed_image.keys(): processed_image[key]=processed_image[key].to(device) outputs = model(**processed_image) return outputs #### end function definition case "timm": #### start function definition def inference_fun(processed_image, device, **kwargs): processed_image=processed_image.to(device) outputs = model(processed_image) return outputs #### end function definition #🖼️/📚 case "clip": #### start function definition def inference_fun(processed_input, device, **kwargs): for key in processed_input.keys(): processed_input[key]=processed_input[key].to(device) outputs = model( input_ids=processed_input["input_ids"], attention_mask=processed_input["attention_mask"], pixel_values=processed_input["pixel_values"], ) return outputs #### end function definition #🥩/🧱/🌟 case "esm3": #### start function definition def inference_fun(processed_input, device, sequence=True, structure=True, **kwargs): """ ESM3 forward pass. """ #overwrite modalities if sequence is False: processed_input["sequence"]=None if structure is False: processed_input["structure"]=None #move to device for key in processed_input.keys(): if processed_input[key] is not None: processed_input[key]=processed_input[key].to(device) #forward outputs = model.forward( sequence_tokens = processed_input["sequence"], structure_tokens = processed_input["structure"] ) return outputs #### end function definition #return rightly defined inference function return inference_fun