Source code for emmental.logging.checkpointer

"""Emmental checkpointer."""
import glob
import logging
import os
from shutil import copyfile
from typing import Dict, List, Set, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer

from emmental.meta import Meta
from emmental.model import EmmentalModel

logger = logging.getLogger(__name__)


[docs]class Checkpointer(object): """Checkpointing class to log train information.""" def __init__(self) -> None: """Initialize the checkpointer.""" # Set up checkpoint directory self.checkpoint_path = Meta.config["logging_config"]["checkpointer_config"][ "checkpoint_path" ] if self.checkpoint_path is None: self.checkpoint_path = Meta.log_path # Create checkpoint directory if necessary if not os.path.exists(self.checkpoint_path): os.makedirs(self.checkpoint_path) # Set up checkpoint frequency self.checkpoint_freq = ( Meta.config["logging_config"]["evaluation_freq"] * Meta.config["logging_config"]["checkpointer_config"]["checkpoint_freq"] ) if self.checkpoint_freq <= 0: raise ValueError( f"Invalid checkpoint freq {self.checkpoint_freq}, " f"must be greater 0." ) # Set up checkpoint unit self.checkpoint_unit = Meta.config["logging_config"]["counter_unit"] logger.info( f"Save checkpoints at {self.checkpoint_path} every " f"{self.checkpoint_freq} {self.checkpoint_unit}" ) # Set up checkpoint metric self.checkpoint_metric = Meta.config["logging_config"]["checkpointer_config"][ "checkpoint_metric" ] self.checkpoint_all_metrics = Meta.config["logging_config"][ "checkpointer_config" ]["checkpoint_task_metrics"] # Collect all metrics to checkpoint if self.checkpoint_all_metrics is None: self.checkpoint_all_metrics = dict() if self.checkpoint_metric: self.checkpoint_all_metrics.update(self.checkpoint_metric) # Check evaluation metric mode for metric, mode in self.checkpoint_all_metrics.items(): if mode not in ["min", "max"]: raise ValueError( f"Unrecognized checkpoint metric mode {mode} for metric {metric}, " f"must be 'min' or 'max'." ) self.checkpoint_runway = Meta.config["logging_config"]["checkpointer_config"][ "checkpoint_runway" ] logger.info( f"No checkpoints saved before {self.checkpoint_runway} " f"{self.checkpoint_unit}." ) self.checkpoint_all = Meta.config["logging_config"]["checkpointer_config"][ "checkpoint_all" ] logger.info(f"Checkpointing all checkpoints: {self.checkpoint_all}.") self.checkpoint_paths: List[str] = [] # Set up checkpoint clear self.clear_intermediate_checkpoints = Meta.config["logging_config"][ "checkpointer_config" ]["clear_intermediate_checkpoints"] self.clear_all_checkpoints = Meta.config["logging_config"][ "checkpointer_config" ]["clear_all_checkpoints"] # Set up checkpoint flag self.checkpoint_condition_met = False self.best_metric_dict: Dict[str, float] = dict()
[docs] def checkpoint( self, iteration: Union[float, int], model: EmmentalModel, optimizer: Optimizer, lr_scheduler: _LRScheduler, metric_dict: Dict[str, float], ) -> None: """Checkpointing the checkpoint. Args: iteration: The current iteration. model: The model to checkpoint. optimizer: The optimizer used during training process. lr_scheduler: Learning rate scheduler. metric_dict: The metric dict. """ # Check the checkpoint_runway condition is met if iteration < self.checkpoint_runway: return elif not self.checkpoint_condition_met and iteration >= self.checkpoint_runway: self.checkpoint_condition_met = True logger.info("checkpoint_runway condition has been met. Start checkpoining.") # Save model state model_path = f"{self.checkpoint_path}/checkpoint_{iteration}.model.pth" model.save(model_path, verbose=False) logger.info( f"Save checkpoint of {iteration} {self.checkpoint_unit} " f"at {model_path}." ) # Save optimizer state optimizer_path = f"{self.checkpoint_path}/checkpoint_{iteration}.optimizer.pth" optimizer_dict = { "optimizer": optimizer.state_dict(), } torch.save(optimizer_dict, optimizer_path) # Save lr_scheduler state scheduler_path = f"{self.checkpoint_path}/checkpoint_{iteration}.scheduler.pth" scheduler_dict = { "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None } torch.save(scheduler_dict, scheduler_path) if self.checkpoint_all is False: for path in self.checkpoint_paths: if os.path.exists(path): os.remove(path) self.checkpoint_paths.extend([model_path, optimizer_path, scheduler_path]) if not set(self.checkpoint_all_metrics.keys()).isdisjoint( set(metric_dict.keys()) ): new_best_metrics = self.is_new_best(metric_dict) for metric in new_best_metrics: best_metric_model_path = ( f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.model.pth" ) copyfile( model_path, best_metric_model_path, ) logger.info( f"Save best model of metric {metric} to {best_metric_model_path}" ) best_metric_optimizer_path = ( f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.optimizer.pth" ) copyfile(optimizer_path, best_metric_optimizer_path) best_metric_scheduler_path = ( f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.scheduler.pth" ) copyfile(scheduler_path, best_metric_scheduler_path)
[docs] def is_new_best(self, metric_dict: Dict[str, float]) -> Set[str]: """Update the best score. Args: metric_dict: The current metric dict. Returns: The updated best metric set. """ best_metric = set() for metric in metric_dict: if metric not in self.checkpoint_all_metrics: continue if metric not in self.best_metric_dict: self.best_metric_dict[metric] = metric_dict[metric] best_metric.add(metric) elif ( self.checkpoint_all_metrics[metric] == "max" and metric_dict[metric] > self.best_metric_dict[metric] ): self.best_metric_dict[metric] = metric_dict[metric] best_metric.add(metric) elif ( self.checkpoint_all_metrics[metric] == "min" and metric_dict[metric] < self.best_metric_dict[metric] ): self.best_metric_dict[metric] = metric_dict[metric] best_metric.add(metric) return best_metric
[docs] def clear(self) -> None: """Clear checkpoints.""" if self.clear_all_checkpoints: logger.info("Clear all checkpoints.") file_list = glob.glob(f"{self.checkpoint_path}/*.pth") for file in file_list: os.remove(file) elif self.clear_intermediate_checkpoints: logger.info("Clear all intermediate checkpoints.") file_list = glob.glob(f"{self.checkpoint_path}/checkpoint_*.pth") for file in file_list: os.remove(file)
[docs] def load_best_model(self, model: EmmentalModel) -> EmmentalModel: """Load the best model from the checkpoint. Args: model: The current model. Returns: The best model load from the checkpoint. """ if list(self.checkpoint_metric.keys())[0] not in self.best_metric_dict: logger.info("No best model found, use the original model.") else: # Load the best model of checkpoint_metric metric = list(self.checkpoint_metric.keys())[0] best_model_path = ( f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.model.pth" ) model.load(best_model_path, verbose=False) logger.info(f"Loading the best model from {best_model_path}.") return model