diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py index 4c14ef2dfeb043a8cdddc8182f4e1a42dbf7147d..88d59347bfed88d8cfb07e4ef68eeeac04fbcb82 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py @@ -162,6 +162,9 @@ def combine_image(batch_images, replace_black_with_white=False): combined_image = replace_black_with_white_color(combined_image) return combined_image +def remove_last(a): + return a[:-1] + def add_model_summary_images( tensors, dimension_colors, dimension_labels, use_separate_channels=False, @@ -190,7 +193,11 @@ def add_model_summary_images( outputs, dimension_colors_with_unknown ) - batch_images_excluding_unknown = batch_images[:-2] if has_unknown_class else batch_images + batch_images_excluding_unknown = ( + remove_last(batch_images) + if has_unknown_class + else batch_images + ) for i, (batch_image, dimension_label) in enumerate(zip( batch_images, dimension_labels_with_unknown)):