Source code for recognite.utils.avg_ref_embs

from typing import Tuple

import torch


[docs]def avg_ref_embs( embeddings: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Groups the embeddings by label and computes the average embedding. Args: embeddings: The embeddings (one embedding per row). labels: The label of each embedding. Returns: A tuple containing the averaged embeddings and their corresponding labels. """ labels = labels.view(labels.size(0), 1).expand(-1, embeddings.size(1)) unique_labels, inv_idxs, labels_count = labels.unique( dim=0, return_counts=True, return_inverse=True ) # Allocate new tensor agg_embeddings = torch.zeros_like(unique_labels).type_as(embeddings) # Sum all embeddings with the same label index inv_idxs = inv_idxs.view(inv_idxs.size(0), 1).expand_as(labels) agg_embeddings = torch.scatter_add(agg_embeddings, dim=0, index=inv_idxs, src=embeddings) # Divide sums by lengths agg_embeddings = agg_embeddings / labels_count.float().unsqueeze(1) return agg_embeddings, unique_labels[:, 0]