diff --git a/sciencebeam_gym/trainer/preprocess.py b/sciencebeam_gym/trainer/preprocess.py index 0a2f2d15a5fb866ea95698a1c303148cf1baab1e..920052a526292d3d22ca717e567a01b0d7573a27 100644 --- a/sciencebeam_gym/trainer/preprocess.py +++ b/sciencebeam_gym/trainer/preprocess.py @@ -70,12 +70,12 @@ def parse_color_map(f, section_names=None): color_map[parse_color(k)] = parse_color(v) return color_map -def map_colors(img, color_map): +def map_colors(img, color_map, default_color=None): if color_map is None or len(color_map) == 0: return img original_data = img.getdata() mapped_data = [ - color_map.get(color, color) + color_map.get(color, default_color or color) for color in original_data ] img.putdata(mapped_data) @@ -188,7 +188,7 @@ def ReadAndConvertAnnotationImage(image_size, color_map): def convert_annotation_image(image): image = image_resize_nearest(image, image_size) if color_map: - image = map_colors(image, color_map) + image = map_colors(image, color_map, default_color=(255, 255, 255)) return image_save_to_bytes(image, 'png') return lambda uri: [ convert_annotation_image(image)