Source code for domid.dsets.dset_mnist

"""
MNIST
"""

import os

import numpy as np
import torch
from domainlab.dsets.utils_data import mk_fun_label2onehot
from domainlab.utils.utils_class import store_args
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets, transforms


[docs]class DsetMNIST(Dataset): """ MNIST Dataset Loading - subdomains: MNIST digit value - structure: each subdomain contains all images of a given digit """
[docs] @store_args def __init__(self, digit, args, list_transforms=None, raw_split="train"): """ :param digit: a integer value from 0 to 9; only images of this digit will be kept. :param path: disk storage directory :param subset_step: used to subsample the dataset; a fraction of 1/subset_step images is kept :param list_transforms: torch transformations :param raw_split: default use the training part of mnist """ dpath = os.path.normpath(args.dpath) dataset = datasets.MNIST(root=dpath, train=True, download=True, transform=list_transforms) # keep only images of specified digit subset_step = args.subset_step self.images = dataset.data[dataset.targets == digit] inds_subset = list(range(0, self.images.shape[0], subset_step)) self.images = self.images[inds_subset] n_img = self.images.shape[0] # dummy class labels (should not be used; included for consistency with DomainLab) self.labels = torch.randint(10, (n_img,), dtype=torch.int32) self.args = args self.inject_variable = args.inject_var
def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx].numpy() image = Image.fromarray(image) image = image.convert("RGB") # image = Image.open(img_loc) if self.list_transforms is not None: for trans in self.list_transforms: image = trans(image) else: image = transforms.ToTensor()(image) # range of pixel [0,1] # dummy class labels (should not be used; included for consistency with DomainLab) label = self.labels[idx] label = mk_fun_label2onehot(10)(label) if self.inject_variable: inject_tensor = np.random.randint(0, self.args.dim_inject_y, size=(1,))[0] # inject_tensor = torch.randint(low=0, high=self.args.dim_inject_y, size=(len(label),)) inject_tensor = mk_fun_label2onehot(self.args.dim_inject_y)(inject_tensor - 1) else: inject_tensor = [] # dummy image locations; included for consistency with code that uses inject_domain. # FIXME: remove location and another_label here, and adjust the code elsewhere that only needs inject_domain but still expects location and another_label. location = "dummy_placeholder" # if self.args.path_to_domain: # inject_domain = np.loadtxt(os.path.join(self.args.path_to_domain, "domain_labels.txt"))[idx] # # FIXME: no need to hardcode the name of the file as "domain_labels.txt" # inject_domain = mk_fun_label2onehot(self.args.d_dim)(int(inject_domain) - 1) # # FIXME: no need to hardcode the number of domains as d_dim # else: # inject_domain = np.array([]) return (image, label, inject_tensor, location) # FIXME for mnist color as well