From b027e5649dacf7d862f14131e4323ab8f1279414 Mon Sep 17 00:00:00 2001
From: Daniel Ecer <de-code@users.noreply.github.com>
Date: Thu, 3 Aug 2017 16:37:16 +0100
Subject: [PATCH] separate discriminator output channels by blanking out
 channels

---
 .../trainer/models/pix2pix/pix2pix_core.py    | 79 ++++++++++++++++---
 .../trainer/models/pix2pix/pix2pix_model.py   |  6 ++
 .../trainer/models/pix2pix/tf_utils.py        | 18 +++++
 .../trainer/models/pix2pix/tf_utils_test.py   | 40 +++++++++-
 4 files changed, 133 insertions(+), 10 deletions(-)

diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py
index 8f4c61a..93b489b 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py
@@ -9,6 +9,10 @@ import logging
 import tensorflow as tf
 import collections
 
+from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
+    blank_other_channels
+)
+
 EPS = 1e-12
 
 class BaseLoss(object):
@@ -151,7 +155,7 @@ def create_generator(generator_inputs, generator_outputs_channels, a):
 
     return layers[-1]
 
-def create_discriminator(discrim_inputs, discrim_targets, a):
+def create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1):
     n_layers = 3
     layers = []
 
@@ -169,21 +173,57 @@ def create_discriminator(discrim_inputs, discrim_targets, a):
     # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
     for i in range(n_layers):
         with tf.variable_scope("layer_%d" % (len(layers) + 1)):
-            out_channels = a.ndf * min(2**(i+1), 8)
+            layer_out_channels = a.ndf * min(2**(i+1), 8)
             stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
-            convolved = conv(layers[-1], out_channels, stride=stride)
+            convolved = conv(layers[-1], layer_out_channels, stride=stride)
             normalized = batchnorm(convolved)
             rectified = lrelu(normalized, 0.2)
             layers.append(rectified)
 
     # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
     with tf.variable_scope("layer_%d" % (len(layers) + 1)):
-        convolved = conv(rectified, out_channels=1, stride=1)
+        convolved = conv(rectified, out_channels=out_channels, stride=1)
         output = tf.sigmoid(convolved)
         layers.append(output)
 
     return layers[-1]
 
+def create_separate_channel_discriminator_by_blanking_out_channels(inputs, targets, a):
+    # We need to teach the discriminator to detect the real channels,
+    # by just looking at the real channel.
+    # For each channel:
+    # - let the discriminator only see the current channel, blank out all other channels
+    # - expect output to not be fake for the not blanked out channel
+    n_targets_channels = int(targets.shape[-1])
+    predict_real_channels = []
+    predict_real_blanked_list = []
+    for i in range(n_targets_channels):
+        masked_targets = blank_other_channels(
+            targets,
+            i
+        )
+        with tf.variable_scope("discriminator", reuse=(i > 0)):
+            # 2x [batch, height, width, channels] => [batch, 30, 30, n_targets_channels]
+            predict_real_i = create_discriminator(
+                inputs, masked_targets, a,
+                out_channels=n_targets_channels
+            )
+            predict_real_channels.append(predict_real_i[:, :, :, i])
+            for j in range(n_targets_channels):
+                if j != i:
+                    predict_real_blanked_list.append(predict_real_i[:, :, :, j])
+    predict_real = tf.stack(
+        predict_real_channels,
+        axis=-1,
+        name='predict_real'
+    )
+    predict_real_blanked = tf.stack(
+        predict_real_blanked_list,
+        axis=-1,
+        name='predict_real_blanked'
+    )
+    return predict_real, predict_real_blanked
+
 
 def create_pix2pix_model(inputs, targets, a):
 
@@ -196,23 +236,44 @@ def create_pix2pix_model(inputs, targets, a):
         else:
             outputs = tf.tanh(outputs)
 
+    targets_channels = int(targets.shape[-1])
+    discrim_out_channels = (
+        targets_channels
+        if a.use_separate_discriminator_channels
+        else 1
+    )
+    get_logger().info('discrim_out_channels: %s', discrim_out_channels)
+
     # create two copies of discriminator, one for real pairs and one for fake pairs
     # they share the same underlying variables
     with tf.name_scope("real_discriminator"):
