Source code for recognite.data.k_fold

from typing import Tuple

import numpy as np
import pandas as pd


[docs]def k_fold_trainval_split( df: pd.DataFrame, num_folds: int = 5, val_fold: int = 0, seed: int = 0, label_key='label' ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Splits the given DataFrame into a train and validation subset. The subsets are composed by shuffling the labels in the DataFrame, using random seed ``seed``. The labels are then split into ``num_folds`` folds, where each label can only be in a single fold. We then choose one fold for validation (as given by ``val_fold``) and the other ``num_folds - 1`` folds for training. Args: df: The DataFrame to split. num_folds: The number of folds to use for splitting the dataset into training and validation. Note that the folds are label-based, not sample-based. val_fold: The index of the fold to use for validation. The others will be used for training. seed: The random seed to use for k-fold splitting. label_key: The column in the CSV file that contains the label of each image. Returns: A tuple with the training DataFrame the validation DataFrame. """ assert val_fold < num_folds labels = df[label_key].unique() np.random.seed(seed) np.random.shuffle(labels) label_folds = np.array_split(labels, num_folds) val_labels = label_folds.pop(val_fold) train_labels = np.concatenate(label_folds) df_val = df[df[label_key].isin(val_labels)].reset_index(drop=True).copy() df_train = df[df[label_key].isin(train_labels)].reset_index(drop=True) return df_train, df_val