Skip to content
Snippets Groups Projects
pix2pix_model.py 18.2 KiB
Newer Older
Daniel Ecer's avatar
Daniel Ecer committed
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
Daniel Ecer's avatar
Daniel Ecer committed
import json
Daniel Ecer's avatar
Daniel Ecer committed

Daniel Ecer's avatar
Daniel Ecer committed
import tensorflow as tf


Daniel Ecer's avatar
Daniel Ecer committed
from tensorflow.python.lib.io.file_io import FileIO # pylint: disable=E0611
Daniel Ecer's avatar
Daniel Ecer committed

from sciencebeam_gym.trainer.util import (
  read_examples
)

from sciencebeam_gym.preprocess.color_map import (
  parse_color_map_from_file
Daniel Ecer's avatar
Daniel Ecer committed
)

from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
  find_nearest_centroid_indices
)

Daniel Ecer's avatar
Daniel Ecer committed
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
  BaseLoss,
Daniel Ecer's avatar
Daniel Ecer committed
  ALL_BASE_LOSS,
Daniel Ecer's avatar
Daniel Ecer committed
  create_pix2pix_model,
  create_other_summaries
)

Daniel Ecer's avatar
Daniel Ecer committed
from sciencebeam_gym.trainer.models.pix2pix.evaluate import (
  evaluate_separate_channels,
  evaluate_predictions,
Daniel Ecer's avatar
Daniel Ecer committed
  evaluation_summary
)

Daniel Ecer's avatar
Daniel Ecer committed

class GraphMode(object):
  TRAIN = 1
  EVALUATE = 2
  PREDICT = 3

def get_logger():
  return logging.getLogger(__name__)


class GraphReferences(object):
  """Holder of base tensors used for training model using common task."""

  def __init__(self):
    self.examples = None
    self.train = None
    self.global_step = None
    self.metric_updates = []
    self.metric_values = []
    self.keys = None
    self.predictions = []
    self.input_jpeg = None
    self.input_uri = None
    self.image_tensor = None
    self.annotation_uri = None
    self.annotation_tensor = None
    self.separate_channel_annotation_tensor = None
    self.class_labels_tensor = None
    self.pred = None
    self.probabilities = None
    self.summary = None
    self.summaries = None
    self.image_tensors = None
    self.targets_class_indices = None
    self.outputs_class_indices = None
Daniel Ecer's avatar
Daniel Ecer committed
    self.output_layer_labels = None
    self.evaluation_result = None
Daniel Ecer's avatar
Daniel Ecer committed

def colors_to_dimensions(image_tensor, colors, use_unknown_class=False):
  with tf.variable_scope("colors_to_dimensions"):
    single_label_tensors = []
    ones = tf.fill(image_tensor.shape[0:-1], 1.0, name='ones')
    zeros = tf.fill(ones.shape, 0.0, name='zeros')
    for single_label_color in colors:
      i = len(single_label_tensors)
      with tf.variable_scope("channel_{}".format(i)):
        is_color = tf.reduce_all(
          tf.equal(image_tensor, single_label_color),
          axis=-1,
          name='is_color'
        )
        single_label_tensor = tf.where(
          is_color,
          ones,
          zeros
        )
        single_label_tensors.append(single_label_tensor)
    if use_unknown_class:
      with tf.variable_scope("unknown_class"):
        single_label_tensors.append(
          tf.where(
            tf.add_n(single_label_tensors) < 0.5,
            ones,
            zeros
          )
        )
    return tf.stack(single_label_tensors, axis=-1)
Daniel Ecer's avatar
Daniel Ecer committed

def batch_dimensions_to_colors_list(image_tensor, colors):
  batch_images = []
  for i, single_label_color in enumerate(colors):
    batch_images.append(
      tf.expand_dims(
        image_tensor[:, :, :, i],
        axis=-1
      ) * ([x / 255.0 for x in single_label_color])
    )
  return batch_images

def batch_dimensions_to_most_likely_colors_list(image_tensor, colors):
  with tf.variable_scope("batch_dimensions_to_most_likely_colors_list"):
    colors_tensor = tf.constant(colors, dtype=tf.uint8, name='colors')
    most_likely_class_index = tf.argmax(image_tensor, 3)
    return tf.gather(params=colors_tensor, indices=most_likely_class_index)
Daniel Ecer's avatar
Daniel Ecer committed
def add_summary_image(tensors, name, image):
  tensors.image_tensors[name] = image
  tf.summary.image(name, image)

def convert_image(image_tensor):
  return tf.image.convert_image_dtype(
    image_tensor,
    dtype=tf.uint8,
    saturate=True
  )

def add_simple_summary_image(tensors, name, image_tensor):
  with tf.name_scope(name):
    add_summary_image(
      tensors,
      name,
      convert_image(image_tensor)
    )

def replace_black_with_white_color(image_tensor):
  is_black = tf.reduce_all(
  tf.equal(image_tensor, (0, 0, 0)),
    axis=-1
  )
  is_black = tf.stack([is_black] * 3, axis=-1)
  return tf.where(
    is_black,
    255 * tf.ones_like(image_tensor),
    image_tensor
  )

def combine_image(batch_images, replace_black_with_white=False):
  clipped_batch_images = [
    tf.clip_by_value(batch_image, 0.0, 1.0)
    for batch_image in batch_images
  ]
Daniel Ecer's avatar
Daniel Ecer committed
  combined_image = convert_image(
Daniel Ecer's avatar
Daniel Ecer committed
      lambda a, b: a + b,
Daniel Ecer's avatar
Daniel Ecer committed
    )
  )
  if replace_black_with_white:
    combined_image = replace_black_with_white_color(combined_image)
  return combined_image

def remove_last(a):
  return a[:-1]

def add_model_summary_images(
  tensors, dimension_colors, dimension_labels,
  use_separate_channels=False,
  has_unknown_class=False):

Daniel Ecer's avatar
Daniel Ecer committed
  tensors.summaries = {}
  add_simple_summary_image(
    tensors, 'input', tensors.image_tensor
  )
  add_simple_summary_image(
    tensors, 'target', tensors.annotation_tensor
  )
  if (has_unknown_class or not use_separate_channels) and dimension_labels is not None:
    dimension_labels_with_unknown = dimension_labels + ['unknown']
    dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)]
  else:
    dimension_labels_with_unknown = dimension_labels
    dimension_colors_with_unknown = dimension_colors
  if use_separate_channels:
    for name, outputs in [
        ('targets', tensors.separate_channel_annotation_tensor),
        ('outputs', tensors.pred)
      ]:

      batch_images = batch_dimensions_to_colors_list(
        outputs,
        dimension_colors_with_unknown
Daniel Ecer's avatar
Daniel Ecer committed
      )
      batch_images_excluding_unknown = (
        remove_last(batch_images)
        if has_unknown_class
        else batch_images
      )
      for i, (batch_image, dimension_label) in enumerate(zip(
        batch_images, dimension_labels_with_unknown)):

        suffix = "_{}_{}".format(
          i, dimension_label if dimension_label else 'unknown_label'
        )
        add_simple_summary_image(
          tensors, name + suffix, batch_image
        )
      with tf.name_scope(name + "_combined"):
        combined_image = combine_image(batch_images_excluding_unknown)
        if name == 'outputs':
          tensors.summaries['output_image'] = combined_image
        add_summary_image(
          tensors,
          name + "_combined",
          combined_image
        )
        with tf.name_scope(name + "_most_likely"):
            name + "_most_likely",
            batch_dimensions_to_most_likely_colors_list(
              outputs,
              dimension_colors_with_unknown)
          )
Daniel Ecer's avatar
Daniel Ecer committed
  else:
    add_simple_summary_image(
      tensors,
      "output",
      tensors.pred
    )
    if tensors.outputs_class_indices is not None:
      outputs = tensors.pred
      with tf.name_scope("outputs_most_likely"):
        colors_tensor = tf.constant(
          dimension_colors_with_unknown,
          dtype=tf.uint8, name='colors'
        )
        add_summary_image(
          tensors,
          "outputs_most_likely",
          tf.gather(
            params=colors_tensor,
            indices=tensors.outputs_class_indices
          )
        )
Daniel Ecer's avatar
Daniel Ecer committed
    tensors.summaries['output_image'] = tensors.image_tensors['output']

Daniel Ecer's avatar
Daniel Ecer committed
def parse_json_file(filename):
  with FileIO(filename, 'r') as f:
    return json.load(f)

