Skip to content
Snippets Groups Projects
Commit 71081bfd authored by Daniel Ecer's avatar Daniel Ecer
Browse files

use default unknown weight and get unknown weight from class weights if present

parent d05fc08e
No related branches found
No related tags found
No related merge requests found
......@@ -39,6 +39,11 @@ from sciencebeam_gym.trainer.models.pix2pix.evaluate import (
)
UNKNOWN_COLOR = (255, 255, 255)
UNKNOWN_LABEL = 'unknown'
DEFAULT_UNKNOWN_CLASS_WEIGHT = 0.1
class GraphMode(object):
TRAIN = 1
EVALUATE = 2
......@@ -185,7 +190,7 @@ def add_model_summary_images(
tensors, 'target', tensors.annotation_tensor
)
if (has_unknown_class or not use_separate_channels) and dimension_labels is not None:
dimension_labels_with_unknown = dimension_labels + ['unknown']
dimension_labels_with_unknown = dimension_labels + [UNKNOWN_LABEL]
dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)]
else:
dimension_labels_with_unknown = dimension_labels
......@@ -260,9 +265,12 @@ def parse_json_file(filename):
with FileIO(filename, 'r') as f:
return json.load(f)
def class_weights_to_pos_weight(class_weights, labels, use_unknown_class):
def class_weights_to_pos_weight(
class_weights, labels,
use_unknown_class, unknown_class_weight=DEFAULT_UNKNOWN_CLASS_WEIGHT):
pos_weight = [class_weights[k] for k in labels]
return pos_weight + [0.0] if use_unknown_class else pos_weight
return pos_weight + [unknown_class_weight] if use_unknown_class else pos_weight
def parse_color_map(color_map_filename):
with FileIO(color_map_filename, 'r') as config_f:
......@@ -283,9 +291,6 @@ def color_map_to_labels(color_map, labels=None):
def color_map_to_colors(color_map, labels):
return [color_map[k] for k in labels]
UNKNOWN_COLOR = (255, 255, 255)
UNKNOWN_LABEL = 'unknown'
def colors_and_labels_with_unknown_class(colors, labels, use_unknown_class):
if use_unknown_class or not colors:
return (
......@@ -330,11 +335,12 @@ class Model(object):
)
logger.debug("dimension_colors: %s", self.dimension_colors)
logger.debug("dimension_labels: %s", self.dimension_labels)
if self.args.class_weights:
if class_weights:
self.pos_weight = class_weights_to_pos_weight(
parse_json_file(self.args.class_weights),
class_weights,
self.dimension_labels,
self.use_unknown_class
self.use_separate_channels,
class_weights.get(UNKNOWN_LABEL, DEFAULT_UNKNOWN_CLASS_WEIGHT)
)
logger.info("pos_weight: %s", self.pos_weight)
......
......@@ -28,6 +28,7 @@ from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import (
colors_and_labels_with_unknown_class,
UNKNOWN_COLOR,
UNKNOWN_LABEL,
DEFAULT_UNKNOWN_CLASS_WEIGHT,
Model,
str_to_list,
model_args_parser,
......@@ -137,7 +138,9 @@ class TestClassWeightsToPosWeight(object):
'a': 0.1,
'b': 0.2,
'c': 0.3
}, ['a', 'b'], True) == [0.1, 0.2, 0.0]
}, ['a', 'b'], True, DEFAULT_UNKNOWN_CLASS_WEIGHT) == (
[0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT]
)
DEFAULT_ARGS = extend_dict(
CORE_DEFAULT_ARGS,
......@@ -205,7 +208,7 @@ class TestModel(object):
model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
)
assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
assert model.pos_weight == [0.1, 0.2, 0.0]
assert model.pos_weight == [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT]
def test_should_only_include_labels_with_non_zero_class_labels_by_default(self):
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
......@@ -230,7 +233,24 @@ class TestModel(object):
model = Model(args)
assert model.dimension_labels == ['a', 'c']
assert model.dimension_colors == [some_color(1), some_color(3)]
assert model.pos_weight == [0.1, 0.3, 0.0]
assert model.pos_weight == [0.1, 0.3, DEFAULT_UNKNOWN_CLASS_WEIGHT]
def test_should_use_unknown_class_weight_from_configuration(self):
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
parse_color_map_from_file.return_value = SOME_COLOR_MAP
parse_json_file.return_value = extend_dict(SOME_CLASS_WEIGHTS, {
'unknown': 0.99
})
args = create_args(
DEFAULT_ARGS,
color_map=COLOR_MAP_FILENAME,
class_weights=CLASS_WEIGHTS_FILENAME,
use_separate_channels=True,
use_unknown_class=True
)
model = Model(args)
assert model.pos_weight[-1] == 0.99
@pytest.mark.slow
@pytest.mark.very_slow
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment