diff --git a/sciencebeam_gym/trainer/preprocess.py b/sciencebeam_gym/trainer/preprocess.py
index 0a2f2d15a5fb866ea95698a1c303148cf1baab1e..920052a526292d3d22ca717e567a01b0d7573a27 100644
--- a/sciencebeam_gym/trainer/preprocess.py
+++ b/sciencebeam_gym/trainer/preprocess.py
@@ -70,12 +70,12 @@ def parse_color_map(f, section_names=None):
         color_map[parse_color(k)] = parse_color(v)
   return color_map
 
-def map_colors(img, color_map):
+def map_colors(img, color_map, default_color=None):
   if color_map is None or len(color_map) == 0:
     return img
   original_data = img.getdata()
   mapped_data = [
-    color_map.get(color, color)
+    color_map.get(color, default_color or color)
     for color in original_data
   ]
   img.putdata(mapped_data)
@@ -188,7 +188,7 @@ def ReadAndConvertAnnotationImage(image_size, color_map):
   def convert_annotation_image(image):
     image = image_resize_nearest(image, image_size)
     if color_map:
-      image = map_colors(image, color_map)
+      image = map_colors(image, color_map, default_color=(255, 255, 255))
     return image_save_to_bytes(image, 'png')
   return lambda uri: [
     convert_annotation_image(image)