Skip to content
Snippets Groups Projects
Commit 5432f53a authored by Daniel Ecer's avatar Daniel Ecer
Browse files

added test to confirm that loss function supports sample based weights

parent 71081bfd
No related branches found
No related tags found
No related merge requests found
......@@ -14,12 +14,11 @@ def cross_entropy_loss(labels, logits):
)
)
def weighted_cross_entropy_loss(targets, logits, pos_weight):
def weighted_cross_entropy_loss(targets, logits, pos_weight, scalar=True):
with tf.name_scope("weighted_cross_entropy"):
return tf.reduce_mean(
tf.nn.weighted_cross_entropy_with_logits(
logits=logits,
targets=targets,
pos_weight=pos_weight
)
value = tf.nn.weighted_cross_entropy_with_logits(
logits=logits,
targets=targets,
pos_weight=pos_weight
)
return tf.reduce_mean(value) if scalar else value
import logging
from six import raise_from
import tensorflow as tf
......@@ -9,6 +11,9 @@ from sciencebeam_gym.trainer.models.pix2pix.loss import (
weighted_cross_entropy_loss
)
def get_logger():
return logging.getLogger(__name__)
def assert_close(a, b, atol=1.e-8):
try:
assert np.allclose([a], [b], atol=atol)
......@@ -99,3 +104,42 @@ class TestWeightedCrossEntropyLoss(object):
with tf.Session() as session:
loss_1_value, loss_2_value = session.run([loss_1, loss_2])
assert loss_1_value < loss_2_value
def test_should_support_batch_example_pos_weights(self):
batch_size = 3
with tf.Graph().as_default():
labels = tf.constant([[0.0, 1.0]] * batch_size)
logits = tf.constant([[10.0, 10.0]] * batch_size)
pos_weight_1 = tf.constant([
[0.5, 0.5],
[1.0, 1.0],
[1.0, 1.0]
])
pos_weight_2 = tf.constant([
[1.0, 1.0],
[0.5, 0.5],
[1.0, 1.0]
])
loss_1 = weighted_cross_entropy_loss(
labels, logits, pos_weight_1, scalar=False
)
loss_2 = weighted_cross_entropy_loss(
labels, logits, pos_weight_2, scalar=False
)
loss_1_per_example = tf.reduce_mean(loss_1, axis=[-1])
loss_2_per_example = tf.reduce_mean(loss_2, axis=[-1])
with tf.Session() as session:
get_logger().debug('labels=\n%s', labels.eval())
get_logger().debug('logits=\n%s', logits.eval())
loss_1_value, loss_2_value, loss_1_per_example_value, loss_2_per_example_value = (
session.run([loss_1, loss_2, loss_1_per_example, loss_2_per_example])
)
get_logger().debug(
'\nloss_1_value=\n%s\nloss_2_value=\n%s'
'\nloss_1_per_example_value=\n%s\nloss_2_per_example_value=\n%s',
loss_1_value, loss_2_value,
loss_1_per_example_value, loss_2_per_example_value
)
assert loss_1_per_example_value[0] < loss_2_per_example_value[0]
assert loss_1_per_example_value[1] > loss_2_per_example_value[1]
assert loss_1_per_example_value[2] == loss_2_per_example_value[2]
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment