From d05fc08e64e52a45a8d13abb372a543737304b77 Mon Sep 17 00:00:00 2001 From: Daniel Ecer <de-code@users.noreply.github.com> Date: Thu, 7 Dec 2017 18:41:47 +0000 Subject: [PATCH] removed limited and unnecessary override_if_not_in_args --- sciencebeam_gym/trainer/task.py | 12 ++++-------- sciencebeam_gym/trainer/util.py | 5 ----- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/sciencebeam_gym/trainer/task.py b/sciencebeam_gym/trainer/task.py index 4fb3486..a8401ed 100644 --- a/sciencebeam_gym/trainer/task.py +++ b/sciencebeam_gym/trainer/task.py @@ -22,7 +22,6 @@ from sciencebeam_gym.trainer.evaluator import Evaluator from sciencebeam_gym.trainer.util import ( CustomSupervisor, SimpleStepScheduler, - override_if_not_in_args, get_graph_size ) @@ -499,14 +498,17 @@ def run(model, argv): parser.add_argument( '--max_steps', type=int, + default=1000 ) parser.add_argument( '--batch_size', type=int, + default=100, help='Number of examples to be processed per mini-batch.' ) parser.add_argument( - '--eval_set_size', type=int, help='Number of examples in the eval set.' + '--eval_set_size', type=int, default=370, + help='Number of examples in the eval set.' ) parser.add_argument( '--qualitative_set_size', @@ -690,12 +692,6 @@ def main(_): model_factory = get_model_factory(args.model) model, task_args = model_factory.create_model(other_args) - override_if_not_in_args('--max_steps', '1000', task_args) - override_if_not_in_args('--batch_size', '100', task_args) - override_if_not_in_args('--eval_set_size', '370', task_args) - override_if_not_in_args('--eval_interval_secs', '2', task_args) - override_if_not_in_args('--log_interval_secs', '2', task_args) - override_if_not_in_args('--min_train_eval_rate', '2', task_args) run(model, task_args) if __name__ == '__main__': diff --git a/sciencebeam_gym/trainer/util.py b/sciencebeam_gym/trainer/util.py index 810082e..57d5ddd 100644 --- a/sciencebeam_gym/trainer/util.py +++ b/sciencebeam_gym/trainer/util.py @@ -139,11 +139,6 @@ def read_examples(input_files, shuffle, num_epochs=None): return example_id, encoded_example -def override_if_not_in_args(flag, argument, args): - """Checks if flags is in args, and if not it adds the flag to args.""" - if flag not in args: - args.extend([flag, argument]) - def loss(loss_value): """Calculates aggregated mean loss.""" total_loss = tf.Variable(0.0, False) -- GitLab