Skip to content
Snippets Groups Projects
pix2pix_model_test.py 9.62 KiB
Newer Older
from collections import namedtuple
from mock import patch

Daniel Ecer's avatar
Daniel Ecer committed
import pytest
Daniel Ecer's avatar
Daniel Ecer committed
import tensorflow as tf

from sciencebeam_gym.utils.collection import (
  extend_dict
)

from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
  BaseLoss,
  ALL_BASE_LOSS
)

from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core_test import (
  DEFAULT_ARGS as CORE_DEFAULT_ARGS
)

import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as pix2pix_model

from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import (
  parse_color_map,
  color_map_to_labels,
  color_map_to_colors,
  color_map_to_colors_and_labels,
  colors_and_labels_with_unknown_class,
  UNKNOWN_COLOR,
  UNKNOWN_LABEL,
  Model,
  str_to_list,
Daniel Ecer's avatar
Daniel Ecer committed
  model_args_parser,
  class_weights_to_pos_weight
)

COLOR_MAP_FILENAME = 'color_map.conf'
Daniel Ecer's avatar
Daniel Ecer committed
CLASS_WEIGHTS_FILENAME = 'class-weights.json'
DATA_PATH = 'some/where/*.tfrecord'
BATCH_SIZE = 10

def some_color(i):
  return (i, i, i)

SOME_COLORS = [some_color(1), some_color(2), some_color(3)]
SOME_LABELS = ['a', 'b', 'c']
Daniel Ecer's avatar
Daniel Ecer committed
SOME_COLOR_MAP = {
  k: v for k, v in zip(SOME_LABELS, SOME_COLORS)
}
SOME_CLASS_WEIGHTS = {
  k: float(i) for i, k in enumerate(SOME_LABELS)
}

class TestParseColorMap(object):
  def test_should_use_fileio_to_load_file_and_pass_to_parser(self):
    with patch.object(pix2pix_model, 'FileIO') as FileIO:
      with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
        parse_color_map(COLOR_MAP_FILENAME)
        FileIO.assert_called_with(COLOR_MAP_FILENAME, 'r')
        parse_color_map_from_file.assert_called_with(FileIO.return_value.__enter__.return_value)

class TestColorMapToLabels(object):
  def test_should_use_color_maps_keys_by_default(self):
    color_map = {
      'a': some_color(1),
      'b': some_color(2),
      'c': some_color(3)
    }
    assert color_map_to_labels(color_map) == ['a', 'b', 'c']

  def test_should_return_specified_labels(self):
    color_map = {
      'a': some_color(1),
      'b': some_color(2),
      'c': some_color(3)
    }
    assert color_map_to_labels(color_map, ['b', 'a']) == ['b', 'a']

  def test_should_raise_error_if_specified_label_not_in_color_map(self):
    color_map = {
      'a': some_color(1),
      'c': some_color(3)
    }
    with raises(ValueError):
      color_map_to_labels(color_map, ['a', 'b'])

class TestColorMapToColors(object):
  def test_should_return_colors_for_labels(self):
    color_map = {
      'a': some_color(1),
      'b': some_color(2),
      'c': some_color(3)
    }
    assert color_map_to_colors(color_map, ['a', 'b']) == [
      some_color(1),
      some_color(2)
    ]

class TestColorMapToColorsAndLabels(object):
  def test_should_return_color_and_specified_labels(self):
    color_map = {
      'a': some_color(1),
      'b': some_color(2),
      'c': some_color(3)
    }
    colors, labels = color_map_to_colors_and_labels(color_map, ['a', 'b'])
    assert colors == [some_color(1), some_color(2)]
    assert labels == ['a', 'b']

  def test_should_return_color_using_sorted_keys_if_no_labels_are_specified(self):
    color_map = {
      'a': some_color(1),
      'b': some_color(2),
      'c': some_color(3)
    }
    colors, labels = color_map_to_colors_and_labels(color_map, None)
    assert colors == [some_color(1), some_color(2), some_color(3)]
    assert labels == ['a', 'b', 'c']

class TestColorsAndLabelsWithUnknownClass(object):
  def test_should_not_add_unknown_class_if_not_enabled(self):
    colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
      SOME_COLORS,
      SOME_LABELS,
      use_unknown_class=False
    )
    assert colors_with_unknown == SOME_COLORS
    assert labels_with_unknown == SOME_LABELS

  def test_should_add_unknown_class_if_enabled(self):
    colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
      SOME_COLORS,
      SOME_LABELS,
      use_unknown_class=True
    )
    assert colors_with_unknown == SOME_COLORS + [UNKNOWN_COLOR]
    assert labels_with_unknown == SOME_LABELS + [UNKNOWN_LABEL]

  def test_should_add_unknown_class_if_colors_are_empty(self):
    colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
      [],
      [],
      use_unknown_class=False
    )
    assert colors_with_unknown == [UNKNOWN_COLOR]
    assert labels_with_unknown == [UNKNOWN_LABEL]

