diff --git a/eval.sh b/eval.sh index 3c053334bda03dd69dc65973b802eab0c99acb71..5506d936ee863cb036f021fc1bd20bdb2ad293ca 100755 --- a/eval.sh +++ b/eval.sh @@ -21,6 +21,14 @@ COMMON_ARGS=( ${TRAINING_ARGS[@]} ) +if [ ! -z "$QUANTITATIVE_FOLDER_NAME" ]; then + COMMON_ARGS=( + ${COMMON_ARGS[@]} + --quantitative_data_paths "${PREPROC_PATH}/${QUANTITATIVE_FOLDER_NAME}/*tfrecord*" + --quantitative_set_size ${QUANTITATIVE_SET_SIZE} + ) +fi + if [ $USE_SEPARATE_CHANNELS == true ]; then COMMON_ARGS=( ${COMMON_ARGS[@]} diff --git a/prepare-shell.sh b/prepare-shell.sh index 38fcf1cd303c208ca7e1cfb83f70d0b171d257e4..85e8570c5559a81be578e4e62a521d709adcd08b 100755 --- a/prepare-shell.sh +++ b/prepare-shell.sh @@ -14,6 +14,9 @@ export BUCKET="gs://${PROJECT}-ml" export COLOR_MAP_FILENAME="color_map.conf" export USE_SEPARATE_CHANNELS=true export DATASET_SUFFIX= +export EVAL_SET_SIZE=10 +export QUANTITATIVE_FOLDER_NAME= +export QUANTITATIVE_SET_SIZE=10 export USE_CLOUD=false diff --git a/sciencebeam_gym/trainer/evaluator.py b/sciencebeam_gym/trainer/evaluator.py index b1d31ab5498577fc59f662e89df10fd05fe32e03..481769c5730eb2da8b502cb0fcaf9f23e60b3b23 100644 --- a/sciencebeam_gym/trainer/evaluator.py +++ b/sciencebeam_gym/trainer/evaluator.py @@ -100,12 +100,19 @@ class Evaluator(object): def __init__( self, args, model, - checkpoint_path, data_paths, dataset='eval', + checkpoint_path, + data_paths, + dataset='eval', + eval_batch_size=None, + eval_set_size=None, + quantitative_set_size=None, run_async=None): - self.eval_batch_size = args.eval_batch_size - self.num_eval_batches = args.eval_set_size // self.eval_batch_size - self.num_detail_eval_batches = min(10, args.eval_set_size) // self.eval_batch_size + self.eval_batch_size = eval_batch_size or args.eval_batch_size + self.num_eval_batches = (eval_set_size or args.eval_set_size) // self.eval_batch_size + self.num_detail_eval_batches = ( + min((quantitative_set_size or 10), args.eval_set_size) // self.eval_batch_size + ) self.batch_of_examples = [] self.checkpoint_path = checkpoint_path self.output_path = os.path.join(args.output_path, dataset) diff --git a/sciencebeam_gym/trainer/task.py b/sciencebeam_gym/trainer/task.py index 13322f5a196fed69bf2c9b3362bba67a4f6eda13..3e4b7a8e904b4f00a48d52d17f353f5458e63606 100644 --- a/sciencebeam_gym/trainer/task.py +++ b/sciencebeam_gym/trainer/task.py @@ -56,6 +56,21 @@ class TrainingProgressLogger(object): self.last_global_step = global_step self.last_local_step = local_step +def get_quantitative_evaluator(args, model, run_async): + if args.quantitative_data_paths: + return Evaluator( + args, + model, + train_dir(args.output_path), + args.quantitative_data_paths, + dataset='quantitative_set', + eval_set_size=args.quantitative_set_size or args.eval_set_size, + quantitative_set_size=args.quantitative_set_size, + run_async=run_async + ) + else: + return None + class Trainer(object): """Performs model training and optionally evaluation.""" @@ -81,6 +96,11 @@ class Trainer(object): 'train_set', run_async=run_async ) + self.quantitative_evaluator = get_quantitative_evaluator( + self.args, + self.model, + run_async=run_async + ) self.min_train_eval_rate = args.min_train_eval_rate self.global_step = None self.last_save = 0 @@ -99,6 +119,8 @@ class Trainer(object): logger = get_logger() self.train_evaluator.init() self.evaluator.init() + if self.quantitative_evaluator: + self.quantitative_evaluator.init() ensure_output_path(self.args.output_path) train_path = train_dir(self.args.output_path) # model_path = model_dir(self.args.output_path) @@ -286,6 +308,12 @@ class Trainer(object): def eval(self, global_step=None): """Runs evaluation loop.""" + if self.quantitative_evaluator: + logging.info( + 'Quantitive Eval, step %s:\n- on eval set %s', + global_step, + self.model.format_metric_values(self.quantitative_evaluator.evaluate()) + ) logging.info( 'Eval, step %s:\n- on eval set %s', global_step, @@ -319,12 +347,23 @@ def write_predictions(args, model, cluster, task): logger.info('Starting to write predictions on %s/%d', task.type, task.index) pool = Pool(processes=args.pool_size) run_async = lambda f, args: pool.apply_async(f, args) + + quantitative_evaluator = get_quantitative_evaluator( + args, + model, + run_async=run_async + ) + if quantitative_evaluator: + quantitative_evaluator.init() + quantitative_evaluator.write_predictions() + evaluator = Evaluator( args, model, train_dir(args.output_path), args.eval_data_paths, run_async=run_async ) evaluator.init() evaluator.write_predictions() + logger.info('Waiting for background tasks to finish') pool.close() pool.join() @@ -414,6 +453,13 @@ def run(model, argv): help='The path to the files used for evaluation. ' 'Can be comma separated list of files or glob pattern.' ) + parser.add_argument( + '--quantitative_data_paths', + type=str, + action='append', + help='The path to the files used for quantitative evaluation. ' + 'You may choose a different set for the quantitative analysis to keep the results consistent.' + ) parser.add_argument( '--output_path', type=str, @@ -433,6 +479,11 @@ def run(model, argv): parser.add_argument( '--eval_set_size', type=int, help='Number of examples in the eval set.' ) + parser.add_argument( + '--quantitative_set_size', + type=int, + help='Number of examples in the quantitative eval set.' + ) parser.add_argument( '--eval_batch_size', type=int, help='Number of examples per eval batch.' ) diff --git a/train.sh b/train.sh index 53d7f0732458318b5059d6675fd2253aacdc6ed7..33d8c9c1e30b819fd1dd7f4cfa3eeeee396c1ca8 100755 --- a/train.sh +++ b/train.sh @@ -19,6 +19,14 @@ COMMON_ARGS=( ${TRAINING_ARGS[@]} ) +if [ ! -z "$QUANTITATIVE_FOLDER_NAME" ]; then + COMMON_ARGS=( + ${COMMON_ARGS[@]} + --quantitative_data_paths "${PREPROC_PATH}/${QUANTITATIVE_FOLDER_NAME}/*tfrecord*" + --quantitative_set_size ${QUANTITATIVE_SET_SIZE} + ) +fi + if [ $USE_CLOUD == true ]; then gcloud ml-engine jobs submit training "$JOB_ID" \ --stream-logs \