Source code for domid.dsets.dset_unittest
import os
import shutil
import numpy as np
import pandas as pd
import torch
from domainlab.dsets.utils_data import mk_fun_label2onehot
from domainlab.utils.utils_class import store_args
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets, transforms
[docs]class DsetUnitTest(Dataset):
"""
This dataset is solely used for unit testing of loss values.
The images contain tensors of one with the dimension of 1x16x16, the label is a random integer.
"""
[docs] @store_args
def __init__(self, digit, args, subset_step=1, list_transforms=None):
dpath = os.path.normpath(args.dpath)
self.digit = digit
if not os.path.exists(dpath):
self.create_the_dataset(dpath)
self.images = torch.load(os.path.join(dpath, "images.pt"))
self.labels = torch.load(os.path.join(dpath, "labels.pt")).squeeze(1)
self.images = self.images[self.labels == digit]
self.args = args
self.inject_variable = args.inject_var
[docs] def create_the_dataset(self, dpath):
# Check if the directory exists
seed = 42
torch.manual_seed(seed)
if not os.path.exists(dpath):
os.makedirs(dpath)
dummy_images = torch.ones(7000, 3, 16, 16)
dummy_labels = torch.randint(0, 10, (7000, 1))
torch.save(dummy_images, os.path.join(dpath, "images.pt"))
torch.save(dummy_labels, os.path.join(dpath, "labels.pt"))
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]
label = mk_fun_label2onehot(10)(label)
inject_tensor = []
img_id = 0
return image, label, inject_tensor, img_id