Source code for recognite.data.data_frame_dataset

from typing import Dict, Callable, Optional

import pandas as pd
from PIL import Image
from torch.utils.data import Dataset


[docs]class DataFrameDataset(Dataset): """A dataset based on a pandas DataFrame. The provided DataFrame contains the path of each image and the corresponding label. Args: df: The DataFrame containing the image paths and labels. label_key: The column in the DataFrame that contains the label of each image. image_key: The column in the DataFrame that contains the image path of each image. label_to_int: A dictionary that maps the label to a unique integer. transform: A transform to apply to the image before returning it. Attributes: df: The DataFrame with the image paths and labels. label_key: The column in the DataFrame that contains the label of each image. image_key: The column in the DataFrame that contains the image path of each image. label_to_int: The dictionary that maps the label to a unique integer. transform: The transform applied to each image before returning it. unique_labels: The unique labels present in this dataset. """
[docs] def __init__( self, df: pd.DataFrame, label_key: str, image_key: str, label_to_int: Dict[str, int], transform: Optional[Callable] = None, ): self.df = df self.transform = transform self.label_key = label_key self.image_key = image_key self.label_to_int = label_to_int self.unique_labels = set(self.label_to_int.keys())
def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] im = Image.open(row[self.image_key]) label = self.label_to_int[row[self.label_key]] if self.transform is not None: im = self.transform(im) return im, label