Source code for emmental.schedulers.sequential_scheduler

"""Emmental sequential scheduler."""
from typing import Iterator, List, Union

from emmental.data import EmmentalDataLoader
from emmental.model import EmmentalModel
from emmental.schedulers.scheduler import Batch, Scheduler


[docs]class SequentialScheduler(Scheduler): """Generate batch generator from all dataloaders in sequential order. Args: fillup: Whether fillup to make all dataloader the same size. """ def __init__(self, fillup: bool = False) -> None: """Initialize SequentialScheduler.""" super().__init__() self.fillup = fillup
[docs] def get_num_batches(self, dataloaders: List[EmmentalDataLoader]) -> int: """Get total number of batches per epoch. Args: dataloaders: List of dataloaders. Returns: Total number of batches per epoch. """ batch_counts = [len(dataloader) for dataloader in dataloaders] if self.fillup: batch_counts = [max(batch_counts)] * len(dataloaders) for idx in range(len(dataloaders)): if dataloaders[idx].n_batches: batch_counts[idx] = dataloaders[idx].n_batches return sum(batch_counts)
[docs] def get_batches( self, dataloaders: List[EmmentalDataLoader], model: EmmentalModel = None ) -> Iterator[Union[Batch, List[Batch]]]: """Generate batch generator from all dataloaders for one epoch. Args: dataloaders: List of dataloaders. model: The training model, defaults to None. Returns: A generator of all batches. """ task_to_label_dicts = [ dataloader.task_to_label_dict for dataloader in dataloaders ] uid_names = [dataloader.uid for dataloader in dataloaders] data_names = [dataloader.data_name for dataloader in dataloaders] splits = [dataloader.split for dataloader in dataloaders] data_loaders = [iter(dataloader) for dataloader in dataloaders] # Calc the batch size for each dataloader batch_counts = [len(dataloader) for dataloader in dataloaders] if self.fillup: batch_counts = [max(batch_counts)] * len(dataloaders) for idx in range(len(dataloaders)): if dataloaders[idx].n_batches: batch_counts[idx] = dataloaders[idx].n_batches for ( data_loader_idx, (task_to_label_dict, data_name, batch_count, split, uid_name), ) in enumerate( zip(task_to_label_dicts, data_names, batch_counts, splits, uid_names) ): for batch_idx in range(batch_count): try: batch = next(data_loaders[data_loader_idx]) except StopIteration: data_loaders[data_loader_idx] = iter(dataloaders[data_loader_idx]) batch = next(data_loaders[data_loader_idx]) if not isinstance(batch, dict): X_dict, Y_dict = batch else: X_dict = batch Y_dict = None yield Batch( X_dict[uid_name], X_dict, Y_dict, task_to_label_dict, data_name, split, )