Source code for activation_extractor.model_functions.embedding_to_numpy

import torch
from transformers.modeling_outputs import (MaskedLMOutput, 
                                            BaseModelOutputWithNoAttention, 
                                            BaseModelOutputWithPastAndCrossAttentions)

from transformers.models.xlnet.modeling_xlnet import XLNetModelOutput

try:
    from esm.models.esm3 import ESMOutput
except:
    print("ESM library not installed")

[docs] def embedding_to_numpy(embeddings): """ Converts different types of module outputs to a numpy array. Handles different cases for the different models. Additionally, moves from GPU to CPU. :param embedding: Intermediate output object from a pytorch model layer/module. :return: intermediate output as a numpy array :rtype: numpy array """ #convert list to array (for protT5 tokens) if isinstance(embeddings, list): embeddings = np.array(embeddings) return embeddings #convert output object to tensor ------------------------- #intermediate activations if isinstance(embeddings, tuple): embeddings = embeddings[0] #NT/ESM output object if isinstance(embeddings, MaskedLMOutput): embeddings = embeddings['logits'] #HYENA or ANKH/PROTT5 or PROTBERT output object if (isinstance(embeddings, BaseModelOutputWithNoAttention) or isinstance(embeddings, BaseModelOutputWithPastAndCrossAttentions) or isinstance(embeddings, XLNetModelOutput)): embeddings = embeddings['last_hidden_state'] #ESM3 Output try: if isinstance(embeddings, ESMOutput): embeddings = embeddings.sequence_logits except Exception as e: pass #convert output object to tensor ------------------------- #float conversions for tensors if torch.is_tensor(embeddings): if embeddings.dtype==torch.bfloat16: embeddings = embeddings.float() #to CPU numpy array embeddings = embeddings.cpu().detach().numpy() return embeddings