"""Emmental precision scorer."""
from typing import Dict, List, Optional
import numpy as np
from numpy import ndarray
from emmental.utils.utils import prob_to_pred
[docs]def precision_scorer(
golds: ndarray,
probs: Optional[ndarray],
preds: ndarray,
uids: Optional[List[str]] = None,
pos_label: int = 1,
) -> Dict[str, float]:
"""Precision.
Args:
golds: Ground truth values.
probs: Predicted probabilities.
preds: Predicted values.
uids: Unique ids, defaults to None.
pos_label: The positive class label, defaults to 1.
Returns:
Precision.
"""
# Convert probabilistic label to hard label
if len(golds.shape) == 2:
golds = prob_to_pred(golds)
pred_pos = np.where(preds == pos_label, True, False)
gt_pos = np.where(golds == pos_label, True, False)
TP = np.sum(pred_pos * gt_pos)
FP = np.sum(pred_pos * np.logical_not(gt_pos))
precision = TP / (TP + FP) if TP + FP > 0 else 0.0
return {"precision": precision}