Newer
Older
from collections import namedtuple
from mock import patch
from pytest import raises
import tensorflow as tf
from sciencebeam_gym.utils.collection import (
extend_dict
)
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
BaseLoss,
ALL_BASE_LOSS
)
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core_test import (
DEFAULT_ARGS as CORE_DEFAULT_ARGS
)
import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as pix2pix_model
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import (
parse_color_map,
color_map_to_labels,
color_map_to_colors,
color_map_to_colors_and_labels,
colors_and_labels_with_unknown_class,
UNKNOWN_COLOR,
UNKNOWN_LABEL,
Model,
str_to_list,
)
COLOR_MAP_FILENAME = 'color_map.conf'
CLASS_WEIGHTS_FILENAME = 'class-weights.json'
DATA_PATH = 'some/where/*.tfrecord'
BATCH_SIZE = 10
def some_color(i):
return (i, i, i)
SOME_COLORS = [some_color(1), some_color(2), some_color(3)]
SOME_LABELS = ['a', 'b', 'c']
SOME_COLOR_MAP = {
k: v for k, v in zip(SOME_LABELS, SOME_COLORS)
}
SOME_CLASS_WEIGHTS = {
k: float(i) for i, k in enumerate(SOME_LABELS)
}
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class TestParseColorMap(object):
def test_should_use_fileio_to_load_file_and_pass_to_parser(self):
with patch.object(pix2pix_model, 'FileIO') as FileIO:
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
parse_color_map(COLOR_MAP_FILENAME)
FileIO.assert_called_with(COLOR_MAP_FILENAME, 'r')
parse_color_map_from_file.assert_called_with(FileIO.return_value.__enter__.return_value)
class TestColorMapToLabels(object):
def test_should_use_color_maps_keys_by_default(self):
color_map = {
'a': some_color(1),
'b': some_color(2),
'c': some_color(3)
}
assert color_map_to_labels(color_map) == ['a', 'b', 'c']
def test_should_return_specified_labels(self):
color_map = {
'a': some_color(1),
'b': some_color(2),
'c': some_color(3)
}
assert color_map_to_labels(color_map, ['b', 'a']) == ['b', 'a']
def test_should_raise_error_if_specified_label_not_in_color_map(self):
color_map = {
'a': some_color(1),
'c': some_color(3)
}
with raises(ValueError):
color_map_to_labels(color_map, ['a', 'b'])
class TestColorMapToColors(object):
def test_should_return_colors_for_labels(self):
color_map = {
'a': some_color(1),
'b': some_color(2),
'c': some_color(3)
}
assert color_map_to_colors(color_map, ['a', 'b']) == [
some_color(1),
some_color(2)
]
class TestColorMapToColorsAndLabels(object):
def test_should_return_color_and_specified_labels(self):
color_map = {
'a': some_color(1),
'b': some_color(2),
'c': some_color(3)
}
colors, labels = color_map_to_colors_and_labels(color_map, ['a', 'b'])
assert colors == [some_color(1), some_color(2)]
assert labels == ['a', 'b']
def test_should_return_color_using_sorted_keys_if_no_labels_are_specified(self):
color_map = {
'a': some_color(1),
'b': some_color(2),
'c': some_color(3)
}
colors, labels = color_map_to_colors_and_labels(color_map, None)
assert colors == [some_color(1), some_color(2), some_color(3)]
assert labels == ['a', 'b', 'c']
class TestColorsAndLabelsWithUnknownClass(object):
def test_should_not_add_unknown_class_if_not_enabled(self):
colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
SOME_COLORS,
SOME_LABELS,
use_unknown_class=False
)
assert colors_with_unknown == SOME_COLORS
assert labels_with_unknown == SOME_LABELS
def test_should_add_unknown_class_if_enabled(self):
colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
SOME_COLORS,
SOME_LABELS,
use_unknown_class=True
)
assert colors_with_unknown == SOME_COLORS + [UNKNOWN_COLOR]
assert labels_with_unknown == SOME_LABELS + [UNKNOWN_LABEL]
def test_should_add_unknown_class_if_colors_are_empty(self):
colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
[],
[],
use_unknown_class=False
)
assert colors_with_unknown == [UNKNOWN_COLOR]
assert labels_with_unknown == [UNKNOWN_LABEL]
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class TestClassWeightsToPosWeight(object):
def test_should_extract_selected_weights(self):
assert class_weights_to_pos_weight({
'a': 0.1,
'b': 0.2,
'c': 0.3
}, ['a', 'b'], False) == [0.1, 0.2]
def test_should_add_zero_if_unknown_class_is_true(self):
assert class_weights_to_pos_weight({
'a': 0.1,
'b': 0.2,
'c': 0.3
}, ['a', 'b'], True) == [0.1, 0.2, 0.0]
DEFAULT_ARGS = extend_dict(
CORE_DEFAULT_ARGS,
dict(
color_map=None,
class_weights=None,
channels=None,
use_separate_channels=False,
use_unknown_class=False
)
)
def create_args(*args, **kwargs):
d = extend_dict(*list(args) + [kwargs])
return namedtuple('args', d.keys())(**d)
class TestModel(object):
def test_parse_separate_channels_with_color_map_without_class_weights(self):
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
parse_color_map_from_file.return_value = {
'a': some_color(1),
'b': some_color(2),
'c': some_color(3)
}
args = create_args(
DEFAULT_ARGS,
color_map=COLOR_MAP_FILENAME,
class_weights=None,
channels=['a', 'b'],
use_separate_channels=True,
use_unknown_class=True
)
model = Model(args)
assert model.dimension_colors == [some_color(1), some_color(2)]
assert model.dimension_labels == ['a', 'b']
assert model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
assert model.pos_weight is None
def test_parse_separate_channels_with_color_map_and_class_weights(self):
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
parse_color_map_from_file.return_value = {
'a': some_color(1),
'b': some_color(2),
'c': some_color(3)
}
parse_json_file.return_value = {
'a': 0.1,
'b': 0.2,
'c': 0.3
}
args = create_args(
color_map=COLOR_MAP_FILENAME,
channels=['a', 'b'],
use_separate_channels=True,
use_unknown_class=True
)
model = Model(args)
assert model.dimension_colors == [some_color(1), some_color(2)]
assert model.dimension_labels == ['a', 'b']
assert (
model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
)
assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
assert model.pos_weight == [0.1, 0.2, 0.0]
@pytest.mark.slow
@pytest.mark.very_slow
class TestModelBuildGraph(object):
def test_should_build_train_graph_with_defaults(self):
with tf.Graph().as_default():
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
with patch.object(pix2pix_model, 'read_examples') as read_examples:
parse_color_map_from_file.return_value = SOME_COLOR_MAP
read_examples.return_value = (
tf.constant(1),
b'dummy'
)
args = create_args(
DEFAULT_ARGS
)
model = Model(args)
model.build_train_graph(DATA_PATH, BATCH_SIZE)
def test_should_build_train_graph_with_class_weights(self):
with tf.Graph().as_default():
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
with patch.object(pix2pix_model, 'read_examples') as read_examples:
parse_color_map_from_file.return_value = SOME_COLOR_MAP
parse_json_file.return_value = SOME_CLASS_WEIGHTS
read_examples.return_value = (
tf.constant(1),
b'dummy'
)
args = create_args(
DEFAULT_ARGS,
base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
color_map=COLOR_MAP_FILENAME,
class_weights=CLASS_WEIGHTS_FILENAME,
channels=['a', 'b'],
use_separate_channels=True,
use_unknown_class=True
)
model = Model(args)
model.build_train_graph(DATA_PATH, BATCH_SIZE)
class TestStrToList(object):
def test_should_parse_empty_string_as_empty_list(self):
assert str_to_list('') == []
def test_should_parse_blank_string_as_empty_list(self):
assert str_to_list(' ') == []
def test_should_parse_comma_separated_list(self):
assert str_to_list('a,b,c') == ['a', 'b', 'c']
def test_should_ignore_white_space_around_values(self):
assert str_to_list(' a , b , c ') == ['a', 'b', 'c']
def test_should_parse_channels(self):
args = model_args_parser().parse_args(['--channels', 'a,b,c'])
assert args.channels == ['a', 'b', 'c']
def test_should_set_channels_to_none_by_default(self):
args = model_args_parser().parse_args([])
assert args.channels is None
def test_should_allow_all_base_loss_options(self):
for base_loss in ALL_BASE_LOSS:
args = model_args_parser().parse_args(['--base_loss', base_loss])
assert args.base_loss == base_loss