def class_weights_to_pos_weight(class_weights, labels, use_unknown_class):
  pos_weight = [class_weights[k] for k in labels]
  return pos_weight + [0.0] if use_unknown_class else pos_weight

def parse_color_map(color_map_filename):
  with FileIO(color_map_filename, 'r') as config_f:
    return parse_color_map_from_file(
      config_f
    )

def color_map_to_labels(color_map, labels=None):
  if labels:
    if not all(color_map.has_key(k) for k in labels):
      raise ValueError(
        'not all lables found in color map, labels=%s, available keys=%s' %
        (labels, color_map.keys())
      )
    return labels
  return sorted(color_map.keys())

def color_map_to_colors(color_map, labels):
  return [color_map[k] for k in labels]

def color_map_to_colors_and_labels(color_map, labels=None):
  labels = color_map_to_labels(color_map, labels)
  colors = color_map_to_colors(color_map, labels)
  return colors, labels

def parse_color_map_to_colors_and_labels(color_map_filename, labels=None):
  color_map = parse_color_map(color_map_filename)
  return color_map_to_colors_and_labels(color_map, labels)

UNKNOWN_COLOR = (255, 255, 255)
UNKNOWN_LABEL = 'unknown'

def colors_and_labels_with_unknown_class(colors, labels, use_unknown_class):
  if use_unknown_class or not colors:
    return (
      colors + [UNKNOWN_COLOR],
      labels + [UNKNOWN_LABEL]
    )
  else:
    return colors, labels

Daniel Ecer's avatar
Daniel Ecer committed
class Model(object):
  def __init__(self, args):
    self.args = args
    self.image_width = 256
    self.image_height = 256
    self.color_map = None
Daniel Ecer's avatar
Daniel Ecer committed
    self.pos_weight = None
Daniel Ecer's avatar
Daniel Ecer committed
    self.dimension_colors = None
    self.dimension_labels = None
    self.use_unknown_class = args.use_unknown_class
    self.use_separate_channels = args.use_separate_channels and self.args.color_map is not None
Daniel Ecer's avatar
Daniel Ecer committed
    logger = get_logger()
    logger.info('use_separate_channels: %s', self.use_separate_channels)
Daniel Ecer's avatar
Daniel Ecer committed
    if self.args.color_map:
      self.dimension_colors, self.dimension_labels = parse_color_map_to_colors_and_labels(
        args.color_map,
        args.channels
      )
      self.dimension_colors_with_unknown, self.dimension_labels_with_unknown = (
        colors_and_labels_with_unknown_class(
          self.dimension_colors,
          self.dimension_labels,
          self.use_unknown_class
        )
      )
Daniel Ecer's avatar
Daniel Ecer committed
      logger.debug("dimension_colors: %s", self.dimension_colors)
      logger.debug("dimension_labels: %s", self.dimension_labels)
Daniel Ecer's avatar
Daniel Ecer committed
      if self.args.class_weights:
        self.pos_weight = class_weights_to_pos_weight(
          parse_json_file(self.args.class_weights),
          self.dimension_labels,
          self.use_unknown_class
        )
        logger.info("pos_weight: %s", self.pos_weight)
Daniel Ecer's avatar
Daniel Ecer committed

  def build_graph(self, data_paths, batch_size, graph_mode):
    logger = get_logger()
    logger.debug('batch_size: %s', batch_size)
    tensors = GraphReferences()
    is_training = (
      graph_mode == GraphMode.TRAIN or
      graph_mode == GraphMode.EVALUATE
    )
    if data_paths:
      get_logger().info('reading examples from %s', data_paths)