-        with tf.variable_scope("discriminator"):
-            # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
-            predict_real = create_discriminator(inputs, targets, a)
+        if discrim_out_channels > 1:
+            predict_real, predict_real_blanked = (
+                create_separate_channel_discriminator_by_blanking_out_channels(
+                    inputs, targets, a
+                )
+            )
+        else:
+            with tf.variable_scope("discriminator"):
+                # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
+                predict_real = create_discriminator(inputs, targets, a)
 
     with tf.name_scope("fake_discriminator"):
         with tf.variable_scope("discriminator", reuse=True):
-            # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
-            predict_fake = create_discriminator(inputs, outputs, a)
+            # 2x [batch, height, width, channels] => [batch, 30, 30, discrim_out_channels]
+            # We don't need to split the channels, the discriminator should detect them all as fake
+            predict_fake = create_discriminator(
+                inputs, outputs, a,
+                out_channels=discrim_out_channels
+            )
 
     with tf.name_scope("discriminator_loss"):
         # minimizing -tf.log will try to get inputs to 1
         # predict_real => 1
         # predict_fake => 0
         discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
+        if discrim_out_channels > 1:
+            discrim_loss += tf.reduce_mean(-tf.log(1 - tf.reshape(predict_real_blanked, [-1]) + EPS))
 
     with tf.name_scope("generator_loss"):
         # predict_fake => 1
diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py
index 00ba381..d155524 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py
@@ -488,6 +488,12 @@ def model_args_parser():
     default=False,
     help='The separate output channels per annotation (if color map is provided)'
   )
+  parser.add_argument(
+    '--use_separate_discriminator_channels',
+    type=str_to_bool,
+    default=False,
+    help='The separate discriminator channels per annotation (if color map is provided)'
+  )
   parser.add_argument(
     '--base_loss',
     type=str,
diff --git a/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py b/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py
index 192f82c..ded5806 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py
@@ -24,3 +24,21 @@ def find_nearest_centroid(predictions, centroids):
     params=centroids,
     indices=find_nearest_centroid_indices(predictions, centroids)
   )
+
+def blank_other_channels(tensor, keep_index):
+  tensor_shape = tensor.shape
+  n_channels = int(tensor_shape[-1])
+  rank = len(tensor_shape)
+  tensor_slice = tf.slice(
+    tensor,
+    begin=[0] * (rank - 1) + [keep_index],
+    size=[-1] * (rank - 1) + [1]
+  )
+  paddings = tf.constant(
+    [[0, 0]] * (rank - 1) +
+    [[keep_index, n_channels - keep_index - 1]]
+  )
+  padded = tf.pad(
+    tensor_slice, paddings, "CONSTANT"
+  )
+  return padded
diff --git a/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py b/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py
index 18d39b1..27fae1a 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py
@@ -5,7 +5,8 @@ import tensorflow as tf
 import numpy as np
 
 from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
-  find_nearest_centroid
+  find_nearest_centroid,
+  blank_other_channels
 )
 
 def test_find_nearest_centroid():
@@ -47,3 +48,40 @@ def test_find_nearest_centroid1():
         ]
       ]
     )
+
+def test_blank_other_channels():
+  tensor = tf.constant([
+    [
+      [5, 5, 5, 5],
+      [6, 6, 6, 6],
+      [7, 7, 7, 7],
+      [8, 8, 8, 8]
+    ],
+    [
+      [5, 5, 5, 5],
+      [6, 6, 6, 6],
+      [7, 7, 7, 7],
+      [8, 8, 8, 8]
+    ]
+  ])
+  padded = blank_other_channels(
+    tensor, 1
+  )
+  with tf.Session() as session:
+    assert np.allclose(
+      session.run(padded),
+      [
+        [
+          [0, 5, 0, 0],
+          [0, 6, 0, 0],
+          [0, 7, 0, 0],
+          [0, 8, 0, 0]
+        ],
+        [
+          [0, 5, 0, 0],
+          [0, 6, 0, 0],
+          [0, 7, 0, 0],
+          [0, 8, 0, 0]
+        ]
+      ]
+    )
-- 
GitLab