From f682ff2e2f4ed113f634a5a1852ffa6e0aeb6fb4 Mon Sep 17 00:00:00 2001 From: Daniel Ecer <de-code@users.noreply.github.com> Date: Mon, 31 Jul 2017 11:16:55 +0100 Subject: [PATCH] run actual evaluation as part of the training loop at configured intervals --- sciencebeam_gym/trainer/task.py | 35 ++++++++++++++------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/sciencebeam_gym/trainer/task.py b/sciencebeam_gym/trainer/task.py index 6d45996..13322f5 100644 --- a/sciencebeam_gym/trainer/task.py +++ b/sciencebeam_gym/trainer/task.py @@ -206,21 +206,22 @@ class Trainer(object): last_run=start_time ) - eval_scheduler = SimpleStepScheduler( - lambda: self.eval(), - min_interval=save_interval, - min_freq=self.args.save_freq, - step=global_step, - last_run=start_time - ) - schedulers = [ log_scheduler, save_scheduler, - eval_train_scheduler, - eval_scheduler + eval_train_scheduler ] + if is_master: + eval_scheduler = SimpleStepScheduler( + lambda: self.eval(global_step=global_step), + min_interval=save_interval, + min_freq=self.args.save_freq, + step=global_step, + last_run=start_time + ) + schedulers = schedulers + [eval_scheduler] + summary_op = sv.summary_op if tensors.summary is None else tensors.summary if summary_op is not None: schedulers.append(SimpleStepScheduler( @@ -270,11 +271,6 @@ class Trainer(object): # Ask for all the services to stop. sv.stop() - if is_master: - logging.info('evaluate...') - self.eval() - logging.info('evaluate done') - def eval_train(self, session, tensors, global_step): """Runs evaluation loop.""" logging.info( @@ -288,13 +284,12 @@ class Trainer(object): ) ) - def eval(self): + def eval(self, global_step=None): """Runs evaluation loop.""" logging.info( - 'Eval:\n- on train set %s\n-- on eval set %s', - self.model.format_metric_values(self.train_evaluator.evaluate()), - None - # self.model.format_metric_values(self.evaluator.evaluate()) + 'Eval, step %s:\n- on eval set %s', + global_step, + self.model.format_metric_values(self.evaluator.evaluate()) ) def copy_data_to_tmp(input_files): -- GitLab