From 71081bfdc47107d63e168681ee2c3edf7b599304 Mon Sep 17 00:00:00 2001 From: Daniel Ecer <de-code@users.noreply.github.com> Date: Thu, 7 Dec 2017 18:42:42 +0000 Subject: [PATCH] use default unknown weight and get unknown weight from class weights if present --- .../trainer/models/pix2pix/pix2pix_model.py | 24 ++++++++++------- .../models/pix2pix/pix2pix_model_test.py | 26 ++++++++++++++++--- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py index 8cfe085..79e3801 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py @@ -39,6 +39,11 @@ from sciencebeam_gym.trainer.models.pix2pix.evaluate import ( ) +UNKNOWN_COLOR = (255, 255, 255) +UNKNOWN_LABEL = 'unknown' + +DEFAULT_UNKNOWN_CLASS_WEIGHT = 0.1 + class GraphMode(object): TRAIN = 1 EVALUATE = 2 @@ -185,7 +190,7 @@ def add_model_summary_images( tensors, 'target', tensors.annotation_tensor ) if (has_unknown_class or not use_separate_channels) and dimension_labels is not None: - dimension_labels_with_unknown = dimension_labels + ['unknown'] + dimension_labels_with_unknown = dimension_labels + [UNKNOWN_LABEL] dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)] else: dimension_labels_with_unknown = dimension_labels @@ -260,9 +265,12 @@ def parse_json_file(filename): with FileIO(filename, 'r') as f: return json.load(f) -def class_weights_to_pos_weight(class_weights, labels, use_unknown_class): +def class_weights_to_pos_weight( + class_weights, labels, + use_unknown_class, unknown_class_weight=DEFAULT_UNKNOWN_CLASS_WEIGHT): + pos_weight = [class_weights[k] for k in labels] - return pos_weight + [0.0] if use_unknown_class else pos_weight + return pos_weight + [unknown_class_weight] if use_unknown_class else pos_weight def parse_color_map(color_map_filename): with FileIO(color_map_filename, 'r') as config_f: @@ -283,9 +291,6 @@ def color_map_to_labels(color_map, labels=None): def color_map_to_colors(color_map, labels): return [color_map[k] for k in labels] -UNKNOWN_COLOR = (255, 255, 255) -UNKNOWN_LABEL = 'unknown' - def colors_and_labels_with_unknown_class(colors, labels, use_unknown_class): if use_unknown_class or not colors: return ( @@ -330,11 +335,12 @@ class Model(object): ) logger.debug("dimension_colors: %s", self.dimension_colors) logger.debug("dimension_labels: %s", self.dimension_labels) - if self.args.class_weights: + if class_weights: self.pos_weight = class_weights_to_pos_weight( - parse_json_file(self.args.class_weights), + class_weights, self.dimension_labels, - self.use_unknown_class + self.use_separate_channels, + class_weights.get(UNKNOWN_LABEL, DEFAULT_UNKNOWN_CLASS_WEIGHT) ) logger.info("pos_weight: %s", self.pos_weight) diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py index af2c849..1cbd457 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py @@ -28,6 +28,7 @@ from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import ( colors_and_labels_with_unknown_class, UNKNOWN_COLOR, UNKNOWN_LABEL, + DEFAULT_UNKNOWN_CLASS_WEIGHT, Model, str_to_list, model_args_parser, @@ -137,7 +138,9 @@ class TestClassWeightsToPosWeight(object): 'a': 0.1, 'b': 0.2, 'c': 0.3 - }, ['a', 'b'], True) == [0.1, 0.2, 0.0] + }, ['a', 'b'], True, DEFAULT_UNKNOWN_CLASS_WEIGHT) == ( + [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT] + ) DEFAULT_ARGS = extend_dict( CORE_DEFAULT_ARGS, @@ -205,7 +208,7 @@ class TestModel(object): model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR] ) assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL] - assert model.pos_weight == [0.1, 0.2, 0.0] + assert model.pos_weight == [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT] def test_should_only_include_labels_with_non_zero_class_labels_by_default(self): with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: @@ -230,7 +233,24 @@ class TestModel(object): model = Model(args) assert model.dimension_labels == ['a', 'c'] assert model.dimension_colors == [some_color(1), some_color(3)] - assert model.pos_weight == [0.1, 0.3, 0.0] + assert model.pos_weight == [0.1, 0.3, DEFAULT_UNKNOWN_CLASS_WEIGHT] + + def test_should_use_unknown_class_weight_from_configuration(self): + with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: + with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: + parse_color_map_from_file.return_value = SOME_COLOR_MAP + parse_json_file.return_value = extend_dict(SOME_CLASS_WEIGHTS, { + 'unknown': 0.99 + }) + args = create_args( + DEFAULT_ARGS, + color_map=COLOR_MAP_FILENAME, + class_weights=CLASS_WEIGHTS_FILENAME, + use_separate_channels=True, + use_unknown_class=True + ) + model = Model(args) + assert model.pos_weight[-1] == 0.99 @pytest.mark.slow @pytest.mark.very_slow -- GitLab