from collections import namedtuple from mock import patch import pytest from pytest import raises import tensorflow as tf from sciencebeam_gym.utils.collection import ( extend_dict ) from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import ( BaseLoss, ALL_BASE_LOSS ) from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core_test import ( DEFAULT_ARGS as CORE_DEFAULT_ARGS ) import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as pix2pix_model from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import ( parse_color_map, color_map_to_labels, color_map_to_colors, color_map_to_colors_and_labels, colors_and_labels_with_unknown_class, UNKNOWN_COLOR, UNKNOWN_LABEL, Model, str_to_list, model_args_parser, class_weights_to_pos_weight ) COLOR_MAP_FILENAME = 'color_map.conf' CLASS_WEIGHTS_FILENAME = 'class-weights.json' DATA_PATH = 'some/where/*.tfrecord' BATCH_SIZE = 10 def some_color(i): return (i, i, i) SOME_COLORS = [some_color(1), some_color(2), some_color(3)] SOME_LABELS = ['a', 'b', 'c'] SOME_COLOR_MAP = { k: v for k, v in zip(SOME_LABELS, SOME_COLORS) } SOME_CLASS_WEIGHTS = { k: float(i) for i, k in enumerate(SOME_LABELS) } class TestParseColorMap(object): def test_should_use_fileio_to_load_file_and_pass_to_parser(self): with patch.object(pix2pix_model, 'FileIO') as FileIO: with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: parse_color_map(COLOR_MAP_FILENAME) FileIO.assert_called_with(COLOR_MAP_FILENAME, 'r') parse_color_map_from_file.assert_called_with(FileIO.return_value.__enter__.return_value) class TestColorMapToLabels(object): def test_should_use_color_maps_keys_by_default(self): color_map = { 'a': some_color(1), 'b': some_color(2), 'c': some_color(3) } assert color_map_to_labels(color_map) == ['a', 'b', 'c'] def test_should_return_specified_labels(self): color_map = { 'a': some_color(1), 'b': some_color(2), 'c': some_color(3) } assert color_map_to_labels(color_map, ['b', 'a']) == ['b', 'a'] def test_should_raise_error_if_specified_label_not_in_color_map(self): color_map = { 'a': some_color(1), 'c': some_color(3) } with raises(ValueError): color_map_to_labels(color_map, ['a', 'b']) class TestColorMapToColors(object): def test_should_return_colors_for_labels(self): color_map = { 'a': some_color(1), 'b': some_color(2), 'c': some_color(3) } assert color_map_to_colors(color_map, ['a', 'b']) == [ some_color(1), some_color(2) ] class TestColorMapToColorsAndLabels(object): def test_should_return_color_and_specified_labels(self): color_map = { 'a': some_color(1), 'b': some_color(2), 'c': some_color(3) } colors, labels = color_map_to_colors_and_labels(color_map, ['a', 'b']) assert colors == [some_color(1), some_color(2)] assert labels == ['a', 'b'] def test_should_return_color_using_sorted_keys_if_no_labels_are_specified(self): color_map = { 'a': some_color(1), 'b': some_color(2), 'c': some_color(3) } colors, labels = color_map_to_colors_and_labels(color_map, None) assert colors == [some_color(1), some_color(2), some_color(3)] assert labels == ['a', 'b', 'c'] class TestColorsAndLabelsWithUnknownClass(object): def test_should_not_add_unknown_class_if_not_enabled(self): colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class( SOME_COLORS, SOME_LABELS, use_unknown_class=False ) assert colors_with_unknown == SOME_COLORS assert labels_with_unknown == SOME_LABELS def test_should_add_unknown_class_if_enabled(self): colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class( SOME_COLORS, SOME_LABELS, use_unknown_class=True ) assert colors_with_unknown == SOME_COLORS + [UNKNOWN_COLOR] assert labels_with_unknown == SOME_LABELS + [UNKNOWN_LABEL] def test_should_add_unknown_class_if_colors_are_empty(self): colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class( [], [], use_unknown_class=False ) assert colors_with_unknown == [UNKNOWN_COLOR] assert labels_with_unknown == [UNKNOWN_LABEL] class TestClassWeightsToPosWeight(object): def test_should_extract_selected_weights(self): assert class_weights_to_pos_weight({ 'a': 0.1, 'b': 0.2, 'c': 0.3 }, ['a', 'b'], False) == [0.1, 0.2] def test_should_add_zero_if_unknown_class_is_true(self): assert class_weights_to_pos_weight({ 'a': 0.1, 'b': 0.2, 'c': 0.3 }, ['a', 'b'], True) == [0.1, 0.2, 0.0] DEFAULT_ARGS = extend_dict( CORE_DEFAULT_ARGS, dict( color_map=None, class_weights=None, channels=None, use_separate_channels=False, use_unknown_class=False ) ) def create_args(*args, **kwargs): d = extend_dict(*list(args) + [kwargs]) return namedtuple('args', d.keys())(**d) class TestModel(object): def test_parse_separate_channels_with_color_map_without_class_weights(self): with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: parse_color_map_from_file.return_value = { 'a': some_color(1), 'b': some_color(2), 'c': some_color(3) } args = create_args( DEFAULT_ARGS, color_map=COLOR_MAP_FILENAME, class_weights=None, channels=['a', 'b'], use_separate_channels=True, use_unknown_class=True ) model = Model(args) assert model.dimension_colors == [some_color(1), some_color(2)] assert model.dimension_labels == ['a', 'b'] assert model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR] assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL] assert model.pos_weight is None def test_parse_separate_channels_with_color_map_and_class_weights(self): with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: parse_color_map_from_file.return_value = { 'a': some_color(1), 'b': some_color(2), 'c': some_color(3) } parse_json_file.return_value = { 'a': 0.1, 'b': 0.2, 'c': 0.3 } args = create_args( DEFAULT_ARGS, color_map=COLOR_MAP_FILENAME, class_weights=CLASS_WEIGHTS_FILENAME, channels=['a', 'b'], use_separate_channels=True, use_unknown_class=True ) model = Model(args) assert model.dimension_colors == [some_color(1), some_color(2)] assert model.dimension_labels == ['a', 'b'] assert ( model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR] ) assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL] assert model.pos_weight == [0.1, 0.2, 0.0] @pytest.mark.slow @pytest.mark.very_slow class TestModelBuildGraph(object): def test_should_build_train_graph_with_defaults(self): with tf.Graph().as_default(): with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: with patch.object(pix2pix_model, 'read_examples') as read_examples: parse_color_map_from_file.return_value = SOME_COLOR_MAP read_examples.return_value = ( tf.constant(1), b'dummy' ) args = create_args( DEFAULT_ARGS ) model = Model(args) model.build_train_graph(DATA_PATH, BATCH_SIZE) def test_should_build_train_graph_with_class_weights(self): with tf.Graph().as_default(): with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: with patch.object(pix2pix_model, 'read_examples') as read_examples: parse_color_map_from_file.return_value = SOME_COLOR_MAP parse_json_file.return_value = SOME_CLASS_WEIGHTS read_examples.return_value = ( tf.constant(1), b'dummy' ) args = create_args( DEFAULT_ARGS, base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY, color_map=COLOR_MAP_FILENAME, class_weights=CLASS_WEIGHTS_FILENAME, channels=['a', 'b'], use_separate_channels=True, use_unknown_class=True ) model = Model(args) model.build_train_graph(DATA_PATH, BATCH_SIZE) class TestStrToList(object): def test_should_parse_empty_string_as_empty_list(self): assert str_to_list('') == [] def test_should_parse_blank_string_as_empty_list(self): assert str_to_list(' ') == [] def test_should_parse_comma_separated_list(self): assert str_to_list('a,b,c') == ['a', 'b', 'c'] def test_should_ignore_white_space_around_values(self): assert str_to_list(' a , b , c ') == ['a', 'b', 'c'] class TestModelArgsParser(object): def test_should_parse_channels(self): args = model_args_parser().parse_args(['--channels', 'a,b,c']) assert args.channels == ['a', 'b', 'c'] def test_should_set_channels_to_none_by_default(self): args = model_args_parser().parse_args([]) assert args.channels is None def test_should_allow_all_base_loss_options(self): for base_loss in ALL_BASE_LOSS: args = model_args_parser().parse_args(['--base_loss', base_loss]) assert args.base_loss == base_loss