import torch.optim as optim
from domainlab.algos.trainers.a_trainer import AbstractTrainer
from tensorboardX import SummaryWriter
from domid.compos.predict_basic import Prediction
from domid.compos.tensorboard_fun import tensorboard_write
from domid.trainers.pretraining_KMeans import Pretraining
from domid.utils.perf_cluster import PerfCluster
from domid.utils.storing import Storing
[docs]class TrainerAE(AbstractTrainer):
[docs] def init_business(self, model, task, observer, device, aconf, flag_accept=True):
"""
:param model: model to train
:param task: task to train on
:param observer: observer to notify
:param device: device to use
:param writer: tensorboard writer
:param pretrain: whether to pretrain the model with MSE loss
:param aconf: configuration parameters, including learning rate and pretrain threshold
"""
super().__init__()
super().init_business(model, task, observer, device, aconf)
print(model)
if aconf.pre_tr > 0:
self.pretrain = True
else:
self.pretrain = False
self.pretraining_finished = not self.pretrain
self.lr = aconf.lr
self.warmup_beta = 0.1
if not self.pretraining_finished:
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
print("".join(["#"] * 60) + "\nPretraining initialized.\n" + "".join(["#"] * 60))
else:
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
self.epo_loss_tr = None
self.thres = aconf.pre_tr # number of epochs for pretraining
self.i_h, self.i_w = task.isize.h, task.isize.w
self.args = aconf
self.storage = Storing(self.args)
self.writer = SummaryWriter(logdir="debug/" + self.storage.experiment_name)
self.loader_val = task.loader_tr
self.aname = aconf.model
[docs] def tr_epoch(self, epoch):
"""
:param epoch: epoch number
:return:
"""
print("Epoch {}.".format(epoch)) if self.pretraining_finished else print("Epoch {}. Pretraining.".format(epoch))
self.model.train()
self.epo_loss_tr = 0
pretrain = Pretraining(self.model, self.device, self.loader_tr, self.loader_val, self.i_h, self.i_w, self.args)
prediction = Prediction(
self.model, self.device, self.loader_tr, self.loader_val, self.i_h, self.i_w, self.args.bs
)
acc_tr_y, _, acc_tr_d, _ = prediction.epoch_tr_acc()
acc_val_y, _, acc_val_d, _ = prediction.epoch_val_acc()
r_score_tr = "None"
r_score_te = "None"
if self.args.task == "her2":
r_score_tr = prediction.epoch_tr_correlation()
r_score_te = prediction.epoch_val_correlation() # validation set is used as a test set
# ___________Define warm-up for ELBO loss_________
if self.warmup_beta < 1 and self.pretraining_finished:
self.warmup_beta = self.warmup_beta + 0.01
# _____________one training epoch: start_______________________
for i, (tensor_x, vec_y, vec_d, *other_vars) in enumerate(self.loader_tr):
if i == 0:
self.model.batch_zero = True
if len(other_vars) > 0:
inject_tensor, image_id = other_vars
if len(inject_tensor) > 0:
inject_tensor = inject_tensor.to(self.device)
tensor_x, vec_y, vec_d = (
tensor_x.to(self.device),
vec_y.to(self.device),
vec_d.to(self.device),
)
self.optimizer.zero_grad()
# __________________Pretrain/ELBO loss____________
if epoch < self.thres and not self.pretraining_finished:
loss = pretrain.pretrain_loss(tensor_x, inject_tensor)
else:
if not self.pretraining_finished:
# i.e., this is the first epoch after pre-training
# So we need to set the pretraining_finishend flag to True, and to reset the optimizer:
self.pretraining_finished = True
self.model.counter = 1
self.optimizer = optim.Adam(
self.model.parameters(),
lr=self.lr,
)
print("".join(["#"] * 60))
print("Epoch {}: Finished pretraining and starting to use the full model loss.".format(epoch))
print("".join(["#"] * 60))
loss = self.model.cal_loss(tensor_x, inject_tensor)
loss = loss.sum()
loss.backward()
self.optimizer.step()
self.epo_loss_tr += loss.cpu().detach().item()
# after one epoch (all batches), GMM is calculated again and pi, mu_c
# will get updated via this line.
# name convention: mu_c is the mean for the Gaussian mixture cluster,
# but mu alone means mean for decoded pixel
if not self.pretraining_finished:
pretrain.model_fit()
# __________________Validation_____________________
for i, (tensor_x_val, vec_y_val, vec_d_val, *other_vars) in enumerate(self.loader_val):
if len(other_vars) > 0:
inject_tensor_val, img_id_val = other_vars
if len(inject_tensor_val) > 0:
inject_tensor_val = inject_tensor_val.to(self.device)
tensor_x_val, vec_y_val, vec_d_val = (
tensor_x_val.to(self.device),
vec_y_val.to(self.device),
vec_d_val.to(self.device),
)
if epoch < self.thres and not self.pretraining_finished:
loss_val = pretrain.pretrain_loss(tensor_x_val, inject_tensor_val)
else:
loss_val = self.model.cal_loss(tensor_x_val, inject_tensor_val, self.warmup_beta)
if self.writer != None:
tensorboard_write(
self.writer,
self.model,
epoch,
self.lr,
self.warmup_beta,
acc_tr_y,
loss,
self.pretraining_finished,
tensor_x,
inject_tensor,
)
# _____storing results and Z space__________
self.storage.storing(
epoch, acc_tr_y, acc_tr_d, self.epo_loss_tr, acc_val_y, acc_val_d, loss_val, r_score_tr, r_score_te
)
if epoch % 2 == 0:
_, z_proj, predictions, vec_y_labels, vec_d_labels, image_id_labels = prediction.mk_prediction()
# _, Z, domain_labels, machine_labels, image_locs = prediction.mk_prediction()
self.storage.storing_z_space(z_proj, predictions, vec_y_labels, vec_d_labels, image_id_labels)
if epoch % 2 == 0:
self.storage.saving_model(self.model)
flag_stop = self.observer.update(epoch) # notify observer
# self.storage.csv_dump(epoch)
return flag_stop
[docs] def before_tr(self):
"""
check the performance of randomly initialized weight
"""
acc = PerfCluster.cal_acc(self.model, self.loader_tr, self.device) # FIXME change tr to te
print("before training, model accuracy:", acc)
[docs] def post_tr(self):
print("training is done")
self.observer.after_all()