Source code for domid.compos.tensorboard_fun

import torch


[docs]def tensorboard_write( writer, model, epoch, lr, warmup_beta, acc_tr, loss, pretraining_finished, tensor_x, inject_tensor=None, other_info=None, ): if lr > 0: writer.add_scalar("learning rate", lr, epoch) writer.add_scalar("warmup", warmup_beta, epoch) if not pretraining_finished: writer.add_scalar("Pretraining", acc_tr, epoch) writer.add_scalar("Pretraining Loss", loss, epoch) else: writer.add_scalar("Training acc", acc_tr, epoch) writer.add_scalar("Loss", loss, epoch) if not pretraining_finished: name = "Output of the decoder pretraining" else: name = "Output of the decoder training" if other_info is not None and epoch > 3: kl_total, ce_total, re_total = other_info writer.add_scalar("KL", kl_total, epoch) writer.add_scalar("CE", ce_total, epoch) writer.add_scalar("RE", re_total, epoch) if inject_tensor is not None: preds, *_, x_pro = model.infer_d_v_2(tensor_x, inject_tensor) else: preds, *_, x_pro = model.infer_d_v_2(tensor_x) if len(x_pro.shape) < 3: x_pro = torch.reshape(x_pro, (x_pro.shape[0], tensor_x.shape[1], tensor_x.shape[2], tensor_x.shape[3])) imgs = torch.cat( ( tensor_x[0:8, :, :, :], x_pro[0:8, :, :, :], ), 0, ) # mse = torch.nn.MSELoss()#(dim=1, eps=1e-08) # sample1 = tensor_x[0, :, :, :].flatten().unsqueeze(0) # sample2 = x_pro[0, :, :, :].flatten().unsqueeze(0) # # acc_ = torch.mean(torch.abs(sample1-sample2)/sample1) # print('SIMILARITY', acc_) writer.add_images(name, imgs, epoch)