-
Daniel Ecer authoredb027e564
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
pix2pix_model.py 16.26 KiB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import argparse
import tensorflow as tf
import six
from six.moves.configparser import ConfigParser
from tensorflow.python.lib.io.file_io import FileIO
from sciencebeam_gym.trainer.util import (
read_examples
)
from sciencebeam_gym.tools.colorize_image import (
parse_color_map_from_configparser
)
from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
find_nearest_centroid_indices
)
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
BaseLoss,
create_pix2pix_model,
create_other_summaries
)
from sciencebeam_gym.trainer.models.pix2pix.evaluate import (
evaluate_separate_channels,
evaluate_predictions,
evaluation_summary
)
class GraphMode(object):
TRAIN = 1
EVALUATE = 2
PREDICT = 3
def get_logger():
return logging.getLogger(__name__)
class GraphReferences(object):
"""Holder of base tensors used for training model using common task."""
def __init__(self):
self.examples = None
self.train = None
self.global_step = None
self.metric_updates = []
self.metric_values = []
self.keys = None
self.predictions = []
self.input_jpeg = None
self.input_uri = None
self.image_tensor = None
self.annotation_uri = None
self.annotation_tensor = None
self.separate_channel_annotation_tensor = None
self.class_labels_tensor = None
self.pred = None
self.probabilities = None
self.summary = None
self.summaries = None
self.image_tensors = None
self.targets_class_indices = None
self.outputs_class_indices = None
self.output_layer_labels = None
self.evaluation_result = None
def colors_to_dimensions(image_tensor, colors, use_unknown_class=False):
with tf.variable_scope("colors_to_dimensions"):
single_label_tensors = []
ones = tf.fill(image_tensor.shape[0:-1], 1.0, name='ones')
zeros = tf.fill(ones.shape, 0.0, name='zeros')
for single_label_color in colors:
i = len(single_label_tensors)
with tf.variable_scope("channel_{}".format(i)):
is_color = tf.reduce_all(
tf.equal(image_tensor, single_label_color),
axis=-1,
name='is_color'
)
single_label_tensor = tf.where(
is_color,
ones,
zeros
)
single_label_tensors.append(single_label_tensor)
if use_unknown_class:
with tf.variable_scope("unknown_class"):
single_label_tensors.append(
tf.where(
tf.add_n(single_label_tensors) < 0.5,
ones,
zeros
)
)
return tf.stack(single_label_tensors, axis=-1)
def batch_dimensions_to_colors_list(image_tensor, colors):
batch_images = []
for i, single_label_color in enumerate(colors):
batch_images.append(
tf.expand_dims(
image_tensor[:, :, :, i],
axis=-1
) * ([x / 255.0 for x in single_label_color])
)
return batch_images
def batch_dimensions_to_most_likely_colors_list(image_tensor, colors):
with tf.variable_scope("batch_dimensions_to_most_likely_colors_list"):
colors_tensor = tf.constant(colors, dtype=tf.uint8, name='colors')
most_likely_class_index = tf.argmax(image_tensor, 3)
return tf.gather(params=colors_tensor, indices=most_likely_class_index)
def add_summary_image(tensors, name, image):
tensors.image_tensors[name] = image
tf.summary.image(name, image)
def convert_image(image_tensor):
return tf.image.convert_image_dtype(
image_tensor,
dtype=tf.uint8,
saturate=True
)
def add_simple_summary_image(tensors, name, image_tensor):
with tf.name_scope(name):
add_summary_image(
tensors,
name,
convert_image(image_tensor)
)
def replace_black_with_white_color(image_tensor):
is_black = tf.reduce_all(
tf.equal(image_tensor, (0, 0, 0)),
axis=-1
)
is_black = tf.stack([is_black] * 3, axis=-1)
return tf.where(
is_black,
255 * tf.ones_like(image_tensor),
image_tensor
)
def combine_image(batch_images, replace_black_with_white=False):
clipped_batch_images = [
tf.clip_by_value(batch_image, 0.0, 1.0)
for batch_image in batch_images
]
combined_image = convert_image(
six.moves.reduce(
lambda a, b: a + b,
clipped_batch_images
)
)
if replace_black_with_white:
combined_image = replace_black_with_white_color(combined_image)
return combined_image
def remove_last(a):
return a[:-1]
def add_model_summary_images(
tensors, dimension_colors, dimension_labels,
use_separate_channels=False,
has_unknown_class=False):
tensors.summaries = {}
add_simple_summary_image(
tensors, 'input', tensors.image_tensor
)
add_simple_summary_image(
tensors, 'target', tensors.annotation_tensor
)
if (has_unknown_class or not use_separate_channels) and dimension_labels is not None:
dimension_labels_with_unknown = dimension_labels + ['unknown']
dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)]
else:
dimension_labels_with_unknown = dimension_labels
dimension_colors_with_unknown = dimension_colors
if use_separate_channels:
for name, outputs in [
('targets', tensors.separate_channel_annotation_tensor),
('outputs', tensors.pred)
]:
batch_images = batch_dimensions_to_colors_list(
outputs,
dimension_colors_with_unknown
)
batch_images_excluding_unknown = (
remove_last(batch_images)
if has_unknown_class
else batch_images
)
for i, (batch_image, dimension_label) in enumerate(zip(
batch_images, dimension_labels_with_unknown)):
suffix = "_{}_{}".format(
i, dimension_label if dimension_label else 'unknown_label'
)
add_simple_summary_image(
tensors, name + suffix, batch_image
)
with tf.name_scope(name + "_combined"):
combined_image = combine_image(batch_images_excluding_unknown)
if name == 'outputs':
tensors.summaries['output_image'] = combined_image
add_summary_image(
tensors,
name + "_combined",
combined_image
)
if name == 'outputs':
with tf.name_scope(name + "_most_likely"):
add_summary_image(
tensors,
name + "_most_likely",
batch_dimensions_to_most_likely_colors_list(
outputs,
dimension_colors_with_unknown)
)
else:
add_simple_summary_image(
tensors,
"output",
tensors.pred
)
if tensors.outputs_class_indices is not None:
outputs = tensors.pred
with tf.name_scope("outputs_most_likely"):
colors_tensor = tf.constant(
dimension_colors_with_unknown,
dtype=tf.uint8, name='colors'
)
add_summary_image(
tensors,
"outputs_most_likely",
tf.gather(
params=colors_tensor,
indices=tensors.outputs_class_indices
)
)
tensors.summaries['output_image'] = tensors.image_tensors['output']
class Model(object):
def __init__(self, args):
self.args = args
self.image_width = 256
self.image_height = 256
self.color_map = None
self.dimension_colors = None
self.dimension_labels = None
self.use_unknown_class = args.use_unknown_class
self.use_separate_channels = args.use_separate_channels and self.args.color_map is not None
logger = get_logger()
logger.info('use_separate_channels: %s', self.use_separate_channels)
if self.args.color_map:
color_map_config = ConfigParser()
with FileIO(self.args.color_map, 'r') as config_f:
color_map_config.readfp(config_f)
self.color_map = parse_color_map_from_configparser(color_map_config)
color_label_map = {
(int(k), int(k), int(k)): v
for k, v in color_map_config.items('color_labels')
}
sorted_keys = sorted(six.iterkeys(self.color_map))
self.dimension_colors = [self.color_map[k] for k in sorted_keys]
self.dimension_labels = [color_label_map.get(k) for k in sorted_keys]
logger.debug("dimension_colors: %s", self.dimension_colors)
logger.debug("dimension_labels: %s", self.dimension_labels)
if self.use_unknown_class or not self.dimension_colors:
self.dimension_labels_with_unknown = self.dimension_labels + ['unknown']
self.dimension_colors_with_unknown = self.dimension_colors + [(255, 255, 255)]
else:
self.dimension_labels_with_unknown = self.dimension_labels
self.dimension_colors_with_unknown = self.dimension_colors
def build_graph(self, data_paths, batch_size, graph_mode):
logger = get_logger()
logger.debug('batch_size: %s', batch_size)
tensors = GraphReferences()
is_training = (
graph_mode == GraphMode.TRAIN or
graph_mode == GraphMode.EVALUATE
)
if data_paths:
tensors.keys, tensors.examples = read_examples(
data_paths,
shuffle=(graph_mode == GraphMode.TRAIN),
num_epochs=None if is_training else 2
)
else:
tensors.examples = tf.placeholder(tf.string, name='input', shape=(None,))
with tf.name_scope('inputs'):
feature_map = {
'input_uri':
tf.FixedLenFeature(
shape=[], dtype=tf.string, default_value=['']
),
'annotation_uri':
tf.FixedLenFeature(
shape=[], dtype=tf.string, default_value=['']
),
'input_image':
tf.FixedLenFeature(
shape=[], dtype=tf.string
),
'annotation_image':
tf.FixedLenFeature(
shape=[], dtype=tf.string
)
}
logging.info('tensors.examples: %s', tensors.examples)
parsed = tf.parse_single_example(tensors.examples, features=feature_map)
tensors.image_tensors = {}
tensors.input_uri = tf.squeeze(parsed['input_uri'])
tensors.annotation_uri = tf.squeeze(parsed['annotation_uri'])
raw_input_image = tf.squeeze(parsed['input_image'])
logging.info('raw_input_image: %s', raw_input_image)
raw_annotation_image = tf.squeeze(parsed['annotation_image'])
tensors.image_tensor = tf.image.decode_png(raw_input_image, channels=3)
tensors.annotation_tensor = tf.image.decode_png(raw_annotation_image, channels=3)
# TODO resize_images and tf.cast did not work on input image
# but did work on annotation image
tensors.image_tensor = tf.image.resize_image_with_crop_or_pad(
tensors.image_tensor, self.image_height, self.image_width
)
tensors.image_tensor = tf.image.convert_image_dtype(tensors.image_tensor, tf.float32)
tensors.annotation_tensor = tf.image.resize_image_with_crop_or_pad(
tensors.annotation_tensor, self.image_height, self.image_width
)
if self.use_separate_channels:
tensors.separate_channel_annotation_tensor = colors_to_dimensions(
tensors.annotation_tensor,
self.dimension_colors,
use_unknown_class=self.use_unknown_class
)
else:
tensors.annotation_tensor = tf.image.convert_image_dtype(tensors.annotation_tensor, tf.float32)
tensors.separate_channel_annotation_tensor = tensors.annotation_tensor
(
tensors.input_uri,
tensors.annotation_uri,
tensors.image_tensor,
tensors.annotation_tensor,
tensors.separate_channel_annotation_tensor
) = tf.train.batch(
[
tensors.input_uri,
tensors.annotation_uri,
tensors.image_tensor,
tensors.annotation_tensor,
tensors.separate_channel_annotation_tensor
],
batch_size=batch_size
)
pix2pix_model = create_pix2pix_model(
tensors.image_tensor,
tensors.separate_channel_annotation_tensor,
self.args
)
if self.use_separate_channels:
with tf.name_scope("evaluation"):
tensors.output_layer_labels = tf.constant(self.dimension_labels)
evaluation_result = evaluate_separate_channels(
targets=pix2pix_model.targets,
outputs=pix2pix_model.outputs,
has_unknown_class=self.use_unknown_class
)
tensors.evaluation_result = evaluation_result
evaluation_summary(evaluation_result, self.dimension_labels)
else:
with tf.name_scope('evaluation'):
if self.dimension_colors:
tensors.output_layer_labels = tf.constant(self.dimension_labels)
colors_tensor = tf.constant(
self.dimension_colors_with_unknown,
dtype=tf.float32
) / 255.0
tensors.outputs_class_indices = find_nearest_centroid_indices(
predictions=pix2pix_model.outputs,
centroids=colors_tensor
)
tensors.targets_class_indices = find_nearest_centroid_indices(
predictions=pix2pix_model.targets,
centroids=colors_tensor
)
evaluation_result = evaluate_predictions(
labels=tensors.targets_class_indices,
predictions=tensors.outputs_class_indices,
n_classes=len(self.dimension_colors_with_unknown),
has_unknown_class=self.use_unknown_class
)
tensors.evaluation_result = evaluation_result
evaluation_summary(evaluation_result, self.dimension_labels)
tensors.global_step = pix2pix_model.global_step
tensors.train = pix2pix_model.train
tensors.class_labels_tensor = tensors.annotation_tensor
tensors.pred = pix2pix_model.outputs
tensors.probabilities = pix2pix_model.outputs
tensors.metric_values = [pix2pix_model.discrim_loss]
add_model_summary_images(
tensors,
self.dimension_colors,
self.dimension_labels,
use_separate_channels=self.use_separate_channels,
has_unknown_class=self.use_unknown_class
)
# tensors.summaries = create_summaries(pix2pix_model)
create_other_summaries(pix2pix_model)
tensors.summary = tf.summary.merge_all()
return tensors
def build_train_graph(self, data_paths, batch_size):
return self.build_graph(data_paths, batch_size, GraphMode.TRAIN)
def build_eval_graph(self, data_paths, batch_size):
return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE)
def initialize(self, session):
pass
def format_metric_values(self, metric_values):
"""Formats metric values - used for logging purpose."""
# Early in training, metric_values may actually be None.
loss_str = 'N/A'
accuracy_str = 'N/A'
try:
loss_str = '%.3f' % metric_values[0]
accuracy_str = '%.3f' % metric_values[1]
except (TypeError, IndexError):
pass
return '%s, %s' % (loss_str, accuracy_str)
def str_to_bool(s):
return s.lower() in ('yes', 'true', '1')
def model_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer")
parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam")
parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient")
parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient")
parser.add_argument(
'--color_map',
type=str,
help='The path to the color map configuration.'
)
parser.add_argument(
'--use_unknown_class',
type=str_to_bool,
default=True,
help='Use unknown class channel (if color map is provided)'
)
parser.add_argument(
'--use_separate_channels',
type=str_to_bool,
default=False,
help='The separate output channels per annotation (if color map is provided)'
)
parser.add_argument(
'--use_separate_discriminator_channels',
type=str_to_bool,
default=False,
help='The separate discriminator channels per annotation (if color map is provided)'
)
parser.add_argument(
'--base_loss',
type=str,
default=BaseLoss.L1,
choices=[BaseLoss.L1, BaseLoss.CROSS_ENTROPY],
help='The base loss function to use'
)
return parser
def create_model(argv=None):
"""Factory method that creates model to be used by generic task.py."""
parser = model_args_parser()
args, task_args = parser.parse_known_args(argv)
return Model(args), task_args