Source code for emmental.metrics.accuracy

"""Emmental accuracy scorer."""
from typing import Dict, List, Optional, Union

import numpy as np
from numpy import ndarray

from emmental.utils.utils import prob_to_pred


[docs]def accuracy_scorer( golds: ndarray, probs: Optional[ndarray], preds: Optional[ndarray], uids: Optional[List[str]] = None, normalize: bool = True, topk: int = 1, ) -> Dict[str, Union[float, int]]: """Accuracy classification score. Args: golds: Ground truth values. probs: Predicted probabilities. preds: Predicted values. uids: Unique ids, defaults to None. normalize: Normalize the results or not, defaults to True. topk: Top K accuracy, defaults to 1. Returns: Accuracy, if normalize is True, return the fraction of correctly predicted samples (float), else returns the number of correctly predicted samples (int). """ # Convert probabilistic label to hard label if len(golds.shape) == 2: golds = prob_to_pred(golds) if topk == 1 and preds is not None: n_matches = np.where(golds == preds)[0].shape[0] else: topk_preds = probs.argsort(axis=1)[:, -topk:][:, ::-1] n_matches = np.logical_or.reduce( topk_preds == golds.reshape(-1, 1), axis=1 ).sum() if normalize: return { "accuracy" if topk == 1 else f"accuracy@{topk}": n_matches / golds.shape[0] } else: return {"accuracy" if topk == 1 else f"accuracy@{topk}": n_matches}