From b027e5649dacf7d862f14131e4323ab8f1279414 Mon Sep 17 00:00:00 2001 From: Daniel Ecer <de-code@users.noreply.github.com> Date: Thu, 3 Aug 2017 16:37:16 +0100 Subject: [PATCH] separate discriminator output channels by blanking out channels --- .../trainer/models/pix2pix/pix2pix_core.py | 79 ++++++++++++++++--- .../trainer/models/pix2pix/pix2pix_model.py | 6 ++ .../trainer/models/pix2pix/tf_utils.py | 18 +++++ .../trainer/models/pix2pix/tf_utils_test.py | 40 +++++++++- 4 files changed, 133 insertions(+), 10 deletions(-) diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py index 8f4c61a..93b489b 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py @@ -9,6 +9,10 @@ import logging import tensorflow as tf import collections +from sciencebeam_gym.trainer.models.pix2pix.tf_utils import ( + blank_other_channels +) + EPS = 1e-12 class BaseLoss(object): @@ -151,7 +155,7 @@ def create_generator(generator_inputs, generator_outputs_channels, a): return layers[-1] -def create_discriminator(discrim_inputs, discrim_targets, a): +def create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1): n_layers = 3 layers = [] @@ -169,21 +173,57 @@ def create_discriminator(discrim_inputs, discrim_targets, a): # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] for i in range(n_layers): with tf.variable_scope("layer_%d" % (len(layers) + 1)): - out_channels = a.ndf * min(2**(i+1), 8) + layer_out_channels = a.ndf * min(2**(i+1), 8) stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 - convolved = conv(layers[-1], out_channels, stride=stride) + convolved = conv(layers[-1], layer_out_channels, stride=stride) normalized = batchnorm(convolved) rectified = lrelu(normalized, 0.2) layers.append(rectified) # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] with tf.variable_scope("layer_%d" % (len(layers) + 1)): - convolved = conv(rectified, out_channels=1, stride=1) + convolved = conv(rectified, out_channels=out_channels, stride=1) output = tf.sigmoid(convolved) layers.append(output) return layers[-1] +def create_separate_channel_discriminator_by_blanking_out_channels(inputs, targets, a): + # We need to teach the discriminator to detect the real channels, + # by just looking at the real channel. + # For each channel: + # - let the discriminator only see the current channel, blank out all other channels + # - expect output to not be fake for the not blanked out channel + n_targets_channels = int(targets.shape[-1]) + predict_real_channels = [] + predict_real_blanked_list = [] + for i in range(n_targets_channels): + masked_targets = blank_other_channels( + targets, + i + ) + with tf.variable_scope("discriminator", reuse=(i > 0)): + # 2x [batch, height, width, channels] => [batch, 30, 30, n_targets_channels] + predict_real_i = create_discriminator( + inputs, masked_targets, a, + out_channels=n_targets_channels + ) + predict_real_channels.append(predict_real_i[:, :, :, i]) + for j in range(n_targets_channels): + if j != i: + predict_real_blanked_list.append(predict_real_i[:, :, :, j]) + predict_real = tf.stack( + predict_real_channels, + axis=-1, + name='predict_real' + ) + predict_real_blanked = tf.stack( + predict_real_blanked_list, + axis=-1, + name='predict_real_blanked' + ) + return predict_real, predict_real_blanked + def create_pix2pix_model(inputs, targets, a): @@ -196,23 +236,44 @@ def create_pix2pix_model(inputs, targets, a): else: outputs = tf.tanh(outputs) + targets_channels = int(targets.shape[-1]) + discrim_out_channels = ( + targets_channels + if a.use_separate_discriminator_channels + else 1 + ) + get_logger().info('discrim_out_channels: %s', discrim_out_channels) + # create two copies of discriminator, one for real pairs and one for fake pairs # they share the same underlying variables with tf.name_scope("real_discriminator"): - with tf.variable_scope("discriminator"): - # 2x [batch, height, width, channels] => [batch, 30, 30, 1] - predict_real = create_discriminator(inputs, targets, a) + if discrim_out_channels > 1: + predict_real, predict_real_blanked = ( + create_separate_channel_discriminator_by_blanking_out_channels( + inputs, targets, a + ) + ) + else: + with tf.variable_scope("discriminator"): + # 2x [batch, height, width, channels] => [batch, 30, 30, 1] + predict_real = create_discriminator(inputs, targets, a) with tf.name_scope("fake_discriminator"): with tf.variable_scope("discriminator", reuse=True): - # 2x [batch, height, width, channels] => [batch, 30, 30, 1] - predict_fake = create_discriminator(inputs, outputs, a) + # 2x [batch, height, width, channels] => [batch, 30, 30, discrim_out_channels] + # We don't need to split the channels, the discriminator should detect them all as fake + predict_fake = create_discriminator( + inputs, outputs, a, + out_channels=discrim_out_channels + ) with tf.name_scope("discriminator_loss"): # minimizing -tf.log will try to get inputs to 1 # predict_real => 1 # predict_fake => 0 discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) + if discrim_out_channels > 1: + discrim_loss += tf.reduce_mean(-tf.log(1 - tf.reshape(predict_real_blanked, [-1]) + EPS)) with tf.name_scope("generator_loss"): # predict_fake => 1 diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py index 00ba381..d155524 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py @@ -488,6 +488,12 @@ def model_args_parser(): default=False, help='The separate output channels per annotation (if color map is provided)' ) + parser.add_argument( + '--use_separate_discriminator_channels', + type=str_to_bool, + default=False, + help='The separate discriminator channels per annotation (if color map is provided)' + ) parser.add_argument( '--base_loss', type=str, diff --git a/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py b/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py index 192f82c..ded5806 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py +++ b/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py @@ -24,3 +24,21 @@ def find_nearest_centroid(predictions, centroids): params=centroids, indices=find_nearest_centroid_indices(predictions, centroids) ) + +def blank_other_channels(tensor, keep_index): + tensor_shape = tensor.shape + n_channels = int(tensor_shape[-1]) + rank = len(tensor_shape) + tensor_slice = tf.slice( + tensor, + begin=[0] * (rank - 1) + [keep_index], + size=[-1] * (rank - 1) + [1] + ) + paddings = tf.constant( + [[0, 0]] * (rank - 1) + + [[keep_index, n_channels - keep_index - 1]] + ) + padded = tf.pad( + tensor_slice, paddings, "CONSTANT" + ) + return padded diff --git a/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py b/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py index 18d39b1..27fae1a 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py +++ b/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py @@ -5,7 +5,8 @@ import tensorflow as tf import numpy as np from sciencebeam_gym.trainer.models.pix2pix.tf_utils import ( - find_nearest_centroid + find_nearest_centroid, + blank_other_channels ) def test_find_nearest_centroid(): @@ -47,3 +48,40 @@ def test_find_nearest_centroid1(): ] ] ) + +def test_blank_other_channels(): + tensor = tf.constant([ + [ + [5, 5, 5, 5], + [6, 6, 6, 6], + [7, 7, 7, 7], + [8, 8, 8, 8] + ], + [ + [5, 5, 5, 5], + [6, 6, 6, 6], + [7, 7, 7, 7], + [8, 8, 8, 8] + ] + ]) + padded = blank_other_channels( + tensor, 1 + ) + with tf.Session() as session: + assert np.allclose( + session.run(padded), + [ + [ + [0, 5, 0, 0], + [0, 6, 0, 0], + [0, 7, 0, 0], + [0, 8, 0, 0] + ], + [ + [0, 5, 0, 0], + [0, 6, 0, 0], + [0, 7, 0, 0], + [0, 8, 0, 0] + ] + ] + ) -- GitLab