Skip to content
Snippets Groups Projects
Commit 7bf05aff authored by Daniel Ecer's avatar Daniel Ecer
Browse files

optionally include unknown in class weights

parent 4400522a
No related branches found
No related tags found
No related merge requests found
......@@ -21,31 +21,49 @@ from sciencebeam_gym.utils.tfrecord import (
def get_logger():
return logging.getLogger(__name__)
def color_frequency(image, color):
return tf.reduce_sum(
tf.cast(
tf.reduce_all(
tf.equal(image, color),
axis=-1,
name='is_color'
),
tf.float32
)
def color_equals_mask(image, color):
return tf.reduce_all(
tf.equal(image, color),
axis=-1,
name='is_color'
)
def color_equals_mask_as_float(image, color):
return tf.cast(color_equals_mask(image, color), tf.float32)
def color_frequency(image, color):
return tf.reduce_sum(color_equals_mask_as_float(image, color))
def get_shape(x):
try:
return x.shape
except AttributeError:
return tf.constant(x).shape
def calculate_sample_frequencies(image, colors):
return [
color_frequency(image, color)
def calculate_sample_frequencies(image, colors, use_unknown_class=False):
color_masks = [
color_equals_mask_as_float(image, color)
for color in colors
]
if use_unknown_class:
shape = tf.shape(color_masks[0])
ones = tf.fill(shape, 1.0, name='ones')
zeros = tf.fill(shape, 0.0, name='zeros')
color_masks.append(
tf.where(
tf.add_n(color_masks) < 0.5,
ones,
zeros
)
)
return [
tf.reduce_sum(color_mask)
for color_mask in color_masks
]
def iter_calculate_sample_frequencies(
images, colors, image_shape=None, image_format=None, use_unknown_class=False):
def iter_calculate_sample_frequencies(images, colors, image_shape=None, image_format=None):
with tf.Graph().as_default():
if image_format == 'png':
image_tensor = tf.placeholder(tf.string, shape=[], name='image')
......@@ -56,7 +74,9 @@ def iter_calculate_sample_frequencies(images, colors, image_shape=None, image_fo
image_tensor = tf.placeholder(tf.uint8, shape=image_shape, name='image')
decoded_image_tensor = image_tensor
get_logger().debug('decoded_image_tensor: %s', decoded_image_tensor)
frequency_tensors = calculate_sample_frequencies(decoded_image_tensor, colors)
frequency_tensors = calculate_sample_frequencies(
decoded_image_tensor, colors, use_unknown_class=use_unknown_class
)
with tf.Session() as session:
for image in images:
frequencies = session.run(frequency_tensors, {
......@@ -117,7 +137,7 @@ def iter_images_for_tfrecord_paths(tfrecord_paths, image_key, progress=False):
yield d[image_key]
def calculate_median_class_weights_for_tfrecord_paths_and_colors(
tfrecord_paths, image_key, colors, progress=False):
tfrecord_paths, image_key, colors, use_unknown_class=False, progress=False):
get_logger().debug('colors: %s', colors)
get_logger().info('loading tfrecords: %s', tfrecord_paths)
......@@ -125,7 +145,9 @@ def calculate_median_class_weights_for_tfrecord_paths_and_colors(
if progress:
images = list(images)
images = tqdm(images, 'analysing images', leave=False)
frequency_list = list(iter_calculate_sample_frequencies(images, colors, image_format='png'))
frequency_list = list(iter_calculate_sample_frequencies(
images, colors, image_format='png', use_unknown_class=use_unknown_class
))
get_logger().debug('frequency_list: %s', frequency_list)
frequencies = transpose(frequency_list)
get_logger().debug('frequencies: %s', frequencies)
......@@ -133,7 +155,9 @@ def calculate_median_class_weights_for_tfrecord_paths_and_colors(
return class_weights
def calculate_median_class_weights_for_tfrecord_paths_and_color_map(
tfrecord_paths, image_key, color_map, channels=None, progress=False):
tfrecord_paths, image_key, color_map, channels=None,
use_unknown_class=False, unknown_class_label='unknown',
progress=False):
if not channels:
channels = sorted(color_map.keys())
colors = [color_map[k] for k in channels]
......@@ -141,12 +165,18 @@ def calculate_median_class_weights_for_tfrecord_paths_and_color_map(
tfrecord_paths,
image_key,
colors,
progress=progress
progress=progress,
use_unknown_class=use_unknown_class
)
if use_unknown_class:
channels += [unknown_class_label]
return {
k: class_weight for k, class_weight in zip(channels, class_weights)
}
def str_to_bool(s):
return s.lower() in ('yes', 'true', '1')
def str_to_list(s):
s = s.strip()
if not s:
......@@ -179,6 +209,12 @@ def get_args_parser():
type=str_to_list,
help='The channels to use (subset of color map), otherwise all of the labels will be used'
)
parser.add_argument(
'--use-unknown-class',
type=str_to_bool,
default=True,
help='Use unknown class channel'
)
parser.add_argument(
'--out',
required=False,
......@@ -200,6 +236,7 @@ def main(argv=None):
args.image_key,
color_map,
channels=args.channels,
use_unknown_class=args.use_unknown_class,
progress=True
)
get_logger().info('class_weights: %s', class_weights_map)
......
......@@ -63,6 +63,12 @@ class TestCalculateSampleFrequencies(object):
COLOR_1, COLOR_1, COLOR_2
]], [COLOR_1, COLOR_2])) == [2.0, 1.0]
def test_should_include_unknown_class_count_if_enabled(self):
with tf.Session() as session:
assert session.run(calculate_sample_frequencies([[
COLOR_1, COLOR_2, COLOR_3
]], [COLOR_1], use_unknown_class=True)) == [1.0, 2.0]
def encode_png(data):
out = BytesIO()
data = np.array(data, dtype=np.uint8)
......@@ -90,6 +96,20 @@ class TestIterCalculateSampleFrequencies(object):
]]
], [COLOR_1])) == [[0.0]]
def test_should_include_unknown_class_if_enabled(self):
assert list(iter_calculate_sample_frequencies([
[[
COLOR_0
]]
], [COLOR_1], image_shape=(1, 1, 3), use_unknown_class=True)) == [[0.0, 1.0]]
def test_should_include_unknown_class_if_enabled_and_infer_shape(self):
assert list(iter_calculate_sample_frequencies([
[[
COLOR_0
]]
], [COLOR_1], use_unknown_class=True)) == [[0.0, 1.0]]
def test_should_return_total_count_for_multiple_mixed_color(self):
assert list(iter_calculate_sample_frequencies([
[[
......@@ -119,6 +139,13 @@ class TestIterCalculateSampleFrequencies(object):
]])
], [COLOR_1], image_format='png')) == [[1.0]]
def _test_should_infer_shape_when_decoding_png_and_include_unknown_class(self):
assert list(iter_calculate_sample_frequencies([
encode_png([[
COLOR_1, COLOR_2, COLOR_3
]])
], [COLOR_1], image_format='png', use_unknown_class=True)) == [[1.0, 2.0]]
class TestCalculateMedianClassWeight(object):
def test_should_return_median_frequency_balanced_for_same_frequencies(self):
assert calculate_median_class_weight([3, 3, 3]) == 1 / 3
......@@ -249,3 +276,22 @@ class TestCalculateMedianClassWeightsForFfrecordPathsAndColorMap(object):
}
)
assert set(class_weights_map.keys()) == {'color1', 'color2'}
def test_should_include_unknown_class_if_enabled(self):
with TemporaryDirectory() as path:
tfrecord_filename = os.path.join(path, 'data.tfrecord')
get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
'image': encode_png([[
COLOR_0, COLOR_1, COLOR_2, COLOR_3
]])
})])
class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
[tfrecord_filename], 'image', {
'color1': COLOR_1,
'color2': COLOR_2
},
use_unknown_class=True,
unknown_class_label='unknown'
)
assert set(class_weights_map.keys()) == {'color1', 'color2', 'unknown'}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment