Source code for recognite.eval.pr_metrics

from typing import List, Union, Tuple, Dict

import torch
from torch import Tensor


[docs]def pr_metrics( scores: Tensor, gallery_labels: Tensor, query_labels: Tensor ) -> Dict[str, Union[List[Tensor], Tensor]]: """Computes the precision-recall curves and related metrics. Args: scores: The scores for each query (rows) and each gallery item (columns). gallery_labels: The labels of the items in the gallery (columns of ``scores``). query_labels: The true label of each query (rows of ``scores``). Returns: A dictionary containing the metrics. The dict contains the following keys: - ``'Precisions'``: The precision values of the PR-curve of each query. - ``'Recalls'``: The recall values of the PR-curve of each query. - ``'Thresholds'``: The thresholds that correspond to each point on the PR curve. - ``'AP'``: The average precision of each query. - ``'mAP'``: The mean of the average precisions of all queries. - ``'P@maxF1'``: The precision at the point on the PR curve with the maximum F1 score for each query. - ``'R@maxF1'``: The recall at the point on the PR curve with the maximum F1 score for each query. - ``'T@maxF1'``: The threshold at the point on the PR curve with the maximum F1 score for each query. - ``'maxF1'``: The maximum F1 score for each query. """ if not scores.shape[0] == len(query_labels): raise ValueError( 'First dim of scores should equal the number of queries' ) if not scores.shape[1] == len(gallery_labels): raise ValueError( 'Second dim of scores should equal the number of gallery items' ) targets = query_labels[:, None] == gallery_labels p_per_q, r_per_q, th_per_q = _pr_curve_per_q(scores, targets) ap = _ap_from_pr( precision=p_per_q, recall=r_per_q, num_classes=len(p_per_q) ) mAP = ap.mean() p_at_max_f1, r_at_max_f1, th_at_max_f1, max_f1 = \ _prtf_at_max_f1(p_per_q, r_per_q, th_per_q) return { 'Precisions': p_per_q, 'Recalls': r_per_q, 'Thresholds': th_per_q, 'AP': ap, 'mAP': mAP, 'P@maxF1': p_at_max_f1, 'R@maxF1': r_at_max_f1, 'T@maxF1': th_at_max_f1, 'maxF1': max_f1, }
def _pr_curve_per_q( scores: Tensor, targets: Tensor ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: """Computes the PR curve for each query. Args: scores: The scores for each query (rows) and each gallery item (columns). targets: Boolean tensor of the same shape as ``scores`` indicating which query-gallery pairs belong to the same class. Returns: A tuple of three lists: precisions, recalls and thresholds. The i-th element of each list corresponds to the i-th query. """ tps, fps, thresh = _binary_clf_curve_per_q(scores, targets) precision = tps / (tps + fps) recall = tps / tps.nan_to_num(nan=0).max(dim=1)[0][..., None] precision = torch.hstack([ torch.fliplr(precision), torch.ones(len(precision)).type_as(precision)[..., None] ]) recall = torch.hstack([ torch.fliplr(recall), torch.zeros(len(recall)).type_as(recall)[..., None] ]) thresh = torch.fliplr(thresh) nan_masks = precision.isnan() p_per_q = [] r_per_q = [] th_per_q = [] idxs_with_no_positives = torch.nonzero(~targets.any(dim=1), as_tuple=True)[0] for i, (p, r, th, nan_mask) in enumerate(zip(precision, recall, thresh, nan_masks)): p_per_q.append(p[~nan_mask]) r = r[~nan_mask] if i in idxs_with_no_positives: # If a query has no positives, the first recall score will be NaN # due to a 0 / 0 division (i.e. tps / max(tps) = 0 / 0). In that # case, we define the recall to be 1 (precision will be 0 there) assert r[0].isnan() r[0] = 1. r_per_q.append(r) th_per_q.append(th[~nan_mask[:-1]]) return p_per_q, r_per_q, th_per_q def _binary_clf_curve_per_q( scores: Tensor, targets: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: """ Gets the number of true positives and false positives for each possible threshold for each query. Adapted from torchmetrics/functional/classification/precision_recall_curve.py Args: scores: The scores for each query (rows) and each gallery item (columns). targets: Boolean tensor of the same shape as ``scores`` indicating which query-gallery pairs belong to the same class. Returns: A tuple containing the true positives, false positives and scores for each query. The tensors are sorted according to descending scores. """ idxs = torch.argsort(scores, descending=True) scores = torch.gather(scores, 1, idxs) targets = torch.gather(targets, 1, idxs) tps = torch.cumsum(targets, dim=1).to(torch.float) fps = (torch.arange(1, targets.size(1) + 1) .type_as(tps) .repeat(targets.size(0), 1) - tps) # Extract indices associated with distinct values. distinct_val_idxs = torch.where(scores[:, :-1] - scores[:, 1:]) distinct_val_mask = torch.zeros(scores.shape).to(torch.bool) distinct_val_mask[distinct_val_idxs] = True # "Concatenate" a value for the end of the curve distinct_val_mask[:, -1] = True tps[~distinct_val_mask] = float('nan') fps[~distinct_val_mask] = float('nan') scores[~distinct_val_mask] = float('nan') # stop when full recall is attained last_inds = torch.argmax( (tps == tps[:, -1].reshape(-1, 1)).to(torch.float), dim=1 ) mask = torch.zeros(scores.shape).to(torch.bool) mask[torch.arange(0, len(tps)), last_inds] = True mask = torch.cumsum(mask, dim=1).roll(1).to(torch.bool) mask[:, 0] = False tps[mask] = float('nan') fps[mask] = float('nan') scores[mask] = float('nan') return tps, fps, scores def _prtf_at_max_f1( precisions: Tensor, recalls: Tensor, thresholds: Tensor ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ Returns the precision, recall, threshold and F1-score at the max F1-score for each query. Args: precisions: The precision values for each query. recalls: The recall values for each query. thresholds: The respective threshold values. Returns: A tuple containing the precision, recall, threshold and F1-score for each query. """ f1s = [ 2 * (p * r)/(p + r) for p, r in zip(precisions, recalls) ] max_f1_idxs = [torch.argmax(torch.nan_to_num(f1)) for f1 in f1s] prtfs_at_max_f1 = [ (p[idx], r[idx], th[idx], f1[idx]) for idx, p, r, th, f1 in zip( max_f1_idxs, precisions, recalls, thresholds, f1s, ) ] (p_at_max_f1, r_at_max_f1, th_at_max_f1, max_f1) = list(zip(*prtfs_at_max_f1)) return (torch.tensor(p_at_max_f1).nan_to_num(), torch.tensor(r_at_max_f1).nan_to_num(), torch.tensor(th_at_max_f1).nan_to_num(), torch.tensor(max_f1).nan_to_num()) def _ap_from_pr( precision: Tensor, recall: Tensor, num_classes: int, ) -> Union[List[Tensor], Tensor]: """Computes the average precision score from precision and recall. Copied from torchmetrics.functional.classification.average_precision Args: precision: precision values recall: recall values num_classes: integer with number of classes. """ # Return the step function integral # The following works because the last entry of precision is # guaranteed to be 1, as returned by precision_recall_curve if num_classes == 1: return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) res = [] for p, r in zip(precision, recall): res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) return torch.tensor(res)