Source code for emmental.learner

"""Emmental learner."""
import collections
import copy
import importlib
import logging
import math
import time
from collections import defaultdict
from functools import partial
from typing import Dict, List, Optional, Union

import numpy as np
import torch
import wandb
from numpy import ndarray
from torch import optim as optim
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, DistributedSampler

from emmental.data import EmmentalDataLoader
from emmental.logging import LoggingManager
from emmental.meta import Meta
from emmental.model import EmmentalModel
from emmental.optimizers.bert_adam import BertAdam
from emmental.schedulers import SCHEDULERS
from emmental.schedulers.scheduler import Scheduler
from emmental.utils.utils import construct_identifier, prob_to_pred

if importlib.util.find_spec("ipywidgets") is not None:
    from tqdm.auto import tqdm
else:
    from tqdm import tqdm

logger = logging.getLogger(__name__)


[docs]class EmmentalLearner(object): """A class for emmental multi-task learning. Args: name: Name of the learner, defaults to None. """ def __init__(self, name: Optional[str] = None) -> None: """Initialize EmmentalLearner.""" self.name = name if name is not None else type(self).__name__ def _set_logging_manager(self) -> None: """Set logging manager.""" if Meta.config["learner_config"]["local_rank"] in [-1, 0]: if self.use_step_base_counter: self.logging_manager = LoggingManager( self.n_batches_per_epoch, 0, self.start_step ) else: self.logging_manager = LoggingManager( self.n_batches_per_epoch, self.start_epoch, self.start_epoch * self.n_batches_per_epoch, ) def _set_optimizer(self, model: EmmentalModel) -> None: """Set optimizer for learning process. Args: model: The model to set up the optimizer. """ optimizer_config = Meta.config["learner_config"]["optimizer_config"] opt = optimizer_config["optimizer"] # If Meta.config["learner_config"]["optimizer_config"]["parameters"] is None, # create a parameter group with all parameters in the model, else load user # specified parameter groups. if optimizer_config["parameters"] is None: parameters = filter(lambda p: p.requires_grad, model.parameters()) else: parameters = optimizer_config["parameters"](model) optim_dict = { # PyTorch optimizer "asgd": optim.ASGD, "adadelta": optim.Adadelta, "adagrad": optim.Adagrad, "adam": optim.Adam, "adamw": optim.AdamW, "adamax": optim.Adamax, "lbfgs": optim.LBFGS, "rms_prop": optim.RMSprop, "r_prop": optim.Rprop, "sgd": optim.SGD, "sparse_adam": optim.SparseAdam, # Customized optimizer "bert_adam": BertAdam, } if opt in ["lbfgs", "r_prop", "sparse_adam"]: optimizer = optim_dict[opt]( parameters, lr=optimizer_config["lr"], **optimizer_config[f"{opt}_config"], ) elif opt in optim_dict.keys(): optimizer = optim_dict[opt]( parameters, lr=optimizer_config["lr"], weight_decay=optimizer_config["l2"], **optimizer_config[f"{opt}_config"], ) elif (isinstance(opt, type) and issubclass(opt, optim.Optimizer)) or ( isinstance(opt, partial) and issubclass(opt.func, optim.Optimizer) # type: ignore ): optimizer = opt(parameters) # type: ignore else: raise ValueError(f"Unrecognized optimizer option '{opt}'") self.optimizer = optimizer if Meta.config["meta_config"]["verbose"]: logger.info(f"Using optimizer {self.optimizer}") if Meta.config["learner_config"]["optimizer_path"]: try: self.optimizer.load_state_dict( torch.load( Meta.config["learner_config"]["optimizer_path"], map_location=torch.device("cpu"), )["optimizer"] ) logger.info( f"Optimizer state loaded from " f"{Meta.config['learner_config']['optimizer_path']}" ) except BaseException: logger.error( f"Loading failed... Cannot load optimizer state from " f"{Meta.config['learner_config']['optimizer_path']}, " f"continuing anyway." ) def _set_lr_scheduler(self, model: EmmentalModel) -> None: """Set learning rate scheduler for learning process. Args: model: The model to set up lr scheduler. """ # Set warmup scheduler self._set_warmup_scheduler(model) # Set lr scheduler lr_scheduler_dict = { "exponential": optim.lr_scheduler.ExponentialLR, "plateau": optim.lr_scheduler.ReduceLROnPlateau, "step": optim.lr_scheduler.StepLR, "multi_step": optim.lr_scheduler.MultiStepLR, "cyclic": optim.lr_scheduler.CyclicLR, "one_cycle": optim.lr_scheduler.OneCycleLR, # type: ignore "cosine_annealing": optim.lr_scheduler.CosineAnnealingLR, } opt = Meta.config["learner_config"]["lr_scheduler_config"]["lr_scheduler"] lr_scheduler_config = Meta.config["learner_config"]["lr_scheduler_config"] if opt is None: lr_scheduler = None elif opt == "linear": linear_decay_func = lambda x: (self.total_steps - self.warmup_steps - x) / ( self.total_steps - self.warmup_steps ) lr_scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, linear_decay_func ) elif opt in ["exponential", "step", "multi_step", "cyclic"]: lr_scheduler = lr_scheduler_dict[opt]( self.optimizer, **lr_scheduler_config[f"{opt}_config"] ) elif opt == "one_cycle": lr_scheduler = lr_scheduler_dict[opt]( self.optimizer, total_steps=self.total_steps, epochs=Meta.config["learner_config"]["n_epochs"] if not self.use_step_base_counter else 1, steps_per_epoch=self.n_batches_per_epoch if not self.use_step_base_counter else self.total_steps, **lr_scheduler_config[f"{opt}_config"], ) elif opt == "cosine_annealing": lr_scheduler = lr_scheduler_dict[opt]( self.optimizer, self.total_steps, eta_min=lr_scheduler_config["min_lr"], **lr_scheduler_config[f"{opt}_config"], ) elif opt == "plateau": plateau_config = copy.deepcopy(lr_scheduler_config["plateau_config"]) del plateau_config["metric"] lr_scheduler = lr_scheduler_dict[opt]( self.optimizer, verbose=Meta.config["meta_config"]["verbose"], min_lr=lr_scheduler_config["min_lr"], **plateau_config, ) elif isinstance(opt, _LRScheduler): lr_scheduler = opt(self.optimizer) # type: ignore else: raise ValueError(f"Unrecognized lr scheduler option '{opt}'") self.lr_scheduler = lr_scheduler self.lr_scheduler_step_unit = Meta.config["learner_config"][ "lr_scheduler_config" ]["lr_scheduler_step_unit"] self.lr_scheduler_step_freq = Meta.config["learner_config"][ "lr_scheduler_config" ]["lr_scheduler_step_freq"] if Meta.config["meta_config"]["verbose"]: logger.info( f"Using lr_scheduler {repr(self.lr_scheduler)} with step every " f"{self.lr_scheduler_step_freq} {self.lr_scheduler_step_unit}." ) if Meta.config["learner_config"]["scheduler_path"]: try: scheduler_state = torch.load( Meta.config["learner_config"]["scheduler_path"] )["lr_scheduler"] if scheduler_state: self.lr_scheduler.load_state_dict(scheduler_state) logger.info( f"Lr scheduler state loaded from " f"{Meta.config['learner_config']['scheduler_path']}" ) except BaseException: logger.error( f"Loading failed... Cannot load lr scheduler state from " f"{Meta.config['learner_config']['scheduler_path']}, " f"continuing anyway." ) def _set_warmup_scheduler(self, model: EmmentalModel) -> None: """Set warmup learning rate scheduler for learning process. Args: model: The model to set up warmup scheduler. """ self.warmup_steps = 0 if Meta.config["learner_config"]["lr_scheduler_config"]["warmup_steps"]: warmup_steps = Meta.config["learner_config"]["lr_scheduler_config"][ "warmup_steps" ] if warmup_steps < 0: raise ValueError("warmup_steps must greater than 0.") warmup_unit = Meta.config["learner_config"]["lr_scheduler_config"][ "warmup_unit" ] if warmup_unit == "epoch": self.warmup_steps = int(warmup_steps * self.n_batches_per_epoch) elif warmup_unit == "batch": self.warmup_steps = int(warmup_steps) else: raise ValueError( f"warmup_unit must be 'batch' or 'epoch', but {warmup_unit} found." ) linear_warmup_func = lambda x: x / self.warmup_steps warmup_scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, linear_warmup_func ) if Meta.config["meta_config"]["verbose"]: logger.info(f"Warmup {self.warmup_steps} batchs.") elif Meta.config["learner_config"]["lr_scheduler_config"]["warmup_percentage"]: warmup_percentage = Meta.config["learner_config"]["lr_scheduler_config"][ "warmup_percentage" ] self.warmup_steps = math.ceil(warmup_percentage * self.total_steps) linear_warmup_func = lambda x: x / self.warmup_steps warmup_scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, linear_warmup_func ) if Meta.config["meta_config"]["verbose"]: logger.info(f"Warmup {self.warmup_steps} batchs.") else: warmup_scheduler = None self.warmup_scheduler = warmup_scheduler def _update_lr_scheduler( self, model: EmmentalModel, step: int, metric_dict: Dict[str, float] ) -> None: """Update the lr using lr_scheduler with each batch. Args: model: The model to update lr scheduler. step: The current step. """ cur_lr = self.optimizer.param_groups[0]["lr"] if self.warmup_scheduler and step < self.warmup_steps: self.warmup_scheduler.step() elif self.lr_scheduler is not None: lr_step_cnt = ( self.lr_scheduler_step_freq if self.lr_scheduler_step_unit == "batch" else self.lr_scheduler_step_freq * self.n_batches_per_epoch ) if (step + 1) % lr_step_cnt == 0: if ( Meta.config["learner_config"]["lr_scheduler_config"]["lr_scheduler"] != "plateau" ): self.lr_scheduler.step() elif ( Meta.config["learner_config"]["lr_scheduler_config"][ "plateau_config" ]["metric"] in metric_dict ): self.lr_scheduler.step( metric_dict[ # type: ignore Meta.config["learner_config"]["lr_scheduler_config"][ "plateau_config" ]["metric"] ] ) min_lr = Meta.config["learner_config"]["lr_scheduler_config"]["min_lr"] if min_lr and self.optimizer.param_groups[0]["lr"] < min_lr: self.optimizer.param_groups[0]["lr"] = min_lr if ( Meta.config["learner_config"]["lr_scheduler_config"]["reset_state"] and cur_lr != self.optimizer.param_groups[0]["lr"] ): logger.info("Reset the state of the optimizer.") self.optimizer.state = collections.defaultdict(dict) # Reset state def _set_task_scheduler(self) -> None: """Set task scheduler for learning process.""" opt = Meta.config["learner_config"]["task_scheduler_config"]["task_scheduler"] if opt in ["sequential", "round_robin", "mixed"]: self.task_scheduler = SCHEDULERS[opt]( # type: ignore **Meta.config["learner_config"]["task_scheduler_config"][ f"{opt}_scheduler_config" ] ) elif isinstance(opt, Scheduler): self.task_scheduler = opt else: raise ValueError(f"Unrecognized task scheduler option '{opt}'") def _evaluate( self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader], split: Union[List[str], str], ) -> Dict[str, float]: """Evaluate the model. Args: model: The model to evaluate. dataloaders: The data to evaluate. split: The split to evaluate. Returns: The score dict. """ if not isinstance(split, list): valid_split = [split] else: valid_split = split valid_dataloaders = [ dataloader for dataloader in dataloaders if dataloader.split in valid_split ] return model.score(valid_dataloaders) def _logging( self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader], batch_size: int, ) -> Dict[str, float]: """Check if it's time to evaluting or checkpointing. Args: model: The model to log. dataloaders: The data to evaluate. batch_size: Batch size. Returns: The score dict. """ # Switch to eval mode for evaluation model.eval() metric_dict = dict() self.logging_manager.update(batch_size) trigger_evaluation = self.logging_manager.trigger_evaluation() # Log the loss and lr metric_dict.update( self._aggregate_running_metrics( model, trigger_evaluation and Meta.config["learner_config"]["online_eval"], ) ) # Evaluate the model and log the metric if trigger_evaluation: # Log task specific metric metric_dict.update( self._evaluate( model, dataloaders, Meta.config["learner_config"]["valid_split"] ) ) self.logging_manager.write_log(metric_dict) self._reset_losses() elif Meta.config["logging_config"]["writer_config"]["write_loss_per_step"]: self.logging_manager.write_log(metric_dict) # Log metric dict every trigger evaluation time or full epoch if Meta.config["meta_config"]["verbose"] and ( trigger_evaluation or self.logging_manager.epoch_total == int(self.logging_manager.epoch_total) ): logger.info( f"{self.logging_manager.counter_unit.capitalize()}: " f"{self.logging_manager.unit_total:.2f} {metric_dict}" ) # Checkpoint the model if self.logging_manager.trigger_checkpointing(): self.logging_manager.checkpoint_model( model, self.optimizer, self.lr_scheduler, metric_dict ) self.logging_manager.write_log(metric_dict) self._reset_losses() # Switch to train mode model.train() return metric_dict def _aggregate_running_metrics( self, model: EmmentalModel, calc_running_scores: bool = False ) -> Dict[str, float]: """Calculate the running overall and task specific metrics. Args: model: The model to evaluate. calc_running_scores: Whether to calc running scores Returns: The score dict. """ metric_dict: Dict[str, float] = dict() total_count = 0 # Log task specific loss for identifier in self.running_uids.keys(): count = len(self.running_uids[identifier]) if count > 0: metric_dict[identifier + "/loss"] = float( self.running_losses[identifier] / count ) total_count += count # Calculate average micro loss if total_count > 0: total_loss = sum(self.running_losses.values()) metric_dict["model/all/train/loss"] = float(total_loss / total_count) if calc_running_scores: micro_score_dict: Dict[str, List[float]] = defaultdict(list) macro_score_dict: Dict[str, List[float]] = defaultdict(list) # Calculate training metric for identifier in self.running_uids.keys(): task_name, data_name, split = identifier.split("/") if ( model.scorers[task_name] and self.running_golds[identifier] and self.running_probs[identifier] ): metric_score = model.scorers[task_name].score( self.running_golds[identifier], self.running_probs[identifier], prob_to_pred(self.running_probs[identifier]), self.running_uids[identifier], ) for metric_name, metric_value in metric_score.items(): metric_dict[f"{identifier}/{metric_name}"] = metric_value # Collect average score identifier = construct_identifier( task_name, data_name, split, "average" ) metric_dict[identifier] = np.mean(list(metric_score.values())) micro_score_dict[split].extend( list(metric_score.values()) # type: ignore ) macro_score_dict[split].append(metric_dict[identifier]) # Collect split-wise micro/macro average score for split in micro_score_dict.keys(): identifier = construct_identifier( "model", "all", split, "micro_average" ) metric_dict[identifier] = np.mean( micro_score_dict[split] # type: ignore ) identifier = construct_identifier( "model", "all", split, "macro_average" ) metric_dict[identifier] = np.mean( macro_score_dict[split] # type: ignore ) # Log the learning rate metric_dict["model/all/train/lr"] = self.optimizer.param_groups[0]["lr"] return metric_dict def _set_learning_counter(self) -> None: if Meta.config["learner_config"]["n_steps"]: if Meta.config["learner_config"]["skip_learned_data"]: self.start_epoch = 0 self.start_step = 0 self.start_train_epoch = 0 self.start_train_step = Meta.config["learner_config"]["steps_learned"] else: self.start_epoch = 0 self.start_step = Meta.config["learner_config"]["steps_learned"] self.start_train_epoch = 0 self.start_train_step = Meta.config["learner_config"]["steps_learned"] self.end_epoch = 1 self.end_step = Meta.config["learner_config"]["n_steps"] self.use_step_base_counter = True self.total_steps = Meta.config["learner_config"]["n_steps"] else: if Meta.config["learner_config"]["skip_learned_data"]: self.start_epoch = 0 self.start_step = 0 self.start_train_epoch = Meta.config["learner_config"]["epochs_learned"] self.start_train_step = Meta.config["learner_config"]["steps_learned"] else: self.start_epoch = Meta.config["learner_config"]["epochs_learned"] self.start_step = Meta.config["learner_config"]["steps_learned"] self.start_train_epoch = Meta.config["learner_config"]["epochs_learned"] self.start_train_step = Meta.config["learner_config"]["steps_learned"] self.end_epoch = Meta.config["learner_config"]["n_epochs"] self.end_step = self.n_batches_per_epoch self.use_step_base_counter = False self.total_steps = ( Meta.config["learner_config"]["n_epochs"] * self.n_batches_per_epoch ) def _reset_losses(self) -> None: """Reset running logs.""" self.running_uids: Dict[str, List[str]] = defaultdict(list) self.running_losses: Dict[str, ndarray] = defaultdict(float) # type: ignore self.running_probs: Dict[str, List[ndarray]] = defaultdict(list) self.running_golds: Dict[str, List[ndarray]] = defaultdict(list)
[docs] def learn( self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader] ) -> None: """Learning procedure of emmental MTL. Args: model: The emmental model that needs to learn. dataloaders: A list of dataloaders used to learn the model. """ start_time = time.time() # Generate the list of dataloaders for learning process train_split = Meta.config["learner_config"]["train_split"] if isinstance(train_split, str): train_split = [train_split] train_dataloaders = [ dataloader for dataloader in dataloaders if dataloader.split in train_split ] if not train_dataloaders: raise ValueError( f"Cannot find the specified train_split " f'{Meta.config["learner_config"]["train_split"]} in dataloaders.' ) # Set up task_scheduler self._set_task_scheduler() # Calculate the total number of batches per epoch self.n_batches_per_epoch: int = self.task_scheduler.get_num_batches( train_dataloaders ) if self.n_batches_per_epoch == 0: logger.info("No batches in training dataloaders, existing...") return # Set up learning counter self._set_learning_counter() # Set up logging manager self._set_logging_manager() # Set up wandb watch model if ( Meta.config["logging_config"]["writer_config"]["writer"] == "wandb" and Meta.config["logging_config"]["writer_config"]["wandb_watch_model"] ): if Meta.config["logging_config"]["writer_config"]["wandb_model_watch_freq"]: wandb.watch( model, log_freq=Meta.config["logging_config"]["writer_config"][ "wandb_model_watch_freq" ], ) else: wandb.watch(model) # Set up optimizer self._set_optimizer(model) # Set up lr_scheduler self._set_lr_scheduler(model) if Meta.config["learner_config"]["fp16"]: try: from apex import amp # type: ignore except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to " "use fp16 training." ) logger.info( f"Modeling training with 16-bit (mixed) precision " f"and {Meta.config['learner_config']['fp16_opt_level']} opt level." ) model, self.optimizer = amp.initialize( model, self.optimizer, opt_level=Meta.config["learner_config"]["fp16_opt_level"], ) # Multi-gpu training (after apex fp16 initialization) if ( Meta.config["learner_config"]["local_rank"] == -1 and Meta.config["model_config"]["dataparallel"] ): model._to_dataparallel() # Distributed training (after apex fp16 initialization) if Meta.config["learner_config"]["local_rank"] != -1: model._to_distributed_dataparallel() # Set to training mode model.train() if Meta.config["meta_config"]["verbose"]: logger.info("Start learning...") self.metrics: Dict[str, float] = dict() self._reset_losses() # Set gradients of all model parameters to zero self.optimizer.zero_grad() batch_iterator = self.task_scheduler.get_batches(train_dataloaders, model) for epoch_num in range(self.start_epoch, self.end_epoch): for train_dataloader in train_dataloaders: # Set epoch for distributed sampler if isinstance(train_dataloader, DataLoader) and isinstance( train_dataloader.sampler, DistributedSampler ): train_dataloader.sampler.set_epoch(epoch_num) step_pbar = tqdm( range(self.start_step, self.end_step), desc=f"Step {self.start_step + 1}/{self.end_step}" if self.use_step_base_counter else f"Epoch {epoch_num + 1}/{self.end_epoch}", disable=not Meta.config["meta_config"]["verbose"] or Meta.config["learner_config"]["local_rank"] not in [-1, 0], ) for step_num in step_pbar: if self.use_step_base_counter: step_pbar.set_description(f"Step {step_num + 1}/{self.total_steps}") step_pbar.refresh() try: batch = next(batch_iterator) except StopIteration: batch_iterator = self.task_scheduler.get_batches( train_dataloaders, model ) batch = next(batch_iterator) # Check if skip the current batch if epoch_num < self.start_train_epoch or ( epoch_num == self.start_train_epoch and step_num < self.start_train_step ): continue # Covert single batch into a batch list if not isinstance(batch, list): batch = [batch] total_step_num = epoch_num * self.n_batches_per_epoch + step_num batch_size = 0 for _batch in batch: batch_size += len(_batch.uids) # Perform forward pass and calcualte the loss and count uid_dict, loss_dict, prob_dict, gold_dict = model( _batch.uids, _batch.X_dict, _batch.Y_dict, _batch.task_to_label_dict, return_probs=Meta.config["learner_config"]["online_eval"], return_action_outputs=False, ) # Update running loss and count for task_name in uid_dict.keys(): identifier = f"{task_name}/{_batch.data_name}/{_batch.split}" self.running_uids[identifier].extend(uid_dict[task_name]) self.running_losses[identifier] += ( loss_dict[task_name].item() * len(uid_dict[task_name]) if len(loss_dict[task_name].size()) == 0 else torch.sum(loss_dict[task_name]).item() ) * model.task_weights[task_name] if ( Meta.config["learner_config"]["online_eval"] and prob_dict and gold_dict ): self.running_probs[identifier].extend(prob_dict[task_name]) self.running_golds[identifier].extend(gold_dict[task_name]) # Calculate the average loss loss = sum( [ model.task_weights[task_name] * task_loss if len(task_loss.size()) == 0 else torch.mean(model.task_weights[task_name] * task_loss) for task_name, task_loss in loss_dict.items() ] ) # Perform backward pass to calculate gradients if Meta.config["learner_config"]["fp16"]: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # type: ignore if (total_step_num + 1) % Meta.config["learner_config"][ "optimizer_config" ]["gradient_accumulation_steps"] == 0 or ( step_num + 1 == self.end_step and epoch_num + 1 == self.end_epoch ): # Clip gradient norm if Meta.config["learner_config"]["optimizer_config"]["grad_clip"]: if Meta.config["learner_config"]["fp16"]: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), Meta.config["learner_config"]["optimizer_config"][ "grad_clip" ], ) else: torch.nn.utils.clip_grad_norm_( model.parameters(), Meta.config["learner_config"]["optimizer_config"][ "grad_clip" ], ) # Update the parameters self.optimizer.step() # Set gradients of all model parameters to zero self.optimizer.zero_grad() if Meta.config["learner_config"]["local_rank"] in [-1, 0]: self.metrics.update(self._logging(model, dataloaders, batch_size)) step_pbar.set_postfix(self.metrics) # Update lr using lr scheduler self._update_lr_scheduler(model, total_step_num, self.metrics) step_pbar.close() if Meta.config["learner_config"]["local_rank"] in [-1, 0]: model = self.logging_manager.close(model) logger.info(f"Total learning time: {time.time() - start_time} seconds.")