Daniel Ecer's avatar
Daniel Ecer committed
class TestClassWeightsToPosWeight(object):
  def test_should_extract_selected_weights(self):
    assert class_weights_to_pos_weight({
      'a': 0.1,
      'b': 0.2,
      'c': 0.3
    }, ['a', 'b'], False) == [0.1, 0.2]

  def test_should_add_zero_if_unknown_class_is_true(self):
    assert class_weights_to_pos_weight({
      'a': 0.1,
      'b': 0.2,
      'c': 0.3
    }, ['a', 'b'], True) == [0.1, 0.2, 0.0]

DEFAULT_ARGS = extend_dict(
  CORE_DEFAULT_ARGS,
  dict(
    color_map=None,
    class_weights=None,
    channels=None,
    use_separate_channels=False,
    use_unknown_class=False
  )
)

def create_args(*args, **kwargs):
  d = extend_dict(*list(args) + [kwargs])
  return namedtuple('args', d.keys())(**d)
Daniel Ecer's avatar
Daniel Ecer committed
  def test_parse_separate_channels_with_color_map_without_class_weights(self):
    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
      parse_color_map_from_file.return_value = {
        'a': some_color(1),
        'b': some_color(2),
        'c': some_color(3)
      }
      args = create_args(
        DEFAULT_ARGS,
        color_map=COLOR_MAP_FILENAME,
        class_weights=None,
        channels=['a', 'b'],
        use_separate_channels=True,
        use_unknown_class=True
      )
      model = Model(args)
      assert model.dimension_colors == [some_color(1), some_color(2)]
      assert model.dimension_labels == ['a', 'b']
      assert model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
      assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
      assert model.pos_weight is None

  def test_parse_separate_channels_with_color_map_and_class_weights(self):
    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
      with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
        parse_color_map_from_file.return_value = {
          'a': some_color(1),
          'b': some_color(2),
          'c': some_color(3)
        }
Daniel Ecer's avatar
Daniel Ecer committed
        parse_json_file.return_value = {
          'a': 0.1,
          'b': 0.2,
          'c': 0.3
        }
Daniel Ecer's avatar
Daniel Ecer committed
          DEFAULT_ARGS,
          color_map=COLOR_MAP_FILENAME,
Daniel Ecer's avatar
Daniel Ecer committed
          class_weights=CLASS_WEIGHTS_FILENAME,
          channels=['a', 'b'],
          use_separate_channels=True,
          use_unknown_class=True
        )
        model = Model(args)
        assert model.dimension_colors == [some_color(1), some_color(2)]
        assert model.dimension_labels == ['a', 'b']
Daniel Ecer's avatar
Daniel Ecer committed
        assert (
          model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
        )
        assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
Daniel Ecer's avatar
Daniel Ecer committed
        assert model.pos_weight == [0.1, 0.2, 0.0]

@pytest.mark.slow
@pytest.mark.very_slow
class TestModelBuildGraph(object):
  def test_should_build_train_graph_with_defaults(self):
    with tf.Graph().as_default():
      with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
        with patch.object(pix2pix_model, 'read_examples') as read_examples:
          parse_color_map_from_file.return_value = SOME_COLOR_MAP
          read_examples.return_value = (
            tf.constant(1),
            b'dummy'
          )
          args = create_args(
            DEFAULT_ARGS
          )
          model = Model(args)
          model.build_train_graph(DATA_PATH, BATCH_SIZE)

  def test_should_build_train_graph_with_class_weights(self):
    with tf.Graph().as_default():
      with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
        with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
          with patch.object(pix2pix_model, 'read_examples') as read_examples:
            parse_color_map_from_file.return_value = SOME_COLOR_MAP
            parse_json_file.return_value = SOME_CLASS_WEIGHTS
            read_examples.return_value = (
              tf.constant(1),
              b'dummy'
            )
            args = create_args(
              DEFAULT_ARGS,
              base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
              color_map=COLOR_MAP_FILENAME,
              class_weights=CLASS_WEIGHTS_FILENAME,
              channels=['a', 'b'],
              use_separate_channels=True,
              use_unknown_class=True
            )
            model = Model(args)
            model.build_train_graph(DATA_PATH, BATCH_SIZE)

class TestStrToList(object):
  def test_should_parse_empty_string_as_empty_list(self):
    assert str_to_list('') == []

  def test_should_parse_blank_string_as_empty_list(self):
    assert str_to_list(' ') == []

  def test_should_parse_comma_separated_list(self):
    assert str_to_list('a,b,c') == ['a', 'b', 'c']

  def test_should_ignore_white_space_around_values(self):
    assert str_to_list(' a , b , c ') == ['a', 'b', 'c']

Daniel Ecer's avatar
Daniel Ecer committed
class TestModelArgsParser(object):
  def test_should_parse_channels(self):
    args = model_args_parser().parse_args(['--channels', 'a,b,c'])
    assert args.channels == ['a', 'b', 'c']

  def test_should_set_channels_to_none_by_default(self):
    args = model_args_parser().parse_args([])
    assert args.channels is None
Daniel Ecer's avatar
Daniel Ecer committed

  def test_should_allow_all_base_loss_options(self):
    for base_loss in ALL_BASE_LOSS:
      args = model_args_parser().parse_args(['--base_loss', base_loss])
      assert args.base_loss == base_loss