Newer
Older
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
from six.moves import reduce
from tensorflow.python.lib.io.file_io import FileIO # pylint: disable=E0611
from sciencebeam_gym.trainer.util import (
read_examples
)
from sciencebeam_gym.preprocess.color_map import (
parse_color_map_from_file
from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
find_nearest_centroid_indices
)
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
from sciencebeam_gym.trainer.models.pix2pix.evaluate import (
evaluate_separate_channels,
class GraphMode(object):
TRAIN = 1
EVALUATE = 2
PREDICT = 3
def get_logger():
return logging.getLogger(__name__)
class GraphReferences(object):
"""Holder of base tensors used for training model using common task."""
def __init__(self):
self.examples = None
self.train = None
self.global_step = None
self.metric_updates = []
self.metric_values = []
self.keys = None
self.predictions = []
self.input_jpeg = None
self.input_uri = None
self.image_tensor = None
self.annotation_uri = None
self.annotation_tensor = None
self.separate_channel_annotation_tensor = None
self.class_labels_tensor = None
self.pred = None
self.probabilities = None
self.summary = None
self.summaries = None
self.image_tensors = None
self.targets_class_indices = None
self.outputs_class_indices = None
def colors_to_dimensions(image_tensor, colors, use_unknown_class=False):
with tf.variable_scope("colors_to_dimensions"):
single_label_tensors = []
ones = tf.fill(image_tensor.shape[0:-1], 1.0, name='ones')
zeros = tf.fill(ones.shape, 0.0, name='zeros')
for single_label_color in colors:
i = len(single_label_tensors)
with tf.variable_scope("channel_{}".format(i)):
is_color = tf.reduce_all(
tf.equal(image_tensor, single_label_color),
axis=-1,
name='is_color'
)
single_label_tensor = tf.where(
is_color,
ones,
zeros
)
single_label_tensors.append(single_label_tensor)
if use_unknown_class:
with tf.variable_scope("unknown_class"):
single_label_tensors.append(
tf.where(
tf.add_n(single_label_tensors) < 0.5,
ones,
zeros
)
)
return tf.stack(single_label_tensors, axis=-1)
def batch_dimensions_to_colors_list(image_tensor, colors):
batch_images = []
for i, single_label_color in enumerate(colors):
batch_images.append(
tf.expand_dims(
image_tensor[:, :, :, i],
axis=-1
) * ([x / 255.0 for x in single_label_color])
)
return batch_images
def batch_dimensions_to_most_likely_colors_list(image_tensor, colors):
with tf.variable_scope("batch_dimensions_to_most_likely_colors_list"):
colors_tensor = tf.constant(colors, dtype=tf.uint8, name='colors')
most_likely_class_index = tf.argmax(image_tensor, 3)
return tf.gather(params=colors_tensor, indices=most_likely_class_index)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def add_summary_image(tensors, name, image):
tensors.image_tensors[name] = image
tf.summary.image(name, image)
def convert_image(image_tensor):
return tf.image.convert_image_dtype(
image_tensor,
dtype=tf.uint8,
saturate=True
)
def add_simple_summary_image(tensors, name, image_tensor):
with tf.name_scope(name):
add_summary_image(
tensors,
name,
convert_image(image_tensor)
)
def replace_black_with_white_color(image_tensor):
is_black = tf.reduce_all(
tf.equal(image_tensor, (0, 0, 0)),
axis=-1
)
is_black = tf.stack([is_black] * 3, axis=-1)
return tf.where(
is_black,
255 * tf.ones_like(image_tensor),
image_tensor
)
def combine_image(batch_images, replace_black_with_white=False):
Daniel Ecer
committed
clipped_batch_images = [
tf.clip_by_value(batch_image, 0.0, 1.0)
for batch_image in batch_images
]
Daniel Ecer
committed
clipped_batch_images
)
)
if replace_black_with_white:
combined_image = replace_black_with_white_color(combined_image)
return combined_image
def remove_last(a):
return a[:-1]
def add_model_summary_images(
tensors, dimension_colors, dimension_labels,
use_separate_channels=False,
has_unknown_class=False):
tensors.summaries = {}
add_simple_summary_image(
tensors, 'input', tensors.image_tensor
)
add_simple_summary_image(
tensors, 'target', tensors.annotation_tensor
)
if (has_unknown_class or not use_separate_channels) and dimension_labels is not None:
dimension_labels_with_unknown = dimension_labels + ['unknown']
dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)]
else:
dimension_labels_with_unknown = dimension_labels
dimension_colors_with_unknown = dimension_colors
if use_separate_channels:
for name, outputs in [
('targets', tensors.separate_channel_annotation_tensor),
('outputs', tensors.pred)
]:
batch_images = batch_dimensions_to_colors_list(
outputs,
dimension_colors_with_unknown
batch_images_excluding_unknown = (
remove_last(batch_images)
if has_unknown_class
else batch_images
)
for i, (batch_image, dimension_label) in enumerate(zip(
batch_images, dimension_labels_with_unknown)):
suffix = "_{}_{}".format(
i, dimension_label if dimension_label else 'unknown_label'
)
add_simple_summary_image(
tensors, name + suffix, batch_image
)
with tf.name_scope(name + "_combined"):
combined_image = combine_image(batch_images_excluding_unknown)
if name == 'outputs':
tensors.summaries['output_image'] = combined_image
add_summary_image(
tensors,
name + "_combined",
combined_image
)
if name == 'outputs':
with tf.name_scope(name + "_most_likely"):
add_summary_image(
tensors,
name + "_most_likely",
batch_dimensions_to_most_likely_colors_list(
outputs,
dimension_colors_with_unknown)
)
else:
add_simple_summary_image(
tensors,
"output",
tensors.pred
)
if tensors.outputs_class_indices is not None:
outputs = tensors.pred
with tf.name_scope("outputs_most_likely"):
colors_tensor = tf.constant(
dimension_colors_with_unknown,
dtype=tf.uint8, name='colors'
)
add_summary_image(
tensors,
"outputs_most_likely",
tf.gather(
params=colors_tensor,
indices=tensors.outputs_class_indices
)
)
tensors.summaries['output_image'] = tensors.image_tensors['output']
def parse_json_file(filename):
with FileIO(filename, 'r') as f:
return json.load(f)
def class_weights_to_pos_weight(class_weights, labels, use_unknown_class):
pos_weight = [class_weights[k] for k in labels]
return pos_weight + [0.0] if use_unknown_class else pos_weight
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def parse_color_map(color_map_filename):
with FileIO(color_map_filename, 'r') as config_f:
return parse_color_map_from_file(
config_f
)
def color_map_to_labels(color_map, labels=None):
if labels:
if not all(color_map.has_key(k) for k in labels):
raise ValueError(
'not all lables found in color map, labels=%s, available keys=%s' %
(labels, color_map.keys())
)
return labels
return sorted(color_map.keys())
def color_map_to_colors(color_map, labels):
return [color_map[k] for k in labels]
def color_map_to_colors_and_labels(color_map, labels=None):
labels = color_map_to_labels(color_map, labels)
colors = color_map_to_colors(color_map, labels)
return colors, labels
def parse_color_map_to_colors_and_labels(color_map_filename, labels=None):
color_map = parse_color_map(color_map_filename)
return color_map_to_colors_and_labels(color_map, labels)
UNKNOWN_COLOR = (255, 255, 255)
UNKNOWN_LABEL = 'unknown'
def colors_and_labels_with_unknown_class(colors, labels, use_unknown_class):
if use_unknown_class or not colors:
return (
colors + [UNKNOWN_COLOR],
labels + [UNKNOWN_LABEL]
)
else:
return colors, labels
class Model(object):
def __init__(self, args):
self.args = args
self.image_width = 256
self.image_height = 256
self.color_map = None
self.dimension_colors = None
self.dimension_labels = None
self.use_unknown_class = args.use_unknown_class
self.use_separate_channels = args.use_separate_channels and self.args.color_map is not None
logger.info('use_separate_channels: %s', self.use_separate_channels)
self.dimension_colors, self.dimension_labels = parse_color_map_to_colors_and_labels(
args.color_map,
args.channels
)
self.dimension_colors_with_unknown, self.dimension_labels_with_unknown = (
colors_and_labels_with_unknown_class(
self.dimension_colors,
self.dimension_labels,
self.use_unknown_class
)
)
logger.debug("dimension_colors: %s", self.dimension_colors)
logger.debug("dimension_labels: %s", self.dimension_labels)
if self.args.class_weights:
self.pos_weight = class_weights_to_pos_weight(
parse_json_file(self.args.class_weights),
self.dimension_labels,
self.use_unknown_class
)
logger.info("pos_weight: %s", self.pos_weight)
def build_graph(self, data_paths, batch_size, graph_mode):
logger = get_logger()
logger.debug('batch_size: %s', batch_size)
tensors = GraphReferences()
is_training = (
graph_mode == GraphMode.TRAIN or
graph_mode == GraphMode.EVALUATE
)
if data_paths:
get_logger().info('reading examples from %s', data_paths)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
tensors.keys, tensors.examples = read_examples(
data_paths,
shuffle=(graph_mode == GraphMode.TRAIN),
num_epochs=None if is_training else 2
)
else:
tensors.examples = tf.placeholder(tf.string, name='input', shape=(None,))
with tf.name_scope('inputs'):
feature_map = {
'input_uri':
tf.FixedLenFeature(
shape=[], dtype=tf.string, default_value=['']
),
'annotation_uri':
tf.FixedLenFeature(
shape=[], dtype=tf.string, default_value=['']
),
'input_image':
tf.FixedLenFeature(
shape=[], dtype=tf.string
),
'annotation_image':
tf.FixedLenFeature(
shape=[], dtype=tf.string
)
}
logging.info('tensors.examples: %s', tensors.examples)
parsed = tf.parse_single_example(tensors.examples, features=feature_map)
tensors.image_tensors = {}
tensors.input_uri = tf.squeeze(parsed['input_uri'])
tensors.annotation_uri = tf.squeeze(parsed['annotation_uri'])
raw_input_image = tf.squeeze(parsed['input_image'])
logging.info('raw_input_image: %s', raw_input_image)
raw_annotation_image = tf.squeeze(parsed['annotation_image'])
tensors.image_tensor = tf.image.decode_png(raw_input_image, channels=3)
tensors.annotation_tensor = tf.image.decode_png(raw_annotation_image, channels=3)
# TODO resize_images and tf.cast did not work on input image
# but did work on annotation image
tensors.image_tensor = tf.image.resize_image_with_crop_or_pad(
tensors.image_tensor, self.image_height, self.image_width
)
tensors.image_tensor = tf.image.convert_image_dtype(tensors.image_tensor, tf.float32)
tensors.annotation_tensor = tf.image.resize_image_with_crop_or_pad(
tensors.annotation_tensor, self.image_height, self.image_width
)
if self.use_separate_channels:
tensors.separate_channel_annotation_tensor = colors_to_dimensions(
tensors.annotation_tensor,
self.dimension_colors,
use_unknown_class=self.use_unknown_class
)
else:
tensors.annotation_tensor = tf.image.convert_image_dtype(tensors.annotation_tensor, tf.float32)
tensors.separate_channel_annotation_tensor = tensors.annotation_tensor
(
tensors.input_uri,
tensors.annotation_uri,
tensors.image_tensor,
tensors.annotation_tensor,
tensors.separate_channel_annotation_tensor
) = tf.train.batch(
[
tensors.input_uri,
tensors.annotation_uri,
tensors.image_tensor,
tensors.annotation_tensor,
tensors.separate_channel_annotation_tensor
],
batch_size=batch_size
)
pix2pix_model = create_pix2pix_model(
tensors.image_tensor,
tensors.separate_channel_annotation_tensor,
if self.use_separate_channels:
tensors.output_layer_labels = tf.constant(self.dimension_labels)
evaluation_result = evaluate_separate_channels(
targets=pix2pix_model.targets,
outputs=pix2pix_model.outputs,
has_unknown_class=self.use_unknown_class
tensors.evaluation_result = evaluation_result
evaluation_summary(evaluation_result, self.dimension_labels)
else:
with tf.name_scope('evaluation'):
if self.dimension_colors:
tensors.output_layer_labels = tf.constant(self.dimension_labels)
colors_tensor = tf.constant(
self.dimension_colors_with_unknown,
dtype=tf.float32
) / 255.0
tensors.outputs_class_indices = find_nearest_centroid_indices(
predictions=pix2pix_model.outputs,
centroids=colors_tensor
)
tensors.targets_class_indices = find_nearest_centroid_indices(
predictions=pix2pix_model.targets,
centroids=colors_tensor
)
evaluation_result = evaluate_predictions(
labels=tensors.targets_class_indices,
predictions=tensors.outputs_class_indices,
n_classes=len(self.dimension_colors_with_unknown),
has_unknown_class=self.use_unknown_class
)
tensors.evaluation_result = evaluation_result
evaluation_summary(evaluation_result, self.dimension_labels)
tensors.global_step = pix2pix_model.global_step
tensors.train = pix2pix_model.train
tensors.class_labels_tensor = tensors.annotation_tensor
tensors.pred = pix2pix_model.outputs
tensors.probabilities = pix2pix_model.outputs
tensors.metric_values = [pix2pix_model.discrim_loss]
add_model_summary_images(
tensors,
self.dimension_colors,
self.dimension_labels,
use_separate_channels=self.use_separate_channels,
has_unknown_class=self.use_unknown_class
)
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
# tensors.summaries = create_summaries(pix2pix_model)
create_other_summaries(pix2pix_model)
tensors.summary = tf.summary.merge_all()
return tensors
def build_train_graph(self, data_paths, batch_size):
return self.build_graph(data_paths, batch_size, GraphMode.TRAIN)
def build_eval_graph(self, data_paths, batch_size):
return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE)
def initialize(self, session):
pass
def format_metric_values(self, metric_values):
"""Formats metric values - used for logging purpose."""
# Early in training, metric_values may actually be None.
loss_str = 'N/A'
accuracy_str = 'N/A'
try:
loss_str = '%.3f' % metric_values[0]
accuracy_str = '%.3f' % metric_values[1]
except (TypeError, IndexError):
pass
return '%s, %s' % (loss_str, accuracy_str)
def str_to_bool(s):
return s.lower() in ('yes', 'true', '1')
def str_to_list(s):
s = s.strip()
if not s:
return []
return [x.strip() for x in s.split(',')]
def model_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--ngf", type=int, default=64, help="number of generator filters in first conv layer"
)
parser.add_argument(
"--ndf", type=int, default=64, help="number of discriminator filters in first conv layer"
)
parser.add_argument(
"--lr", type=float, default=0.0002, help="initial learning rate for adam"
)
parser.add_argument(
"--beta1", type=float, default=0.5, help="momentum term of adam"
)
parser.add_argument(
"--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient"
)
parser.add_argument(
"--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient"
)
parser.add_argument(
'--color_map',
type=str,
help='The path to the color map configuration.'
)
parser.add_argument(
'--class_weights',
type=str,
help='The path to the class weights configuration.'
)
parser.add_argument(
'--channels',
type=str_to_list,
help='The channels to use (subset of color map), otherwise all of the labels will be used'
)
parser.add_argument(
'--use_unknown_class',
default=True,
help='Use unknown class channel (if color map is provided)'
)
parser.add_argument(
'--use_separate_channels',
type=str_to_bool,
default=False,
help='The separate output channels per annotation (if color map is provided)'
)
parser.add_argument(
'--use_separate_discriminator_channels',
type=str_to_bool,
default=False,
help='The separate discriminator channels per annotation (if color map is provided)'
)
parser.add_argument(
'--use_separate_discriminators',
type=str_to_bool,
default=False,
help='The separate discriminators per annotation (if color map is provided)'
)
parser.add_argument(
'--base_loss',
type=str,
default=BaseLoss.L1,
help='The base loss function to use'
)
return parser
def create_model(argv=None):
"""Factory method that creates model to be used by generic task.py."""
parser = model_args_parser()
args, task_args = parser.parse_known_args(argv)
return Model(args), task_args