Skip to content
Snippets Groups Projects
Commit fd816b46 authored by Daniel Ecer's avatar Daniel Ecer
Browse files

added explicit quantitative evaluation

parent 3e9037d4
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,14 @@ COMMON_ARGS=(
${TRAINING_ARGS[@]}
)
if [ ! -z "$QUANTITATIVE_FOLDER_NAME" ]; then
COMMON_ARGS=(
${COMMON_ARGS[@]}
--quantitative_data_paths "${PREPROC_PATH}/${QUANTITATIVE_FOLDER_NAME}/*tfrecord*"
--quantitative_set_size ${QUANTITATIVE_SET_SIZE}
)
fi
if [ $USE_SEPARATE_CHANNELS == true ]; then
COMMON_ARGS=(
${COMMON_ARGS[@]}
......
......@@ -14,6 +14,9 @@ export BUCKET="gs://${PROJECT}-ml"
export COLOR_MAP_FILENAME="color_map.conf"
export USE_SEPARATE_CHANNELS=true
export DATASET_SUFFIX=
export EVAL_SET_SIZE=10
export QUANTITATIVE_FOLDER_NAME=
export QUANTITATIVE_SET_SIZE=10
export USE_CLOUD=false
......
......@@ -100,12 +100,19 @@ class Evaluator(object):
def __init__(
self, args, model,
checkpoint_path, data_paths, dataset='eval',
checkpoint_path,
data_paths,
dataset='eval',
eval_batch_size=None,
eval_set_size=None,
quantitative_set_size=None,
run_async=None):
self.eval_batch_size = args.eval_batch_size
self.num_eval_batches = args.eval_set_size // self.eval_batch_size
self.num_detail_eval_batches = min(10, args.eval_set_size) // self.eval_batch_size
self.eval_batch_size = eval_batch_size or args.eval_batch_size
self.num_eval_batches = (eval_set_size or args.eval_set_size) // self.eval_batch_size
self.num_detail_eval_batches = (
min((quantitative_set_size or 10), args.eval_set_size) // self.eval_batch_size
)
self.batch_of_examples = []
self.checkpoint_path = checkpoint_path
self.output_path = os.path.join(args.output_path, dataset)
......
......@@ -56,6 +56,21 @@ class TrainingProgressLogger(object):
self.last_global_step = global_step
self.last_local_step = local_step
def get_quantitative_evaluator(args, model, run_async):
if args.quantitative_data_paths:
return Evaluator(
args,
model,
train_dir(args.output_path),
args.quantitative_data_paths,
dataset='quantitative_set',
eval_set_size=args.quantitative_set_size or args.eval_set_size,
quantitative_set_size=args.quantitative_set_size,
run_async=run_async
)
else:
return None
class Trainer(object):
"""Performs model training and optionally evaluation."""
......@@ -81,6 +96,11 @@ class Trainer(object):
'train_set',
run_async=run_async
)
self.quantitative_evaluator = get_quantitative_evaluator(
self.args,
self.model,
run_async=run_async
)
self.min_train_eval_rate = args.min_train_eval_rate
self.global_step = None
self.last_save = 0
......@@ -99,6 +119,8 @@ class Trainer(object):
logger = get_logger()
self.train_evaluator.init()
self.evaluator.init()
if self.quantitative_evaluator:
self.quantitative_evaluator.init()
ensure_output_path(self.args.output_path)
train_path = train_dir(self.args.output_path)
# model_path = model_dir(self.args.output_path)
......@@ -286,6 +308,12 @@ class Trainer(object):
def eval(self, global_step=None):
"""Runs evaluation loop."""
if self.quantitative_evaluator:
logging.info(
'Quantitive Eval, step %s:\n- on eval set %s',
global_step,
self.model.format_metric_values(self.quantitative_evaluator.evaluate())
)
logging.info(
'Eval, step %s:\n- on eval set %s',
global_step,
......@@ -319,12 +347,23 @@ def write_predictions(args, model, cluster, task):
logger.info('Starting to write predictions on %s/%d', task.type, task.index)
pool = Pool(processes=args.pool_size)
run_async = lambda f, args: pool.apply_async(f, args)
quantitative_evaluator = get_quantitative_evaluator(
args,
model,
run_async=run_async
)
if quantitative_evaluator:
quantitative_evaluator.init()
quantitative_evaluator.write_predictions()
evaluator = Evaluator(
args, model, train_dir(args.output_path), args.eval_data_paths,
run_async=run_async
)
evaluator.init()
evaluator.write_predictions()
logger.info('Waiting for background tasks to finish')
pool.close()
pool.join()
......@@ -414,6 +453,13 @@ def run(model, argv):
help='The path to the files used for evaluation. '
'Can be comma separated list of files or glob pattern.'
)
parser.add_argument(
'--quantitative_data_paths',
type=str,
action='append',
help='The path to the files used for quantitative evaluation. '
'You may choose a different set for the quantitative analysis to keep the results consistent.'
)
parser.add_argument(
'--output_path',
type=str,
......@@ -433,6 +479,11 @@ def run(model, argv):
parser.add_argument(
'--eval_set_size', type=int, help='Number of examples in the eval set.'
)
parser.add_argument(
'--quantitative_set_size',
type=int,
help='Number of examples in the quantitative eval set.'
)
parser.add_argument(
'--eval_batch_size', type=int, help='Number of examples per eval batch.'
)
......
......@@ -19,6 +19,14 @@ COMMON_ARGS=(
${TRAINING_ARGS[@]}
)
if [ ! -z "$QUANTITATIVE_FOLDER_NAME" ]; then
COMMON_ARGS=(
${COMMON_ARGS[@]}
--quantitative_data_paths "${PREPROC_PATH}/${QUANTITATIVE_FOLDER_NAME}/*tfrecord*"
--quantitative_set_size ${QUANTITATIVE_SET_SIZE}
)
fi
if [ $USE_CLOUD == true ]; then
gcloud ml-engine jobs submit training "$JOB_ID" \
--stream-logs \
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment