Source code for emmental.schedulers.round_robin_scheduler

"""Emmental round robin scheduler."""
import random
from typing import Iterator, List, Union

from emmental.data import EmmentalDataLoader
from emmental.model import EmmentalModel
from emmental.schedulers.scheduler import Batch, Scheduler


[docs]class RoundRobinScheduler(Scheduler): """Generate batch generator from all dataloaders in round robin order. Args: fillup: Whether fillup to make all dataloader the same size. """ def __init__(self, fillup: bool = False) -> None: """Initialize RoundRobinScheduler.""" super().__init__() self.fillup = fillup
[docs] def get_num_batches(self, dataloaders: List[EmmentalDataLoader]) -> int: """Get total number of batches per epoch. Args: dataloaders: List of dataloaders. Returns: Total number of batches per epoch. """ batch_counts = [len(dataloader) for dataloader in dataloaders] if self.fillup: batch_counts = [max(batch_counts)] * len(dataloaders) for idx in range(len(dataloaders)): if dataloaders[idx].n_batches: batch_counts[idx] = dataloaders[idx].n_batches return sum(batch_counts)
[docs] def get_batches( self, dataloaders: List[EmmentalDataLoader], model: EmmentalModel = None ) -> Iterator[Union[Batch, List[Batch]]]: """Generate batch generator from all dataloaders for one epoch. Args: dataloaders: List of dataloaders. model: The training model, defaults to None. Returns: A generator of all batches. """ task_to_label_dicts = [ dataloader.task_to_label_dict for dataloader in dataloaders ] uid_names = [dataloader.uid for dataloader in dataloaders] data_names = [dataloader.data_name for dataloader in dataloaders] splits = [dataloader.split for dataloader in dataloaders] data_loaders = [iter(dataloader) for dataloader in dataloaders] # Calc the batch size for each dataloader batch_counts = [len(dataloader) for dataloader in dataloaders] if self.fillup: batch_counts = [max(batch_counts)] * len(dataloaders) for idx in range(len(dataloaders)): if dataloaders[idx].n_batches: batch_counts[idx] = dataloaders[idx].n_batches dataloader_indexer = [] for idx, count in enumerate(batch_counts): dataloader_indexer.extend([idx] * count) random.shuffle(dataloader_indexer) for data_loader_idx in dataloader_indexer: uid_name = uid_names[data_loader_idx] try: batch = next(data_loaders[data_loader_idx]) except StopIteration: data_loaders[data_loader_idx] = iter(dataloaders[data_loader_idx]) batch = next(data_loaders[data_loader_idx]) if not isinstance(batch, dict): X_dict, Y_dict = batch else: X_dict = batch Y_dict = None yield Batch( X_dict[uid_name], X_dict, Y_dict, task_to_label_dicts[data_loader_idx], data_names[data_loader_idx], splits[data_loader_idx], )