Skip to content
Snippets Groups Projects
Commit b027e564 authored by Daniel Ecer's avatar Daniel Ecer
Browse files

separate discriminator output channels by blanking out channels

parent ba0ae9c5
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,10 @@ import logging ...@@ -9,6 +9,10 @@ import logging
import tensorflow as tf import tensorflow as tf
import collections import collections
from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
blank_other_channels
)
EPS = 1e-12 EPS = 1e-12
class BaseLoss(object): class BaseLoss(object):
...@@ -151,7 +155,7 @@ def create_generator(generator_inputs, generator_outputs_channels, a): ...@@ -151,7 +155,7 @@ def create_generator(generator_inputs, generator_outputs_channels, a):
return layers[-1] return layers[-1]
def create_discriminator(discrim_inputs, discrim_targets, a): def create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1):
n_layers = 3 n_layers = 3
layers = [] layers = []
...@@ -169,21 +173,57 @@ def create_discriminator(discrim_inputs, discrim_targets, a): ...@@ -169,21 +173,57 @@ def create_discriminator(discrim_inputs, discrim_targets, a):
# layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
for i in range(n_layers): for i in range(n_layers):
with tf.variable_scope("layer_%d" % (len(layers) + 1)): with tf.variable_scope("layer_%d" % (len(layers) + 1)):
out_channels = a.ndf * min(2**(i+1), 8) layer_out_channels = a.ndf * min(2**(i+1), 8)
stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1
convolved = conv(layers[-1], out_channels, stride=stride) convolved = conv(layers[-1], layer_out_channels, stride=stride)
normalized = batchnorm(convolved) normalized = batchnorm(convolved)
rectified = lrelu(normalized, 0.2) rectified = lrelu(normalized, 0.2)
layers.append(rectified) layers.append(rectified)
# layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
with tf.variable_scope("layer_%d" % (len(layers) + 1)): with tf.variable_scope("layer_%d" % (len(layers) + 1)):
convolved = conv(rectified, out_channels=1, stride=1) convolved = conv(rectified, out_channels=out_channels, stride=1)
output = tf.sigmoid(convolved) output = tf.sigmoid(convolved)
layers.append(output) layers.append(output)
return layers[-1] return layers[-1]
def create_separate_channel_discriminator_by_blanking_out_channels(inputs, targets, a):
# We need to teach the discriminator to detect the real channels,
# by just looking at the real channel.
# For each channel:
# - let the discriminator only see the current channel, blank out all other channels
# - expect output to not be fake for the not blanked out channel
n_targets_channels = int(targets.shape[-1])
predict_real_channels = []
predict_real_blanked_list = []
for i in range(n_targets_channels):
masked_targets = blank_other_channels(
targets,
i
)
with tf.variable_scope("discriminator", reuse=(i > 0)):
# 2x [batch, height, width, channels] => [batch, 30, 30, n_targets_channels]
predict_real_i = create_discriminator(
inputs, masked_targets, a,
out_channels=n_targets_channels
)
predict_real_channels.append(predict_real_i[:, :, :, i])
for j in range(n_targets_channels):
if j != i:
predict_real_blanked_list.append(predict_real_i[:, :, :, j])
predict_real = tf.stack(
predict_real_channels,
axis=-1,
name='predict_real'
)
predict_real_blanked = tf.stack(
predict_real_blanked_list,
axis=-1,
name='predict_real_blanked'
)
return predict_real, predict_real_blanked
def create_pix2pix_model(inputs, targets, a): def create_pix2pix_model(inputs, targets, a):
...@@ -196,23 +236,44 @@ def create_pix2pix_model(inputs, targets, a): ...@@ -196,23 +236,44 @@ def create_pix2pix_model(inputs, targets, a):
else: else:
outputs = tf.tanh(outputs) outputs = tf.tanh(outputs)
targets_channels = int(targets.shape[-1])
discrim_out_channels = (
targets_channels
if a.use_separate_discriminator_channels
else 1
)
get_logger().info('discrim_out_channels: %s', discrim_out_channels)
# create two copies of discriminator, one for real pairs and one for fake pairs # create two copies of discriminator, one for real pairs and one for fake pairs
# they share the same underlying variables # they share the same underlying variables
with tf.name_scope("real_discriminator"): with tf.name_scope("real_discriminator"):
with tf.variable_scope("discriminator"): if discrim_out_channels > 1:
# 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_real, predict_real_blanked = (
predict_real = create_discriminator(inputs, targets, a) create_separate_channel_discriminator_by_blanking_out_channels(
inputs, targets, a
)
)
else:
with tf.variable_scope("discriminator"):
# 2x [batch, height, width, channels] => [batch, 30, 30, 1]
predict_real = create_discriminator(inputs, targets, a)
with tf.name_scope("fake_discriminator"): with tf.name_scope("fake_discriminator"):
with tf.variable_scope("discriminator", reuse=True): with tf.variable_scope("discriminator", reuse=True):
# 2x [batch, height, width, channels] => [batch, 30, 30, 1] # 2x [batch, height, width, channels] => [batch, 30, 30, discrim_out_channels]
predict_fake = create_discriminator(inputs, outputs, a) # We don't need to split the channels, the discriminator should detect them all as fake
predict_fake = create_discriminator(
inputs, outputs, a,
out_channels=discrim_out_channels
)
with tf.name_scope("discriminator_loss"): with tf.name_scope("discriminator_loss"):
# minimizing -tf.log will try to get inputs to 1 # minimizing -tf.log will try to get inputs to 1
# predict_real => 1 # predict_real => 1
# predict_fake => 0 # predict_fake => 0
discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
if discrim_out_channels > 1:
discrim_loss += tf.reduce_mean(-tf.log(1 - tf.reshape(predict_real_blanked, [-1]) + EPS))
with tf.name_scope("generator_loss"): with tf.name_scope("generator_loss"):
# predict_fake => 1 # predict_fake => 1
......
...@@ -488,6 +488,12 @@ def model_args_parser(): ...@@ -488,6 +488,12 @@ def model_args_parser():
default=False, default=False,
help='The separate output channels per annotation (if color map is provided)' help='The separate output channels per annotation (if color map is provided)'
) )
parser.add_argument(
'--use_separate_discriminator_channels',
type=str_to_bool,
default=False,
help='The separate discriminator channels per annotation (if color map is provided)'
)
parser.add_argument( parser.add_argument(
'--base_loss', '--base_loss',
type=str, type=str,
......
...@@ -24,3 +24,21 @@ def find_nearest_centroid(predictions, centroids): ...@@ -24,3 +24,21 @@ def find_nearest_centroid(predictions, centroids):
params=centroids, params=centroids,
indices=find_nearest_centroid_indices(predictions, centroids) indices=find_nearest_centroid_indices(predictions, centroids)
) )
def blank_other_channels(tensor, keep_index):
tensor_shape = tensor.shape
n_channels = int(tensor_shape[-1])
rank = len(tensor_shape)
tensor_slice = tf.slice(
tensor,
begin=[0] * (rank - 1) + [keep_index],
size=[-1] * (rank - 1) + [1]
)
paddings = tf.constant(
[[0, 0]] * (rank - 1) +
[[keep_index, n_channels - keep_index - 1]]
)
padded = tf.pad(
tensor_slice, paddings, "CONSTANT"
)
return padded
...@@ -5,7 +5,8 @@ import tensorflow as tf ...@@ -5,7 +5,8 @@ import tensorflow as tf
import numpy as np import numpy as np
from sciencebeam_gym.trainer.models.pix2pix.tf_utils import ( from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
find_nearest_centroid find_nearest_centroid,
blank_other_channels
) )
def test_find_nearest_centroid(): def test_find_nearest_centroid():
...@@ -47,3 +48,40 @@ def test_find_nearest_centroid1(): ...@@ -47,3 +48,40 @@ def test_find_nearest_centroid1():
] ]
] ]
) )
def test_blank_other_channels():
tensor = tf.constant([
[
[5, 5, 5, 5],
[6, 6, 6, 6],
[7, 7, 7, 7],
[8, 8, 8, 8]
],
[
[5, 5, 5, 5],
[6, 6, 6, 6],
[7, 7, 7, 7],
[8, 8, 8, 8]
]
])
padded = blank_other_channels(
tensor, 1
)
with tf.Session() as session:
assert np.allclose(
session.run(padded),
[
[
[0, 5, 0, 0],
[0, 6, 0, 0],
[0, 7, 0, 0],
[0, 8, 0, 0]
],
[
[0, 5, 0, 0],
[0, 6, 0, 0],
[0, 7, 0, 0],
[0, 8, 0, 0]
]
]
)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment