Source code for emmental.task

"""Emmental task."""
import logging
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

import torch
from torch import nn
from torch.nn.modules.container import ModuleDict

from emmental.meta import Meta
from emmental.scorer import Scorer

logger = logging.getLogger(__name__)

ActionIndex = Union[str, Tuple[str, str], Tuple[str, int]]
ActionInputs = Union[str, Sequence[ActionIndex]]
ActionOutputs = Union[str, Sequence[ActionIndex]]


[docs]class Action: """An action to execute in a EmmentalTask task_flow. Action is the object that populate the task_flow sequence. It has three attributes: name, module_name and inputs where name is the name of the action, module_name is the module name used in this action and inputs is the inputs to the action. By introducing a class for specifying actions in the task_flow, we standardize its definition. Moreover, Action enables more user flexibility in specifying a task flow as we can now support a wider-range of formats for the input attribute of a task_flow as follow: 1. It now supports str as inputs (e.g., inputs="input1") which means take the input1's output as input for current action. 2. It also support None as inputs which will take all modules' output as input. 3. It also supports a list as inputs which can be constructed by three different formats: a). x (x is str) where takes whole output of x's output as input: this enables users to pass all outputs from one module to another without having to manually specify every input to the module. b). (x, y) (y is int) where takes x's y-th output as input. c). (x, y) (y is str) where takes x's output str as input. Args: name: The name of the action. module_name: The module_name of the module. inputs: The inputs of the action. Details can be found above. """ def __init__( self, name: str, module: str, inputs: Optional[ActionInputs] = None, ) -> None: """Initialize Action.""" self.name = name self.module = module self.inputs = inputs if inputs is None or isinstance(inputs, list) else [inputs] def __repr__(self) -> str: """Represent the action as a string.""" return ( f"Action(name={self.name}, " f"module={self.module}, " f"inputs={self.inputs})" )
[docs]class EmmentalTask(object): """Task class to define task in Emmental model. Args: name: The name of the task (Primary key). module_pool: A dict of modules that uses in the task. task_flow: The task flow among modules to define how the data flows. loss_func: The function to calculate the loss. output_func: The function to generate the output. scorer: The class of metrics to evaluate the task, defaults to None. action_outputs: The action outputs need to output, defaults to None. module_device: The dict of module device specification, defaults to None. weight: The weight of the task, defaults to 1.0. require_prob_for_eval: Whether require prob for evaluation, defaults to True. require_pred_for_eval: Whether require pred for evaluation, defaults to True. """ def __init__( self, name: str, module_pool: ModuleDict, task_flow: Sequence[Action], loss_func: Callable, output_func: Callable, scorer: Scorer = None, action_outputs: Optional[ActionOutputs] = None, module_device: Dict[str, Union[int, str, torch.device]] = dict(), weight: Union[float, int] = 1.0, require_prob_for_eval: bool = True, require_pred_for_eval: bool = True, ) -> None: """Initialize EmmentalTask.""" self.name = name assert isinstance(module_pool, nn.ModuleDict) is True self.module_pool = module_pool self.task_flow = task_flow self.loss_func = loss_func self.output_func = output_func self.scorer = scorer self.action_outputs = ( action_outputs if action_outputs is None or isinstance(action_outputs, list) else [action_outputs] ) if action_outputs is not None: self.action_outputs = list(set(action_outputs)) self.module_device = {} for module_name in module_device.keys(): if module_name not in self.module_pool: logger.warning( f"Module {module_name} from module_device doesn't in module_pool, " "skip..." ) continue if module_device[module_name] == -1: self.module_device[module_name] = torch.device("cpu") else: self.module_device[module_name] = torch.device( module_device[module_name] ) self.require_prob_for_eval = require_prob_for_eval self.require_pred_for_eval = require_pred_for_eval self.weight = weight if Meta.config["meta_config"]["verbose"]: logger.info(f"Created task: {self.name}") def __repr__(self) -> str: """Represent the task as a string.""" cls_name = type(self).__name__ return f"{cls_name}(name={self.name})"