Daniel Ecer's avatar
Daniel Ecer committed
      tensors.keys, tensors.examples = read_examples(
        data_paths,
        shuffle=(graph_mode == GraphMode.TRAIN),
        num_epochs=None if is_training else 2
      )
    else:
      tensors.examples = tf.placeholder(tf.string, name='input', shape=(None,))
    with tf.name_scope('inputs'):
      feature_map = {
        'input_uri':
          tf.FixedLenFeature(
            shape=[], dtype=tf.string, default_value=['']
          ),
        'annotation_uri':
          tf.FixedLenFeature(
            shape=[], dtype=tf.string, default_value=['']
          ),
        'input_image':
          tf.FixedLenFeature(
            shape=[], dtype=tf.string
          ),
        'annotation_image':
          tf.FixedLenFeature(
            shape=[], dtype=tf.string
          )
      }
      logging.info('tensors.examples: %s', tensors.examples)
    parsed = tf.parse_single_example(tensors.examples, features=feature_map)

    tensors.image_tensors = {}

    tensors.input_uri = tf.squeeze(parsed['input_uri'])
    tensors.annotation_uri = tf.squeeze(parsed['annotation_uri'])
    raw_input_image = tf.squeeze(parsed['input_image'])
    logging.info('raw_input_image: %s', raw_input_image)
    raw_annotation_image = tf.squeeze(parsed['annotation_image'])
    tensors.image_tensor = tf.image.decode_png(raw_input_image, channels=3)
    tensors.annotation_tensor = tf.image.decode_png(raw_annotation_image, channels=3)

    # TODO resize_images and tf.cast did not work on input image
    #   but did work on annotation image
    tensors.image_tensor = tf.image.resize_image_with_crop_or_pad(
      tensors.image_tensor, self.image_height, self.image_width
    )

    tensors.image_tensor = tf.image.convert_image_dtype(tensors.image_tensor, tf.float32)

    tensors.annotation_tensor = tf.image.resize_image_with_crop_or_pad(
      tensors.annotation_tensor, self.image_height, self.image_width
    )

    if self.use_separate_channels:
Daniel Ecer's avatar
Daniel Ecer committed
      tensors.separate_channel_annotation_tensor = colors_to_dimensions(
        tensors.annotation_tensor,
        self.dimension_colors,
        use_unknown_class=self.use_unknown_class
Daniel Ecer's avatar
Daniel Ecer committed
      )
    else:
      tensors.annotation_tensor = tf.image.convert_image_dtype(tensors.annotation_tensor, tf.float32)
      tensors.separate_channel_annotation_tensor = tensors.annotation_tensor

    (
      tensors.input_uri,
      tensors.annotation_uri,
      tensors.image_tensor,
      tensors.annotation_tensor,
      tensors.separate_channel_annotation_tensor
    ) = tf.train.batch(
      [
        tensors.input_uri,
        tensors.annotation_uri,
        tensors.image_tensor,
        tensors.annotation_tensor,
        tensors.separate_channel_annotation_tensor
      ],
      batch_size=batch_size
    )

    pix2pix_model = create_pix2pix_model(
      tensors.image_tensor,
      tensors.separate_channel_annotation_tensor,
Daniel Ecer's avatar
Daniel Ecer committed
      self.args,
      pos_weight=self.pos_weight
Daniel Ecer's avatar
Daniel Ecer committed
    )

    if self.use_separate_channels:
Daniel Ecer's avatar
Daniel Ecer committed
      with tf.name_scope("evaluation"):
Daniel Ecer's avatar
Daniel Ecer committed
        tensors.output_layer_labels = tf.constant(self.dimension_labels)
Daniel Ecer's avatar
Daniel Ecer committed
        evaluation_result = evaluate_separate_channels(
          targets=pix2pix_model.targets,
          outputs=pix2pix_model.outputs,
          has_unknown_class=self.use_unknown_class
Daniel Ecer's avatar
Daniel Ecer committed
        )
        tensors.evaluation_result = evaluation_result
Daniel Ecer's avatar
Daniel Ecer committed
        evaluation_summary(evaluation_result, self.dimension_labels)
    else:
      with tf.name_scope('evaluation'):
        if self.dimension_colors:
          tensors.output_layer_labels = tf.constant(self.dimension_labels)
          colors_tensor = tf.constant(
            self.dimension_colors_with_unknown,
            dtype=tf.float32
          ) / 255.0
          tensors.outputs_class_indices = find_nearest_centroid_indices(
            predictions=pix2pix_model.outputs,
            centroids=colors_tensor
          )
          tensors.targets_class_indices = find_nearest_centroid_indices(
            predictions=pix2pix_model.targets,
            centroids=colors_tensor
          )
          evaluation_result = evaluate_predictions(
            labels=tensors.targets_class_indices,
            predictions=tensors.outputs_class_indices,
            n_classes=len(self.dimension_colors_with_unknown),
            has_unknown_class=self.use_unknown_class
          )
          tensors.evaluation_result = evaluation_result
          evaluation_summary(evaluation_result, self.dimension_labels)
