diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py index 93b489b63e77b205e4de2b8c2ecd1d229f32e0b6..bd14254d60e3ade4b8910396e39d73c67adfb1c9 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py @@ -10,7 +10,7 @@ import tensorflow as tf import collections from sciencebeam_gym.trainer.models.pix2pix.tf_utils import ( - blank_other_channels + blank_other_channels ) EPS = 1e-12 @@ -40,337 +40,337 @@ def get_logger(): return logging.getLogger(__name__) def lrelu(x, a): - with tf.name_scope("lrelu"): - # adding these together creates the leak part and linear part - # then cancels them out by subtracting/adding an absolute value term - # leak: a*x/2 - a*abs(x)/2 - # linear: x/2 + abs(x)/2 + with tf.name_scope("lrelu"): + # adding these together creates the leak part and linear part + # then cancels them out by subtracting/adding an absolute value term + # leak: a*x/2 - a*abs(x)/2 + # linear: x/2 + abs(x)/2 - # this block looks like it has 2 inputs on the graph unless we do this - x = tf.identity(x) - return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) + # this block looks like it has 2 inputs on the graph unless we do this + x = tf.identity(x) + return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) def batchnorm(input): - with tf.variable_scope("batchnorm"): - # this block looks like it has 3 inputs on the graph unless we do this - input = tf.identity(input) - - input_shape = input.get_shape() - get_logger().debug('batchnorm, input_shape: %s', input_shape) - channels = input_shape[-1] - offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) - scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02)) - mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False) - variance_epsilon = 1e-5 - normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) - return normalized + with tf.variable_scope("batchnorm"): + # this block looks like it has 3 inputs on the graph unless we do this + input = tf.identity(input) + + input_shape = input.get_shape() + get_logger().debug('batchnorm, input_shape: %s', input_shape) + channels = input_shape[-1] + offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) + scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02)) + mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False) + variance_epsilon = 1e-5 + normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) + return normalized def conv(batch_input, out_channels, stride): - with tf.variable_scope("conv"): - input_shape = batch_input.get_shape() - get_logger().debug('conv, input_shape: %s', input_shape) - in_channels = input_shape[-1] - filter = tf.get_variable("filter", [4, 4, in_channels, out_channels], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.02)) - # [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels] - # => [batch, out_height, out_width, out_channels] - padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") - conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding="VALID") - return conv + with tf.variable_scope("conv"): + input_shape = batch_input.get_shape() + get_logger().debug('conv, input_shape: %s', input_shape) + in_channels = input_shape[-1] + filter = tf.get_variable("filter", [4, 4, in_channels, out_channels], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.02)) + # [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels] + # => [batch, out_height, out_width, out_channels] + padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") + conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding="VALID") + return conv def deconv(batch_input, out_channels): - with tf.variable_scope("deconv"): - input_shape = batch_input.get_shape() - get_logger().debug('deconv, input_shape: %s', input_shape) + with tf.variable_scope("deconv"): + input_shape = batch_input.get_shape() + get_logger().debug('deconv, input_shape: %s', input_shape) - batch, in_height, in_width, in_channels = [int(d) for d in input_shape] - filter = tf.get_variable("filter", [4, 4, out_channels, in_channels], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.02)) - # [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels] - # => [batch, out_height, out_width, out_channels] - conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels], [1, 2, 2, 1], padding="SAME") - return conv + batch, in_height, in_width, in_channels = [int(d) for d in input_shape] + filter = tf.get_variable("filter", [4, 4, out_channels, in_channels], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.02)) + # [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels] + # => [batch, out_height, out_width, out_channels] + conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels], [1, 2, 2, 1], padding="SAME") + return conv def create_generator(generator_inputs, generator_outputs_channels, a): - layers = [] + layers = [] # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] - with tf.variable_scope("encoder_1"): - output = conv(generator_inputs, a.ngf, stride=2) - layers.append(output) - - layer_specs = [ - a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] - a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] - a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] - a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] - a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] - a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] - a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] - ] - - for out_channels in layer_specs: - with tf.variable_scope("encoder_%d" % (len(layers) + 1)): - rectified = lrelu(layers[-1], 0.2) - # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] - convolved = conv(rectified, out_channels, stride=2) - output = batchnorm(convolved) - layers.append(output) - - layer_specs = [ - (a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] - (a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] - (a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] - (a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] - (a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] - (a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] - (a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] - ] - - num_encoder_layers = len(layers) - for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): - skip_layer = num_encoder_layers - decoder_layer - 1 - with tf.variable_scope("decoder_%d" % (skip_layer + 1)): - if decoder_layer == 0: - # first decoder layer doesn't have skip connections - # since it is directly connected to the skip_layer - input = layers[-1] - else: - input = tf.concat([layers[-1], layers[skip_layer]], axis=3) - - rectified = tf.nn.relu(input) - # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] - output = deconv(rectified, out_channels) - output = batchnorm(output) - - if dropout > 0.0: - output = tf.nn.dropout(output, keep_prob=1 - dropout) - - layers.append(output) - - # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] - with tf.variable_scope("decoder_1"): - input = tf.concat([layers[-1], layers[0]], axis=3) - rectified = tf.nn.relu(input) - output = deconv(rectified, generator_outputs_channels) - layers.append(output) - - return layers[-1] + with tf.variable_scope("encoder_1"): + output = conv(generator_inputs, a.ngf, stride=2) + layers.append(output) + + layer_specs = [ + a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] + a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] + a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] + a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] + a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] + a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] + a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] + ] + + for out_channels in layer_specs: + with tf.variable_scope("encoder_%d" % (len(layers) + 1)): + rectified = lrelu(layers[-1], 0.2) + # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] + convolved = conv(rectified, out_channels, stride=2) + output = batchnorm(convolved) + layers.append(output) + + layer_specs = [ + (a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] + (a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] + (a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] + (a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] + (a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] + (a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] + (a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] + ] + + num_encoder_layers = len(layers) + for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): + skip_layer = num_encoder_layers - decoder_layer - 1 + with tf.variable_scope("decoder_%d" % (skip_layer + 1)): + if decoder_layer == 0: + # first decoder layer doesn't have skip connections + # since it is directly connected to the skip_layer + input = layers[-1] + else: + input = tf.concat([layers[-1], layers[skip_layer]], axis=3) + + rectified = tf.nn.relu(input) + # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] + output = deconv(rectified, out_channels) + output = batchnorm(output) + + if dropout > 0.0: + output = tf.nn.dropout(output, keep_prob=1 - dropout) + + layers.append(output) + + # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] + with tf.variable_scope("decoder_1"): + input = tf.concat([layers[-1], layers[0]], axis=3) + rectified = tf.nn.relu(input) + output = deconv(rectified, generator_outputs_channels) + layers.append(output) + + return layers[-1] def create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1): - n_layers = 3 - layers = [] - - # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2] - input = tf.concat([discrim_inputs, discrim_targets], axis=3) - - # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf] - with tf.variable_scope("layer_1"): - convolved = conv(input, a.ndf, stride=2) - rectified = lrelu(convolved, 0.2) - layers.append(rectified) - - # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2] - # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4] - # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] - for i in range(n_layers): - with tf.variable_scope("layer_%d" % (len(layers) + 1)): - layer_out_channels = a.ndf * min(2**(i+1), 8) - stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 - convolved = conv(layers[-1], layer_out_channels, stride=stride) - normalized = batchnorm(convolved) - rectified = lrelu(normalized, 0.2) - layers.append(rectified) - - # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] + n_layers = 3 + layers = [] + + # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2] + input = tf.concat([discrim_inputs, discrim_targets], axis=3) + + # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf] + with tf.variable_scope("layer_1"): + convolved = conv(input, a.ndf, stride=2) + rectified = lrelu(convolved, 0.2) + layers.append(rectified) + + # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2] + # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4] + # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] + for i in range(n_layers): with tf.variable_scope("layer_%d" % (len(layers) + 1)): - convolved = conv(rectified, out_channels=out_channels, stride=1) - output = tf.sigmoid(convolved) - layers.append(output) + layer_out_channels = a.ndf * min(2**(i+1), 8) + stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 + convolved = conv(layers[-1], layer_out_channels, stride=stride) + normalized = batchnorm(convolved) + rectified = lrelu(normalized, 0.2) + layers.append(rectified) - return layers[-1] + # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] + with tf.variable_scope("layer_%d" % (len(layers) + 1)): + convolved = conv(rectified, out_channels=out_channels, stride=1) + output = tf.sigmoid(convolved) + layers.append(output) + + return layers[-1] def create_separate_channel_discriminator_by_blanking_out_channels(inputs, targets, a): - # We need to teach the discriminator to detect the real channels, - # by just looking at the real channel. - # For each channel: - # - let the discriminator only see the current channel, blank out all other channels - # - expect output to not be fake for the not blanked out channel - n_targets_channels = int(targets.shape[-1]) - predict_real_channels = [] - predict_real_blanked_list = [] - for i in range(n_targets_channels): - masked_targets = blank_other_channels( - targets, - i - ) - with tf.variable_scope("discriminator", reuse=(i > 0)): - # 2x [batch, height, width, channels] => [batch, 30, 30, n_targets_channels] - predict_real_i = create_discriminator( - inputs, masked_targets, a, - out_channels=n_targets_channels - ) - predict_real_channels.append(predict_real_i[:, :, :, i]) - for j in range(n_targets_channels): - if j != i: - predict_real_blanked_list.append(predict_real_i[:, :, :, j]) - predict_real = tf.stack( - predict_real_channels, - axis=-1, - name='predict_real' + # We need to teach the discriminator to detect the real channels, + # by just looking at the real channel. + # For each channel: + # - let the discriminator only see the current channel, blank out all other channels + # - expect output to not be fake for the not blanked out channel + n_targets_channels = int(targets.shape[-1]) + predict_real_channels = [] + predict_real_blanked_list = [] + for i in range(n_targets_channels): + masked_targets = blank_other_channels( + targets, + i ) - predict_real_blanked = tf.stack( - predict_real_blanked_list, - axis=-1, - name='predict_real_blanked' - ) - return predict_real, predict_real_blanked + with tf.variable_scope("discriminator", reuse=(i > 0)): + # 2x [batch, height, width, channels] => [batch, 30, 30, n_targets_channels] + predict_real_i = create_discriminator( + inputs, masked_targets, a, + out_channels=n_targets_channels + ) + predict_real_channels.append(predict_real_i[:, :, :, i]) + for j in range(n_targets_channels): + if j != i: + predict_real_blanked_list.append(predict_real_i[:, :, :, j]) + predict_real = tf.stack( + predict_real_channels, + axis=-1, + name='predict_real' + ) + predict_real_blanked = tf.stack( + predict_real_blanked_list, + axis=-1, + name='predict_real_blanked' + ) + return predict_real, predict_real_blanked def create_pix2pix_model(inputs, targets, a): - with tf.variable_scope("generator") as scope: - out_channels = int(targets.get_shape()[-1]) - outputs = create_generator(inputs, out_channels, a) - if a.base_loss == BaseLoss.CROSS_ENTROPY: - output_logits = outputs - outputs = tf.nn.softmax(output_logits) - else: - outputs = tf.tanh(outputs) - - targets_channels = int(targets.shape[-1]) - discrim_out_channels = ( - targets_channels - if a.use_separate_discriminator_channels - else 1 - ) - get_logger().info('discrim_out_channels: %s', discrim_out_channels) - - # create two copies of discriminator, one for real pairs and one for fake pairs - # they share the same underlying variables - with tf.name_scope("real_discriminator"): - if discrim_out_channels > 1: - predict_real, predict_real_blanked = ( - create_separate_channel_discriminator_by_blanking_out_channels( - inputs, targets, a - ) - ) - else: - with tf.variable_scope("discriminator"): - # 2x [batch, height, width, channels] => [batch, 30, 30, 1] - predict_real = create_discriminator(inputs, targets, a) - - with tf.name_scope("fake_discriminator"): - with tf.variable_scope("discriminator", reuse=True): - # 2x [batch, height, width, channels] => [batch, 30, 30, discrim_out_channels] - # We don't need to split the channels, the discriminator should detect them all as fake - predict_fake = create_discriminator( - inputs, outputs, a, - out_channels=discrim_out_channels - ) - - with tf.name_scope("discriminator_loss"): - # minimizing -tf.log will try to get inputs to 1 - # predict_real => 1 - # predict_fake => 0 - discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) - if discrim_out_channels > 1: - discrim_loss += tf.reduce_mean(-tf.log(1 - tf.reshape(predict_real_blanked, [-1]) + EPS)) - - with tf.name_scope("generator_loss"): - # predict_fake => 1 - gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) - if a.base_loss == BaseLoss.CROSS_ENTROPY: - get_logger().info('using cross entropy loss function') - # TODO change variable name - gen_loss_L1 = tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits( - logits=output_logits, - labels=targets, - name='softmax_cross_entropy_with_logits' - ) - ) - else: - get_logger().info('using L1 loss function') - # abs(targets - outputs) => 0 - gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) - gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight - - with tf.name_scope("discriminator_train"): - discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] - discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1) - discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) - discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) - - with tf.name_scope("generator_train"): - with tf.control_dependencies([discrim_train]): - gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] - gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1) - gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) - gen_train = gen_optim.apply_gradients(gen_grads_and_vars) - - ema = tf.train.ExponentialMovingAverage(decay=0.99) - update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1]) - - global_step = tf.contrib.framework.get_or_create_global_step() - incr_global_step = tf.assign(global_step, global_step+1) - - return Pix2PixModel( - inputs=inputs, - targets=targets, - predict_real=predict_real, - predict_fake=predict_fake, - discrim_loss=ema.average(discrim_loss), - discrim_grads_and_vars=discrim_grads_and_vars, - gen_loss_GAN=ema.average(gen_loss_GAN), - gen_loss_L1=ema.average(gen_loss_L1), - gen_grads_and_vars=gen_grads_and_vars, - outputs=outputs, - global_step=global_step, - train=tf.group(update_losses, incr_global_step, gen_train), - ) + with tf.variable_scope("generator") as scope: + out_channels = int(targets.get_shape()[-1]) + outputs = create_generator(inputs, out_channels, a) + if a.base_loss == BaseLoss.CROSS_ENTROPY: + output_logits = outputs + outputs = tf.nn.softmax(output_logits) + else: + outputs = tf.tanh(outputs) + + targets_channels = int(targets.shape[-1]) + discrim_out_channels = ( + targets_channels + if a.use_separate_discriminator_channels + else 1 + ) + get_logger().info('discrim_out_channels: %s', discrim_out_channels) + + # create two copies of discriminator, one for real pairs and one for fake pairs + # they share the same underlying variables + with tf.name_scope("real_discriminator"): + if discrim_out_channels > 1: + predict_real, predict_real_blanked = ( + create_separate_channel_discriminator_by_blanking_out_channels( + inputs, targets, a + ) + ) + else: + with tf.variable_scope("discriminator"): + # 2x [batch, height, width, channels] => [batch, 30, 30, 1] + predict_real = create_discriminator(inputs, targets, a) + + with tf.name_scope("fake_discriminator"): + with tf.variable_scope("discriminator", reuse=True): + # 2x [batch, height, width, channels] => [batch, 30, 30, discrim_out_channels] + # We don't need to split the channels, the discriminator should detect them all as fake + predict_fake = create_discriminator( + inputs, outputs, a, + out_channels=discrim_out_channels + ) + + with tf.name_scope("discriminator_loss"): + # minimizing -tf.log will try to get inputs to 1 + # predict_real => 1 + # predict_fake => 0 + discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) + if discrim_out_channels > 1: + discrim_loss += tf.reduce_mean(-tf.log(1 - tf.reshape(predict_real_blanked, [-1]) + EPS)) + + with tf.name_scope("generator_loss"): + # predict_fake => 1 + gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) + if a.base_loss == BaseLoss.CROSS_ENTROPY: + get_logger().info('using cross entropy loss function') + # TODO change variable name + gen_loss_L1 = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits( + logits=output_logits, + labels=targets, + name='softmax_cross_entropy_with_logits' + ) + ) + else: + get_logger().info('using L1 loss function') + # abs(targets - outputs) => 0 + gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) + gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight + + with tf.name_scope("discriminator_train"): + discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] + discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1) + discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) + discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) + + with tf.name_scope("generator_train"): + with tf.control_dependencies([discrim_train]): + gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] + gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1) + gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) + gen_train = gen_optim.apply_gradients(gen_grads_and_vars) + + ema = tf.train.ExponentialMovingAverage(decay=0.99) + update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1]) + + global_step = tf.contrib.framework.get_or_create_global_step() + incr_global_step = tf.assign(global_step, global_step+1) + + return Pix2PixModel( + inputs=inputs, + targets=targets, + predict_real=predict_real, + predict_fake=predict_fake, + discrim_loss=ema.average(discrim_loss), + discrim_grads_and_vars=discrim_grads_and_vars, + gen_loss_GAN=ema.average(gen_loss_GAN), + gen_loss_L1=ema.average(gen_loss_L1), + gen_grads_and_vars=gen_grads_and_vars, + outputs=outputs, + global_step=global_step, + train=tf.group(update_losses, incr_global_step, gen_train), + ) def create_image_summaries(model): - def convert(image): - return tf.image.convert_image_dtype( - image, - dtype=tf.uint8, - saturate=True - ) - summaries = {} + def convert(image): + return tf.image.convert_image_dtype( + image, + dtype=tf.uint8, + saturate=True + ) + summaries = {} - # reverse any processing on images so they can be written to disk or displayed to user - with tf.name_scope("convert_inputs"): - converted_inputs = convert(model.inputs) + # reverse any processing on images so they can be written to disk or displayed to user + with tf.name_scope("convert_inputs"): + converted_inputs = convert(model.inputs) - with tf.name_scope("convert_targets"): - converted_targets = convert(model.targets) + with tf.name_scope("convert_targets"): + converted_targets = convert(model.targets) - with tf.name_scope("convert_outputs"): - converted_outputs = convert(model.outputs) + with tf.name_scope("convert_outputs"): + converted_outputs = convert(model.outputs) - with tf.name_scope("inputs_summary"): - tf.summary.image("inputs", converted_inputs) + with tf.name_scope("inputs_summary"): + tf.summary.image("inputs", converted_inputs) - with tf.name_scope("targets_summary"): - tf.summary.image("targets", converted_targets) + with tf.name_scope("targets_summary"): + tf.summary.image("targets", converted_targets) - with tf.name_scope("outputs_summary"): - tf.summary.image("outputs", converted_outputs) - summaries['output_image'] = converted_outputs + with tf.name_scope("outputs_summary"): + tf.summary.image("outputs", converted_outputs) + summaries['output_image'] = converted_outputs - with tf.name_scope("predict_real_summary"): - tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) + with tf.name_scope("predict_real_summary"): + tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) - with tf.name_scope("predict_fake_summary"): - tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) - return summaries + with tf.name_scope("predict_fake_summary"): + tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) + return summaries def create_other_summaries(model): - tf.summary.scalar("discriminator_loss", model.discrim_loss) - tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) - tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) + tf.summary.scalar("discriminator_loss", model.discrim_loss) + tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) + tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) - for var in tf.trainable_variables(): - tf.summary.histogram(var.op.name + "/values", var) + for var in tf.trainable_variables(): + tf.summary.histogram(var.op.name + "/values", var) - for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars: - tf.summary.histogram(var.op.name + "/gradients", grad) + for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars: + tf.summary.histogram(var.op.name + "/gradients", grad)