Source code for domid.dsets.dset_wsi
import os
import numpy as np
import pandas as pd
import torch
from domainlab.dsets.utils_data import mk_fun_label2onehot
from domainlab.utils.utils_class import store_args
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.io import read_image
[docs]class DsetWSI(Dataset):
"""
Dataset of WEAH stained digital microscopy images.
As currently implemented, the subdomains are the HER2 diagnostic classes 1, 2, and 3.
There are also 4 data collection site/machine combinations.
"""
[docs] @store_args
def __init__(self, class_num, path, args, path_to_domain=None, transform=None):
"""
:param class_num: a integer value from 0 to 2, only images of this class will be kept.Note: that actual classes are from 1-3 (therefore, 1 is added in line 28)
:param path: path to root storage directory
:param d_dim: number of clusters for the clustering task
:param path_to_domain: if inject previously predicted domain labels, the path needs to be specified.domain_labels.txt must be inside the directory, containing to-be-injected labels.
:param transform: torch transformations
"""
self.dpath = args.dpath
self.img_dir = args.dpath # os.path.join(path, "class" + str(class_num + 1) + "jpg")
self.images = path # os.listdir(self.img_dir)
self.class_num = class_num
self.transform = transform
self.total_imgs = len(self.images)
self.path_to_domain = path_to_domain
self.d_dim = args.d_dim
self.df = pd.read_csv(args.meta_data_csv)
print("the data is loading from the csv:", args.meta_data_csv)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# print(self.images[idx])
# print(self.images[idx])
# import pdb; pdb.set_trace()
img_loc = os.path.join(self.dpath, self.images[idx])
# print(img_loc)
image = Image.open(img_loc)
# print(image)
if self.transform:
for trans in self.transform:
image = trans(image)
image = transforms.ToTensor()(image)
resp_label = int(self.df.loc[self.df["path"] == self.images[idx]]["resp"])
cah_label = int(self.df.loc[self.df["path"] == self.images[idx]]["ann"])
label_dict = {"01": 0, "02": 1, "11": 2, "12": 3, "03": 4, "13": 5}
encod_label = label_dict[str(resp_label) + str(cah_label)]
label = mk_fun_label2onehot(6)(encod_label)
inject_tensor = []
img_id = img_loc
return image, label, inject_tensor, img_id