Source code for domid.trainers.zoo_trainer
"""
select trainer
"""
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler
from domid.trainers.trainer_ae import TrainerAE
from domid.trainers.trainer_cluster import TrainerCluster
from domid.trainers.trainer_sdcn import TrainerSDCN
[docs]class TrainerChainNodeGetter(object):
"""
Chain of Responsibility: node is named in pattern Trainer[XXX] where the string
after 'Trainer' is the name to be passed to args.trainer.
"""
[docs] def __init__(self, str_trainer):
"""__init__.
:param args: command line arguments
"""
self._list_str_trainer = None
if str_trainer is not None:
self._list_str_trainer = str_trainer.split("_")
self.request = self._list_str_trainer.pop(0)
else:
self.request = str_trainer
def __call__(self, lst_candidates=None, default=None, lst_excludes=None):
"""
1. construct the chain, filter out responsible node,
create heavy-weight business object
2. hard code seems to be the best solution
"""
if lst_candidates is not None and self.request not in lst_candidates:
raise RuntimeError(
f"desired {self.request} is not supported \
among {lst_candidates}"
)
if default is not None and self.request is None:
self.request = default
if lst_excludes is not None and self.request in lst_excludes:
raise RuntimeError(f"desired {self.request} is not supported among {lst_excludes}")
chain = TrainerBasic(None)
chain = TrainerSDCN(chain)
chain = TrainerCluster(chain)
chain = TrainerAE(chain)
node = chain.handle(self.request)
head = node
while self._list_str_trainer:
self.request = self._list_str_trainer.pop(0)
node2decorate = self.__call__(lst_candidates, default, lst_excludes)
head.extend(node2decorate)
head = node2decorate
return node