Source code for emmental.logging.logging_manager

"""Emmental logging manager."""
import logging
from typing import Dict, Union

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer

from emmental.logging.checkpointer import Checkpointer
from emmental.logging.json_writer import JsonWriter
from emmental.logging.log_writer import LogWriter
from emmental.logging.tensorboard_writer import TensorBoardWriter
from emmental.logging.wandb_writer import WandbWriter
from emmental.meta import Meta
from emmental.model import EmmentalModel

logger = logging.getLogger(__name__)


[docs]class LoggingManager(object): """A class to manage logging during training progress. Args: n_batches_per_epoch: Total number batches per epoch. """ def __init__( self, n_batches_per_epoch: int, epoch_count: int = 0, batch_count: int = 0 ) -> None: """Initialize LoggingManager.""" self.n_batches_per_epoch = n_batches_per_epoch # Set up counter # Set up evaluation/checkpointing unit (sample, batch, epoch) self.counter_unit = Meta.config["logging_config"]["counter_unit"] if self.counter_unit not in ["sample", "batch", "epoch"]: raise ValueError(f"Unrecognized unit: {self.counter_unit}") # Set up evaluation frequency self.evaluation_freq = Meta.config["logging_config"]["evaluation_freq"] if Meta.config["meta_config"]["verbose"]: logger.info(f"Evaluating every {self.evaluation_freq} {self.counter_unit}.") if Meta.config["logging_config"]["checkpointing"]: self.checkpointing = True # Set up checkpointing frequency self.checkpointing_freq = int( Meta.config["logging_config"]["checkpointer_config"]["checkpoint_freq"] ) if Meta.config["meta_config"]["verbose"]: logger.info( f"Checkpointing every " f"{self.checkpointing_freq * self.evaluation_freq} " f"{self.counter_unit}." ) # Set up checkpointer self.checkpointer = Checkpointer() else: self.checkpointing = False if Meta.config["meta_config"]["verbose"]: logger.info("No checkpointing.") # Set up number of samples passed since last evaluation/checkpointing and # total number of samples passed since learning process self.sample_count: int = 0 self.sample_total: int = 0 # Set up number of batches passed since last evaluation/checkpointing and # total number of batches passed since learning process self.batch_count: int = batch_count self.batch_total: int = batch_count if self.batch_count != 0: if self.counter_unit == "batch": while self.batch_count >= self.evaluation_freq: self.batch_count -= self.evaluation_freq elif self.counter_unit == "epoch": while ( self.batch_count >= self.evaluation_freq * self.n_batches_per_epoch ): self.batch_count -= self.evaluation_freq * self.n_batches_per_epoch # Set up number of epochs passed since last evaluation/checkpointing and # total number of epochs passed since learning process self.epoch_count: Union[float, int] = epoch_count self.epoch_total: Union[float, int] = epoch_count if self.epoch_count != 0: if self.counter_unit == "epoch": while self.epoch_count >= self.evaluation_freq: self.epoch_count -= self.evaluation_freq elif self.counter_unit == "batch": while ( self.epoch_count >= self.evaluation_freq / self.n_batches_per_epoch ): self.epoch_count -= self.evaluation_freq / self.n_batches_per_epoch # Set up number of unit passed since last evaluation/checkpointing and # total number of unit passed since learning process self.unit_count: Union[float, int] = 0 self.unit_total: Union[float, int] = 0 # Set up count that triggers the evaluation since last checkpointing self.trigger_count = 0 # Set up log writer writer_opt = Meta.config["logging_config"]["writer_config"]["writer"] if writer_opt is None: self.writer = LogWriter() elif writer_opt == "json": self.writer = JsonWriter() elif writer_opt == "tensorboard": self.writer = TensorBoardWriter() elif writer_opt == "wandb": self.writer = WandbWriter() else: raise ValueError(f"Unrecognized writer option '{writer_opt}'") self.log_unit_sanity_check = False
[docs] def update(self, batch_size: int) -> None: """Update the counter. Args: batch_size: The number of the samples in the batch. """ # Update number of samples self.sample_count += batch_size self.sample_total += batch_size # Update number of batches self.batch_count += 1 self.batch_total += 1 # Update number of epochs self.epoch_count = self.batch_count / self.n_batches_per_epoch self.epoch_total = self.batch_total / self.n_batches_per_epoch if self.epoch_count == int(self.epoch_count): self.epoch_count = int(self.epoch_count) if self.epoch_total == int(self.epoch_total): self.epoch_total = int(self.epoch_total) # Update number of units if self.counter_unit == "sample": self.unit_count = self.sample_count self.unit_total = self.sample_total if self.counter_unit == "batch": self.unit_count = self.batch_count self.unit_total = self.batch_total elif self.counter_unit == "epoch": self.unit_count = self.epoch_count self.unit_total = self.epoch_total
[docs] def trigger_evaluation(self) -> bool: """Check if triggers the evaluation.""" satisfied = self.unit_count >= self.evaluation_freq if satisfied: self.trigger_count += 1 self.reset() return satisfied
[docs] def trigger_checkpointing(self) -> bool: """Check if triggers the checkpointing.""" if not self.checkpointing: return False satisfied = self.trigger_count >= self.checkpointing_freq if satisfied: self.trigger_count = 0 return satisfied
[docs] def reset(self) -> None: """Reset the counter.""" self.sample_count = 0 self.batch_count = 0 self.epoch_count = 0 self.unit_count = 0
[docs] def write_log(self, metric_dict: Dict[str, float]) -> None: """Write the metrics to the log. Args: metric_dict: The metric dict. """ unit_total: Union[float, int] = self.unit_total # As Tensorboard/Wandb only allow integer values for unit count, emmental casts # non integer value to integer and switch the counter unit from epoch to batch. if ( Meta.config["logging_config"]["writer_config"]["writer"] in ["tensorboard", "wandb"] and self.counter_unit == "epoch" and int(self.evaluation_freq) != self.evaluation_freq ): if not self.log_unit_sanity_check: logger.warning( "Cannot use float value for evaluation_freq when counter_unit " "uses epoch with tensorboard writer, switch to batch as " "count_unit." ) self.log_unit_sanity_check = True unit_total = self.batch_total self.writer.add_scalar_dict(metric_dict, unit_total)
[docs] def checkpoint_model( self, model: EmmentalModel, optimizer: Optimizer, lr_scheduler: _LRScheduler, metric_dict: Dict[str, float], ) -> None: """Checkpoint the model. Args: model: The model to checkpoint. optimizer: The optimizer used during training process. lr_scheduler: Learning rate scheduler. metric_dict: the metric dict. """ self.checkpointer.checkpoint( self.unit_total, model, optimizer, lr_scheduler, metric_dict )
[docs] def close(self, model: EmmentalModel) -> EmmentalModel: """Close the checkpointer and reload the model if necessary. Args: model: The trained model. Returns: The reloaded model if necessary """ self.writer.close() if self.checkpointing: model = self.checkpointer.load_best_model(model) self.checkpointer.clear() return model