Source code for activation_extractor.extractors.intermediateExtractorBase

import numpy as np
import os, shutil
from collections import OrderedDict

import torch

[docs] class IntermediateExtractorBase: """ Extractor for intermediate model outputs. Extracts the intermediate calculations (from a specified list of modules) from a pytorch model during inference. :param model: the Pytorch model object :param layer_list: list of module names to get outputs from :type layer_list: list of strings """ def __init__(self, model, layer_list): self.model = model self.layer_list = layer_list self.hook_handles = {} #store hook handles self.intermediate_outputs = {} #store intermediate calculations
[docs] def create_hook(self, layer_name): """ Creates a pytorch hook that saves the output of a given module/layer in the model. A pytorch hook is a function that is executed after the module is called. :param layer_name: name of the module/layer :type layer_name: str :return: the corresponding hook function :rtype: function """ def hook(model, input, output): self.intermediate_outputs[layer_name] = output return hook
[docs] def register_hooks(self): """ Registers all the hooks for the specified layers. It saves the hook handles to the hook_handles attribute. """ for name, module in self.model.named_modules(): if name in self.layer_list: self.hook_handles[name] = module.register_forward_hook(self.create_hook(name))
[docs] def detach_hooks(self): """ Detaches all the registered hooks saved in the hook_handles attribute. """ for name, hook_handle in self.hook_handles.items(): hook_handle.remove()
[docs] def clear_all_hooks(self): """ Clears ALL the forward hooks registered to the model. """ for name, module in self.model.named_modules(): if name in self.layer_list: module._forward_hooks = OrderedDict()
[docs] def get_outputs(self): """ Returns the intermediate activation outputs. :return: dictionary with intermediate outputs for each specified module/layer. :rtype: dictionary """ return self.intermediate_outputs