Daniel Ecer's avatar
Daniel Ecer committed

Daniel Ecer's avatar
Daniel Ecer committed
    tensors.global_step = pix2pix_model.global_step
    tensors.train = pix2pix_model.train
    tensors.class_labels_tensor = tensors.annotation_tensor
    tensors.pred = pix2pix_model.outputs
    tensors.probabilities = pix2pix_model.outputs
    tensors.metric_values = [pix2pix_model.discrim_loss]

    add_model_summary_images(
      tensors,
      self.dimension_colors,
      self.dimension_labels,
      use_separate_channels=self.use_separate_channels,
      has_unknown_class=self.use_unknown_class
    )
Daniel Ecer's avatar
Daniel Ecer committed

    # tensors.summaries = create_summaries(pix2pix_model)
    create_other_summaries(pix2pix_model)

    tensors.summary = tf.summary.merge_all()
    return tensors

  def build_train_graph(self, data_paths, batch_size):
    return self.build_graph(data_paths, batch_size, GraphMode.TRAIN)

  def build_eval_graph(self, data_paths, batch_size):
    return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE)

  def initialize(self, session):
    pass

  def format_metric_values(self, metric_values):
    """Formats metric values - used for logging purpose."""

    # Early in training, metric_values may actually be None.
    loss_str = 'N/A'
    accuracy_str = 'N/A'
    try:
      loss_str = '%.3f' % metric_values[0]
      accuracy_str = '%.3f' % metric_values[1]
    except (TypeError, IndexError):
      pass

    return '%s, %s' % (loss_str, accuracy_str)

def str_to_bool(s):
  return s.lower() in ('yes', 'true', '1')

def str_to_list(s):
  s = s.strip()
  if not s:
    return []
  return [x.strip() for x in s.split(',')]

Daniel Ecer's avatar
Daniel Ecer committed
def model_args_parser():
  parser = argparse.ArgumentParser()
  parser.add_argument(
    "--ngf", type=int, default=64, help="number of generator filters in first conv layer"
  )
  parser.add_argument(
    "--ndf", type=int, default=64, help="number of discriminator filters in first conv layer"
  )
  parser.add_argument(
    "--lr", type=float, default=0.0002, help="initial learning rate for adam"
  )
  parser.add_argument(
    "--beta1", type=float, default=0.5, help="momentum term of adam"
  )
  parser.add_argument(
    "--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient"
  )
  parser.add_argument(
    "--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient"
  )
Daniel Ecer's avatar
Daniel Ecer committed

  parser.add_argument(
    '--color_map',
    type=str,
    help='The path to the color map configuration.'
  )
Daniel Ecer's avatar
Daniel Ecer committed
  parser.add_argument(
    '--class_weights',
    type=str,
    help='The path to the class weights configuration.'
  )
  parser.add_argument(
    '--channels',
    type=str_to_list,
    help='The channels to use (subset of color map), otherwise all of the labels will be used'
  )
  parser.add_argument(
    '--use_unknown_class',
    default=True,
    help='Use unknown class channel (if color map is provided)'
  )
  parser.add_argument(
    '--use_separate_channels',
    type=str_to_bool,
    default=False,
    help='The separate output channels per annotation (if color map is provided)'
  )
  parser.add_argument(
    '--use_separate_discriminator_channels',
    type=str_to_bool,
    default=False,
    help='The separate discriminator channels per annotation (if color map is provided)'
  )
  parser.add_argument(
    '--use_separate_discriminators',
    type=str_to_bool,
    default=False,
    help='The separate discriminators per annotation (if color map is provided)'
  )
  parser.add_argument(
    '--base_loss',
    type=str,
    default=BaseLoss.L1,
Daniel Ecer's avatar
Daniel Ecer committed
    choices=ALL_BASE_LOSS,
    help='The base loss function to use'
  )
Daniel Ecer's avatar
Daniel Ecer committed
  return parser


def create_model(argv=None):
  """Factory method that creates model to be used by generic task.py."""
  parser = model_args_parser()
  args, task_args = parser.parse_known_args(argv)
  return Model(args), task_args