From 71081bfdc47107d63e168681ee2c3edf7b599304 Mon Sep 17 00:00:00 2001
From: Daniel Ecer <de-code@users.noreply.github.com>
Date: Thu, 7 Dec 2017 18:42:42 +0000
Subject: [PATCH] use default unknown weight and get unknown weight from class
 weights if present

---
 .../trainer/models/pix2pix/pix2pix_model.py   | 24 ++++++++++-------
 .../models/pix2pix/pix2pix_model_test.py      | 26 ++++++++++++++++---
 2 files changed, 38 insertions(+), 12 deletions(-)

diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py
index 8cfe085..79e3801 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py
@@ -39,6 +39,11 @@ from sciencebeam_gym.trainer.models.pix2pix.evaluate import (
 )
 
 
+UNKNOWN_COLOR = (255, 255, 255)
+UNKNOWN_LABEL = 'unknown'
+
+DEFAULT_UNKNOWN_CLASS_WEIGHT = 0.1
+
 class GraphMode(object):
   TRAIN = 1
   EVALUATE = 2
@@ -185,7 +190,7 @@ def add_model_summary_images(
     tensors, 'target', tensors.annotation_tensor
   )
   if (has_unknown_class or not use_separate_channels) and dimension_labels is not None:
-    dimension_labels_with_unknown = dimension_labels + ['unknown']
+    dimension_labels_with_unknown = dimension_labels + [UNKNOWN_LABEL]
     dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)]
   else:
     dimension_labels_with_unknown = dimension_labels
@@ -260,9 +265,12 @@ def parse_json_file(filename):
   with FileIO(filename, 'r') as f:
     return json.load(f)
 
-def class_weights_to_pos_weight(class_weights, labels, use_unknown_class):
+def class_weights_to_pos_weight(
+  class_weights, labels,
+  use_unknown_class, unknown_class_weight=DEFAULT_UNKNOWN_CLASS_WEIGHT):
+
   pos_weight = [class_weights[k] for k in labels]
-  return pos_weight + [0.0] if use_unknown_class else pos_weight
+  return pos_weight + [unknown_class_weight] if use_unknown_class else pos_weight
 
 def parse_color_map(color_map_filename):
   with FileIO(color_map_filename, 'r') as config_f:
@@ -283,9 +291,6 @@ def color_map_to_labels(color_map, labels=None):
 def color_map_to_colors(color_map, labels):
   return [color_map[k] for k in labels]
 
-UNKNOWN_COLOR = (255, 255, 255)
-UNKNOWN_LABEL = 'unknown'
-
 def colors_and_labels_with_unknown_class(colors, labels, use_unknown_class):
   if use_unknown_class or not colors:
     return (
@@ -330,11 +335,12 @@ class Model(object):
       )
       logger.debug("dimension_colors: %s", self.dimension_colors)
       logger.debug("dimension_labels: %s", self.dimension_labels)
-      if self.args.class_weights:
+      if class_weights:
         self.pos_weight = class_weights_to_pos_weight(
-          parse_json_file(self.args.class_weights),
+          class_weights,
           self.dimension_labels,
-          self.use_unknown_class
+          self.use_separate_channels,
+          class_weights.get(UNKNOWN_LABEL, DEFAULT_UNKNOWN_CLASS_WEIGHT)
         )
         logger.info("pos_weight: %s", self.pos_weight)
 
diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py
index af2c849..1cbd457 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py
@@ -28,6 +28,7 @@ from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import (
   colors_and_labels_with_unknown_class,
   UNKNOWN_COLOR,
   UNKNOWN_LABEL,
+  DEFAULT_UNKNOWN_CLASS_WEIGHT,
   Model,
   str_to_list,
   model_args_parser,
@@ -137,7 +138,9 @@ class TestClassWeightsToPosWeight(object):
       'a': 0.1,
       'b': 0.2,
       'c': 0.3
-    }, ['a', 'b'], True) == [0.1, 0.2, 0.0]
+    }, ['a', 'b'], True, DEFAULT_UNKNOWN_CLASS_WEIGHT) == (
+      [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT]
+    )
 
 DEFAULT_ARGS = extend_dict(
   CORE_DEFAULT_ARGS,
@@ -205,7 +208,7 @@ class TestModel(object):
           model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
         )
         assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
-        assert model.pos_weight == [0.1, 0.2, 0.0]
+        assert model.pos_weight == [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT]
 
   def test_should_only_include_labels_with_non_zero_class_labels_by_default(self):
     with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
@@ -230,7 +233,24 @@ class TestModel(object):
         model = Model(args)
         assert model.dimension_labels == ['a', 'c']
         assert model.dimension_colors == [some_color(1), some_color(3)]
-        assert model.pos_weight == [0.1, 0.3, 0.0]
+        assert model.pos_weight == [0.1, 0.3, DEFAULT_UNKNOWN_CLASS_WEIGHT]
+
+  def test_should_use_unknown_class_weight_from_configuration(self):
+    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
+      with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
+        parse_color_map_from_file.return_value = SOME_COLOR_MAP
+        parse_json_file.return_value = extend_dict(SOME_CLASS_WEIGHTS, {
+          'unknown': 0.99
+        })
+        args = create_args(
+          DEFAULT_ARGS,
+          color_map=COLOR_MAP_FILENAME,
+          class_weights=CLASS_WEIGHTS_FILENAME,
+          use_separate_channels=True,
+          use_unknown_class=True
+        )
+        model = Model(args)
+        assert model.pos_weight[-1] == 0.99
 
 @pytest.mark.slow
 @pytest.mark.very_slow
-- 
GitLab