Source code for domid.dsets.dset_usps
import os
import numpy as np
import pandas as pd
import torch
from domainlab.dsets.utils_data import mk_fun_label2onehot
from domainlab.utils.utils_class import store_args
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets, transforms
[docs]class DsetUSPS(Dataset):
[docs] @store_args
def __init__(self, digit, args, subset_step=1, list_transforms=None):
dpath = os.path.normpath(args.dpath)
self.digit = digit
self.dataset = datasets.USPS(root=dpath, train=True, download=True, transform=list_transforms)
self.images = self.dataset.data[torch.Tensor(self.dataset.targets) == digit]
all_labels = torch.Tensor(self.dataset.targets)
self.labels = all_labels[torch.Tensor(self.dataset.targets) == digit].to(dtype=torch.long)
self.args = args
self.inject_variable = args.inject_var
[docs] def get_original_indicies(self):
return (torch.Tensor(self.dataset.targets) == self.digit).nonzero().flatten()
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# img_loc = os.path.join(self.img_dir, self.images[idx])
image = self.images[idx]
image = Image.fromarray(image)
# image = image.convert("RGB")
if self.list_transforms is not None:
for trans in self.list_transforms:
image = trans(image)
else:
image = transforms.ToTensor()(image)
label = self.labels[idx]
label = mk_fun_label2onehot(10)(label)
inject_tensor = []
img_id = label
return image, label, inject_tensor, img_id