Source code for recognite.eval.hard_pos_neg_scores

from typing import Dict

import torch
from torch import Tensor
import math


[docs]def hard_pos_neg_scores( scores: Tensor, gallery_labels: Tensor, query_labels: Tensor, ) -> Dict[str, Tensor]: """ Computes the similarity scores between each query and the hardest negative and hardest positive in the gallery. Args: scores: The scores for each query (rows) and each gallery item (columns). query_labels: The true label of each query (rows of ``scores``). gallery_labels: The labels of the items in the gallery (columns of ``scores``). Returns: A dictionary with the following items - ``'HardPosScores'``: The score of the hardest positive of each query. - ``'HardNegScores'``: The score of the hardest negative of each query. """ # Subtract infinity from the positive labels, so we can find the # closest negative pos_mask = query_labels[:, None] == gallery_labels[None, :] hard_neg_scoremat = torch.clone(scores) hard_neg_scoremat[pos_mask] -= math.inf hardest_neg_scores = hard_neg_scoremat.max(dim=1)[0] # Note: an item of hardest_neg_scores will be -inf if there are no # negatives for that query hardest_neg_scores = hardest_neg_scores[~torch.isinf(hardest_neg_scores)] hard_pos_scoremat = torch.clone(scores) hard_pos_scoremat[~pos_mask] += math.inf hardest_pos_scores = hard_pos_scoremat.min(dim=1)[0] # Note: an item of hardest_pos_scores will be +inf if there are no # positives for that query hardest_pos_scores = hardest_pos_scores[~torch.isinf(hardest_pos_scores)] return { 'HardPosScores': hardest_pos_scores, 'HardNegScores': hardest_neg_scores }