diff --git a/sciencebeam_gym/trainer/task.py b/sciencebeam_gym/trainer/task.py index 4fb3486864ba4549d221057a45c2f0f67137078e..a8401eda233ba05de481c50f7e32ba088474638b 100644 --- a/sciencebeam_gym/trainer/task.py +++ b/sciencebeam_gym/trainer/task.py @@ -22,7 +22,6 @@ from sciencebeam_gym.trainer.evaluator import Evaluator from sciencebeam_gym.trainer.util import ( CustomSupervisor, SimpleStepScheduler, - override_if_not_in_args, get_graph_size ) @@ -499,14 +498,17 @@ def run(model, argv): parser.add_argument( '--max_steps', type=int, + default=1000 ) parser.add_argument( '--batch_size', type=int, + default=100, help='Number of examples to be processed per mini-batch.' ) parser.add_argument( - '--eval_set_size', type=int, help='Number of examples in the eval set.' + '--eval_set_size', type=int, default=370, + help='Number of examples in the eval set.' ) parser.add_argument( '--qualitative_set_size', @@ -690,12 +692,6 @@ def main(_): model_factory = get_model_factory(args.model) model, task_args = model_factory.create_model(other_args) - override_if_not_in_args('--max_steps', '1000', task_args) - override_if_not_in_args('--batch_size', '100', task_args) - override_if_not_in_args('--eval_set_size', '370', task_args) - override_if_not_in_args('--eval_interval_secs', '2', task_args) - override_if_not_in_args('--log_interval_secs', '2', task_args) - override_if_not_in_args('--min_train_eval_rate', '2', task_args) run(model, task_args) if __name__ == '__main__': diff --git a/sciencebeam_gym/trainer/util.py b/sciencebeam_gym/trainer/util.py index 810082e6a0ea174eec45577dd3fe12303a06808d..57d5dddcd12c58a311b9199067dbdc69d3884f83 100644 --- a/sciencebeam_gym/trainer/util.py +++ b/sciencebeam_gym/trainer/util.py @@ -139,11 +139,6 @@ def read_examples(input_files, shuffle, num_epochs=None): return example_id, encoded_example -def override_if_not_in_args(flag, argument, args): - """Checks if flags is in args, and if not it adds the flag to args.""" - if flag not in args: - args.extend([flag, argument]) - def loss(loss_value): """Calculates aggregated mean loss.""" total_loss = tf.Variable(0.0, False)