|
import random |
|
import typing |
|
import torch |
|
from torch.utils.data import Dataset |
|
import torch.nn.functional as F |
|
|
|
class FewShotDataset(Dataset): |
|
""" A custom Dataset class for Few-Shot Learning tasks. |
|
This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """ |
|
def __init__(self, dataset, indices: list, classes: list, transform:typing.Callable, mode="support"): |
|
""" Args: |
|
dataset (list): List of (feature, label) pairs. |
|
indices (list): List of indices to be used for the dataset. |
|
transform (callable): Transform to be applied to the features. |
|
mode (str): Mode of operation, either "support" or "query". Default is "support". """ |
|
assert mode in ["support", "query"], "Invalid mode. Must be either 'support' or 'query'." |
|
assert dataset and indices and classes is not None, "Dataset or indices cannot be None." |
|
|
|
self.dataset, self.indices, self.classes = dataset, indices, classes |
|
self.mode, self.transform = mode, transform |
|
|
|
|
|
def __getitem__(self, index: int): |
|
""" Returns a sample from the dataset at the given index. |
|
Args: index of the sample to be retrieved. |
|
Returns: tuple of the transformed feature and the label. """ |
|
if index >= len(self.indices): |
|
raise IndexError("Index out of bounds") |
|
feature, label = self.dataset[self.indices[index]] |
|
|
|
feature = self.transform(feature) |
|
if self.mode == "query": |
|
label = F.one_hot(torch.tensor(self.classes.index(label)), num_classes=len(self.classes)).float() |
|
return feature, label |
|
|
|
|
|
def __len__(self): return len(self.indices) |
|
|
|
|
|
class FewShotEpisoder: |
|
""" A class to generate episodes for Few-Shot Learning. |
|
Each episode consists of a support set and a query set. """ |
|
def __init__(self, dataset, classes: list, k_shot: int, n_query: int, transform: typing.Callable): |
|
""" Args: |
|
dataset (Dataset): The base dataset to generate episodes from. |
|
k_shot (int): Number of support samples per class. |
|
n_query (int): Number of query samples per class. |
|
transform (callable): Transform to be applied to the features. """ |
|
assert k_shot > 0 and n_query > 0, "k_shot and n_query must be greater than 0." |
|
|
|
self.k_shot, self.n_query, self.classes = k_shot, n_query, classes |
|
self.dataset, self.transform = dataset, transform |
|
self.indices_c = self.get_class_indices() |
|
|
|
|
|
def get_class_indices(self) -> dict: |
|
""" Initialize the class indices for the dataset. |
|
Returns: tuple of Number of classes and a list of indices grouped by class. """ |
|
indices_c = {label: [] for label in range(self.classes.__len__())} |
|
for index, (_, label) in enumerate(self.dataset): |
|
if label in self.classes: indices_c[self.classes.index(label)].append(index) |
|
for label, _indices_c in indices_c.items(): |
|
indices_c[label] = random.sample(_indices_c, self.k_shot + self.n_query) |
|
return indices_c |
|
|
|
|
|
def get_episode(self) -> tuple: |
|
""" Generate an episode consisting of a support set and a query set. |
|
Returns: tuple of A FewShotDataset for the support set and a FewShotDataset for the query set. """ |
|
|
|
support_examples, query_examples = [], [] |
|
for class_label in range(self.classes.__len__()): |
|
if len(self.indices_c[class_label]) < self.k_shot + self.n_query: continue |
|
selected_indices = random.sample(self.indices_c[class_label], self.k_shot + self.n_query) |
|
support_examples.extend(selected_indices[:self.k_shot]) |
|
query_examples.extend(selected_indices) |
|
|
|
|
|
support_set = FewShotDataset(self.dataset, support_examples, self.classes, self.transform, "support") |
|
query_set = FewShotDataset(self.dataset, query_examples, self.classes, self.transform, "query") |
|
|
|
return support_set, query_set |
|
|
|
|