Source code for domid.algos.builder_m2yd

from domainlab.algos.a_algo_builder import NodeAlgoBuilder
from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor
from domainlab.algos.msels.c_msel_val import MSelValPerf
from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp
from domainlab.utils.utils_cuda import get_device

from domid.algos.observers.b_obvisitor_clustering import ObVisitorClustering
from domid.algos.observers.b_obvisitor_clustering_only import ObVisitorClusteringOnly
from domid.models.model_m2yd import mk_m2yd
from domid.trainers.zoo_trainer import TrainerChainNodeGetter


[docs]class NodeAlgoBuilderM2YD(NodeAlgoBuilder):
[docs] def init_business(self, exp): """ return trainer, model, observer """ task = exp.task args = exp.args device = get_device(args) model = mk_m2yd()( y_dim=len(task.list_str_y), list_str_y=task.list_str_y, zd_dim=args.zd_dim, gamma_y=args.gamma_y, device=device, i_c=task.isize.c, i_h=task.isize.h, i_w=task.isize.w, ) observer = ObVisitorCleanUp( ObVisitorClusteringOnly(exp, MSelOracleVisitor(MSelValPerf(max_es=args.es)), device) ) # FIXME: may need to be ObVisitorClustering instead of ObVisitorClusteringOnly... trainer = TrainerChainNodeGetter(args.trainer)() trainer.init_business(model, task, observer, device, args) return trainer, model, observer, device
[docs]def get_node_na(): return NodeAlgoBuilderM2YD