"""Emmental model."""
import importlib
import itertools
import logging
import os
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import torch
from numpy import ndarray
from torch import Tensor, nn
from torch.nn import ModuleDict
from emmental.data import EmmentalDataLoader
from emmental.meta import Meta
from emmental.scorer import Scorer
from emmental.task import ActionIndex, EmmentalTask
from emmental.utils.utils import (
array_to_numpy,
construct_identifier,
move_to_device,
prob_to_pred,
)
if importlib.util.find_spec("ipywidgets") is not None:
from tqdm.auto import tqdm
else:
from tqdm import tqdm
logger = logging.getLogger(__name__)
[docs]class EmmentalModel(nn.Module):
"""A class to build multi-task model.
Args:
name: Name of the model, defaults to None.
tasks: A task or a list of tasks.
"""
def __init__(
self,
name: Optional[str] = None,
tasks: Optional[Union[EmmentalTask, List[EmmentalTask]]] = None,
) -> None:
"""Initialize EmmentalModel."""
super().__init__()
self.name = name if name is not None else type(self).__name__
# Initiate the model attributes
self.module_pool: ModuleDict = ModuleDict()
self.task_names: Set[str] = set()
self.task_flows: Dict[str, Any] = dict() # TODO: make it concrete
self.loss_funcs: Dict[str, Callable] = dict()
self.output_funcs: Dict[str, Callable] = dict()
self.scorers: Dict[str, Scorer] = dict()
self.action_outputs: Dict[
str, Optional[List[Union[Tuple[str, str], Tuple[str, int]]]]
] = dict()
self.module_device: Dict[str, Union[int, str, torch.device]] = dict()
self.task_weights: Dict[str, float] = dict()
self.require_prob_for_evals: Dict[str, bool] = dict()
self.require_pred_for_evals: Dict[str, bool] = dict()
# Build network with given tasks
if tasks is not None:
self.add_tasks(tasks)
if Meta.config["meta_config"]["verbose"]:
logger.info(
f"Created emmental model {self.name} that contains "
f"task {self.task_names}."
)
def _get_default_device(self) -> torch.device:
return (
torch.device("cpu")
if Meta.config["model_config"]["device"] == -1
else torch.device(Meta.config["model_config"]["device"])
)
def _move_to_device(self) -> None:
"""Move model to specified device."""
default_device = self._get_default_device()
for module_name in self.module_pool.keys():
device = (
self.module_device[module_name]
if module_name in self.module_device
else default_device
)
if device != torch.device("cpu"):
if torch.cuda.is_available():
if Meta.config["meta_config"]["verbose"]:
logger.info(f"Moving {module_name} module to GPU ({device}).")
self.module_pool[module_name].to(device)
else:
if Meta.config["meta_config"]["verbose"]:
logger.info(
f"No cuda device available. "
f"Switch {module_name} to cpu instead."
)
self.module_pool[module_name].to(torch.device("cpu"))
else:
if Meta.config["meta_config"]["verbose"]:
logger.info(f"Moving {module_name} module to CPU.")
self.module_pool[module_name].to(torch.device("cpu"))
def _to_dataparallel(self) -> None:
default_device = self._get_default_device()
for module_name in self.module_pool.keys():
device = (
self.module_device[module_name]
if module_name in self.module_device
else default_device
)
if device != torch.device("cpu"):
self.module_pool[module_name] = torch.nn.DataParallel(
self.module_pool[module_name]
)
def _to_distributed_dataparallel(self) -> None:
# TODO support multiple device with DistributedDataParallel
for key in self.module_pool.keys():
# Ensure there is some gradient parameter for DDP
if not any(p.requires_grad for p in self.module_pool[key].parameters()):
continue
self.module_pool[
key
] = torch.nn.parallel.DistributedDataParallel( # type: ignore
self.module_pool[key],
device_ids=[Meta.config["learner_config"]["local_rank"]],
output_device=Meta.config["learner_config"]["local_rank"],
find_unused_parameters=True,
)
[docs] def add_tasks(self, tasks: Union[EmmentalTask, List[EmmentalTask]]) -> None:
"""Build the MTL network using all tasks.
Args:
tasks: A task or a list of tasks.
"""
if not isinstance(tasks, Iterable):
tasks = [tasks]
for task in tasks:
self.add_task(task)
[docs] def add_task(self, task: EmmentalTask) -> None:
"""Add a single task into MTL network.
Args:
task: A task to add.
"""
if not isinstance(task, EmmentalTask):
raise ValueError(f"Unrecognized task type {task}.")
if task.name in self.task_names:
raise ValueError(
f"Found duplicate task {task.name}, different task should use "
f"different task name."
)
# Combine module_pool from all tasks
for key in task.module_pool.keys():
if key in self.module_pool.keys():
task.module_pool[key] = self.module_pool[key]
else:
self.module_pool[key] = task.module_pool[key]
# Collect task name
self.task_names.add(task.name)
# Collect task flow
self.task_flows[task.name] = task.task_flow
# Collect loss function
self.loss_funcs[task.name] = task.loss_func
# Collect output function
self.output_funcs[task.name] = task.output_func
# Collect action outputs
self.action_outputs[task.name] = task.action_outputs
# Collect module device
self.module_device.update(task.module_device)
# Collect scorer
self.scorers[task.name] = task.scorer
# Collect weight
self.task_weights[task.name] = task.weight
# Collect require prob for eval
self.require_prob_for_evals[task.name] = task.require_prob_for_eval
# Collect require pred for eval
self.require_pred_for_evals[task.name] = task.require_pred_for_eval
# Move model to specified device
self._move_to_device()
[docs] def update_task(self, task: EmmentalTask) -> None:
"""Update a existing task in MTL network.
Args:
task: A task to update.
"""
# Update module_pool with task
for key in task.module_pool.keys():
# Update the model's module with the task's module
self.module_pool[key] = task.module_pool[key]
# Update task flow
self.task_flows[task.name] = task.task_flow
# Update loss function
self.loss_funcs[task.name] = task.loss_func
# Update output function
self.output_funcs[task.name] = task.output_func
# Update action outputs
self.action_outputs[task.name] = task.action_outputs
# Update module device
self.module_device.update(task.module_device)
# Update scorer
self.scorers[task.name] = task.scorer
# Update weight
self.task_weights[task.name] = task.weight
# Update require prob for eval
self.require_prob_for_evals[task.name] = task.require_prob_for_eval
# Update require pred for eval
self.require_pred_for_evals[task.name] = task.require_pred_for_eval
# Move model to specified device
self._move_to_device()
[docs] def remove_task(self, task_name: str) -> None:
"""Remove a existing task from MTL network.
Args:
task_name: The task name to remove.
"""
if task_name not in self.task_flows:
if Meta.config["meta_config"]["verbose"]:
logger.info(f"Task ({task_name}) not in the current model, skip...")
return
# Remove task by task_name
if Meta.config["meta_config"]["verbose"]:
logger.info(f"Removing Task {task_name}.")
self.task_names.remove(task_name)
del self.task_flows[task_name]
del self.loss_funcs[task_name]
del self.output_funcs[task_name]
del self.action_outputs[task_name]
del self.scorers[task_name]
del self.task_weights[task_name]
del self.require_prob_for_evals[task_name]
del self.require_pred_for_evals[task_name]
# TODO: remove the modules only associate with that task
def __repr__(self) -> str:
"""Represent the model as a string."""
cls_name = type(self).__name__
return f"{cls_name}(name={self.name})"
def _get_data_from_output_dict(
self, output_dict: Dict[str, Any], index: ActionIndex
) -> Any:
"""Get output_dict output based on output_idx.
For the valid index, please check the definition of Action.
"""
# Handle any output_dict's item and index is str or int
if isinstance(index, (str, int)):
if index in output_dict:
return output_dict[index]
else:
raise ValueError(f"Action {index}'s output is not in the output_dict.")
# Handle output_dict's item is a list, tuple or dict, and index is (X, Y)
elif isinstance(output_dict[index[0]], (list, tuple)):
if isinstance(index[1], int):
return output_dict[index[0]][index[1]]
else:
raise ValueError(
f"Action {index[0]} output has {type(output_dict[index[0]])} type, "
f"while index has {type(index[1])} not int."
)
elif isinstance(output_dict[index[0]], dict):
if index[1] in output_dict[index[0]]:
return output_dict[index[0]][index[1]]
else:
raise ValueError(
f"Action {index[0]}'s output doesn't have attribute {index[1]}."
)
# Handle output_dict's item is neither a list or dict, and index is (X, Y)
elif int(index[1]) == 0:
return output_dict[index[0]]
raise ValueError(f"Cannot parse action index {index}.")
[docs] def flow(self, X_dict: Dict[str, Any], task_names: List[str]) -> Dict[str, Any]:
"""Forward based on input and task flow.
Note:
We assume that all shared modules from all tasks are based on the
same input.
Args:
X_dict: The input data
task_names: The task names that needs to forward.
Returns:
The output of all forwarded modules
"""
default_device = self._get_default_device()
X_dict = move_to_device(X_dict, default_device)
output_dict = dict(_input_=X_dict)
# Call forward for each task
for task_name in task_names:
for action in self.task_flows[task_name]:
if action.name not in output_dict:
if action.inputs:
try:
action_module_device = (
self.module_device[action.module]
if action.module in self.module_device
else default_device
)
input = move_to_device(
[
self._get_data_from_output_dict(output_dict, _input)
for _input in action.inputs
],
action_module_device,
)
except Exception:
raise ValueError(f"Unrecognized action {action}.")
output = self.module_pool[action.module].forward(*input)
else:
# TODO: Handle multiple device with not inputs case
output = self.module_pool[action.module].forward(output_dict)
output_dict[action.name] = output
return output_dict
[docs] def forward( # type: ignore
self,
uids: List[str],
X_dict: Dict[str, Any],
Y_dict: Dict[str, Tensor],
task_to_label_dict: Dict[str, str],
return_loss=True,
return_probs=True,
return_action_outputs=False,
) -> Union[
Tuple[
Dict[str, List[str]],
Dict[str, Tensor],
Dict[str, Union[ndarray, List[ndarray]]],
Dict[str, Union[ndarray, List[ndarray]]],
Dict[str, Dict[str, Union[ndarray, List]]],
],
Tuple[
Dict[str, List[str]],
Dict[str, Tensor],
Dict[str, Union[ndarray, List[ndarray]]],
Dict[str, Union[ndarray, List[ndarray]]],
],
]:
"""Forward function.
Args:
uids: The uids of input data.
X_dict: The input data.
Y_dict: The output data.
task_to_label_dict: The task to label mapping.
return_loss: Whether return loss or not, defaults to True.
return_probs: Whether return probs or not, defaults to True.
return_action_outputs: Whether return action_outputs or not,
defaults to False.
Returns:
The uids, loss, prob, gold, action_output (optional) in the batch of
all tasks.
"""
uid_dict: Dict[str, List[str]] = defaultdict(list)
loss_dict: Dict[str, Tensor] = defaultdict(Tensor) if return_loss else None
gold_dict: Dict[str, Union[ndarray, List[ndarray]]] = (
defaultdict(list) if Y_dict is not None else None
)
prob_dict: Dict[str, Union[ndarray, List[ndarray]]] = (
defaultdict(list) if return_probs else None
)
out_dict: Dict[str, Dict[str, Union[ndarray, List]]] = (
defaultdict(lambda: defaultdict(list)) if return_action_outputs else None
)
output_dict = self.flow(X_dict, list(task_to_label_dict.keys()))
# Calculate logits and loss for each task
for task_name, label_name in task_to_label_dict.items():
assert Y_dict is not None or (
Y_dict is None and label_name is None
), f"Task {task_name} has not {label_name} label."
uid_dict[task_name] = uids
if (
return_loss
and task_name in self.loss_funcs
and self.loss_funcs[task_name] is not None
):
loss_dict[task_name] = self.loss_funcs[task_name](
output_dict,
move_to_device(
Y_dict[label_name],
Meta.config["model_config"]["device"],
)
if Y_dict is not None and label_name is not None
else None,
)
if (
return_probs
and task_name in self.output_funcs
and self.output_funcs[task_name] is not None
):
prob_dict[task_name] = (
self.output_funcs[task_name](output_dict).cpu().detach().numpy()
)
if Y_dict is not None and label_name is not None:
gold_dict[task_name] = Y_dict[label_name].cpu().numpy()
if (
return_action_outputs
and task_name in self.action_outputs
and self.action_outputs[task_name] is not None
):
for _output in self.action_outputs[task_name]:
out_dict[task_name][
_output
if isinstance(_output, str)
else f"{_output[0]}_{_output[1]}"
] = (
self._get_data_from_output_dict(output_dict, _output)
.cpu()
.detach()
.numpy()
)
if return_action_outputs:
return uid_dict, loss_dict, prob_dict, gold_dict, out_dict
else:
return uid_dict, loss_dict, prob_dict, gold_dict
[docs] @torch.no_grad()
def predict(
self,
dataloader: EmmentalDataLoader,
return_loss: bool = True,
return_probs: bool = True,
return_preds: bool = False,
return_action_outputs: bool = False,
) -> Dict[str, Any]:
"""Predict from dataloader.
Args:
dataloader: The dataloader to predict.
return_loss: Whether return loss or not, defaults to True.
return_probs: Whether return probs or not, defaults to True.
return_preds: Whether return predictions or not, defaults to False.
return_action_outputs: Whether return action_outputs or not,
defaults to False.
Returns:
The result dict.
"""
self.eval()
# Check if Y_dict exists
has_y_dict = False if isinstance(dataloader.dataset[0], dict) else True
uid_dict: Dict[str, List[str]] = defaultdict(list)
prob_dict: Dict[str, Union[ndarray, List[ndarray]]] = (
defaultdict(list) if return_probs else None
)
pred_dict: Dict[str, Union[ndarray, List[ndarray]]] = (
defaultdict(list) if return_preds else None
)
out_dict: Dict[str, Dict[str, List[Union[ndarray, int, float]]]] = (
defaultdict(lambda: defaultdict(list)) if return_action_outputs else None
)
loss_dict: Dict[str, Union[ndarray, float]] = (
defaultdict(list) if return_loss else None # type: ignore
)
gold_dict: Dict[str, List[Union[ndarray, int, float]]] = (
defaultdict(list) if has_y_dict else None
)
with torch.no_grad():
for bdict in tqdm(
dataloader,
total=len(dataloader),
desc=f"Evaluating {dataloader.data_name} ({dataloader.split})",
):
if has_y_dict:
X_bdict, Y_bdict = bdict
else:
X_bdict = bdict
Y_bdict = None
if return_action_outputs:
(
uid_bdict,
loss_bdict,
prob_bdict,
gold_bdict,
out_bdict,
) = self.forward( # type: ignore
X_bdict[dataloader.uid],
X_bdict,
Y_bdict,
dataloader.task_to_label_dict,
return_loss=return_loss,
return_action_outputs=return_action_outputs,
return_probs=return_probs or return_preds,
)
else:
(
uid_bdict,
loss_bdict,
prob_bdict,
gold_bdict,
) = self.forward( # type: ignore
X_bdict[dataloader.uid],
X_bdict,
Y_bdict,
dataloader.task_to_label_dict,
return_loss=return_loss,
return_action_outputs=return_action_outputs,
return_probs=return_probs or return_preds,
)
out_bdict = None
for task_name in uid_bdict.keys():
uid_dict[task_name].extend(uid_bdict[task_name])
if return_loss:
if len(loss_bdict[task_name].size()) == 0:
if loss_dict[task_name] == []:
loss_dict[task_name] = 0
loss_dict[task_name] += loss_bdict[task_name].item() * len(
uid_bdict[task_name]
)
else:
loss_dict[task_name].extend( # type: ignore
loss_bdict[task_name].cpu().numpy()
)
if return_probs:
prob_dict[task_name].extend( # type: ignore
prob_bdict[task_name]
)
if return_preds:
pred_dict[task_name].extend( # type: ignore
prob_to_pred(prob_bdict[task_name])
)
if has_y_dict:
gold_dict[task_name].extend(gold_bdict[task_name])
if return_action_outputs and out_bdict:
for task_name in out_bdict.keys():
for action_name in out_bdict[task_name].keys():
out_dict[task_name][action_name].extend(
out_bdict[task_name][action_name]
)
# Calculate average loss
if return_loss:
for task_name in uid_dict.keys():
if not isinstance(loss_dict[task_name], list):
loss_dict[task_name] /= len(uid_dict[task_name])
res = {
"uids": uid_dict,
"golds": gold_dict,
"losses": loss_dict,
}
if return_probs:
for task_name in prob_dict.keys():
prob_dict[task_name] = array_to_numpy(prob_dict[task_name])
res["probs"] = prob_dict
if return_preds:
for task_name in pred_dict.keys():
pred_dict[task_name] = array_to_numpy(pred_dict[task_name])
res["preds"] = pred_dict
if return_action_outputs:
res["outputs"] = out_dict
return res
[docs] @torch.no_grad()
def score(
self,
dataloaders: Union[EmmentalDataLoader, List[EmmentalDataLoader]],
return_average: bool = True,
) -> Dict[str, float]:
"""Score the data from dataloader.
Args:
dataloaders: The dataloaders to score.
return_average: Whether to return average score.
Returns:
Score dict.
"""
self.eval()
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
metric_score_dict = dict()
if return_average:
micro_score_dict: defaultdict = defaultdict(list)
macro_score_dict: defaultdict = defaultdict(list)
macro_loss_dict: defaultdict = defaultdict(list)
for dataloader in dataloaders:
return_probs = False
return_preds = False
for task_name in dataloader.task_to_label_dict:
return_probs = return_probs or self.require_prob_for_evals[task_name]
return_preds = return_preds or self.require_pred_for_evals[task_name]
predictions = self.predict(
dataloader,
return_probs=return_probs,
return_preds=return_preds,
return_action_outputs=False,
)
for task_name in predictions["uids"].keys():
# Store the loss
identifier = construct_identifier(
task_name, dataloader.data_name, dataloader.split, "loss"
)
metric_score_dict[identifier] = np.mean( # type: ignore
predictions["losses"][task_name]
)
if return_average:
macro_loss_dict[dataloader.split].append(
metric_score_dict[identifier]
)
# Store the task specific metric score
if self.scorers[task_name]:
metric_score = self.scorers[task_name].score(
predictions["golds"][task_name],
predictions["probs"][task_name] if return_probs else None,
predictions["preds"][task_name] if return_preds else None,
predictions["uids"][task_name],
)
for metric_name, metric_value in metric_score.items():
identifier = construct_identifier(
task_name,
dataloader.data_name,
dataloader.split,
metric_name,
)
metric_score_dict[identifier] = metric_value
if return_average:
# Collect average score
identifier = construct_identifier(
task_name, dataloader.data_name, dataloader.split, "average"
)
metric_score_dict[identifier] = np.mean( # type: ignore
list(metric_score.values())
)
micro_score_dict[dataloader.split].extend(
list(metric_score.values())
)
macro_score_dict[dataloader.split].append(
metric_score_dict[identifier]
)
if return_average:
# Collect split-wise micro/macro average score
for split in micro_score_dict.keys():
identifier = construct_identifier(
"model", "all", split, "micro_average"
)
metric_score_dict[identifier] = np.mean( # type: ignore
micro_score_dict[split]
)
identifier = construct_identifier(
"model", "all", split, "macro_average"
)
metric_score_dict[identifier] = np.mean( # type: ignore
macro_score_dict[split]
)
for split in macro_loss_dict.keys():
identifier = construct_identifier("model", "all", split, "loss")
metric_score_dict[identifier] = np.mean( # type: ignore
macro_loss_dict[split]
)
# Collect overall micro/macro average score/loss
if micro_score_dict:
identifier = construct_identifier(
"model", "all", "all", "micro_average"
)
metric_score_dict[identifier] = np.mean( # type: ignore
list(itertools.chain.from_iterable(micro_score_dict.values()))
)
if macro_score_dict:
identifier = construct_identifier(
"model", "all", "all", "macro_average"
)
metric_score_dict[identifier] = np.mean( # type: ignore
list(itertools.chain.from_iterable(macro_score_dict.values()))
)
if macro_loss_dict:
identifier = construct_identifier("model", "all", "all", "loss")
metric_score_dict[identifier] = np.mean( # type: ignore
list(itertools.chain.from_iterable(macro_loss_dict.values()))
)
# TODO: have a better to handle global evaluation metric
if Meta.config["learner_config"]["global_evaluation_metric_dict"]:
global_evaluation_metric_dict = Meta.config["learner_config"][
"global_evaluation_metric_dict"
]
for metric_name, metric in global_evaluation_metric_dict.items():
metric_score_dict[metric_name] = metric(metric_score_dict)
return metric_score_dict
[docs] def save(
self,
model_path: str,
iteration: Optional[Union[float, int]] = None,
metric_dict: Optional[Dict[str, float]] = None,
verbose: bool = True,
) -> None:
"""Save model.
Args:
model_path: Saved model path.
iteration: The iteration of the model, defaults to `None`.
metric_dict: The metric dict, defaults to `None`.
verbose: Whether log the info, defaults to `True`.
"""
# Check existence of model saving directory and create if does not exist.
if not os.path.exists(os.path.dirname(model_path)):
os.makedirs(os.path.dirname(model_path))
state_dict = {
"model": {
"name": self.name,
"module_pool": self.collect_state_dict(),
# "task_names": self.task_names,
# "task_flows": self.task_flows,
# "loss_funcs": self.loss_funcs,
# "output_funcs": self.output_funcs,
# "scorers": self.scorers,
},
"iteration": iteration,
"metric_dict": metric_dict,
}
try:
torch.save(state_dict, model_path)
except BaseException:
logger.warning("Saving failed... continuing anyway.")
if Meta.config["meta_config"]["verbose"] and verbose:
logger.info(f"[{self.name}] Model saved in {model_path}")
[docs] def load(
self,
model_path: str,
verbose: bool = True,
) -> None:
"""Load model state_dict from file and reinitialize the model weights.
Args:
model_path: Saved model path.
verbose: Whether log the info, defaults to `True`.
"""
if not os.path.exists(model_path):
logger.error("Loading failed... Model does not exist.")
try:
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
except BaseException:
logger.error(f"Loading failed... Cannot load model from {model_path}")
raise
self.load_state_dict(checkpoint["model"]["module_pool"])
if Meta.config["meta_config"]["verbose"] and verbose:
logger.info(f"[{self.name}] Model loaded from {model_path}")
# Move model to specified device
self._move_to_device()
[docs] def collect_state_dict(self) -> Dict[str, Any]:
"""Collect the state dict."""
state_dict: Dict[str, Any] = defaultdict(list)
for module_name, module in self.module_pool.items():
if hasattr(module, "module"):
state_dict[module_name] = module.module.state_dict() # type: ignore
else:
state_dict[module_name] = module.state_dict()
return state_dict
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # type: ignore
"""Load the state dict.
Args:
state_dict: The state dict to load.
"""
for module_name, module_state_dict in state_dict.items():
if module_name in self.module_pool:
if hasattr(self.module_pool[module_name], "module"):
self.module_pool[module_name].module.load_state_dict(
module_state_dict
)
else:
self.module_pool[module_name].load_state_dict(module_state_dict)
else:
logger.info(f"Missing {module_name} in module_pool, skip it..")