From 91e1c0d0b16ff2dfc6f9dd8e6cee2e975639b53a Mon Sep 17 00:00:00 2001
From: Daniel Ecer <de-code@users.noreply.github.com>
Date: Fri, 2 Nov 2018 16:45:33 +0000
Subject: [PATCH] pylint and flake8 checking (#39)

* added pylint check

* added pylintrc to docker image

* reduced accessive apache beam debug logging

* configured pylint, addressed linting

* enabled flake8 checks

* downgrade pycodestyle to 2.3.1 due to error

* switch to 4 spaces indent

* autopep8

* more flake8

* added new line to .flake8
---
 .flake8                                       |    4 +
 .pylintrc                                     |   34 +-
 Dockerfile                                    |    3 +
 project_tests.sh                              |    8 +
 requirements.dev.txt                          |    2 +
 sciencebeam_gym/conftest.py                   |    6 +-
 .../convert/conversion_pipeline.py            |  901 ++++-----
 .../convert/conversion_pipeline_test.py       |  856 ++++----
 .../convert/cv_conversion_utils.py            |   67 +-
 .../convert/cv_conversion_utils_test.py       |   74 +-
 .../convert/grobid/grobid_service.py          |  134 +-
 .../convert/grobid/grobid_service_test.py     |   90 +-
 .../convert/grobid/grobid_service_wrapper.py  |  208 +-
 .../convert/grobid/grobid_xml_enhancer.py     |  157 +-
 .../grobid/grobid_xml_enhancer_test.py        |  319 +--
 sciencebeam_gym/inference_model/__init__.py   |  161 +-
 .../inference_model/__init___test.py          |  106 +-
 .../annotate_using_predictions.py             |  289 +--
 .../annotate_using_predictions_test.py        |  321 +--
 .../extract_from_annotated_document.py        |  161 +-
 .../extract_from_annotated_document_test.py   |  423 ++--
 .../inference_model/extract_to_xml.py         |  308 +--
 .../inference_model/extract_to_xml_test.py    |  285 +--
 sciencebeam_gym/model_utils/channels.py       |   53 +-
 .../text/crf/annotate_using_predictions.py    |  202 +-
 .../crf/annotate_using_predictions_test.py    |  162 +-
 .../models/text/crf/crfsuite_model.py         |   23 +-
 .../models/text/crf/crfsuite_model_test.py    |  178 +-
 .../text/crf/crfsuite_training_pipeline.py    |  326 ++--
 .../crf/crfsuite_training_pipeline_test.py    |  476 +++--
 .../models/text/feature_extractor.py          |  192 +-
 .../models/text/feature_extractor_test.py     |  573 +++---
 sciencebeam_gym/pdf/__init__.py               |    4 +-
 sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py    |  243 +--
 sciencebeam_gym/pdf/pdf_to_png.py             |  108 +-
 sciencebeam_gym/pdf/pdf_to_png_test.py        |   85 +-
 .../annotation/annotation_evaluation.py       |   81 +-
 .../annotation/annotation_evaluation_test.py  |   89 +-
 .../preprocess/annotation/annotator.py        |   46 +-
 .../preprocess/annotation/annotator_test.py   |   47 +-
 .../preprocess/annotation/find_line_number.py |   91 +-
 .../annotation/find_line_numbers_test.py      |  187 +-
 .../preprocess/annotation/fuzzy_match.py      |  502 ++---
 .../preprocess/annotation/fuzzy_match_test.py |  331 ++--
 .../annotation/matching_annotator.py          | 1166 +++++------
 .../annotation/matching_annotator_test.py     | 1735 +++++++++--------
 .../annotation/target_annotation.py           |  695 +++----
 .../annotation/target_annotation_test.py      | 1185 +++++------
 .../preprocess/blockify_annotations.py        |  446 +++--
 .../preprocess/blockify_annotations_test.py   |  488 ++---
 sciencebeam_gym/preprocess/color_map.py       |   55 +-
 sciencebeam_gym/preprocess/color_map_test.py  |   21 +-
 sciencebeam_gym/preprocess/lxml_to_svg.py     |  388 ++--
 .../preprocess/lxml_to_svg_test.py            |  339 ++--
 .../preprocess/preprocessing_pipeline.py      |  923 ++++-----
 .../preprocess/preprocessing_pipeline_test.py |  714 +++----
 .../preprocess/preprocessing_transforms.py    |   50 +-
 .../preprocessing_transforms_test.py          |   71 +-
 .../preprocess/preprocessing_utils.py         |  275 ++-
 .../preprocess/preprocessing_utils_test.py    |   92 +-
 .../preprocess/visualize_svg_annotation.py    |  118 +-
 .../visualize_svg_annotation_test.py          |  138 +-
 .../structured_document/__init__.py           |  380 ++--
 .../structured_document/__init___test.py      |  206 +-
 sciencebeam_gym/structured_document/lxml.py   |   77 +-
 .../structured_document/lxml_test.py          |  275 +--
 .../structured_document_loader.py             |   71 +-
 .../structured_document_loader_test.py        |  175 +-
 .../structured_document_saver.py              |   31 +-
 .../structured_document_saver_test.py         |   71 +-
 sciencebeam_gym/structured_document/svg.py    |  142 +-
 .../structured_document/svg_test.py           |  337 ++--
 .../tools/calculate_class_weights.py          |  412 ++--
 .../tools/calculate_class_weights_test.py     |  611 +++---
 sciencebeam_gym/tools/colorize_image.py       |  142 +-
 sciencebeam_gym/tools/inspect_tfrecords.py    |  164 +-
 sciencebeam_gym/tools/resize_image.py         |   81 +-
 sciencebeam_gym/trainer/checkpoint.py         |   54 +-
 sciencebeam_gym/trainer/data/examples.py      |  180 +-
 sciencebeam_gym/trainer/data/examples_test.py |  135 +-
 sciencebeam_gym/trainer/evaluator.py          |  847 ++++----
 sciencebeam_gym/trainer/evaluator_test.py     |  184 +-
 .../trainer/models/pix2pix/evaluate.py        |  182 +-
 .../trainer/models/pix2pix/evaluate_test.py   |   75 +-
 .../trainer/models/pix2pix/loss.py            |   61 +-
 .../trainer/models/pix2pix/loss_test.py       |  233 +--
 .../trainer/models/pix2pix/pix2pix_core.py    |  996 +++++-----
 .../models/pix2pix/pix2pix_core_test.py       |  342 ++--
 .../trainer/models/pix2pix/pix2pix_model.py   | 1165 +++++------
 .../models/pix2pix/pix2pix_model_test.py      |  716 ++++---
 .../trainer/models/pix2pix/tf_utils.py        |   79 +-
 .../trainer/models/pix2pix/tf_utils_test.py   |  132 +-
 sciencebeam_gym/trainer/predict.py            |   64 +-
 sciencebeam_gym/trainer/saver.py              |   14 +-
 sciencebeam_gym/trainer/task.py               | 1361 ++++++-------
 sciencebeam_gym/trainer/util.py               |  217 ++-
 sciencebeam_gym/utils/bounding_box.py         |  192 +-
 sciencebeam_gym/utils/bounding_box_test.py    |   62 +-
 sciencebeam_gym/utils/pages_zip.py            |   47 +-
 sciencebeam_gym/utils/pyplot.py               |    6 +
 sciencebeam_gym/utils/tf.py                   |   27 +-
 sciencebeam_gym/utils/tfrecord.py             |   80 +-
 sciencebeam_gym/utils/tfrecord_test.py        |   43 +-
 setup.py                                      |  140 +-
 104 files changed, 14774 insertions(+), 13850 deletions(-)
 create mode 100644 .flake8
 create mode 100644 sciencebeam_gym/utils/pyplot.py

diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..c184d2f
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,4 @@
+[flake8]
+max-line-length = 100
+ignore =
+    E731 # do not assign a lambda expression, use a def
diff --git a/.pylintrc b/.pylintrc
index 4242187..1d99acd 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -1,11 +1,29 @@
-[FORMAT]
-indent-string='  '
-indent-after-paren=2
-
 [MESSAGES CONTROL]
-# Disable C0111 (docstrings) - use useful method names instead
-# Disable C0103 (invalid variable name) - in very short blocks it is okay to use generic names
-disable=C0111,C0103,W0108,E1101,C0330,E1129,C1801
+disable=
+    missing-docstring,
+    too-few-public-methods,
+    no-self-use,
+    invalid-name,
+    too-many-arguments,
+    duplicate-code,
+    too-many-locals,
+    too-many-instance-attributes,
+    too-many-statements,
+    too-many-branches,
+    too-many-public-methods,
+    no-else-return,
+    unnecessary-lambda,
+    len-as-condition,
+    fixme,
+    bad-continuation
 
 [TYPECHECK]
-ignored-modules=sciencelab.alignment.AlignmentMatrix
+ignored-classes=
+    lxml.builder.ElementMaker,
+    lxml.builder.E
+extension-pkg-whitelist=
+    lxml
+ignored-modules=
+    sciencelab.alignment.AlignmentMatrix,
+    tensorflow,
+    apache_beam
diff --git a/Dockerfile b/Dockerfile
index 181da08..4528f26 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -16,3 +16,6 @@ RUN venv/bin/pip install -r requirements.txt
 
 COPY sciencebeam_gym ${PROJECT_HOME}/sciencebeam_gym
 COPY *.conf *.sh *.in *.txt *.py ${PROJECT_HOME}/
+
+# tests
+COPY .pylintrc .flake8 ${PROJECT_HOME}/
diff --git a/project_tests.sh b/project_tests.sh
index 9584518..27f2114 100755
--- a/project_tests.sh
+++ b/project_tests.sh
@@ -4,3 +4,11 @@ set -e
 pip install -r requirements.dev.txt
 
 pytest sciencebeam_gym
+
+echo "running pylint"
+pylint sciencebeam_gym setup.py
+
+echo "running flake8"
+flake8 sciencebeam_gym setup.py
+
+echo "done"
diff --git a/requirements.dev.txt b/requirements.dev.txt
index 220de02..88ffb2a 100644
--- a/requirements.dev.txt
+++ b/requirements.dev.txt
@@ -1,5 +1,7 @@
+flake8==3.5.0
 Cython==0.28.1
 nose==1.3.7
+pycodestyle==2.3.1
 pylint==1.8.3
 pytest==3.4.2
 pytest-watch==4.1.0
diff --git a/sciencebeam_gym/conftest.py b/sciencebeam_gym/conftest.py
index b12470c..8102c5b 100644
--- a/sciencebeam_gym/conftest.py
+++ b/sciencebeam_gym/conftest.py
@@ -2,7 +2,9 @@ import logging
 
 import pytest
 
+
 @pytest.fixture(scope='session', autouse=True)
 def setup_logging():
-  logging.root.handlers = []
-  logging.basicConfig(level='DEBUG')
+    logging.root.handlers = []
+    logging.basicConfig(level='WARNING')
+    logging.getLogger('sciencebeam_gym').setLevel('DEBUG')
diff --git a/sciencebeam_gym/convert/conversion_pipeline.py b/sciencebeam_gym/convert/conversion_pipeline.py
index 3a47ef3..fc5ea0a 100644
--- a/sciencebeam_gym/convert/conversion_pipeline.py
+++ b/sciencebeam_gym/convert/conversion_pipeline.py
@@ -13,557 +13,584 @@ from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
 from lxml import etree
 
 from sciencebeam_utils.beam_utils.utils import (
-  TransformAndCount,
-  TransformAndLog,
-  MapOrLog,
-  PreventFusion
+    TransformAndCount,
+    TransformAndLog,
+    MapOrLog,
+    PreventFusion
 )
 
 from sciencebeam_utils.beam_utils.files import (
-  ReadFileList,
-  FindFiles
+    ReadFileList,
+    FindFiles
 )
 
 from sciencebeam_utils.beam_utils.io import (
-  read_all_from_path,
-  save_file_content
+    read_all_from_path,
+    save_file_content
 )
 
 from sciencebeam_utils.beam_utils.main import (
-  add_cloud_args,
-  process_cloud_args,
-  process_sciencebeam_gym_dep_args
+    add_cloud_args,
+    process_cloud_args,
+    process_sciencebeam_gym_dep_args
 )
 
 from sciencebeam_utils.utils.collection import (
-  extend_dict,
-  remove_keys_from_dict
+    extend_dict,
+    remove_keys_from_dict
 )
 
 from sciencebeam_utils.utils.file_path import (
-  join_if_relative_path,
-  get_output_file
+    join_if_relative_path,
+    get_output_file
 )
 
 from sciencebeam_gym.structured_document.structured_document_loader import (
-  load_structured_document
+    load_structured_document
 )
 
 from sciencebeam_gym.structured_document.lxml import (
-  LxmlStructuredDocument
+    LxmlStructuredDocument
 )
 
 from sciencebeam_gym.preprocess.preprocessing_utils import (
-  convert_pdf_bytes_to_lxml,
-  parse_page_range,
-  save_pages,
-  pdf_bytes_to_png_pages
+    convert_pdf_bytes_to_lxml,
+    parse_page_range,
+    save_pages,
+    pdf_bytes_to_png_pages
 )
 
 from sciencebeam_gym.inference_model.extract_to_xml import (
-  extract_structured_document_to_xml
+    extract_structured_document_to_xml
 )
 
 from sciencebeam_gym.models.text.crf.annotate_using_predictions import (
-  predict_and_annotate_structured_document,
-  CRF_TAG_SCOPE
+    predict_and_annotate_structured_document,
+    CRF_TAG_SCOPE
 )
 
 from sciencebeam_gym.inference_model.annotate_using_predictions import (
-  annotate_structured_document_using_predicted_images,
-  AnnotatedImage,
-  CV_TAG_SCOPE
+    annotate_structured_document_using_predicted_images,
+    AnnotatedImage,
+    CV_TAG_SCOPE
 )
 
 from .grobid.grobid_xml_enhancer import (
-  GrobidXmlEnhancer
+    GrobidXmlEnhancer
 )
 
 from .cv_conversion_utils import (
-  InferenceModelWrapper,
-  image_data_to_png
+    InferenceModelWrapper,
+    image_data_to_png
 )
 
 from .grobid.grobid_service import (
-  grobid_service,
-  GrobidApiPaths
+    grobid_service,
+    GrobidApiPaths
 )
 
 
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class MetricCounters(object):
-  FILES = 'files'
-  READ_LXML_ERROR = 'read_lxml_error_count'
-  CONVERT_PDF_TO_LXML_ERROR = 'ConvertPdfToLxml_error_count'
-  CONVERT_PDF_TO_PNG_ERROR = 'ConvertPdfToPng_error_count'
-  CONVERT_LXML_TO_SVG_ANNOT_ERROR = 'ConvertPdfToSvgAnnot_error_count'
-  CV_PREDICTION_ERROR = 'ComputerVisionPrediction_error_count'
-  ANNOTATE_USING_PREDICTION_ERROR = 'AnnotateLxmlUsingPrediction_error_count'
-  EXTRACT_TO_XML_ERROR = 'ExtractToXml_error_count'
-  GROBID_ERROR = 'Grobid_error_count'
+    FILES = 'files'
+    READ_LXML_ERROR = 'read_lxml_error_count'
+    CONVERT_PDF_TO_LXML_ERROR = 'ConvertPdfToLxml_error_count'
+    CONVERT_PDF_TO_PNG_ERROR = 'ConvertPdfToPng_error_count'
+    CONVERT_LXML_TO_SVG_ANNOT_ERROR = 'ConvertPdfToSvgAnnot_error_count'
+    CV_PREDICTION_ERROR = 'ComputerVisionPrediction_error_count'
+    ANNOTATE_USING_PREDICTION_ERROR = 'AnnotateLxmlUsingPrediction_error_count'
+    EXTRACT_TO_XML_ERROR = 'ExtractToXml_error_count'
+    GROBID_ERROR = 'Grobid_error_count'
+
 
 class OutputExt(object):
-  CRF_ANNOT_LXML = '.crf.lxml.gz'
-  CRF_CV_ANNOT_LXML = '.crf-cv.lxml.gz'
-  CV_ANNOT_LXML = '.cv.lxml.gz'
-  CV_PNG = '.cv-png.zip'
+    CRF_ANNOT_LXML = '.crf.lxml.gz'
+    CRF_CV_ANNOT_LXML = '.crf-cv.lxml.gz'
+    CV_ANNOT_LXML = '.cv.lxml.gz'
+    CV_PNG = '.cv-png.zip'
+
 
 class DataProps(object):
-  SOURCE_FILENAME = 'source_filename'
-  PDF_CONTENT = 'pdf_content'
-  STRUCTURED_DOCUMENT = 'structured_document'
-  PDF_PNG_PAGES = 'pdf_png_pages'
-  CV_PREDICTION_PNG_PAGES = 'cv_prediction_png_pages'
-  COLOR_MAP = 'color_map'
-  EXTRACTED_XML = 'extracted_xml'
+    SOURCE_FILENAME = 'source_filename'
+    PDF_CONTENT = 'pdf_content'
+    STRUCTURED_DOCUMENT = 'structured_document'
+    PDF_PNG_PAGES = 'pdf_png_pages'
+    CV_PREDICTION_PNG_PAGES = 'cv_prediction_png_pages'
+    COLOR_MAP = 'color_map'
+    EXTRACTED_XML = 'extracted_xml'
+
 
 def convert_pdf_bytes_to_structured_document(pdf_content, path=None, page_range=None):
-  return LxmlStructuredDocument(etree.parse(BytesIO(
-    convert_pdf_bytes_to_lxml(pdf_content, path=path, page_range=page_range)
-  )))
+    return LxmlStructuredDocument(etree.parse(BytesIO(
+        convert_pdf_bytes_to_lxml(pdf_content, path=path, page_range=page_range)
+    )))
+
 
 def annotate_structured_document_using_predicted_image_data(
-  structured_document, prediction_images, color_map, tag_scope=None):
+        structured_document, prediction_images, color_map, tag_scope=None):
+
+    return annotate_structured_document_using_predicted_images(
+        structured_document, (
+            AnnotatedImage(prediction_image, color_map)
+            for prediction_image in prediction_images
+        ), tag_scope=tag_scope
+    )
 
-  return annotate_structured_document_using_predicted_images(
-    structured_document, (
-      AnnotatedImage(prediction_image, color_map)
-      for prediction_image in prediction_images
-    ), tag_scope=tag_scope
-  )
 
 def extract_annotated_structured_document_to_xml(structured_document, tag_scope=None):
-  xml_root = extract_structured_document_to_xml(structured_document, tag_scope=tag_scope)
-  return etree.tostring(xml_root, pretty_print=True)
+    xml_root = extract_structured_document_to_xml(structured_document, tag_scope=tag_scope)
+    return etree.tostring(xml_root, pretty_print=True)
+
 
 def load_crf_model(path):
-  with FileSystems.open(path) as crf_model_f:
-    return pickle.load(crf_model_f)
+    with FileSystems.open(path) as crf_model_f:
+        return pickle.load(crf_model_f)
+
 
 def save_structured_document(filename, structured_document):
-  # only support saving lxml for now
-  assert isinstance(structured_document, LxmlStructuredDocument)
-  save_file_content(filename, etree.tostring(structured_document.root, pretty_print=True))
-  return filename
+    # only support saving lxml for now
+    assert isinstance(structured_document, LxmlStructuredDocument)
+    save_file_content(filename, etree.tostring(structured_document.root, pretty_print=True))
+    return filename
+
 
 def get_annot_lxml_ext(crf_enabled, cv_enabled):
-  if crf_enabled and cv_enabled:
-    return OutputExt.CRF_CV_ANNOT_LXML
-  if crf_enabled:
-    return OutputExt.CRF_ANNOT_LXML
-  if cv_enabled:
-    return OutputExt.CV_ANNOT_LXML
-  raise AssertionError('at least one of crf or cv need to be enabled')
+    if crf_enabled and cv_enabled:
+        return OutputExt.CRF_CV_ANNOT_LXML
+    if crf_enabled:
+        return OutputExt.CRF_ANNOT_LXML
+    if cv_enabled:
+        return OutputExt.CV_ANNOT_LXML
+    raise AssertionError('at least one of crf or cv need to be enabled')
+
 
 def PdfUrlSource(opt):
-  if opt.pdf_file_list:
-    return ReadFileList(opt.pdf_file_list, column=opt.pdf_file_column, limit=opt.limit)
-  else:
-    return FindFiles(join_if_relative_path(opt.base_data_path, opt.pdf_path))
+    if opt.pdf_file_list:
+        return ReadFileList(opt.pdf_file_list, column=opt.pdf_file_column, limit=opt.limit)
+    else:
+        return FindFiles(join_if_relative_path(opt.base_data_path, opt.pdf_path))
+
 
 def ReadPdfContent():
-  return "ReadPdfContent" >> TransformAndCount(
-    beam.Map(lambda pdf_url: {
-      DataProps.SOURCE_FILENAME: pdf_url,
-      DataProps.PDF_CONTENT: read_all_from_path(pdf_url)
-    }),
-    MetricCounters.FILES
-  )
+    return "ReadPdfContent" >> TransformAndCount(
+        beam.Map(lambda pdf_url: {
+            DataProps.SOURCE_FILENAME: pdf_url,
+            DataProps.PDF_CONTENT: read_all_from_path(pdf_url)
+        }),
+        MetricCounters.FILES
+    )
+
 
 def add_read_pdfs_to_annotated_lxml_pipeline_steps(p, opt, get_pipeline_output_file):
-  page_range = opt.pages
-
-  cv_enabled = opt.cv_model_export_dir
-
-  extract_tag_scope = None
-
-  pdf_urls = p | PdfUrlSource(opt)
-
-  lxml_content = (
-    pdf_urls |
-    PreventFusion() |
-
-    ReadPdfContent() |
-
-    "ConvertPdfToLxml" >> MapOrLog(lambda v: extend_dict(v, {
-      DataProps.STRUCTURED_DOCUMENT: convert_pdf_bytes_to_structured_document(
-        v[DataProps.PDF_CONTENT], path=v[DataProps.SOURCE_FILENAME],
-        page_range=page_range
-      )
-    }), log_fn=lambda e, v: (
-      get_logger().warning(
-        'caught exception (ignoring item): %s, pdf: %s',
-        e, v[DataProps.SOURCE_FILENAME], exc_info=e
-      )
-    ), error_count=MetricCounters.CONVERT_PDF_TO_LXML_ERROR)
-  )
-
-  if cv_enabled:
-    image_size = (
-      (opt.image_width, opt.image_height)
-      if opt.image_width and opt.image_height
-      else None
+    page_range = opt.pages
+
+    cv_enabled = opt.cv_model_export_dir
+
+    extract_tag_scope = None
+
+    pdf_urls = p | PdfUrlSource(opt)
+
+    lxml_content = (
+        pdf_urls |
+        PreventFusion() |
+
+        ReadPdfContent() |
+
+        "ConvertPdfToLxml" >> MapOrLog(lambda v: extend_dict(v, {
+            DataProps.STRUCTURED_DOCUMENT: convert_pdf_bytes_to_structured_document(
+                v[DataProps.PDF_CONTENT], path=v[DataProps.SOURCE_FILENAME],
+                page_range=page_range
+            )
+        }), log_fn=lambda e, v: (
+            get_logger().warning(
+                'caught exception (ignoring item): %s, pdf: %s',
+                e, v[DataProps.SOURCE_FILENAME], exc_info=e
+            )
+        ), error_count=MetricCounters.CONVERT_PDF_TO_LXML_ERROR)
     )
-    inference_model_wrapper = InferenceModelWrapper(opt.cv_model_export_dir)
-
-    cv_predictions = (
-      lxml_content |
-
-      "ConvertPdfToPng" >> MapOrLog(lambda v: remove_keys_from_dict(
-        extend_dict(v, {
-          DataProps.PDF_PNG_PAGES:  list(pdf_bytes_to_png_pages(
-            v[DataProps.PDF_CONTENT],
-            dpi=90, # not used if the image is scaled
-            image_size=image_size,
-            page_range=page_range
-          ))
-        }),
-        keys_to_remove={DataProps.PDF_CONTENT}
-      ), error_count=MetricCounters.CONVERT_PDF_TO_PNG_ERROR) |
 
-      "ComputerVisionPrediction" >> MapOrLog(lambda v: remove_keys_from_dict(
-        extend_dict(v, {
-          DataProps.CV_PREDICTION_PNG_PAGES: inference_model_wrapper(v[DataProps.PDF_PNG_PAGES]),
-          DataProps.COLOR_MAP: inference_model_wrapper.get_color_map()
-        }),
-        keys_to_remove={DataProps.PDF_PNG_PAGES}
-      ), error_count=MetricCounters.CV_PREDICTION_ERROR)
+    if cv_enabled:
+        image_size = (
+            (opt.image_width, opt.image_height)
+            if opt.image_width and opt.image_height
+            else None
+        )
+        inference_model_wrapper = InferenceModelWrapper(opt.cv_model_export_dir)
+
+        cv_predictions = (
+            lxml_content |
+
+            "ConvertPdfToPng" >> MapOrLog(lambda v: remove_keys_from_dict(
+                extend_dict(v, {
+                    DataProps.PDF_PNG_PAGES: list(pdf_bytes_to_png_pages(
+                        v[DataProps.PDF_CONTENT],
+                        dpi=90,  # not used if the image is scaled
+                        image_size=image_size,
+                        page_range=page_range
+                    ))
+                }),
+                keys_to_remove={DataProps.PDF_CONTENT}
+            ), error_count=MetricCounters.CONVERT_PDF_TO_PNG_ERROR) |
+
+            "ComputerVisionPrediction" >> MapOrLog(lambda v: remove_keys_from_dict(
+                extend_dict(v, {
+                    DataProps.CV_PREDICTION_PNG_PAGES: inference_model_wrapper(
+                        v[DataProps.PDF_PNG_PAGES]
+                    ),
+                    DataProps.COLOR_MAP: inference_model_wrapper.get_color_map()
+                }),
+                keys_to_remove={DataProps.PDF_PNG_PAGES}
+            ), error_count=MetricCounters.CV_PREDICTION_ERROR)
+        )
+
+        if opt.save_cv_output:
+            _ = (
+                cv_predictions |
+                "SaveComputerVisionOutput" >> TransformAndLog(
+                    beam.Map(lambda v: save_pages(
+                        get_pipeline_output_file(
+                            v[DataProps.SOURCE_FILENAME],
+                            OutputExt.CV_PNG
+                        ),
+                        '.png',
+                        [image_data_to_png(image_data)
+                            for image_data in v[DataProps.CV_PREDICTION_PNG_PAGES]]
+                    )),
+                    log_fn=lambda x: get_logger().info('saved cv output: %s', x)
+                )
+            )
+
+        cv_annotated_lxml = (
+            cv_predictions |
+            "AnnotateLxmlUsingCvPrediction" >> MapOrLog(lambda v: remove_keys_from_dict(
+                extend_dict(v, {
+                    DataProps.STRUCTURED_DOCUMENT: (
+                        annotate_structured_document_using_predicted_image_data(
+                            v[DataProps.STRUCTURED_DOCUMENT],
+                            v[DataProps.CV_PREDICTION_PNG_PAGES],
+                            v[DataProps.COLOR_MAP],
+                            tag_scope=CV_TAG_SCOPE
+                        )
+                    )
+                }),
+                keys_to_remove={DataProps.PDF_PNG_PAGES}
+            ), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR)
+        )
+
+        lxml_content = cv_annotated_lxml
+        extract_tag_scope = CV_TAG_SCOPE
+
+    if opt.crf_model:
+        model = load_crf_model(opt.crf_model)
+        crf_annotated_lxml = (
+            lxml_content |
+            "AnnotateLxmlUsingCrfPrediction" >> MapOrLog(lambda v: extend_dict(v, {
+                DataProps.STRUCTURED_DOCUMENT: predict_and_annotate_structured_document(
+                    v[DataProps.STRUCTURED_DOCUMENT], model
+                )
+            }), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR)
+        )
+
+        lxml_content = crf_annotated_lxml
+        extract_tag_scope = CRF_TAG_SCOPE
+
+    if opt.save_annot_lxml:
+        _ = (  # flake8: noqa
+            lxml_content |
+            "SaveAnnotLxml" >> TransformAndLog(
+                beam.Map(lambda v: save_structured_document(
+                    get_pipeline_output_file(
+                        v[DataProps.SOURCE_FILENAME],
+                        get_annot_lxml_ext(
+                            crf_enabled=opt.crf_model,
+                            cv_enabled=cv_enabled
+                        )
+                    ),
+                    v[DataProps.STRUCTURED_DOCUMENT]
+                )),
+                log_fn=lambda x: get_logger().info('saved annoted lxml to: %s', x)
+            )
+        )
+    return lxml_content, extract_tag_scope
+
+
+def add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt):
+    grobid_transformer = grobid_service(
+        opt.grobid_url, opt.grobid_action, start_service=opt.start_grobid_service
+    )
+
+    return (
+        p |
+        PdfUrlSource(opt) |
+        PreventFusion() |
+        ReadPdfContent() |
+        "Grobid" >> MapOrLog(lambda v: extend_dict(v, {
+            DataProps.EXTRACTED_XML: grobid_transformer(
+                (v[DataProps.SOURCE_FILENAME], v[DataProps.PDF_CONTENT])
+            )[1]
+        }), error_count=MetricCounters.GROBID_ERROR)
     )
 
-    if opt.save_cv_output:
-      _ = (
-        cv_predictions |
-        "SaveComputerVisionOutput" >> TransformAndLog(
-          beam.Map(lambda v: save_pages(
-            get_pipeline_output_file(
-              v[DataProps.SOURCE_FILENAME],
-              OutputExt.CV_PNG
-            ),
-            '.png',
-            [image_data_to_png(image_data) for image_data in v[DataProps.CV_PREDICTION_PNG_PAGES]]
-          )),
-          log_fn=lambda x: get_logger().info('saved cv output: %s', x)
+
+def add_read_source_to_extracted_xml_pipeline_steps(p, opt, get_pipeline_output_file):
+    if opt.lxml_file_list:
+        lxml_urls = p | ReadFileList(
+            opt.lxml_file_list, column=opt.lxml_file_column, limit=opt.limit)
+
+        annotated_lxml = (
+            lxml_urls |
+            PreventFusion() |
+
+            "ReadLxmlContent" >> TransformAndCount(
+                MapOrLog(lambda url: {
+                    DataProps.SOURCE_FILENAME: url,
+                    DataProps.STRUCTURED_DOCUMENT: load_structured_document(url)
+                }, error_count=MetricCounters.READ_LXML_ERROR),
+                MetricCounters.FILES
+            )
         )
-      )
-
-    cv_annotated_lxml = (
-      cv_predictions |
-      "AnnotateLxmlUsingCvPrediction" >> MapOrLog(lambda v: remove_keys_from_dict(
-        extend_dict(v, {
-          DataProps.STRUCTURED_DOCUMENT: annotate_structured_document_using_predicted_image_data(
-            v[DataProps.STRUCTURED_DOCUMENT],
-            v[DataProps.CV_PREDICTION_PNG_PAGES],
-            v[DataProps.COLOR_MAP],
-            tag_scope=CV_TAG_SCOPE
-          )
-        }),
-        keys_to_remove={DataProps.PDF_PNG_PAGES}
-      ), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR)
+
+        extract_tag_scope = None
+    else:
+        annotated_lxml, extract_tag_scope = add_read_pdfs_to_annotated_lxml_pipeline_steps(
+            p, opt, get_pipeline_output_file
+        )
+
+    extracted_xml = (
+        annotated_lxml |
+        "ExtractToXml" >> MapOrLog(lambda v: remove_keys_from_dict(
+            extend_dict(v, {
+                DataProps.EXTRACTED_XML: extract_annotated_structured_document_to_xml(
+                    v[DataProps.STRUCTURED_DOCUMENT],
+                    tag_scope=extract_tag_scope
+                )
+            }),
+            keys_to_remove={DataProps.STRUCTURED_DOCUMENT}
+        ), error_count=MetricCounters.EXTRACT_TO_XML_ERROR)
     )
 
-    lxml_content = cv_annotated_lxml
-    extract_tag_scope = CV_TAG_SCOPE
+    if opt.use_grobid:
+        enhancer = GrobidXmlEnhancer(
+            opt.grobid_url, start_service=opt.start_grobid_service
+        )
+        extracted_xml = (
+            extracted_xml |
+            "GrobidEnhanceXml" >> MapOrLog(lambda v: extend_dict(v, {
+                DataProps.EXTRACTED_XML: enhancer(
+                    v[DataProps.EXTRACTED_XML]
+                )
+            }), error_count=MetricCounters.GROBID_ERROR)
+        )
+    return extracted_xml
+
 
-  if opt.crf_model:
-    model = load_crf_model(opt.crf_model)
-    crf_annotated_lxml = (
-      lxml_content |
-      "AnnotateLxmlUsingCrfPrediction" >> MapOrLog(lambda v: extend_dict(v, {
-        DataProps.STRUCTURED_DOCUMENT: predict_and_annotate_structured_document(
-          v[DataProps.STRUCTURED_DOCUMENT], model
+def configure_pipeline(p, opt):
+    def get_pipeline_output_file(source_url, ext):
+        return get_output_file(
+            source_url,
+            opt.base_data_path,
+            opt.output_path,
+            ext
+        )
+
+    if (
+        opt.use_grobid and not opt.crf_model and
+        not opt.cv_model_export_dir and not opt.lxml_file_list
+    ):
+        extracted_xml = add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt)
+    else:
+        extracted_xml = add_read_source_to_extracted_xml_pipeline_steps(
+            p, opt, get_pipeline_output_file
+        )
+
+    _ = (  # flake8: noqa
+        extracted_xml |
+        "WriteXml" >> TransformAndLog(
+            beam.Map(lambda v: save_file_content(
+                get_pipeline_output_file(
+                    v[DataProps.SOURCE_FILENAME],
+                    opt.output_suffix
+                ),
+                v[DataProps.EXTRACTED_XML]
+            )),
+            log_fn=lambda x: get_logger().info('saved xml to: %s', x)
         )
-      }), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR)
     )
 
-    lxml_content = crf_annotated_lxml
-    extract_tag_scope = CRF_TAG_SCOPE
-
-  if opt.save_annot_lxml:
-    _ = (
-      lxml_content |
-      "SaveAnnotLxml" >> TransformAndLog(
-        beam.Map(lambda v: save_structured_document(
-          get_pipeline_output_file(
-            v[DataProps.SOURCE_FILENAME],
-            get_annot_lxml_ext(
-              crf_enabled=opt.crf_model,
-              cv_enabled=cv_enabled
-            )
-          ),
-          v[DataProps.STRUCTURED_DOCUMENT]
-        )),
-        log_fn=lambda x: get_logger().info('saved annoted lxml to: %s', x)
-      )
+
+def add_main_args(parser):
+    parser.add_argument(
+        '--data-path', type=str, required=True,
+        help='base data path'
     )
-  return lxml_content, extract_tag_scope
 
-def add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt):
-  grobid_transformer = grobid_service(
-    opt.grobid_url, opt.grobid_action, start_service=opt.start_grobid_service
-  )
-
-  return (
-    p |
-    PdfUrlSource(opt) |
-    PreventFusion() |
-    ReadPdfContent() |
-    "Grobid" >> MapOrLog(lambda v: extend_dict(v, {
-      DataProps.EXTRACTED_XML: grobid_transformer(
-        (v[DataProps.SOURCE_FILENAME], v[DataProps.PDF_CONTENT])
-      )[1]
-    }), error_count=MetricCounters.GROBID_ERROR)
-  )
+    source_group = parser.add_argument_group('source')
+    source_one_of_group = source_group.add_mutually_exclusive_group(required=True)
+    source_one_of_group.add_argument(
+        '--pdf-path', type=str, required=False,
+        help='path to pdf file(s), relative to data-path'
+    )
+    source_one_of_group.add_argument(
+        '--pdf-file-list', type=str, required=False,
+        help='path to pdf csv/tsv file list'
+    )
+    source_group.add_argument(
+        '--pdf-file-column', type=str, required=False, default='pdf_url',
+        help='the column of the pdf file list to use'
+    )
+    source_one_of_group.add_argument(
+        '--lxml-file-list', type=str, required=False,
+        help='path to annotated lxml or svg pages zip file list'
+        '; (CRF and CV models are not supported in this mode)'
+    )
+    source_group.add_argument(
+        '--lxml-file-column', type=str, required=False, default='url',
+        help='the column of the lxml file list to use'
+    )
 
-def add_read_source_to_extracted_xml_pipeline_steps(p, opt, get_pipeline_output_file):
-  if opt.lxml_file_list:
-    lxml_urls = p | ReadFileList(opt.lxml_file_list, column=opt.lxml_file_column, limit=opt.limit)
-
-    annotated_lxml = (
-      lxml_urls |
-      PreventFusion() |
-
-      "ReadLxmlContent" >> TransformAndCount(
-        MapOrLog(lambda url: {
-          DataProps.SOURCE_FILENAME: url,
-          DataProps.STRUCTURED_DOCUMENT: load_structured_document(url)
-        }, error_count=MetricCounters.READ_LXML_ERROR),
-        MetricCounters.FILES
-      )
+    parser.add_argument(
+        '--limit', type=int, required=False,
+        help='limit the number of file pairs to process'
     )
 
-    extract_tag_scope = None
-  else:
-    annotated_lxml, extract_tag_scope = add_read_pdfs_to_annotated_lxml_pipeline_steps(
-      p, opt, get_pipeline_output_file
+    output_group = parser.add_argument_group('output')
+    output_group.add_argument(
+        '--output-path', required=False,
+        help='Output directory to write results to.'
+    )
+    output_group.add_argument(
+        '--output-suffix', required=False, default='.crf.xml',
+        help='Output file suffix to add to the filename (excluding the file extension).'
     )
 
-  extracted_xml = (
-    annotated_lxml |
-    "ExtractToXml" >> MapOrLog(lambda v: remove_keys_from_dict(
-      extend_dict(v, {
-        DataProps.EXTRACTED_XML: extract_annotated_structured_document_to_xml(
-          v[DataProps.STRUCTURED_DOCUMENT],
-          tag_scope=extract_tag_scope
-        )
-      }),
-      keys_to_remove={DataProps.STRUCTURED_DOCUMENT}
-    ), error_count=MetricCounters.EXTRACT_TO_XML_ERROR)
-  )
-
-  if opt.use_grobid:
-    enhancer = GrobidXmlEnhancer(
-      opt.grobid_url, start_service=opt.start_grobid_service
+    parser.add_argument(
+        '--save-annot-lxml', action='store_true', default=False,
+        help='enable saving of annotated lxml'
     )
-    extracted_xml = (
-      extracted_xml |
-      "GrobidEnhanceXml" >> MapOrLog(lambda v: extend_dict(v, {
-        DataProps.EXTRACTED_XML: enhancer(
-          v[DataProps.EXTRACTED_XML]
-        )
-      }), error_count=MetricCounters.GROBID_ERROR)
+
+    grobid_group = parser.add_argument_group('Grobid')
+    grobid_group.add_argument(
+        '--use-grobid', action='store_true', default=False,
+        help='enable the use of grobid'
+    )
+    grobid_group.add_argument(
+        '--grobid-url', required=False, default=None,
+        help='Base URL to the Grobid service'
+    )
+    parser.add_argument(
+        '--grobid-action', required=False,
+        default=GrobidApiPaths.PROCESS_HEADER_DOCUMENT,
+        help='Name of the Grobid action (if Grobid is used without CRF or CV model)'
     )
-  return extracted_xml
 
-def configure_pipeline(p, opt):
-  get_pipeline_output_file = lambda source_url, ext: get_output_file(
-    source_url,
-    opt.base_data_path,
-    opt.output_path,
-    ext
-  )
-
-  if (
-    opt.use_grobid and not opt.crf_model and
-    not opt.cv_model_export_dir and not opt.lxml_file_list
-  ):
-    extracted_xml = add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt)
-  else:
-    extracted_xml = add_read_source_to_extracted_xml_pipeline_steps(
-      p, opt, get_pipeline_output_file
+    parser.add_argument(
+        '--debug', action='store_true', default=False,
+        help='enable debug output'
     )
 
-  _ = (
-    extracted_xml |
-    "WriteXml" >> TransformAndLog(
-      beam.Map(lambda v: save_file_content(
-        get_pipeline_output_file(
-          v[DataProps.SOURCE_FILENAME],
-          opt.output_suffix
-        ),
-        v[DataProps.EXTRACTED_XML]
-      )),
-      log_fn=lambda x: get_logger().info('saved xml to: %s', x)
+    parser.add_argument(
+        '--pages', type=parse_page_range, default=None,
+        help='only processes the selected pages'
     )
-  )
 
+    crf_group = parser.add_argument_group('CRF')
+    crf_group.add_argument(
+        '--crf-model', type=str, required=False,
+        help='path to saved crf model'
+    )
+
+    cv_group = parser.add_argument_group('CV')
+    cv_group.add_argument(
+        '--cv-model-export-dir', type=str, required=False,
+        help='path to cv model export dir'
+    )
+    cv_group.add_argument(
+        '--image-width', type=int, required=False,
+        default=256,
+        help='image width of resulting PNGs'
+    )
+    cv_group.add_argument(
+        '--image-height', type=int, required=False,
+        default=256,
+        help='image height of resulting PNGs'
+    )
+    cv_group.add_argument(
+        '--save-cv-output', action='store_true', default=False,
+        help='enable saving of computer vision output (png pages)'
+    )
 
-def add_main_args(parser):
-  parser.add_argument(
-    '--data-path', type=str, required=True,
-    help='base data path'
-  )
-
-  source_group = parser.add_argument_group('source')
-  source_one_of_group = source_group.add_mutually_exclusive_group(required=True)
-  source_one_of_group.add_argument(
-    '--pdf-path', type=str, required=False,
-    help='path to pdf file(s), relative to data-path'
-  )
-  source_one_of_group.add_argument(
-    '--pdf-file-list', type=str, required=False,
-    help='path to pdf csv/tsv file list'
-  )
-  source_group.add_argument(
-    '--pdf-file-column', type=str, required=False, default='pdf_url',
-    help='the column of the pdf file list to use'
-  )
-  source_one_of_group.add_argument(
-    '--lxml-file-list', type=str, required=False,
-    help='path to annotated lxml or svg pages zip file list'
-    '; (CRF and CV models are not supported in this mode)'
-  )
-  source_group.add_argument(
-    '--lxml-file-column', type=str, required=False, default='url',
-    help='the column of the lxml file list to use'
-  )
-
-  parser.add_argument(
-    '--limit', type=int, required=False,
-    help='limit the number of file pairs to process'
-  )
-
-  output_group = parser.add_argument_group('output')
-  output_group.add_argument(
-    '--output-path', required=False,
-    help='Output directory to write results to.'
-  )
-  output_group.add_argument(
-    '--output-suffix', required=False, default='.crf.xml',
-    help='Output file suffix to add to the filename (excluding the file extension).'
-  )
-
-  parser.add_argument(
-    '--save-annot-lxml', action='store_true', default=False,
-    help='enable saving of annotated lxml'
-  )
-
-  grobid_group = parser.add_argument_group('Grobid')
-  grobid_group.add_argument(
-    '--use-grobid', action='store_true', default=False,
-    help='enable the use of grobid'
-  )
-  grobid_group.add_argument(
-    '--grobid-url', required=False, default=None,
-    help='Base URL to the Grobid service'
-  )
-  parser.add_argument(
-    '--grobid-action', required=False,
-    default=GrobidApiPaths.PROCESS_HEADER_DOCUMENT,
-    help='Name of the Grobid action (if Grobid is used without CRF or CV model)'
-  )
-
-  parser.add_argument(
-    '--debug', action='store_true', default=False,
-    help='enable debug output'
-  )
-
-  parser.add_argument(
-    '--pages', type=parse_page_range, default=None,
-    help='only processes the selected pages'
-  )
-
-  crf_group = parser.add_argument_group('CRF')
-  crf_group.add_argument(
-    '--crf-model', type=str, required=False,
-    help='path to saved crf model'
-  )
-
-  cv_group = parser.add_argument_group('CV')
-  cv_group.add_argument(
-    '--cv-model-export-dir', type=str, required=False,
-    help='path to cv model export dir'
-  )
-  cv_group.add_argument(
-    '--image-width', type=int, required=False,
-    default=256,
-    help='image width of resulting PNGs'
-  )
-  cv_group.add_argument(
-    '--image-height', type=int, required=False,
-    default=256,
-    help='image height of resulting PNGs'
-  )
-  cv_group.add_argument(
-    '--save-cv-output', action='store_true', default=False,
-    help='enable saving of computer vision output (png pages)'
-  )
 
 def process_main_args(args, parser):
-  args.base_data_path = args.data_path.replace('/*/', '/')
+    args.base_data_path = args.data_path.replace('/*/', '/')
 
-  if not args.output_path:
-    args.output_path = os.path.join(
-      os.path.dirname(args.base_data_path),
-      os.path.basename(args.base_data_path + '-results')
-    )
+    if not args.output_path:
+        args.output_path = os.path.join(
+            os.path.dirname(args.base_data_path),
+            os.path.basename(args.base_data_path + '-results')
+        )
+
+    if args.lxml_file_list:
+        if args.crf_model:
+            parser.error('--crf-model cannot be used in conjunction with --lxml-file-list')
+
+        if args.cv_model_export_dir:
+            parser.error(
+                '--crf-model-export-dir cannot be used in conjunction with --lxml-file-list'
+            )
+    else:
+        if not args.crf_model and not args.cv_model_export_dir and not args.use_grobid:
+            parser.error(
+                '--crf-model, --cv-model-export-dir or --use-grobid required in conjunction'
+                ' with --pdf-file-list or --pdf-path'
+            )
+
+    if args.use_grobid and not args.grobid_url:
+        args.grobid_url = 'http://localhost:8080/api'
+        args.start_grobid_service = True
+    else:
+        args.start_grobid_service = False
 
-  if args.lxml_file_list:
-    if args.crf_model:
-      parser.error('--crf-model cannot be used in conjunction with --lxml-file-list')
-
-    if args.cv_model_export_dir:
-      parser.error('--crf-model-export-dir cannot be used in conjunction with --lxml-file-list')
-  else:
-    if not args.crf_model and not args.cv_model_export_dir and not args.use_grobid:
-      parser.error(
-        '--crf-model, --cv-model-export-dir or --use-grobid required in conjunction'
-        ' with --pdf-file-list or --pdf-path'
-      )
-
-  if args.use_grobid and not args.grobid_url:
-    args.grobid_url = 'http://localhost:8080/api'
-    args.start_grobid_service = True
-  else:
-    args.start_grobid_service = False
 
 def parse_args(argv=None):
-  parser = argparse.ArgumentParser()
-  add_main_args(parser)
-  add_cloud_args(parser)
+    parser = argparse.ArgumentParser()
+    add_main_args(parser)
+    add_cloud_args(parser)
 
-  args = parser.parse_args(argv)
+    args = parser.parse_args(argv)
 
-  if args.debug:
-    logging.getLogger().setLevel('DEBUG')
+    if args.debug:
+        logging.getLogger().setLevel('DEBUG')
+
+    process_main_args(args, parser)
+    process_cloud_args(
+        args, args.output_path,
+        name='sciencebeam-convert'
+    )
+    process_sciencebeam_gym_dep_args(args)
 
-  process_main_args(args, parser)
-  process_cloud_args(
-    args, args.output_path,
-    name='sciencebeam-convert'
-  )
-  process_sciencebeam_gym_dep_args(args)
+    get_logger().info('args: %s', args)
 
-  get_logger().info('args: %s', args)
+    return args
 
-  return args
 
 def run(argv=None):
-  args = parse_args(argv)
+    args = parse_args(argv)
 
-  # We use the save_main_session option because one or more DoFn's in this
-  # workflow rely on global context (e.g., a module imported at module level).
-  pipeline_options = PipelineOptions.from_dictionary(vars(args))
-  pipeline_options.view_as(SetupOptions).save_main_session = True
+    # We use the save_main_session option because one or more DoFn's in this
+    # workflow rely on global context (e.g., a module imported at module level).
+    pipeline_options = PipelineOptions.from_dictionary(vars(args))
+    pipeline_options.view_as(SetupOptions).save_main_session = True
 
-  with beam.Pipeline(args.runner, options=pipeline_options) as p:
-    configure_pipeline(p, args)
+    with beam.Pipeline(args.runner, options=pipeline_options) as p:
+        configure_pipeline(p, args)
 
-    # Execute the pipeline and wait until it is completed.
+        # Execute the pipeline and wait until it is completed.
 
 
 if __name__ == '__main__':
-  logging.basicConfig(level='INFO')
+    logging.basicConfig(level='INFO')
 
-  run()
+    run()
diff --git a/sciencebeam_gym/convert/conversion_pipeline_test.py b/sciencebeam_gym/convert/conversion_pipeline_test.py
index 4bf728d..9247cfd 100644
--- a/sciencebeam_gym/convert/conversion_pipeline_test.py
+++ b/sciencebeam_gym/convert/conversion_pipeline_test.py
@@ -6,18 +6,18 @@ import pytest
 import apache_beam as beam
 
 from sciencebeam_utils.beam_utils.testing import (
-  BeamTest,
-  TestPipeline
+    BeamTest,
+    TestPipeline
 )
 
 from . import conversion_pipeline as conversion_pipeline
 from .conversion_pipeline import (
-  get_annot_lxml_ext,
-  configure_pipeline,
-  parse_args,
-  OutputExt,
-  CV_TAG_SCOPE,
-  CRF_TAG_SCOPE
+    get_annot_lxml_ext,
+    configure_pipeline,
+    parse_args,
+    OutputExt,
+    CV_TAG_SCOPE,
+    CRF_TAG_SCOPE
 )
 
 
@@ -41,439 +41,449 @@ PDF_CONTENT_1 = b'pdf content'
 LXML_CONTENT_1 = b'<LXML>lxml content</LXML>'
 TEI_XML_CONTENT_1 = b'<TEI>tei content</TEI>'
 
-fake_pdf_png_page = lambda i=0: 'fake pdf png page: %d' % i
+
+def fake_pdf_png_page(i=0):
+    return 'fake pdf png page: %d' % i
+
 
 MIN_ARGV = [
-  '--data-path=' + BASE_DATA_PATH,
-  '--pdf-path=' + PDF_PATH,
-  '--crf-model=' + MODEL_EXPORT_DIR
+    '--data-path=' + BASE_DATA_PATH,
+    '--pdf-path=' + PDF_PATH,
+    '--crf-model=' + MODEL_EXPORT_DIR
 ]
 
+
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 def get_default_args():
-  return parse_args(MIN_ARGV)
+    return parse_args(MIN_ARGV)
+
 
 def patch_conversion_pipeline(**kwargs):
-  always_mock = {
-    'read_all_from_path',
-    'load_structured_document',
-    'convert_pdf_bytes_to_lxml',
-    'convert_pdf_bytes_to_structured_document',
-    'load_crf_model',
-    'ReadFileList',
-    'FindFiles',
-    'predict_and_annotate_structured_document',
-    'extract_annotated_structured_document_to_xml',
-    'save_structured_document',
-    'save_file_content',
-    'GrobidXmlEnhancer',
-    'grobid_service',
-    'pdf_bytes_to_png_pages',
-    'InferenceModelWrapper',
-    'annotate_structured_document_using_predicted_image_data'
-  }
-
-  return patch.multiple(
-    conversion_pipeline,
-    **{
-      k: kwargs.get(k, DEFAULT)
-      for k in always_mock
+    always_mock = {
+        'read_all_from_path',
+        'load_structured_document',
+        'convert_pdf_bytes_to_lxml',
+        'convert_pdf_bytes_to_structured_document',
+        'load_crf_model',
+        'ReadFileList',
+        'FindFiles',
+        'predict_and_annotate_structured_document',
+        'extract_annotated_structured_document_to_xml',
+        'save_structured_document',
+        'save_file_content',
+        'GrobidXmlEnhancer',
+        'grobid_service',
+        'pdf_bytes_to_png_pages',
+        'InferenceModelWrapper',
+        'annotate_structured_document_using_predicted_image_data'
     }
-  )
 
-def _setup_mocks_for_pages(mocks, page_no_list, file_count=1):
-  mocks['pdf_bytes_to_png_pages'].return_value = [
-    fake_pdf_png_page(i) for i in page_no_list
-  ]
+    return patch.multiple(
+        conversion_pipeline,
+        **{
+            k: kwargs.get(k, DEFAULT)
+            for k in always_mock
+        }
+    )
+
+
+def _setup_mocks_for_pages(mocks, page_no_list):
+    mocks['pdf_bytes_to_png_pages'].return_value = [
+        fake_pdf_png_page(i) for i in page_no_list
+    ]
+
 
 class TestGetAnnotLxmlExt(object):
-  def test_should_return_crf_cv_annot_lxml(self):
-    assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=True) == OutputExt.CRF_CV_ANNOT_LXML
+    def test_should_return_crf_cv_annot_lxml(self):
+        assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=True) == OutputExt.CRF_CV_ANNOT_LXML
 
-  def test_should_return_crf_annot_lxml(self):
-    assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=False) == OutputExt.CRF_ANNOT_LXML
+    def test_should_return_crf_annot_lxml(self):
+        assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=False) == OutputExt.CRF_ANNOT_LXML
 
-  def test_should_return_cv_annot_lxml(self):
-    assert get_annot_lxml_ext(crf_enabled=False, cv_enabled=True) == OutputExt.CV_ANNOT_LXML
+    def test_should_return_cv_annot_lxml(self):
+        assert get_annot_lxml_ext(crf_enabled=False, cv_enabled=True) == OutputExt.CV_ANNOT_LXML
+
+    def test_should_raise_error_if_neither_crf_or_cv_are_enabled(self):
+        with pytest.raises(AssertionError):
+            get_annot_lxml_ext(crf_enabled=False, cv_enabled=False)
 
-  def test_should_raise_error_if_neither_crf_or_cv_are_enabled(self):
-    with pytest.raises(AssertionError):
-      get_annot_lxml_ext(crf_enabled=False, cv_enabled=False)
 
 @pytest.mark.slow
 class TestConfigurePipeline(BeamTest):
-  def test_should_pass_pdf_pattern_to_find_files_and_read_pdf_file(self):
-    with patch_conversion_pipeline() as mocks:
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = PDF_PATH
-      opt.pdf_file_list = None
-      with TestPipeline() as p:
-        mocks['FindFiles'].return_value = beam.Create([PDF_FILE_1])
-        configure_pipeline(p, opt)
-
-      mocks['FindFiles'].assert_called_with(
-        BASE_DATA_PATH + '/' + PDF_PATH
-      )
-      mocks['read_all_from_path'].assert_called_with(
-        PDF_FILE_1
-      )
-
-  def test_should_pass_pdf_file_list_and_limit_to_read_file_list_and_read_pdf_file(self):
-    with patch_conversion_pipeline() as mocks:
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = None
-      opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
-      opt.limit = 100
-      with TestPipeline() as p:
-        mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
-        configure_pipeline(p, opt)
-
-      mocks['ReadFileList'].assert_called_with(
-        opt.pdf_file_list, column='pdf_url', limit=opt.limit
-      )
-      mocks['read_all_from_path'].assert_called_with(
-        PDF_FILE_1
-      )
-
-  def test_should_pass_around_values_with_default_pipeline(self):
-    with patch_conversion_pipeline() as mocks:
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = None
-      opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
-      opt.output_path = OUTPUT_PATH
-      opt.output_suffix = OUTPUT_SUFFIX
-      with TestPipeline() as p:
-        mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
-        mocks['read_all_from_path'].return_value = PDF_CONTENT_1
-        configure_pipeline(p, opt)
-
-      mocks['convert_pdf_bytes_to_structured_document'].assert_called_with(
-        PDF_CONTENT_1, page_range=None, path=PDF_FILE_1
-      )
-      mocks['predict_and_annotate_structured_document'].assert_called_with(
-        mocks['convert_pdf_bytes_to_structured_document'].return_value,
-        mocks['load_crf_model'].return_value
-      )
-      mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
-        mocks['predict_and_annotate_structured_document'].return_value,
-        tag_scope=CRF_TAG_SCOPE
-      )
-      mocks['save_file_content'].assert_called_with(
-        OUTPUT_XML_FILE_1,
-        mocks['extract_annotated_structured_document_to_xml'].return_value
-      )
-
-  def test_should_save_annotated_lxml_if_enabled(self):
-    with patch_conversion_pipeline() as mocks:
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = None
-      opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
-      opt.output_path = OUTPUT_PATH
-      opt.output_suffix = OUTPUT_SUFFIX
-      opt.save_annot_lxml = True
-      with TestPipeline() as p:
-        mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
-        configure_pipeline(p, opt)
-
-      mocks['save_structured_document'].assert_called_with(
-        OUTPUT_PATH + '/' + REL_PDF_FILE_WITHOUT_EXT_1 + OutputExt.CRF_ANNOT_LXML,
-        mocks['predict_and_annotate_structured_document'].return_value
-      )
-
-  def test_should_use_lxml_file_list_if_provided_and_load_structured_documents(self):
-    with patch_conversion_pipeline() as mocks:
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = None
-      opt.pdf_file_list = None
-      opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv'
-      opt.output_path = OUTPUT_PATH
-      opt.output_suffix = OUTPUT_SUFFIX
-      with TestPipeline() as p:
-        mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1])
-        configure_pipeline(p, opt)
-
-      mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
-        mocks['load_structured_document'].return_value,
-        tag_scope=None
-      )
-      mocks['save_file_content'].assert_called_with(
-        OUTPUT_XML_FILE_1,
-        mocks['extract_annotated_structured_document_to_xml'].return_value
-      )
-
-  def test_should_use_crf_model_with_cv_model_if_enabled(self):
-    with patch_conversion_pipeline() as mocks:
-      inference_model_wrapper = mocks['InferenceModelWrapper'].return_value
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = None
-      opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
-      opt.output_path = OUTPUT_PATH
-      opt.output_suffix = OUTPUT_SUFFIX
-      opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR
-      with TestPipeline() as p:
-        mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
-        mocks['read_all_from_path'].return_value = PDF_CONTENT_1
-        _setup_mocks_for_pages(mocks, [1, 2])
-        configure_pipeline(p, opt)
-
-      mocks['convert_pdf_bytes_to_structured_document'].assert_called_with(
-        PDF_CONTENT_1, page_range=None, path=PDF_FILE_1
-      )
-
-      # cv model
-      inference_model_wrapper.assert_called_with(
-        [fake_pdf_png_page(i) for i in [1, 2]]
-      )
-      mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with(
-        mocks['convert_pdf_bytes_to_structured_document'].return_value,
-        inference_model_wrapper.return_value,
-        inference_model_wrapper.get_color_map.return_value,
-        tag_scope=CV_TAG_SCOPE
-      )
-
-      # crf model should receive output from cv model
-      mocks['predict_and_annotate_structured_document'].assert_called_with(
-        mocks['annotate_structured_document_using_predicted_image_data'].return_value,
-        mocks['load_crf_model'].return_value
-      )
-      mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
-        mocks['predict_and_annotate_structured_document'].return_value,
-        tag_scope=CRF_TAG_SCOPE
-      )
-
-  def test_should_use_cv_model_only_if_enabled(self):
-    with patch_conversion_pipeline() as mocks:
-      inference_model_wrapper = mocks['InferenceModelWrapper'].return_value
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = None
-      opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
-      opt.output_path = OUTPUT_PATH
-      opt.output_suffix = OUTPUT_SUFFIX
-      opt.crf_model = None
-      opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR
-      with TestPipeline() as p:
-        mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
-        mocks['read_all_from_path'].return_value = PDF_CONTENT_1
-        _setup_mocks_for_pages(mocks, [1, 2])
-        configure_pipeline(p, opt)
-
-      mocks['convert_pdf_bytes_to_structured_document'].assert_called_with(
-        PDF_CONTENT_1, page_range=None, path=PDF_FILE_1
-      )
-
-      # cv model
-      inference_model_wrapper.assert_called_with(
-        [fake_pdf_png_page(i) for i in [1, 2]]
-      )
-      mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with(
-        mocks['convert_pdf_bytes_to_structured_document'].return_value,
-        inference_model_wrapper.return_value,
-        inference_model_wrapper.get_color_map.return_value,
-        tag_scope=CV_TAG_SCOPE
-      )
-      mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
-        mocks['annotate_structured_document_using_predicted_image_data'].return_value,
-        tag_scope=CV_TAG_SCOPE
-      )
-
-      # crf model not be called
-      mocks['predict_and_annotate_structured_document'].assert_not_called()
-
-  def test_should_use_grobid_if_enabled(self):
-    with patch_conversion_pipeline() as mocks:
-      grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = None
-      opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
-      opt.output_path = OUTPUT_PATH
-      opt.output_suffix = OUTPUT_SUFFIX
-      opt.use_grobid = True
-      opt.grobid_url = 'http://test/api'
-      with TestPipeline() as p:
-        mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
-        mocks['read_all_from_path'].return_value = PDF_CONTENT_1
-        mocks['convert_pdf_bytes_to_lxml'].return_value = LXML_CONTENT_1
-        configure_pipeline(p, opt)
-
-      mocks['GrobidXmlEnhancer'].assert_called_with(
-        opt.grobid_url,
-        start_service=opt.start_grobid_service
-      )
-      grobid_xml_enhancer.assert_called_with(
-        mocks['extract_annotated_structured_document_to_xml'].return_value
-      )
-      mocks['save_file_content'].assert_called_with(
-        OUTPUT_XML_FILE_1,
-        grobid_xml_enhancer.return_value
-      )
-
-  def test_should_use_grobid_with_lxml_file_list_if_enabled(self):
-    with patch_conversion_pipeline() as mocks:
-      grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = None
-      opt.pdf_file_list = None
-      opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv'
-      opt.output_path = OUTPUT_PATH
-      opt.output_suffix = OUTPUT_SUFFIX
-      opt.crf_model = None
-      opt.cv_model_export_dir = None
-      opt.use_grobid = True
-      opt.grobid_url = 'http://test/api'
-      with TestPipeline() as p:
-        mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1])
-        configure_pipeline(p, opt)
-
-      mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
-        mocks['load_structured_document'].return_value,
-        tag_scope=None
-      )
-      mocks['save_file_content'].assert_called_with(
-        OUTPUT_XML_FILE_1,
-        grobid_xml_enhancer.return_value
-      )
-
-  def test_should_use_grobid_only_if_crf_or_cv_model_are_not_enabled(self):
-    with patch_conversion_pipeline() as mocks:
-      opt = get_default_args()
-      opt.base_data_path = BASE_DATA_PATH
-      opt.pdf_path = None
-      opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
-      opt.output_path = OUTPUT_PATH
-      opt.output_suffix = OUTPUT_SUFFIX
-      opt.crf_model = None
-      opt.cv_model_export_dir = None
-      opt.use_grobid = True
-      opt.grobid_url = 'http://test/api'
-      with TestPipeline() as p:
-        mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
-        mocks['read_all_from_path'].return_value = PDF_CONTENT_1
-        mocks['grobid_service'].return_value = lambda x: (
-          PDF_FILE_1, TEI_XML_CONTENT_1
-        )
-        configure_pipeline(p, opt)
-
-      mocks['grobid_service'].assert_called_with(
-        opt.grobid_url,
-        opt.grobid_action,
-        start_service=opt.start_grobid_service
-      )
-      mocks['save_file_content'].assert_called_with(
-        OUTPUT_XML_FILE_1,
-        TEI_XML_CONTENT_1
-      )
+    def test_should_pass_pdf_pattern_to_find_files_and_read_pdf_file(self):
+        with patch_conversion_pipeline() as mocks:
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = PDF_PATH
+            opt.pdf_file_list = None
+            with TestPipeline() as p:
+                mocks['FindFiles'].return_value = beam.Create([PDF_FILE_1])
+                configure_pipeline(p, opt)
+
+            mocks['FindFiles'].assert_called_with(
+                BASE_DATA_PATH + '/' + PDF_PATH
+            )
+            mocks['read_all_from_path'].assert_called_with(
+                PDF_FILE_1
+            )
+
+    def test_should_pass_pdf_file_list_and_limit_to_read_file_list_and_read_pdf_file(self):
+        with patch_conversion_pipeline() as mocks:
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = None
+            opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
+            opt.limit = 100
+            with TestPipeline() as p:
+                mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
+                configure_pipeline(p, opt)
+
+            mocks['ReadFileList'].assert_called_with(
+                opt.pdf_file_list, column='pdf_url', limit=opt.limit
+            )
+            mocks['read_all_from_path'].assert_called_with(
+                PDF_FILE_1
+            )
+
+    def test_should_pass_around_values_with_default_pipeline(self):
+        with patch_conversion_pipeline() as mocks:
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = None
+            opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
+            opt.output_path = OUTPUT_PATH
+            opt.output_suffix = OUTPUT_SUFFIX
+            with TestPipeline() as p:
+                mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
+                mocks['read_all_from_path'].return_value = PDF_CONTENT_1
+                configure_pipeline(p, opt)
+
+            mocks['convert_pdf_bytes_to_structured_document'].assert_called_with(
+                PDF_CONTENT_1, page_range=None, path=PDF_FILE_1
+            )
+            mocks['predict_and_annotate_structured_document'].assert_called_with(
+                mocks['convert_pdf_bytes_to_structured_document'].return_value,
+                mocks['load_crf_model'].return_value
+            )
+            mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
+                mocks['predict_and_annotate_structured_document'].return_value,
+                tag_scope=CRF_TAG_SCOPE
+            )
+            mocks['save_file_content'].assert_called_with(
+                OUTPUT_XML_FILE_1,
+                mocks['extract_annotated_structured_document_to_xml'].return_value
+            )
+
+    def test_should_save_annotated_lxml_if_enabled(self):
+        with patch_conversion_pipeline() as mocks:
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = None
+            opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
+            opt.output_path = OUTPUT_PATH
+            opt.output_suffix = OUTPUT_SUFFIX
+            opt.save_annot_lxml = True
+            with TestPipeline() as p:
+                mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
+                configure_pipeline(p, opt)
+
+            mocks['save_structured_document'].assert_called_with(
+                OUTPUT_PATH + '/' + REL_PDF_FILE_WITHOUT_EXT_1 + OutputExt.CRF_ANNOT_LXML,
+                mocks['predict_and_annotate_structured_document'].return_value
+            )
+
+    def test_should_use_lxml_file_list_if_provided_and_load_structured_documents(self):
+        with patch_conversion_pipeline() as mocks:
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = None
+            opt.pdf_file_list = None
+            opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv'
+            opt.output_path = OUTPUT_PATH
+            opt.output_suffix = OUTPUT_SUFFIX
+            with TestPipeline() as p:
+                mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1])
+                configure_pipeline(p, opt)
+
+            mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
+                mocks['load_structured_document'].return_value,
+                tag_scope=None
+            )
+            mocks['save_file_content'].assert_called_with(
+                OUTPUT_XML_FILE_1,
+                mocks['extract_annotated_structured_document_to_xml'].return_value
+            )
+
+    def test_should_use_crf_model_with_cv_model_if_enabled(self):
+        with patch_conversion_pipeline() as mocks:
+            inference_model_wrapper = mocks['InferenceModelWrapper'].return_value
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = None
+            opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
+            opt.output_path = OUTPUT_PATH
+            opt.output_suffix = OUTPUT_SUFFIX
+            opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR
+            with TestPipeline() as p:
+                mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
+                mocks['read_all_from_path'].return_value = PDF_CONTENT_1
+                _setup_mocks_for_pages(mocks, [1, 2])
+                configure_pipeline(p, opt)
+
+            mocks['convert_pdf_bytes_to_structured_document'].assert_called_with(
+                PDF_CONTENT_1, page_range=None, path=PDF_FILE_1
+            )
+
+            # cv model
+            inference_model_wrapper.assert_called_with(
+                [fake_pdf_png_page(i) for i in [1, 2]]
+            )
+            mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with(
+                mocks['convert_pdf_bytes_to_structured_document'].return_value,
+                inference_model_wrapper.return_value,
+                inference_model_wrapper.get_color_map.return_value,
+                tag_scope=CV_TAG_SCOPE
+            )
+
+            # crf model should receive output from cv model
+            mocks['predict_and_annotate_structured_document'].assert_called_with(
+                mocks['annotate_structured_document_using_predicted_image_data'].return_value,
+                mocks['load_crf_model'].return_value
+            )
+            mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
+                mocks['predict_and_annotate_structured_document'].return_value,
+                tag_scope=CRF_TAG_SCOPE
+            )
+
+    def test_should_use_cv_model_only_if_enabled(self):
+        with patch_conversion_pipeline() as mocks:
+            inference_model_wrapper = mocks['InferenceModelWrapper'].return_value
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = None
+            opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
+            opt.output_path = OUTPUT_PATH
+            opt.output_suffix = OUTPUT_SUFFIX
+            opt.crf_model = None
+            opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR
+            with TestPipeline() as p:
+                mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
+                mocks['read_all_from_path'].return_value = PDF_CONTENT_1
+                _setup_mocks_for_pages(mocks, [1, 2])
+                configure_pipeline(p, opt)
+
+            mocks['convert_pdf_bytes_to_structured_document'].assert_called_with(
+                PDF_CONTENT_1, page_range=None, path=PDF_FILE_1
+            )
+
+            # cv model
+            inference_model_wrapper.assert_called_with(
+                [fake_pdf_png_page(i) for i in [1, 2]]
+            )
+            mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with(
+                mocks['convert_pdf_bytes_to_structured_document'].return_value,
+                inference_model_wrapper.return_value,
+                inference_model_wrapper.get_color_map.return_value,
+                tag_scope=CV_TAG_SCOPE
+            )
+            mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
+                mocks['annotate_structured_document_using_predicted_image_data'].return_value,
+                tag_scope=CV_TAG_SCOPE
+            )
+
+            # crf model not be called
+            mocks['predict_and_annotate_structured_document'].assert_not_called()
+
+    def test_should_use_grobid_if_enabled(self):
+        with patch_conversion_pipeline() as mocks:
+            grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = None
+            opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
+            opt.output_path = OUTPUT_PATH
+            opt.output_suffix = OUTPUT_SUFFIX
+            opt.use_grobid = True
+            opt.grobid_url = 'http://test/api'
+            with TestPipeline() as p:
+                mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
+                mocks['read_all_from_path'].return_value = PDF_CONTENT_1
+                mocks['convert_pdf_bytes_to_lxml'].return_value = LXML_CONTENT_1
+                configure_pipeline(p, opt)
+
+            mocks['GrobidXmlEnhancer'].assert_called_with(
+                opt.grobid_url,
+                start_service=opt.start_grobid_service
+            )
+            grobid_xml_enhancer.assert_called_with(
+                mocks['extract_annotated_structured_document_to_xml'].return_value
+            )
+            mocks['save_file_content'].assert_called_with(
+                OUTPUT_XML_FILE_1,
+                grobid_xml_enhancer.return_value
+            )
+
+    def test_should_use_grobid_with_lxml_file_list_if_enabled(self):
+        with patch_conversion_pipeline() as mocks:
+            grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = None
+            opt.pdf_file_list = None
+            opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv'
+            opt.output_path = OUTPUT_PATH
+            opt.output_suffix = OUTPUT_SUFFIX
+            opt.crf_model = None
+            opt.cv_model_export_dir = None
+            opt.use_grobid = True
+            opt.grobid_url = 'http://test/api'
+            with TestPipeline() as p:
+                mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1])
+                configure_pipeline(p, opt)
+
+            mocks['extract_annotated_structured_document_to_xml'].assert_called_with(
+                mocks['load_structured_document'].return_value,
+                tag_scope=None
+            )
+            mocks['save_file_content'].assert_called_with(
+                OUTPUT_XML_FILE_1,
+                grobid_xml_enhancer.return_value
+            )
+
+    def test_should_use_grobid_only_if_crf_or_cv_model_are_not_enabled(self):
+        with patch_conversion_pipeline() as mocks:
+            opt = get_default_args()
+            opt.base_data_path = BASE_DATA_PATH
+            opt.pdf_path = None
+            opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv'
+            opt.output_path = OUTPUT_PATH
+            opt.output_suffix = OUTPUT_SUFFIX
+            opt.crf_model = None
+            opt.cv_model_export_dir = None
+            opt.use_grobid = True
+            opt.grobid_url = 'http://test/api'
+            with TestPipeline() as p:
+                mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1])
+                mocks['read_all_from_path'].return_value = PDF_CONTENT_1
+                mocks['grobid_service'].return_value = lambda x: (
+                    PDF_FILE_1, TEI_XML_CONTENT_1
+                )
+                configure_pipeline(p, opt)
+
+            mocks['grobid_service'].assert_called_with(
+                opt.grobid_url,
+                opt.grobid_action,
+                start_service=opt.start_grobid_service
+            )
+            mocks['save_file_content'].assert_called_with(
+                OUTPUT_XML_FILE_1,
+                TEI_XML_CONTENT_1
+            )
+
 
 class TestParseArgs(object):
-  def test_should_parse_minimum_number_of_arguments(self):
-    parse_args(MIN_ARGV)
-
-  def test_should_raise_error_if_no_source_argument_was_provided(self):
-    with pytest.raises(SystemExit):
-      parse_args([
-        '--data-path=' + BASE_DATA_PATH,
-        '--crf-model=' + MODEL_EXPORT_DIR
-      ])
-
-  def test_should_allow_pdf_path_to_be_specified(self):
-    args = parse_args([
-      '--data-path=' + BASE_DATA_PATH,
-      '--pdf-path=' + PDF_PATH,
-      '--crf-model=' + MODEL_EXPORT_DIR
-    ])
-    assert args.pdf_path == PDF_PATH
-
-  def test_should_allow_pdf_file_list_and_column_to_be_specified(self):
-    args = parse_args([
-      '--data-path=' + BASE_DATA_PATH,
-      '--pdf-file-list=' + FILE_LIST_PATH,
-      '--pdf-file-column=' + FILE_COLUMN,
-      '--crf-model=' + MODEL_EXPORT_DIR
-    ])
-    assert args.pdf_file_list == FILE_LIST_PATH
-    assert args.pdf_file_column == FILE_COLUMN
-
-  def test_should_allow_lxml_file_list_and_column_to_be_specified(self):
-    args = parse_args([
-      '--data-path=' + BASE_DATA_PATH,
-      '--lxml-file-list=' + FILE_LIST_PATH,
-      '--lxml-file-column=' + FILE_COLUMN
-    ])
-    assert args.lxml_file_list == FILE_LIST_PATH
-    assert args.lxml_file_column == FILE_COLUMN
-
-  def test_should_not_allow_crf_model_with_lxml_file_list(self):
-    with pytest.raises(SystemExit):
-      parse_args([
-        '--data-path=' + BASE_DATA_PATH,
-        '--lxml-file-list=' + FILE_LIST_PATH,
-        '--lxml-file-column=' + FILE_COLUMN,
-        '--crf-model=' + MODEL_EXPORT_DIR
-      ])
-
-  def test_should_not_allow_cv_model_with_lxml_file_list(self):
-    with pytest.raises(SystemExit):
-      parse_args([
-        '--data-path=' + BASE_DATA_PATH,
-        '--lxml-file-list=' + FILE_LIST_PATH,
-        '--lxml-file-column=' + FILE_COLUMN,
-        '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR
-      ])
-
-  def test_should_require_crf_or_cv_model_with_pdf_file_list(self):
-    with pytest.raises(SystemExit):
-      parse_args([
-        '--data-path=' + BASE_DATA_PATH,
-        '--pdf-file-list=' + FILE_LIST_PATH,
-        '--pdf-file-column=' + FILE_COLUMN
-      ])
-
-  def test_should_require_crf_or_cv_model_with_pdf_path(self):
-    with pytest.raises(SystemExit):
-      parse_args([
-        '--data-path=' + BASE_DATA_PATH,
-        '--pdf-path=' + PDF_PATH
-      ])
-
-  def test_should_allow_crf_model_only_with_pdf_file_list(self):
-    parse_args([
-      '--data-path=' + BASE_DATA_PATH,
-      '--pdf-file-list=' + FILE_LIST_PATH,
-      '--pdf-file-column=' + FILE_COLUMN,
-      '--crf-model=' + MODEL_EXPORT_DIR
-    ])
-
-  def test_should_allow_cv_model_only_with_pdf_file_list(self):
-    parse_args([
-      '--data-path=' + BASE_DATA_PATH,
-      '--pdf-file-list=' + FILE_LIST_PATH,
-      '--pdf-file-column=' + FILE_COLUMN,
-      '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR
-    ])
-
-  def test_should_allow_crf_and_cv_model_only_with_pdf_file_list(self):
-    parse_args([
-      '--data-path=' + BASE_DATA_PATH,
-      '--pdf-file-list=' + FILE_LIST_PATH,
-      '--pdf-file-column=' + FILE_COLUMN,
-      '--crf-model=' + MODEL_EXPORT_DIR,
-      '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR
-    ])
-
-  def test_should_allow_grobid_only_with_pdf_file_list(self):
-    parse_args([
-      '--data-path=' + BASE_DATA_PATH,
-      '--pdf-file-list=' + FILE_LIST_PATH,
-      '--pdf-file-column=' + FILE_COLUMN,
-      '--use-grobid'
-    ])
+    def test_should_parse_minimum_number_of_arguments(self):
+        parse_args(MIN_ARGV)
+
+    def test_should_raise_error_if_no_source_argument_was_provided(self):
+        with pytest.raises(SystemExit):
+            parse_args([
+                '--data-path=' + BASE_DATA_PATH,
+                '--crf-model=' + MODEL_EXPORT_DIR
+            ])
+
+    def test_should_allow_pdf_path_to_be_specified(self):
+        args = parse_args([
+            '--data-path=' + BASE_DATA_PATH,
+            '--pdf-path=' + PDF_PATH,
+            '--crf-model=' + MODEL_EXPORT_DIR
+        ])
+        assert args.pdf_path == PDF_PATH
+
+    def test_should_allow_pdf_file_list_and_column_to_be_specified(self):
+        args = parse_args([
+            '--data-path=' + BASE_DATA_PATH,
+            '--pdf-file-list=' + FILE_LIST_PATH,
+            '--pdf-file-column=' + FILE_COLUMN,
+            '--crf-model=' + MODEL_EXPORT_DIR
+        ])
+        assert args.pdf_file_list == FILE_LIST_PATH
+        assert args.pdf_file_column == FILE_COLUMN
+
+    def test_should_allow_lxml_file_list_and_column_to_be_specified(self):
+        args = parse_args([
+            '--data-path=' + BASE_DATA_PATH,
+            '--lxml-file-list=' + FILE_LIST_PATH,
+            '--lxml-file-column=' + FILE_COLUMN
+        ])
+        assert args.lxml_file_list == FILE_LIST_PATH
+        assert args.lxml_file_column == FILE_COLUMN
+
+    def test_should_not_allow_crf_model_with_lxml_file_list(self):
+        with pytest.raises(SystemExit):
+            parse_args([
+                '--data-path=' + BASE_DATA_PATH,
+                '--lxml-file-list=' + FILE_LIST_PATH,
+                '--lxml-file-column=' + FILE_COLUMN,
+                '--crf-model=' + MODEL_EXPORT_DIR
+            ])
+
+    def test_should_not_allow_cv_model_with_lxml_file_list(self):
+        with pytest.raises(SystemExit):
+            parse_args([
+                '--data-path=' + BASE_DATA_PATH,
+                '--lxml-file-list=' + FILE_LIST_PATH,
+                '--lxml-file-column=' + FILE_COLUMN,
+                '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR
+            ])
+
+    def test_should_require_crf_or_cv_model_with_pdf_file_list(self):
+        with pytest.raises(SystemExit):
+            parse_args([
+                '--data-path=' + BASE_DATA_PATH,
+                '--pdf-file-list=' + FILE_LIST_PATH,
+                '--pdf-file-column=' + FILE_COLUMN
+            ])
+
+    def test_should_require_crf_or_cv_model_with_pdf_path(self):
+        with pytest.raises(SystemExit):
+            parse_args([
+                '--data-path=' + BASE_DATA_PATH,
+                '--pdf-path=' + PDF_PATH
+            ])
+
+    def test_should_allow_crf_model_only_with_pdf_file_list(self):
+        parse_args([
+            '--data-path=' + BASE_DATA_PATH,
+            '--pdf-file-list=' + FILE_LIST_PATH,
+            '--pdf-file-column=' + FILE_COLUMN,
+            '--crf-model=' + MODEL_EXPORT_DIR
+        ])
+
+    def test_should_allow_cv_model_only_with_pdf_file_list(self):
+        parse_args([
+            '--data-path=' + BASE_DATA_PATH,
+            '--pdf-file-list=' + FILE_LIST_PATH,
+            '--pdf-file-column=' + FILE_COLUMN,
+            '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR
+        ])
+
+    def test_should_allow_crf_and_cv_model_only_with_pdf_file_list(self):
+        parse_args([
+            '--data-path=' + BASE_DATA_PATH,
+            '--pdf-file-list=' + FILE_LIST_PATH,
+            '--pdf-file-column=' + FILE_COLUMN,
+            '--crf-model=' + MODEL_EXPORT_DIR,
+            '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR
+        ])
+
+    def test_should_allow_grobid_only_with_pdf_file_list(self):
+        parse_args([
+            '--data-path=' + BASE_DATA_PATH,
+            '--pdf-file-list=' + FILE_LIST_PATH,
+            '--pdf-file-column=' + FILE_COLUMN,
+            '--use-grobid'
+        ])
diff --git a/sciencebeam_gym/convert/cv_conversion_utils.py b/sciencebeam_gym/convert/cv_conversion_utils.py
index d3cbc56..cd36db2 100644
--- a/sciencebeam_gym/convert/cv_conversion_utils.py
+++ b/sciencebeam_gym/convert/cv_conversion_utils.py
@@ -8,42 +8,49 @@ import numpy as np
 from PIL import Image
 
 from sciencebeam_gym.inference_model import (
-  load_inference_model
+    load_inference_model
 )
 
+
 def lazy_cached_value(value_fn):
-  cache = {}
-  def wrapper():
-    value = cache.get('value')
-    if value is None:
-      value = value_fn()
-      cache['value'] = value
-    return value
-  return wrapper
+    cache = {}
+
+    def wrapper():
+        value = cache.get('value')
+        if value is None:
+            value = value_fn()
+            cache['value'] = value
+        return value
+    return wrapper
+
 
 def image_data_to_png(image_data):
-  image = Image.fromarray(image_data, 'RGB')
-  out = BytesIO()
-  image.save(out, 'png')
-  return out.getvalue()
+    image = Image.fromarray(image_data, 'RGB')
+    out = BytesIO()
+    image.save(out, 'png')
+    return out.getvalue()
+
 
 def png_bytes_to_image_data(png_bytes):
-  return np.asarray(Image.open(BytesIO(png_bytes)).convert('RGB'), dtype=np.uint8)
+    return np.asarray(Image.open(BytesIO(png_bytes)).convert('RGB'), dtype=np.uint8)
+
 
 class InferenceModelWrapper(object):
-  def __init__(self, export_dir):
-    self.session_cache = lazy_cached_value(lambda: tf.InteractiveSession())
-    self.inference_model_cache = lazy_cached_value(
-      lambda: load_inference_model(export_dir, session=self.session_cache())
-    )
-
-  def get_color_map(self):
-    return self.inference_model_cache().get_color_map(session=self.session_cache())
-
-  def __call__(self, png_pages):
-    input_data = [
-      png_bytes_to_image_data(png_page)
-      for png_page in png_pages
-    ]
-    output_img_data_batch = self.inference_model_cache()(input_data, session=self.session_cache())
-    return output_img_data_batch
+    def __init__(self, export_dir):
+        self.session_cache = lazy_cached_value(lambda: tf.InteractiveSession())
+        self.inference_model_cache = lazy_cached_value(
+            lambda: load_inference_model(export_dir, session=self.session_cache())
+        )
+
+    def get_color_map(self):
+        return self.inference_model_cache().get_color_map(session=self.session_cache())
+
+    def __call__(self, png_pages):
+        input_data = [
+            png_bytes_to_image_data(png_page)
+            for png_page in png_pages
+        ]
+        output_img_data_batch = self.inference_model_cache()(
+            input_data, session=self.session_cache()
+        )
+        return output_img_data_batch
diff --git a/sciencebeam_gym/convert/cv_conversion_utils_test.py b/sciencebeam_gym/convert/cv_conversion_utils_test.py
index ae82821..addbaca 100644
--- a/sciencebeam_gym/convert/cv_conversion_utils_test.py
+++ b/sciencebeam_gym/convert/cv_conversion_utils_test.py
@@ -1,46 +1,64 @@
-import logging
 from mock import patch
 
+import pytest
+
 from . import cv_conversion_utils as cv_conversion_utils
 from .cv_conversion_utils import (
-  InferenceModelWrapper
+    InferenceModelWrapper
 )
 
 CV_MODEL_EXPORT_DIR = './model-export'
 PNG_BYTES = b'dummy png bytes'
 
-def setup_module():
-  logging.basicConfig(level='DEBUG')
 
-class TestInferenceModelWrapper(object):
-  def test_should_lazy_load_model(self):
+@pytest.fixture(name='load_inference_model_mock')
+def _load_inference_model_mock():
     with patch.object(cv_conversion_utils, 'load_inference_model') as load_inference_model:
-      with patch.object(cv_conversion_utils, 'png_bytes_to_image_data') as png_bytes_to_image_data:
-        with patch.object(cv_conversion_utils, 'tf') as tf:
-          inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR)
-          load_inference_model.assert_not_called()
+        yield load_inference_model
 
-          output_image_data = inference_model_wrapper([PNG_BYTES])
 
-          tf.InteractiveSession.assert_called_with()
-          session = tf.InteractiveSession.return_value
+@pytest.fixture(name='png_bytes_to_image_data_mock')
+def _png_bytes_to_image_data_mock():
+    with patch.object(cv_conversion_utils, 'png_bytes_to_image_data') as png_bytes_to_image_data:
+        yield png_bytes_to_image_data
 
-          png_bytes_to_image_data.assert_called_with(PNG_BYTES)
 
-          load_inference_model.assert_called_with(CV_MODEL_EXPORT_DIR, session=session)
-          inference_model = load_inference_model.return_value
+@pytest.fixture(name='tf_mock')
+def _tf_mock():
+    with patch.object(cv_conversion_utils, 'tf') as tf_mock:
+        yield tf_mock
 
-          inference_model.assert_called_with([
-            png_bytes_to_image_data.return_value
-          ], session=session)
 
-          assert output_image_data == inference_model.return_value
+class TestInferenceModelWrapper(object):
+    def test_should_lazy_load_model(
+            self,
+            load_inference_model_mock,
+            png_bytes_to_image_data_mock,
+            tf_mock):
+        inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR)
+        load_inference_model_mock.assert_not_called()
 
-  def test_should_load_model_only_once(self):
-    with patch.object(cv_conversion_utils, 'load_inference_model') as load_inference_model:
-      with patch.object(cv_conversion_utils, 'png_bytes_to_image_data') as _:
-        with patch.object(cv_conversion_utils, 'tf') as _:
-          inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR)
-          inference_model_wrapper([PNG_BYTES])
-          inference_model_wrapper([PNG_BYTES])
-          load_inference_model.assert_called_once()
+        output_image_data = inference_model_wrapper([PNG_BYTES])
+
+        tf_mock.InteractiveSession.assert_called_with()
+        session = tf_mock.InteractiveSession.return_value
+
+        png_bytes_to_image_data_mock.assert_called_with(PNG_BYTES)
+
+        load_inference_model_mock.assert_called_with(CV_MODEL_EXPORT_DIR, session=session)
+        inference_model = load_inference_model_mock.return_value
+
+        inference_model.assert_called_with([
+            png_bytes_to_image_data_mock.return_value
+        ], session=session)
+
+        assert output_image_data == inference_model.return_value
+
+    @pytest.mark.usefixtures('png_bytes_to_image_data_mock', 'tf_mock')
+    def test_should_load_model_only_once(
+            self,
+            load_inference_model_mock):
+        inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR)
+        inference_model_wrapper([PNG_BYTES])
+        inference_model_wrapper([PNG_BYTES])
+        load_inference_model_mock.assert_called_once()
diff --git a/sciencebeam_gym/convert/grobid/grobid_service.py b/sciencebeam_gym/convert/grobid/grobid_service.py
index e80fa08..73f1ca6 100644
--- a/sciencebeam_gym/convert/grobid/grobid_service.py
+++ b/sciencebeam_gym/convert/grobid/grobid_service.py
@@ -5,80 +5,86 @@ from functools import partial
 import requests
 
 from .grobid_service_wrapper import (
-  GrobidServiceWrapper
+    GrobidServiceWrapper
 )
 
+
 class GrobidApiPaths(object):
-  PROCESS_HEADER_DOCUMENT = '/processHeaderDocument'
-  PROCESS_HEADER_NAMES = '/processHeaderNames'
-  PROCESS_CITATION_NAMES = '/processCitationNames'
-  PROCESS_AFFILIATIONS = '/processAffiliations'
-  PROCESS_CITATION = '/processCitation'
-  PROCESS_FULL_TEXT_DOCUMENT = '/processFulltextDocument'
+    PROCESS_HEADER_DOCUMENT = '/processHeaderDocument'
+    PROCESS_HEADER_NAMES = '/processHeaderNames'
+    PROCESS_CITATION_NAMES = '/processCitationNames'
+    PROCESS_AFFILIATIONS = '/processAffiliations'
+    PROCESS_CITATION = '/processCitation'
+    PROCESS_FULL_TEXT_DOCUMENT = '/processFulltextDocument'
+
 
 service_wrapper = GrobidServiceWrapper()
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def start_service_if_not_running():
-  service_wrapper.start_service_if_not_running()
+    service_wrapper.start_service_if_not_running()
+
 
 def run_grobid_service(item, base_url, path, start_service=True, field_name=None):
-  """
-  Translates PDF content via the GROBID service.
-
-  Args:
-    item: one of:
-      * tuple (filename, pdf content)
-      * pdf content
-      * field content (requires field name)
-    base_url: base url to the GROBID service
-    path: path of the GROBID endpoint
-    start_service: if true, a GROBID service will be started automatically and
-      kept running until the application ends
-    field_name: the field name the field content relates to
-
-  Returns:
-    If item is tuple:
-      returns tuple (filename, xml result)
-    Otherwise:
-      returns xml result
-  """
-
-  url = base_url + path
-
-  if start_service:
-    start_service_if_not_running()
-
-  if field_name:
-    content = item
-    response = requests.post(url,
-      data={field_name: content}
-    )
-  else:
-    filename = item[0] if isinstance(item, tuple) else 'unknown.pdf'
-    content = item[1] if isinstance(item, tuple) else item
-    get_logger().info('processing: %s (%d) - %s', filename, len(content), url)
-    response = requests.post(url,
-      files={'input': (filename, BytesIO(content))},
-      data={
-        'consolidateHeader': '0',
-        'consolidateCitations': '0'
-      }
-    )
-  response.raise_for_status()
-  result_content = response.content
-  if isinstance(item, tuple):
-    return filename, result_content
-  else:
-    return result_content
+    """
+    Translates PDF content via the GROBID service.
+
+    Args:
+      item: one of:
+        * tuple (filename, pdf content)
+        * pdf content
+        * field content (requires field name)
+      base_url: base url to the GROBID service
+      path: path of the GROBID endpoint
+      start_service: if true, a GROBID service will be started automatically and
+        kept running until the application ends
+      field_name: the field name the field content relates to
+
+    Returns:
+      If item is tuple:
+        returns tuple (filename, xml result)
+      Otherwise:
+        returns xml result
+    """
+
+    url = base_url + path
+
+    if start_service:
+        start_service_if_not_running()
+
+    if field_name:
+        content = item
+        response = requests.post(url,
+                                 data={field_name: content}
+                                 )
+    else:
+        filename = item[0] if isinstance(item, tuple) else 'unknown.pdf'
+        content = item[1] if isinstance(item, tuple) else item
+        get_logger().info('processing: %s (%d) - %s', filename, len(content), url)
+        response = requests.post(url,
+                                 files={'input': (filename, BytesIO(content))},
+                                 data={
+                                     'consolidateHeader': '0',
+                                     'consolidateCitations': '0'
+                                 }
+                                 )
+    response.raise_for_status()
+    result_content = response.content
+    if isinstance(item, tuple):
+        return filename, result_content
+    else:
+        return result_content
+
 
 def grobid_service(base_url, path, start_service=True, field_name=None):
-  return partial(
-    run_grobid_service,
-    base_url=base_url,
-    path=path,
-    start_service=start_service,
-    field_name=field_name
-  )
+    return partial(
+        run_grobid_service,
+        base_url=base_url,
+        path=path,
+        start_service=start_service,
+        field_name=field_name
+    )
diff --git a/sciencebeam_gym/convert/grobid/grobid_service_test.py b/sciencebeam_gym/convert/grobid/grobid_service_test.py
index fd02bfe..77cc021 100644
--- a/sciencebeam_gym/convert/grobid/grobid_service_test.py
+++ b/sciencebeam_gym/convert/grobid/grobid_service_test.py
@@ -14,54 +14,56 @@ PDF_CONTENT_1 = b'pdf content1'
 FIELD_NAME_1 = 'field1'
 FIELD_VALUE_1 = 'value1'
 
+
 @pytest.fixture(name='requests_post', autouse=True)
 def _mock_requests_post():
-  with patch('requests.post') as requests_post:
-    yield requests_post
+    with patch('requests.post') as requests_post:
+        yield requests_post
+
 
 class TestCreateGrobidService(object):
-  def test_should_pass_url_and_data_as_file(self, requests_post):
-    create_grobid_service(BASE_URL, PATH_1, start_service=False)(
-      (FILENAME_1, PDF_CONTENT_1)
-    )
-    requests_post.assert_called_with(
-      BASE_URL + PATH_1,
-      files=ANY,
-      data=ANY
-    )
-    kwargs = requests_post.call_args[1]
-    assert kwargs['files']['input'][0] == FILENAME_1
-    assert kwargs['files']['input'][1].read() == PDF_CONTENT_1
+    def test_should_pass_url_and_data_as_file(self, requests_post):
+        create_grobid_service(BASE_URL, PATH_1, start_service=False)(
+            (FILENAME_1, PDF_CONTENT_1)
+        )
+        requests_post.assert_called_with(
+            BASE_URL + PATH_1,
+            files=ANY,
+            data=ANY
+        )
+        kwargs = requests_post.call_args[1]
+        assert kwargs['files']['input'][0] == FILENAME_1
+        assert kwargs['files']['input'][1].read() == PDF_CONTENT_1
 
-  def test_should_be_able_to_override_path_on_call(self, requests_post):
-    create_grobid_service(BASE_URL, PATH_1, start_service=False)(
-      (FILENAME_1, PDF_CONTENT_1),
-      path=PATH_2
-    )
-    requests_post.assert_called_with(
-      BASE_URL + PATH_2,
-      files=ANY,
-      data=ANY
-    )
+    def test_should_be_able_to_override_path_on_call(self, requests_post):
+        create_grobid_service(BASE_URL, PATH_1, start_service=False)(
+            (FILENAME_1, PDF_CONTENT_1),
+            path=PATH_2
+        )
+        requests_post.assert_called_with(
+            BASE_URL + PATH_2,
+            files=ANY,
+            data=ANY
+        )
 
-  def test_should_pass_data_as_field(self, requests_post):
-    create_grobid_service(BASE_URL, PATH_1, start_service=False, field_name=FIELD_NAME_1)(
-      FIELD_VALUE_1
-    )
-    requests_post.assert_called_with(
-      BASE_URL + PATH_1,
-      data={FIELD_NAME_1: FIELD_VALUE_1}
-    )
+    def test_should_pass_data_as_field(self, requests_post):
+        create_grobid_service(BASE_URL, PATH_1, start_service=False, field_name=FIELD_NAME_1)(
+            FIELD_VALUE_1
+        )
+        requests_post.assert_called_with(
+            BASE_URL + PATH_1,
+            data={FIELD_NAME_1: FIELD_VALUE_1}
+        )
 
-  def test_should_pass_consolidate_flags(self, requests_post):
-    create_grobid_service(BASE_URL, PATH_1, start_service=False)(
-      (FILENAME_1, PDF_CONTENT_1)
-    )
-    requests_post.assert_called_with(
-      BASE_URL + PATH_1,
-      files=ANY,
-      data={
-        'consolidateHeader': '0',
-        'consolidateCitations': '0'
-      }
-    )
+    def test_should_pass_consolidate_flags(self, requests_post):
+        create_grobid_service(BASE_URL, PATH_1, start_service=False)(
+            (FILENAME_1, PDF_CONTENT_1)
+        )
+        requests_post.assert_called_with(
+            BASE_URL + PATH_1,
+            files=ANY,
+            data={
+                'consolidateHeader': '0',
+                'consolidateCitations': '0'
+            }
+        )
diff --git a/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py b/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py
index 3f47f37..d69b531 100644
--- a/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py
+++ b/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py
@@ -13,114 +13,120 @@ from urllib import URLopener
 from sciencebeam_utils.utils.io import makedirs
 from sciencebeam_utils.utils.zip import extract_all_with_executable_permission
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def iter_read_lines(reader):
-  while True:
-    line = reader.readline()
-    if not line:
-      break
-    yield line
+    while True:
+        line = reader.readline()
+        if not line:
+            break
+        yield line
+
 
 def stream_lines_to_logger(lines, logger, prefix=''):
-  for line in lines:
-    line = line.strip()
-    if line:
-      logger.info('%s%s', prefix, line)
+    for line in lines:
+        line = line.strip()
+        if line:
+            logger.info('%s%s', prefix, line)
+
 
 class GrobidServiceWrapper(object):
-  def __init__(self):
-    self.grobid_service_instance = None
-    temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.temp'))
-    self.grobid_service_target_directory = os.path.join(temp_dir, 'grobid-service')
-    self.grobid_service_zip_filename = os.path.join(temp_dir, 'grobid-service.zip')
-    self.grobid_service_zip_url = (
-      'https://storage.googleapis.com/elife-ml/artefacts/grobid-service.zip'
-    )
-
-  def stop_service_if_running(self):
-    if self.grobid_service_instance is not None:
-      get_logger().info('stopping instance: %s', self.grobid_service_instance)
-      self.grobid_service_instance.kill()
-
-  def download__grobid_service_zip_if_not_exist(self):
-    if not os.path.isfile(self.grobid_service_zip_filename):
-      get_logger().info(
-        'downloading %s to %s',
-        self.grobid_service_zip_url,
-        self.grobid_service_zip_filename
-      )
-
-      makedirs(os.path.dirname(self.grobid_service_zip_filename), exists_ok=True)
-
-      temp_zip_filename = self.grobid_service_zip_filename + '.part'
-      if os.path.isfile(temp_zip_filename):
-        os.remove(temp_zip_filename)
-      URLopener().retrieve(self.grobid_service_zip_url, temp_zip_filename)
-      os.rename(temp_zip_filename, self.grobid_service_zip_filename)
-
-  def unzip_grobid_service_zip_if_target_directory_does_not_exist(self):
-    if not os.path.isdir(self.grobid_service_target_directory):
-      self.download__grobid_service_zip_if_not_exist()
-      get_logger().info(
-        'unzipping %s to %s',
-        self.grobid_service_zip_filename,
-        self.grobid_service_target_directory
-      )
-      temp_target_directory = self.grobid_service_target_directory + '.part'
-      if os.path.isdir(temp_target_directory):
-        rmtree(temp_target_directory)
-
-      with ZipFile(self.grobid_service_zip_filename, 'r') as zf:
-        extract_all_with_executable_permission(zf, temp_target_directory)
-        sub_dir = os.path.join(temp_target_directory, 'grobid-service')
-        if os.path.isdir(sub_dir):
-          os.rename(sub_dir, self.grobid_service_target_directory)
-          rmtree(temp_target_directory)
-        else:
-          os.rename(temp_target_directory, self.grobid_service_target_directory)
-
-  def start_service_if_not_running(self):
-    get_logger().info('grobid_service_instance: %s', self.grobid_service_instance)
-    if self.grobid_service_instance is None:
-      self.unzip_grobid_service_zip_if_target_directory_does_not_exist()
-      grobid_service_home = os.path.abspath(self.grobid_service_target_directory)
-      cwd = grobid_service_home + '/bin'
-      grobid_service_home_jar_dir = grobid_service_home + '/lib'
-      command_line = 'java -cp "{}/*" org.grobid.service.main.GrobidServiceApplication'.format(
-        grobid_service_home_jar_dir
-      )
-      args = shlex.split(command_line)
-      get_logger().info('command_line: %s', command_line)
-      get_logger().info('args: %s', args)
-      self.grobid_service_instance = subprocess.Popen(
-        args, cwd=cwd, stdout=PIPE, stderr=subprocess.STDOUT
-      )
-      if self.grobid_service_instance is None:
-        raise RuntimeError('failed to start grobid service')
-      atexit.register(self.stop_service_if_running)
-      pstdout = self.grobid_service_instance.stdout
-      out_prefix = 'stdout: '
-      while True:
-        line = pstdout.readline().strip()
-        if line:
-          get_logger().info('%s%s', out_prefix, line)
-        if 'jetty.server.Server: Started' in line:
-          get_logger().info('grobid service started successfully')
-          break
-        if 'ERROR' in line or 'Error' in line:
-          raise RuntimeError('failed to start grobid service due to {}'.format(line))
-      t = Thread(target=partial(
-        stream_lines_to_logger,
-        lines=iter_read_lines(pstdout),
-        logger=get_logger(),
-        prefix=out_prefix
-      ))
-      t.daemon = True
-      t.start()
+    def __init__(self):
+        self.grobid_service_instance = None
+        temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.temp'))
+        self.grobid_service_target_directory = os.path.join(temp_dir, 'grobid-service')
+        self.grobid_service_zip_filename = os.path.join(temp_dir, 'grobid-service.zip')
+        self.grobid_service_zip_url = (
+            'https://storage.googleapis.com/elife-ml/artefacts/grobid-service.zip'
+        )
+
+    def stop_service_if_running(self):
+        if self.grobid_service_instance is not None:
+            get_logger().info('stopping instance: %s', self.grobid_service_instance)
+            self.grobid_service_instance.kill()
+
+    def download__grobid_service_zip_if_not_exist(self):
+        if not os.path.isfile(self.grobid_service_zip_filename):
+            get_logger().info(
+                'downloading %s to %s',
+                self.grobid_service_zip_url,
+                self.grobid_service_zip_filename
+            )
+
+            makedirs(os.path.dirname(self.grobid_service_zip_filename), exists_ok=True)
+
+            temp_zip_filename = self.grobid_service_zip_filename + '.part'
+            if os.path.isfile(temp_zip_filename):
+                os.remove(temp_zip_filename)
+            URLopener().retrieve(self.grobid_service_zip_url, temp_zip_filename)
+            os.rename(temp_zip_filename, self.grobid_service_zip_filename)
+
+    def unzip_grobid_service_zip_if_target_directory_does_not_exist(self):
+        if not os.path.isdir(self.grobid_service_target_directory):
+            self.download__grobid_service_zip_if_not_exist()
+            get_logger().info(
+                'unzipping %s to %s',
+                self.grobid_service_zip_filename,
+                self.grobid_service_target_directory
+            )
+            temp_target_directory = self.grobid_service_target_directory + '.part'
+            if os.path.isdir(temp_target_directory):
+                rmtree(temp_target_directory)
+
+            with ZipFile(self.grobid_service_zip_filename, 'r') as zf:
+                extract_all_with_executable_permission(zf, temp_target_directory)
+                sub_dir = os.path.join(temp_target_directory, 'grobid-service')
+                if os.path.isdir(sub_dir):
+                    os.rename(sub_dir, self.grobid_service_target_directory)
+                    rmtree(temp_target_directory)
+                else:
+                    os.rename(temp_target_directory, self.grobid_service_target_directory)
+
+    def start_service_if_not_running(self):
+        get_logger().info('grobid_service_instance: %s', self.grobid_service_instance)
+        if self.grobid_service_instance is None:
+            self.unzip_grobid_service_zip_if_target_directory_does_not_exist()
+            grobid_service_home = os.path.abspath(self.grobid_service_target_directory)
+            cwd = grobid_service_home + '/bin'
+            grobid_service_home_jar_dir = grobid_service_home + '/lib'
+            main_class = "org.grobid.service.main.GrobidServiceApplication"
+            command_line = 'java -cp "%s/*" %s' % (
+                grobid_service_home_jar_dir, main_class
+            )
+            args = shlex.split(command_line)
+            get_logger().info('command_line: %s', command_line)
+            get_logger().info('args: %s', args)
+            self.grobid_service_instance = subprocess.Popen(
+                args, cwd=cwd, stdout=PIPE, stderr=subprocess.STDOUT
+            )
+            if self.grobid_service_instance is None:
+                raise RuntimeError('failed to start grobid service')
+            atexit.register(self.stop_service_if_running)
+            pstdout = self.grobid_service_instance.stdout
+            out_prefix = 'stdout: '
+            while True:
+                line = pstdout.readline().strip()
+                if line:
+                    get_logger().info('%s%s', out_prefix, line)
+                if 'jetty.server.Server: Started' in line:
+                    get_logger().info('grobid service started successfully')
+                    break
+                if 'ERROR' in line or 'Error' in line:
+                    raise RuntimeError('failed to start grobid service due to {}'.format(line))
+            t = Thread(target=partial(
+                stream_lines_to_logger,
+                lines=iter_read_lines(pstdout),
+                logger=get_logger(),
+                prefix=out_prefix
+            ))
+            t.daemon = True
+            t.start()
+
 
 if __name__ == '__main__':
-  logging.basicConfig(level='INFO')
+    logging.basicConfig(level='INFO')
 
-  GrobidServiceWrapper().start_service_if_not_running()
+    GrobidServiceWrapper().start_service_if_not_running()
diff --git a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py
index 26646cb..74104ab 100644
--- a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py
+++ b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py
@@ -5,14 +5,14 @@ from lxml import etree
 from lxml.builder import E
 
 from sciencebeam_gym.inference_model.extract_to_xml import (
-  XmlPaths,
-  create_node_recursive,
-  rsplit_xml_path
+    XmlPaths,
+    create_node_recursive,
+    rsplit_xml_path
 )
 
 from .grobid_service import (
-  grobid_service,
-  GrobidApiPaths
+    grobid_service,
+    GrobidApiPaths
 )
 
 TEI_NS = 'http://www.tei-c.org/ns/1.0'
@@ -27,83 +27,86 @@ JATS_ADDR_LINE = 'addr-line'
 JATS_NAMED_CONTENT = 'named-content'
 JATS_INSTITUTION = 'institution'
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def create_or_append(xml_root, path):
-  parent_path, tag_name = rsplit_xml_path(path)
-  parent_node = create_node_recursive(xml_root, parent_path, exists_ok=True)
-  node = E(tag_name)
-  parent_node.append(node)
-  return node
+    parent_path, tag_name = rsplit_xml_path(path)
+    parent_node = create_node_recursive(xml_root, parent_path, exists_ok=True)
+    node = E(tag_name)
+    parent_node.append(node)
+    return node
+
 
 class GrobidXmlEnhancer(object):
-  def __init__(self, grobid_url, start_service):
-    self.process_header_names = grobid_service(
-      grobid_url,
-      GrobidApiPaths.PROCESS_HEADER_NAMES,
-      start_service=start_service,
-      field_name='names'
-    )
-    self.process_affiliations = grobid_service(
-      grobid_url,
-      GrobidApiPaths.PROCESS_AFFILIATIONS,
-      start_service=start_service,
-      field_name='affiliations'
-    )
+    def __init__(self, grobid_url, start_service):
+        self.process_header_names = grobid_service(
+            grobid_url,
+            GrobidApiPaths.PROCESS_HEADER_NAMES,
+            start_service=start_service,
+            field_name='names'
+        )
+        self.process_affiliations = grobid_service(
+            grobid_url,
+            GrobidApiPaths.PROCESS_AFFILIATIONS,
+            start_service=start_service,
+            field_name='affiliations'
+        )
 
-  def process_and_replace_authors(self, xml_root):
-    author_nodes = list(xml_root.findall(XmlPaths.AUTHOR))
-    if author_nodes:
-      authors = '\n'.join(x.text for x in author_nodes)
-      get_logger().debug('authors: %s', authors)
-      grobid_response = self.process_header_names(authors)
-      get_logger().debug('grobid_response: %s', grobid_response)
-      response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response))
-      for author in author_nodes:
-        author.getparent().remove(author)
-      for pers_name in response_xml_root.findall(TEI_PERS_NAME):
-        get_logger().debug('pers_name: %s', pers_name)
-        node = create_or_append(xml_root, XmlPaths.AUTHOR)
-        for surname in pers_name.findall(TEI_SURNAME):
-          node.append(E(JATS_SURNAME, surname.text))
-        forenames = [x.text for x in pers_name.findall(TEI_FORNAME)]
-        if forenames:
-          node.append(E(JATS_GIVEN_NAMES, ' '.join(forenames)))
-    return xml_root
+    def process_and_replace_authors(self, xml_root):
+        author_nodes = list(xml_root.findall(XmlPaths.AUTHOR))
+        if author_nodes:
+            authors = '\n'.join(x.text for x in author_nodes)
+            get_logger().debug('authors: %s', authors)
+            grobid_response = self.process_header_names(authors)
+            get_logger().debug('grobid_response: %s', grobid_response)
+            response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response))
+            for author in author_nodes:
+                author.getparent().remove(author)
+            for pers_name in response_xml_root.findall(TEI_PERS_NAME):
+                get_logger().debug('pers_name: %s', pers_name)
+                node = create_or_append(xml_root, XmlPaths.AUTHOR)
+                for surname in pers_name.findall(TEI_SURNAME):
+                    node.append(E(JATS_SURNAME, surname.text))
+                forenames = [x.text for x in pers_name.findall(TEI_FORNAME)]
+                if forenames:
+                    node.append(E(JATS_GIVEN_NAMES, ' '.join(forenames)))
+        return xml_root
 
-  def process_and_replace_affiliations(self, xml_root):
-    aff_nodes = list(xml_root.findall(XmlPaths.AUTHOR_AFF))
-    if aff_nodes:
-      affiliations = '\n'.join(x.text for x in aff_nodes)
-      get_logger().debug('affiliations: %s', affiliations)
-      grobid_response = self.process_affiliations(affiliations)
-      get_logger().debug('grobid_response: %s', grobid_response)
-      response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response))
-      for aff in aff_nodes:
-        aff.getparent().remove(aff)
-      for affiliation in response_xml_root.findall('affiliation'):
-        get_logger().debug('affiliation: %s', affiliation)
-        node = create_or_append(xml_root, XmlPaths.AUTHOR_AFF)
-        for department in affiliation.xpath('./orgName[@type="department"]'):
-          node.append(E(
-            JATS_ADDR_LINE,
-            E(
-              JATS_NAMED_CONTENT,
-              department.text,
-              {
-                'content-type': 'department'
-              }
-            )
-          ))
-        for institution in affiliation.xpath('./orgName[@type="institution"]'):
-          node.append(E(
-            JATS_INSTITUTION,
-            institution.text
-          ))
+    def process_and_replace_affiliations(self, xml_root):
+        aff_nodes = list(xml_root.findall(XmlPaths.AUTHOR_AFF))
+        if aff_nodes:
+            affiliations = '\n'.join(x.text for x in aff_nodes)
+            get_logger().debug('affiliations: %s', affiliations)
+            grobid_response = self.process_affiliations(affiliations)
+            get_logger().debug('grobid_response: %s', grobid_response)
+            response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response))
+            for aff in aff_nodes:
+                aff.getparent().remove(aff)
+            for affiliation in response_xml_root.findall('affiliation'):
+                get_logger().debug('affiliation: %s', affiliation)
+                node = create_or_append(xml_root, XmlPaths.AUTHOR_AFF)
+                for department in affiliation.xpath('./orgName[@type="department"]'):
+                    node.append(E(
+                        JATS_ADDR_LINE,
+                        E(
+                            JATS_NAMED_CONTENT,
+                            department.text,
+                            {
+                                'content-type': 'department'
+                            }
+                        )
+                    ))
+                for institution in affiliation.xpath('./orgName[@type="institution"]'):
+                    node.append(E(
+                        JATS_INSTITUTION,
+                        institution.text
+                    ))
 
-  def __call__(self, extracted_xml):
-    xml_root = etree.parse(BytesIO(extracted_xml))
-    self.process_and_replace_authors(xml_root)
-    self.process_and_replace_affiliations(xml_root)
-    return etree.tostring(xml_root, pretty_print=True)
+    def __call__(self, extracted_xml):
+        xml_root = etree.parse(BytesIO(extracted_xml))
+        self.process_and_replace_authors(xml_root)
+        self.process_and_replace_affiliations(xml_root)
+        return etree.tostring(xml_root, pretty_print=True)
diff --git a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py
index e2b7052..7726fb9 100644
--- a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py
+++ b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py
@@ -7,17 +7,17 @@ from lxml import etree
 from lxml.builder import E
 
 from sciencebeam_gym.inference_model.extract_to_xml import (
-  XmlPaths,
-  create_xml_text
+    XmlPaths,
+    create_xml_text
 )
 
 from .grobid_service import (
-  GrobidApiPaths
+    GrobidApiPaths
 )
 
 from . import grobid_xml_enhancer as grobid_xml_enhancer
 from .grobid_xml_enhancer import (
-  GrobidXmlEnhancer
+    GrobidXmlEnhancer
 )
 
 GROBID_URL = 'http://localhost:8080/api'
@@ -38,170 +38,179 @@ INSTITUTION_2 = 'institution 2'
 DEPARTMENT_XPATH = './addr-line/named-content[@content-type="department"]'
 INSTITUTION_XPATH = './institution'
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 def pers_name(*names):
-  forenames = names[:-1]
-  surname = names[-1]
-  return (
-    '<persName xmlns="http://www.tei-c.org/ns/1.0">'
-    '  %s'
-    '  <surname>%s</surname>'
-    '</persName>'
-  ) % (
-    ' '.join(
-      '<forename type="%s">%s</forename>' % (
-        'first' if i == 0 else 'middle',
-        forename
-      )
-      for i, forename in enumerate(forenames)
-    ),
-    surname
-  )
+    forenames = names[:-1]
+    surname = names[-1]
+    return (
+        '<persName xmlns="http://www.tei-c.org/ns/1.0">'
+        '  %s'
+        '  <surname>%s</surname>'
+        '</persName>'
+    ) % (
+        ' '.join(
+            '<forename type="%s">%s</forename>' % (
+                'first' if i == 0 else 'middle',
+                forename
+            )
+            for i, forename in enumerate(forenames)
+        ),
+        surname
+    )
+
 
 def tei_affiliation(department=None, institution=None):
-  affiliation = E.affiliation()
-  if department:
-    affiliation.append(E.orgName(department, type='department'))
-  if institution:
-    affiliation.append(E.orgName(institution, type='institution'))
-  return etree.tostring(affiliation)
+    affiliation = E.affiliation()
+    if department:
+        affiliation.append(E.orgName(department, type='department'))
+    if institution:
+        affiliation.append(E.orgName(institution, type='institution'))
+    return etree.tostring(affiliation)
+
 
 def get_text(node):
-  if isinstance(node, list):
-    return ''.join(get_text(x) for x in node)
-  return node.text if node is not None else None
+    if isinstance(node, list):
+        return ''.join(get_text(x) for x in node)
+    return node.text if node is not None else None
+
 
 def get_child_text(node, name):
-  return get_text(node.find(name))
+    return get_text(node.find(name))
+
 
 @contextmanager
 def patch_grobid_service():
-  with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service:
-    process_header_names = Mock()
-    process_affiliations = Mock()
-    grobid_service.side_effect = [process_header_names, process_affiliations]
-    yield process_header_names, process_affiliations
+    with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service:
+        process_header_names = Mock()
+        process_affiliations = Mock()
+        grobid_service.side_effect = [process_header_names, process_affiliations]
+        yield process_header_names, process_affiliations
+
 
 class TestGrobidXmlEnhancer(object):
-  def test_should_initialise_grobid_service(self):
-    with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service:
-      GrobidXmlEnhancer(GROBID_URL, start_service=False)
-      grobid_service.assert_any_call(
-        GROBID_URL, GrobidApiPaths.PROCESS_HEADER_NAMES, start_service=False, field_name='names'
-      )
-      grobid_service.assert_any_call(
-        GROBID_URL, GrobidApiPaths.PROCESS_AFFILIATIONS, start_service=False,
-        field_name='affiliations'
-      )
-
-  def test_should_convert_single_author(self):
-    logging.basicConfig(level='DEBUG')
-    with patch_grobid_service() as (process_header_names, process_affiliations):
-      _ = process_affiliations
-      process_header_names.return_value = pers_name(FORENAME_1, SURNAME_1)
-      enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
-      xml_root = E.article()
-      create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1)
-      enhanced_xml = enhancer(etree.tostring(xml_root))
-      get_logger().info('enhanced_xml: %s', enhanced_xml)
-      enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
-      authors = enhanced_xml_root.findall(XmlPaths.AUTHOR)
-      assert [
-        (get_child_text(author, 'given-names'), get_child_text(author, 'surname'))
-        for author in authors
-      ] == [(FORENAME_1, SURNAME_1)]
-
-  def test_should_convert_single_affiliation(self):
-    logging.basicConfig(level='DEBUG')
-    with patch_grobid_service() as (process_header_names, process_affiliations):
-      _ = process_header_names
-      process_affiliations.return_value = tei_affiliation(
-        department=DEPARTMENT_1,
-        institution=INSTITUTION_1
-      )
-      enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
-      xml_root = E.article()
-      create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1)
-      enhanced_xml = enhancer(etree.tostring(xml_root))
-      get_logger().info('enhanced_xml: %s', enhanced_xml)
-      enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
-      affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF)
-      assert [
-        get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations
-      ] == [DEPARTMENT_1]
-      assert [
-        get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations
-      ] == [INSTITUTION_1]
-
-  def test_should_convert_multiple_author(self):
-    logging.basicConfig(level='DEBUG')
-    with patch_grobid_service() as (process_header_names, process_affiliations):
-      _ = process_affiliations
-      process_header_names.return_value = (
-        pers_name(FORENAME_1, SURNAME_1) +
-        pers_name(FORENAME_2, SURNAME_2)
-      )
-      enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
-      xml_root = E.article()
-      create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1)
-      create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_2)
-      enhanced_xml = enhancer(etree.tostring(xml_root))
-      get_logger().info('enhanced_xml: %s', enhanced_xml)
-      enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
-      authors = enhanced_xml_root.findall(XmlPaths.AUTHOR)
-      assert [
-        (get_child_text(author, 'given-names'), get_child_text(author, 'surname'))
-        for author in authors
-      ] == [(FORENAME_1, SURNAME_1), (FORENAME_2, SURNAME_2)]
-
-  def test_should_convert_multiple_affiliations(self):
-    logging.basicConfig(level='DEBUG')
-    with patch_grobid_service() as (process_header_names, process_affiliations):
-      _ = process_header_names
-      process_affiliations.return_value = (
-        tei_affiliation(
-          department=DEPARTMENT_1,
-          institution=INSTITUTION_1
-        ) +
-        tei_affiliation(
-          department=DEPARTMENT_2,
-          institution=INSTITUTION_2
-        )
-      )
-      enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
-      xml_root = E.article()
-      create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1)
-      enhanced_xml = enhancer(etree.tostring(xml_root))
-      get_logger().info('enhanced_xml: %s', enhanced_xml)
-      enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
-      affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF)
-      assert [
-        get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations
-      ] == [DEPARTMENT_1, DEPARTMENT_2]
-      assert [
-        get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations
-      ] == [INSTITUTION_1, INSTITUTION_2]
-
-  def test_should_combine_multiple_forenames(self):
-    logging.basicConfig(level='DEBUG')
-    with patch_grobid_service() as (process_header_names, process_affiliations):
-      _ = process_affiliations
-      process_header_names.return_value = pers_name(
-        FORENAME_1, FORENAME_2, SURNAME_1
-      )
-      enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
-      xml_root = E.article()
-      create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1)
-      enhanced_xml = enhancer(etree.tostring(xml_root))
-      get_logger().info('enhanced_xml: %s', enhanced_xml)
-      enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
-      authors = enhanced_xml_root.findall(XmlPaths.AUTHOR)
-      assert [
-        (get_child_text(author, 'given-names'), get_child_text(author, 'surname'))
-        for author in authors
-      ] == [(' '.join([FORENAME_1, FORENAME_2]), SURNAME_1)]
+    def test_should_initialise_grobid_service(self):
+        with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service:
+            GrobidXmlEnhancer(GROBID_URL, start_service=False)
+            grobid_service.assert_any_call(
+                GROBID_URL, GrobidApiPaths.PROCESS_HEADER_NAMES,
+                start_service=False, field_name='names'
+            )
+            grobid_service.assert_any_call(
+                GROBID_URL, GrobidApiPaths.PROCESS_AFFILIATIONS, start_service=False,
+                field_name='affiliations'
+            )
+
+    def test_should_convert_single_author(self):
+        logging.basicConfig(level='DEBUG')
+        with patch_grobid_service() as (process_header_names, process_affiliations):
+            _ = process_affiliations  # flake8: noqa
+            process_header_names.return_value = pers_name(FORENAME_1, SURNAME_1)
+            enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
+            xml_root = E.article()
+            create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1)
+            enhanced_xml = enhancer(etree.tostring(xml_root))
+            get_logger().info('enhanced_xml: %s', enhanced_xml)
+            enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
+            authors = enhanced_xml_root.findall(XmlPaths.AUTHOR)
+            assert [
+                (get_child_text(author, 'given-names'), get_child_text(author, 'surname'))
+                for author in authors
+            ] == [(FORENAME_1, SURNAME_1)]
+
+    def test_should_convert_single_affiliation(self):
+        logging.basicConfig(level='DEBUG')
+        with patch_grobid_service() as (process_header_names, process_affiliations):
+            _ = process_header_names  # flake8: noqa
+            process_affiliations.return_value = tei_affiliation(
+                department=DEPARTMENT_1,
+                institution=INSTITUTION_1
+            )
+            enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
+            xml_root = E.article()
+            create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1)
+            enhanced_xml = enhancer(etree.tostring(xml_root))
+            get_logger().info('enhanced_xml: %s', enhanced_xml)
+            enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
+            affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF)
+            assert [
+                get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations
+            ] == [DEPARTMENT_1]
+            assert [
+                get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations
+            ] == [INSTITUTION_1]
+
+    def test_should_convert_multiple_author(self):
+        logging.basicConfig(level='DEBUG')
+        with patch_grobid_service() as (process_header_names, process_affiliations):
+            _ = process_affiliations  # flake8: noqa
+            process_header_names.return_value = (
+                pers_name(FORENAME_1, SURNAME_1) +
+                pers_name(FORENAME_2, SURNAME_2)
+            )
+            enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
+            xml_root = E.article()
+            create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1)
+            create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_2)
+            enhanced_xml = enhancer(etree.tostring(xml_root))
+            get_logger().info('enhanced_xml: %s', enhanced_xml)
+            enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
+            authors = enhanced_xml_root.findall(XmlPaths.AUTHOR)
+            assert [
+                (get_child_text(author, 'given-names'), get_child_text(author, 'surname'))
+                for author in authors
+            ] == [(FORENAME_1, SURNAME_1), (FORENAME_2, SURNAME_2)]
+
+    def test_should_convert_multiple_affiliations(self):
+        logging.basicConfig(level='DEBUG')
+        with patch_grobid_service() as (process_header_names, process_affiliations):
+            _ = process_header_names  # flake8: noqa
+            process_affiliations.return_value = (
+                tei_affiliation(
+                    department=DEPARTMENT_1,
+                    institution=INSTITUTION_1
+                ) +
+                tei_affiliation(
+                    department=DEPARTMENT_2,
+                    institution=INSTITUTION_2
+                )
+            )
+            enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
+            xml_root = E.article()
+            create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1)
+            enhanced_xml = enhancer(etree.tostring(xml_root))
+            get_logger().info('enhanced_xml: %s', enhanced_xml)
+            enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
+            affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF)
+            assert [
+                get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations
+            ] == [DEPARTMENT_1, DEPARTMENT_2]
+            assert [
+                get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations
+            ] == [INSTITUTION_1, INSTITUTION_2]
+
+    def test_should_combine_multiple_forenames(self):
+        logging.basicConfig(level='DEBUG')
+        with patch_grobid_service() as (process_header_names, process_affiliations):
+            _ = process_affiliations  # flake8: noqa
+            process_header_names.return_value = pers_name(
+                FORENAME_1, FORENAME_2, SURNAME_1
+            )
+            enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False)
+            xml_root = E.article()
+            create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1)
+            enhanced_xml = enhancer(etree.tostring(xml_root))
+            get_logger().info('enhanced_xml: %s', enhanced_xml)
+            enhanced_xml_root = etree.parse(BytesIO(enhanced_xml))
+            authors = enhanced_xml_root.findall(XmlPaths.AUTHOR)
+            assert [
+                (get_child_text(author, 'given-names'), get_child_text(author, 'surname'))
+                for author in authors
+            ] == [(' '.join([FORENAME_1, FORENAME_2]), SURNAME_1)]
diff --git a/sciencebeam_gym/inference_model/__init__.py b/sciencebeam_gym/inference_model/__init__.py
index c26ac14..a8a17ea 100644
--- a/sciencebeam_gym/inference_model/__init__.py
+++ b/sciencebeam_gym/inference_model/__init__.py
@@ -15,93 +15,98 @@ OUTPUTS_KEY = 'annotation'
 LABELS_KEY = 'labels'
 COLORS_KEY = 'colors'
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class InferenceModel(object):
-  def __init__(self, inputs, outputs, labels_tensor=None, colors_tensor=None):
-    self.inputs_tensor = inputs
-    self.outputs_tensor = outputs
-    self.labels_tensor = labels_tensor
-    self.colors_tensor = colors_tensor
-    self._color_map = None
+    def __init__(self, inputs, outputs, labels_tensor=None, colors_tensor=None):
+        self.inputs_tensor = inputs
+        self.outputs_tensor = outputs
+        self.labels_tensor = labels_tensor
+        self.colors_tensor = colors_tensor
+        self._color_map = None
 
-  def get_color_map(self, session=None):
-    if self._color_map is None:
-      assert self.labels_tensor is not None
-      assert self.colors_tensor is not None
-      session = session or tf.get_default_session()
-      assert session is not None
-      labels, colors = session.run(
-        [self.labels_tensor, self.colors_tensor]
-      )
-      self._color_map = {
-        k: tuple(v)
-        for k, v in zip(labels, colors)
-      }
-    return self._color_map
+    def get_color_map(self, session=None):
+        if self._color_map is None:
+            assert self.labels_tensor is not None
+            assert self.colors_tensor is not None
+            session = session or tf.get_default_session()
+            assert session is not None
+            labels, colors = session.run(
+                [self.labels_tensor, self.colors_tensor]
+            )
+            self._color_map = {
+                k: tuple(v)
+                for k, v in zip(labels, colors)
+            }
+        return self._color_map
+
+    def __call__(self, inputs, session=None):
+        session = session or tf.get_default_session()
+        assert session is not None
+        return session.run(self.outputs_tensor, feed_dict={
+            self.inputs_tensor: inputs
+        })
 
-  def __call__(self, inputs, session=None):
-    session = session or tf.get_default_session()
-    assert session is not None
-    return session.run(self.outputs_tensor, feed_dict={
-      self.inputs_tensor: inputs
-    })
 
 def save_inference_model(export_dir, inference_model, session=None, replace=True):
-  if session is None:
-    session = tf.get_default_session()
-  assert session is not None
-  if replace and is_directory(export_dir):
-    get_logger().info('replacing %s', export_dir)
-    delete_recursively(export_dir)
-  prediction_signature = predict_signature_def(
-    inputs={INPUTS_KEY: inference_model.inputs_tensor},
-    outputs={k: v for k, v in {
-        OUTPUTS_KEY: inference_model.outputs_tensor,
-        LABELS_KEY: inference_model.labels_tensor,
-        COLORS_KEY: inference_model.colors_tensor
-      }.items() if v is not None
+    if session is None:
+        session = tf.get_default_session()
+    assert session is not None
+    if replace and is_directory(export_dir):
+        get_logger().info('replacing %s', export_dir)
+        delete_recursively(export_dir)
+    prediction_signature = predict_signature_def(
+        inputs={INPUTS_KEY: inference_model.inputs_tensor},
+        outputs={k: v for k, v in {
+            OUTPUTS_KEY: inference_model.outputs_tensor,
+            LABELS_KEY: inference_model.labels_tensor,
+            COLORS_KEY: inference_model.colors_tensor
+        }.items() if v is not None
+        }
+    )
+    signature_def_map = {
+        DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature
     }
-  )
-  signature_def_map = {
-    DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature
-  }
-  legacy_init_op = tf.group(
-    tf.tables_initializer(),
-    name='legacy_init_op'
-  )
-  builder = SavedModelBuilder(export_dir)
-  builder.add_meta_graph_and_variables(
-    session,
-    [SERVING],
-    signature_def_map=signature_def_map,
-    legacy_init_op=legacy_init_op
-  )
-  builder.save()
+    legacy_init_op = tf.group(
+        tf.tables_initializer(),
+        name='legacy_init_op'
+    )
+    builder = SavedModelBuilder(export_dir)
+    builder.add_meta_graph_and_variables(
+        session,
+        [SERVING],
+        signature_def_map=signature_def_map,
+        legacy_init_op=legacy_init_op
+    )
+    builder.save()
+
 
 def get_output_tensor_or_none(graph, signature, name):
-  tensor_name = signature.outputs[name].name
-  return graph.get_tensor_by_name(tensor_name) if tensor_name else None
+    tensor_name = signature.outputs[name].name
+    return graph.get_tensor_by_name(tensor_name) if tensor_name else None
+
 
 def load_inference_model(export_dir, session=None):
-  if session is None:
-    session = tf.get_default_session()
-  assert session is not None
-  meta_graph_def = tf.saved_model.loader.load(session, [SERVING], export_dir)
-  signature = meta_graph_def.signature_def[DEFAULT_SERVING_SIGNATURE_DEF_KEY]
-  inputs_name = signature.inputs[INPUTS_KEY].name
-  outputs_name = signature.outputs[OUTPUTS_KEY].name
-  get_logger().info('inputs_name: %s', inputs_name)
-  get_logger().info('outputs_name: %s', outputs_name)
-  graph = tf.get_default_graph()
-  inputs = graph.get_tensor_by_name(inputs_name)
-  outputs = graph.get_tensor_by_name(outputs_name)
-  get_logger().info('inputs: %s', inputs)
-  get_logger().info('output: %s', outputs)
-  return InferenceModel(
-    inputs,
-    outputs,
-    get_output_tensor_or_none(graph, signature, LABELS_KEY),
-    get_output_tensor_or_none(graph, signature, COLORS_KEY)
-  )
+    if session is None:
+        session = tf.get_default_session()
+    assert session is not None
+    meta_graph_def = tf.saved_model.loader.load(session, [SERVING], export_dir)
+    signature = meta_graph_def.signature_def[DEFAULT_SERVING_SIGNATURE_DEF_KEY]
+    inputs_name = signature.inputs[INPUTS_KEY].name
+    outputs_name = signature.outputs[OUTPUTS_KEY].name
+    get_logger().info('inputs_name: %s', inputs_name)
+    get_logger().info('outputs_name: %s', outputs_name)
+    graph = tf.get_default_graph()
+    inputs = graph.get_tensor_by_name(inputs_name)
+    outputs = graph.get_tensor_by_name(outputs_name)
+    get_logger().info('inputs: %s', inputs)
+    get_logger().info('output: %s', outputs)
+    return InferenceModel(
+        inputs,
+        outputs,
+        get_output_tensor_or_none(graph, signature, LABELS_KEY),
+        get_output_tensor_or_none(graph, signature, COLORS_KEY)
+    )
diff --git a/sciencebeam_gym/inference_model/__init___test.py b/sciencebeam_gym/inference_model/__init___test.py
index 9eb117f..01b60a9 100644
--- a/sciencebeam_gym/inference_model/__init___test.py
+++ b/sciencebeam_gym/inference_model/__init___test.py
@@ -6,13 +6,13 @@ import tensorflow as tf
 import numpy as np
 
 from sciencebeam_utils.utils.num import (
-  assert_all_close
+    assert_all_close
 )
 
 from sciencebeam_gym.inference_model import (
-  InferenceModel,
-  save_inference_model,
-  load_inference_model
+    InferenceModel,
+    save_inference_model,
+    load_inference_model
 )
 
 TEMP_DIR = '.temp/tests/%s' % __name__
@@ -20,59 +20,63 @@ TEMP_DIR = '.temp/tests/%s' % __name__
 LABELS = ['label 1', 'label 2', 'label 3']
 COLORS = [(1, 1, 1), (2, 2, 2), (3, 3, 3)]
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 class TestInferenceModelSaverLoader(object):
-  def test_should_export_import_and_run_simple_model(self):
-    export_dir = os.path.join(TEMP_DIR, 'export')
-    if os.path.isdir(export_dir):
-      rmtree(export_dir)
+    def test_should_export_import_and_run_simple_model(self):
+        export_dir = os.path.join(TEMP_DIR, 'export')
+        if os.path.isdir(export_dir):
+            rmtree(export_dir)
 
-    # sample fn that works with tf and np
-    sample_fn = lambda x: x * 2.0 + 10.0
+        # sample fn that works with tf and np
+        def sample_fn(x):
+            return x * 2.0 + 10.0
 
-    with tf.Graph().as_default():
-      with tf.variable_scope('scope1'):
-        inputs = tf.placeholder(tf.float32, (None, 16, 16, 3))
-        outputs = sample_fn(inputs)
-        get_logger().info('outputs: %s', outputs)
-        with tf.Session() as session:
-          save_inference_model(
-            export_dir,
-            InferenceModel(inputs, outputs)
-          )
-    with tf.Graph().as_default():
-      with tf.Session() as session:
-        inference_model = load_inference_model(export_dir)
-        inputs_value = np.ones((5, 16, 16, 3))
-        assert_all_close(
-          inference_model(inputs_value, session=session),
-          sample_fn(inputs_value)
-        )
+        with tf.Graph().as_default():
+            with tf.variable_scope('scope1'):
+                inputs = tf.placeholder(tf.float32, (None, 16, 16, 3))
+                outputs = sample_fn(inputs)
+                get_logger().info('outputs: %s', outputs)
+                with tf.Session() as session:
+                    save_inference_model(
+                        export_dir,
+                        InferenceModel(inputs, outputs)
+                    )
+        with tf.Graph().as_default():
+            with tf.Session() as session:
+                inference_model = load_inference_model(export_dir)
+                inputs_value = np.ones((5, 16, 16, 3))
+                assert_all_close(
+                    inference_model(inputs_value, session=session),
+                    sample_fn(inputs_value)
+                )
 
-  def test_should_export_import_color_map(self):
-    export_dir = os.path.join(TEMP_DIR, 'export')
-    if os.path.isdir(export_dir):
-      rmtree(export_dir)
+    def test_should_export_import_color_map(self):
+        export_dir = os.path.join(TEMP_DIR, 'export')
+        if os.path.isdir(export_dir):
+            rmtree(export_dir)
 
-    with tf.Graph().as_default():
-      with tf.variable_scope('scope1'):
-        inputs = tf.placeholder(tf.float32, (None, 16, 16, 3))
-        outputs = inputs * 2.0
-        labels = tf.constant(LABELS)
-        colors = tf.constant(COLORS)
-        with tf.Session():
-          save_inference_model(
-            export_dir,
-            InferenceModel(inputs, outputs, labels, colors)
-          )
-    with tf.Graph().as_default():
-      with tf.Session():
-        inference_model = load_inference_model(export_dir)
-        color_map = inference_model.get_color_map()
-        get_logger().debug('color_map: %s', color_map)
-        assert set(color_map.items()) == set(zip(LABELS, COLORS))
+        with tf.Graph().as_default():
+            with tf.variable_scope('scope1'):
+                inputs = tf.placeholder(tf.float32, (None, 16, 16, 3))
+                outputs = inputs * 2.0
+                labels = tf.constant(LABELS)
+                colors = tf.constant(COLORS)
+                with tf.Session():
+                    save_inference_model(
+                        export_dir,
+                        InferenceModel(inputs, outputs, labels, colors)
+                    )
+        with tf.Graph().as_default():
+            with tf.Session():
+                inference_model = load_inference_model(export_dir)
+                color_map = inference_model.get_color_map()
+                get_logger().debug('color_map: %s', color_map)
+                assert set(color_map.items()) == set(zip(LABELS, COLORS))
diff --git a/sciencebeam_gym/inference_model/annotate_using_predictions.py b/sciencebeam_gym/inference_model/annotate_using_predictions.py
index d3af139..6cef2a6 100644
--- a/sciencebeam_gym/inference_model/annotate_using_predictions.py
+++ b/sciencebeam_gym/inference_model/annotate_using_predictions.py
@@ -8,183 +8,198 @@ import numpy as np
 from PIL import Image
 
 from sciencebeam_utils.beam_utils.io import (
-  read_all_from_path
+    read_all_from_path
 )
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.preprocess.color_map import (
-  parse_color_map_from_file
+    parse_color_map_from_file
 )
 
 from sciencebeam_gym.structured_document.structured_document_loader import (
-  load_structured_document
+    load_structured_document
 )
 
 from sciencebeam_gym.structured_document.structured_document_saver import (
-  save_structured_document
+    save_structured_document
 )
 
 CV_TAG_SCOPE = 'cv'
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class AnnotatedImage(object):
-  def __init__(self, data, color_map):
-    self.data = data
-    self.color_map = color_map
-    self.size = (data.shape[1], data.shape[0])
-
-  def get_tag_probabilities_within(self, bounding_box):
-    image_area = self.data[
-      int(bounding_box.y):int(bounding_box.y + bounding_box.height),
-      int(bounding_box.x):int(bounding_box.x + bounding_box.width)
-    ]
-    counts = {
-      k: np.sum(np.all(image_area == v, axis=-1))
-      for k, v in self.color_map.items()
-    }
-    total = image_area.size / image_area.shape[-1]
-    return {
-      k: v / total if total > 0.0 else 0.0
-      for k, v in counts.items()
-    }
+    def __init__(self, data, color_map):
+        self.data = data
+        self.color_map = color_map
+        self.size = (data.shape[1], data.shape[0])
+
+    def get_tag_probabilities_within(self, bounding_box):
+        image_area = self.data[
+            int(bounding_box.y):int(bounding_box.y + bounding_box.height),
+            int(bounding_box.x):int(bounding_box.x + bounding_box.width)
+        ]
+        counts = {
+            k: np.sum(np.all(image_area == v, axis=-1))
+            for k, v in self.color_map.items()
+        }
+        total = image_area.size / image_area.shape[-1]
+        return {
+            k: v / total if total > 0.0 else 0.0
+            for k, v in counts.items()
+        }
+
 
 def calculate_rescale_factors(structured_document, page, annotated_image):
-  page_bounding_box = structured_document.get_bounding_box(page)
-  get_logger().debug('page_bounding_box: %s', page_bounding_box)
-  assert page_bounding_box is not None
-  page_width = page_bounding_box.width
-  page_height = page_bounding_box.height
-  annotated_image_width, annotated_image_height = annotated_image.size
-  get_logger().debug(
-    'annotated_image width, height: %f, %f',
-    annotated_image_width, annotated_image_height
-  )
-  rx = annotated_image_width / page_width
-  ry = annotated_image_height / page_height
-  return rx, ry
+    page_bounding_box = structured_document.get_bounding_box(page)
+    get_logger().debug('page_bounding_box: %s', page_bounding_box)
+    assert page_bounding_box is not None
+    page_width = page_bounding_box.width
+    page_height = page_bounding_box.height
+    annotated_image_width, annotated_image_height = annotated_image.size
+    get_logger().debug(
+        'annotated_image width, height: %f, %f',
+        annotated_image_width, annotated_image_height
+    )
+    rx = annotated_image_width / page_width
+    ry = annotated_image_height / page_height
+    return rx, ry
+
 
 def scale_bounding_box(bounding_box, rx, ry):
-  return BoundingBox(
-    bounding_box.x * rx,
-    bounding_box.y * ry,
-    bounding_box.width * rx,
-    bounding_box.height * ry
-  )
+    return BoundingBox(
+        bounding_box.x * rx,
+        bounding_box.y * ry,
+        bounding_box.width * rx,
+        bounding_box.height * ry
+    )
+
 
 def annotate_page_using_predicted_image(
-  structured_document, page, annotated_image, tag_scope=CV_TAG_SCOPE):
-
-  rx, ry = calculate_rescale_factors(structured_document, page, annotated_image)
-  get_logger().debug('rx, ry: %f, %f', rx, ry)
-  for line in structured_document.get_lines_of_page(page):
-    for token in structured_document.get_tokens_of_line(line):
-      bounding_box = structured_document.get_bounding_box(token)
-      if bounding_box:
-        get_logger().debug('original bounding_box: %s', bounding_box)
-        bounding_box = scale_bounding_box(bounding_box, rx, ry)
-        get_logger().debug('scaled bounding_box: %s', bounding_box)
-        tag_probabilites = sorted(
-          ((k, v) for k, v in annotated_image.get_tag_probabilities_within(bounding_box).items()),
-          key=lambda x: x[1],
-          reverse=True
-        )
-        top_probability_tag, top_probability_value = tag_probabilites[0]
-        get_logger().debug('tag_probabilites: %s', tag_probabilites)
-        if top_probability_value > 0.5:
-          get_logger().debug(
-            'tagging token: %s: %s',
-            top_probability_tag,
-            structured_document.get_text(token)
-          )
-          structured_document.set_tag(token, top_probability_tag, scope=tag_scope)
+        structured_document, page, annotated_image, tag_scope=CV_TAG_SCOPE):
+
+    rx, ry = calculate_rescale_factors(structured_document, page, annotated_image)
+    get_logger().debug('rx, ry: %f, %f', rx, ry)
+    for line in structured_document.get_lines_of_page(page):
+        for token in structured_document.get_tokens_of_line(line):
+            bounding_box = structured_document.get_bounding_box(token)
+            if bounding_box:
+                get_logger().debug('original bounding_box: %s', bounding_box)
+                bounding_box = scale_bounding_box(bounding_box, rx, ry)
+                get_logger().debug('scaled bounding_box: %s', bounding_box)
+                tag_probabilites = sorted(
+                    (
+                        (k, v)
+                        for k, v in annotated_image.get_tag_probabilities_within(
+                            bounding_box
+                        ).items()
+                    ),
+                    key=lambda x: x[1],
+                    reverse=True
+                )
+                top_probability_tag, top_probability_value = tag_probabilites[0]
+                get_logger().debug('tag_probabilites: %s', tag_probabilites)
+                if top_probability_value > 0.5:
+                    get_logger().debug(
+                        'tagging token: %s: %s',
+                        top_probability_tag,
+                        structured_document.get_text(token)
+                    )
+                    structured_document.set_tag(token, top_probability_tag, scope=tag_scope)
+
 
 def annotate_structured_document_using_predicted_images(
-  structured_document, annotated_images, tag_scope=CV_TAG_SCOPE):
+        structured_document, annotated_images, tag_scope=CV_TAG_SCOPE):
+
+    for page, annotated_image in zip(structured_document.get_pages(), annotated_images):
+        annotate_page_using_predicted_image(
+            structured_document, page, annotated_image, tag_scope=tag_scope
+        )
+    return structured_document
 
-  for page, annotated_image in zip(structured_document.get_pages(), annotated_images):
-    annotate_page_using_predicted_image(
-      structured_document, page, annotated_image, tag_scope=tag_scope
-    )
-  return structured_document
 
 def parse_args(argv=None):
-  parser = argparse.ArgumentParser('Annotated LXML using prediction images')
-  source = parser.add_mutually_exclusive_group(required=True)
-  source.add_argument(
-    '--lxml-path', type=str, required=False,
-    help='path to lxml or svg pages document'
-  )
-
-  images = parser.add_mutually_exclusive_group(required=True)
-  images.add_argument(
-    '--images-path', type=str, nargs='+',
-    help='path to lxml document'
-  )
-
-  parser.add_argument(
-    '--output-path', type=str, required=True,
-    help='output path to annotated document'
-  )
-
-  parser.add_argument(
-    '--tag-scope', type=str, required=False,
-    default=CV_TAG_SCOPE,
-    help='target tag scope for the predicted tags'
-  )
-
-  parser.add_argument(
-    '--color-map', default='color_map.conf',
-    help='color map to use'
-  )
-
-  parser.add_argument(
-    '--debug', action='store_true', default=False,
-    help='enable debug output'
-  )
-
-  return parser.parse_args(argv)
+    parser = argparse.ArgumentParser('Annotated LXML using prediction images')
+    source = parser.add_mutually_exclusive_group(required=True)
+    source.add_argument(
+        '--lxml-path', type=str, required=False,
+        help='path to lxml or svg pages document'
+    )
+
+    images = parser.add_mutually_exclusive_group(required=True)
+    images.add_argument(
+        '--images-path', type=str, nargs='+',
+        help='path to lxml document'
+    )
+
+    parser.add_argument(
+        '--output-path', type=str, required=True,
+        help='output path to annotated document'
+    )
+
+    parser.add_argument(
+        '--tag-scope', type=str, required=False,
+        default=CV_TAG_SCOPE,
+        help='target tag scope for the predicted tags'
+    )
+
+    parser.add_argument(
+        '--color-map', default='color_map.conf',
+        help='color map to use'
+    )
+
+    parser.add_argument(
+        '--debug', action='store_true', default=False,
+        help='enable debug output'
+    )
+
+    return parser.parse_args(argv)
+
 
 def load_annotation_image(path, color_map):
-  get_logger().debug('loading annotation image: %s', path)
-  return AnnotatedImage(
-    np.asarray(
-      Image.open(BytesIO(read_all_from_path(path, 'rb'))).convert('RGB'),
-      dtype=np.uint8
-    ),
-    color_map
-  )
+    get_logger().debug('loading annotation image: %s', path)
+    return AnnotatedImage(
+        np.asarray(
+            Image.open(BytesIO(read_all_from_path(path, 'rb'))).convert('RGB'),
+            dtype=np.uint8
+        ),
+        color_map
+    )
+
 
 def main(argv=None):
-  args = parse_args(argv)
+    args = parse_args(argv)
 
-  if args.debug:
-    logging.getLogger().setLevel('DEBUG')
+    if args.debug:
+        logging.getLogger().setLevel('DEBUG')
 
-  color_map = parse_color_map_from_file(args.color_map)
-  get_logger().debug('color_map: %s', color_map)
+    color_map = parse_color_map_from_file(args.color_map)
+    get_logger().debug('color_map: %s', color_map)
 
-  structured_document = load_structured_document(args.lxml_path, 'rb')
+    structured_document = load_structured_document(args.lxml_path, 'rb')
 
-  annotated_images = (
-    load_annotation_image(path, color_map)
-    for path in args.images_path
-  )
+    annotated_images = (
+        load_annotation_image(path, color_map)
+        for path in args.images_path
+    )
+
+    structured_document = annotate_structured_document_using_predicted_images(
+        structured_document,
+        annotated_images,
+        tag_scope=args.tag_scope
+    )
 
-  structured_document = annotate_structured_document_using_predicted_images(
-    structured_document,
-    annotated_images,
-    tag_scope=args.tag_scope
-  )
+    get_logger().info('writing result to: %s', args.output_path)
+    save_structured_document(args.output_path, structured_document.root)
 
-  get_logger().info('writing result to: %s', args.output_path)
-  save_structured_document(args.output_path, structured_document.root)
 
 if __name__ == '__main__':
-  logging.basicConfig(level='INFO')
+    logging.basicConfig(level='INFO')
 
-  main()
+    main()
diff --git a/sciencebeam_gym/inference_model/annotate_using_predictions_test.py b/sciencebeam_gym/inference_model/annotate_using_predictions_test.py
index f483b39..312fd0f 100644
--- a/sciencebeam_gym/inference_model/annotate_using_predictions_test.py
+++ b/sciencebeam_gym/inference_model/annotate_using_predictions_test.py
@@ -3,20 +3,20 @@ import pytest
 import numpy as np
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimpleLine,
-  SimpleToken
+    SimpleStructuredDocument,
+    SimpleLine,
+    SimpleToken
 )
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.inference_model.annotate_using_predictions import (
-  AnnotatedImage,
-  annotate_structured_document_using_predicted_images,
-  parse_args,
-  CV_TAG_SCOPE
+    AnnotatedImage,
+    annotate_structured_document_using_predicted_images,
+    parse_args,
+    CV_TAG_SCOPE
 )
 
 TAG_1 = 'tag1'
@@ -30,163 +30,168 @@ DEFAULT_HEIGHT = 3
 DEFAULT_WIDTH = 3
 DEFAULT_BOUNDING_BOX = BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT)
 
+
 def filled_image(color, color_map, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT):
-  return AnnotatedImage(
-    np.full((height, width, 3), color),
-    color_map
-  )
+    return AnnotatedImage(
+        np.full((height, width, 3), color),
+        color_map
+    )
+
 
 def fill_rect(annoted_image, bounding_box, color):
-  for y in range(bounding_box.y, bounding_box.y + bounding_box.height):
-    for x in range(bounding_box.x, bounding_box.x + bounding_box.width):
-      annoted_image.data[y, x] = color
+    for y in range(bounding_box.y, bounding_box.y + bounding_box.height):
+        for x in range(bounding_box.x, bounding_box.x + bounding_box.width):
+            annoted_image.data[y, x] = color
+
 
 class TestAnnotatedImage(object):
-  def test_should_return_zero_tag_probality_if_color_not_in_output(self):
-    annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1})
-    assert annotated_image.get_tag_probabilities_within(
-      DEFAULT_BOUNDING_BOX
-    ).get(TAG_1) == 0.0
-
-  def test_should_return_one_tag_probality_if_color_is_only_color_in_output(self):
-    annotated_image = filled_image(COLOR_1, {TAG_1: COLOR_1})
-    assert annotated_image.get_tag_probabilities_within(
-      DEFAULT_BOUNDING_BOX
-    ).get(TAG_1) == 1.0
-
-  def test_should_return_zero_tag_probality_if_bounding_box_is_empty(self):
-    annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1})
-    assert annotated_image.get_tag_probabilities_within(
-      BoundingBox(0, 0, 0, 0)
-    ).get(TAG_1) == 0.0
-
-  def test_should_return_zero_tag_probality_if_bounding_box_is_outside_image(self):
-    annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1})
-    assert annotated_image.get_tag_probabilities_within(
-      DEFAULT_BOUNDING_BOX.move_by(DEFAULT_WIDTH, 0)
-    ).get(TAG_1) == 0.0
+    def test_should_return_zero_tag_probality_if_color_not_in_output(self):
+        annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1})
+        assert annotated_image.get_tag_probabilities_within(
+            DEFAULT_BOUNDING_BOX
+        ).get(TAG_1) == 0.0
+
+    def test_should_return_one_tag_probality_if_color_is_only_color_in_output(self):
+        annotated_image = filled_image(COLOR_1, {TAG_1: COLOR_1})
+        assert annotated_image.get_tag_probabilities_within(
+            DEFAULT_BOUNDING_BOX
+        ).get(TAG_1) == 1.0
+
+    def test_should_return_zero_tag_probality_if_bounding_box_is_empty(self):
+        annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1})
+        assert annotated_image.get_tag_probabilities_within(
+            BoundingBox(0, 0, 0, 0)
+        ).get(TAG_1) == 0.0
+
+    def test_should_return_zero_tag_probality_if_bounding_box_is_outside_image(self):
+        annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1})
+        assert annotated_image.get_tag_probabilities_within(
+            DEFAULT_BOUNDING_BOX.move_by(DEFAULT_WIDTH, 0)
+        ).get(TAG_1) == 0.0
+
 
 class TestAnnotateStructuredDocumentUsingPredictedImages(object):
-  def test_should_not_fail_with_empty_document(self):
-    structured_document = SimpleStructuredDocument()
-    annotate_structured_document_using_predicted_images(
-      structured_document,
-      []
-    )
+    def test_should_not_fail_with_empty_document(self):
+        structured_document = SimpleStructuredDocument()
+        annotate_structured_document_using_predicted_images(
+            structured_document,
+            []
+        )
+
+    def test_should_not_tag_single_token_not_within_prediction(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1)
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
+        structured_document.set_bounding_box(
+            structured_document.get_pages()[0],
+            DEFAULT_BOUNDING_BOX
+        )
+        structured_document.set_bounding_box(token_1, DEFAULT_BOUNDING_BOX)
+        annotate_structured_document_using_predicted_images(
+            structured_document,
+            [filled_image(BG_COLOR, {TAG_1: COLOR_1})]
+        )
+        assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) is None
+
+    def test_should_tag_single_token_within_prediction(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1)
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
+        structured_document.set_bounding_box(
+            structured_document.get_pages()[0],
+            DEFAULT_BOUNDING_BOX
+        )
+        structured_document.set_bounding_box(token_1, DEFAULT_BOUNDING_BOX)
+        annotate_structured_document_using_predicted_images(
+            structured_document,
+            [filled_image(COLOR_1, {TAG_1: COLOR_1})]
+        )
+        assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1
+
+    def test_should_tag_single_token_within_full_prediction_at_smaller_scale(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1)
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
+        structured_document.set_bounding_box(
+            structured_document.get_pages()[0],
+            BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10)
+        )
+        structured_document.set_bounding_box(
+            token_1,
+            BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10)
+        )
+        annotate_structured_document_using_predicted_images(
+            structured_document,
+            [filled_image(COLOR_1, {TAG_1: COLOR_1}, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT)]
+        )
+        assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1
+
+    def test_should_tag_single_token_within_partial_prediction_at_same_scale(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1)
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
+        structured_document.set_bounding_box(
+            structured_document.get_pages()[0],
+            DEFAULT_BOUNDING_BOX
+        )
+        structured_document.set_bounding_box(
+            structured_document.get_pages()[0],
+            BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10)
+        )
+        structured_document.set_bounding_box(
+            token_1,
+            BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT)
+        )
+        annotated_image = filled_image(
+            BG_COLOR, {TAG_1: COLOR_1},
+            width=DEFAULT_WIDTH * 10,
+            height=DEFAULT_HEIGHT * 10
+        )
+        fill_rect(
+            annotated_image,
+            BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT),
+            COLOR_1
+        )
+        annotate_structured_document_using_predicted_images(
+            structured_document,
+            [annotated_image]
+        )
+        assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1
+
+    def test_should_tag_single_token_within_partial_prediction_at_smaller_scale(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1)
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
+        structured_document.set_bounding_box(
+            structured_document.get_pages()[0],
+            BoundingBox(0, 0, DEFAULT_WIDTH * 100, DEFAULT_HEIGHT * 100)
+        )
+        structured_document.set_bounding_box(
+            token_1,
+            BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10)
+        )
+        annotated_image = filled_image(
+            BG_COLOR, {TAG_1: COLOR_1},
+            width=DEFAULT_WIDTH * 10,
+            height=DEFAULT_HEIGHT * 10
+        )
+        fill_rect(
+            annotated_image,
+            BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT),
+            COLOR_1
+        )
+        annotate_structured_document_using_predicted_images(
+            structured_document,
+            [annotated_image]
+        )
+        assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1
 
-  def test_should_not_tag_single_token_not_within_prediction(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1)
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
-    structured_document.set_bounding_box(
-      structured_document.get_pages()[0],
-      DEFAULT_BOUNDING_BOX
-    )
-    structured_document.set_bounding_box(token_1, DEFAULT_BOUNDING_BOX)
-    annotate_structured_document_using_predicted_images(
-      structured_document,
-      [filled_image(BG_COLOR, {TAG_1: COLOR_1})]
-    )
-    assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) is None
-
-  def test_should_tag_single_token_within_prediction(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1)
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
-    structured_document.set_bounding_box(
-      structured_document.get_pages()[0],
-      DEFAULT_BOUNDING_BOX
-    )
-    structured_document.set_bounding_box(token_1, DEFAULT_BOUNDING_BOX)
-    annotate_structured_document_using_predicted_images(
-      structured_document,
-      [filled_image(COLOR_1, {TAG_1: COLOR_1})]
-    )
-    assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1
-
-  def test_should_tag_single_token_within_full_prediction_at_smaller_scale(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1)
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
-    structured_document.set_bounding_box(
-      structured_document.get_pages()[0],
-      BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10)
-    )
-    structured_document.set_bounding_box(
-      token_1,
-      BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10)
-    )
-    annotate_structured_document_using_predicted_images(
-      structured_document,
-      [filled_image(COLOR_1, {TAG_1: COLOR_1}, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT)]
-    )
-    assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1
-
-  def test_should_tag_single_token_within_partial_prediction_at_same_scale(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1)
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
-    structured_document.set_bounding_box(
-      structured_document.get_pages()[0],
-      DEFAULT_BOUNDING_BOX
-    )
-    structured_document.set_bounding_box(
-      structured_document.get_pages()[0],
-      BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10)
-    )
-    structured_document.set_bounding_box(
-      token_1,
-      BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT)
-    )
-    annotated_image = filled_image(
-      BG_COLOR, {TAG_1: COLOR_1},
-      width=DEFAULT_WIDTH * 10,
-      height=DEFAULT_HEIGHT * 10
-    )
-    fill_rect(
-      annotated_image,
-      BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT),
-      COLOR_1
-    )
-    annotate_structured_document_using_predicted_images(
-      structured_document,
-      [annotated_image]
-    )
-    assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1
-
-  def test_should_tag_single_token_within_partial_prediction_at_smaller_scale(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1)
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
-    structured_document.set_bounding_box(
-      structured_document.get_pages()[0],
-      BoundingBox(0, 0, DEFAULT_WIDTH * 100, DEFAULT_HEIGHT * 100)
-    )
-    structured_document.set_bounding_box(
-      token_1,
-      BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10)
-    )
-    annotated_image = filled_image(
-      BG_COLOR, {TAG_1: COLOR_1},
-      width=DEFAULT_WIDTH * 10,
-      height=DEFAULT_HEIGHT * 10
-    )
-    fill_rect(
-      annotated_image,
-      BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT),
-      COLOR_1
-    )
-    annotate_structured_document_using_predicted_images(
-      structured_document,
-      [annotated_image]
-    )
-    assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1
 
 class TestParseArgs(object):
-  def test_should_raise_error_if_not_enough_arguments_are_passed(self):
-    with pytest.raises(SystemExit):
-      parse_args([])
-
-  def test_should_not_raise_error_with_minimum_args(self):
-    parse_args(['--lxml-path=test', '--images-path=test', '--output-path=test'])
-
-  # def test_should_raise_error_if_mutliple_source_args_are_specified(self):
-  #   with pytest.raises(SystemExit):
-  #     parse_args([
-  #       '--lxml-path=test', '--svg-path=test', '--images-path=test', '--output-path=test'
-  #     ])
+    def test_should_raise_error_if_not_enough_arguments_are_passed(self):
+        with pytest.raises(SystemExit):
+            parse_args([])
+
+    def test_should_not_raise_error_with_minimum_args(self):
+        parse_args(['--lxml-path=test', '--images-path=test', '--output-path=test'])
+
+    # def test_should_raise_error_if_mutliple_source_args_are_specified(self):
+    #   with pytest.raises(SystemExit):
+    #     parse_args([
+    #       '--lxml-path=test', '--svg-path=test', '--images-path=test', '--output-path=test'
+    #     ])
diff --git a/sciencebeam_gym/inference_model/extract_from_annotated_document.py b/sciencebeam_gym/inference_model/extract_from_annotated_document.py
index aa06c4c..eb12356 100644
--- a/sciencebeam_gym/inference_model/extract_from_annotated_document.py
+++ b/sciencebeam_gym/inference_model/extract_from_annotated_document.py
@@ -1,95 +1,102 @@
 import logging
 
 from sciencebeam_gym.structured_document import (
-  B_TAG_PREFIX
+    B_TAG_PREFIX
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class ExtractedItem(object):
-  def __init__(self, tag, text, tokens=None, tag_prefix=None, sub_items=None):
-    self.tag = tag
-    self.tag_prefix = tag_prefix
-    self.text = text
-    self.tokens = tokens or []
-    self.sub_items = sub_items or []
-
-  def extend(self, other_item):
-    return ExtractedItem(
-      self.tag,
-      self.text + '\n' + other_item.text,
-      tokens=self.tokens + other_item.tokens,
-      tag_prefix=self.tag_prefix,
-      sub_items=self.sub_items + other_item.sub_items
-    )
+    def __init__(self, tag, text, tokens=None, tag_prefix=None, sub_items=None):
+        self.tag = tag
+        self.tag_prefix = tag_prefix
+        self.text = text
+        self.tokens = tokens or []
+        self.sub_items = sub_items or []
+
+    def extend(self, other_item):
+        return ExtractedItem(
+            self.tag,
+            self.text + '\n' + other_item.text,
+            tokens=self.tokens + other_item.tokens,
+            tag_prefix=self.tag_prefix,
+            sub_items=self.sub_items + other_item.sub_items
+        )
+
 
 def get_lines(structured_document):
-  for page in structured_document.get_pages():
-    for line in structured_document.get_lines_of_page(page):
-      yield line
+    for page in structured_document.get_pages():
+        for line in structured_document.get_lines_of_page(page):
+            yield line
+
 
 def extract_from_annotated_tokens(structured_document, tokens, tag_scope=None, level=None):
-  previous_tokens = []
-  previous_tag = None
-  previous_tag_prefix = None
-  for token in tokens:
-    tag_prefix, tag = structured_document.get_tag_prefix_and_value(
-      token, scope=tag_scope, level=level
-    )
-    if not previous_tokens:
-      previous_tokens = [token]
-      previous_tag = tag
-      previous_tag_prefix = tag_prefix
-    elif tag == previous_tag and tag_prefix != B_TAG_PREFIX:
-      previous_tokens.append(token)
-    else:
-      yield ExtractedItem(
-        previous_tag,
-        ' '.join(structured_document.get_text(t) for t in previous_tokens),
-        tokens=previous_tokens,
-        tag_prefix=previous_tag_prefix
-      )
-      previous_tokens = [token]
-      previous_tag = tag
-      previous_tag_prefix = tag_prefix
-  if previous_tokens:
-    yield ExtractedItem(
-      previous_tag,
-      ' '.join(structured_document.get_text(t) for t in previous_tokens),
-      tokens=previous_tokens,
-      tag_prefix=previous_tag_prefix
-    )
+    previous_tokens = []
+    previous_tag = None
+    previous_tag_prefix = None
+    for token in tokens:
+        tag_prefix, tag = structured_document.get_tag_prefix_and_value(
+            token, scope=tag_scope, level=level
+        )
+        if not previous_tokens:
+            previous_tokens = [token]
+            previous_tag = tag
+            previous_tag_prefix = tag_prefix
+        elif tag == previous_tag and tag_prefix != B_TAG_PREFIX:
+            previous_tokens.append(token)
+        else:
+            yield ExtractedItem(
+                previous_tag,
+                ' '.join(structured_document.get_text(t) for t in previous_tokens),
+                tokens=previous_tokens,
+                tag_prefix=previous_tag_prefix
+            )
+            previous_tokens = [token]
+            previous_tag = tag
+            previous_tag_prefix = tag_prefix
+    if previous_tokens:
+        yield ExtractedItem(
+            previous_tag,
+            ' '.join(structured_document.get_text(t) for t in previous_tokens),
+            tokens=previous_tokens,
+            tag_prefix=previous_tag_prefix
+        )
+
 
 def with_sub_items(structured_document, extracted_item, tag_scope=None):
-  return ExtractedItem(
-    extracted_item.tag,
-    extracted_item.text,
-    tokens=extracted_item.tokens,
-    tag_prefix=extracted_item.tag_prefix,
-    sub_items=list(extract_from_annotated_tokens(
-      structured_document, extracted_item.tokens,
-      tag_scope=tag_scope, level=2
-    ))
-  )
+    return ExtractedItem(
+        extracted_item.tag,
+        extracted_item.text,
+        tokens=extracted_item.tokens,
+        tag_prefix=extracted_item.tag_prefix,
+        sub_items=list(extract_from_annotated_tokens(
+            structured_document, extracted_item.tokens,
+            tag_scope=tag_scope, level=2
+        ))
+    )
+
 
 def extract_from_annotated_lines(structured_document, lines, tag_scope=None):
-  previous_item = None
-  for line in lines:
-    tokens = structured_document.get_tokens_of_line(line)
-    for item in extract_from_annotated_tokens(structured_document, tokens, tag_scope=tag_scope):
-      if previous_item is not None:
-        if previous_item.tag == item.tag and item.tag_prefix != B_TAG_PREFIX:
-          previous_item = previous_item.extend(item)
-        else:
-          yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope)
-          previous_item = item
-      else:
-        previous_item = item
-  if previous_item is not None:
-    yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope)
+    previous_item = None
+    for line in lines:
+        tokens = structured_document.get_tokens_of_line(line)
+        for item in extract_from_annotated_tokens(structured_document, tokens, tag_scope=tag_scope):
+            if previous_item is not None:
+                if previous_item.tag == item.tag and item.tag_prefix != B_TAG_PREFIX:
+                    previous_item = previous_item.extend(item)
+                else:
+                    yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope)
+                    previous_item = item
+            else:
+                previous_item = item
+    if previous_item is not None:
+        yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope)
+
 
 def extract_from_annotated_document(structured_document, tag_scope=None):
-  return extract_from_annotated_lines(
-    structured_document, get_lines(structured_document), tag_scope=tag_scope
-  )
+    return extract_from_annotated_lines(
+        structured_document, get_lines(structured_document), tag_scope=tag_scope
+    )
diff --git a/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py b/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py
index 847f0e8..9c467c2 100644
--- a/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py
+++ b/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py
@@ -1,15 +1,15 @@
 import logging
 
 from sciencebeam_gym.structured_document import (
-  SimpleToken,
-  SimpleLine,
-  SimpleStructuredDocument,
-  B_TAG_PREFIX,
-  I_TAG_PREFIX
+    SimpleToken,
+    SimpleLine,
+    SimpleStructuredDocument,
+    B_TAG_PREFIX,
+    I_TAG_PREFIX
 )
 
 from sciencebeam_gym.inference_model.extract_from_annotated_document import (
-  extract_from_annotated_document
+    extract_from_annotated_document
 )
 
 VALUE_1 = 'value1'
@@ -24,233 +24,242 @@ TAG_3 = 'tag3'
 
 TAG_SCOPE_1 = 'tag_scope1'
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def with_tag(x, tag):
-  if isinstance(x, SimpleToken):
-    x.set_tag(tag)
-  elif isinstance(x, list):
-    return [with_tag(y, tag) for y in x]
-  elif isinstance(x, SimpleLine):
-    return SimpleLine(with_tag(x.tokens, tag))
-  return x
+    if isinstance(x, SimpleToken):
+        x.set_tag(tag)
+    elif isinstance(x, list):
+        return [with_tag(y, tag) for y in x]
+    elif isinstance(x, SimpleLine):
+        return SimpleLine(with_tag(x.tokens, tag))
+    return x
+
 
 def to_token(token):
-  return SimpleToken(token) if isinstance(token, str) else token
+    return SimpleToken(token) if isinstance(token, str) else token
+
 
 def to_tokens(tokens):
-  if isinstance(tokens, str):
-    tokens = tokens.split(' ')
-  return [to_token(t) for t in tokens]
+    if isinstance(tokens, str):
+        tokens = tokens.split(' ')
+    return [to_token(t) for t in tokens]
+
 
 def to_line(tokens):
-  return SimpleLine(to_tokens(tokens))
+    return SimpleLine(to_tokens(tokens))
+
 
 def annotated_tokens(tokens, tag):
-  return with_tag(to_tokens(tokens), tag)
+    return with_tag(to_tokens(tokens), tag)
+
 
 def annotated_line(tokens, tag):
-  return with_tag(to_line(tokens), tag)
+    return with_tag(to_line(tokens), tag)
+
 
 def _token_with_sub_tag(text, tag=None, tag_prefix=None, sub_tag=None, sub_tag_prefix=None):
-  token = SimpleToken(text, tag=tag, tag_prefix=tag_prefix)
-  if sub_tag:
-    token.set_tag(sub_tag, prefix=sub_tag_prefix, level=2)
-  return token
+    token = SimpleToken(text, tag=tag, tag_prefix=tag_prefix)
+    if sub_tag:
+        token.set_tag(sub_tag, prefix=sub_tag_prefix, level=2)
+    return token
+
 
 class TestExtractFromAnnotatedDocument(object):
-  def test_should_not_fail_on_empty_document(self):
-    structured_document = SimpleStructuredDocument()
-    extract_from_annotated_document(structured_document)
+    def test_should_not_fail_on_empty_document(self):
+        structured_document = SimpleStructuredDocument()
+        extract_from_annotated_document(structured_document)
 
-  def test_should_extract_single_annotated_line(self):
-    lines = [annotated_line(TEXT_1, TAG_1)]
-    structured_document = SimpleStructuredDocument(lines=lines)
-    result = [
-      (x.tag, x.text)
-      for x in
-      extract_from_annotated_document(structured_document)
-    ]
-    assert result == [(TAG_1, TEXT_1)]
+    def test_should_extract_single_annotated_line(self):
+        lines = [annotated_line(TEXT_1, TAG_1)]
+        structured_document = SimpleStructuredDocument(lines=lines)
+        result = [
+            (x.tag, x.text)
+            for x in
+            extract_from_annotated_document(structured_document)
+        ]
+        assert result == [(TAG_1, TEXT_1)]
 
-  def test_should_extract_from_different_tag_scope(self):
-    lines = [SimpleLine([SimpleToken(TEXT_1, tag=TAG_1, tag_scope=TAG_SCOPE_1)])]
-    structured_document = SimpleStructuredDocument(lines=lines)
-    result = [
-      (x.tag, x.text)
-      for x in
-      extract_from_annotated_document(structured_document, tag_scope=TAG_SCOPE_1)
-    ]
-    assert result == [(TAG_1, TEXT_1)]
+    def test_should_extract_from_different_tag_scope(self):
+        lines = [SimpleLine([SimpleToken(TEXT_1, tag=TAG_1, tag_scope=TAG_SCOPE_1)])]
+        structured_document = SimpleStructuredDocument(lines=lines)
+        result = [
+            (x.tag, x.text)
+            for x in
+            extract_from_annotated_document(structured_document, tag_scope=TAG_SCOPE_1)
+        ]
+        assert result == [(TAG_1, TEXT_1)]
 
-  def test_should_extract_multiple_annotations_on_single_line(self):
-    lines = [to_line(
-      annotated_tokens(TEXT_1, TAG_1) +
-      to_tokens(TEXT_2) +
-      annotated_tokens(TEXT_3, TAG_3)
-    )]
-    structured_document = SimpleStructuredDocument(lines=lines)
-    result = [
-      (x.tag, x.text)
-      for x in
-      extract_from_annotated_document(structured_document)
-    ]
-    assert result == [
-      (TAG_1, TEXT_1),
-      (None, TEXT_2),
-      (TAG_3, TEXT_3)
-    ]
+    def test_should_extract_multiple_annotations_on_single_line(self):
+        lines = [to_line(
+            annotated_tokens(TEXT_1, TAG_1) +
+            to_tokens(TEXT_2) +
+            annotated_tokens(TEXT_3, TAG_3)
+        )]
+        structured_document = SimpleStructuredDocument(lines=lines)
+        result = [
+            (x.tag, x.text)
+            for x in
+            extract_from_annotated_document(structured_document)
+        ]
+        assert result == [
+            (TAG_1, TEXT_1),
+            (None, TEXT_2),
+            (TAG_3, TEXT_3)
+        ]
 
-  def test_should_combine_multiple_lines(self):
-    lines = [
-      annotated_line(TEXT_1, TAG_1),
-      annotated_line(TEXT_2, TAG_1)
-    ]
-    structured_document = SimpleStructuredDocument(lines=lines)
-    result = [
-      (x.tag, x.text)
-      for x in
-      extract_from_annotated_document(structured_document)
-    ]
-    get_logger().debug('result: %s', result)
-    assert result == [(TAG_1, '\n'.join([TEXT_1, TEXT_2]))]
+    def test_should_combine_multiple_lines(self):
+        lines = [
+            annotated_line(TEXT_1, TAG_1),
+            annotated_line(TEXT_2, TAG_1)
+        ]
+        structured_document = SimpleStructuredDocument(lines=lines)
+        result = [
+            (x.tag, x.text)
+            for x in
+            extract_from_annotated_document(structured_document)
+        ]
+        get_logger().debug('result: %s', result)
+        assert result == [(TAG_1, '\n'.join([TEXT_1, TEXT_2]))]
 
-  def test_should_combine_multiple_lines_separated_by_other_tag(self):
-    lines = [
-      annotated_line(TEXT_1, TAG_1),
-      annotated_line(TEXT_2, TAG_2),
-      annotated_line(TEXT_3, TAG_2),
-      annotated_line(TEXT_1, TAG_1),
-      annotated_line(TEXT_2, TAG_2),
-      annotated_line(TEXT_3, TAG_2)
-    ]
-    structured_document = SimpleStructuredDocument(lines=lines)
-    result = [
-      (x.tag, x.text)
-      for x in
-      extract_from_annotated_document(structured_document)
-    ]
-    get_logger().debug('result: %s', result)
-    assert result == [
-      (TAG_1, TEXT_1),
-      (TAG_2, '\n'.join([TEXT_2, TEXT_3])),
-      (TAG_1, TEXT_1),
-      (TAG_2, '\n'.join([TEXT_2, TEXT_3]))
-    ]
+    def test_should_combine_multiple_lines_separated_by_other_tag(self):
+        lines = [
+            annotated_line(TEXT_1, TAG_1),
+            annotated_line(TEXT_2, TAG_2),
+            annotated_line(TEXT_3, TAG_2),
+            annotated_line(TEXT_1, TAG_1),
+            annotated_line(TEXT_2, TAG_2),
+            annotated_line(TEXT_3, TAG_2)
+        ]
+        structured_document = SimpleStructuredDocument(lines=lines)
+        result = [
+            (x.tag, x.text)
+            for x in
+            extract_from_annotated_document(structured_document)
+        ]
+        get_logger().debug('result: %s', result)
+        assert result == [
+            (TAG_1, TEXT_1),
+            (TAG_2, '\n'.join([TEXT_2, TEXT_3])),
+            (TAG_1, TEXT_1),
+            (TAG_2, '\n'.join([TEXT_2, TEXT_3]))
+        ]
 
-  def test_should_separate_items_based_on_tag_prefix(self):
-    tokens = [
-      SimpleToken(VALUE_1, tag=TAG_1, tag_prefix=B_TAG_PREFIX),
-      SimpleToken(VALUE_2, tag=TAG_1, tag_prefix=I_TAG_PREFIX),
-      SimpleToken(VALUE_3, tag=TAG_1, tag_prefix=I_TAG_PREFIX),
-      SimpleToken(VALUE_1, tag=TAG_1, tag_prefix=B_TAG_PREFIX),
-      SimpleToken(VALUE_2, tag=TAG_1, tag_prefix=I_TAG_PREFIX),
-      SimpleToken(VALUE_3, tag=TAG_1, tag_prefix=I_TAG_PREFIX)
-    ]
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)])
-    result = [
-      (x.tag, x.text)
-      for x in
-      extract_from_annotated_document(structured_document)
-    ]
-    get_logger().debug('result: %s', result)
-    assert result == [
-      (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3])),
-      (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]))
-    ]
+    def test_should_separate_items_based_on_tag_prefix(self):
+        tokens = [
+            SimpleToken(VALUE_1, tag=TAG_1, tag_prefix=B_TAG_PREFIX),
+            SimpleToken(VALUE_2, tag=TAG_1, tag_prefix=I_TAG_PREFIX),
+            SimpleToken(VALUE_3, tag=TAG_1, tag_prefix=I_TAG_PREFIX),
+            SimpleToken(VALUE_1, tag=TAG_1, tag_prefix=B_TAG_PREFIX),
+            SimpleToken(VALUE_2, tag=TAG_1, tag_prefix=I_TAG_PREFIX),
+            SimpleToken(VALUE_3, tag=TAG_1, tag_prefix=I_TAG_PREFIX)
+        ]
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)])
+        result = [
+            (x.tag, x.text)
+            for x in
+            extract_from_annotated_document(structured_document)
+        ]
+        get_logger().debug('result: %s', result)
+        assert result == [
+            (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3])),
+            (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]))
+        ]
 
-  def test_should_extract_sub_tags_from_single_item(self):
-    tokens = [
-      _token_with_sub_tag(
-        VALUE_1,
-        tag=TAG_1, tag_prefix=B_TAG_PREFIX,
-        sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX
-      ),
-      _token_with_sub_tag(
-        VALUE_2,
-        tag=TAG_1, tag_prefix=I_TAG_PREFIX,
-        sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX
-      ),
-      _token_with_sub_tag(
-        VALUE_3,
-        tag=TAG_1, tag_prefix=I_TAG_PREFIX,
-        sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX
-      )
-    ]
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)])
-    extracted_items = list(extract_from_annotated_document(structured_document))
-    result = [
-      (
-        x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items]
-      ) for x in extracted_items
-    ]
-    get_logger().debug('result: %s', result)
-    assert result == [
-      (
-        TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]),
-        [
-          (TAG_2, ' '.join([VALUE_1, VALUE_2])),
-          (TAG_3, VALUE_3)
+    def test_should_extract_sub_tags_from_single_item(self):
+        tokens = [
+            _token_with_sub_tag(
+                VALUE_1,
+                tag=TAG_1, tag_prefix=B_TAG_PREFIX,
+                sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX
+            ),
+            _token_with_sub_tag(
+                VALUE_2,
+                tag=TAG_1, tag_prefix=I_TAG_PREFIX,
+                sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX
+            ),
+            _token_with_sub_tag(
+                VALUE_3,
+                tag=TAG_1, tag_prefix=I_TAG_PREFIX,
+                sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX
+            )
+        ]
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)])
+        extracted_items = list(extract_from_annotated_document(structured_document))
+        result = [
+            (
+                x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items]
+            ) for x in extracted_items
+        ]
+        get_logger().debug('result: %s', result)
+        assert result == [
+            (
+                TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]),
+                [
+                    (TAG_2, ' '.join([VALUE_1, VALUE_2])),
+                    (TAG_3, VALUE_3)
+                ]
+            )
         ]
-      )
-    ]
 
-  def test_should_extract_sub_tags_from_multiple_items(self):
-    tokens = [
-      _token_with_sub_tag(
-        VALUE_1,
-        tag=TAG_1, tag_prefix=B_TAG_PREFIX,
-        sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX
-      ),
-      _token_with_sub_tag(
-        VALUE_2,
-        tag=TAG_1, tag_prefix=I_TAG_PREFIX,
-        sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX
-      ),
-      _token_with_sub_tag(
-        VALUE_3,
-        tag=TAG_1, tag_prefix=I_TAG_PREFIX,
-        sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX
-      ),
+    def test_should_extract_sub_tags_from_multiple_items(self):
+        tokens = [
+            _token_with_sub_tag(
+                VALUE_1,
+                tag=TAG_1, tag_prefix=B_TAG_PREFIX,
+                sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX
+            ),
+            _token_with_sub_tag(
+                VALUE_2,
+                tag=TAG_1, tag_prefix=I_TAG_PREFIX,
+                sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX
+            ),
+            _token_with_sub_tag(
+                VALUE_3,
+                tag=TAG_1, tag_prefix=I_TAG_PREFIX,
+                sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX
+            ),
 
-      _token_with_sub_tag(
-        VALUE_1,
-        tag=TAG_1, tag_prefix=B_TAG_PREFIX,
-        sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX
-      ),
-      _token_with_sub_tag(
-        VALUE_2,
-        tag=TAG_1, tag_prefix=I_TAG_PREFIX,
-        sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX
-      ),
-      _token_with_sub_tag(
-        VALUE_3,
-        tag=TAG_1, tag_prefix=I_TAG_PREFIX,
-        sub_tag=TAG_3, sub_tag_prefix=I_TAG_PREFIX
-      )
-    ]
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)])
-    extracted_items = list(extract_from_annotated_document(structured_document))
-    result = [
-      (
-        x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items]
-      ) for x in extracted_items
-    ]
-    get_logger().debug('result: %s', result)
-    assert result == [
-      (
-        TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]),
-        [
-          (TAG_2, ' '.join([VALUE_1, VALUE_2])),
-          (TAG_3, VALUE_3)
+            _token_with_sub_tag(
+                VALUE_1,
+                tag=TAG_1, tag_prefix=B_TAG_PREFIX,
+                sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX
+            ),
+            _token_with_sub_tag(
+                VALUE_2,
+                tag=TAG_1, tag_prefix=I_TAG_PREFIX,
+                sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX
+            ),
+            _token_with_sub_tag(
+                VALUE_3,
+                tag=TAG_1, tag_prefix=I_TAG_PREFIX,
+                sub_tag=TAG_3, sub_tag_prefix=I_TAG_PREFIX
+            )
+        ]
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)])
+        extracted_items = list(extract_from_annotated_document(structured_document))
+        result = [
+            (
+                x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items]
+            ) for x in extracted_items
         ]
-      ),
-      (
-        TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]),
-        [
-          (TAG_2, VALUE_1),
-          (TAG_3, ' '.join([VALUE_2, VALUE_3]))
+        get_logger().debug('result: %s', result)
+        assert result == [
+            (
+                TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]),
+                [
+                    (TAG_2, ' '.join([VALUE_1, VALUE_2])),
+                    (TAG_3, VALUE_3)
+                ]
+            ),
+            (
+                TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]),
+                [
+                    (TAG_2, VALUE_1),
+                    (TAG_3, ' '.join([VALUE_2, VALUE_3]))
+                ]
+            )
         ]
-      )
-    ]
diff --git a/sciencebeam_gym/inference_model/extract_to_xml.py b/sciencebeam_gym/inference_model/extract_to_xml.py
index fba57f7..927ee26 100644
--- a/sciencebeam_gym/inference_model/extract_to_xml.py
+++ b/sciencebeam_gym/inference_model/extract_to_xml.py
@@ -5,202 +5,220 @@ from lxml import etree
 from lxml.builder import E
 
 from sciencebeam_utils.beam_utils.io import (
-  save_file_content
+    save_file_content
 )
 
 from sciencebeam_gym.structured_document.structured_document_loader import (
-  load_structured_document
+    load_structured_document
 )
 
 from sciencebeam_gym.inference_model.extract_from_annotated_document import (
-  extract_from_annotated_document
+    extract_from_annotated_document
 )
 
+
 class Tags(object):
-  TITLE = 'manuscript_title'
-  ABSTRACT = 'abstract'
-  AUTHOR = 'author'
-  AUTHOR_AFF = 'author_aff'
+    TITLE = 'manuscript_title'
+    ABSTRACT = 'abstract'
+    AUTHOR = 'author'
+    AUTHOR_AFF = 'author_aff'
+
 
 class XmlPaths(object):
-  TITLE = 'front/article-meta/title-group/article-title'
-  ABSTRACT = 'front/article-meta/abstract'
-  AUTHOR = 'front/article-meta/contrib-group/contrib'
-  AUTHOR_AFF = 'front/article-meta/contrib-group/aff'
+    TITLE = 'front/article-meta/title-group/article-title'
+    ABSTRACT = 'front/article-meta/abstract'
+    AUTHOR = 'front/article-meta/contrib-group/contrib'
+    AUTHOR_AFF = 'front/article-meta/contrib-group/aff'
+
 
 class SubTags(object):
-  AUTHOR_SURNAME = 'surname'
-  AUTHOR_GIVEN_NAMES = 'givennames'
+    AUTHOR_SURNAME = 'surname'
+    AUTHOR_GIVEN_NAMES = 'givennames'
+
 
 class SubXmlPaths(object):
-  AUTHOR_SURNAME = 'name/surname'
-  AUTHOR_GIVEN_NAMES = 'name/given-names'
+    AUTHOR_SURNAME = 'name/surname'
+    AUTHOR_GIVEN_NAMES = 'name/given-names'
+
 
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def rsplit_xml_path(path):
-  i = path.rfind('/')
-  if i >= 0:
-    return path[0:i], path[i + 1:]
-  else:
-    return None, path
+    i = path.rfind('/')
+    if i >= 0:
+        return path[0:i], path[i + 1:]
+    else:
+        return None, path
+
 
 def create_node_recursive(xml_root, path, exists_ok=False):
-  node = xml_root.find(path)
-  if node is not None:
-    if not exists_ok:
-      raise RuntimeError('xml node already exists: %s' % path)
+    node = xml_root.find(path)
+    if node is not None:
+        if not exists_ok:
+            raise RuntimeError('xml node already exists: %s' % path)
+        return node
+    parent, base = rsplit_xml_path(path)
+    if parent:
+        parent_node = create_node_recursive(xml_root, parent, exists_ok=True)
+    else:
+        parent_node = xml_root
+    node = etree.Element(base)
+    parent_node.append(node)
     return node
-  parent, base = rsplit_xml_path(path)
-  if parent:
-    parent_node = create_node_recursive(xml_root, parent, exists_ok=True)
-  else:
-    parent_node = xml_root
-  node = etree.Element(base)
-  parent_node.append(node)
-  return node
+
 
 def create_and_append_xml_node(xml_root, path):
-  parent, base = rsplit_xml_path(path)
-  parent_node = (
-    create_node_recursive(xml_root, parent, exists_ok=True)
-    if parent
-    else xml_root
-  )
-  node = etree.Element(base)
-  parent_node.append(node)
-  return node
+    parent, base = rsplit_xml_path(path)
+    parent_node = (
+        create_node_recursive(xml_root, parent, exists_ok=True)
+        if parent
+        else xml_root
+    )
+    node = etree.Element(base)
+    parent_node.append(node)
+    return node
+
 
 def create_xml_text(xml_root, path, text):
-  node = create_and_append_xml_node(xml_root, path)
-  node.text = text
-  return node
+    node = create_and_append_xml_node(xml_root, path)
+    node.text = text
+    return node
+
 
 AUTHOR_JUNK_CHARS = ',+*0123456789'
 
+
 def _clean_author_name(s):
-  i = len(s)
-  while (
-    i > 0 and
-    (
-      s[i - 1] in AUTHOR_JUNK_CHARS or
-      # only remove dot after special characters
-      (s[i - 1] == '.' and i >= 2 and s[i - 2] in AUTHOR_JUNK_CHARS)
-    )
-  ):
-    i -= 1
-  return s[:i]
+    i = len(s)
+    while (
+        i > 0 and
+        (
+            s[i - 1] in AUTHOR_JUNK_CHARS or
+            # only remove dot after special characters
+            (s[i - 1] == '.' and i >= 2 and s[i - 2] in AUTHOR_JUNK_CHARS)
+        )
+    ):
+        i -= 1
+    return s[:i]
+
 
 class XmlMapping(object):
-  def __init__(
-    self, xml_path, single_node=False, sub_mapping=None, attrib=None,
-    clean_fn=None):
+    def __init__(
+            self, xml_path, single_node=False, sub_mapping=None, attrib=None,
+            clean_fn=None):
+
+        self.xml_path = xml_path
+        self.single_node = single_node
+        self.sub_mapping = sub_mapping
+        self.attrib = attrib
+        self.clean_fn = clean_fn
 
-    self.xml_path = xml_path
-    self.single_node = single_node
-    self.sub_mapping = sub_mapping
-    self.attrib = attrib
-    self.clean_fn = clean_fn
 
 def _extract_items(parent_node, extracted_items, xml_mapping):
-  previous_tag = None
-  for extracted_item in extracted_items:
-    tag = extracted_item.tag
-    if tag:
-      mapping_entry = xml_mapping.get(tag)
-      if not mapping_entry:
-        get_logger().warning('tag not configured: %s', tag)
-        continue
-      extracted_text = extracted_item.text
-      if extracted_text and mapping_entry.clean_fn:
-        extracted_text = mapping_entry.clean_fn(extracted_text)
-      path = mapping_entry.xml_path
-      if mapping_entry.single_node:
-        node = create_node_recursive(parent_node, path, exists_ok=True)
-        if node.text is None:
-          node.text = extracted_text
-        elif previous_tag == tag:
-          node.text += '\n' + extracted_text
-        else:
-          get_logger().debug('ignoring tag %s, after tag %s', tag, previous_tag)
-      else:
-        node = create_and_append_xml_node(parent_node, path)
-        if mapping_entry.attrib:
-          for k, v in mapping_entry.attrib.items():
-            node.attrib[k] = v
-        if extracted_item.sub_items and mapping_entry.sub_mapping:
-          _extract_items(node, extracted_item.sub_items, mapping_entry.sub_mapping)
-        else:
-          node.text = extracted_text
-      previous_tag = tag
+    previous_tag = None
+    for extracted_item in extracted_items:
+        tag = extracted_item.tag
+        if tag:
+            mapping_entry = xml_mapping.get(tag)
+            if not mapping_entry:
+                get_logger().warning('tag not configured: %s', tag)
+                continue
+            extracted_text = extracted_item.text
+            if extracted_text and mapping_entry.clean_fn:
+                extracted_text = mapping_entry.clean_fn(extracted_text)
+            path = mapping_entry.xml_path
+            if mapping_entry.single_node:
+                node = create_node_recursive(parent_node, path, exists_ok=True)
+                if node.text is None:
+                    node.text = extracted_text
+                elif previous_tag == tag:
+                    node.text += '\n' + extracted_text
+                else:
+                    get_logger().debug('ignoring tag %s, after tag %s', tag, previous_tag)
+            else:
+                node = create_and_append_xml_node(parent_node, path)
+                if mapping_entry.attrib:
+                    for k, v in mapping_entry.attrib.items():
+                        node.attrib[k] = v
+                if extracted_item.sub_items and mapping_entry.sub_mapping:
+                    _extract_items(node, extracted_item.sub_items, mapping_entry.sub_mapping)
+                else:
+                    node.text = extracted_text
+            previous_tag = tag
+
 
 def extracted_items_to_xml(extracted_items):
-  xml_mapping = {
-    Tags.TITLE: XmlMapping(XmlPaths.TITLE, single_node=True),
-    Tags.ABSTRACT: XmlMapping(XmlPaths.ABSTRACT, single_node=True),
-    Tags.AUTHOR: XmlMapping(XmlPaths.AUTHOR, sub_mapping={
-      SubTags.AUTHOR_GIVEN_NAMES: XmlMapping(
-        SubXmlPaths.AUTHOR_GIVEN_NAMES, clean_fn=_clean_author_name
-      ),
-      SubTags.AUTHOR_SURNAME: XmlMapping(
-        SubXmlPaths.AUTHOR_SURNAME, clean_fn=_clean_author_name
-      )
-    }, attrib={
-      'contrib-type': 'author'
-    }, clean_fn=_clean_author_name),
-    Tags.AUTHOR_AFF: XmlMapping(XmlPaths.AUTHOR_AFF)
-  }
-  xml_root = E.article()
-  _extract_items(xml_root, extracted_items, xml_mapping)
-  return xml_root
+    xml_mapping = {
+        Tags.TITLE: XmlMapping(XmlPaths.TITLE, single_node=True),
+        Tags.ABSTRACT: XmlMapping(XmlPaths.ABSTRACT, single_node=True),
+        Tags.AUTHOR: XmlMapping(XmlPaths.AUTHOR, sub_mapping={
+            SubTags.AUTHOR_GIVEN_NAMES: XmlMapping(
+                SubXmlPaths.AUTHOR_GIVEN_NAMES, clean_fn=_clean_author_name
+            ),
+            SubTags.AUTHOR_SURNAME: XmlMapping(
+                SubXmlPaths.AUTHOR_SURNAME, clean_fn=_clean_author_name
+            )
+        }, attrib={
+            'contrib-type': 'author'
+        }, clean_fn=_clean_author_name),
+        Tags.AUTHOR_AFF: XmlMapping(XmlPaths.AUTHOR_AFF)
+    }
+    xml_root = E.article()
+    _extract_items(xml_root, extracted_items, xml_mapping)
+    return xml_root
+
 
 def extract_structured_document_to_xml(structured_document, tag_scope=None):
-  return extracted_items_to_xml(
-    extract_from_annotated_document(structured_document, tag_scope=tag_scope)
-  )
+    return extracted_items_to_xml(
+        extract_from_annotated_document(structured_document, tag_scope=tag_scope)
+    )
+
 
 def parse_args(argv=None):
-  parser = argparse.ArgumentParser('Extract JATSy XML from annotated LXML')
-  parser.add_argument(
-    '--lxml-path', type=str, required=True,
-    help='path to lxml or svg pages document'
-  )
+    parser = argparse.ArgumentParser('Extract JATSy XML from annotated LXML')
+    parser.add_argument(
+        '--lxml-path', type=str, required=True,
+        help='path to lxml or svg pages document'
+    )
 
-  parser.add_argument(
-    '--tag-scope', type=str, required=False,
-    help='tag scope to extract based on'
-  )
+    parser.add_argument(
+        '--tag-scope', type=str, required=False,
+        help='tag scope to extract based on'
+    )
 
-  parser.add_argument(
-    '--output-path', type=str, required=True,
-    help='output path to annotated document'
-  )
+    parser.add_argument(
+        '--output-path', type=str, required=True,
+        help='output path to annotated document'
+    )
 
-  parser.add_argument(
-    '--debug', action='store_true', default=False,
-    help='enable debug output'
-  )
+    parser.add_argument(
+        '--debug', action='store_true', default=False,
+        help='enable debug output'
+    )
+
+    return parser.parse_args(argv)
 
-  return parser.parse_args(argv)
 
 def main(argv=None):
-  args = parse_args(argv)
+    args = parse_args(argv)
+
+    if args.debug:
+        logging.getLogger().setLevel('DEBUG')
 
-  if args.debug:
-    logging.getLogger().setLevel('DEBUG')
+    structured_document = load_structured_document(args.lxml_path)
 
-  structured_document = load_structured_document(args.lxml_path)
+    xml_root = extract_structured_document_to_xml(
+        structured_document,
+        tag_scope=args.tag_scope
+    )
 
-  xml_root = extract_structured_document_to_xml(
-    structured_document,
-    tag_scope=args.tag_scope
-  )
+    get_logger().info('writing result to: %s', args.output_path)
+    save_file_content(args.output_path, etree.tostring(xml_root, pretty_print=True))
 
-  get_logger().info('writing result to: %s', args.output_path)
-  save_file_content(args.output_path, etree.tostring(xml_root, pretty_print=True))
 
 if __name__ == '__main__':
-  logging.basicConfig(level='INFO')
+    logging.basicConfig(level='INFO')
 
-  main()
+    main()
diff --git a/sciencebeam_gym/inference_model/extract_to_xml_test.py b/sciencebeam_gym/inference_model/extract_to_xml_test.py
index dad436a..abf1329 100644
--- a/sciencebeam_gym/inference_model/extract_to_xml_test.py
+++ b/sciencebeam_gym/inference_model/extract_to_xml_test.py
@@ -6,166 +6,169 @@ from lxml import etree
 from lxml.builder import E
 
 from sciencebeam_utils.utils.xml import (
-  get_text_content,
-  get_text_content_list
+    get_text_content,
+    get_text_content_list
 )
 
 from sciencebeam_gym.inference_model.extract_from_annotated_document import (
-  ExtractedItem
+    ExtractedItem
 )
 
 from sciencebeam_gym.inference_model.extract_to_xml import (
-  extracted_items_to_xml,
-  Tags,
-  XmlPaths,
-  SubTags,
-  SubXmlPaths,
-  main
+    extracted_items_to_xml,
+    Tags,
+    XmlPaths,
+    SubTags,
+    SubXmlPaths,
+    main
 )
 
 TEXT_1 = 'some text here'
 TEXT_2 = 'more text to come'
 TEXT_3 = 'does not stop here'
 
+
 def _create_author_extracted_items(given_names, surname):
-  return [
-    ExtractedItem(Tags.AUTHOR, ' '.join([given_names, surname]), sub_items=[
-      ExtractedItem(SubTags.AUTHOR_GIVEN_NAMES, given_names),
-      ExtractedItem(SubTags.AUTHOR_SURNAME, surname)
-    ])
-  ]
+    return [
+        ExtractedItem(Tags.AUTHOR, ' '.join([given_names, surname]), sub_items=[
+            ExtractedItem(SubTags.AUTHOR_GIVEN_NAMES, given_names),
+            ExtractedItem(SubTags.AUTHOR_SURNAME, surname)
+        ])
+    ]
+
 
 class TestExtractedItemsToXml(object):
-  def test_should_return_empty_xml_for_no_empty_list_of_extracted_items(self):
-    xml_root = extracted_items_to_xml([])
-    assert xml_root is not None
-
-  def test_should_populate_title(self):
-    xml_root = extracted_items_to_xml([
-      ExtractedItem(Tags.TITLE, TEXT_1)
-    ])
-    assert xml_root is not None
-    assert get_text_content(xml_root.find(XmlPaths.TITLE)) == TEXT_1
-
-  def test_should_append_to_abstract(self):
-    xml_root = extracted_items_to_xml([
-      ExtractedItem(Tags.ABSTRACT, TEXT_1),
-      ExtractedItem(Tags.ABSTRACT, TEXT_2)
-    ])
-    assert xml_root is not None
-    assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1, TEXT_2])
-
-  def test_should_not_append_to_abstract_after_untagged_content(self):
-    xml_root = extracted_items_to_xml([
-      ExtractedItem(Tags.ABSTRACT, TEXT_1),
-      ExtractedItem(None, TEXT_2),
-      ExtractedItem(Tags.ABSTRACT, TEXT_3)
-    ])
-    assert xml_root is not None
-    assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1, TEXT_3])
-
-  def test_should_not_append_to_abstract_after_another_tag_occured(self):
-    xml_root = extracted_items_to_xml([
-      ExtractedItem(Tags.ABSTRACT, TEXT_1),
-      ExtractedItem(Tags.AUTHOR, TEXT_2),
-      ExtractedItem(Tags.ABSTRACT, TEXT_3)
-    ])
-    assert xml_root is not None
-    assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1])
-
-  def test_should_create_separate_author_node(self):
-    xml_root = extracted_items_to_xml([
-      ExtractedItem(Tags.AUTHOR, TEXT_1),
-      ExtractedItem(Tags.AUTHOR, TEXT_2)
-    ])
-    assert xml_root is not None
-    assert get_text_content_list(xml_root.findall(XmlPaths.AUTHOR)) == [TEXT_1, TEXT_2]
-
-  def test_should_extract_author_surname_and_given_names_from_single_author(self):
-    xml_root = extracted_items_to_xml([
-      ExtractedItem(Tags.AUTHOR, ' '.join([TEXT_1, TEXT_2]), sub_items=[
-        ExtractedItem(SubTags.AUTHOR_GIVEN_NAMES, TEXT_1),
-        ExtractedItem(SubTags.AUTHOR_SURNAME, TEXT_2)
-      ])
-    ])
-    assert xml_root is not None
-    author = xml_root.find(XmlPaths.AUTHOR)
-    assert author is not None
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == TEXT_1
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == TEXT_2
-
-  def test_should_remove_special_characters_and_numbers_from_author(self):
-    special_num_chars = ',+*0123456789'
-    xml_root = extracted_items_to_xml(_create_author_extracted_items(
-      TEXT_1 + special_num_chars, TEXT_2 + special_num_chars
-    ))
-    assert xml_root is not None
-    author = xml_root.find(XmlPaths.AUTHOR)
-    assert author is not None
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == TEXT_1
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == TEXT_2
-
-  def test_should_not_remove_dot_after_initials_from_author(self):
-    xml_root = extracted_items_to_xml(_create_author_extracted_items(
-      'Mr T.', 'E.'
-    ))
-    assert xml_root is not None
-    author = xml_root.find(XmlPaths.AUTHOR)
-    assert author is not None
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T.'
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'E.'
-
-  def test_should_not_remove_dot_after_suffix_from_author(self):
-    xml_root = extracted_items_to_xml(_create_author_extracted_items(
-      'Mr T.', 'Jr.'
-    ))
-    assert xml_root is not None
-    author = xml_root.find(XmlPaths.AUTHOR)
-    assert author is not None
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T.'
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'Jr.'
-
-  def test_should_remove_dot_after_other_special_characters(self):
-    xml_root = extracted_items_to_xml(_create_author_extracted_items(
-      'Mr T*.', 'E*.'
-    ))
-    assert xml_root is not None
-    author = xml_root.find(XmlPaths.AUTHOR)
-    assert author is not None
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T'
-    assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'E'
-
-  def test_should_add_contrib_type_author_attribute(self):
-    xml_root = extracted_items_to_xml(_create_author_extracted_items(TEXT_1, TEXT_2))
-    assert xml_root is not None
-    author = xml_root.find(XmlPaths.AUTHOR)
-    assert author is not None
-    assert author.tag == 'contrib'
-    assert author.attrib.get('contrib-type') == 'author'
+    def test_should_return_empty_xml_for_no_empty_list_of_extracted_items(self):
+        xml_root = extracted_items_to_xml([])
+        assert xml_root is not None
+
+    def test_should_populate_title(self):
+        xml_root = extracted_items_to_xml([
+            ExtractedItem(Tags.TITLE, TEXT_1)
+        ])
+        assert xml_root is not None
+        assert get_text_content(xml_root.find(XmlPaths.TITLE)) == TEXT_1
+
+    def test_should_append_to_abstract(self):
+        xml_root = extracted_items_to_xml([
+            ExtractedItem(Tags.ABSTRACT, TEXT_1),
+            ExtractedItem(Tags.ABSTRACT, TEXT_2)
+        ])
+        assert xml_root is not None
+        assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1, TEXT_2])
+
+    def test_should_not_append_to_abstract_after_untagged_content(self):
+        xml_root = extracted_items_to_xml([
+            ExtractedItem(Tags.ABSTRACT, TEXT_1),
+            ExtractedItem(None, TEXT_2),
+            ExtractedItem(Tags.ABSTRACT, TEXT_3)
+        ])
+        assert xml_root is not None
+        assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1, TEXT_3])
+
+    def test_should_not_append_to_abstract_after_another_tag_occured(self):
+        xml_root = extracted_items_to_xml([
+            ExtractedItem(Tags.ABSTRACT, TEXT_1),
+            ExtractedItem(Tags.AUTHOR, TEXT_2),
+            ExtractedItem(Tags.ABSTRACT, TEXT_3)
+        ])
+        assert xml_root is not None
+        assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1])
+
+    def test_should_create_separate_author_node(self):
+        xml_root = extracted_items_to_xml([
+            ExtractedItem(Tags.AUTHOR, TEXT_1),
+            ExtractedItem(Tags.AUTHOR, TEXT_2)
+        ])
+        assert xml_root is not None
+        assert get_text_content_list(xml_root.findall(XmlPaths.AUTHOR)) == [TEXT_1, TEXT_2]
+
+    def test_should_extract_author_surname_and_given_names_from_single_author(self):
+        xml_root = extracted_items_to_xml([
+            ExtractedItem(Tags.AUTHOR, ' '.join([TEXT_1, TEXT_2]), sub_items=[
+                ExtractedItem(SubTags.AUTHOR_GIVEN_NAMES, TEXT_1),
+                ExtractedItem(SubTags.AUTHOR_SURNAME, TEXT_2)
+            ])
+        ])
+        assert xml_root is not None
+        author = xml_root.find(XmlPaths.AUTHOR)
+        assert author is not None
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == TEXT_1
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == TEXT_2
+
+    def test_should_remove_special_characters_and_numbers_from_author(self):
+        special_num_chars = ',+*0123456789'
+        xml_root = extracted_items_to_xml(_create_author_extracted_items(
+            TEXT_1 + special_num_chars, TEXT_2 + special_num_chars
+        ))
+        assert xml_root is not None
+        author = xml_root.find(XmlPaths.AUTHOR)
+        assert author is not None
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == TEXT_1
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == TEXT_2
+
+    def test_should_not_remove_dot_after_initials_from_author(self):
+        xml_root = extracted_items_to_xml(_create_author_extracted_items(
+            'Mr T.', 'E.'
+        ))
+        assert xml_root is not None
+        author = xml_root.find(XmlPaths.AUTHOR)
+        assert author is not None
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T.'
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'E.'
+
+    def test_should_not_remove_dot_after_suffix_from_author(self):
+        xml_root = extracted_items_to_xml(_create_author_extracted_items(
+            'Mr T.', 'Jr.'
+        ))
+        assert xml_root is not None
+        author = xml_root.find(XmlPaths.AUTHOR)
+        assert author is not None
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T.'
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'Jr.'
+
+    def test_should_remove_dot_after_other_special_characters(self):
+        xml_root = extracted_items_to_xml(_create_author_extracted_items(
+            'Mr T*.', 'E*.'
+        ))
+        assert xml_root is not None
+        author = xml_root.find(XmlPaths.AUTHOR)
+        assert author is not None
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T'
+        assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'E'
+
+    def test_should_add_contrib_type_author_attribute(self):
+        xml_root = extracted_items_to_xml(_create_author_extracted_items(TEXT_1, TEXT_2))
+        assert xml_root is not None
+        author = xml_root.find(XmlPaths.AUTHOR)
+        assert author is not None
+        assert author.tag == 'contrib'
+        assert author.attrib.get('contrib-type') == 'author'
+
 
 class TestMain(object):
-  def test_should_extract_from_simple_annotated_document(self):
-    with TemporaryDirectory() as path:
-      lxml_root = E.DOCUMENT(
-        E.PAGE(
-          E.TEXT(
-            E.TOKEN(
-              TEXT_1,
-              {
-                'tag': Tags.TITLE
-              }
+    def test_should_extract_from_simple_annotated_document(self):
+        with TemporaryDirectory() as path:
+            lxml_root = E.DOCUMENT(
+                E.PAGE(
+                    E.TEXT(
+                        E.TOKEN(
+                            TEXT_1,
+                            {
+                                'tag': Tags.TITLE
+                            }
+                        )
+                    )
+                )
             )
-          )
-        )
-      )
 
-      lxml_path = os.path.join(path, 'test.lxml')
-      with open(lxml_path, 'w') as f:
-        f.write(etree.tostring(lxml_root))
+            lxml_path = os.path.join(path, 'test.lxml')
+            with open(lxml_path, 'w') as f:
+                f.write(etree.tostring(lxml_root))
 
-      output_path = os.path.join(path, 'test.xml')
+            output_path = os.path.join(path, 'test.xml')
 
-      main(['--lxml-path=%s' % lxml_path, '--output-path=%s' % output_path])
+            main(['--lxml-path=%s' % lxml_path, '--output-path=%s' % output_path])
 
-      xml_root = etree.parse(output_path)
-      assert get_text_content(xml_root.find(XmlPaths.TITLE)) == TEXT_1
+            xml_root = etree.parse(output_path)
+            assert get_text_content(xml_root.find(XmlPaths.TITLE)) == TEXT_1
diff --git a/sciencebeam_gym/model_utils/channels.py b/sciencebeam_gym/model_utils/channels.py
index b342dce..4328e3e 100644
--- a/sciencebeam_gym/model_utils/channels.py
+++ b/sciencebeam_gym/model_utils/channels.py
@@ -1,34 +1,41 @@
 import tensorflow as tf
 
 from sciencebeam_gym.utils.tf import (
-  variable_scoped
+    variable_scoped
 )
 
+
 def color_equals_mask(image, color):
-  return tf.reduce_all(
-    tf.equal(image, color),
-    axis=-1,
-    name='is_color'
-  )
+    return tf.reduce_all(
+        tf.equal(image, color),
+        axis=-1,
+        name='is_color'
+    )
+
 
 def color_equals_mask_as_float(image, color):
-  return tf.cast(color_equals_mask(image, color), tf.float32)
+    return tf.cast(color_equals_mask(image, color), tf.float32)
+
 
 def calculate_color_masks(image, colors, use_unknown_class=False):
-  color_masks = [
-    variable_scoped('channel_%d' % i, lambda: color_equals_mask_as_float(image, color))
-    for i, color in enumerate(colors)
-  ]
-  if use_unknown_class:
-    with tf.variable_scope("unknown_class"):
-      shape = tf.shape(color_masks[0])
-      ones = tf.fill(shape, 1.0, name='ones')
-      zeros = tf.fill(shape, 0.0, name='zeros')
-      color_masks.append(
-        tf.where(
-          tf.add_n(color_masks) < 0.5,
-          ones,
-          zeros
+    color_masks = [
+        variable_scoped(
+            'channel_%d' % i,
+            lambda color_param: color_equals_mask_as_float(image, color_param),
+            color_param=color
         )
-      )
-  return color_masks
+        for i, color in enumerate(colors)
+    ]
+    if use_unknown_class:
+        with tf.variable_scope("unknown_class"):
+            shape = tf.shape(color_masks[0])
+            ones = tf.fill(shape, 1.0, name='ones')
+            zeros = tf.fill(shape, 0.0, name='zeros')
+            color_masks.append(
+                tf.where(
+                    tf.add_n(color_masks) < 0.5,
+                    ones,
+                    zeros
+                )
+            )
+    return color_masks
diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py
index cd76934..de47dad 100644
--- a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py
+++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py
@@ -4,138 +4,146 @@ import pickle
 from itertools import repeat
 
 from sciencebeam_gym.utils.tf import (
-  FileIO
+    FileIO
 )
 
 from sciencebeam_gym.models.text.feature_extractor import (
-  structured_document_to_token_props,
-  token_props_list_to_features,
-  merge_with_cv_structured_document,
-  NONE_TAG
+    structured_document_to_token_props,
+    token_props_list_to_features,
+    merge_with_cv_structured_document,
+    NONE_TAG
 )
 
 from sciencebeam_gym.structured_document.structured_document_loader import (
-  load_lxml_structured_document
+    load_lxml_structured_document
 )
 
 from sciencebeam_gym.structured_document.structured_document_saver import (
-  save_structured_document
+    save_structured_document
 )
 
 CRF_TAG_SCOPE = 'crf'
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def _iter_tokens(structured_document):
-  for page in structured_document.get_pages():
-    for line in structured_document.get_lines_of_page(page):
-      for token in structured_document.get_tokens_of_line(line):
-        yield token
+    for page in structured_document.get_pages():
+        for line in structured_document.get_lines_of_page(page):
+            for token in structured_document.get_tokens_of_line(line):
+                yield token
+
 
 def annotate_structured_document_using_predictions(
-  structured_document, predictions, token_props_list=None,
-  tag_scope=CRF_TAG_SCOPE):
-  """
-  Annotates the structured document using the predicted tags.
-
-  Args:
-    structured_document: the document that will be tagged
-    predictions: list of predicted tags
-    token_props_list: optional, used to verify that the correct token is being tagged
-    tag_scope: tag scope to use when setting predicted tags
-  """
-
-  if token_props_list is None:
-    token_props_list = repeat(None)
-  for token, prediction, token_props in zip(
-    _iter_tokens(structured_document),
-    predictions, token_props_list
+        structured_document, predictions, token_props_list=None,
+        tag_scope=CRF_TAG_SCOPE):
+    """
+    Annotates the structured document using the predicted tags.
+
+    Args:
+      structured_document: the document that will be tagged
+      predictions: list of predicted tags
+      token_props_list: optional, used to verify that the correct token is being tagged
+      tag_scope: tag scope to use when setting predicted tags
+    """
+
+    if token_props_list is None:
+        token_props_list = repeat(None)
+    for token, prediction, token_props in zip(
+        _iter_tokens(structured_document),
+        predictions, token_props_list
     ):
 
-    if token_props:
-      assert structured_document.get_text(token) == token_props['text']
+        if token_props:
+            assert structured_document.get_text(token) == token_props['text']
+
+        if prediction and prediction != NONE_TAG:
+            structured_document.set_tag(token, prediction, scope=tag_scope)
 
-    if prediction and prediction != NONE_TAG:
-      structured_document.set_tag(token, prediction, scope=tag_scope)
 
 def predict_and_annotate_structured_document(structured_document, model, tag_scope=CRF_TAG_SCOPE):
-  token_props = list(structured_document_to_token_props(structured_document))
-  x = token_props_list_to_features(token_props)
-  y_pred = model.predict([x])[0]
-  annotate_structured_document_using_predictions(
-    structured_document, y_pred, token_props, tag_scope=tag_scope
-  )
-  return structured_document
+    token_props = list(structured_document_to_token_props(structured_document))
+    x = token_props_list_to_features(token_props)
+    y_pred = model.predict([x])[0]
+    annotate_structured_document_using_predictions(
+        structured_document, y_pred, token_props, tag_scope=tag_scope
+    )
+    return structured_document
+
 
 def parse_args(argv=None):
-  parser = argparse.ArgumentParser('Annotated LXML using CRF model')
-  source = parser.add_argument_group('source')
-  source.add_argument(
-    '--lxml-path', type=str, required=False,
-    help='path to lxml document'
-  )
-
-  cv_source = parser.add_argument_group('CV source')
-  cv_source.add_argument(
-    '--cv-lxml-path', type=str, required=False,
-    help='path to lxml document with cv predicted tags'
-  )
-
-  parser.add_argument(
-    '--crf-model', type=str, required=True,
-    help='path to saved crf model'
-  )
-
-  parser.add_argument(
-    '--output-path', type=str, required=True,
-    help='output path to annotated document'
-  )
-
-  parser.add_argument(
-    '--tag-scope', type=str, required=False,
-    default=CRF_TAG_SCOPE,
-    help='target tag scope for the predicted tags'
-  )
-
-  parser.add_argument(
-    '--debug', action='store_true', default=False,
-    help='enable debug output'
-  )
-
-  return parser.parse_args(argv)
+    parser = argparse.ArgumentParser('Annotated LXML using CRF model')
+    source = parser.add_argument_group('source')
+    source.add_argument(
+        '--lxml-path', type=str, required=False,
+        help='path to lxml document'
+    )
+
+    cv_source = parser.add_argument_group('CV source')
+    cv_source.add_argument(
+        '--cv-lxml-path', type=str, required=False,
+        help='path to lxml document with cv predicted tags'
+    )
+
+    parser.add_argument(
+        '--crf-model', type=str, required=True,
+        help='path to saved crf model'
+    )
+
+    parser.add_argument(
+        '--output-path', type=str, required=True,
+        help='output path to annotated document'
+    )
+
+    parser.add_argument(
+        '--tag-scope', type=str, required=False,
+        default=CRF_TAG_SCOPE,
+        help='target tag scope for the predicted tags'
+    )
+
+    parser.add_argument(
+        '--debug', action='store_true', default=False,
+        help='enable debug output'
+    )
+
+    return parser.parse_args(argv)
+
 
 def load_crf_model(path):
-  with FileIO(path, 'rb') as crf_model_f:
-    return pickle.load(crf_model_f)
+    with FileIO(path, 'rb') as crf_model_f:
+        return pickle.load(crf_model_f)
+
 
 def main(argv=None):
-  args = parse_args(argv)
+    args = parse_args(argv)
 
-  if args.debug:
-    logging.getLogger().setLevel('DEBUG')
+    if args.debug:
+        logging.getLogger().setLevel('DEBUG')
 
-  structured_document = load_lxml_structured_document(args.lxml_path)
+    structured_document = load_lxml_structured_document(args.lxml_path)
 
-  if args.cv_lxml_path:
-    cv_structured_document = load_lxml_structured_document(args.cv_lxml_path)
-    structured_document = merge_with_cv_structured_document(
-      structured_document,
-      cv_structured_document
-    )
+    if args.cv_lxml_path:
+        cv_structured_document = load_lxml_structured_document(args.cv_lxml_path)
+        structured_document = merge_with_cv_structured_document(
+            structured_document,
+            cv_structured_document
+        )
 
-  model = load_crf_model(args.crf_model)
+    model = load_crf_model(args.crf_model)
+
+    predict_and_annotate_structured_document(
+        structured_document,
+        model,
+        tag_scope=args.tag_scope
+    )
 
-  predict_and_annotate_structured_document(
-    structured_document,
-    model,
-    tag_scope=args.tag_scope
-  )
+    get_logger().info('writing result to: %s', args.output_path)
+    save_structured_document(args.output_path, structured_document)
 
-  get_logger().info('writing result to: %s', args.output_path)
-  save_structured_document(args.output_path, structured_document)
 
 if __name__ == '__main__':
-  logging.basicConfig(level='INFO')
+    logging.basicConfig(level='INFO')
 
-  main()
+    main()
diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py
index 7b576e8..7ba0029 100644
--- a/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py
+++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py
@@ -3,26 +3,26 @@ from mock import MagicMock
 import pytest
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimplePage,
-  SimpleLine,
-  SimpleToken
+    SimpleStructuredDocument,
+    SimplePage,
+    SimpleLine,
+    SimpleToken
 )
 
 from sciencebeam_gym.models.text.feature_extractor import (
-  structured_document_to_token_props,
-  token_props_list_to_features,
-  NONE_TAG
+    structured_document_to_token_props,
+    token_props_list_to_features,
+    NONE_TAG
 )
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.models.text.crf.annotate_using_predictions import (
-  annotate_structured_document_using_predictions,
-  predict_and_annotate_structured_document,
-  CRF_TAG_SCOPE
+    annotate_structured_document_using_predictions,
+    predict_and_annotate_structured_document,
+    CRF_TAG_SCOPE
 )
 
 TAG_1 = 'tag1'
@@ -32,75 +32,77 @@ TOKEN_TEXT_2 = 'token 2'
 
 BOUNDING_BOX = BoundingBox(0, 0, 10, 10)
 
+
 class TestAnnotateStructuredDocumentUsingPredictions(object):
-  def test_should_not_fail_with_empty_document(self):
-    structured_document = SimpleStructuredDocument()
-    annotate_structured_document_using_predictions(
-      structured_document,
-      []
-    )
-
-  def test_should_tag_single_token_using_prediction(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1)
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
-    annotate_structured_document_using_predictions(
-      structured_document,
-      [TAG_1]
-    )
-    assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1
-
-  def test_should_not_tag_using_none_tag(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1)
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
-    annotate_structured_document_using_predictions(
-      structured_document,
-      [NONE_TAG]
-    )
-    assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) is None
-
-  def test_should_tag_single_token_using_prediction_and_check_token_props(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX)
-    structured_document = SimpleStructuredDocument(SimplePage(
-      lines=[SimpleLine([token_1])],
-      bounding_box=BOUNDING_BOX
-    ))
-    token_props_list = structured_document_to_token_props(structured_document)
-    annotate_structured_document_using_predictions(
-      structured_document,
-      [TAG_1],
-      token_props_list
-    )
-    assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1
-
-  def test_should_raise_error_if_token_props_do_not_match(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX)
-    structured_document = SimpleStructuredDocument(SimplePage(
-      lines=[SimpleLine([token_1])],
-      bounding_box=BOUNDING_BOX
-    ))
-    token_props_list = list(structured_document_to_token_props(structured_document))
-    token_props_list[0]['text'] = TOKEN_TEXT_2
-    with pytest.raises(AssertionError):
-      annotate_structured_document_using_predictions(
-        structured_document,
-        [TAG_1],
-        token_props_list
-      )
+    def test_should_not_fail_with_empty_document(self):
+        structured_document = SimpleStructuredDocument()
+        annotate_structured_document_using_predictions(
+            structured_document,
+            []
+        )
+
+    def test_should_tag_single_token_using_prediction(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1)
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
+        annotate_structured_document_using_predictions(
+            structured_document,
+            [TAG_1]
+        )
+        assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1
+
+    def test_should_not_tag_using_none_tag(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1)
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
+        annotate_structured_document_using_predictions(
+            structured_document,
+            [NONE_TAG]
+        )
+        assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) is None
+
+    def test_should_tag_single_token_using_prediction_and_check_token_props(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX)
+        structured_document = SimpleStructuredDocument(SimplePage(
+            lines=[SimpleLine([token_1])],
+            bounding_box=BOUNDING_BOX
+        ))
+        token_props_list = structured_document_to_token_props(structured_document)
+        annotate_structured_document_using_predictions(
+            structured_document,
+            [TAG_1],
+            token_props_list
+        )
+        assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1
+
+    def test_should_raise_error_if_token_props_do_not_match(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX)
+        structured_document = SimpleStructuredDocument(SimplePage(
+            lines=[SimpleLine([token_1])],
+            bounding_box=BOUNDING_BOX
+        ))
+        token_props_list = list(structured_document_to_token_props(structured_document))
+        token_props_list[0]['text'] = TOKEN_TEXT_2
+        with pytest.raises(AssertionError):
+            annotate_structured_document_using_predictions(
+                structured_document,
+                [TAG_1],
+                token_props_list
+            )
+
 
 class TestPredictAndAnnotateStructuredDocument(object):
-  def test_should_predict_and_annotate_single_token(self):
-    token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX)
-    structured_document = SimpleStructuredDocument(SimplePage(
-      lines=[SimpleLine([token_1])],
-      bounding_box=BOUNDING_BOX
-    ))
-    model = MagicMock()
-    model.predict.return_value = [[TAG_1]]
-    token_props = list(structured_document_to_token_props(structured_document))
-    X = [token_props_list_to_features(token_props)]
-    predict_and_annotate_structured_document(
-      structured_document,
-      model
-    )
-    assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1
-    model.predict.assert_called_with(X)
+    def test_should_predict_and_annotate_single_token(self):
+        token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX)
+        structured_document = SimpleStructuredDocument(SimplePage(
+            lines=[SimpleLine([token_1])],
+            bounding_box=BOUNDING_BOX
+        ))
+        model = MagicMock()
+        model.predict.return_value = [[TAG_1]]
+        token_props = list(structured_document_to_token_props(structured_document))
+        X = [token_props_list_to_features(token_props)]
+        predict_and_annotate_structured_document(
+            structured_document,
+            model
+        )
+        assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1
+        model.predict.assert_called_with(X)
diff --git a/sciencebeam_gym/models/text/crf/crfsuite_model.py b/sciencebeam_gym/models/text/crf/crfsuite_model.py
index a831eff..6650359 100644
--- a/sciencebeam_gym/models/text/crf/crfsuite_model.py
+++ b/sciencebeam_gym/models/text/crf/crfsuite_model.py
@@ -2,19 +2,22 @@ import logging
 
 from sklearn_crfsuite import CRF
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 DEFAULT_PARAMS = dict(
-  algorithm='lbfgs',
-  c1=0.1,
-  c2=0.1,
-  max_iterations=100,
-  all_possible_transitions=True
+    algorithm='lbfgs',
+    c1=0.1,
+    c2=0.1,
+    max_iterations=100,
+    all_possible_transitions=True
 )
 
+
 class CrfSuiteModel(CRF):
-  def __init__(self, **kwargs):
-    d = dict(DEFAULT_PARAMS)
-    d.update(kwargs)
-    super(CrfSuiteModel, self).__init__(**d)
+    def __init__(self, **kwargs):
+        d = dict(DEFAULT_PARAMS)
+        d.update(kwargs)
+        super(CrfSuiteModel, self).__init__(**d)
diff --git a/sciencebeam_gym/models/text/crf/crfsuite_model_test.py b/sciencebeam_gym/models/text/crf/crfsuite_model_test.py
index 54a6da8..5c9c1c4 100644
--- a/sciencebeam_gym/models/text/crf/crfsuite_model_test.py
+++ b/sciencebeam_gym/models/text/crf/crfsuite_model_test.py
@@ -3,24 +3,24 @@ import pickle
 from contextlib import contextmanager
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimplePage,
-  SimpleLine,
-  SimpleToken
+    SimpleStructuredDocument,
+    SimplePage,
+    SimpleLine,
+    SimpleToken
 )
 
 from sciencebeam_gym.models.text.feature_extractor import (
-  structured_document_to_token_props,
-  token_props_list_to_features,
-  token_props_list_to_labels
+    structured_document_to_token_props,
+    token_props_list_to_features,
+    token_props_list_to_labels
 )
 
 from sciencebeam_gym.models.text.crf.crfsuite_model import (
-  CrfSuiteModel
+    CrfSuiteModel
 )
 
 PAGE_BOUNDING_BOX = BoundingBox(0, 0, 100, 200)
@@ -34,89 +34,93 @@ TAG_1 = 'tag1'
 TAG_2 = 'tag2'
 TAG_3 = 'tag3'
 
+
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 @contextmanager
 def create_crf_suite_model():
-  model = CrfSuiteModel()
-  yield model
+    model = CrfSuiteModel()
+    yield model
+
 
 class TestCrfSuiteModel(object):
-  def test_should_learn_simple_sequence(self):
-    structured_document = SimpleStructuredDocument(
-      SimplePage([
-        SimpleLine([
-          SimpleToken(TEXT_1, tag=TAG_1),
-          SimpleToken(TEXT_2, tag=TAG_2),
-          SimpleToken(TEXT_3, tag=TAG_3)
-        ])
-      ], bounding_box=PAGE_BOUNDING_BOX)
-    )
-    token_props_list = list(structured_document_to_token_props(structured_document))
-    get_logger().debug('token_props_list:\n%s', token_props_list)
-    X = [token_props_list_to_features(token_props_list)]
-    y = [token_props_list_to_labels(token_props_list)]
-    get_logger().debug('X:\n%s', X)
-    get_logger().debug('y:\n%s', y)
-    with create_crf_suite_model() as model:
-      model.fit(X, y)
-      y_predicted = model.predict(X)
-      assert y_predicted == y
-
-  def test_should_learn_similar_sequence(self):
-    structured_document_train = SimpleStructuredDocument(
-      SimplePage([
-        SimpleLine([
-          SimpleToken(TEXT_1, tag=TAG_1),
-          SimpleToken(TEXT_1, tag=TAG_1),
-          SimpleToken(TEXT_2, tag=TAG_2),
-          SimpleToken(TEXT_3, tag=TAG_3)
-        ])
-      ], bounding_box=PAGE_BOUNDING_BOX)
-    )
-    structured_document_test = SimpleStructuredDocument(
-      SimplePage([
-        SimpleLine([
-          SimpleToken(TEXT_1, tag=TAG_1),
-          SimpleToken(TEXT_2, tag=TAG_2),
-          SimpleToken(TEXT_3, tag=TAG_3)
-        ])
-      ], bounding_box=PAGE_BOUNDING_BOX)
-    )
-    token_props_list_train = list(structured_document_to_token_props(structured_document_train))
-    X_train = [token_props_list_to_features(token_props_list_train)]
-    y_train = [token_props_list_to_labels(token_props_list_train)]
-
-    token_props_list_test = list(structured_document_to_token_props(structured_document_test))
-    X_test = [token_props_list_to_features(token_props_list_test)]
-    y_test = [token_props_list_to_labels(token_props_list_test)]
-
-    with create_crf_suite_model() as model:
-      model.fit(X_train, y_train)
-      y_predicted = model.predict(X_test)
-      assert y_predicted == y_test
-
-  def test_should_pickle_and_unpickle_model(self):
-    structured_document = SimpleStructuredDocument(
-      SimplePage([
-        SimpleLine([
-          SimpleToken(TEXT_1, tag=TAG_1),
-          SimpleToken(TEXT_2, tag=TAG_2),
-          SimpleToken(TEXT_3, tag=TAG_3)
-        ])
-      ], bounding_box=PAGE_BOUNDING_BOX)
-    )
-    token_props_list = list(structured_document_to_token_props(structured_document))
-    X = [token_props_list_to_features(token_props_list)]
-    y = [token_props_list_to_labels(token_props_list)]
-    with create_crf_suite_model() as model:
-      model.fit(X, y)
-      serialized_model = pickle.dumps(model)
-
-    model = pickle.loads(serialized_model)
-    y_predicted = model.predict(X)
-    assert y_predicted == y
+    def test_should_learn_simple_sequence(self):
+        structured_document = SimpleStructuredDocument(
+            SimplePage([
+                SimpleLine([
+                    SimpleToken(TEXT_1, tag=TAG_1),
+                    SimpleToken(TEXT_2, tag=TAG_2),
+                    SimpleToken(TEXT_3, tag=TAG_3)
+                ])
+            ], bounding_box=PAGE_BOUNDING_BOX)
+        )
+        token_props_list = list(structured_document_to_token_props(structured_document))
+        get_logger().debug('token_props_list:\n%s', token_props_list)
+        X = [token_props_list_to_features(token_props_list)]
+        y = [token_props_list_to_labels(token_props_list)]
+        get_logger().debug('X:\n%s', X)
+        get_logger().debug('y:\n%s', y)
+        with create_crf_suite_model() as model:
+            model.fit(X, y)
+            y_predicted = model.predict(X)
+            assert y_predicted == y
+
+    def test_should_learn_similar_sequence(self):
+        structured_document_train = SimpleStructuredDocument(
+            SimplePage([
+                SimpleLine([
+                    SimpleToken(TEXT_1, tag=TAG_1),
+                    SimpleToken(TEXT_1, tag=TAG_1),
+                    SimpleToken(TEXT_2, tag=TAG_2),
+                    SimpleToken(TEXT_3, tag=TAG_3)
+                ])
+            ], bounding_box=PAGE_BOUNDING_BOX)
+        )
+        structured_document_test = SimpleStructuredDocument(
+            SimplePage([
+                SimpleLine([
+                    SimpleToken(TEXT_1, tag=TAG_1),
+                    SimpleToken(TEXT_2, tag=TAG_2),
+                    SimpleToken(TEXT_3, tag=TAG_3)
+                ])
+            ], bounding_box=PAGE_BOUNDING_BOX)
+        )
+        token_props_list_train = list(structured_document_to_token_props(structured_document_train))
+        X_train = [token_props_list_to_features(token_props_list_train)]
+        y_train = [token_props_list_to_labels(token_props_list_train)]
+
+        token_props_list_test = list(structured_document_to_token_props(structured_document_test))
+        X_test = [token_props_list_to_features(token_props_list_test)]
+        y_test = [token_props_list_to_labels(token_props_list_test)]
+
+        with create_crf_suite_model() as model:
+            model.fit(X_train, y_train)
+            y_predicted = model.predict(X_test)
+            assert y_predicted == y_test
+
+    def test_should_pickle_and_unpickle_model(self):
+        structured_document = SimpleStructuredDocument(
+            SimplePage([
+                SimpleLine([
+                    SimpleToken(TEXT_1, tag=TAG_1),
+                    SimpleToken(TEXT_2, tag=TAG_2),
+                    SimpleToken(TEXT_3, tag=TAG_3)
+                ])
+            ], bounding_box=PAGE_BOUNDING_BOX)
+        )
+        token_props_list = list(structured_document_to_token_props(structured_document))
+        X = [token_props_list_to_features(token_props_list)]
+        y = [token_props_list_to_labels(token_props_list)]
+        with create_crf_suite_model() as model:
+            model.fit(X, y)
+            serialized_model = pickle.dumps(model)
+
+        model = pickle.loads(serialized_model)
+        y_predicted = model.predict(X)
+        assert y_predicted == y
diff --git a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py
index 50d7c80..30d6303 100644
--- a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py
+++ b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py
@@ -9,215 +9,227 @@ from six import raise_from
 from tqdm import tqdm
 
 from sciencebeam_utils.beam_utils.io import (
-  save_file_content
+    save_file_content
 )
 
 from sciencebeam_utils.utils.stopwatch import (
-  StopWatchRecorder
+    StopWatchRecorder
 )
 
 from sciencebeam_utils.utils.file_list import (
-  load_file_list
+    load_file_list
 )
 
 from sciencebeam_gym.structured_document.structured_document_loader import (
-  load_structured_document
+    load_structured_document
 )
 
 from sciencebeam_gym.preprocess.preprocessing_utils import (
-  parse_page_range
+    parse_page_range
 )
 
 from sciencebeam_gym.models.text.feature_extractor import (
-  structured_document_to_token_props,
-  token_props_list_to_features,
-  token_props_list_to_labels,
-  merge_with_cv_structured_document,
-  CV_TAG_SCOPE
+    structured_document_to_token_props,
+    token_props_list_to_features,
+    token_props_list_to_labels,
+    merge_with_cv_structured_document,
+    CV_TAG_SCOPE
 )
 
 from sciencebeam_gym.models.text.crf.crfsuite_model import (
-  CrfSuiteModel
+    CrfSuiteModel
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def parse_args(argv=None):
-  parser = argparse.ArgumentParser('Trains the CRF Suite model')
-  source = parser.add_argument_group('source')
-  source.add_argument(
-    '--source-file-list', type=str, required=True,
-    help='path to source file list (tsv/csv/lst)'
-  )
-  source.add_argument(
-    '--source-file-column', type=str, required=False,
-    default='url',
-    help='csv/tsv column (ignored for plain file list)'
-  )
-
-  cv_source = parser.add_argument_group('CV source')
-  cv_source.add_argument(
-    '--cv-source-file-list', type=str, required=False,
-    help='path to cv source file list (tsv/csv/lst)'
-    ' (must be in line with main source file list)'
-  )
-  source.add_argument(
-    '--cv-source-file-column', type=str, required=False,
-    default='url',
-    help='csv/tsv column (ignored for plain file list)'
-  )
-  source.add_argument(
-    '--cv-source-tag-scope', type=str, required=False,
-    default=CV_TAG_SCOPE,
-    help='source tag scope to get the cv tag from'
-  )
-
-  parser.add_argument(
-    '--limit', type=int, required=False,
-    help='limit the files to process'
-  )
-  parser.add_argument(
-    '--pages', type=parse_page_range, default=None,
-    help='only processes the selected pages'
-  )
-
-  output = parser.add_argument_group('output')
-  output.add_argument(
-    '--output-path', type=str, required=True,
-    help='output path to model'
-  )
-
-  parser.add_argument(
-    '--debug', action='store_true', default=False,
-    help='enable debug output'
-  )
-
-  return parser.parse_args(argv)
+    parser = argparse.ArgumentParser('Trains the CRF Suite model')
+    source = parser.add_argument_group('source')
+    source.add_argument(
+        '--source-file-list', type=str, required=True,
+        help='path to source file list (tsv/csv/lst)'
+    )
+    source.add_argument(
+        '--source-file-column', type=str, required=False,
+        default='url',
+        help='csv/tsv column (ignored for plain file list)'
+    )
+
+    cv_source = parser.add_argument_group('CV source')
+    cv_source.add_argument(
+        '--cv-source-file-list', type=str, required=False,
+        help='path to cv source file list (tsv/csv/lst)'
+        ' (must be in line with main source file list)'
+    )
+    source.add_argument(
+        '--cv-source-file-column', type=str, required=False,
+        default='url',
+        help='csv/tsv column (ignored for plain file list)'
+    )
+    source.add_argument(
+        '--cv-source-tag-scope', type=str, required=False,
+        default=CV_TAG_SCOPE,
+        help='source tag scope to get the cv tag from'
+    )
+
+    parser.add_argument(
+        '--limit', type=int, required=False,
+        help='limit the files to process'
+    )
+    parser.add_argument(
+        '--pages', type=parse_page_range, default=None,
+        help='only processes the selected pages'
+    )
+
+    output = parser.add_argument_group('output')
+    output.add_argument(
+        '--output-path', type=str, required=True,
+        help='output path to model'
+    )
+
+    parser.add_argument(
+        '--debug', action='store_true', default=False,
+        help='enable debug output'
+    )
+
+    return parser.parse_args(argv)
+
 
 def load_and_convert_to_token_props(filename, cv_filename, cv_source_tag_scope, page_range=None):
-  try:
-    structured_document = load_structured_document(filename, page_range=page_range)
-    if cv_filename:
-      cv_structured_document = load_structured_document(cv_filename, page_range=page_range)
-      structured_document = merge_with_cv_structured_document(
-        structured_document,
-        cv_structured_document,
-        cv_source_tag_scope=cv_source_tag_scope
-      )
-    return list(structured_document_to_token_props(
-      structured_document
-    ))
-  except StandardError as e:
-    raise_from(RuntimeError('failed to process %s (due to %s: %s)' % (filename, type(e), e)), e)
+    try:
+        structured_document = load_structured_document(filename, page_range=page_range)
+        if cv_filename:
+            cv_structured_document = load_structured_document(cv_filename, page_range=page_range)
+            structured_document = merge_with_cv_structured_document(
+                structured_document,
+                cv_structured_document,
+                cv_source_tag_scope=cv_source_tag_scope
+            )
+        return list(structured_document_to_token_props(
+            structured_document
+        ))
+    except StandardError as e:
+        raise_from(RuntimeError('failed to process %s (due to %s: %s)' % (filename, type(e), e)), e)
+
 
 def serialize_model(model):
-  return pickle.dumps(model)
+    return pickle.dumps(model)
+
 
 def submit_all(executor, fn, iterable):
-  return {executor.submit(fn, x) for x in iterable}
+    return {executor.submit(fn, x) for x in iterable}
+
 
 def load_token_props_list_by_document(
-  file_list, cv_file_list, cv_source_tag_scope, page_range=None, progress=True):
-
-  if not cv_file_list:
-    cv_file_list = [None] * len(file_list)
-
-  token_props_list_by_document = []
-  total = len(file_list)
-  error_count = 0
-  with tqdm(total=total, leave=False, desc='loading files', disable=not progress) as pbar:
-    with ThreadPoolExecutor(max_workers=50) as executor:
-      process_fn = lambda (filename, cv_filename): (
-        load_and_convert_to_token_props(
-          filename, cv_filename, cv_source_tag_scope=cv_source_tag_scope,
-          page_range=page_range
+        file_list, cv_file_list, cv_source_tag_scope, page_range=None, progress=True):
+
+    if not cv_file_list:
+        cv_file_list = [None] * len(file_list)
+
+    token_props_list_by_document = []
+    total = len(file_list)
+    error_count = 0
+    with tqdm(total=total, leave=False, desc='loading files', disable=not progress) as pbar:
+        with ThreadPoolExecutor(max_workers=50) as executor:
+            def process_fn((filename, cv_filename)):
+                return (
+                    load_and_convert_to_token_props(
+                        filename, cv_filename, cv_source_tag_scope=cv_source_tag_scope,
+                        page_range=page_range
+                    )
+                )
+            futures = submit_all(executor, process_fn, zip(file_list, cv_file_list))
+            for future in concurrent.futures.as_completed(futures):
+                try:
+                    token_props_list_by_document.append(future.result())
+                except StandardError as e:
+                    get_logger().warning(str(e), exc_info=e)
+                    error_count += 1
+                pbar.update(1)
+    if error_count:
+        get_logger().info(
+            'loading error count: %d (loaded: %d)', error_count, len(token_props_list_by_document)
         )
-      )
-      futures = submit_all(executor, process_fn, zip(file_list, cv_file_list))
-      for future in concurrent.futures.as_completed(futures):
-        try:
-          token_props_list_by_document.append(future.result())
-        except StandardError as e:
-          get_logger().warning(str(e), exc_info=e)
-          error_count += 1
-        pbar.update(1)
-  if error_count:
-    get_logger().info(
-      'loading error count: %d (loaded: %d)', error_count, len(token_props_list_by_document)
-    )
-  return token_props_list_by_document
+    return token_props_list_by_document
+
 
 def train_model(
-  file_list, cv_file_list, cv_source_tag_scope, page_range=None, progress=True):
+        file_list, cv_file_list, cv_source_tag_scope, page_range=None, progress=True):
+
+    stop_watch_recorder = StopWatchRecorder()
+    model = CrfSuiteModel()
 
-  stop_watch_recorder = StopWatchRecorder()
-  model = CrfSuiteModel()
+    stop_watch_recorder.start('loading files')
+    token_props_list_by_document = load_token_props_list_by_document(
+        file_list, cv_file_list, cv_source_tag_scope=cv_source_tag_scope,
+        page_range=page_range, progress=progress
+    )
 
-  stop_watch_recorder.start('loading files')
-  token_props_list_by_document = load_token_props_list_by_document(
-    file_list, cv_file_list, cv_source_tag_scope=cv_source_tag_scope,
-    page_range=page_range, progress=progress
-  )
+    assert token_props_list_by_document
 
-  assert token_props_list_by_document
+    stop_watch_recorder.start('converting to features')
+    X = [token_props_list_to_features(x) for x in token_props_list_by_document]
+    y = [token_props_list_to_labels(x) for x in token_props_list_by_document]
 
-  stop_watch_recorder.start('converting to features')
-  X = [token_props_list_to_features(x) for x in token_props_list_by_document]
-  y = [token_props_list_to_labels(x) for x in token_props_list_by_document]
+    get_logger().info('training model (with %d documents)', len(X))
+    stop_watch_recorder.start('train')
+    model.fit(X, y)
 
-  get_logger().info('training model (with %d documents)', len(X))
-  stop_watch_recorder.start('train')
-  model.fit(X, y)
+    stop_watch_recorder.start('serialize')
+    serialized_model = serialize_model(model)
 
-  stop_watch_recorder.start('serialize')
-  serialized_model = serialize_model(model)
+    stop_watch_recorder.stop()
+    get_logger().info('timings: %s', stop_watch_recorder)
 
-  stop_watch_recorder.stop()
-  get_logger().info('timings: %s', stop_watch_recorder)
+    return serialized_model
 
-  return serialized_model
 
 def save_model(output_filename, model_bytes):
-  get_logger().info('saving model to %s', output_filename)
-  save_file_content(output_filename, model_bytes)
+    get_logger().info('saving model to %s', output_filename)
+    save_file_content(output_filename, model_bytes)
+
 
 def run(opt):
-  file_list = load_file_list(
-    opt.source_file_list,
-    opt.source_file_column,
-    limit=opt.limit
-  )
-  if opt.cv_source_file_list:
-    cv_file_list = load_file_list(
-      opt.cv_source_file_list,
-      opt.cv_source_file_column,
-      limit=opt.limit
+    file_list = load_file_list(
+        opt.source_file_list,
+        opt.source_file_column,
+        limit=opt.limit
+    )
+    if opt.cv_source_file_list:
+        cv_file_list = load_file_list(
+            opt.cv_source_file_list,
+            opt.cv_source_file_column,
+            limit=opt.limit
+        )
+    else:
+        cv_file_list = None
+    get_logger().info(
+        'training using %d files (limit %d), page range: %s',
+        len(file_list), opt.limit, opt.pages
     )
-  else:
-    cv_file_list = None
-  get_logger().info(
-    'training using %d files (limit %d), page range: %s',
-    len(file_list), opt.limit, opt.pages
-  )
-  save_model(
-    opt.output_path,
-    train_model(
-      file_list, cv_file_list, cv_source_tag_scope=opt.cv_source_tag_scope,
-      page_range=opt.pages
+    save_model(
+        opt.output_path,
+        train_model(
+            file_list, cv_file_list, cv_source_tag_scope=opt.cv_source_tag_scope,
+            page_range=opt.pages
+        )
     )
-  )
+
 
 def main(argv=None):
-  args = parse_args(argv)
+    args = parse_args(argv)
+
+    if args.debug:
+        logging.getLogger().setLevel('DEBUG')
 
-  if args.debug:
-    logging.getLogger().setLevel('DEBUG')
+    run(args)
 
-  run(args)
 
 if __name__ == '__main__':
-  logging.basicConfig(level='INFO')
-  logging.getLogger('oauth2client').setLevel('WARN')
+    logging.basicConfig(level='INFO')
+    logging.getLogger('oauth2client').setLevel('WARN')
 
-  main()
+    main()
diff --git a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py
index 24cd23e..a571262 100644
--- a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py
+++ b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py
@@ -3,27 +3,27 @@ from mock import patch, Mock, ANY
 import pytest
 
 from sciencebeam_utils.utils.collection import (
-  to_namedtuple
+    to_namedtuple
 )
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimpleLine,
-  SimpleToken
+    SimpleStructuredDocument,
+    SimpleLine,
+    SimpleToken
 )
 
 from sciencebeam_gym.models.text.feature_extractor import (
-  CV_TAG_SCOPE
+    CV_TAG_SCOPE
 )
 
 import sciencebeam_gym.models.text.crf.crfsuite_training_pipeline as crfsuite_training_pipeline
 from sciencebeam_gym.models.text.crf.crfsuite_training_pipeline import (
-  load_and_convert_to_token_props,
-  load_token_props_list_by_document,
-  train_model,
-  save_model,
-  run,
-  main
+    load_and_convert_to_token_props,
+    load_token_props_list_by_document,
+    train_model,
+    save_model,
+    run,
+    main
 )
 
 SOURCE_FILE_LIST_PATH = '.temp/source-file-list.lst'
@@ -43,252 +43,320 @@ TAG_1 = 'tag1'
 TAG_2 = 'tag2'
 
 DEFAULT_ARGS = dict(
-  source_file_column='url',
-  cv_source_file_list=None,
-  cv_source_file_column='url',
-  cv_source_tag_scope=CV_TAG_SCOPE
+    source_file_column='url',
+    cv_source_file_list=None,
+    cv_source_file_column='url',
+    cv_source_tag_scope=CV_TAG_SCOPE
 )
 
+
+@pytest.fixture(name='load_structured_document_mock')
+def _load_structured_document_mock():
+    with patch.object(crfsuite_training_pipeline, 'load_structured_document') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='structured_document_to_token_props_mock')
+def _structured_document_to_token_props_mock():
+    with patch.object(crfsuite_training_pipeline, 'structured_document_to_token_props') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='token_props_list_to_features_mock')
+def _token_props_list_to_features_mock():
+    with patch.object(crfsuite_training_pipeline, 'token_props_list_to_features') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='load_token_props_list_by_document_mock')
+def _load_token_props_list_by_document_mock():
+    with patch.object(crfsuite_training_pipeline, 'load_token_props_list_by_document') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='load_and_convert_to_token_props_mock')
+def _load_and_convert_to_token_props_mock():
+    with patch.object(crfsuite_training_pipeline, 'load_and_convert_to_token_props') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='CrfSuiteModel_mock')
+def _CrfSuiteModel_mock():
+    with patch.object(crfsuite_training_pipeline, 'CrfSuiteModel') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='pickle_mock')
+def _pickle_mock():
+    with patch.object(crfsuite_training_pipeline, 'pickle') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='save_file_content_mock')
+def _save_file_content_mock():
+    with patch.object(crfsuite_training_pipeline, 'save_file_content') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='load_file_list_mock')
+def _load_file_list_mock():
+    with patch.object(crfsuite_training_pipeline, 'load_file_list') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='train_model_mock')
+def _train_model_mock():
+    with patch.object(crfsuite_training_pipeline, 'train_model') as _mock:
+        yield _mock
+
+
+@pytest.fixture(name='save_model_mock')
+def _save_model_mock():
+    with patch.object(crfsuite_training_pipeline, 'save_model') as _mock:
+        yield _mock
+
+
 class TestLoadAndConvertToTokenProps(object):
-  def test_should_load_and_convert_document(self):
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'load_structured_document') as load_structured_document_mock:
-      with patch.object(m, 'structured_document_to_token_props') as \
-      structured_document_to_token_props_mock:
+    def test_should_load_and_convert_document(
+            self,
+            load_structured_document_mock,
+            structured_document_to_token_props_mock):
 
         load_and_convert_to_token_props(
-          FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE,
-          page_range=PAGE_RANGE
+            FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE
         )
         load_structured_document_mock.assert_called_with(
-          FILE_1, page_range=PAGE_RANGE
+            FILE_1, page_range=PAGE_RANGE
         )
         structured_document_to_token_props_mock.assert_called_with(
-          load_structured_document_mock.return_value
+            load_structured_document_mock.return_value
         )
 
-  def test_should_load_and_convert_document_with_cv(self):
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'load_structured_document') as load_structured_document_mock:
-      with patch.object(m, 'structured_document_to_token_props') as \
-      structured_document_to_token_props_mock:
+    def test_should_load_and_convert_document_with_cv(
+            self,
+            load_structured_document_mock,
+            structured_document_to_token_props_mock):
 
         load_and_convert_to_token_props(
-          FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE,
-          page_range=PAGE_RANGE
+            FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE
         )
         load_structured_document_mock.assert_any_call(
-          FILE_1, page_range=PAGE_RANGE
+            FILE_1, page_range=PAGE_RANGE
         )
         structured_document_to_token_props_mock.assert_called_with(
-          load_structured_document_mock.return_value
+            load_structured_document_mock.return_value
         )
 
-  def test_should_merge_doc_and_scope_cv_tag(self):
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag=TAG_1)
-    ])])
-    cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag=TAG_2, tag_scope=CV_TAG_SCOPE)
-    ])])
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'load_structured_document') as load_structured_document_mock:
-      with patch.object(m, 'structured_document_to_token_props') as \
-      structured_document_to_token_props_mock:
-
+    def test_should_merge_doc_and_scope_cv_tag(
+            self,
+            load_structured_document_mock,
+            structured_document_to_token_props_mock):
+
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag=TAG_1)
+        ])])
+        cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag=TAG_2, tag_scope=CV_TAG_SCOPE)
+        ])])
         load_structured_document_mock.side_effect = [
-          structured_document,
-          cv_structured_document
+            structured_document,
+            cv_structured_document
         ]
         load_and_convert_to_token_props(
-          FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE,
-          page_range=PAGE_RANGE
+            FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE
         )
         load_structured_document_mock.assert_any_call(
-          FILE_1, page_range=PAGE_RANGE
+            FILE_1, page_range=PAGE_RANGE
         )
         structured_document_arg = structured_document_to_token_props_mock.call_args[0][0]
         assert [
-          structured_document_arg.get_tag_by_scope(t)
-          for t in structured_document_arg.iter_all_tokens()
+            structured_document_arg.get_tag_by_scope(t)
+            for t in structured_document_arg.iter_all_tokens()
         ] == [{None: TAG_1, CV_TAG_SCOPE: TAG_2}]
 
+
 class TestLoadTokenPropsListByDocument(object):
-  def test_should_load_single_file_without_cv(self):
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'load_and_convert_to_token_props') as \
-      load_and_convert_to_token_props_mock:
-
-      result = load_token_props_list_by_document(
-        [FILE_1], None, cv_source_tag_scope=CV_TAG_SCOPE,
-        page_range=PAGE_RANGE, progress=False
-      )
-      load_and_convert_to_token_props_mock.assert_called_with(
-        FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE, page_range=PAGE_RANGE
-      )
-      assert result == [load_and_convert_to_token_props_mock.return_value]
-
-  def test_should_load_single_file_with_cv(self):
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'load_and_convert_to_token_props') as \
-      load_and_convert_to_token_props_mock:
-
-      result = load_token_props_list_by_document(
-        [FILE_1], [FILE_2], cv_source_tag_scope=CV_TAG_SCOPE,
-        page_range=PAGE_RANGE, progress=False
-      )
-      load_and_convert_to_token_props_mock.assert_called_with(
-        FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE,
-        page_range=PAGE_RANGE
-      )
-      assert result == [load_and_convert_to_token_props_mock.return_value]
-
-  def test_should_load_multiple_files(self):
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'load_and_convert_to_token_props') as \
-      load_and_convert_to_token_props_mock:
-
-      return_values = [Mock(), Mock()]
-      load_and_convert_to_token_props_mock.side_effect = return_values
-      result = load_token_props_list_by_document(
-        [FILE_1, FILE_2], None, cv_source_tag_scope=CV_TAG_SCOPE,
-        page_range=PAGE_RANGE, progress=False
-      )
-      load_and_convert_to_token_props_mock.assert_any_call(
-        FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE,
-        page_range=PAGE_RANGE
-      )
-      load_and_convert_to_token_props_mock.assert_any_call(
-        FILE_2, None, cv_source_tag_scope=CV_TAG_SCOPE,
-        page_range=PAGE_RANGE
-      )
-      assert set(result) == set(return_values)
-
-  def test_should_skip_files_with_errors(self):
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'load_and_convert_to_token_props') as \
-      load_and_convert_to_token_props_mock:
-
-      valid_response = Mock()
-      load_and_convert_to_token_props_mock.side_effect = [
-        RuntimeError('oh dear'), valid_response
-      ]
-      result = load_token_props_list_by_document(
-        [FILE_1, FILE_2], None, cv_source_tag_scope=CV_TAG_SCOPE,
-        page_range=PAGE_RANGE, progress=False
-      )
-      assert result == [valid_response]
+    def test_should_load_single_file_without_cv(
+            self,
+            load_and_convert_to_token_props_mock):
+
+        result = load_token_props_list_by_document(
+            [FILE_1], None, cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE, progress=False
+        )
+        load_and_convert_to_token_props_mock.assert_called_with(
+            FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE, page_range=PAGE_RANGE
+        )
+        assert result == [load_and_convert_to_token_props_mock.return_value]
 
+    def test_should_load_single_file_with_cv(
+            self,
+            load_and_convert_to_token_props_mock):
+
+        result = load_token_props_list_by_document(
+            [FILE_1], [FILE_2], cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE, progress=False
+        )
+        load_and_convert_to_token_props_mock.assert_called_with(
+            FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE
+        )
+        assert result == [load_and_convert_to_token_props_mock.return_value]
+
+    def test_should_load_multiple_files(
+            self,
+            load_and_convert_to_token_props_mock):
+
+        return_values = [Mock(), Mock()]
+        load_and_convert_to_token_props_mock.side_effect = return_values
+        result = load_token_props_list_by_document(
+            [FILE_1, FILE_2], None, cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE, progress=False
+        )
+        load_and_convert_to_token_props_mock.assert_any_call(
+            FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE
+        )
+        load_and_convert_to_token_props_mock.assert_any_call(
+            FILE_2, None, cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE
+        )
+        assert set(result) == set(return_values)
+
+    def test_should_skip_files_with_errors(
+            self,
+            load_and_convert_to_token_props_mock):
+
+        valid_response = Mock()
+        load_and_convert_to_token_props_mock.side_effect = [
+            RuntimeError('oh dear'), valid_response
+        ]
+        result = load_token_props_list_by_document(
+            [FILE_1, FILE_2], None, cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE, progress=False
+        )
+        assert result == [valid_response]
+
+
+@pytest.mark.usefixtures(
+    'token_props_list_to_features_mock',
+    'CrfSuiteModel_mock',
+    'pickle_mock'
+)
 class TestTrainModel(object):
-  def test_should_train_on_single_file(self):
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'load_token_props_list_by_document') as \
-      load_token_props_list_by_document_mock:
-      with patch.object(m, 'CrfSuiteModel') as CrfSuiteModel_mock:
-        with patch.object(m, 'pickle') as pickle:
-          with patch.object(m, 'token_props_list_to_features') as _:
+    def test_should_train_on_single_file(
+            self,
+            load_token_props_list_by_document_mock,
+            CrfSuiteModel_mock,
+            pickle_mock):
+
+        train_model(
+            [FILE_1], [FILE_2],
+            cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE, progress=False
+        )
+        load_token_props_list_by_document_mock.assert_called_with(
+            [FILE_1], [FILE_2], cv_source_tag_scope=CV_TAG_SCOPE,
+            page_range=PAGE_RANGE, progress=False
+        )
+        model = CrfSuiteModel_mock.return_value
+        model.fit.assert_called_with(ANY, ANY)
+        pickle_mock.dumps.assert_called_with(model)
+
+    def test_should_raise_error_if_no_documents_have_been_loaded(
+            self,
+            load_token_props_list_by_document_mock):
+
+        with pytest.raises(AssertionError):
+            load_token_props_list_by_document_mock.return_value = []
             train_model(
-              [FILE_1], [FILE_2],
-              cv_source_tag_scope=CV_TAG_SCOPE,
-              page_range=PAGE_RANGE, progress=False
-            )
-            load_token_props_list_by_document_mock.assert_called_with(
-              [FILE_1], [FILE_2], cv_source_tag_scope=CV_TAG_SCOPE,
-              page_range=PAGE_RANGE, progress=False
-            )
-            model = CrfSuiteModel_mock.return_value
-            model.fit.assert_called_with(ANY, ANY)
-            pickle.dumps.assert_called_with(model)
-
-  def test_should_raise_error_if_no_documents_have_been_loaded(self):
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'load_token_props_list_by_document') as \
-      load_token_props_list_by_document_mock:
-      with patch.object(m, 'CrfSuiteModel'):
-        with patch.object(m, 'pickle'):
-          with patch.object(m, 'token_props_list_to_features') as _:
-            with pytest.raises(AssertionError):
-              load_token_props_list_by_document_mock.return_value = []
-              train_model(
                 [FILE_1], [FILE_2],
                 cv_source_tag_scope=CV_TAG_SCOPE,
                 page_range=PAGE_RANGE
-              )
+            )
+
 
 class TestSaveModel(object):
-  def test_should_call_save_content(self):
-    m = crfsuite_training_pipeline
-    with patch.object(m, 'save_file_content') as save_file_content:
-      save_model(FILE_1, MODEL_DATA)
-      save_file_content.assert_called_with(FILE_1, MODEL_DATA)
+    def test_should_call_save_content(
+            self,
+            save_file_content_mock):
+        save_model(FILE_1, MODEL_DATA)
+        save_file_content_mock.assert_called_with(FILE_1, MODEL_DATA)
+
 
 class TestRun(object):
-  def test_should_train_on_single_file(self):
-    m = crfsuite_training_pipeline
-    opt = to_namedtuple(
-      DEFAULT_ARGS,
-      source_file_list=SOURCE_FILE_LIST_PATH,
-      output_path=FILE_1,
-      limit=2,
-      pages=PAGE_RANGE
-    )
-    with patch.object(m, 'load_file_list') as load_file_list:
-      with patch.object(m, 'train_model') as train_model_mock:
-        with patch.object(m, 'save_model') as save_model_mock:
-          run(opt)
-          load_file_list.assert_called_with(
+    def test_should_train_on_single_file(
+            self,
+            load_file_list_mock,
+            train_model_mock,
+            save_model_mock):
+
+        opt = to_namedtuple(
+            DEFAULT_ARGS,
+            source_file_list=SOURCE_FILE_LIST_PATH,
+            output_path=FILE_1,
+            limit=2,
+            pages=PAGE_RANGE
+        )
+        run(opt)
+        load_file_list_mock.assert_called_with(
             opt.source_file_list, opt.source_file_column, limit=opt.limit
-          )
-          train_model_mock.assert_called_with(
-            load_file_list.return_value,
+        )
+        train_model_mock.assert_called_with(
+            load_file_list_mock.return_value,
             None,
             cv_source_tag_scope=opt.cv_source_tag_scope,
             page_range=PAGE_RANGE
-          )
-          save_model_mock.assert_called_with(
+        )
+        save_model_mock.assert_called_with(
             opt.output_path,
             train_model_mock.return_value
-          )
-
-  def test_should_train_on_single_file_with_cv_output(self):
-    m = crfsuite_training_pipeline
-    opt = to_namedtuple(
-      DEFAULT_ARGS,
-      source_file_list=SOURCE_FILE_LIST_PATH,
-      cv_source_file_list=CV_SOURCE_FILE_LIST_PATH,
-      output_path=FILE_1,
-      limit=2,
-      pages=PAGE_RANGE
-    )
-    file_list = [FILE_1, FILE_2]
-    cv_file_list = ['cv.' + FILE_1, 'cv.' + FILE_2]
-    with patch.object(m, 'load_file_list') as load_file_list:
-      with patch.object(m, 'train_model') as train_model_mock:
-        with patch.object(m, 'save_model') as save_model_mock:
-          load_file_list.side_effect = [file_list, cv_file_list]
-          run(opt)
-          load_file_list.assert_any_call(
+        )
+
+    def test_should_train_on_single_file_with_cv_output(
+            self,
+            load_file_list_mock,
+            train_model_mock,
+            save_model_mock):
+
+        opt = to_namedtuple(
+            DEFAULT_ARGS,
+            source_file_list=SOURCE_FILE_LIST_PATH,
+            cv_source_file_list=CV_SOURCE_FILE_LIST_PATH,
+            output_path=FILE_1,
+            limit=2,
+            pages=PAGE_RANGE
+        )
+        file_list = [FILE_1, FILE_2]
+        cv_file_list = ['cv.' + FILE_1, 'cv.' + FILE_2]
+        load_file_list_mock.side_effect = [file_list, cv_file_list]
+        run(opt)
+        load_file_list_mock.assert_any_call(
             opt.source_file_list, opt.source_file_column, limit=opt.limit
-          )
-          load_file_list.assert_any_call(
+        )
+        load_file_list_mock.assert_any_call(
             opt.cv_source_file_list, opt.cv_source_file_column, limit=opt.limit
-          )
-          train_model_mock.assert_called_with(
+        )
+        train_model_mock.assert_called_with(
             file_list,
             cv_file_list,
             cv_source_tag_scope=opt.cv_source_tag_scope,
             page_range=PAGE_RANGE
-          )
-          save_model_mock.assert_called_with(
+        )
+        save_model_mock.assert_called_with(
             opt.output_path,
             train_model_mock.return_value
-          )
+        )
+
 
 class TestMain(object):
-  def test_should(self):
-    argv = ['--source-file-list', SOURCE_FILE_LIST_PATH]
-    with patch.object(crfsuite_training_pipeline, 'parse_args') as parse_args_mock:
-      with patch.object(crfsuite_training_pipeline, 'run') as run_mock:
-        main(argv)
-        parse_args_mock.assert_called_with(argv)
-        run_mock.assert_called_with(parse_args_mock.return_value)
+    def test_should(self):
+        argv = ['--source-file-list', SOURCE_FILE_LIST_PATH]
+        with patch.object(crfsuite_training_pipeline, 'parse_args') as parse_args_mock:
+            with patch.object(crfsuite_training_pipeline, 'run') as run_mock:
+                main(argv)
+                parse_args_mock.assert_called_with(argv)
+                run_mock.assert_called_with(parse_args_mock.return_value)
diff --git a/sciencebeam_gym/models/text/feature_extractor.py b/sciencebeam_gym/models/text/feature_extractor.py
index b8e11ff..b0fce1c 100644
--- a/sciencebeam_gym/models/text/feature_extractor.py
+++ b/sciencebeam_gym/models/text/feature_extractor.py
@@ -1,119 +1,125 @@
 from functools import partial
 
 from sciencebeam_gym.inference_model.annotate_using_predictions import (
-  CV_TAG_SCOPE
+    CV_TAG_SCOPE
 )
 
 from sciencebeam_gym.structured_document import (
-  merge_token_tag
+    merge_token_tag
 )
 
 NONE_TAG = 'O'
 
+
 def structured_document_to_token_props(structured_document):
-  pages = list(structured_document.get_pages())
-  for page_index, page in enumerate(pages):
-    page_bounding_box = structured_document.get_bounding_box(page)
-    assert page_bounding_box is not None
-    page_width = page_bounding_box.width
-    page_height = page_bounding_box.height
-    page_info = {
-      'index': page_index,
-      'count': len(pages),
-      'width': page_width,
-      'height': page_height
-    }
-    page_rx = 1.0 / page_width
-    page_ry = 1.0 / page_height
-    lines = list(structured_document.get_lines_of_page(page))
-    for line_index, line in enumerate(lines):
-      line_tokens = list(structured_document.get_tokens_of_line(line))
-      line_info = {
-        'index': line_index,
-        'count': len(lines)
-      }
-      for line_token_index, token in enumerate(line_tokens):
-        line_token_info = {
-          'index': line_token_index,
-          'count': len(line_tokens)
-        }
-        bounding_box = structured_document.get_bounding_box(token)
-        rel_bounding_box = (
-          bounding_box.scale_by(page_rx, page_ry)
-          if bounding_box else None
-        )
-        yield {
-          'text': structured_document.get_text(token),
-          'tag': structured_document.get_tag(token),
-          'scoped_tags': {
-            k: v
-            for k, v in structured_document.get_tag_by_scope(token).items()
-            if k
-          },
-          'bounding_box': bounding_box,
-          'rel_bounding_box': rel_bounding_box,
-          'line_token': line_token_info,
-          'page': page_info,
-          'line': line_info
+    pages = list(structured_document.get_pages())
+    for page_index, page in enumerate(pages):
+        page_bounding_box = structured_document.get_bounding_box(page)
+        assert page_bounding_box is not None
+        page_width = page_bounding_box.width
+        page_height = page_bounding_box.height
+        page_info = {
+            'index': page_index,
+            'count': len(pages),
+            'width': page_width,
+            'height': page_height
         }
+        page_rx = 1.0 / page_width
+        page_ry = 1.0 / page_height
+        lines = list(structured_document.get_lines_of_page(page))
+        for line_index, line in enumerate(lines):
+            line_tokens = list(structured_document.get_tokens_of_line(line))
+            line_info = {
+                'index': line_index,
+                'count': len(lines)
+            }
+            for line_token_index, token in enumerate(line_tokens):
+                line_token_info = {
+                    'index': line_token_index,
+                    'count': len(line_tokens)
+                }
+                bounding_box = structured_document.get_bounding_box(token)
+                rel_bounding_box = (
+                    bounding_box.scale_by(page_rx, page_ry)
+                    if bounding_box else None
+                )
+                yield {
+                    'text': structured_document.get_text(token),
+                    'tag': structured_document.get_tag(token),
+                    'scoped_tags': {
+                        k: v
+                        for k, v in structured_document.get_tag_by_scope(token).items()
+                        if k
+                    },
+                    'bounding_box': bounding_box,
+                    'rel_bounding_box': rel_bounding_box,
+                    'line_token': line_token_info,
+                    'page': page_info,
+                    'line': line_info
+                }
+
 
 def token_props_features(token_props, prefix=''):
-  word = token_props.get('text') or ''
-  word_lower = word.lower()
-  d = {
-    prefix + 'word.lower': word_lower,
-    prefix + 'word[:1]': word_lower[:1],
-    prefix + 'word[-3:]': word_lower[-3:],
-    prefix + 'word[-2:]': word_lower[-2:],
-    prefix + 'word[:1].isupper': word[:1].istitle(),
-    prefix + 'word.isupper': word.isupper(),
-    prefix + 'word.isdigit': word.isdigit()
-  }
-  for scope, tag in token_props.get('scoped_tags', {}).items():
-    d[prefix + scope + '.tag'] = tag
-  return d
+    word = token_props.get('text') or ''
+    word_lower = word.lower()
+    d = {
+        prefix + 'word.lower': word_lower,
+        prefix + 'word[:1]': word_lower[:1],
+        prefix + 'word[-3:]': word_lower[-3:],
+        prefix + 'word[-2:]': word_lower[-2:],
+        prefix + 'word[:1].isupper': word[:1].istitle(),
+        prefix + 'word.isupper': word.isupper(),
+        prefix + 'word.isdigit': word.isdigit()
+    }
+    for scope, tag in token_props.get('scoped_tags', {}).items():
+        d[prefix + scope + '.tag'] = tag
+    return d
+
 
 def token_props_to_features(token_props_list, i):
-  features = token_props_features(token_props_list[i])
-  if i > 0:
-    pass
-  for rel_token_index in [-2, -1, 1, 2]:
-    abs_token_index = i + rel_token_index
-    if abs_token_index < 0:
-      features['BOD[%d]' % rel_token_index] = True
-    elif abs_token_index >= len(token_props_list):
-      features['EOD[%d]' % rel_token_index] = True
-    else:
-      features.update(token_props_features(
-        token_props_list[abs_token_index],
-        str(rel_token_index) + ':'
-      ))
-  return features
+    features = token_props_features(token_props_list[i])
+    if i > 0:
+        pass
+    for rel_token_index in [-2, -1, 1, 2]:
+        abs_token_index = i + rel_token_index
+        if abs_token_index < 0:
+            features['BOD[%d]' % rel_token_index] = True
+        elif abs_token_index >= len(token_props_list):
+            features['EOD[%d]' % rel_token_index] = True
+        else:
+            features.update(token_props_features(
+                token_props_list[abs_token_index],
+                str(rel_token_index) + ':'
+            ))
+    return features
 
 
 def token_props_list_to_features(token_props_list):
-  token_props_list = list(token_props_list)
-  return [token_props_to_features(token_props_list, i) for i in range(len(token_props_list))]
+    token_props_list = list(token_props_list)
+    return [token_props_to_features(token_props_list, i) for i in range(len(token_props_list))]
+
 
 def remove_labels_from_token_props_list(token_props_list):
-  return [
-    {k: v for k, v in token_props.items() if k != 'tag'}
-    for token_props in token_props_list
-  ]
+    return [
+        {k: v for k, v in token_props.items() if k != 'tag'}
+        for token_props in token_props_list
+    ]
+
 
 def token_props_list_to_labels(token_props_list):
-  return [token_props.get('tag') or NONE_TAG for token_props in token_props_list]
+    return [token_props.get('tag') or NONE_TAG for token_props in token_props_list]
+
 
 def merge_with_cv_structured_document(
-  structured_document, cv_structured_document,
-  cv_source_tag_scope=CV_TAG_SCOPE):
-
-  structured_document.merge_with(
-    cv_structured_document,
-    partial(
-      merge_token_tag,
-      source_scope=cv_source_tag_scope,
-      target_scope=CV_TAG_SCOPE
+        structured_document, cv_structured_document,
+        cv_source_tag_scope=CV_TAG_SCOPE):
+
+    structured_document.merge_with(
+        cv_structured_document,
+        partial(
+            merge_token_tag,
+            source_scope=cv_source_tag_scope,
+            target_scope=CV_TAG_SCOPE
+        )
     )
-  )
-  return structured_document
+    return structured_document
diff --git a/sciencebeam_gym/models/text/feature_extractor_test.py b/sciencebeam_gym/models/text/feature_extractor_test.py
index acaa016..4ee97b1 100644
--- a/sciencebeam_gym/models/text/feature_extractor_test.py
+++ b/sciencebeam_gym/models/text/feature_extractor_test.py
@@ -1,22 +1,22 @@
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimplePage,
-  SimpleLine,
-  SimpleToken
+    SimpleStructuredDocument,
+    SimplePage,
+    SimpleLine,
+    SimpleToken
 )
 
 from sciencebeam_gym.models.text.feature_extractor import (
-  structured_document_to_token_props,
-  token_props_list_to_features,
-  token_props_list_to_labels,
-  remove_labels_from_token_props_list,
-  merge_with_cv_structured_document,
-  NONE_TAG,
-  CV_TAG_SCOPE
+    structured_document_to_token_props,
+    token_props_list_to_features,
+    token_props_list_to_labels,
+    remove_labels_from_token_props_list,
+    merge_with_cv_structured_document,
+    NONE_TAG,
+    CV_TAG_SCOPE
 )
 
 PAGE_BOUNDING_BOX = BoundingBox(0, 0, 100, 200)
@@ -32,287 +32,294 @@ TAG_3 = 'tag3'
 
 SCOPE_1 = 'scope1'
 
+
 class TestStructuredDocumentToTokenProps(object):
-  def test_should_return_empty_token_list_if_document_has_no_pages(self):
-    structured_document = SimpleStructuredDocument([])
-    assert list(structured_document_to_token_props(
-      structured_document
-    )) == []
-
-  def test_should_return_empty_token_list_if_document_has_no_lines(self):
-    structured_document = SimpleStructuredDocument(
-      SimplePage([], bounding_box=PAGE_BOUNDING_BOX)
-    )
-    assert list(structured_document_to_token_props(
-      structured_document
-    )) == []
-
-  def test_should_return_single_token_text(self):
-    structured_document = SimpleStructuredDocument(
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    )
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    assert [t.get('text') for t in result] == [TEXT_1]
-
-  def test_should_return_multiple_token_texts(self):
-    structured_document = SimpleStructuredDocument(
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1),
-        SimpleToken(TEXT_2),
-        SimpleToken(TEXT_3)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    )
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    assert [t.get('text') for t in result] == [TEXT_1, TEXT_2, TEXT_3]
-
-  def test_should_return_tag(self):
-    structured_document = SimpleStructuredDocument([
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1, tag=TAG_1),
-        SimpleToken(TEXT_2)
-      ])], bounding_box=PAGE_BOUNDING_BOX),
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_3, tag=TAG_3)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    ])
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    assert [t.get('tag') for t in result] == [TAG_1, None, TAG_3]
-
-  def test_should_return_scoped_tags(self):
-    structured_document = SimpleStructuredDocument([
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1, tag=TAG_1),
-        SimpleToken(TEXT_2)
-      ])], bounding_box=PAGE_BOUNDING_BOX),
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_3, tag=TAG_3, tag_scope=SCOPE_1)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    ])
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    assert [t.get('scoped_tags') for t in result] == [{}, {}, {SCOPE_1: TAG_3}]
-
-  def test_should_return_bounding_box(self):
-    structured_document = SimpleStructuredDocument([
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1, bounding_box=TOKEN_BOUNDING_BOX)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    ])
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    assert [t.get('bounding_box') for t in result] == [TOKEN_BOUNDING_BOX]
-
-  def test_should_return_rel_bounding_box(self):
-    structured_document = SimpleStructuredDocument([
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1, bounding_box=TOKEN_BOUNDING_BOX)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    ])
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    assert [t.get('rel_bounding_box') for t in result] == [
-      TOKEN_BOUNDING_BOX.scale_by(
-        1.0 / PAGE_BOUNDING_BOX.width,
-        1.0 / PAGE_BOUNDING_BOX.height
-      )
-    ]
-
-  def test_should_return_page_index_and_page_count(self):
-    structured_document = SimpleStructuredDocument([
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1),
-        SimpleToken(TEXT_2)
-      ])], bounding_box=PAGE_BOUNDING_BOX),
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_3)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    ])
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    pages = [t.get('page') for t in result]
-    assert [p.get('index') for p in pages] == [0, 0, 1]
-    assert [p.get('count') for p in pages] == [2, 2, 2]
-
-  def test_should_return_page_width_and_height(self):
-    structured_document = SimpleStructuredDocument([
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    ])
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    pages = [t.get('page') for t in result]
-    assert [p.get('width') for p in pages] == [PAGE_BOUNDING_BOX.width]
-    assert [p.get('height') for p in pages] == [PAGE_BOUNDING_BOX.height]
-
-  def test_should_return_line_index_and_page_count(self):
-    structured_document = SimpleStructuredDocument([
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1)
-      ]), SimpleLine([
-        SimpleToken(TEXT_2)
-      ])], bounding_box=PAGE_BOUNDING_BOX),
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_3)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    ])
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    lines = [t.get('line') for t in result]
-    assert [l.get('index') for l in lines] == [0, 1, 0]
-    assert [l.get('count') for l in lines] == [2, 2, 1]
-
-  def test_should_return_line_token_index_and_page_count(self):
-    structured_document = SimpleStructuredDocument([
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_1),
-        SimpleToken(TEXT_2)
-      ])], bounding_box=PAGE_BOUNDING_BOX),
-      SimplePage([SimpleLine([
-        SimpleToken(TEXT_3)
-      ])], bounding_box=PAGE_BOUNDING_BOX)
-    ])
-    result = list(structured_document_to_token_props(
-      structured_document
-    ))
-    line_tokens = [t.get('line_token') for t in result]
-    assert [t.get('index') for t in line_tokens] == [0, 1, 0]
-    assert [t.get('count') for t in line_tokens] == [2, 2, 1]
+    def test_should_return_empty_token_list_if_document_has_no_pages(self):
+        structured_document = SimpleStructuredDocument([])
+        assert list(structured_document_to_token_props(
+            structured_document
+        )) == []
+
+    def test_should_return_empty_token_list_if_document_has_no_lines(self):
+        structured_document = SimpleStructuredDocument(
+            SimplePage([], bounding_box=PAGE_BOUNDING_BOX)
+        )
+        assert list(structured_document_to_token_props(
+            structured_document
+        )) == []
+
+    def test_should_return_single_token_text(self):
+        structured_document = SimpleStructuredDocument(
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        )
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        assert [t.get('text') for t in result] == [TEXT_1]
+
+    def test_should_return_multiple_token_texts(self):
+        structured_document = SimpleStructuredDocument(
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1),
+                SimpleToken(TEXT_2),
+                SimpleToken(TEXT_3)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        )
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        assert [t.get('text') for t in result] == [TEXT_1, TEXT_2, TEXT_3]
+
+    def test_should_return_tag(self):
+        structured_document = SimpleStructuredDocument([
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1, tag=TAG_1),
+                SimpleToken(TEXT_2)
+            ])], bounding_box=PAGE_BOUNDING_BOX),
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_3, tag=TAG_3)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        ])
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        assert [t.get('tag') for t in result] == [TAG_1, None, TAG_3]
+
+    def test_should_return_scoped_tags(self):
+        structured_document = SimpleStructuredDocument([
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1, tag=TAG_1),
+                SimpleToken(TEXT_2)
+            ])], bounding_box=PAGE_BOUNDING_BOX),
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_3, tag=TAG_3, tag_scope=SCOPE_1)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        ])
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        assert [t.get('scoped_tags') for t in result] == [{}, {}, {SCOPE_1: TAG_3}]
+
+    def test_should_return_bounding_box(self):
+        structured_document = SimpleStructuredDocument([
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1, bounding_box=TOKEN_BOUNDING_BOX)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        ])
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        assert [t.get('bounding_box') for t in result] == [TOKEN_BOUNDING_BOX]
+
+    def test_should_return_rel_bounding_box(self):
+        structured_document = SimpleStructuredDocument([
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1, bounding_box=TOKEN_BOUNDING_BOX)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        ])
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        assert [t.get('rel_bounding_box') for t in result] == [
+            TOKEN_BOUNDING_BOX.scale_by(
+                1.0 / PAGE_BOUNDING_BOX.width,
+                1.0 / PAGE_BOUNDING_BOX.height
+            )
+        ]
+
+    def test_should_return_page_index_and_page_count(self):
+        structured_document = SimpleStructuredDocument([
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1),
+                SimpleToken(TEXT_2)
+            ])], bounding_box=PAGE_BOUNDING_BOX),
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_3)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        ])
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        pages = [t.get('page') for t in result]
+        assert [p.get('index') for p in pages] == [0, 0, 1]
+        assert [p.get('count') for p in pages] == [2, 2, 2]
+
+    def test_should_return_page_width_and_height(self):
+        structured_document = SimpleStructuredDocument([
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        ])
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        pages = [t.get('page') for t in result]
+        assert [p.get('width') for p in pages] == [PAGE_BOUNDING_BOX.width]
+        assert [p.get('height') for p in pages] == [PAGE_BOUNDING_BOX.height]
+
+    def test_should_return_line_index_and_page_count(self):
+        structured_document = SimpleStructuredDocument([
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1)
+            ]), SimpleLine([
+                SimpleToken(TEXT_2)
+            ])], bounding_box=PAGE_BOUNDING_BOX),
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_3)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        ])
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        lines = [t.get('line') for t in result]
+        assert [l.get('index') for l in lines] == [0, 1, 0]
+        assert [l.get('count') for l in lines] == [2, 2, 1]
+
+    def test_should_return_line_token_index_and_page_count(self):
+        structured_document = SimpleStructuredDocument([
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_1),
+                SimpleToken(TEXT_2)
+            ])], bounding_box=PAGE_BOUNDING_BOX),
+            SimplePage([SimpleLine([
+                SimpleToken(TEXT_3)
+            ])], bounding_box=PAGE_BOUNDING_BOX)
+        ])
+        result = list(structured_document_to_token_props(
+            structured_document
+        ))
+        line_tokens = [t.get('line_token') for t in result]
+        assert [t.get('index') for t in line_tokens] == [0, 1, 0]
+        assert [t.get('count') for t in line_tokens] == [2, 2, 1]
+
 
 def create_token_props(text, **kwargs):
-  d = {
-    'text': text
-  }
-  d.update(kwargs)
-  return d
+    d = {
+        'text': text
+    }
+    d.update(kwargs)
+    return d
+
 
 class TestTokenPropsListToFeatures(object):
-  def test_should_extract_various_word_features(self):
-    result = token_props_list_to_features([
-      create_token_props('TestMe')
-    ])
-    assert [x.get('word.lower') for x in result] == ['testme']
-    assert [x.get('word[:1]') for x in result] == ['t']
-    assert [x.get('word[-3:]') for x in result] == ['tme']
-    assert [x.get('word[-2:]') for x in result] == ['me']
-    assert [x.get('word[:1].isupper') for x in result] == [True]
-    assert [x.get('word.isupper') for x in result] == [False]
-    assert [x.get('word.isdigit') for x in result] == [False]
-
-  def test_should_extract_scoped_tags(self):
-    token_props = create_token_props(TEXT_1)
-    token_props['scoped_tags'] = {
-      SCOPE_1: TAG_1
-    }
-    result = token_props_list_to_features([token_props])
-    assert [x.get('%s.tag' % SCOPE_1) for x in result] == [TAG_1]
-
-  def test_should_add_previous_and_next_token_word_features(self):
-    result = token_props_list_to_features([
-      create_token_props(TEXT_1),
-      create_token_props(TEXT_2),
-      create_token_props(TEXT_3)
-    ])
-    assert [x.get('word.lower') for x in result] == [
-      TEXT_1.lower(), TEXT_2.lower(), TEXT_3.lower()
-    ]
-    assert [x.get('-2:word.lower') for x in result] == [
-      None, None, TEXT_1.lower()
-    ]
-    assert [x.get('-1:word.lower') for x in result] == [
-      None, TEXT_1.lower(), TEXT_2.lower()
-    ]
-    assert [x.get('1:word.lower') for x in result] == [
-      TEXT_2.lower(), TEXT_3.lower(), None
-    ]
-    assert [x.get('2:word.lower') for x in result] == [
-      TEXT_3.lower(), None, None
-    ]
-    assert [x.get('BOD[-2]') for x in result] == [
-      True, True, None
-    ]
-    assert [x.get('BOD[-1]') for x in result] == [
-      True, None, None
-    ]
-    assert [x.get('EOD[1]') for x in result] == [
-      None, None, True
-    ]
-    assert [x.get('EOD[2]') for x in result] == [
-      None, True, True
-    ]
-
-  def test_should_not_include_tag(self):
-    result = token_props_list_to_features([
-      create_token_props(TEXT_1, tag=TAG_1)
-    ])
-    assert [x.get('tag') for x in result] == [None]
+    def test_should_extract_various_word_features(self):
+        result = token_props_list_to_features([
+            create_token_props('TestMe')
+        ])
+        assert [x.get('word.lower') for x in result] == ['testme']
+        assert [x.get('word[:1]') for x in result] == ['t']
+        assert [x.get('word[-3:]') for x in result] == ['tme']
+        assert [x.get('word[-2:]') for x in result] == ['me']
+        assert [x.get('word[:1].isupper') for x in result] == [True]
+        assert [x.get('word.isupper') for x in result] == [False]
+        assert [x.get('word.isdigit') for x in result] == [False]
+
+    def test_should_extract_scoped_tags(self):
+        token_props = create_token_props(TEXT_1)
+        token_props['scoped_tags'] = {
+            SCOPE_1: TAG_1
+        }
+        result = token_props_list_to_features([token_props])
+        assert [x.get('%s.tag' % SCOPE_1) for x in result] == [TAG_1]
+
+    def test_should_add_previous_and_next_token_word_features(self):
+        result = token_props_list_to_features([
+            create_token_props(TEXT_1),
+            create_token_props(TEXT_2),
+            create_token_props(TEXT_3)
+        ])
+        assert [x.get('word.lower') for x in result] == [
+            TEXT_1.lower(), TEXT_2.lower(), TEXT_3.lower()
+        ]
+        assert [x.get('-2:word.lower') for x in result] == [
+            None, None, TEXT_1.lower()
+        ]
+        assert [x.get('-1:word.lower') for x in result] == [
+            None, TEXT_1.lower(), TEXT_2.lower()
+        ]
+        assert [x.get('1:word.lower') for x in result] == [
+            TEXT_2.lower(), TEXT_3.lower(), None
+        ]
+        assert [x.get('2:word.lower') for x in result] == [
+            TEXT_3.lower(), None, None
+        ]
+        assert [x.get('BOD[-2]') for x in result] == [
+            True, True, None
+        ]
+        assert [x.get('BOD[-1]') for x in result] == [
+            True, None, None
+        ]
+        assert [x.get('EOD[1]') for x in result] == [
+            None, None, True
+        ]
+        assert [x.get('EOD[2]') for x in result] == [
+            None, True, True
+        ]
+
+    def test_should_not_include_tag(self):
+        result = token_props_list_to_features([
+            create_token_props(TEXT_1, tag=TAG_1)
+        ])
+        assert [x.get('tag') for x in result] == [None]
+
 
 class TestTokenPropsListToLabels(object):
-  def test_should_extract_tag(self):
-    assert token_props_list_to_labels([
-      create_token_props(TEXT_1, tag=TAG_1),
-      create_token_props(TEXT_2, tag=TAG_2)
-    ]) == [TAG_1, TAG_2]
-
-  def test_should_replace_none_tag(self):
-    assert token_props_list_to_labels([
-      create_token_props(TEXT_1, tag=TAG_1),
-      create_token_props(TEXT_2, tag=None)
-    ]) == [TAG_1, NONE_TAG]
+    def test_should_extract_tag(self):
+        assert token_props_list_to_labels([
+            create_token_props(TEXT_1, tag=TAG_1),
+            create_token_props(TEXT_2, tag=TAG_2)
+        ]) == [TAG_1, TAG_2]
+
+    def test_should_replace_none_tag(self):
+        assert token_props_list_to_labels([
+            create_token_props(TEXT_1, tag=TAG_1),
+            create_token_props(TEXT_2, tag=None)
+        ]) == [TAG_1, NONE_TAG]
+
 
 class TestRemoveLabelsFromTokenPropsList(object):
-  def test_should_remove_tag(self):
-    token_props_list = [
-      create_token_props(TEXT_1, tag=TAG_1),
-      create_token_props(TEXT_2, tag=TAG_2)
-    ]
-    updated_token_props_list = remove_labels_from_token_props_list(token_props_list)
-    assert [x.get('tag') for x in token_props_list] == [TAG_1, TAG_2]
-    assert [x.get('tag') for x in updated_token_props_list] == [None, None]
-    assert [x.get('text') for x in updated_token_props_list] == [TEXT_1, TEXT_2]
+    def test_should_remove_tag(self):
+        token_props_list = [
+            create_token_props(TEXT_1, tag=TAG_1),
+            create_token_props(TEXT_2, tag=TAG_2)
+        ]
+        updated_token_props_list = remove_labels_from_token_props_list(token_props_list)
+        assert [x.get('tag') for x in token_props_list] == [TAG_1, TAG_2]
+        assert [x.get('tag') for x in updated_token_props_list] == [None, None]
+        assert [x.get('text') for x in updated_token_props_list] == [TEXT_1, TEXT_2]
+
 
 def get_all_token_tags(structured_document, scope=None):
-  return [x.get_tag(scope=scope) for x in structured_document.iter_all_tokens()]
+    return [x.get_tag(scope=scope) for x in structured_document.iter_all_tokens()]
+
 
 class TestMergeWithCvStructuredDocument(object):
-  def test_should_merge_from_default_tag_scope(self):
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag_scope=None, tag=TAG_1)
-    ])])
-    cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag_scope=None, tag=TAG_2)
-    ])])
-    structured_document = merge_with_cv_structured_document(
-      structured_document, cv_structured_document,
-      cv_source_tag_scope=None
-    )
-    assert get_all_token_tags(structured_document) == [TAG_1]
-    assert get_all_token_tags(structured_document, scope=CV_TAG_SCOPE) == [TAG_2]
-
-  def test_should_merge_from_cv_tag_scope(self):
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag_scope=None, tag=TAG_1)
-    ])])
-    cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag_scope=CV_TAG_SCOPE, tag=TAG_2)
-    ])])
-    structured_document = merge_with_cv_structured_document(
-      structured_document, cv_structured_document,
-      cv_source_tag_scope=CV_TAG_SCOPE
-    )
-    assert get_all_token_tags(structured_document) == [TAG_1]
-    assert get_all_token_tags(structured_document, scope=CV_TAG_SCOPE) == [TAG_2]
+    def test_should_merge_from_default_tag_scope(self):
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag_scope=None, tag=TAG_1)
+        ])])
+        cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag_scope=None, tag=TAG_2)
+        ])])
+        structured_document = merge_with_cv_structured_document(
+            structured_document, cv_structured_document,
+            cv_source_tag_scope=None
+        )
+        assert get_all_token_tags(structured_document) == [TAG_1]
+        assert get_all_token_tags(structured_document, scope=CV_TAG_SCOPE) == [TAG_2]
+
+    def test_should_merge_from_cv_tag_scope(self):
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag_scope=None, tag=TAG_1)
+        ])])
+        cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag_scope=CV_TAG_SCOPE, tag=TAG_2)
+        ])])
+        structured_document = merge_with_cv_structured_document(
+            structured_document, cv_structured_document,
+            cv_source_tag_scope=CV_TAG_SCOPE
+        )
+        assert get_all_token_tags(structured_document) == [TAG_1]
+        assert get_all_token_tags(structured_document, scope=CV_TAG_SCOPE) == [TAG_2]
diff --git a/sciencebeam_gym/pdf/__init__.py b/sciencebeam_gym/pdf/__init__.py
index e6c3295..236de0d 100644
--- a/sciencebeam_gym/pdf/__init__.py
+++ b/sciencebeam_gym/pdf/__init__.py
@@ -1,2 +1,2 @@
-from sciencebeam_gym.pdf.pdf_to_lxml_wrapper import PdfToLxmlWrapper
-from sciencebeam_gym.pdf.pdf_to_png import PdfToPng
+from sciencebeam_gym.pdf.pdf_to_lxml_wrapper import PdfToLxmlWrapper  # flake8: noqa
+from sciencebeam_gym.pdf.pdf_to_png import PdfToPng  # flake8: noqa
diff --git a/sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py b/sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py
index 0c6fadb..ceae411 100644
--- a/sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py
+++ b/sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py
@@ -10,137 +10,144 @@ from tempfile import NamedTemporaryFile
 from sciencebeam_utils.utils.io import makedirs
 from sciencebeam_utils.utils.zip import extract_all_with_executable_permission
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def iter_read_lines(reader):
-  while True:
-    line = reader.readline()
-    if not line:
-      break
-    yield line
+    while True:
+        line = reader.readline()
+        if not line:
+            break
+        yield line
+
 
 def stream_lines_to_logger(lines, logger, prefix=''):
-  for line in lines:
-    line = line.strip()
-    if line:
-      logger.info('%s%s', prefix, line)
+    for line in lines:
+        line = line.strip()
+        if line:
+            logger.info('%s%s', prefix, line)
+
 
 def download_if_not_exist(url, target_file):
-  if not os.path.isfile(target_file):
-    get_logger().info('downloading %s to %s', url, target_file)
+    if not os.path.isfile(target_file):
+        get_logger().info('downloading %s to %s', url, target_file)
 
-    makedirs(os.path.dirname(target_file), exists_ok=True)
+        makedirs(os.path.dirname(target_file), exists_ok=True)
 
-    temp_filename = target_file + '.part'
-    if os.path.isfile(temp_filename):
-      os.remove(temp_filename)
-    URLopener().retrieve(url, temp_filename)
-    os.rename(temp_filename, target_file)
-  return target_file
+        temp_filename = target_file + '.part'
+        if os.path.isfile(temp_filename):
+            os.remove(temp_filename)
+        URLopener().retrieve(url, temp_filename)
+        os.rename(temp_filename, target_file)
+    return target_file
 
-def unzip_if_not_exist(zip_filename, target_directory, ignore_subdirectory=None):
-  if ignore_subdirectory is None:
-    ignore_subdirectory = os.path.basename(zip_filename)
-  if not os.path.isdir(target_directory):
-    get_logger().info('unzipping %s to %s', zip_filename, target_directory)
-    temp_target_directory = target_directory + '.part'
-    if os.path.isdir(temp_target_directory):
-      rmtree(temp_target_directory)
-
-    with ZipFile(zip_filename, 'r') as zf:
-      extract_all_with_executable_permission(zf, temp_target_directory)
-      # ignore first level in the directory structure, if applicable
-      sub_dir = (
-        os.path.join(temp_target_directory, ignore_subdirectory)
-        if ignore_subdirectory
-        else target_directory
-      )
-      if os.path.isdir(sub_dir):
-        os.rename(sub_dir, target_directory)
-        rmtree(temp_target_directory)
-      else:
-        os.rename(temp_target_directory, target_directory)
-  return target_directory
 
-class PdfToLxmlWrapper(object):
-  def __init__(self):
-    temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.temp'))
-    self.target_directory = os.path.join(temp_dir, 'pdf2xml')
-    self.zip_filename = os.path.join(temp_dir, 'pdf2xml.zip')
-    self.zip_url = (
-      'https://storage.googleapis.com/elife-ml/artefacts/pdf2xml-linux-64.zip'
-    )
+def unzip_if_not_exist(zip_filename, target_directory, ignore_subdirectory=None):
+    if ignore_subdirectory is None:
+        ignore_subdirectory = os.path.basename(zip_filename)
+    if not os.path.isdir(target_directory):
+        get_logger().info('unzipping %s to %s', zip_filename, target_directory)
+        temp_target_directory = target_directory + '.part'
+        if os.path.isdir(temp_target_directory):
+            rmtree(temp_target_directory)
+
+        with ZipFile(zip_filename, 'r') as zf:
+            extract_all_with_executable_permission(zf, temp_target_directory)
+            # ignore first level in the directory structure, if applicable
+            sub_dir = (
+                os.path.join(temp_target_directory, ignore_subdirectory)
+                if ignore_subdirectory
+                else target_directory
+            )
+            if os.path.isdir(sub_dir):
+                os.rename(sub_dir, target_directory)
+                rmtree(temp_target_directory)
+            else:
+                os.rename(temp_target_directory, target_directory)
+    return target_directory
 
-  def download_pdf2xml_zip_if_not_exist(self):
-    download_if_not_exist(
-      self.zip_url,
-      self.zip_filename
-    )
 
-  def unzip_pdf2xml_zip_if_target_directory_does_not_exist(self):
-    if not os.path.isdir(self.target_directory):
-      self.download_pdf2xml_zip_if_not_exist()
-      unzip_if_not_exist(self.zip_filename, self.target_directory)
-
-  def get_pdf2xml_executable_path(self):
-    self.unzip_pdf2xml_zip_if_target_directory_does_not_exist()
-    # use pdftoxml_server as it already handles timeouts
-    return os.path.join(
-      self.target_directory,
-      'lin-64/pdftoxml'
-    )
+class PdfToLxmlWrapper(object):
+    def __init__(self):
+        temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.temp'))
+        self.target_directory = os.path.join(temp_dir, 'pdf2xml')
+        self.zip_filename = os.path.join(temp_dir, 'pdf2xml.zip')
+        self.zip_url = (
+            'https://storage.googleapis.com/elife-ml/artefacts/pdf2xml-linux-64.zip'
+        )
+
+    def download_pdf2xml_zip_if_not_exist(self):
+        download_if_not_exist(
+            self.zip_url,
+            self.zip_filename
+        )
+
+    def unzip_pdf2xml_zip_if_target_directory_does_not_exist(self):
+        if not os.path.isdir(self.target_directory):
+            self.download_pdf2xml_zip_if_not_exist()
+            unzip_if_not_exist(self.zip_filename, self.target_directory)
+
+    def get_pdf2xml_executable_path(self):
+        self.unzip_pdf2xml_zip_if_target_directory_does_not_exist()
+        # use pdftoxml_server as it already handles timeouts
+        return os.path.join(
+            self.target_directory,
+            'lin-64/pdftoxml'
+        )
+
+    def process_input(self, source_data, args):
+        with NamedTemporaryFile() as f:
+            f.write(source_data)
+            f.flush()
+            os.fsync(f)
+            return self.process_file(f.name, args)
+
+    def process_file(self, source_filename, args):
+        pdf2xml = self.get_pdf2xml_executable_path()
+        get_logger().info('processing %s using %s', source_filename, pdf2xml)
+        cmd = [pdf2xml] + args + [source_filename, '-']
+        p = subprocess.Popen(
+            ['timeout', '20s'] + cmd,
+            stdout=PIPE,
+            stderr=PIPE,
+            stdin=None
+        )
+        out, err = p.communicate()
+        return_code = p.returncode
+        if return_code != 0:
+            get_logger().warning(
+                'process failed with %d, stderr=%s, stdout=%.200s...',
+                return_code, err, out
+            )
+            raise RuntimeError('process failed with %d, stderr: %s' % (return_code, err))
+        if len(out) == 0:
+            get_logger().warning(
+                'process returned empty response (code %d), stderr=%s, stdout=%.500s...',
+                return_code, err, out
+            )
+            raise RuntimeError(
+                'process returned empty response (code %d), stderr: %s' % (return_code, err)
+            )
+        get_logger().info(
+            'received response for %s (%s bytes)',
+            source_filename, format(len(out), ',')
+        )
+        return out
 
-  def process_input(self, source_data, args):
-    with NamedTemporaryFile() as f:
-      f.write(source_data)
-      f.flush()
-      os.fsync(f)
-      return self.process_file(f.name, args)
-
-  def process_file(self, source_filename, args):
-    pdf2xml = self.get_pdf2xml_executable_path()
-    get_logger().info('processing %s using %s', source_filename, pdf2xml)
-    cmd = [pdf2xml] + args + [source_filename, '-']
-    p = subprocess.Popen(
-      ['timeout', '20s'] + cmd,
-      stdout=PIPE,
-      stderr=PIPE,
-      stdin=None
-    )
-    out, err = p.communicate()
-    return_code = p.returncode
-    if return_code != 0:
-      get_logger().warning(
-        'process failed with %d, stderr=%s, stdout=%.200s...',
-        return_code, err, out
-      )
-      raise RuntimeError('process failed with %d, stderr: %s' % (return_code, err))
-    if len(out) == 0:
-      get_logger().warning(
-        'process returned empty response (code %d), stderr=%s, stdout=%.500s...',
-        return_code, err, out
-      )
-      raise RuntimeError(
-        'process returned empty response (code %d), stderr: %s' % (return_code, err)
-      )
-    get_logger().info(
-      'received response for %s (%s bytes)',
-      source_filename, format(len(out), ',')
-    )
-    return out
 
 if __name__ == '__main__':
-  logging.basicConfig(level='INFO')
-
-  sample_pdf_url = 'https://rawgit.com/elifesciences/XML-mapping/master/elife-00666.pdf'
-  sample_pdf_filename = '.temp/elife-00666.pdf'
-  download_if_not_exist(sample_pdf_url, sample_pdf_filename)
-  with open(sample_pdf_filename, 'rb') as sample_f:
-    sample_pdf_contents = sample_f.read()
-  get_logger().info('pdf size: %s bytes', format(len(sample_pdf_contents), ','))
-  process_out = PdfToLxmlWrapper().process_input(
-    sample_pdf_contents,
-    '-blocks -noImageInline -noImage -fullFontName'.split()
-  )
-  get_logger().info('out: %.1000s...', process_out)
+    logging.basicConfig(level='INFO')
+
+    sample_pdf_url = 'https://rawgit.com/elifesciences/XML-mapping/master/elife-00666.pdf'
+    sample_pdf_filename = '.temp/elife-00666.pdf'
+    download_if_not_exist(sample_pdf_url, sample_pdf_filename)
+    with open(sample_pdf_filename, 'rb') as sample_f:
+        sample_pdf_contents = sample_f.read()
+    get_logger().info('pdf size: %s bytes', format(len(sample_pdf_contents), ','))
+    process_out = PdfToLxmlWrapper().process_input(
+        sample_pdf_contents,
+        '-blocks -noImageInline -noImage -fullFontName'.split()
+    )
+    get_logger().info('out: %.1000s...', process_out)
diff --git a/sciencebeam_gym/pdf/pdf_to_png.py b/sciencebeam_gym/pdf/pdf_to_png.py
index 0bc816a..1aec457 100644
--- a/sciencebeam_gym/pdf/pdf_to_png.py
+++ b/sciencebeam_gym/pdf/pdf_to_png.py
@@ -4,59 +4,65 @@ from subprocess import Popen, PIPE
 
 from backports.tempfile import TemporaryDirectory
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class PdfToPng(object):
-  def __init__(self, dpi=None, image_size=None, page_range=None):
-    self.dpi = dpi
-    self.image_size = image_size
-    self.page_range = page_range
-
-  def iter_pdf_bytes_to_png_fp(self, pdf_bytes):
-    cmd = ['pdftoppm', '-png']
-    if self.page_range:
-      cmd += ['-f', str(self.page_range[0]), '-l', str(self.page_range[1])]
-    if self.image_size:
-      cmd += ['-scale-to-x', str(self.image_size[0]), '-scale-to-y', str(self.image_size[1])]
-    elif self.dpi:
-      cmd += ['-r', str(self.dpi)]
-    cmd += ['-']
-    with TemporaryDirectory() as path:
-      cmd += [os.path.join(path, 'page')]
-
-      p = Popen(cmd, stdout=PIPE, stdin=PIPE, stderr=PIPE)
-      try:
-        p.stdin.write(pdf_bytes)
-      except IOError:
-        # we'll check the returncode
-        pass
-
-      out, err = p.communicate()
-      if p.returncode != 0:
-        get_logger().debug(
-          'process failed with return code %d: cmd=%s, out=%s, err=%s',
-          p.returncode, cmd, out, err
-        )
-        raise IOError(
-          'process failed with return code %d, cmd=%s, err=%s' %
-          (p.returncode, cmd, err)
-        )
-
-      for filename in sorted(os.listdir(path)):
-        with open(os.path.join(path, filename), 'rb') as f:
-          yield f
+    def __init__(self, dpi=None, image_size=None, page_range=None):
+        self.dpi = dpi
+        self.image_size = image_size
+        self.page_range = page_range
+
+    def iter_pdf_bytes_to_png_fp(self, pdf_bytes):
+        cmd = ['pdftoppm', '-png']
+        if self.page_range:
+            cmd += ['-f', str(self.page_range[0]), '-l', str(self.page_range[1])]
+        if self.image_size:
+            cmd += ['-scale-to-x', str(self.image_size[0]), '-scale-to-y', str(self.image_size[1])]
+        elif self.dpi:
+            cmd += ['-r', str(self.dpi)]
+        cmd += ['-']
+        with TemporaryDirectory() as path:
+            cmd += [os.path.join(path, 'page')]
+
+            p = Popen(cmd, stdout=PIPE, stdin=PIPE, stderr=PIPE)
+            try:
+                p.stdin.write(pdf_bytes)
+            except IOError:
+                # we'll check the returncode
+                pass
+
+            out, err = p.communicate()
+            if p.returncode != 0:
+                get_logger().debug(
+                    'process failed with return code %d: cmd=%s, out=%s, err=%s',
+                    p.returncode, cmd, out, err
+                )
+                raise IOError(
+                    'process failed with return code %d, cmd=%s, err=%s' %
+                    (p.returncode, cmd, err)
+                )
+
+            for filename in sorted(os.listdir(path)):
+                with open(os.path.join(path, filename), 'rb') as f:
+                    yield f
+
 
 if __name__ == '__main__':
-  from sciencebeam_gym.pdf.pdf_to_lxml_wrapper import download_if_not_exist
-
-  logging.basicConfig(level='INFO')
-
-  sample_pdf_url = 'https://rawgit.com/elifesciences/XML-mapping/master/elife-00666.pdf'
-  sample_pdf_filename = '.temp/elife-00666.pdf'
-  download_if_not_exist(sample_pdf_url, sample_pdf_filename)
-  with open(sample_pdf_filename, 'rb') as sample_f:
-    sample_pdf_contents = sample_f.read()
-  get_logger().info('pdf size: %s bytes', format(len(sample_pdf_contents), ','))
-  png_bytes = [f.read() for f in PdfToPng(dpi=30).iter_pdf_bytes_to_png_fp(sample_pdf_contents)]
-  get_logger().info('read: total %d (%d files)', sum(len(x) for x in png_bytes), len(png_bytes))
+    from sciencebeam_gym.pdf.pdf_to_lxml_wrapper import download_if_not_exist
+
+    logging.basicConfig(level='INFO')
+
+    sample_pdf_url = 'https://rawgit.com/elifesciences/XML-mapping/master/elife-00666.pdf'
+    sample_pdf_filename = '.temp/elife-00666.pdf'
+    download_if_not_exist(sample_pdf_url, sample_pdf_filename)
+    with open(sample_pdf_filename, 'rb') as sample_f:
+        sample_pdf_contents = sample_f.read()
+    get_logger().info('pdf size: %s bytes', format(len(sample_pdf_contents), ','))
+    png_bytes = [
+        _f.read()
+        for _f in PdfToPng(dpi=30).iter_pdf_bytes_to_png_fp(sample_pdf_contents)
+    ]
+    get_logger().info('read: total %d (%d files)', sum(len(x) for x in png_bytes), len(png_bytes))
diff --git a/sciencebeam_gym/pdf/pdf_to_png_test.py b/sciencebeam_gym/pdf/pdf_to_png_test.py
index 4a19024..a84ad07 100644
--- a/sciencebeam_gym/pdf/pdf_to_png_test.py
+++ b/sciencebeam_gym/pdf/pdf_to_png_test.py
@@ -1,10 +1,9 @@
-import logging
 from subprocess import PIPE
 from contextlib import contextmanager
 from mock import patch
 
 from sciencebeam_gym.pdf.pdf_to_png import (
-  PdfToPng
+    PdfToPng
 )
 
 import sciencebeam_gym.pdf.pdf_to_png as pdf_to_png
@@ -17,52 +16,56 @@ ARGS_PREFIX = ['pdftoppm', '-png']
 ARGS_SUFFIX = ['-', TEMP_DIR + '/page']
 DEFAULT_KWARGS = dict(stdout=PIPE, stdin=PIPE, stderr=PIPE)
 
+
 @contextmanager
 def patch_popen():
-  with patch.object(pdf_to_png, 'Popen') as mock:
-    p = mock.return_value
-    p.communicate.return_value = (None, None)
-    p.returncode = 0
-    yield mock
+    with patch.object(pdf_to_png, 'Popen') as mock:
+        p = mock.return_value
+        p.communicate.return_value = (None, None)
+        p.returncode = 0
+        yield mock
+
 
 @contextmanager
 def mock_temp_dir():
-  with patch.object(pdf_to_png, 'TemporaryDirectory') as mock:
-    mock.return_value.__enter__.return_value = TEMP_DIR
-    with patch('os.listdir') as listdir:
-      listdir.return_value = []
-      yield mock
+    with patch.object(pdf_to_png, 'TemporaryDirectory') as mock:
+        mock.return_value.__enter__.return_value = TEMP_DIR
+        with patch('os.listdir') as listdir:
+            listdir.return_value = []
+            yield mock
+
 
 class TestPdfToPng(object):
-  def test_should_pass_default_args_to_Popen(self):
-    with patch_popen() as mock:
-      with mock_temp_dir():
-        list(PdfToPng().iter_pdf_bytes_to_png_fp(PDF_CONTENT_1))
-        assert mock.called
-        mock.assert_called_with(
-          ARGS_PREFIX + ARGS_SUFFIX, **DEFAULT_KWARGS
-        )
+    def test_should_pass_default_args_to_Popen(self):
+        with patch_popen() as mock:
+            with mock_temp_dir():
+                list(PdfToPng().iter_pdf_bytes_to_png_fp(PDF_CONTENT_1))
+                assert mock.called
+                mock.assert_called_with(
+                    ARGS_PREFIX + ARGS_SUFFIX, **DEFAULT_KWARGS
+                )
 
-  def test_should_add_page_range_to_args(self):
-    with patch_popen() as mock:
-      with mock_temp_dir():
-        list(PdfToPng(page_range=(1, 3)).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1))
-        mock.assert_called_with(
-          ARGS_PREFIX + ['-f', '1', '-l', '3'] + ARGS_SUFFIX, **DEFAULT_KWARGS
-        )
+    def test_should_add_page_range_to_args(self):
+        with patch_popen() as mock:
+            with mock_temp_dir():
+                list(PdfToPng(page_range=(1, 3)).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1))
+                mock.assert_called_with(
+                    ARGS_PREFIX + ['-f', '1', '-l', '3'] + ARGS_SUFFIX, **DEFAULT_KWARGS
+                )
 
-  def test_should_add_image_size_to_args(self):
-    with patch_popen() as mock:
-      with mock_temp_dir():
-        list(PdfToPng(image_size=(100, 200)).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1))
-        mock.assert_called_with(
-          ARGS_PREFIX + ['-scale-to-x', '100', '-scale-to-y', '200'] + ARGS_SUFFIX, **DEFAULT_KWARGS
-        )
+    def test_should_add_image_size_to_args(self):
+        with patch_popen() as mock:
+            with mock_temp_dir():
+                list(PdfToPng(image_size=(100, 200)).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1))
+                mock.assert_called_with(
+                    ARGS_PREFIX + ['-scale-to-x', '100', '-scale-to-y', '200'] + ARGS_SUFFIX,
+                    **DEFAULT_KWARGS
+                )
 
-  def test_should_add_dpi_to_args(self):
-    with patch_popen() as mock:
-      with mock_temp_dir():
-        list(PdfToPng(dpi=200).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1))
-        mock.assert_called_with(
-          ARGS_PREFIX + ['-r', '200'] + ARGS_SUFFIX, **DEFAULT_KWARGS
-        )
+    def test_should_add_dpi_to_args(self):
+        with patch_popen() as mock:
+            with mock_temp_dir():
+                list(PdfToPng(dpi=200).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1))
+                mock.assert_called_with(
+                    ARGS_PREFIX + ['-r', '200'] + ARGS_SUFFIX, **DEFAULT_KWARGS
+                )
diff --git a/sciencebeam_gym/preprocess/annotation/annotation_evaluation.py b/sciencebeam_gym/preprocess/annotation/annotation_evaluation.py
index 5ba16ce..909a2dd 100644
--- a/sciencebeam_gym/preprocess/annotation/annotation_evaluation.py
+++ b/sciencebeam_gym/preprocess/annotation/annotation_evaluation.py
@@ -5,54 +5,59 @@ from collections import Counter
 from six import iteritems
 
 from sciencebeam_utils.utils.collection import (
-  flatten
+    flatten
 )
 
+
 class EvaluationFields(object):
-  DOCUMENT = 'document'
-  PAGE = 'page'
-  TAG = 'tag'
-  COUNT = 'count'
+    DOCUMENT = 'document'
+    PAGE = 'page'
+    TAG = 'tag'
+    COUNT = 'count'
+
 
 DEFAULT_EVALUATION_COLUMNS = [
-  EvaluationFields.DOCUMENT,
-  EvaluationFields.PAGE,
-  EvaluationFields.TAG,
-  EvaluationFields.COUNT
+    EvaluationFields.DOCUMENT,
+    EvaluationFields.PAGE,
+    EvaluationFields.TAG,
+    EvaluationFields.COUNT
 ]
 
+
 def evaluate_document_page(structured_document, page):
-  tag_counter = Counter()
-  for line in structured_document.get_lines_of_page(page):
-    tag_counter.update(
-      structured_document.get_tag_value(token)
-      for token in structured_document.get_tokens_of_line(line)
-    )
-  num_tokens = sum(tag_counter.values())
-  return {
-    'count': dict(tag_counter),
-    'percentage': {
-      k: c / num_tokens
-      for k, c in iteritems(tag_counter)
+    tag_counter = Counter()
+    for line in structured_document.get_lines_of_page(page):
+        tag_counter.update(
+            structured_document.get_tag_value(token)
+            for token in structured_document.get_tokens_of_line(line)
+        )
+    num_tokens = sum(tag_counter.values())
+    return {
+        'count': dict(tag_counter),
+        'percentage': {
+            k: c / num_tokens
+            for k, c in iteritems(tag_counter)
+        }
     }
-  }
+
 
 def evaluate_document_by_page(structured_document):
-  return [
-    evaluate_document_page(structured_document, page)
-    for page in structured_document.get_pages()
-  ]
+    return [
+        evaluate_document_page(structured_document, page)
+        for page in structured_document.get_pages()
+    ]
+
 
 def to_csv_dict_rows(evaluation_result, document=None):
-  return flatten(
-    [
-      {
-        EvaluationFields.DOCUMENT: document,
-        EvaluationFields.PAGE: 1 + page_index,
-        EvaluationFields.TAG: tag,
-        EvaluationFields.COUNT: count
-      }
-      for tag, count in iteritems(page_evaluation['count'])
-    ]
-    for page_index, page_evaluation in enumerate(evaluation_result)
-  )
+    return flatten(
+        [
+            {
+                EvaluationFields.DOCUMENT: document,
+                EvaluationFields.PAGE: 1 + page_index,
+                EvaluationFields.TAG: tag,
+                EvaluationFields.COUNT: count
+            }
+            for tag, count in iteritems(page_evaluation['count'])
+        ]
+        for page_index, page_evaluation in enumerate(evaluation_result)
+    )
diff --git a/sciencebeam_gym/preprocess/annotation/annotation_evaluation_test.py b/sciencebeam_gym/preprocess/annotation/annotation_evaluation_test.py
index 37121b0..c51c4f6 100644
--- a/sciencebeam_gym/preprocess/annotation/annotation_evaluation_test.py
+++ b/sciencebeam_gym/preprocess/annotation/annotation_evaluation_test.py
@@ -1,56 +1,57 @@
 from __future__ import division
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimpleLine,
-  SimpleToken,
-  B_TAG_PREFIX,
-  I_TAG_PREFIX
+    SimpleStructuredDocument,
+    SimpleLine,
+    SimpleToken,
+    B_TAG_PREFIX,
+    I_TAG_PREFIX
 )
 
 from sciencebeam_gym.preprocess.annotation.annotation_evaluation import (
-  evaluate_document_by_page
+    evaluate_document_by_page
 )
 
 TAG1 = 'tag1'
 
+
 class TestEvaluateDocumentByPage(object):
-  def test_should_return_ratio_and_count_of_tagged_tokens(self):
-    tagged_tokens = [
-      SimpleToken('this'),
-      SimpleToken('is'),
-      SimpleToken('tagged')
-    ]
-    not_tagged_tokens = [
-      SimpleToken('this'),
-      SimpleToken('isn\'t')
-    ]
-    doc = SimpleStructuredDocument(lines=[SimpleLine(
-      tagged_tokens + not_tagged_tokens
-    )])
-    for token in tagged_tokens:
-      doc.set_tag(token, TAG1)
-    num_total = len(tagged_tokens) + len(not_tagged_tokens)
-    results = evaluate_document_by_page(doc)
-    assert results == [{
-      'count': {
-        TAG1: len(tagged_tokens),
-        None: len(not_tagged_tokens)
-      },
-      'percentage': {
-        TAG1: len(tagged_tokens) / num_total,
-        None: len(not_tagged_tokens) / num_total
-      }
-    }]
+    def test_should_return_ratio_and_count_of_tagged_tokens(self):
+        tagged_tokens = [
+            SimpleToken('this'),
+            SimpleToken('is'),
+            SimpleToken('tagged')
+        ]
+        not_tagged_tokens = [
+            SimpleToken('this'),
+            SimpleToken('isn\'t')
+        ]
+        doc = SimpleStructuredDocument(lines=[SimpleLine(
+            tagged_tokens + not_tagged_tokens
+        )])
+        for token in tagged_tokens:
+            doc.set_tag(token, TAG1)
+        num_total = len(tagged_tokens) + len(not_tagged_tokens)
+        results = evaluate_document_by_page(doc)
+        assert results == [{
+            'count': {
+                TAG1: len(tagged_tokens),
+                None: len(not_tagged_tokens)
+            },
+            'percentage': {
+                TAG1: len(tagged_tokens) / num_total,
+                None: len(not_tagged_tokens) / num_total
+            }
+        }]
 
-  def test_should_strip_prefix(self):
-    tagged_tokens = [
-      SimpleToken('this', tag=TAG1, tag_prefix=B_TAG_PREFIX),
-      SimpleToken('is', tag=TAG1, tag_prefix=I_TAG_PREFIX),
-      SimpleToken('tagged', tag=TAG1, tag_prefix=I_TAG_PREFIX)
-    ]
-    doc = SimpleStructuredDocument(lines=[SimpleLine(
-      tagged_tokens
-    )])
-    results = evaluate_document_by_page(doc)
-    assert set(results[0]['count'].keys()) == {TAG1}
+    def test_should_strip_prefix(self):
+        tagged_tokens = [
+            SimpleToken('this', tag=TAG1, tag_prefix=B_TAG_PREFIX),
+            SimpleToken('is', tag=TAG1, tag_prefix=I_TAG_PREFIX),
+            SimpleToken('tagged', tag=TAG1, tag_prefix=I_TAG_PREFIX)
+        ]
+        doc = SimpleStructuredDocument(lines=[SimpleLine(
+            tagged_tokens
+        )])
+        results = evaluate_document_by_page(doc)
+        assert set(results[0]['count'].keys()) == {TAG1}
diff --git a/sciencebeam_gym/preprocess/annotation/annotator.py b/sciencebeam_gym/preprocess/annotation/annotator.py
index 35ab00d..d6a526e 100644
--- a/sciencebeam_gym/preprocess/annotation/annotator.py
+++ b/sciencebeam_gym/preprocess/annotation/annotator.py
@@ -1,36 +1,40 @@
 from abc import ABCMeta, abstractmethod
 
 from sciencebeam_gym.preprocess.annotation.find_line_number import (
-  find_line_number_tokens
+    find_line_number_tokens
 )
 
+
 class AbstractAnnotator(object):
-  __metaclass__ = ABCMeta
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def annotate(self, structured_document):
+        pass
 
-  @abstractmethod
-  def annotate(self, structured_document):
-    pass
 
 class LineAnnotator(AbstractAnnotator):
-  def __init__(self, tag='line_no'):
-    self.tag = tag
+    def __init__(self, tag='line_no'):
+        self.tag = tag
+
+    def annotate(self, structured_document):
+        for t in find_line_number_tokens(structured_document):
+            structured_document.set_tag(t, self.tag)
+        return structured_document
 
-  def annotate(self, structured_document):
-    for t in find_line_number_tokens(structured_document):
-      structured_document.set_tag(t, self.tag)
-    return structured_document
 
 DEFAULT_ANNOTATORS = [
-  LineAnnotator()
+    LineAnnotator()
 ]
 
+
 class Annotator(object):
-  def __init__(self, annotators=None):
-    if annotators is None:
-      annotators = DEFAULT_ANNOTATORS
-    self.annotators = annotators
-
-  def annotate(self, structured_document):
-    for annotator in self.annotators:
-      structured_document = annotator.annotate(structured_document)
-    return structured_document
+    def __init__(self, annotators=None):
+        if annotators is None:
+            annotators = DEFAULT_ANNOTATORS
+        self.annotators = annotators
+
+    def annotate(self, structured_document):
+        for annotator in self.annotators:
+            structured_document = annotator.annotate(structured_document)
+        return structured_document
diff --git a/sciencebeam_gym/preprocess/annotation/annotator_test.py b/sciencebeam_gym/preprocess/annotation/annotator_test.py
index 93a4c7d..53494ff 100644
--- a/sciencebeam_gym/preprocess/annotation/annotator_test.py
+++ b/sciencebeam_gym/preprocess/annotation/annotator_test.py
@@ -1,32 +1,33 @@
 from sciencebeam_gym.preprocess.annotation.annotator import (
-  LineAnnotator
+    LineAnnotator
 )
 
 from sciencebeam_gym.structured_document import (
-  SimpleLine,
-  SimpleToken,
-  SimpleStructuredDocument
+    SimpleLine,
+    SimpleToken,
+    SimpleStructuredDocument
 )
 
 line_annotator = LineAnnotator()
 
+
 class TestLineAnnotator(object):
-  def test_x(self):
-    line_number_tokens = [
-      SimpleToken(str(line_no), dict(x='1', y=str(line_no * 20)))
-      for line_no in range(1, 5)
-    ]
-    doc = SimpleStructuredDocument(lines=[
-      SimpleLine([
-        line_number_token,
-        SimpleToken('other text', dict(
-          x=str(float(line_number_token.get_x()) + 50),
-          y=line_number_token.get_y()
-        ))
-      ])
-      for line_number_token in line_number_tokens
-    ])
-    line_annotator.annotate(
-      doc
-    )
-    assert [t.get_tag() for t in line_number_tokens] == ['line_no'] * len(line_number_tokens)
+    def test_x(self):
+        line_number_tokens = [
+            SimpleToken(str(line_no), dict(x='1', y=str(line_no * 20)))
+            for line_no in range(1, 5)
+        ]
+        doc = SimpleStructuredDocument(lines=[
+            SimpleLine([
+                line_number_token,
+                SimpleToken('other text', dict(
+                    x=str(float(line_number_token.get_x()) + 50),
+                    y=line_number_token.get_y()
+                ))
+            ])
+            for line_number_token in line_number_tokens
+        ])
+        line_annotator.annotate(
+            doc
+        )
+        assert [t.get_tag() for t in line_number_tokens] == ['line_no'] * len(line_number_tokens)
diff --git a/sciencebeam_gym/preprocess/annotation/find_line_number.py b/sciencebeam_gym/preprocess/annotation/find_line_number.py
index 6730efa..918f7b8 100644
--- a/sciencebeam_gym/preprocess/annotation/find_line_number.py
+++ b/sciencebeam_gym/preprocess/annotation/find_line_number.py
@@ -10,52 +10,55 @@ DEFAULT_X_THRESHOLD = 10
 # (low ratio indicates numbers may be figures or table values rather than line numbers)
 DEFAULT_TOKEN_RATIO_THRESHOLD = 0.7
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def _find_line_number_token_candidates(structured_document, page):
-  for line in structured_document.get_lines_of_page(page):
-    text_tokens = sorted(
-      structured_document.get_tokens_of_line(line),
-      key=lambda t: structured_document.get_x(t)
-    )
-    if text_tokens:
-      token = text_tokens[0]
-      token_text = structured_document.get_text(token)
-      if token_text and token_text.isdigit():
-        yield token
+    for line in structured_document.get_lines_of_page(page):
+        text_tokens = sorted(
+            structured_document.get_tokens_of_line(line),
+            key=lambda t: structured_document.get_x(t)
+        )
+        if text_tokens:
+            token = text_tokens[0]
+            token_text = structured_document.get_text(token)
+            if token_text and token_text.isdigit():
+                yield token
+
 
 def find_line_number_tokens(
-  structured_document,
-  x_threshold=DEFAULT_X_THRESHOLD,
-  token_ratio_threshold=DEFAULT_TOKEN_RATIO_THRESHOLD):
-
-  for page in structured_document.get_pages():
-    line_number_candidates = list(_find_line_number_token_candidates(
-      structured_document,
-      page
-    ))
-    # we need more than two lines
-    if len(line_number_candidates) > 2:
-      c = Counter((
-        round(float(structured_document.get_x(t)))
-        for t in line_number_candidates
-      ))
-      get_logger().debug('counter: %s', c)
-      most_common_x, most_common_count = c.most_common(1)[0]
-      get_logger().debug('most_common: x: %s (count: %s)', most_common_x, most_common_count)
-      tokens_within_range = [
-        token
-        for token in line_number_candidates
-        if abs(float(token.attrib['x']) - most_common_x) < x_threshold
-      ]
-      token_within_range_ratio = len(tokens_within_range) / len(line_number_candidates)
-      if token_within_range_ratio < token_ratio_threshold:
-        get_logger().debug(
-          'token within range ratio not meeting threshold: %f < %f',
-          token_within_range_ratio,
-          token_ratio_threshold
-        )
-      else:
-        for token in tokens_within_range:
-          yield token
+        structured_document,
+        x_threshold=DEFAULT_X_THRESHOLD,
+        token_ratio_threshold=DEFAULT_TOKEN_RATIO_THRESHOLD):
+
+    for page in structured_document.get_pages():
+        line_number_candidates = list(_find_line_number_token_candidates(
+            structured_document,
+            page
+        ))
+        # we need more than two lines
+        if len(line_number_candidates) > 2:
+            c = Counter((
+                round(float(structured_document.get_x(t)))
+                for t in line_number_candidates
+            ))
+            get_logger().debug('counter: %s', c)
+            most_common_x, most_common_count = c.most_common(1)[0]
+            get_logger().debug('most_common: x: %s (count: %s)', most_common_x, most_common_count)
+            tokens_within_range = [
+                token
+                for token in line_number_candidates
+                if abs(float(token.attrib['x']) - most_common_x) < x_threshold
+            ]
+            token_within_range_ratio = len(tokens_within_range) / len(line_number_candidates)
+            if token_within_range_ratio < token_ratio_threshold:
+                get_logger().debug(
+                    'token within range ratio not meeting threshold: %f < %f',
+                    token_within_range_ratio,
+                    token_ratio_threshold
+                )
+            else:
+                for token in tokens_within_range:
+                    yield token
diff --git a/sciencebeam_gym/preprocess/annotation/find_line_numbers_test.py b/sciencebeam_gym/preprocess/annotation/find_line_numbers_test.py
index 367ed0b..5719466 100644
--- a/sciencebeam_gym/preprocess/annotation/find_line_numbers_test.py
+++ b/sciencebeam_gym/preprocess/annotation/find_line_numbers_test.py
@@ -1,107 +1,108 @@
 from sciencebeam_utils.utils.collection import (
-  flatten
+    flatten
 )
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimpleLine,
-  SimpleToken
+    SimpleStructuredDocument,
+    SimpleLine,
+    SimpleToken
 )
 
 from sciencebeam_gym.preprocess.annotation.find_line_number import (
-  find_line_number_tokens
+    find_line_number_tokens
 )
 
+
 class TestFindLxmlLineNumberTokens(object):
-  def test_should_return_empty_list_for_empty_page(self):
-    doc = SimpleStructuredDocument(lines=[])
-    line_number_tokens = list(find_line_number_tokens(doc))
-    assert len(line_number_tokens) == 0
+    def test_should_return_empty_list_for_empty_page(self):
+        doc = SimpleStructuredDocument(lines=[])
+        line_number_tokens = list(find_line_number_tokens(doc))
+        assert len(line_number_tokens) == 0
 
-  def test_should_return_line_number_tokens_appearing_first_in_line(self):
-    line_number_tokens = [
-      SimpleToken(str(line_no), dict(
-        x=str(10),
-        y=str(line_no * 20))
-      )
-      for line_no in range(1, 5)
-    ]
-    doc = SimpleStructuredDocument(lines=[
-      SimpleLine([
-        line_number_token,
-        SimpleToken('other text', dict(
-          x=str(50),
-          y=line_number_token.get_y()
-        ))
-      ])
-      for line_number_token in line_number_tokens
-    ])
-    expected_line_number_tokens = line_number_tokens
-    actual_line_number_tokens = list(find_line_number_tokens(doc))
-    assert actual_line_number_tokens == expected_line_number_tokens
+    def test_should_return_line_number_tokens_appearing_first_in_line(self):
+        line_number_tokens = [
+            SimpleToken(str(line_no), dict(
+                x=str(10),
+                y=str(line_no * 20))
+            )
+            for line_no in range(1, 5)
+        ]
+        doc = SimpleStructuredDocument(lines=[
+            SimpleLine([
+                line_number_token,
+                SimpleToken('other text', dict(
+                    x=str(50),
+                    y=line_number_token.get_y()
+                ))
+            ])
+            for line_number_token in line_number_tokens
+        ])
+        expected_line_number_tokens = line_number_tokens
+        actual_line_number_tokens = list(find_line_number_tokens(doc))
+        assert actual_line_number_tokens == expected_line_number_tokens
 
-  def test_should_not_return_line_number_tokens_if_not_line(self):
-    line_number_tokens = [
-      SimpleToken(str(line_no), dict(
-        x=str(30),
-        y=str(line_no * 20))
-      )
-      for line_no in range(1, 5)
-    ]
-    doc = SimpleStructuredDocument(lines=[
-      SimpleLine([
-        line_number_token,
-        SimpleToken('other text', dict(
-          x=str(20),
-          y=line_number_token.get_y()
-        ))
-      ])
-      for line_number_token in line_number_tokens
-    ])
-    expected_line_number_tokens = []
-    actual_line_number_tokens = list(find_line_number_tokens(doc))
-    assert actual_line_number_tokens == expected_line_number_tokens
+    def test_should_not_return_line_number_tokens_if_not_line(self):
+        line_number_tokens = [
+            SimpleToken(str(line_no), dict(
+                x=str(30),
+                y=str(line_no * 20))
+            )
+            for line_no in range(1, 5)
+        ]
+        doc = SimpleStructuredDocument(lines=[
+            SimpleLine([
+                line_number_token,
+                SimpleToken('other text', dict(
+                    x=str(20),
+                    y=line_number_token.get_y()
+                ))
+            ])
+            for line_number_token in line_number_tokens
+        ])
+        expected_line_number_tokens = []
+        actual_line_number_tokens = list(find_line_number_tokens(doc))
+        assert actual_line_number_tokens == expected_line_number_tokens
 
-  def test_should_not_return_line_number_tokens_at_unusual_position(self):
-    usual_line_number_x = 1
-    line_number_tokens = [
-      SimpleToken(str(line_no), dict(
-        x=str(usual_line_number_x if line_no != 2 else usual_line_number_x + 30),
-        y=str(line_no * 20))
-      )
-      for line_no in range(1, 5)
-    ]
-    doc = SimpleStructuredDocument(lines=[
-      SimpleLine([
-        line_number_token,
-        SimpleToken('other text', dict(
-          x=str(50),
-          y=line_number_token.get_y()
-        ))
-      ])
-      for line_number_token in line_number_tokens
-    ])
-    expected_line_number_tokens = [
-      t for t in line_number_tokens
-      if int(t.get_x()) == usual_line_number_x
-    ]
-    actual_line_number_tokens = list(find_line_number_tokens(doc))
-    assert actual_line_number_tokens == expected_line_number_tokens
+    def test_should_not_return_line_number_tokens_at_unusual_position(self):
+        usual_line_number_x = 1
+        line_number_tokens = [
+            SimpleToken(str(line_no), dict(
+                x=str(usual_line_number_x if line_no != 2 else usual_line_number_x + 30),
+                y=str(line_no * 20))
+            )
+            for line_no in range(1, 5)
+        ]
+        doc = SimpleStructuredDocument(lines=[
+            SimpleLine([
+                line_number_token,
+                SimpleToken('other text', dict(
+                    x=str(50),
+                    y=line_number_token.get_y()
+                ))
+            ])
+            for line_number_token in line_number_tokens
+        ])
+        expected_line_number_tokens = [
+            t for t in line_number_tokens
+            if int(t.get_x()) == usual_line_number_x
+        ]
+        actual_line_number_tokens = list(find_line_number_tokens(doc))
+        assert actual_line_number_tokens == expected_line_number_tokens
 
-  def test_should_not_return_line_number_tokens_at_unusual_position2(self):
-    number_tokens = flatten([
-      [
-        SimpleToken(str(line_no), dict(
-          x=str(x * 50),
-          y=str(line_no * 20))
-        )
-        for line_no in range(1, 5)
-      ]
-      for x in range(1, 3)
-    ])
-    doc = SimpleStructuredDocument(lines=[
-      SimpleLine([number_token])
-      for number_token in number_tokens
-    ])
-    actual_line_number_tokens = list(find_line_number_tokens(doc))
-    assert actual_line_number_tokens == []
+    def test_should_not_return_line_number_tokens_at_unusual_position2(self):
+        number_tokens = flatten([
+            [
+                SimpleToken(str(line_no), dict(
+                    x=str(x * 50),
+                    y=str(line_no * 20))
+                )
+                for line_no in range(1, 5)
+            ]
+            for x in range(1, 3)
+        ])
+        doc = SimpleStructuredDocument(lines=[
+            SimpleLine([number_token])
+            for number_token in number_tokens
+        ])
+        actual_line_number_tokens = list(find_line_number_tokens(doc))
+        assert actual_line_number_tokens == []
diff --git a/sciencebeam_gym/preprocess/annotation/fuzzy_match.py b/sciencebeam_gym/preprocess/annotation/fuzzy_match.py
index aa80315..65360a0 100644
--- a/sciencebeam_gym/preprocess/annotation/fuzzy_match.py
+++ b/sciencebeam_gym/preprocess/annotation/fuzzy_match.py
@@ -3,273 +3,287 @@ from __future__ import division
 import logging
 
 from sciencebeam_utils.utils.string import (
-  LazyStr
+    LazyStr
 )
 
 from sciencebeam_alignment.align import (
-  LocalSequenceMatcher,
-  SimpleScoring
+    LocalSequenceMatcher,
+    SimpleScoring
 )
 
 from sciencebeam_alignment.word_sequence_matcher import (
-  WordSequenceMatcher
+    WordSequenceMatcher
 )
 
 DEFAULT_SCORING = SimpleScoring(
-  match_score=2,
-  mismatch_score=-1,
-  gap_score=-2
+    match_score=2,
+    mismatch_score=-1,
+    gap_score=-2
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def len_index_range(index_range):
-  return index_range[1] - index_range[0]
+    return index_range[1] - index_range[0]
+
 
 # Treat space or comma after a dot, or a dot after a letter as junk
-DEFAULT_ISJUNK = lambda s, i: (
-  (i > 0 and s[i - 1] == '.' and (s[i] == ' ' or s[i] == ',')) or
-  (i > 0 and s[i - 1].isalpha() and s[i] == '.') or
-  (i > 0 and s[i - 1] == s[i]) or
-  s[i] == '*'
-)
+def DEFAULT_ISJUNK(s, i):
+    return (
+        (i > 0 and s[i - 1] == '.' and (s[i] == ' ' or s[i] == ',')) or
+        (i > 0 and s[i - 1].isalpha() and s[i] == '.') or
+        (i > 0 and s[i - 1] == s[i]) or
+        s[i] == '*'
+    )
 
-DOT_IS_JUNK = lambda s, i: s[i] == '.'
 
-def remove_junk(s, isjunk=None):
-  if isjunk is None:
-    isjunk = DEFAULT_ISJUNK
-  result = None
-  start = 0
-  for i in range(len(s)):
-    if isjunk(s, i):
-      if result is None:
-        result = []
-      if i > start:
-        result.append(s[start:i])
-      start = i + 1
-  if result is None:
-    return s
-  if len(s) > start:
-    result.append(s[start:])
-  return ''.join(result)
+def DOT_IS_JUNK(s, i):
+    return s[i] == '.'
 
-def invert_index_ranges(range_list, start, end):
-  i = start
-  for r_start, r_end in range_list:
-    if i >= end:
-      return
-    if i < r_start:
-      yield i, min(end, r_start)
-    i = r_end
-  if i < end:
-    yield i, end
 
-class FuzzyMatchResult(object):
-  def __init__(self, a, b, matching_blocks, isjunk=None):
-    self.a = a
-    self.b = b
-    self.matching_blocks = matching_blocks
-    self.non_empty_matching_blocks = [x for x in self.matching_blocks if x[-1]]
-    self._match_count = None
-    self._a_index_range = None
-    self._b_index_range = None
-    self.isjunk = isjunk or DEFAULT_ISJUNK
-
-  def has_match(self):
-    return len(self.non_empty_matching_blocks) > 0
-
-  def match_count(self):
-    if self._match_count is None:
-      self._match_count = sum(triple[-1] for triple in self.matching_blocks)
-    return self._match_count
-
-  def ratio_to(self, size):
-    if not size:
-      return 0.0
-    return self.match_count() / size
-
-  def ratio(self):
-    a_match_len = len_index_range(self.a_index_range())
-    b_match_len = len_index_range(self.b_index_range())
-    max_len = max(a_match_len, b_match_len)
-    if max_len == a_match_len:
-      junk_match_count = self.a_non_matching_junk_count(self.a_index_range())
-    else:
-      junk_match_count = self.b_non_matching_junk_count(self.b_index_range())
-    max_len_excl_junk = max_len - junk_match_count
-    result = self.ratio_to(max_len_excl_junk)
-    if result > 1.0:
-      get_logger().debug(
-        'ratio: ratio greater than 1.0, a_match_len=%d, b_match_len=%d, max_len=%d,'
-        ' junk_match_count=%d, max_len_excl_junk=%d, result=%f',
-        a_match_len, b_match_len, max_len, junk_match_count, max_len_excl_junk, result
-      )
-    return result
-
-  def count_junk_between(self, s, index_range):
-    if not self.isjunk:
-      return 0
-    return sum(self.isjunk(s, i) for i in range(index_range[0], index_range[1]))
-
-  def count_non_matching_junk(self, s, s_matching_blocks, index_range=None):
-    if not self.isjunk:
-      return 0
-    if index_range is None:
-      index_range = (0, len(s))
-    return sum(
-      self.count_junk_between(s, block_index_range)
-      for block_index_range in invert_index_ranges(
-        s_matching_blocks, index_range[0], index_range[1]
-      )
-    )
+def remove_junk(s, isjunk=None):
+    if isjunk is None:
+        isjunk = DEFAULT_ISJUNK
+    result = None
+    start = 0
+    for i in range(len(s)):
+        if isjunk(s, i):
+            if result is None:
+                result = []
+            if i > start:
+                result.append(s[start:i])
+            start = i + 1
+    if result is None:
+        return s
+    if len(s) > start:
+        result.append(s[start:])
+    return ''.join(result)
 
-  def a_junk_match_count(self):
-    return self.count_junk_between(self.a, self.a_index_range())
-
-  def a_junk_count(self):
-    return self.count_junk_between(self.a, (0, len(self.a)))
-
-  def a_non_matching_junk_count(self, index_range=None):
-    return self.count_non_matching_junk(self.a, self.a_matching_blocks(), index_range)
-
-  def b_junk_match_count(self):
-    return self.count_junk_between(self.b, self.b_index_range())
-
-  def b_junk_count(self):
-    return self.count_junk_between(self.b, (0, len(self.b)))
-
-  def b_non_matching_junk_count(self, index_range=None):
-    return self.count_non_matching_junk(self.b, self.b_matching_blocks(), index_range)
-
-  def a_ratio(self):
-    return self.ratio_to(len(self.a) - self.a_non_matching_junk_count())
-
-  def b_ratio(self):
-    return self.ratio_to(len(self.b) - self.b_non_matching_junk_count())
-
-  def b_gap_ratio(self):
-    """
-    Calculates the ratio of matches vs the length of b,
-    but also adds any gaps / mismatches within a.
-    """
-    a_index_range = self.a_index_range()
-    a_match_len = len_index_range(a_index_range)
-    match_count = self.match_count()
-    a_junk_match_count = self.a_non_matching_junk_count(a_index_range)
-    b_junk_count = self.b_non_matching_junk_count()
-    a_gaps = a_match_len - match_count
-    return self.ratio_to(len(self.b) + a_gaps - a_junk_match_count - b_junk_count)
-
-  def a_matching_blocks(self):
-    return ((a, a + size) for a, _, size in self.non_empty_matching_blocks)
-
-  def b_matching_blocks(self):
-    return ((b, b + size) for _, b, size in self.non_empty_matching_blocks)
-
-  def a_start_index(self):
-    return self.non_empty_matching_blocks[0][0] if self.has_match() else None
-
-  def a_end_index(self):
-    if not self.has_match():
-      return None
-    ai, _, size = self.non_empty_matching_blocks[-1]
-    return ai + size
-
-  def a_index_range(self):
-    if not self.non_empty_matching_blocks:
-      return (0, 0)
-    if not self._a_index_range:
-      self._a_index_range = (self.a_start_index(), self.a_end_index())
-    return self._a_index_range
-
-  def b_start_index(self):
-    return self.non_empty_matching_blocks[0][1] if self.has_match() else None
-
-  def b_end_index(self):
-    if not self.has_match():
-      return None
-    _, bi, size = self.non_empty_matching_blocks[-1]
-    return bi + size
-
-  def b_index_range(self):
-    if not self.non_empty_matching_blocks:
-      return (0, 0)
-    if not self._b_index_range:
-      self._b_index_range = (self.b_start_index(), self.b_end_index())
-    return self._b_index_range
-
-  def a_split_at(self, index, a_pre_split=None, a_post_split=None):
-    if a_pre_split is None:
-      a_pre_split = self.a[:index]
-    if a_post_split is None:
-      a_post_split = self.a[index:]
-    if not self.non_empty_matching_blocks or self.a_end_index() <= index:
-      return (
-        FuzzyMatchResult(a_pre_split, self.b, self.non_empty_matching_blocks),
-        FuzzyMatchResult(a_post_split, self.b, [])
-      )
-    return (
-      FuzzyMatchResult(a_pre_split, self.b, [
-        (ai, bi, min(size, index - ai))
-        for ai, bi, size in self.non_empty_matching_blocks
-        if ai < index
-      ]),
-      FuzzyMatchResult(a_post_split, self.b, [
-        (max(0, ai - index), bi, size if ai >= index else size + ai - index)
-        for ai, bi, size in self.non_empty_matching_blocks
-        if ai + size > index
-      ])
-    )
 
-  def b_split_at(self, index, b_pre_split=None, b_post_split=None):
-    if b_pre_split is None:
-      b_pre_split = self.b[:index]
-    if b_post_split is None:
-      b_post_split = self.b[index:]
-    if not self.non_empty_matching_blocks or self.b_end_index() <= index:
-      return (
-        FuzzyMatchResult(self.a, b_pre_split, self.non_empty_matching_blocks),
-        FuzzyMatchResult(self.a, b_post_split, [])
-      )
-    result = (
-      FuzzyMatchResult(self.a, b_pre_split, [
-        (ai, bi, min(size, index - bi))
-        for ai, bi, size in self.non_empty_matching_blocks
-        if bi < index
-      ]),
-      FuzzyMatchResult(self.a, b_post_split, [
-        (ai, max(0, bi - index), size if bi >= index else size + bi - index)
-        for ai, bi, size in self.non_empty_matching_blocks
-        if bi + size > index
-      ])
-    )
-    return result
-
-  def detailed_str(self):
-    return 'matching_blocks=[%s]' % (
-      ', '.join([
-        '(a[%d:+%d] = b[%d:+%d] = "%s")' % (ai, size, bi, size, self.a[ai:ai + size])
-        for ai, bi, size in self.non_empty_matching_blocks
-      ])
-    )
+def invert_index_ranges(range_list, start, end):
+    i = start
+    for r_start, r_end in range_list:
+        if i >= end:
+            return
+        if i < r_start:
+            yield i, min(end, r_start)
+        i = r_end
+    if i < end:
+        yield i, end
 
-  def detailed(self):
-    return LazyStr(self.detailed_str)
 
-  def __repr__(self):
-    return (
-      'FuzzyMatchResult(matching_blocks={}, match_count={}, ratio={},'
-      ' a_ratio={}, b_gap_ratio={})'.format(
-        self.matching_blocks, self.match_count(), self.ratio(), self.a_ratio(), self.b_gap_ratio()
-      )
-    )
+class FuzzyMatchResult(object):
+    def __init__(self, a, b, matching_blocks, isjunk=None):
+        self.a = a
+        self.b = b
+        self.matching_blocks = matching_blocks
+        self.non_empty_matching_blocks = [x for x in self.matching_blocks if x[-1]]
+        self._match_count = None
+        self._a_index_range = None
+        self._b_index_range = None
+        self.isjunk = isjunk or DEFAULT_ISJUNK
+
+    def has_match(self):
+        return len(self.non_empty_matching_blocks) > 0
+
+    def match_count(self):
+        if self._match_count is None:
+            self._match_count = sum(triple[-1] for triple in self.matching_blocks)
+        return self._match_count
+
+    def ratio_to(self, size):
+        if not size:
+            return 0.0
+        return self.match_count() / size
+
+    def ratio(self):
+        a_match_len = len_index_range(self.a_index_range())
+        b_match_len = len_index_range(self.b_index_range())
+        max_len = max(a_match_len, b_match_len)
+        if max_len == a_match_len:
+            junk_match_count = self.a_non_matching_junk_count(self.a_index_range())
+        else:
+            junk_match_count = self.b_non_matching_junk_count(self.b_index_range())
+        max_len_excl_junk = max_len - junk_match_count
+        result = self.ratio_to(max_len_excl_junk)
+        if result > 1.0:
+            get_logger().debug(
+                'ratio: ratio greater than 1.0, a_match_len=%d, b_match_len=%d, max_len=%d,'
+                ' junk_match_count=%d, max_len_excl_junk=%d, result=%f',
+                a_match_len, b_match_len, max_len, junk_match_count, max_len_excl_junk, result
+            )
+        return result
+
+    def count_junk_between(self, s, index_range):
+        if not self.isjunk:
+            return 0
+        return sum(self.isjunk(s, i) for i in range(index_range[0], index_range[1]))
+
+    def count_non_matching_junk(self, s, s_matching_blocks, index_range=None):
+        if not self.isjunk:
+            return 0
+        if index_range is None:
+            index_range = (0, len(s))
+        return sum(
+            self.count_junk_between(s, block_index_range)
+            for block_index_range in invert_index_ranges(
+                s_matching_blocks, index_range[0], index_range[1]
+            )
+        )
+
+    def a_junk_match_count(self):
+        return self.count_junk_between(self.a, self.a_index_range())
+
+    def a_junk_count(self):
+        return self.count_junk_between(self.a, (0, len(self.a)))
+
+    def a_non_matching_junk_count(self, index_range=None):
+        return self.count_non_matching_junk(self.a, self.a_matching_blocks(), index_range)
+
+    def b_junk_match_count(self):
+        return self.count_junk_between(self.b, self.b_index_range())
+
+    def b_junk_count(self):
+        return self.count_junk_between(self.b, (0, len(self.b)))
+
+    def b_non_matching_junk_count(self, index_range=None):
+        return self.count_non_matching_junk(self.b, self.b_matching_blocks(), index_range)
+
+    def a_ratio(self):
+        return self.ratio_to(len(self.a) - self.a_non_matching_junk_count())
+
+    def b_ratio(self):
+        return self.ratio_to(len(self.b) - self.b_non_matching_junk_count())
+
+    def b_gap_ratio(self):
+        """
+        Calculates the ratio of matches vs the length of b,
+        but also adds any gaps / mismatches within a.
+        """
+        a_index_range = self.a_index_range()
+        a_match_len = len_index_range(a_index_range)
+        match_count = self.match_count()
+        a_junk_match_count = self.a_non_matching_junk_count(a_index_range)
+        b_junk_count = self.b_non_matching_junk_count()
+        a_gaps = a_match_len - match_count
+        return self.ratio_to(len(self.b) + a_gaps - a_junk_match_count - b_junk_count)
+
+    def a_matching_blocks(self):
+        return ((a, a + size) for a, _, size in self.non_empty_matching_blocks)
+
+    def b_matching_blocks(self):
+        return ((b, b + size) for _, b, size in self.non_empty_matching_blocks)
+
+    def a_start_index(self):
+        return self.non_empty_matching_blocks[0][0] if self.has_match() else None
+
+    def a_end_index(self):
+        if not self.has_match():
+            return None
+        ai, _, size = self.non_empty_matching_blocks[-1]
+        return ai + size
+
+    def a_index_range(self):
+        if not self.non_empty_matching_blocks:
+            return (0, 0)
+        if not self._a_index_range:
+            self._a_index_range = (self.a_start_index(), self.a_end_index())
+        return self._a_index_range
+
+    def b_start_index(self):
+        return self.non_empty_matching_blocks[0][1] if self.has_match() else None
+
+    def b_end_index(self):
+        if not self.has_match():
+            return None
+        _, bi, size = self.non_empty_matching_blocks[-1]
+        return bi + size
+
+    def b_index_range(self):
+        if not self.non_empty_matching_blocks:
+            return (0, 0)
+        if not self._b_index_range:
+            self._b_index_range = (self.b_start_index(), self.b_end_index())
+        return self._b_index_range
+
+    def a_split_at(self, index, a_pre_split=None, a_post_split=None):
+        if a_pre_split is None:
+            a_pre_split = self.a[:index]
+        if a_post_split is None:
+            a_post_split = self.a[index:]
+        if not self.non_empty_matching_blocks or self.a_end_index() <= index:
+            return (
+                FuzzyMatchResult(a_pre_split, self.b, self.non_empty_matching_blocks),
+                FuzzyMatchResult(a_post_split, self.b, [])
+            )
+        return (
+            FuzzyMatchResult(a_pre_split, self.b, [
+                (ai, bi, min(size, index - ai))
+                for ai, bi, size in self.non_empty_matching_blocks
+                if ai < index
+            ]),
+            FuzzyMatchResult(a_post_split, self.b, [
+                (max(0, ai - index), bi, size if ai >= index else size + ai - index)
+                for ai, bi, size in self.non_empty_matching_blocks
+                if ai + size > index
+            ])
+        )
+
+    def b_split_at(self, index, b_pre_split=None, b_post_split=None):
+        if b_pre_split is None:
+            b_pre_split = self.b[:index]
+        if b_post_split is None:
+            b_post_split = self.b[index:]
+        if not self.non_empty_matching_blocks or self.b_end_index() <= index:
+            return (
+                FuzzyMatchResult(self.a, b_pre_split, self.non_empty_matching_blocks),
+                FuzzyMatchResult(self.a, b_post_split, [])
+            )
+        result = (
+            FuzzyMatchResult(self.a, b_pre_split, [
+                (ai, bi, min(size, index - bi))
+                for ai, bi, size in self.non_empty_matching_blocks
+                if bi < index
+            ]),
+            FuzzyMatchResult(self.a, b_post_split, [
+                (ai, max(0, bi - index), size if bi >= index else size + bi - index)
+                for ai, bi, size in self.non_empty_matching_blocks
+                if bi + size > index
+            ])
+        )
+        return result
+
+    def detailed_str(self):
+        return 'matching_blocks=[%s]' % (
+            ', '.join([
+                '(a[%d:+%d] = b[%d:+%d] = "%s")' % (ai, size, bi, size, self.a[ai:ai + size])
+                for ai, bi, size in self.non_empty_matching_blocks
+            ])
+        )
+
+    def detailed(self):
+        return LazyStr(self.detailed_str)
+
+    def __repr__(self):
+        return (
+            'FuzzyMatchResult(matching_blocks={}, match_count={}, ratio={},'
+            ' a_ratio={}, b_gap_ratio={})'.format(
+                self.matching_blocks,
+                self.match_count(),
+                self.ratio(),
+                self.a_ratio(),
+                self.b_gap_ratio()
+            )
+        )
+
 
 def fuzzy_match(a, b, exact_word_match_threshold=5):
-  if min(len(a), len(b)) < exact_word_match_threshold:
-    sm = WordSequenceMatcher(None, a, b)
-  else:
-    sm = LocalSequenceMatcher(a=a, b=b, scoring=DEFAULT_SCORING)
-  matching_blocks = sm.get_matching_blocks()
-  return FuzzyMatchResult(a, b, matching_blocks)
+    if min(len(a), len(b)) < exact_word_match_threshold:
+        sm = WordSequenceMatcher(None, a, b)
+    else:
+        sm = LocalSequenceMatcher(a=a, b=b, scoring=DEFAULT_SCORING)
+    matching_blocks = sm.get_matching_blocks()
+    return FuzzyMatchResult(a, b, matching_blocks)
diff --git a/sciencebeam_gym/preprocess/annotation/fuzzy_match_test.py b/sciencebeam_gym/preprocess/annotation/fuzzy_match_test.py
index 5b52a30..9b4fe9f 100644
--- a/sciencebeam_gym/preprocess/annotation/fuzzy_match_test.py
+++ b/sciencebeam_gym/preprocess/annotation/fuzzy_match_test.py
@@ -3,186 +3,191 @@ from __future__ import division
 import logging
 
 from sciencebeam_gym.preprocess.annotation.fuzzy_match import (
-  remove_junk,
-  invert_index_ranges,
-  FuzzyMatchResult,
-  fuzzy_match,
-  DOT_IS_JUNK
+    remove_junk,
+    invert_index_ranges,
+    FuzzyMatchResult,
+    fuzzy_match,
+    DOT_IS_JUNK
 )
 
+
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 class TestRemoveJunk(object):
-  def test_should_keep_str_without_junk(self):
-    assert remove_junk('abc', DOT_IS_JUNK) == 'abc'
+    def test_should_keep_str_without_junk(self):
+        assert remove_junk('abc', DOT_IS_JUNK) == 'abc'
+
+    def test_should_remove_dots_after_capitals(self):
+        assert remove_junk('P.O. Box', DOT_IS_JUNK) == 'PO Box'
 
-  def test_should_remove_dots_after_capitals(self):
-    assert remove_junk('P.O. Box', DOT_IS_JUNK) == 'PO Box'
+    def test_should_remove_asterisk_after_capitals(self):
+        assert remove_junk('Mr Beam*') == 'Mr Beam'
 
-  def test_should_remove_asterisk_after_capitals(self):
-    assert remove_junk('Mr Beam*') == 'Mr Beam'
+    def test_should_remove_repeating_characters(self):
+        assert remove_junk('Mr Beeeam') == 'Mr Beam'
 
-  def test_should_remove_repeating_characters(self):
-    assert remove_junk('Mr Beeeam') == 'Mr Beam'
 
 class TestInvertIndexRanges(object):
-  def test_should_return_empty_for_empty_range(self):
-    assert list(invert_index_ranges([], 0, 0)) == list([])
+    def test_should_return_empty_for_empty_range(self):
+        assert list(invert_index_ranges([], 0, 0)) == list([])
 
-  def test_should_return_whole_range_for_empty_range_list(self):
-    assert list(invert_index_ranges([], 0, 10)) == list([(0, 10)])
+    def test_should_return_whole_range_for_empty_range_list(self):
+        assert list(invert_index_ranges([], 0, 10)) == list([(0, 10)])
 
-  def test_should_exclude_range_in_the_beginning(self):
-    assert list(invert_index_ranges([(0, 3)], 0, 10)) == list([(3, 10)])
+    def test_should_exclude_range_in_the_beginning(self):
+        assert list(invert_index_ranges([(0, 3)], 0, 10)) == list([(3, 10)])
 
-  def test_should_exclude_range_in_the_beginning_beyond_start(self):
-    assert list(invert_index_ranges([(0, 13)], 10, 20)) == list([(13, 20)])
+    def test_should_exclude_range_in_the_beginning_beyond_start(self):
+        assert list(invert_index_ranges([(0, 13)], 10, 20)) == list([(13, 20)])
 
-  def test_should_exclude_range_in_the_middle(self):
-    assert list(invert_index_ranges([(4, 7)], 0, 10)) == list([(0, 4), (7, 10)])
+    def test_should_exclude_range_in_the_middle(self):
+        assert list(invert_index_ranges([(4, 7)], 0, 10)) == list([(0, 4), (7, 10)])
 
-  def test_should_exclude_range_at_the_end(self):
-    assert list(invert_index_ranges([(7, 10)], 0, 10)) == list([(0, 7)])
+    def test_should_exclude_range_at_the_end(self):
+        assert list(invert_index_ranges([(7, 10)], 0, 10)) == list([(0, 7)])
+
+    def test_should_exclude_range_at_the_end_beyond_end(self):
+        assert list(invert_index_ranges([(7, 100)], 0, 10)) == list([(0, 7)])
 
-  def test_should_exclude_range_at_the_end_beyond_end(self):
-    assert list(invert_index_ranges([(7, 100)], 0, 10)) == list([(0, 7)])
 
 class TestFuzzyMatch(object):
-  def test_match_count_should_be_the_same_independent_of_order(self):
-    s1 = 'this is a some sequence'
-    choice = 'this is another sequence'
-    fm_1 = fuzzy_match(s1, choice)
-    fm_2 = fuzzy_match(choice, s1)
-    assert fm_1.match_count() == fm_2.match_count()
+    def test_match_count_should_be_the_same_independent_of_order(self):
+        s1 = 'this is a some sequence'
+        choice = 'this is another sequence'
+        fm_1 = fuzzy_match(s1, choice)
+        fm_2 = fuzzy_match(choice, s1)
+        assert fm_1.match_count() == fm_2.match_count()
+
 
 class TestFuzzyMatchResult(object):
-  def test_exact_match(self):
-    fm = FuzzyMatchResult('abc', 'abc', [(0, 0, 3)])
-    assert fm.has_match()
-    assert fm.match_count() == 3
-    assert fm.ratio() == 1.0
-    assert fm.a_ratio() == 1.0
-    assert fm.b_ratio() == 1.0
-    assert fm.b_gap_ratio() == 1.0
-    assert fm.a_index_range() == (0, 3)
-    assert fm.b_index_range() == (0, 3)
-
-  def test_no_match(self):
-    fm = FuzzyMatchResult('abc', 'xyz', [])
-    assert not fm.has_match()
-    assert fm.match_count() == 0
-
-  def test_partial_match(self):
-    fm = FuzzyMatchResult('abx', 'aby', [(0, 0, 2)])
-    assert fm.has_match()
-    assert fm.match_count() == 2
-    assert fm.ratio() == 1.0
-    assert fm.a_ratio() == 2 / 3
-    assert fm.b_ratio() == 2 / 3
-    assert fm.b_gap_ratio() == 2 / 3
-    assert fm.a_index_range() == (0, 2)
-    assert fm.b_index_range() == (0, 2)
-
-  def test_partial_match_ignore_junk_at_the_end_of_a(self):
-    fm = FuzzyMatchResult('ab.', 'ab', [(0, 0, 2)], isjunk=lambda s, i: s[i] == '.')
-    assert fm.has_match()
-    assert fm.match_count() == 2
-    assert fm.ratio() == 1.0
-    assert fm.a_ratio() == 1.0
-    assert fm.b_ratio() == 1.0
-    assert fm.b_gap_ratio() == 1.0
-    assert fm.a_index_range() == (0, 2)
-    assert fm.b_index_range() == (0, 2)
-
-  def test_partial_match_ignore_junk_at_the_end_of_b(self):
-    fm = FuzzyMatchResult('ab', 'ab.', [(0, 0, 2)], isjunk=lambda s, i: s[i] == '.')
-    assert fm.has_match()
-    assert fm.match_count() == 2
-    assert fm.ratio() == 1.0
-    assert fm.a_ratio() == 1.0
-    assert fm.b_ratio() == 1.0
-    assert fm.b_gap_ratio() == 1.0
-    assert fm.a_index_range() == (0, 2)
-    assert fm.b_index_range() == (0, 2)
-
-  def test_partial_match_ignore_junk_in_the_middle_of_a(self):
-    fm = FuzzyMatchResult('a.b', 'ab', [(0, 0, 1), (2, 1, 1)], isjunk=lambda s, i: s[i] == '.')
-    assert fm.has_match()
-    assert fm.match_count() == 2
-    assert fm.ratio() == 1.0
-    assert fm.a_ratio() == 1.0
-    assert fm.b_ratio() == 1.0
-    assert fm.b_gap_ratio() == 1.0
-    assert fm.a_index_range() == (0, 3)
-    assert fm.b_index_range() == (0, 2)
-
-  def test_partial_match_ignore_junk_in_the_middle_of_b(self):
-    fm = FuzzyMatchResult('ab', 'a.b', [(0, 0, 1), (1, 2, 1)], isjunk=lambda s, i: s[i] == '.')
-    assert fm.has_match()
-    assert fm.match_count() == 2
-    assert fm.ratio() == 1.0
-    assert fm.a_ratio() == 1.0
-    assert fm.b_ratio() == 1.0
-    assert fm.b_gap_ratio() == 1.0
-    assert fm.a_index_range() == (0, 2)
-    assert fm.b_index_range() == (0, 3)
-
-  def test_should_not_double_count_matching_junk(self):
-    fm = FuzzyMatchResult('a.b', 'a.b', [(0, 0, 3)], isjunk=lambda s, i: s[i] == '.')
-    assert fm.has_match()
-    assert fm.match_count() == 3
-    assert fm.ratio() == 1.0
-    assert fm.a_ratio() == 1.0
-    assert fm.b_ratio() == 1.0
-    assert fm.b_gap_ratio() == 1.0
-    assert fm.a_index_range() == (0, 3)
-    assert fm.b_index_range() == (0, 3)
-
-  def test_a_split_no_match(self):
-    fm = FuzzyMatchResult('abc', 'xyz', [])
-    fm_1, fm_2 = fm.a_split_at(2)
-
-    assert not fm_1.has_match()
-    assert fm_1.a == 'ab'
-    assert fm_1.b == 'xyz'
-
-    assert not fm_2.has_match()
-    assert fm_2.a == 'c'
-    assert fm_2.b == 'xyz'
-
-  def test_b_split_no_match(self):
-    fm = FuzzyMatchResult('abc', 'xyz', [])
-    fm_1, fm_2 = fm.b_split_at(2)
-
-    assert not fm_1.has_match()
-    assert fm_1.a == 'abc'
-    assert fm_1.b == 'xy'
-
-    assert not fm_2.has_match()
-    assert fm_2.a == 'abc'
-    assert fm_2.b == 'z'
-
-  def test_a_split_exact_match(self):
-    fm = FuzzyMatchResult('abc', 'abc', [(0, 0, 3)])
-    fm_1, fm_2 = fm.a_split_at(2)
-
-    assert fm_1.a == 'ab'
-    assert fm_1.b == 'abc'
-    assert fm_1.has_match()
-    assert fm_1.ratio() == 1.0
-    assert fm_1.a_ratio() == 1.0
-    assert fm_1.b_ratio() == 2 / 3
-    assert fm_1.b_gap_ratio() == 2 / 3
-    assert fm_1.a_index_range() == (0, 2)
-    assert fm_1.b_index_range() == (0, 2)
-
-    assert fm_2.a == 'c'
-    assert fm_2.b == 'abc'
-    assert fm_2.has_match()
-    assert fm_2.ratio() == 1.0
-    assert fm_2.a_ratio() == 1.0
-    assert fm_2.b_ratio() == 1 / 3
-    assert fm_2.b_gap_ratio() == 1 / 3
-    assert fm_2.a_index_range() == (0, 1)
-    assert fm_2.b_index_range() == (0, 1)
+    def test_exact_match(self):
+        fm = FuzzyMatchResult('abc', 'abc', [(0, 0, 3)])
+        assert fm.has_match()
+        assert fm.match_count() == 3
+        assert fm.ratio() == 1.0
+        assert fm.a_ratio() == 1.0
+        assert fm.b_ratio() == 1.0
+        assert fm.b_gap_ratio() == 1.0
+        assert fm.a_index_range() == (0, 3)
+        assert fm.b_index_range() == (0, 3)
+
+    def test_no_match(self):
+        fm = FuzzyMatchResult('abc', 'xyz', [])
+        assert not fm.has_match()
+        assert fm.match_count() == 0
+
+    def test_partial_match(self):
+        fm = FuzzyMatchResult('abx', 'aby', [(0, 0, 2)])
+        assert fm.has_match()
+        assert fm.match_count() == 2
+        assert fm.ratio() == 1.0
+        assert fm.a_ratio() == 2 / 3
+        assert fm.b_ratio() == 2 / 3
+        assert fm.b_gap_ratio() == 2 / 3
+        assert fm.a_index_range() == (0, 2)
+        assert fm.b_index_range() == (0, 2)
+
+    def test_partial_match_ignore_junk_at_the_end_of_a(self):
+        fm = FuzzyMatchResult('ab.', 'ab', [(0, 0, 2)], isjunk=lambda s, i: s[i] == '.')
+        assert fm.has_match()
+        assert fm.match_count() == 2
+        assert fm.ratio() == 1.0
+        assert fm.a_ratio() == 1.0
+        assert fm.b_ratio() == 1.0
+        assert fm.b_gap_ratio() == 1.0
+        assert fm.a_index_range() == (0, 2)
+        assert fm.b_index_range() == (0, 2)
+
+    def test_partial_match_ignore_junk_at_the_end_of_b(self):
+        fm = FuzzyMatchResult('ab', 'ab.', [(0, 0, 2)], isjunk=lambda s, i: s[i] == '.')
+        assert fm.has_match()
+        assert fm.match_count() == 2
+        assert fm.ratio() == 1.0
+        assert fm.a_ratio() == 1.0
+        assert fm.b_ratio() == 1.0
+        assert fm.b_gap_ratio() == 1.0
+        assert fm.a_index_range() == (0, 2)
+        assert fm.b_index_range() == (0, 2)
+
+    def test_partial_match_ignore_junk_in_the_middle_of_a(self):
+        fm = FuzzyMatchResult('a.b', 'ab', [(0, 0, 1), (2, 1, 1)], isjunk=lambda s, i: s[i] == '.')
+        assert fm.has_match()
+        assert fm.match_count() == 2
+        assert fm.ratio() == 1.0
+        assert fm.a_ratio() == 1.0
+        assert fm.b_ratio() == 1.0
+        assert fm.b_gap_ratio() == 1.0
+        assert fm.a_index_range() == (0, 3)
+        assert fm.b_index_range() == (0, 2)
+
+    def test_partial_match_ignore_junk_in_the_middle_of_b(self):
+        fm = FuzzyMatchResult('ab', 'a.b', [(0, 0, 1), (1, 2, 1)], isjunk=lambda s, i: s[i] == '.')
+        assert fm.has_match()
+        assert fm.match_count() == 2
+        assert fm.ratio() == 1.0
+        assert fm.a_ratio() == 1.0
+        assert fm.b_ratio() == 1.0
+        assert fm.b_gap_ratio() == 1.0
+        assert fm.a_index_range() == (0, 2)
+        assert fm.b_index_range() == (0, 3)
+
+    def test_should_not_double_count_matching_junk(self):
+        fm = FuzzyMatchResult('a.b', 'a.b', [(0, 0, 3)], isjunk=lambda s, i: s[i] == '.')
+        assert fm.has_match()
+        assert fm.match_count() == 3
+        assert fm.ratio() == 1.0
+        assert fm.a_ratio() == 1.0
+        assert fm.b_ratio() == 1.0
+        assert fm.b_gap_ratio() == 1.0
+        assert fm.a_index_range() == (0, 3)
+        assert fm.b_index_range() == (0, 3)
+
+    def test_a_split_no_match(self):
+        fm = FuzzyMatchResult('abc', 'xyz', [])
+        fm_1, fm_2 = fm.a_split_at(2)
+
+        assert not fm_1.has_match()
+        assert fm_1.a == 'ab'
+        assert fm_1.b == 'xyz'
+
+        assert not fm_2.has_match()
+        assert fm_2.a == 'c'
+        assert fm_2.b == 'xyz'
+
+    def test_b_split_no_match(self):
+        fm = FuzzyMatchResult('abc', 'xyz', [])
+        fm_1, fm_2 = fm.b_split_at(2)
+
+        assert not fm_1.has_match()
+        assert fm_1.a == 'abc'
+        assert fm_1.b == 'xy'
+
+        assert not fm_2.has_match()
+        assert fm_2.a == 'abc'
+        assert fm_2.b == 'z'
+
+    def test_a_split_exact_match(self):
+        fm = FuzzyMatchResult('abc', 'abc', [(0, 0, 3)])
+        fm_1, fm_2 = fm.a_split_at(2)
+
+        assert fm_1.a == 'ab'
+        assert fm_1.b == 'abc'
+        assert fm_1.has_match()
+        assert fm_1.ratio() == 1.0
+        assert fm_1.a_ratio() == 1.0
+        assert fm_1.b_ratio() == 2 / 3
+        assert fm_1.b_gap_ratio() == 2 / 3
+        assert fm_1.a_index_range() == (0, 2)
+        assert fm_1.b_index_range() == (0, 2)
+
+        assert fm_2.a == 'c'
+        assert fm_2.b == 'abc'
+        assert fm_2.has_match()
+        assert fm_2.ratio() == 1.0
+        assert fm_2.a_ratio() == 1.0
+        assert fm_2.b_ratio() == 1 / 3
+        assert fm_2.b_gap_ratio() == 1 / 3
+        assert fm_2.a_index_range() == (0, 1)
+        assert fm_2.b_index_range() == (0, 1)
diff --git a/sciencebeam_gym/preprocess/annotation/matching_annotator.py b/sciencebeam_gym/preprocess/annotation/matching_annotator.py
index 564d45c..7b5b59b 100644
--- a/sciencebeam_gym/preprocess/annotation/matching_annotator.py
+++ b/sciencebeam_gym/preprocess/annotation/matching_annotator.py
@@ -8,35 +8,35 @@ from itertools import tee, islice
 from six.moves import zip_longest
 
 from sciencebeam_utils.utils.compat import (
-  python_2_unicode_compatible
+    python_2_unicode_compatible
 )
 
 from sciencebeam_utils.utils.csv import (
-  csv_delimiter_by_filename,
-  write_csv_row
+    csv_delimiter_by_filename,
+    write_csv_row
 )
 
 from sciencebeam_utils.utils.string import (
-  LazyStr
+    LazyStr
 )
 
 from sciencebeam_utils.utils.collection import (
-  iter_flatten,
-  extract_from_dict
+    iter_flatten,
+    extract_from_dict
 )
 
 from sciencebeam_gym.structured_document import (
-  B_TAG_PREFIX,
-  I_TAG_PREFIX
+    B_TAG_PREFIX,
+    I_TAG_PREFIX
 )
 
 from sciencebeam_gym.preprocess.annotation.fuzzy_match import (
-  remove_junk,
-  fuzzy_match
+    remove_junk,
+    fuzzy_match
 )
 
 from sciencebeam_gym.preprocess.annotation.annotator import (
-  AbstractAnnotator
+    AbstractAnnotator
 )
 
 THIN_SPACE = u'\u2009'
@@ -46,623 +46,653 @@ EM_DASH = u'\u2014'
 DEFAULT_SCORE_THRESHOLD = 0.9
 DEFAULT_MAX_MATCH_GAP = 5
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def normalise_str(s):
-  return s.lower().replace(EM_DASH, u'-').replace(EN_DASH, u'-').replace(THIN_SPACE, ' ')
+    return s.lower().replace(EM_DASH, u'-').replace(EN_DASH, u'-').replace(THIN_SPACE, ' ')
+
 
 def normalise_str_or_list(x):
-  if isinstance(x, list):
-    return [normalise_str(s) for s in x]
-  else:
-    return normalise_str(x)
+    if isinstance(x, list):
+        return [normalise_str(s) for s in x]
+    else:
+        return normalise_str(x)
+
 
 def normalise_and_remove_junk_str(s):
-  return remove_junk(normalise_str(s))
+    return remove_junk(normalise_str(s))
+
 
 def normalise_and_remove_junk_str_or_list(x):
-  if isinstance(x, list):
-    return [normalise_and_remove_junk_str(s) for s in x]
-  else:
-    return normalise_and_remove_junk_str(x)
+    if isinstance(x, list):
+        return [normalise_and_remove_junk_str(s) for s in x]
+    else:
+        return normalise_and_remove_junk_str(x)
+
 
 class SequenceWrapper(object):
-  def __init__(self, structured_document, tokens, str_filter_f=None):
-    self.structured_document = structured_document
-    self.str_filter_f = str_filter_f
-    self.tokens = tokens
-    self.token_str_list = [structured_document.get_text(t) or '' for t in tokens]
-    if str_filter_f:
-      self.token_str_list = [str_filter_f(s) for s in self.token_str_list]
-    self.tokens_as_str = ' '.join(self.token_str_list)
-
-  def tokens_between(self, index_range):
-    start, end = index_range
-    i = 0
-    for token, token_str in zip(self.tokens, self.token_str_list):
-      if i >= end:
-        break
-      token_end = i + len(token_str)
-      if token_end > start:
-        yield token
-      i = token_end + 1
-
-  def sub_sequence_for_tokens(self, tokens):
-    return SequenceWrapper(self.structured_document, tokens, str_filter_f=self.str_filter_f)
-
-  def untagged_sub_sequences(self):
-    token_tags = [self.structured_document.get_tag(t) for t in self.tokens]
-    tagged_count = len([t for t in token_tags if t])
-    if tagged_count == 0:
-      yield self
-    elif tagged_count == len(self.tokens):
-      pass
-    else:
-      untagged_tokens = []
-      for token, tag in zip(self.tokens, token_tags):
-        if not tag:
-          untagged_tokens.append(token)
-        elif untagged_tokens:
-          yield self.sub_sequence_for_tokens(untagged_tokens)
-          untagged_tokens = []
-      if untagged_tokens:
-        yield self.sub_sequence_for_tokens(untagged_tokens)
-
-  def __str__(self):
-    return self.tokens_as_str
-
-  def __repr__(self):
-    return '{}({})'.format('SequenceWrapper', self.tokens_as_str)
+    def __init__(self, structured_document, tokens, str_filter_f=None):
+        self.structured_document = structured_document
+        self.str_filter_f = str_filter_f
+        self.tokens = tokens
+        self.token_str_list = [structured_document.get_text(t) or '' for t in tokens]
+        if str_filter_f:
+            self.token_str_list = [str_filter_f(s) for s in self.token_str_list]
+        self.tokens_as_str = ' '.join(self.token_str_list)
+
+    def tokens_between(self, index_range):
+        start, end = index_range
+        i = 0
+        for token, token_str in zip(self.tokens, self.token_str_list):
+            if i >= end:
+                break
+            token_end = i + len(token_str)
+            if token_end > start:
+                yield token
+            i = token_end + 1
+
+    def sub_sequence_for_tokens(self, tokens):
+        return SequenceWrapper(self.structured_document, tokens, str_filter_f=self.str_filter_f)
+
+    def untagged_sub_sequences(self):
+        token_tags = [self.structured_document.get_tag(t) for t in self.tokens]
+        tagged_count = len([t for t in token_tags if t])
+        if tagged_count == 0:
+            yield self
+        elif tagged_count == len(self.tokens):
+            pass
+        else:
+            untagged_tokens = []
+            for token, tag in zip(self.tokens, token_tags):
+                if not tag:
+                    untagged_tokens.append(token)
+                elif untagged_tokens:
+                    yield self.sub_sequence_for_tokens(untagged_tokens)
+                    untagged_tokens = []
+            if untagged_tokens:
+                yield self.sub_sequence_for_tokens(untagged_tokens)
+
+    def __str__(self):
+        return self.tokens_as_str
+
+    def __repr__(self):
+        return '{}({})'.format('SequenceWrapper', self.tokens_as_str)
+
 
 class SequenceWrapperWithPosition(SequenceWrapper):
-  def __init__(self, *args, **kwargs):
-    position, kwargs = extract_from_dict(kwargs, 'position')
-    super(SequenceWrapperWithPosition, self).__init__(*args, **kwargs)
-    self.position = position
-
-  def sub_sequence_for_tokens(self, tokens):
-    return SequenceWrapperWithPosition(
-      self.structured_document, tokens,
-      str_filter_f=self.str_filter_f,
-      position=self.position
-    )
+    def __init__(self, *args, **kwargs):
+        position, kwargs = extract_from_dict(kwargs, 'position')
+        super(SequenceWrapperWithPosition, self).__init__(*args, **kwargs)
+        self.position = position
+
+    def sub_sequence_for_tokens(self, tokens):
+        return SequenceWrapperWithPosition(
+            self.structured_document, tokens,
+            str_filter_f=self.str_filter_f,
+            position=self.position
+        )
+
+    def __repr__(self):
+        return '{}({}, {})'.format('SequenceWrapperWithPosition', self.tokens_as_str, self.position)
 
-  def __repr__(self):
-    return '{}({}, {})'.format('SequenceWrapperWithPosition', self.tokens_as_str, self.position)
 
 @python_2_unicode_compatible
 class SequenceMatch(object):
-  def __init__(self, seq1, seq2, index1_range, index2_range):
-    self.seq1 = seq1
-    self.seq2 = seq2
-    self.index1_range = index1_range
-    self.index2_range = index2_range
-
-  def __repr__(self):
-    return u"SequenceMatch('{}'[{}:{}], '{}'[{}:{}])".format(
-      self.seq1,
-      self.index1_range[0],
-      self.index1_range[1],
-      self.seq2,
-      self.index2_range[0],
-      self.index2_range[1]
-    )
+    def __init__(self, seq1, seq2, index1_range, index2_range):
+        self.seq1 = seq1
+        self.seq2 = seq2
+        self.index1_range = index1_range
+        self.index2_range = index2_range
+
+    def __repr__(self):
+        return u"SequenceMatch('{}'[{}:{}], '{}'[{}:{}])".format(
+            self.seq1,
+            self.index1_range[0],
+            self.index1_range[1],
+            self.seq2,
+            self.index2_range[0],
+            self.index2_range[1]
+        )
+
 
 @python_2_unicode_compatible
 class PositionedSequenceSet(object):
-  def __init__(self):
-    self.data = set()
+    def __init__(self):
+        self.data = set()
+
+    def add(self, sequence):
+        self.data.add(sequence.position)
 
-  def add(self, sequence):
-    self.data.add(sequence.position)
+    def is_close_to_any(self, sequence, max_gap):
+        if not max_gap or not self.data:
+            return True
+        position = sequence.position
+        max_distance = max_gap + 1
+        for other_position in self.data:
+            if abs(position - other_position) <= max_distance:
+                return True
+        return False
 
-  def is_close_to_any(self, sequence, max_gap):
-    if not max_gap or not self.data:
-      return True
-    position = sequence.position
-    max_distance = max_gap + 1
-    for other_position in self.data:
-      if abs(position - other_position) <= max_distance:
-        return True
-    return False
+    def __str__(self):
+        return str(self.data)
 
-  def __str__(self):
-    return str(self.data)
 
 def offset_range_by(index_range, offset):
-  if not offset:
-    return index_range
-  return (offset + index_range[0], offset + index_range[1])
+    if not offset:
+        return index_range
+    return (offset + index_range[0], offset + index_range[1])
+
 
 def skip_whitespaces(s, start):
-  while start < len(s) and s[start].isspace():
-    start += 1
-  return start
+    while start < len(s) and s[start].isspace():
+        start += 1
+    return start
+
 
 def get_fuzzy_match_filter(
-  b_score_threshold, min_match_count, total_match_threshold,
-  ratio_min_match_count, ratio_threshold):
-  def check(fm, fm_next=None, previous_match=False):
-    if (
-      fm.match_count() >= ratio_min_match_count and
-      fm.ratio() >= ratio_threshold):
-      return True
-    return (
-      fm.b_gap_ratio() >= b_score_threshold and
-      (
-        previous_match or
-        (
-          fm.match_count() >= min_match_count and
-          (fm_next is None or fm_next.ratio() >= ratio_threshold)
-        ) or
-        fm.a_ratio() >= total_match_threshold
-      )
-    )
-  return check
+        b_score_threshold, min_match_count, total_match_threshold,
+        ratio_min_match_count, ratio_threshold):
+    def check(fm, fm_next=None, previous_match=False):
+        if (
+                fm.match_count() >= ratio_min_match_count and
+                fm.ratio() >= ratio_threshold):
+            return True
+        return (
+            fm.b_gap_ratio() >= b_score_threshold and
+            (
+                previous_match or
+                (
+                    fm.match_count() >= min_match_count and
+                    (fm_next is None or fm_next.ratio() >= ratio_threshold)
+                ) or
+                fm.a_ratio() >= total_match_threshold
+            )
+        )
+    return check
+
 
 DEFAULT_SEQ_FUZZY_MATCH_FILTER = get_fuzzy_match_filter(
-  DEFAULT_SCORE_THRESHOLD,
-  5,
-  0.9,
-  50,
-  0.9
+    DEFAULT_SCORE_THRESHOLD,
+    5,
+    0.9,
+    50,
+    0.9
 )
 
 DEFAULT_CHOICE_FUZZY_MATCH_FILTER = get_fuzzy_match_filter(
-  DEFAULT_SCORE_THRESHOLD,
-  1,
-  0.9,
-  100,
-  0.9
+    DEFAULT_SCORE_THRESHOLD,
+    1,
+    0.9,
+    100,
+    0.9
 )
 
+
 class MatchDebugFields(object):
-  ID = 'id'
-  TAG = 'tag'
-  MATCH_MULTIPLE = 'match_multiple'
-  TAG_VALUE_PRE = 'tag_value_pre'
-  TAG_VALUE_CURRENT = 'tag_value_current'
-  START_INDEX = 'start_index'
-  NEXT_START_INDEX = 'next_start_index'
-  REACHED_END = 'reached_end'
-  CHOICE_COMBINED = 'choice_combined'
-  CHOICE_CURRENT = 'choice_current'
-  CHOICE_NEXT = 'choice_next'
-  ACCEPTED = 'accepted'
-  TAG_TO_CHOICE_MATCH = 'tag_to_choice_match'
-  SUB_ANNOTATION = 'sub_annotation'
-  FM_COMBINED = 'fm_combined'
-  FM_COMBINED_DETAILED = 'fm_combined_detailed'
-  FM_CURRENT = 'fm_current'
-  FM_CURRENT_DETAILED = 'fm_current_detailed'
-  FM_NEXT = 'fm_next'
-  FM_NEXT_DETAILED = 'fm_next_detailed'
+    ID = 'id'
+    TAG = 'tag'
+    MATCH_MULTIPLE = 'match_multiple'
+    TAG_VALUE_PRE = 'tag_value_pre'
+    TAG_VALUE_CURRENT = 'tag_value_current'
+    START_INDEX = 'start_index'
+    NEXT_START_INDEX = 'next_start_index'
+    REACHED_END = 'reached_end'
+    CHOICE_COMBINED = 'choice_combined'
+    CHOICE_CURRENT = 'choice_current'
+    CHOICE_NEXT = 'choice_next'
+    ACCEPTED = 'accepted'
+    TAG_TO_CHOICE_MATCH = 'tag_to_choice_match'
+    SUB_ANNOTATION = 'sub_annotation'
+    FM_COMBINED = 'fm_combined'
+    FM_COMBINED_DETAILED = 'fm_combined_detailed'
+    FM_CURRENT = 'fm_current'
+    FM_CURRENT_DETAILED = 'fm_current_detailed'
+    FM_NEXT = 'fm_next'
+    FM_NEXT_DETAILED = 'fm_next_detailed'
+
 
 DEFAULT_MATCH_DEBUG_COLUMNS = [
-  MatchDebugFields.ID,
-  MatchDebugFields.TAG,
-  MatchDebugFields.MATCH_MULTIPLE,
-  MatchDebugFields.TAG_VALUE_PRE,
-  MatchDebugFields.TAG_VALUE_CURRENT,
-  MatchDebugFields.START_INDEX,
-  MatchDebugFields.NEXT_START_INDEX,
-  MatchDebugFields.REACHED_END,
-  MatchDebugFields.CHOICE_COMBINED,
-  MatchDebugFields.CHOICE_CURRENT,
-  MatchDebugFields.CHOICE_NEXT,
-  MatchDebugFields.ACCEPTED,
-  MatchDebugFields.TAG_TO_CHOICE_MATCH,
-  MatchDebugFields.SUB_ANNOTATION,
-  MatchDebugFields.FM_COMBINED,
-  MatchDebugFields.FM_COMBINED_DETAILED,
-  MatchDebugFields.FM_CURRENT,
-  MatchDebugFields.FM_CURRENT_DETAILED,
-  MatchDebugFields.FM_NEXT,
-  MatchDebugFields.FM_NEXT_DETAILED
+    MatchDebugFields.ID,
+    MatchDebugFields.TAG,
+    MatchDebugFields.MATCH_MULTIPLE,
+    MatchDebugFields.TAG_VALUE_PRE,
+    MatchDebugFields.TAG_VALUE_CURRENT,
+    MatchDebugFields.START_INDEX,
+    MatchDebugFields.NEXT_START_INDEX,
+    MatchDebugFields.REACHED_END,
+    MatchDebugFields.CHOICE_COMBINED,
+    MatchDebugFields.CHOICE_CURRENT,
+    MatchDebugFields.CHOICE_NEXT,
+    MatchDebugFields.ACCEPTED,
+    MatchDebugFields.TAG_TO_CHOICE_MATCH,
+    MatchDebugFields.SUB_ANNOTATION,
+    MatchDebugFields.FM_COMBINED,
+    MatchDebugFields.FM_COMBINED_DETAILED,
+    MatchDebugFields.FM_CURRENT,
+    MatchDebugFields.FM_CURRENT_DETAILED,
+    MatchDebugFields.FM_NEXT,
+    MatchDebugFields.FM_NEXT_DETAILED
 ]
 
+
 class TargetAnnotationMatchFinder(object):
-  def __init__(
-    self,
-    target_annotation,
-    sequence,
-    choices,
-    seq_match_filter=DEFAULT_SEQ_FUZZY_MATCH_FILTER,
-    choice_match_filter=DEFAULT_CHOICE_FUZZY_MATCH_FILTER,
-    max_gap=DEFAULT_MAX_MATCH_GAP,
-    matched_choices=None,
-    match_detail_reporter=None,
-    is_sub_match=False
-  ):
-    if matched_choices is None:
-      matched_choices = PositionedSequenceSet()
-    self.target_annotation = target_annotation
-    self.sequence = sequence
-    self.choices = choices
-    self.seq_match_filter = seq_match_filter
-    self.choice_match_filter = choice_match_filter
-    self.max_gap = max_gap
-    self.matched_choices = matched_choices
-    self.match_detail_reporter = match_detail_reporter
-    self.is_sub_match = is_sub_match
-    self.current_choices, self.next_choices = tee(choices, 2)
-    self.next_choices = islice(self.next_choices, 1, None)
-
-  def find_next_best_matches(self):
-    if isinstance(self.sequence, list):
-      all_matches = []
-      get_logger().debug('found sequence list: %s', self.sequence)
-      # Use tee as choices may be an iterable instead of a list
-      for s, sub_current_choices, sub_next_choices in zip(
-        self.sequence,
-        tee(self.current_choices, len(self.sequence)),
-        tee(self.next_choices, len(self.sequence))
-      ):
-        all_matches.extend(self._do_find_next_best_matches(
-          s,
-          sub_current_choices,
-          sub_next_choices
-        ))
-      get_logger().debug(
-        'all_matches (bonding=%s): %s', self.target_annotation.bonding, all_matches
-      )
-      if not self.target_annotation.bonding or len(all_matches) > 1 or self.target_annotation.name == 'author':
-        for m in all_matches:
-          yield m
-    else:
-      matches = self._do_find_next_best_matches(
-        self.sequence, self.current_choices, self.next_choices
-      )
-      for m in matches:
-        yield m
-
-  def _do_find_next_best_matches(self, sequence, current_choices, next_choices):
-    target_annotation = self.target_annotation
-    seq_match_filter = self.seq_match_filter
-    choice_match_filter = self.choice_match_filter
-    max_gap = self.max_gap
-    matched_choices = self.matched_choices
-
-    start_index = 0
-    s1 = text(sequence)
-    too_distant_choices = []
-
-    is_last_match = False
-    previous_match = False
-
-    for choice, next_choice in zip_longest(current_choices, next_choices):
-      if not matched_choices.is_close_to_any(choice, max_gap=max_gap):
-        too_distant_choices.append(choice)
-        continue
-      current_choice_str = text(choice)
-      if not current_choice_str:
-        return
-      if next_choice:
-        next_choice_str = text(next_choice)
-        choice_str = current_choice_str + ' ' + next_choice_str
-      else:
-        choice_str = current_choice_str
-        next_choice_str = None
-      current_start_index = start_index
-      get_logger().debug(
-        'processing choice: tag=%s, s1[:%d]=%s, s1[%d:]=%s, current=%s, next=%s (%s), combined=%s',
-        target_annotation.name,
-        start_index, s1[:start_index],
-        start_index, s1[start_index:],
-        current_choice_str,
-        next_choice_str, type(next_choice_str), choice_str
-      )
-      fm_combined, fm, fm_next = None, None, None
-      reached_end = None
-      tag_to_choice_match = self.is_sub_match or (len(s1) - start_index < len(current_choice_str))
-      if not tag_to_choice_match:
-        fm_combined = fuzzy_match(s1, choice_str)
-        fm, fm_next = fm_combined.b_split_at(len(current_choice_str))
-        get_logger().debug(
-          'regular match: s1=%s, choice=%s, fm=%s (combined: %s)',
-          s1, choice, fm, fm_combined
-        )
-        get_logger().debug('detailed match: %s', fm_combined.detailed())
-        accept_match = fm.has_match() and (
-          seq_match_filter(fm, fm_next, previous_match=previous_match) or
-          (seq_match_filter(fm_combined) and fm.b_start_index() < len(current_choice_str))
-        )
-        if accept_match:
-          previous_match = True
-          sm = SequenceMatch(
-            sequence,
-            choice,
-            fm.a_index_range(),
-            fm.b_index_range()
-          )
-          matched_choices.add(choice)
-          get_logger().debug('found match: %s', sm)
-          yield sm
-          if fm_next.has_match():
-            sm = SequenceMatch(
-              sequence,
-              next_choice,
-              fm_next.a_index_range(),
-              fm_next.b_index_range()
+    def __init__(
+        self,
+        target_annotation,
+        sequence,
+        choices,
+        seq_match_filter=DEFAULT_SEQ_FUZZY_MATCH_FILTER,
+        choice_match_filter=DEFAULT_CHOICE_FUZZY_MATCH_FILTER,
+        max_gap=DEFAULT_MAX_MATCH_GAP,
+        matched_choices=None,
+        match_detail_reporter=None,
+        is_sub_match=False
+    ):
+        if matched_choices is None:
+            matched_choices = PositionedSequenceSet()
+        self.target_annotation = target_annotation
+        self.sequence = sequence
+        self.choices = choices
+        self.seq_match_filter = seq_match_filter
+        self.choice_match_filter = choice_match_filter
+        self.max_gap = max_gap
+        self.matched_choices = matched_choices
+        self.match_detail_reporter = match_detail_reporter
+        self.is_sub_match = is_sub_match
+        self.current_choices, self.next_choices = tee(choices, 2)
+        self.next_choices = islice(self.next_choices, 1, None)
+
+    def find_next_best_matches(self):
+        if isinstance(self.sequence, list):
+            all_matches = []
+            get_logger().debug('found sequence list: %s', self.sequence)
+            # Use tee as choices may be an iterable instead of a list
+            for s, sub_current_choices, sub_next_choices in zip(
+                self.sequence,
+                tee(self.current_choices, len(self.sequence)),
+                tee(self.next_choices, len(self.sequence))
+            ):
+                all_matches.extend(self._do_find_next_best_matches(
+                    s,
+                    sub_current_choices,
+                    sub_next_choices
+                ))
+            get_logger().debug(
+                'all_matches (bonding=%s): %s', self.target_annotation.bonding, all_matches
             )
-            matched_choices.add(choice)
-            get_logger().debug('found next match: %s', sm)
-            yield sm
-            index1_end = skip_whitespaces(s1, fm_next.a_end_index())
-          else:
-            index1_end = skip_whitespaces(s1, fm.a_end_index())
-          reached_end = index1_end >= len(s1)
-          if reached_end:
-            get_logger().debug('end reached: %d >= %d', index1_end, len(s1))
-            is_last_match = True
-          else:
-            start_index = index1_end
-            get_logger().debug('setting start index to: %d', start_index)
-      else:
-        s1_sub = s1[start_index:]
-        fm_combined = fuzzy_match(choice_str, s1_sub)
-        fm, fm_next = fm_combined.a_split_at(len(current_choice_str))
-        get_logger().debug(
-          'short match: s1_sub=%s, choice=%s, fm=%s (combined: %s)',
-          s1_sub, choice, fm, fm_combined
-        )
-        get_logger().debug('detailed match: %s', fm_combined.detailed())
-        accept_match = fm.has_match() and (
-          choice_match_filter(fm, previous_match=previous_match) or
-          (
-            choice_match_filter(fm_combined) and
-            fm_combined.a_start_index() < len(current_choice_str)
-          )
-        )
-        if accept_match:
-          sm = SequenceMatch(
-            sequence,
-            choice,
-            offset_range_by(fm.b_index_range(), start_index),
-            fm.a_index_range()
-          )
-          matched_choices.add(choice)
-          get_logger().debug('found match: %s', sm)
-          yield sm
-          if fm_next.has_match():
-            sm = SequenceMatch(
-              sequence,
-              next_choice,
-              offset_range_by(fm_next.b_index_range(), start_index),
-              fm_next.a_index_range()
+            if (
+                not self.target_annotation.bonding or
+                len(all_matches) > 1 or
+                self.target_annotation.name == 'author'
+            ):
+                for m in all_matches:
+                    yield m
+        else:
+            matches = self._do_find_next_best_matches(
+                self.sequence, self.current_choices, self.next_choices
+            )
+            for m in matches:
+                yield m
+
+    def _do_find_next_best_matches(self, sequence, current_choices, next_choices):
+        target_annotation = self.target_annotation
+        seq_match_filter = self.seq_match_filter
+        choice_match_filter = self.choice_match_filter
+        max_gap = self.max_gap
+        matched_choices = self.matched_choices
+
+        start_index = 0
+        s1 = text(sequence)
+        too_distant_choices = []
+
+        is_last_match = False
+        previous_match = False
+
+        for choice, next_choice in zip_longest(current_choices, next_choices):
+            if not matched_choices.is_close_to_any(choice, max_gap=max_gap):
+                too_distant_choices.append(choice)
+                continue
+            current_choice_str = text(choice)
+            if not current_choice_str:
+                return
+            if next_choice:
+                next_choice_str = text(next_choice)
+                choice_str = current_choice_str + ' ' + next_choice_str
+            else:
+                choice_str = current_choice_str
+                next_choice_str = None
+            current_start_index = start_index
+            get_logger().debug(
+                'processing choice: tag=%s, s1[:%d]=%s, s1[%d:]=%s'
+                ', current=%s, next=%s (%s), combined=%s',
+                target_annotation.name,
+                start_index, s1[:start_index],
+                start_index, s1[start_index:],
+                current_choice_str,
+                next_choice_str, type(next_choice_str), choice_str
             )
-            get_logger().debug('found next match: %s', sm)
-            matched_choices.add(next_choice)
-            yield sm
-          is_last_match = True
-      if self.match_detail_reporter:
-        self.match_detail_reporter({
-          MatchDebugFields.TAG: target_annotation.name,
-          MatchDebugFields.MATCH_MULTIPLE: target_annotation.match_multiple,
-          MatchDebugFields.TAG_VALUE_PRE: s1[:current_start_index],
-          MatchDebugFields.TAG_VALUE_CURRENT: s1[current_start_index:],
-          MatchDebugFields.START_INDEX: current_start_index,
-          MatchDebugFields.NEXT_START_INDEX: start_index,
-          MatchDebugFields.REACHED_END: reached_end,
-          MatchDebugFields.CHOICE_COMBINED: choice_str,
-          MatchDebugFields.CHOICE_CURRENT: current_choice_str,
-          MatchDebugFields.CHOICE_NEXT: next_choice_str,
-          MatchDebugFields.ACCEPTED: accept_match,
-          MatchDebugFields.TAG_TO_CHOICE_MATCH: tag_to_choice_match,
-          MatchDebugFields.SUB_ANNOTATION: self.is_sub_match,
-          MatchDebugFields.FM_COMBINED: fm_combined,
-          MatchDebugFields.FM_COMBINED_DETAILED: fm_combined and fm_combined.detailed_str(),
-          MatchDebugFields.FM_CURRENT: fm,
-          MatchDebugFields.FM_CURRENT_DETAILED: fm and fm.detailed_str(),
-          MatchDebugFields.FM_NEXT: fm_next,
-          MatchDebugFields.FM_NEXT_DETAILED: fm_next.detailed_str()
-        })
-      if is_last_match:
-        break
-    if too_distant_choices:
-      get_logger().debug(
-        'ignored too distant choices: matched=%s (ignored=%s)',
-        matched_choices,
-        LazyStr(lambda: ' '.join(str(choice.position) for choice in too_distant_choices))
-      )
+            fm_combined, fm, fm_next = None, None, None
+            reached_end = None
+            tag_to_choice_match = self.is_sub_match or (
+                len(s1) - start_index < len(current_choice_str))
+            if not tag_to_choice_match:
+                fm_combined = fuzzy_match(s1, choice_str)
+                fm, fm_next = fm_combined.b_split_at(len(current_choice_str))
+                get_logger().debug(
+                    'regular match: s1=%s, choice=%s, fm=%s (combined: %s)',
+                    s1, choice, fm, fm_combined
+                )
+                get_logger().debug('detailed match: %s', fm_combined.detailed())
+                accept_match = fm.has_match() and (
+                    seq_match_filter(fm, fm_next, previous_match=previous_match) or
+                    (seq_match_filter(fm_combined) and fm.b_start_index() < len(current_choice_str))
+                )
+                if accept_match:
+                    previous_match = True
+                    sm = SequenceMatch(
+                        sequence,
+                        choice,
+                        fm.a_index_range(),
+                        fm.b_index_range()
+                    )
+                    matched_choices.add(choice)
+                    get_logger().debug('found match: %s', sm)
+                    yield sm
+                    if fm_next.has_match():
+                        sm = SequenceMatch(
+                            sequence,
+                            next_choice,
+                            fm_next.a_index_range(),
+                            fm_next.b_index_range()
+                        )
+                        matched_choices.add(choice)
+                        get_logger().debug('found next match: %s', sm)
+                        yield sm
+                        index1_end = skip_whitespaces(s1, fm_next.a_end_index())
+                    else:
+                        index1_end = skip_whitespaces(s1, fm.a_end_index())
+                    reached_end = index1_end >= len(s1)
+                    if reached_end:
+                        get_logger().debug('end reached: %d >= %d', index1_end, len(s1))
+                        is_last_match = True
+                    else:
+                        start_index = index1_end
+                        get_logger().debug('setting start index to: %d', start_index)
+            else:
+                s1_sub = s1[start_index:]
+                fm_combined = fuzzy_match(choice_str, s1_sub)
+                fm, fm_next = fm_combined.a_split_at(len(current_choice_str))
+                get_logger().debug(
+                    'short match: s1_sub=%s, choice=%s, fm=%s (combined: %s)',
+                    s1_sub, choice, fm, fm_combined
+                )
+                get_logger().debug('detailed match: %s', fm_combined.detailed())
+                accept_match = fm.has_match() and (
+                    choice_match_filter(fm, previous_match=previous_match) or
+                    (
+                        choice_match_filter(fm_combined) and
+                        fm_combined.a_start_index() < len(current_choice_str)
+                    )
+                )
+                if accept_match:
+                    sm = SequenceMatch(
+                        sequence,
+                        choice,
+                        offset_range_by(fm.b_index_range(), start_index),
+                        fm.a_index_range()
+                    )
+                    matched_choices.add(choice)
+                    get_logger().debug('found match: %s', sm)
+                    yield sm
+                    if fm_next.has_match():
+                        sm = SequenceMatch(
+                            sequence,
+                            next_choice,
+                            offset_range_by(fm_next.b_index_range(), start_index),
+                            fm_next.a_index_range()
+                        )
+                        get_logger().debug('found next match: %s', sm)
+                        matched_choices.add(next_choice)
+                        yield sm
+                    is_last_match = True
+            if self.match_detail_reporter:
+                self.match_detail_reporter({
+                    MatchDebugFields.TAG: target_annotation.name,
+                    MatchDebugFields.MATCH_MULTIPLE: target_annotation.match_multiple,
+                    MatchDebugFields.TAG_VALUE_PRE: s1[:current_start_index],
+                    MatchDebugFields.TAG_VALUE_CURRENT: s1[current_start_index:],
+                    MatchDebugFields.START_INDEX: current_start_index,
+                    MatchDebugFields.NEXT_START_INDEX: start_index,
+                    MatchDebugFields.REACHED_END: reached_end,
+                    MatchDebugFields.CHOICE_COMBINED: choice_str,
+                    MatchDebugFields.CHOICE_CURRENT: current_choice_str,
+                    MatchDebugFields.CHOICE_NEXT: next_choice_str,
+                    MatchDebugFields.ACCEPTED: accept_match,
+                    MatchDebugFields.TAG_TO_CHOICE_MATCH: tag_to_choice_match,
+                    MatchDebugFields.SUB_ANNOTATION: self.is_sub_match,
+                    MatchDebugFields.FM_COMBINED: fm_combined,
+                    MatchDebugFields.FM_COMBINED_DETAILED: (
+                        fm_combined and fm_combined.detailed_str()
+                    ),
+                    MatchDebugFields.FM_CURRENT: fm,
+                    MatchDebugFields.FM_CURRENT_DETAILED: fm and fm.detailed_str(),
+                    MatchDebugFields.FM_NEXT: fm_next,
+                    MatchDebugFields.FM_NEXT_DETAILED: fm_next.detailed_str()
+                })
+            if is_last_match:
+                break
+        if too_distant_choices:
+            get_logger().debug(
+                'ignored too distant choices: matched=%s (ignored=%s)',
+                matched_choices,
+                LazyStr(lambda: ' '.join(str(choice.position) for choice in too_distant_choices))
+            )
+
 
 class CsvMatchDetailReporter(object):
-  def __init__(self, fp, filename=None, fields=None):
-    self.fp = fp
-    self.fields = fields or DEFAULT_MATCH_DEBUG_COLUMNS
-    self.writer = csv.writer(
-      fp,
-      delimiter=csv_delimiter_by_filename(filename)
-    )
-    self.writer.writerow(self.fields)
-    self.id = 1
+    def __init__(self, fp, filename=None, fields=None):
+        self.fp = fp
+        self.fields = fields or DEFAULT_MATCH_DEBUG_COLUMNS
+        self.writer = csv.writer(
+            fp,
+            delimiter=csv_delimiter_by_filename(filename)
+        )
+        self.writer.writerow(self.fields)
+        self.id = 1
 
-  def __call__(self, row):
-    get_logger().debug('logging debug match id %d', self.id)
-    write_csv_row(self.writer, [
-      self.id if k == MatchDebugFields.ID else row.get(k)
-      for k in self.fields
-    ])
-    self.id += 1
+    def __call__(self, row):
+        get_logger().debug('logging debug match id %d', self.id)
+        write_csv_row(self.writer, [
+            self.id if k == MatchDebugFields.ID else row.get(k)
+            for k in self.fields
+        ])
+        self.id += 1
+
+    def close(self):
+        self.fp.close()
 
-  def close(self):
-    self.fp.close()
 
 def sorted_matches_by_position(matches):
-  return sorted(
-    matches,
-    key=lambda m: (m.seq2.position, m.index2_range)
-  )
+    return sorted(
+        matches,
+        key=lambda m: (m.seq2.position, m.index2_range)
+    )
+
 
 def matches_position_range(matches):
-  positions = [m.seq2.position for m in matches]
-  return min(positions), max(positions)
+    positions = [m.seq2.position for m in matches]
+    return min(positions), max(positions)
+
 
 def distance_between_matches(matches1, matches2):
-  matches1_start, matches1_end = matches_position_range(matches1)
-  matches2_start, matches2_end = matches_position_range(matches2)
-  return min(
-    abs(matches2_start - matches1_end),
-    abs(matches1_start - matches2_end)
-  )
+    matches1_start, matches1_end = matches_position_range(matches1)
+    matches2_start, matches2_end = matches_position_range(matches2)
+    return min(
+        abs(matches2_start - matches1_end),
+        abs(matches1_start - matches2_end)
+    )
+
 
 def _apply_sub_annotations(
-  target_annotation, structured_document, matching_tokens,
-  match_detail_reporter, use_tag_begin_prefix):
-
-  seq = SequenceWrapperWithPosition(
-    structured_document, matching_tokens, normalise_str,
-    position=0
-  )
-  matched_choices = PositionedSequenceSet()
-  for sub_annotation in target_annotation.sub_annotations:
-    sub_tag = sub_annotation.name
-    match_finder = TargetAnnotationMatchFinder(
-      sub_annotation,
-      normalise_str_or_list(sub_annotation.value),
-      [seq],
-      matched_choices=matched_choices,
-      match_detail_reporter=match_detail_reporter,
-      is_sub_match=True
-    )
-    matches = match_finder.find_next_best_matches()
-    matches = list(matches)
-    get_logger().info('sub annotation matches: %s', matches)
-    first_token = True
-    for m in matches:
-      choice = m.seq2
-      matching_sub_tokens = list(choice.tokens_between(m.index2_range))
-      get_logger().debug(
-        'matching_sub_tokens: %s %s',
-        [structured_document.get_text(token) for token in matching_tokens],
-        m.index2_range
-      )
-      for token in matching_sub_tokens:
-        tag_prefix = None
-        if use_tag_begin_prefix:
-          tag_prefix = B_TAG_PREFIX if first_token else I_TAG_PREFIX
-        structured_document.set_sub_tag_with_prefix(token, sub_tag, prefix=tag_prefix)
-        first_token = False
+        target_annotation, structured_document, matching_tokens,
+        match_detail_reporter, use_tag_begin_prefix):
 
-def _apply_annotations_to_matches(
-  target_annotation,
-  structured_document,
-  matches,
-  match_detail_reporter,
-  use_tag_begin_prefix):
-
-  first_token = True
-  all_matching_tokens = []
-  for m in matches:
-    choice = m.seq2
-    matching_tokens = list(choice.tokens_between(m.index2_range))
-    get_logger().debug(
-      'matching_tokens: %s %s',
-      [structured_document.get_text(token) for token in matching_tokens],
-      m.index2_range
+    seq = SequenceWrapperWithPosition(
+        structured_document, matching_tokens, normalise_str,
+        position=0
     )
-    for token in matching_tokens:
-      if not structured_document.get_tag(token):
-        tag_prefix = None
-        if use_tag_begin_prefix:
-          tag_prefix = B_TAG_PREFIX if first_token else I_TAG_PREFIX
-        structured_document.set_tag_with_prefix(
-          token,
-          target_annotation.name,
-          prefix=tag_prefix
+    matched_choices = PositionedSequenceSet()
+    for sub_annotation in target_annotation.sub_annotations:
+        sub_tag = sub_annotation.name
+        match_finder = TargetAnnotationMatchFinder(
+            sub_annotation,
+            normalise_str_or_list(sub_annotation.value),
+            [seq],
+            matched_choices=matched_choices,
+            match_detail_reporter=match_detail_reporter,
+            is_sub_match=True
         )
-        first_token = False
-        all_matching_tokens.append(token)
-    if target_annotation.sub_annotations:
-      _apply_sub_annotations(
-        target_annotation, structured_document, all_matching_tokens,
-        match_detail_reporter, use_tag_begin_prefix
-      )
+        matches = match_finder.find_next_best_matches()
+        matches = list(matches)
+        get_logger().info('sub annotation matches: %s', matches)
+        first_token = True
+        for m in matches:
+            choice = m.seq2
+            matching_sub_tokens = list(choice.tokens_between(m.index2_range))
+            get_logger().debug(
+                'matching_sub_tokens: %s %s',
+                [structured_document.get_text(token) for token in matching_tokens],
+                m.index2_range
+            )
+            for token in matching_sub_tokens:
+                tag_prefix = None
+                if use_tag_begin_prefix:
+                    tag_prefix = B_TAG_PREFIX if first_token else I_TAG_PREFIX
+                structured_document.set_sub_tag_with_prefix(token, sub_tag, prefix=tag_prefix)
+                first_token = False
 
 
-class MatchingAnnotator(AbstractAnnotator):
-  def __init__(
-    self, target_annotations, match_detail_reporter=None,
-    use_tag_begin_prefix=False):
-
-    self.target_annotations = target_annotations
-    self.match_detail_reporter = match_detail_reporter
-    self.use_tag_begin_prefix = use_tag_begin_prefix
-
-  def annotate(self, structured_document):
-    pending_sequences = []
-    for page in structured_document.get_pages():
-      for line in structured_document.get_lines_of_page(page):
-        tokens = [
-          token
-          for token in structured_document.get_tokens_of_line(line)
-          if not structured_document.get_tag(token)
-        ]
-        if tokens:
-          get_logger().debug(
-            'tokens without tag: %s',
-            [structured_document.get_text(token) for token in tokens]
-          )
-          pending_sequences.append(SequenceWrapperWithPosition(
-            structured_document,
-            tokens,
-            normalise_and_remove_junk_str,
-            position=len(pending_sequences)
-          ))
-
-    conditional_match = None
-
-    matched_choices_map = dict()
-    for target_annotation in self.target_annotations:
-      get_logger().debug('target annotation: %s', target_annotation)
-      target_value = normalise_and_remove_junk_str_or_list(target_annotation.value)
-      untagged_pending_sequences = iter_flatten(
-        seq.untagged_sub_sequences() for seq in pending_sequences
-      )
-      if target_annotation.bonding:
-        matched_choices = matched_choices_map.setdefault(
-          target_annotation.name,
-          PositionedSequenceSet()
-        )
-      else:
-        matched_choices = PositionedSequenceSet()
-      match_finder = TargetAnnotationMatchFinder(
+def _apply_annotations_to_matches(
         target_annotation,
-        target_value,
-        untagged_pending_sequences,
-        matched_choices=matched_choices,
-        match_detail_reporter=self.match_detail_reporter
-      )
-      item_index = 0
-      while item_index == 0 or target_annotation.match_multiple:
-        get_logger().info('calling find_next_best_matches')
-        matches = sorted_matches_by_position(
-          match_finder.find_next_best_matches()
+        structured_document,
+        matches,
+        match_detail_reporter,
+        use_tag_begin_prefix):
+
+    first_token = True
+    all_matching_tokens = []
+    for m in matches:
+        choice = m.seq2
+        matching_tokens = list(choice.tokens_between(m.index2_range))
+        get_logger().debug(
+            'matching_tokens: %s %s',
+            [structured_document.get_text(token) for token in matching_tokens],
+            m.index2_range
         )
-        if not matches:
-          conditional_match = None
-          break
-        get_logger().info('matches: %s', matches)
-        if (
-          conditional_match and
-          distance_between_matches(matches, conditional_match['matches']) <= 1
-        ):
-          _apply_annotations_to_matches(
-            conditional_match['target_annotation'],
-            structured_document,
-            conditional_match['matches'],
-            self.match_detail_reporter, self.use_tag_begin_prefix
-          )
-        if target_annotation.require_next:
-          conditional_match = dict(
-            target_annotation=target_annotation,
-            matches=matches
-          )
-        else:
-          _apply_annotations_to_matches(
-            target_annotation, structured_document, matches,
-            self.match_detail_reporter, self.use_tag_begin_prefix
-          )
-        item_index += 1
-    return structured_document
+        for token in matching_tokens:
+            if not structured_document.get_tag(token):
+                tag_prefix = None
+                if use_tag_begin_prefix:
+                    tag_prefix = B_TAG_PREFIX if first_token else I_TAG_PREFIX
+                structured_document.set_tag_with_prefix(
+                    token,
+                    target_annotation.name,
+                    prefix=tag_prefix
+                )
+                first_token = False
+                all_matching_tokens.append(token)
+        if target_annotation.sub_annotations:
+            _apply_sub_annotations(
+                target_annotation, structured_document, all_matching_tokens,
+                match_detail_reporter, use_tag_begin_prefix
+            )
+
+
+class MatchingAnnotator(AbstractAnnotator):
+    def __init__(
+            self, target_annotations, match_detail_reporter=None,
+            use_tag_begin_prefix=False):
+
+        self.target_annotations = target_annotations
+        self.match_detail_reporter = match_detail_reporter
+        self.use_tag_begin_prefix = use_tag_begin_prefix
+
+    def annotate(self, structured_document):
+        pending_sequences = []
+        for page in structured_document.get_pages():
+            for line in structured_document.get_lines_of_page(page):
+                tokens = [
+                    token
+                    for token in structured_document.get_tokens_of_line(line)
+                    if not structured_document.get_tag(token)
+                ]
+                if tokens:
+                    get_logger().debug(
+                        'tokens without tag: %s',
+                        [structured_document.get_text(token) for token in tokens]
+                    )
+                    pending_sequences.append(SequenceWrapperWithPosition(
+                        structured_document,
+                        tokens,
+                        normalise_and_remove_junk_str,
+                        position=len(pending_sequences)
+                    ))
+
+        conditional_match = None
+
+        matched_choices_map = dict()
+        for target_annotation in self.target_annotations:
+            get_logger().debug('target annotation: %s', target_annotation)
+            target_value = normalise_and_remove_junk_str_or_list(target_annotation.value)
+            untagged_pending_sequences = iter_flatten(
+                seq.untagged_sub_sequences() for seq in pending_sequences
+            )
+            if target_annotation.bonding:
+                matched_choices = matched_choices_map.setdefault(
+                    target_annotation.name,
+                    PositionedSequenceSet()
+                )
+            else:
+                matched_choices = PositionedSequenceSet()
+            match_finder = TargetAnnotationMatchFinder(
+                target_annotation,
+                target_value,
+                untagged_pending_sequences,
+                matched_choices=matched_choices,
+                match_detail_reporter=self.match_detail_reporter
+            )
+            item_index = 0
+            while item_index == 0 or target_annotation.match_multiple:
+                get_logger().info('calling find_next_best_matches')
+                matches = sorted_matches_by_position(
+                    match_finder.find_next_best_matches()
+                )
+                if not matches:
+                    conditional_match = None
+                    break
+                get_logger().info('matches: %s', matches)
+                if (
+                    conditional_match and
+                    distance_between_matches(matches, conditional_match['matches']) <= 1
+                ):
+                    _apply_annotations_to_matches(
+                        conditional_match['target_annotation'],
+                        structured_document,
+                        conditional_match['matches'],
+                        self.match_detail_reporter, self.use_tag_begin_prefix
+                    )
+                if target_annotation.require_next:
+                    conditional_match = dict(
+                        target_annotation=target_annotation,
+                        matches=matches
+                    )
+                else:
+                    _apply_annotations_to_matches(
+                        target_annotation, structured_document, matches,
+                        self.match_detail_reporter, self.use_tag_begin_prefix
+                    )
+                item_index += 1
+        return structured_document
diff --git a/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py b/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py
index 76fb765..fafac01 100644
--- a/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py
+++ b/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py
@@ -3,26 +3,26 @@ from __future__ import division
 import logging
 
 from sciencebeam_utils.utils.collection import (
-  flatten
+    flatten
 )
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimpleLine,
-  SimpleToken
+    SimpleStructuredDocument,
+    SimpleLine,
+    SimpleToken
 )
 
 from sciencebeam_gym.preprocess.annotation.target_annotation import (
-  TargetAnnotation
+    TargetAnnotation
 )
 
 from sciencebeam_gym.preprocess.annotation.matching_annotator import (
-  normalise_str,
-  MatchingAnnotator,
-  SequenceWrapper,
-  THIN_SPACE,
-  EN_DASH,
-  EM_DASH
+    normalise_str,
+    MatchingAnnotator,
+    SequenceWrapper,
+    THIN_SPACE,
+    EN_DASH,
+    EM_DASH
 )
 
 TAG1 = 'tag1'
@@ -39,877 +39,892 @@ SOME_VALUE_2 = 'some value2'
 SOME_LONGER_VALUE = 'some longer value1'
 SOME_SHORTER_VALUE = 'value1'
 
+
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 def _get_tags_of_tokens(tokens):
-  return [t.get_tag() for t in tokens]
+    return [t.get_tag() for t in tokens]
+
 
 def _copy_tokens(tokens):
-  return [SimpleToken(t.text) for t in tokens]
+    return [SimpleToken(t.text) for t in tokens]
+
 
 def _tokens_for_text(text):
-  return [SimpleToken(s) for s in text.split(' ')]
+    return [SimpleToken(s) for s in text.split(' ')]
+
 
 def _tokens_to_text(tokens):
-  return ' '.join([t.text for t in tokens])
+    return ' '.join([t.text for t in tokens])
+
 
 def _tokens_for_text_lines(text_lines):
-  return [_tokens_for_text(line) for line in text_lines]
+    return [_tokens_for_text(line) for line in text_lines]
+
 
 def _lines_for_tokens(tokens_by_line):
-  return [SimpleLine(tokens) for tokens in tokens_by_line]
+    return [SimpleLine(tokens) for tokens in tokens_by_line]
+
 
 def _document_for_tokens(tokens_by_line):
-  return SimpleStructuredDocument(lines=_lines_for_tokens(tokens_by_line))
+    return SimpleStructuredDocument(lines=_lines_for_tokens(tokens_by_line))
+
 
 class TestNormaliseStr(object):
-  def test_should_replace_thin_space_with_regular_space(self):
-    assert normalise_str(THIN_SPACE) == ' '
+    def test_should_replace_thin_space_with_regular_space(self):
+        assert normalise_str(THIN_SPACE) == ' '
 
-  def test_should_replace_em_dash_with_hyphen(self):
-    assert normalise_str(EM_DASH) == '-'
+    def test_should_replace_em_dash_with_hyphen(self):
+        assert normalise_str(EM_DASH) == '-'
+
+    def test_should_replace_en_dash_with_hyphen(self):
+        assert normalise_str(EN_DASH) == '-'
 
-  def test_should_replace_en_dash_with_hyphen(self):
-    assert normalise_str(EN_DASH) == '-'
 
 class TestSequenceWrapper(object):
-  def test_should_find_all_tokens_without_str_filter(self):
-    text = 'this is matching'
-    tokens = _tokens_for_text(text)
-    doc = _document_for_tokens([tokens])
-    seq = SequenceWrapper(doc, tokens)
-    assert str(seq) == text
-    assert list(seq.tokens_between((0, len(text)))) == tokens
-
-  def test_should_find_tokens_between_without_str_filter(self):
-    text = 'this is matching'
-    tokens = _tokens_for_text(text)
-    doc = _document_for_tokens([tokens])
-    seq = SequenceWrapper(doc, tokens)
-    assert str(seq) == text
-    assert list(seq.tokens_between((6, 7))) == [tokens[1]]
-
-  def test_should_find_tokens_between_adjusted_indices_due_to_str_filter(self):
-    text = 'this is matching'
-    tokens = _tokens_for_text(text)
-    doc = _document_for_tokens([tokens])
-    seq = SequenceWrapper(doc, tokens, str_filter_f=lambda s: s.replace('th', ''))
-    assert str(seq) == 'is is matching'
-    assert list(seq.tokens_between((4, 5))) == [tokens[1]]
+    def test_should_find_all_tokens_without_str_filter(self):
+        text = 'this is matching'
+        tokens = _tokens_for_text(text)
+        doc = _document_for_tokens([tokens])
+        seq = SequenceWrapper(doc, tokens)
+        assert str(seq) == text
+        assert list(seq.tokens_between((0, len(text)))) == tokens
+
+    def test_should_find_tokens_between_without_str_filter(self):
+        text = 'this is matching'
+        tokens = _tokens_for_text(text)
+        doc = _document_for_tokens([tokens])
+        seq = SequenceWrapper(doc, tokens)
+        assert str(seq) == text
+        assert list(seq.tokens_between((6, 7))) == [tokens[1]]
+
+    def test_should_find_tokens_between_adjusted_indices_due_to_str_filter(self):
+        text = 'this is matching'
+        tokens = _tokens_for_text(text)
+        doc = _document_for_tokens([tokens])
+        seq = SequenceWrapper(doc, tokens, str_filter_f=lambda s: s.replace('th', ''))
+        assert str(seq) == 'is is matching'
+        assert list(seq.tokens_between((4, 5))) == [tokens[1]]
+
 
 class TestMatchingAnnotator(object):
-  def test_should_not_fail_on_empty_document(self):
-    doc = SimpleStructuredDocument(lines=[])
-    MatchingAnnotator([]).annotate(doc)
-
-  def test_should_not_fail_on_empty_line_with_blank_token(self):
-    target_annotations = [
-      TargetAnnotation('this is. matching', TAG1)
-    ]
-    doc = _document_for_tokens([[SimpleToken('')]])
-    MatchingAnnotator(target_annotations).annotate(doc)
-
-  def test_should_annotate_exactly_matching(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_use_begin_prefix_if_enabled(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [B_TAG_1, I_TAG_1, I_TAG_1]
-
-  def test_should_match_normalised_characters(self):
-    matching_tokens = [
-      SimpleToken('this'),
-      SimpleToken('is' + THIN_SPACE + EN_DASH + EM_DASH),
-      SimpleToken('matching')
-    ]
-    target_annotations = [
-      TargetAnnotation('this is -- matching', TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_match_case_insensitive(self):
-    matching_tokens = _tokens_for_text('This Is Matching')
-    target_annotations = [
-      TargetAnnotation('tHIS iS mATCHING', TAG1)
-    ]
-    doc = SimpleStructuredDocument(lines=[SimpleLine(matching_tokens)])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_prefer_word_boundaries(self):
-    pre_tokens = _tokens_for_text('this')
-    matching_tokens = _tokens_for_text('is')
-    post_tokens = _tokens_for_text('miss')
-    target_annotations = [
-      TargetAnnotation('is', TAG1)
-    ]
-    doc = _document_for_tokens([
-      pre_tokens + matching_tokens + post_tokens
-    ])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
-    assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
-
-  def test_should_annotate_multiple_value_target_annotation(self):
-    matching_tokens = _tokens_for_text('this may match')
-    target_annotations = [
-      TargetAnnotation([
-        'this', 'may', 'match'
-      ], TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_multiple_value_target_annotation_with_begin_prefix(self):
-    matching_tokens = _tokens_for_text('this may match')
-    target_annotations = [
-      TargetAnnotation([
-        'this', 'may', 'match'
-      ], TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == (
-      [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1)
-    )
-
-  def test_should_annotate_multiple_value_target_annotation_rev_order_with_begin_prefix(self):
-    matching_tokens = _tokens_for_text('this may match')
-    target_annotations = [
-      TargetAnnotation(list(reversed([
-        'this', 'may', 'match'
-      ])), TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == (
-      [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1)
-    )
-
-  def test_should_annotate_multiple_value_target_annotation_over_multiple_lines(self):
-    tokens_by_line = [
-      _tokens_for_text('this may'),
-      _tokens_for_text('match')
-    ]
-    matching_tokens = flatten(tokens_by_line)
-    target_annotations = [
-      TargetAnnotation([
-        'this', 'may', 'match'
-      ], TAG1)
-    ]
-    doc = _document_for_tokens(tokens_by_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_mult_value_target_annot_rev_order_over_mult_lines_with_b_prefix(self):
-    tokens_by_line = [
-      _tokens_for_text('this may'),
-      _tokens_for_text('match')
-    ]
-    matching_tokens = flatten(tokens_by_line)
-    target_annotations = [
-      TargetAnnotation(list(reversed([
-        'this', 'may', 'match'
-      ])), TAG1)
-    ]
-    doc = _document_for_tokens(tokens_by_line)
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == (
-      [B_TAG_1] +
-      [I_TAG_1] * (len(matching_tokens) - 1)
-    )
-
-  def test_should_annotate_not_match_distant_value_of_multiple_value_target_annotation(self):
-    matching_tokens = _tokens_for_text('this may match')
-    distant_matching_tokens = _tokens_for_text('not')
-    distance_in_lines = 10
-    tokens_by_line = [matching_tokens] + [
-      _tokens_for_text('other') for _ in range(distance_in_lines)
-    ] + [distant_matching_tokens]
-    target_annotations = [
-      TargetAnnotation([
-        'this', 'may', 'match', 'not'
-      ], TAG1)
-    ]
-    doc = _document_for_tokens(tokens_by_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(distant_matching_tokens) == [None] * len(distant_matching_tokens)
-
-  def test_should_annotate_not_match_distant_value_of_target_annotation_with_bonding(self):
-    matching_tokens = _tokens_for_text('this may match')
-    distant_matching_tokens = _tokens_for_text('not')
-    distance_in_lines = 10
-    tokens_by_line = [matching_tokens] + [
-      _tokens_for_text('other') for _ in range(distance_in_lines)
-    ] + [distant_matching_tokens]
-    target_annotations = [
-      TargetAnnotation('this may match', TAG1, bonding=True),
-      TargetAnnotation('not', TAG1, bonding=True)
-    ]
-    doc = _document_for_tokens(tokens_by_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(distant_matching_tokens) == [None] * len(distant_matching_tokens)
-
-  def test_should_annotate_fuzzily_matching(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('this is. matching', TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_ignoring_space_after_dot_short_sequence(self):
-    matching_tokens = [
-      SimpleToken('A.B.,')
-    ]
-    target_annotations = [
-      TargetAnnotation('A. B.', TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_ignoring_comma_after_short_sequence(self):
-    matching_tokens = [
-      SimpleToken('Name,'),
-    ]
-    target_annotations = [
-      TargetAnnotation('Name', TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_ignoring_dots_after_capitals_in_target_annotation(self):
-    matching_tokens = _tokens_for_text('PO Box 12345')
-    target_annotations = [
-      TargetAnnotation('P.O. Box 12345', TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_ignoring_dots_after_capitals_in_document(self):
-    matching_tokens = _tokens_for_text('P.O. Box 12345')
-    target_annotations = [
-      TargetAnnotation('PO Box 12345', TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_with_local_matching_smaller_gaps(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('this is. matching indeed matching', TAG1)
-    ]
-    # this should align with 'this is_ matching' with one gap'
-    # instead of globally 'this is_ ________ ______ matching'
-    # (which would result in a worse b_gap_ratio)
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_not_annotate_fuzzily_matching_with_many_differences(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('txhxixsx ixsx mxaxtxcxhxixnxgx', TAG1)
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [None] * len(matching_tokens)
-
-  def test_should_annotate_fuzzily_matching_longer_matches_based_on_ratio(self):
-    long_matching_text = 'this is matching and is really really long match that we can trust'
-    matching_tokens = _tokens_for_text(long_matching_text)
-    no_matching_tokens = _tokens_for_text('what comes next is different')
-    target_annotations = [
-      TargetAnnotation(long_matching_text + ' but this is not and is another matter', TAG1)
-    ]
-    doc = _document_for_tokens([
-      matching_tokens + no_matching_tokens
-    ])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(no_matching_tokens) == [None] * len(no_matching_tokens)
-
-  def test_should_not_annotate_not_matching(self):
-    not_matching_tokens = _tokens_for_text('something completely different')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1)
-    ]
-    doc = _document_for_tokens([not_matching_tokens])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(not_matching_tokens) == [None] * len(not_matching_tokens)
-
-  def test_should_annotate_exactly_matching_across_multiple_lines(self):
-    matching_tokens_per_line = [
-      _tokens_for_text('this is matching'),
-      _tokens_for_text('and continues here')
-    ]
-    matching_tokens = flatten(matching_tokens_per_line)
-    target_annotations = [
-      TargetAnnotation('this is matching and continues here', TAG1)
-    ]
-    doc = _document_for_tokens(matching_tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_exactly_matching_across_multiple_lines_with_begin_prefix(self):
-    matching_tokens_per_line = [
-      _tokens_for_text('this is matching'),
-      _tokens_for_text('and continues here')
-    ]
-    matching_tokens = flatten(matching_tokens_per_line)
-    target_annotations = [
-      TargetAnnotation('this is matching and continues here', TAG1)
-    ]
-    doc = _document_for_tokens(matching_tokens_per_line)
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == (
-      [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1)
-    )
-
-  def test_should_not_annotate_shorter_sequence_if_next_line_does_not_match(self):
-    tokens_per_line = [
-      _tokens_for_text('this is'),
-      _tokens_for_text('something completely different')
-    ]
-    tokens = flatten(tokens_per_line)
-    target_annotations = [
-      TargetAnnotation('this is not matching', TAG1)
-    ]
-    doc = _document_for_tokens(tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(tokens) == [None] * len(tokens)
-
-  def test_should_annotate_over_multiple_lines_with_tag_transition(self):
-    tag1_tokens_by_line = [
-      _tokens_for_text('this may'),
-      _tokens_for_text('match')
-    ]
-    tag1_tokens = flatten(tag1_tokens_by_line)
-    tag2_tokens_by_line = [
-      _tokens_for_text('another'),
-      _tokens_for_text('tag here')
-    ]
-    tag2_tokens = flatten(tag2_tokens_by_line)
-    tokens_by_line = [
-      tag1_tokens_by_line[0],
-      tag1_tokens_by_line[1] + tag2_tokens_by_line[0],
-      tag2_tokens_by_line[1]
-    ]
-    target_annotations = [
-      TargetAnnotation('this may match', TAG1),
-      TargetAnnotation('another tag here', TAG2)
-    ]
-    doc = _document_for_tokens(tokens_by_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(tag1_tokens) == [TAG1] * len(tag1_tokens)
-    assert _get_tags_of_tokens(tag2_tokens) == [TAG2] * len(tag2_tokens)
-
-  def test_should_annotate_over_multiple_lines_with_tag_transition_with_begin_prefix(self):
-    tag1_tokens_by_line = [
-      _tokens_for_text('this may'),
-      _tokens_for_text('match')
-    ]
-    tag1_tokens = flatten(tag1_tokens_by_line)
-    tag2_tokens_by_line = [
-      _tokens_for_text('another'),
-      _tokens_for_text('tag here')
-    ]
-    tag2_tokens = flatten(tag2_tokens_by_line)
-    tokens_by_line = [
-      tag1_tokens_by_line[0],
-      tag1_tokens_by_line[1] + tag2_tokens_by_line[0],
-      tag2_tokens_by_line[1]
-    ]
-    target_annotations = [
-      TargetAnnotation('this may match', TAG1),
-      TargetAnnotation('another tag here', TAG2)
-    ]
-    doc = _document_for_tokens(tokens_by_line)
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
-    assert _get_tags_of_tokens(tag1_tokens) == (
-      [B_TAG_1] + [I_TAG_1] * (len(tag1_tokens) - 1)
-    )
-    assert _get_tags_of_tokens(tag2_tokens) == (
-      [B_TAG_2] + [I_TAG_2] * (len(tag2_tokens) - 1)
-    )
-
-  def test_should_annotate_short_section_title_followed_by_paragraph(self):
-    section_title_text = 'section title'
-    section_paragraph_text = 'paragraph text to come here.'
-    section_title_tokens = _tokens_for_text(section_title_text + '.')
-    section_paragraph_tokens = _tokens_for_text(section_paragraph_text)
-    tokens_per_line = [
-      section_title_tokens + section_paragraph_tokens
-    ]
-    target_annotations = [
-      TargetAnnotation(section_title_text, 'section_title', require_next=True),
-      TargetAnnotation(section_paragraph_text, 'section_paragraph')
-    ]
-    doc = _document_for_tokens(tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert (
-      _get_tags_of_tokens(section_title_tokens) ==
-      ['section_title'] * len(section_title_tokens)
-    )
-    assert (
-      _get_tags_of_tokens(section_paragraph_tokens) ==
-      ['section_paragraph'] * len(section_paragraph_tokens)
-    )
-
-  def test_should_not_annotate_short_section_title_not_followed_by_paragraph(self):
-    section_title_text = 'section title'
-    section_title_tokens = _tokens_for_text(section_title_text + '.')
-    section_paragraph_text = 'paragraph text to come here.'
-    section_paragraph_tokens = _tokens_for_text(section_paragraph_text)
-    tokens_per_line = [
-      section_title_tokens + _tokens_for_text('other text to come here.'),
-      _tokens_for_text('more unrelated text.'),
-      _tokens_for_text('even more.'),
-      section_paragraph_tokens
-    ]
-    target_annotations = [
-      TargetAnnotation(section_title_text, 'section_title', require_next=True),
-      TargetAnnotation(section_paragraph_text, 'section_paragraph')
-    ]
-    doc = _document_for_tokens(tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert (
-      _get_tags_of_tokens(section_title_tokens) ==
-      [None] * len(section_title_tokens)
-    )
-
-  def test_should_not_annotate_short_section_title_if_paragraph_follows_later(self):
-    section_title_text = 'section title'
-    section_title_tokens = _tokens_for_text(section_title_text + '.')
-    other_tokens = _tokens_for_text('other text to come here.')
-    tokens_per_line = [
-      section_title_tokens + other_tokens
-    ]
-    target_annotations = [
-      TargetAnnotation(section_title_text, 'section_title', require_next=True)
-    ]
-    doc = _document_for_tokens(tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert (
-      _get_tags_of_tokens(section_title_tokens) ==
-      [None] * len(section_title_tokens)
-    )
-
-  def test_should_annotate_short_reference_item_followed_by_other_reference_items(self):
-    reference_item_texts = ['ref_id', 'ref_title']
-    reference_item_tokens = _tokens_for_text(' '.join(reference_item_texts))
-    tokens_per_line = [
-      reference_item_tokens
-    ]
-    target_annotations = [
-      TargetAnnotation(reference_item_texts, 'reference', bonding=True)
-    ]
-    doc = _document_for_tokens(tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert (
-      _get_tags_of_tokens(reference_item_tokens) ==
-      ['reference'] * len(reference_item_tokens)
-    )
-
-  def test_should_not_annotate_short_reference_item_not_followed_by_other_reference_items(self):
-    matching_reference_item_text = 'ref_id'
-    reference_item_texts = [matching_reference_item_text] + ['ref_title']
-    matching_reference_item_tokens = _tokens_for_text(matching_reference_item_text)
-    other_tokens = _tokens_for_text('other')
-    tokens_per_line = [
-      matching_reference_item_tokens + other_tokens
-    ]
-    target_annotations = [
-      TargetAnnotation(reference_item_texts, 'reference', bonding=True)
-    ]
-    doc = _document_for_tokens(tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert (
-      _get_tags_of_tokens(matching_reference_item_tokens) ==
-      [None] * len(matching_reference_item_tokens)
-    )
-
-  def test_should_annotate_last_line_of_block_followed_by_other_text(self):
-    block_text_lines = [
-      'this is the first row',
-      'second row follows',
-      'here we are on the third',
-      'last line of block'
-    ]
-    block_tokens_per_line = _tokens_for_text_lines(block_text_lines)
-    block_tokens = flatten(block_tokens_per_line)
-    tokens_per_line = block_tokens_per_line + [
-      _tokens_for_text('other text')
-    ]
-    target_annotations = [
-      TargetAnnotation('\n'.join(block_text_lines), TAG1)
-    ]
-    doc = _document_for_tokens(tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert (
-      _get_tags_of_tokens(block_tokens) ==
-      [TAG1] * len(block_tokens)
-    )
-
-  def test_should_annotate_longer_sequence_over_multiple_lines_considering_next_line(self):
-    # we need a long enough sequence to fall into the first branch
-    # and match the partial match threshold
-    exact_matching_text_lines = ('this may', 'indeed match very well without the slightest doubt')
-    # add a short prefix that doesn't affect the score much
-    # but would be skipped if we only matched the second line
-    matching_text_lines = (exact_matching_text_lines[0], 'x ' + exact_matching_text_lines[1])
-    matching_tokens_by_line = _tokens_for_text_lines(matching_text_lines)
-    matching_tokens = flatten(matching_tokens_by_line)
-    pre_tokens = _tokens_for_text(matching_text_lines[0] + ' this may not')
-    post_tokens = _tokens_for_text('or not')
-    tokens_by_line = [
-      pre_tokens + matching_tokens_by_line[0],
-      matching_tokens_by_line[1] + post_tokens
-    ]
-    target_annotations = [
-      TargetAnnotation(' '.join(exact_matching_text_lines), TAG1)
-    ]
-    doc = _document_for_tokens(tokens_by_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
-    assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
-
-  def test_should_annotate_shorter_sequence_over_multiple_lines_considering_next_line(self):
-    # use a short sequence that wouldn't get matched on it's own
-    matching_text_lines = ('this may', 'match')
-    matching_tokens_by_line = _tokens_for_text_lines(matching_text_lines)
-    matching_tokens = flatten(matching_tokens_by_line)
-    # repeat the same text on the two lines, only by combining the lines would it be clear
-    # which tokens to match
-    pre_tokens = _tokens_for_text(matching_text_lines[0] + ' be some other longer preceeding text')
-    post_tokens = _tokens_for_text('this is some text after but no ' + matching_text_lines[1])
-    tokens_by_line = [
-      pre_tokens + matching_tokens_by_line[0],
-      matching_tokens_by_line[1] + post_tokens
-    ]
-    target_annotations = [
-      TargetAnnotation('this may match', TAG1)
-    ]
-    doc = _document_for_tokens(tokens_by_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
-    assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
-
-  def test_should_not_annotate_too_short_match_of_longer_sequence(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    too_short_tokens = _tokens_for_text('1')
-    tokens_per_line = [
-      too_short_tokens,
-      matching_tokens
-    ]
-    target_annotations = [
-      TargetAnnotation('this is matching 1', TAG1)
-    ]
-    doc = _document_for_tokens(tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(too_short_tokens) == [None] * len(too_short_tokens)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_not_annotate_similar_sequence_multiple_times(self):
-    matching_tokens_per_line = [
-      _tokens_for_text('this is matching'),
-      _tokens_for_text('and continues here')
-    ]
-    not_matching_tokens = _tokens_for_text('this is matching')
-
-    matching_tokens = flatten(matching_tokens_per_line)
-    target_annotations = [
-      TargetAnnotation('this is matching and continues here', TAG1)
-    ]
-    doc = _document_for_tokens(
-      matching_tokens_per_line + [not_matching_tokens]
-    )
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(not_matching_tokens) == [None] * len(not_matching_tokens)
-
-  def test_should_annotate_same_sequence_multiple_times_if_enabled(self):
-    matching_tokens_per_line = [
-      _tokens_for_text('this is matching'),
-      _tokens_for_text('this is matching')
-    ]
-
-    matching_tokens = flatten(matching_tokens_per_line)
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1, match_multiple=True)
-    ]
-    doc = _document_for_tokens(matching_tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_same_sequence_multiple_times_with_begin_prefix(self):
-    matching_tokens_per_line = [
-      _tokens_for_text('this is matching'),
-      _tokens_for_text('this is matching')
-    ]
-
-    matching_tokens = flatten(matching_tokens_per_line)
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1, match_multiple=True)
-    ]
-    doc = _document_for_tokens(matching_tokens_per_line)
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
-    # the begin tag should appear at the beginning of each match
-    assert (
-      _get_tags_of_tokens(matching_tokens) ==
-      [B_TAG_1, I_TAG_1, I_TAG_1, B_TAG_1, I_TAG_1, I_TAG_1]
-    )
-
-  def test_should_not_override_annotation(self):
-    matching_tokens_per_line = [
-      _tokens_for_text('this is matching')
-    ]
-
-    matching_tokens = flatten(matching_tokens_per_line)
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1),
-      TargetAnnotation('matching', TAG2)
-    ]
-    doc = _document_for_tokens(matching_tokens_per_line)
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_not_annotate_pre_annotated_tokens_on_separate_lines(self):
-    line_no_tokens = _tokens_for_text('1')
-    line_no_tokens[0].set_tag('line_no')
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('1', TAG2),
-      TargetAnnotation('this is matching', TAG1)
-    ]
-    doc = _document_for_tokens([
-      line_no_tokens + matching_tokens
-    ])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(line_no_tokens) == ['line_no'] * len(line_no_tokens)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-
-  def test_should_annotate_shorter_target_annotation_in_longer_line(self):
-    pre_tokens = _tokens_for_text('pre')
-    matching_tokens = _tokens_for_text('this is matching')
-    post_tokens = _tokens_for_text('post')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1)
-    ]
-    doc = _document_for_tokens([
-      pre_tokens + matching_tokens + post_tokens
-    ])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
-
-  def test_should_annotate_shorter_target_annotation_fuzzily(self):
-    pre_tokens = _tokens_for_text('pre')
-    matching_tokens = _tokens_for_text('this is matching')
-    post_tokens = _tokens_for_text('post')
-    target_annotations = [
-      TargetAnnotation('this is. matching', TAG1)
-    ]
-    doc = _document_for_tokens([
-      pre_tokens + matching_tokens + post_tokens
-    ])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
-
-  def test_should_annotate_multiple_shorter_target_annotation_in_longer_line(self):
-    pre_tokens = _tokens_for_text('pre')
-    matching_tokens_tag_1 = _tokens_for_text('this is matching')
-    mid_tokens = _tokens_for_text('mid')
-    matching_tokens_tag_2 = _tokens_for_text('also good')
-    post_tokens = _tokens_for_text('post')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1),
-      TargetAnnotation('also good', TAG2)
-    ]
-    doc = _document_for_tokens([
-      pre_tokens + matching_tokens_tag_1 + mid_tokens + matching_tokens_tag_2 + post_tokens
-    ])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
-    assert _get_tags_of_tokens(matching_tokens_tag_1) == [TAG1] * len(matching_tokens_tag_1)
-    assert _get_tags_of_tokens(mid_tokens) == [None] * len(mid_tokens)
-    assert _get_tags_of_tokens(matching_tokens_tag_2) == [TAG2] * len(matching_tokens_tag_2)
-    assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
-
-  def test_should_not_annotate_shorter_target_annotation_in_longer_line_multiple_times(self):
-    pre_tokens = _tokens_for_text('pre')
-    matching_tokens = _tokens_for_text('this is matching')
-    post_tokens = _tokens_for_text('post')
-    first_line_tokens = pre_tokens + matching_tokens + post_tokens
-    similar_line_tokens = _copy_tokens(first_line_tokens)
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1)
-    ]
-    doc = _document_for_tokens([
-      first_line_tokens,
-      similar_line_tokens
-    ])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(similar_line_tokens) == [None] * len(similar_line_tokens)
-
-  def test_should_annotate_shorter_target_annotation_in_longer_line_multiple_times_if_enabled(self):
-    pre_tokens = _tokens_for_text('pre')
-    matching_tokens = _tokens_for_text('this is matching')
-    post_tokens = _tokens_for_text('post')
-    same_matching_tokens = _copy_tokens(matching_tokens)
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG1, match_multiple=True)
-    ]
-    doc = _document_for_tokens([
-      pre_tokens + matching_tokens + post_tokens,
-      _copy_tokens(pre_tokens) + same_matching_tokens + _copy_tokens(post_tokens)
-    ])
-    MatchingAnnotator(target_annotations).annotate(doc)
-    assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
-    assert _get_tags_of_tokens(same_matching_tokens) == [TAG1] * len(same_matching_tokens)
+    def test_should_not_fail_on_empty_document(self):
+        doc = SimpleStructuredDocument(lines=[])
+        MatchingAnnotator([]).annotate(doc)
+
+    def test_should_not_fail_on_empty_line_with_blank_token(self):
+        target_annotations = [
+            TargetAnnotation('this is. matching', TAG1)
+        ]
+        doc = _document_for_tokens([[SimpleToken('')]])
+        MatchingAnnotator(target_annotations).annotate(doc)
+
+    def test_should_annotate_exactly_matching(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_use_begin_prefix_if_enabled(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [B_TAG_1, I_TAG_1, I_TAG_1]
+
+    def test_should_match_normalised_characters(self):
+        matching_tokens = [
+            SimpleToken('this'),
+            SimpleToken('is' + THIN_SPACE + EN_DASH + EM_DASH),
+            SimpleToken('matching')
+        ]
+        target_annotations = [
+            TargetAnnotation('this is -- matching', TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_match_case_insensitive(self):
+        matching_tokens = _tokens_for_text('This Is Matching')
+        target_annotations = [
+            TargetAnnotation('tHIS iS mATCHING', TAG1)
+        ]
+        doc = SimpleStructuredDocument(lines=[SimpleLine(matching_tokens)])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_prefer_word_boundaries(self):
+        pre_tokens = _tokens_for_text('this')
+        matching_tokens = _tokens_for_text('is')
+        post_tokens = _tokens_for_text('miss')
+        target_annotations = [
+            TargetAnnotation('is', TAG1)
+        ]
+        doc = _document_for_tokens([
+            pre_tokens + matching_tokens + post_tokens
+        ])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
+        assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
+
+    def test_should_annotate_multiple_value_target_annotation(self):
+        matching_tokens = _tokens_for_text('this may match')
+        target_annotations = [
+            TargetAnnotation([
+                'this', 'may', 'match'
+            ], TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_multiple_value_target_annotation_with_begin_prefix(self):
+        matching_tokens = _tokens_for_text('this may match')
+        target_annotations = [
+            TargetAnnotation([
+                'this', 'may', 'match'
+            ], TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == (
+            [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1)
+        )
+
+    def test_should_annotate_multiple_value_target_annotation_rev_order_with_begin_prefix(self):
+        matching_tokens = _tokens_for_text('this may match')
+        target_annotations = [
+            TargetAnnotation(list(reversed([
+                'this', 'may', 'match'
+            ])), TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == (
+            [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1)
+        )
+
+    def test_should_annotate_multiple_value_target_annotation_over_multiple_lines(self):
+        tokens_by_line = [
+            _tokens_for_text('this may'),
+            _tokens_for_text('match')
+        ]
+        matching_tokens = flatten(tokens_by_line)
+        target_annotations = [
+            TargetAnnotation([
+                'this', 'may', 'match'
+            ], TAG1)
+        ]
+        doc = _document_for_tokens(tokens_by_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_mult_value_target_annot_rev_order_over_mult_lines_with_b_prefix(self):
+        tokens_by_line = [
+            _tokens_for_text('this may'),
+            _tokens_for_text('match')
+        ]
+        matching_tokens = flatten(tokens_by_line)
+        target_annotations = [
+            TargetAnnotation(list(reversed([
+                'this', 'may', 'match'
+            ])), TAG1)
+        ]
+        doc = _document_for_tokens(tokens_by_line)
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == (
+            [B_TAG_1] +
+            [I_TAG_1] * (len(matching_tokens) - 1)
+        )
+
+    def test_should_annotate_not_match_distant_value_of_multiple_value_target_annotation(self):
+        matching_tokens = _tokens_for_text('this may match')
+        distant_matching_tokens = _tokens_for_text('not')
+        distance_in_lines = 10
+        tokens_by_line = [matching_tokens] + [
+            _tokens_for_text('other') for _ in range(distance_in_lines)
+        ] + [distant_matching_tokens]
+        target_annotations = [
+            TargetAnnotation([
+                'this', 'may', 'match', 'not'
+            ], TAG1)
+        ]
+        doc = _document_for_tokens(tokens_by_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(distant_matching_tokens) == [None] * len(distant_matching_tokens)
+
+    def test_should_annotate_not_match_distant_value_of_target_annotation_with_bonding(self):
+        matching_tokens = _tokens_for_text('this may match')
+        distant_matching_tokens = _tokens_for_text('not')
+        distance_in_lines = 10
+        tokens_by_line = [matching_tokens] + [
+            _tokens_for_text('other') for _ in range(distance_in_lines)
+        ] + [distant_matching_tokens]
+        target_annotations = [
+            TargetAnnotation('this may match', TAG1, bonding=True),
+            TargetAnnotation('not', TAG1, bonding=True)
+        ]
+        doc = _document_for_tokens(tokens_by_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(distant_matching_tokens) == [None] * len(distant_matching_tokens)
+
+    def test_should_annotate_fuzzily_matching(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('this is. matching', TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_ignoring_space_after_dot_short_sequence(self):
+        matching_tokens = [
+            SimpleToken('A.B.,')
+        ]
+        target_annotations = [
+            TargetAnnotation('A. B.', TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_ignoring_comma_after_short_sequence(self):
+        matching_tokens = [
+            SimpleToken('Name,'),
+        ]
+        target_annotations = [
+            TargetAnnotation('Name', TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_ignoring_dots_after_capitals_in_target_annotation(self):
+        matching_tokens = _tokens_for_text('PO Box 12345')
+        target_annotations = [
+            TargetAnnotation('P.O. Box 12345', TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_ignoring_dots_after_capitals_in_document(self):
+        matching_tokens = _tokens_for_text('P.O. Box 12345')
+        target_annotations = [
+            TargetAnnotation('PO Box 12345', TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_with_local_matching_smaller_gaps(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('this is. matching indeed matching', TAG1)
+        ]
+        # this should align with 'this is_ matching' with one gap'
+        # instead of globally 'this is_ ________ ______ matching'
+        # (which would result in a worse b_gap_ratio)
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_not_annotate_fuzzily_matching_with_many_differences(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('txhxixsx ixsx mxaxtxcxhxixnxgx', TAG1)
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [None] * len(matching_tokens)
+
+    def test_should_annotate_fuzzily_matching_longer_matches_based_on_ratio(self):
+        long_matching_text = 'this is matching and is really really long match that we can trust'
+        matching_tokens = _tokens_for_text(long_matching_text)
+        no_matching_tokens = _tokens_for_text('what comes next is different')
+        target_annotations = [
+            TargetAnnotation(long_matching_text + ' but this is not and is another matter', TAG1)
+        ]
+        doc = _document_for_tokens([
+            matching_tokens + no_matching_tokens
+        ])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(no_matching_tokens) == [None] * len(no_matching_tokens)
+
+    def test_should_not_annotate_not_matching(self):
+        not_matching_tokens = _tokens_for_text('something completely different')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1)
+        ]
+        doc = _document_for_tokens([not_matching_tokens])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(not_matching_tokens) == [None] * len(not_matching_tokens)
+
+    def test_should_annotate_exactly_matching_across_multiple_lines(self):
+        matching_tokens_per_line = [
+            _tokens_for_text('this is matching'),
+            _tokens_for_text('and continues here')
+        ]
+        matching_tokens = flatten(matching_tokens_per_line)
+        target_annotations = [
+            TargetAnnotation('this is matching and continues here', TAG1)
+        ]
+        doc = _document_for_tokens(matching_tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_exactly_matching_across_multiple_lines_with_begin_prefix(self):
+        matching_tokens_per_line = [
+            _tokens_for_text('this is matching'),
+            _tokens_for_text('and continues here')
+        ]
+        matching_tokens = flatten(matching_tokens_per_line)
+        target_annotations = [
+            TargetAnnotation('this is matching and continues here', TAG1)
+        ]
+        doc = _document_for_tokens(matching_tokens_per_line)
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == (
+            [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1)
+        )
+
+    def test_should_not_annotate_shorter_sequence_if_next_line_does_not_match(self):
+        tokens_per_line = [
+            _tokens_for_text('this is'),
+            _tokens_for_text('something completely different')
+        ]
+        tokens = flatten(tokens_per_line)
+        target_annotations = [
+            TargetAnnotation('this is not matching', TAG1)
+        ]
+        doc = _document_for_tokens(tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(tokens) == [None] * len(tokens)
+
+    def test_should_annotate_over_multiple_lines_with_tag_transition(self):
+        tag1_tokens_by_line = [
+            _tokens_for_text('this may'),
+            _tokens_for_text('match')
+        ]
+        tag1_tokens = flatten(tag1_tokens_by_line)
+        tag2_tokens_by_line = [
+            _tokens_for_text('another'),
+            _tokens_for_text('tag here')
+        ]
+        tag2_tokens = flatten(tag2_tokens_by_line)
+        tokens_by_line = [
+            tag1_tokens_by_line[0],
+            tag1_tokens_by_line[1] + tag2_tokens_by_line[0],
+            tag2_tokens_by_line[1]
+        ]
+        target_annotations = [
+            TargetAnnotation('this may match', TAG1),
+            TargetAnnotation('another tag here', TAG2)
+        ]
+        doc = _document_for_tokens(tokens_by_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(tag1_tokens) == [TAG1] * len(tag1_tokens)
+        assert _get_tags_of_tokens(tag2_tokens) == [TAG2] * len(tag2_tokens)
+
+    def test_should_annotate_over_multiple_lines_with_tag_transition_with_begin_prefix(self):
+        tag1_tokens_by_line = [
+            _tokens_for_text('this may'),
+            _tokens_for_text('match')
+        ]
+        tag1_tokens = flatten(tag1_tokens_by_line)
+        tag2_tokens_by_line = [
+            _tokens_for_text('another'),
+            _tokens_for_text('tag here')
+        ]
+        tag2_tokens = flatten(tag2_tokens_by_line)
+        tokens_by_line = [
+            tag1_tokens_by_line[0],
+            tag1_tokens_by_line[1] + tag2_tokens_by_line[0],
+            tag2_tokens_by_line[1]
+        ]
+        target_annotations = [
+            TargetAnnotation('this may match', TAG1),
+            TargetAnnotation('another tag here', TAG2)
+        ]
+        doc = _document_for_tokens(tokens_by_line)
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
+        assert _get_tags_of_tokens(tag1_tokens) == (
+            [B_TAG_1] + [I_TAG_1] * (len(tag1_tokens) - 1)
+        )
+        assert _get_tags_of_tokens(tag2_tokens) == (
+            [B_TAG_2] + [I_TAG_2] * (len(tag2_tokens) - 1)
+        )
+
+    def test_should_annotate_short_section_title_followed_by_paragraph(self):
+        section_title_text = 'section title'
+        section_paragraph_text = 'paragraph text to come here.'
+        section_title_tokens = _tokens_for_text(section_title_text + '.')
+        section_paragraph_tokens = _tokens_for_text(section_paragraph_text)
+        tokens_per_line = [
+            section_title_tokens + section_paragraph_tokens
+        ]
+        target_annotations = [
+            TargetAnnotation(section_title_text, 'section_title', require_next=True),
+            TargetAnnotation(section_paragraph_text, 'section_paragraph')
+        ]
+        doc = _document_for_tokens(tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert (
+            _get_tags_of_tokens(section_title_tokens) ==
+            ['section_title'] * len(section_title_tokens)
+        )
+        assert (
+            _get_tags_of_tokens(section_paragraph_tokens) ==
+            ['section_paragraph'] * len(section_paragraph_tokens)
+        )
+
+    def test_should_not_annotate_short_section_title_not_followed_by_paragraph(self):
+        section_title_text = 'section title'
+        section_title_tokens = _tokens_for_text(section_title_text + '.')
+        section_paragraph_text = 'paragraph text to come here.'
+        section_paragraph_tokens = _tokens_for_text(section_paragraph_text)
+        tokens_per_line = [
+            section_title_tokens + _tokens_for_text('other text to come here.'),
+            _tokens_for_text('more unrelated text.'),
+            _tokens_for_text('even more.'),
+            section_paragraph_tokens
+        ]
+        target_annotations = [
+            TargetAnnotation(section_title_text, 'section_title', require_next=True),
+            TargetAnnotation(section_paragraph_text, 'section_paragraph')
+        ]
+        doc = _document_for_tokens(tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert (
+            _get_tags_of_tokens(section_title_tokens) ==
+            [None] * len(section_title_tokens)
+        )
+
+    def test_should_not_annotate_short_section_title_if_paragraph_follows_later(self):
+        section_title_text = 'section title'
+        section_title_tokens = _tokens_for_text(section_title_text + '.')
+        other_tokens = _tokens_for_text('other text to come here.')
+        tokens_per_line = [
+            section_title_tokens + other_tokens
+        ]
+        target_annotations = [
+            TargetAnnotation(section_title_text, 'section_title', require_next=True)
+        ]
+        doc = _document_for_tokens(tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert (
+            _get_tags_of_tokens(section_title_tokens) ==
+            [None] * len(section_title_tokens)
+        )
+
+    def test_should_annotate_short_reference_item_followed_by_other_reference_items(self):
+        reference_item_texts = ['ref_id', 'ref_title']
+        reference_item_tokens = _tokens_for_text(' '.join(reference_item_texts))
+        tokens_per_line = [
+            reference_item_tokens
+        ]
+        target_annotations = [
+            TargetAnnotation(reference_item_texts, 'reference', bonding=True)
+        ]
+        doc = _document_for_tokens(tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert (
+            _get_tags_of_tokens(reference_item_tokens) ==
+            ['reference'] * len(reference_item_tokens)
+        )
+
+    def test_should_not_annotate_short_reference_item_not_followed_by_other_reference_items(self):
+        matching_reference_item_text = 'ref_id'
+        reference_item_texts = [matching_reference_item_text] + ['ref_title']
+        matching_reference_item_tokens = _tokens_for_text(matching_reference_item_text)
+        other_tokens = _tokens_for_text('other')
+        tokens_per_line = [
+            matching_reference_item_tokens + other_tokens
+        ]
+        target_annotations = [
+            TargetAnnotation(reference_item_texts, 'reference', bonding=True)
+        ]
+        doc = _document_for_tokens(tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert (
+            _get_tags_of_tokens(matching_reference_item_tokens) ==
+            [None] * len(matching_reference_item_tokens)
+        )
+
+    def test_should_annotate_last_line_of_block_followed_by_other_text(self):
+        block_text_lines = [
+            'this is the first row',
+            'second row follows',
+            'here we are on the third',
+            'last line of block'
+        ]
+        block_tokens_per_line = _tokens_for_text_lines(block_text_lines)
+        block_tokens = flatten(block_tokens_per_line)
+        tokens_per_line = block_tokens_per_line + [
+            _tokens_for_text('other text')
+        ]
+        target_annotations = [
+            TargetAnnotation('\n'.join(block_text_lines), TAG1)
+        ]
+        doc = _document_for_tokens(tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert (
+            _get_tags_of_tokens(block_tokens) ==
+            [TAG1] * len(block_tokens)
+        )
+
+    def test_should_annotate_longer_sequence_over_multiple_lines_considering_next_line(self):
+        # we need a long enough sequence to fall into the first branch
+        # and match the partial match threshold
+        exact_matching_text_lines = (
+            'this may', 'indeed match very well without the slightest doubt')
+        # add a short prefix that doesn't affect the score much
+        # but would be skipped if we only matched the second line
+        matching_text_lines = (exact_matching_text_lines[0], 'x ' + exact_matching_text_lines[1])
+        matching_tokens_by_line = _tokens_for_text_lines(matching_text_lines)
+        matching_tokens = flatten(matching_tokens_by_line)
+        pre_tokens = _tokens_for_text(matching_text_lines[0] + ' this may not')
+        post_tokens = _tokens_for_text('or not')
+        tokens_by_line = [
+            pre_tokens + matching_tokens_by_line[0],
+            matching_tokens_by_line[1] + post_tokens
+        ]
+        target_annotations = [
+            TargetAnnotation(' '.join(exact_matching_text_lines), TAG1)
+        ]
+        doc = _document_for_tokens(tokens_by_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
+        assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
+
+    def test_should_annotate_shorter_sequence_over_multiple_lines_considering_next_line(self):
+        # use a short sequence that wouldn't get matched on it's own
+        matching_text_lines = ('this may', 'match')
+        matching_tokens_by_line = _tokens_for_text_lines(matching_text_lines)
+        matching_tokens = flatten(matching_tokens_by_line)
+        # repeat the same text on the two lines, only by combining the lines would it be clear
+        # which tokens to match
+        pre_tokens = _tokens_for_text(
+            matching_text_lines[0] + ' be some other longer preceeding text')
+        post_tokens = _tokens_for_text('this is some text after but no ' + matching_text_lines[1])
+        tokens_by_line = [
+            pre_tokens + matching_tokens_by_line[0],
+            matching_tokens_by_line[1] + post_tokens
+        ]
+        target_annotations = [
+            TargetAnnotation('this may match', TAG1)
+        ]
+        doc = _document_for_tokens(tokens_by_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
+        assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
+
+    def test_should_not_annotate_too_short_match_of_longer_sequence(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        too_short_tokens = _tokens_for_text('1')
+        tokens_per_line = [
+            too_short_tokens,
+            matching_tokens
+        ]
+        target_annotations = [
+            TargetAnnotation('this is matching 1', TAG1)
+        ]
+        doc = _document_for_tokens(tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(too_short_tokens) == [None] * len(too_short_tokens)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_not_annotate_similar_sequence_multiple_times(self):
+        matching_tokens_per_line = [
+            _tokens_for_text('this is matching'),
+            _tokens_for_text('and continues here')
+        ]
+        not_matching_tokens = _tokens_for_text('this is matching')
+
+        matching_tokens = flatten(matching_tokens_per_line)
+        target_annotations = [
+            TargetAnnotation('this is matching and continues here', TAG1)
+        ]
+        doc = _document_for_tokens(
+            matching_tokens_per_line + [not_matching_tokens]
+        )
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(not_matching_tokens) == [None] * len(not_matching_tokens)
+
+    def test_should_annotate_same_sequence_multiple_times_if_enabled(self):
+        matching_tokens_per_line = [
+            _tokens_for_text('this is matching'),
+            _tokens_for_text('this is matching')
+        ]
+
+        matching_tokens = flatten(matching_tokens_per_line)
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1, match_multiple=True)
+        ]
+        doc = _document_for_tokens(matching_tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_same_sequence_multiple_times_with_begin_prefix(self):
+        matching_tokens_per_line = [
+            _tokens_for_text('this is matching'),
+            _tokens_for_text('this is matching')
+        ]
+
+        matching_tokens = flatten(matching_tokens_per_line)
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1, match_multiple=True)
+        ]
+        doc = _document_for_tokens(matching_tokens_per_line)
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
+        # the begin tag should appear at the beginning of each match
+        assert (
+            _get_tags_of_tokens(matching_tokens) ==
+            [B_TAG_1, I_TAG_1, I_TAG_1, B_TAG_1, I_TAG_1, I_TAG_1]
+        )
+
+    def test_should_not_override_annotation(self):
+        matching_tokens_per_line = [
+            _tokens_for_text('this is matching')
+        ]
+
+        matching_tokens = flatten(matching_tokens_per_line)
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1),
+            TargetAnnotation('matching', TAG2)
+        ]
+        doc = _document_for_tokens(matching_tokens_per_line)
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_not_annotate_pre_annotated_tokens_on_separate_lines(self):
+        line_no_tokens = _tokens_for_text('1')
+        line_no_tokens[0].set_tag('line_no')
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('1', TAG2),
+            TargetAnnotation('this is matching', TAG1)
+        ]
+        doc = _document_for_tokens([
+            line_no_tokens + matching_tokens
+        ])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(line_no_tokens) == ['line_no'] * len(line_no_tokens)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+
+    def test_should_annotate_shorter_target_annotation_in_longer_line(self):
+        pre_tokens = _tokens_for_text('pre')
+        matching_tokens = _tokens_for_text('this is matching')
+        post_tokens = _tokens_for_text('post')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1)
+        ]
+        doc = _document_for_tokens([
+            pre_tokens + matching_tokens + post_tokens
+        ])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
+
+    def test_should_annotate_shorter_target_annotation_fuzzily(self):
+        pre_tokens = _tokens_for_text('pre')
+        matching_tokens = _tokens_for_text('this is matching')
+        post_tokens = _tokens_for_text('post')
+        target_annotations = [
+            TargetAnnotation('this is. matching', TAG1)
+        ]
+        doc = _document_for_tokens([
+            pre_tokens + matching_tokens + post_tokens
+        ])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
+
+    def test_should_annotate_multiple_shorter_target_annotation_in_longer_line(self):
+        pre_tokens = _tokens_for_text('pre')
+        matching_tokens_tag_1 = _tokens_for_text('this is matching')
+        mid_tokens = _tokens_for_text('mid')
+        matching_tokens_tag_2 = _tokens_for_text('also good')
+        post_tokens = _tokens_for_text('post')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1),
+            TargetAnnotation('also good', TAG2)
+        ]
+        doc = _document_for_tokens([
+            pre_tokens + matching_tokens_tag_1 + mid_tokens + matching_tokens_tag_2 + post_tokens
+        ])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens)
+        assert _get_tags_of_tokens(matching_tokens_tag_1) == [TAG1] * len(matching_tokens_tag_1)
+        assert _get_tags_of_tokens(mid_tokens) == [None] * len(mid_tokens)
+        assert _get_tags_of_tokens(matching_tokens_tag_2) == [TAG2] * len(matching_tokens_tag_2)
+        assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens)
+
+    def test_should_not_annotate_shorter_target_annotation_in_longer_line_multiple_times(self):
+        pre_tokens = _tokens_for_text('pre')
+        matching_tokens = _tokens_for_text('this is matching')
+        post_tokens = _tokens_for_text('post')
+        first_line_tokens = pre_tokens + matching_tokens + post_tokens
+        similar_line_tokens = _copy_tokens(first_line_tokens)
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1)
+        ]
+        doc = _document_for_tokens([
+            first_line_tokens,
+            similar_line_tokens
+        ])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(similar_line_tokens) == [None] * len(similar_line_tokens)
+
+    def test_should_annotate_shorter_target_annotation_in_longer_line_multiple_times_if_enabled(
+            self):
+        pre_tokens = _tokens_for_text('pre')
+        matching_tokens = _tokens_for_text('this is matching')
+        post_tokens = _tokens_for_text('post')
+        same_matching_tokens = _copy_tokens(matching_tokens)
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG1, match_multiple=True)
+        ]
+        doc = _document_for_tokens([
+            pre_tokens + matching_tokens + post_tokens,
+            _copy_tokens(pre_tokens) + same_matching_tokens + _copy_tokens(post_tokens)
+        ])
+        MatchingAnnotator(target_annotations).annotate(doc)
+        assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens)
+        assert _get_tags_of_tokens(same_matching_tokens) == [TAG1] * len(same_matching_tokens)
+
 
 class TestMatchingAnnotatorSubAnnotations(object):
-  def test_should_annotate_sub_tag_exactly_matching_without_begin_prefix(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG2, sub_annotations=[
-        TargetAnnotation('this', TAG1)
-      ])
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
-    assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, None, None]
-
-  def test_should_annotate_sub_tag_exactly_matching_with_begin_prefix(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG2, sub_annotations=[
-        TargetAnnotation('this', TAG1)
-      ])
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
-    assert [doc.get_sub_tag(x) for x in matching_tokens] == [B_TAG_1, None, None]
-
-  def test_should_annotate_sub_tag_case_insensitive(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG2, sub_annotations=[
-        TargetAnnotation('This', TAG1)
-      ])
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
-    assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, None, None]
-
-  def test_should_annotate_multiple_sub_tag_exactly_matching(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG2, sub_annotations=[
-        TargetAnnotation('this', TAG1),
-        TargetAnnotation('is', TAG2)
-      ])
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
-    assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, TAG2, None]
-
-  def test_should_annotate_multiple_sub_annotations_with_same_sub_tag(self):
-    matching_tokens = _tokens_for_text('this is matching')
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG2, sub_annotations=[
-        TargetAnnotation('this', TAG1),
-        TargetAnnotation('is', TAG1)
-      ])
-    ]
-    doc = _document_for_tokens([matching_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
-    assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, TAG1, None]
-
-  def test_should_annotate_same_sub_annotations_multiple_times_with_begin_prefic(self):
-    matching_tokens_by_line = [
-      _tokens_for_text('this is matching'),
-      _tokens_for_text('this is matching')
-    ]
-    matching_tokens = flatten(matching_tokens_by_line)
-    target_annotations = [
-      TargetAnnotation('this is matching', TAG2, match_multiple=True, sub_annotations=[
-        TargetAnnotation('this is', TAG1)
-      ])
-    ]
-    doc = _document_for_tokens(matching_tokens_by_line)
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
-    assert (
-      [doc.get_sub_tag(x) for x in matching_tokens] ==
-      [B_TAG_1, I_TAG_1, None, B_TAG_1, I_TAG_1, None]
-    )
-
-  def test_should_annotate_sub_tag_across_multiple_tokens(self):
-    sub_matching_tokens = _tokens_for_text('this is matching')
-    tag_matching_tokens = (
-      _tokens_for_text('something before') +
-      sub_matching_tokens +
-      _tokens_for_text('more text to come')
-    )
-    all_tokens = (
-      _tokens_for_text('not matching') + tag_matching_tokens + _tokens_for_text('and there')
-    )
-    target_annotations = [
-      TargetAnnotation(_tokens_to_text(tag_matching_tokens), TAG2, sub_annotations=[
-        TargetAnnotation(_tokens_to_text(sub_matching_tokens), TAG1)
-      ])
-    ]
-    doc = _document_for_tokens([all_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
-    assert [doc.get_sub_tag(x) for x in sub_matching_tokens] == [TAG1, TAG1, TAG1]
-
-  def test_should_annotate_sub_tag_across_multiple_tokens_with_junk_characters(self):
-    junk_char = ','
-    sub_matching_tokens = _tokens_for_text('this is matching' + junk_char)
-    tag_matching_tokens = (
-      _tokens_for_text('something before') +
-      sub_matching_tokens +
-      _tokens_for_text('more text to come')
-    )
-    all_tokens = (
-      _tokens_for_text('not matching') + tag_matching_tokens + _tokens_for_text('and there')
-    )
-    tag_matching_text = _tokens_to_text(tag_matching_tokens).replace(junk_char, '')
-    sub_matching_text = _tokens_to_text(sub_matching_tokens).replace(junk_char, '')
-    target_annotations = [
-      TargetAnnotation(tag_matching_text, TAG2, sub_annotations=[
-        TargetAnnotation(sub_matching_text, TAG1)
-      ])
-    ]
-    doc = _document_for_tokens([all_tokens])
-    MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
-    assert [doc.get_sub_tag(x) for x in sub_matching_tokens] == [TAG1, TAG1, TAG1]
+    def test_should_annotate_sub_tag_exactly_matching_without_begin_prefix(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG2, sub_annotations=[
+                TargetAnnotation('this', TAG1)
+            ])
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
+        assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, None, None]
+
+    def test_should_annotate_sub_tag_exactly_matching_with_begin_prefix(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG2, sub_annotations=[
+                TargetAnnotation('this', TAG1)
+            ])
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
+        assert [doc.get_sub_tag(x) for x in matching_tokens] == [B_TAG_1, None, None]
+
+    def test_should_annotate_sub_tag_case_insensitive(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG2, sub_annotations=[
+                TargetAnnotation('This', TAG1)
+            ])
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
+        assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, None, None]
+
+    def test_should_annotate_multiple_sub_tag_exactly_matching(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG2, sub_annotations=[
+                TargetAnnotation('this', TAG1),
+                TargetAnnotation('is', TAG2)
+            ])
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
+        assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, TAG2, None]
+
+    def test_should_annotate_multiple_sub_annotations_with_same_sub_tag(self):
+        matching_tokens = _tokens_for_text('this is matching')
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG2, sub_annotations=[
+                TargetAnnotation('this', TAG1),
+                TargetAnnotation('is', TAG1)
+            ])
+        ]
+        doc = _document_for_tokens([matching_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
+        assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, TAG1, None]
+
+    def test_should_annotate_same_sub_annotations_multiple_times_with_begin_prefic(self):
+        matching_tokens_by_line = [
+            _tokens_for_text('this is matching'),
+            _tokens_for_text('this is matching')
+        ]
+        matching_tokens = flatten(matching_tokens_by_line)
+        target_annotations = [
+            TargetAnnotation('this is matching', TAG2, match_multiple=True, sub_annotations=[
+                TargetAnnotation('this is', TAG1)
+            ])
+        ]
+        doc = _document_for_tokens(matching_tokens_by_line)
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc)
+        assert (
+            [doc.get_sub_tag(x) for x in matching_tokens] ==
+            [B_TAG_1, I_TAG_1, None, B_TAG_1, I_TAG_1, None]
+        )
+
+    def test_should_annotate_sub_tag_across_multiple_tokens(self):
+        sub_matching_tokens = _tokens_for_text('this is matching')
+        tag_matching_tokens = (
+            _tokens_for_text('something before') +
+            sub_matching_tokens +
+            _tokens_for_text('more text to come')
+        )
+        all_tokens = (
+            _tokens_for_text('not matching') + tag_matching_tokens + _tokens_for_text('and there')
+        )
+        target_annotations = [
+            TargetAnnotation(_tokens_to_text(tag_matching_tokens), TAG2, sub_annotations=[
+                TargetAnnotation(_tokens_to_text(sub_matching_tokens), TAG1)
+            ])
+        ]
+        doc = _document_for_tokens([all_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
+        assert [doc.get_sub_tag(x) for x in sub_matching_tokens] == [TAG1, TAG1, TAG1]
+
+    def test_should_annotate_sub_tag_across_multiple_tokens_with_junk_characters(self):
+        junk_char = ','
+        sub_matching_tokens = _tokens_for_text('this is matching' + junk_char)
+        tag_matching_tokens = (
+            _tokens_for_text('something before') +
+            sub_matching_tokens +
+            _tokens_for_text('more text to come')
+        )
+        all_tokens = (
+            _tokens_for_text('not matching') + tag_matching_tokens + _tokens_for_text('and there')
+        )
+        tag_matching_text = _tokens_to_text(tag_matching_tokens).replace(junk_char, '')
+        sub_matching_text = _tokens_to_text(sub_matching_tokens).replace(junk_char, '')
+        target_annotations = [
+            TargetAnnotation(tag_matching_text, TAG2, sub_annotations=[
+                TargetAnnotation(sub_matching_text, TAG1)
+            ])
+        ]
+        doc = _document_for_tokens([all_tokens])
+        MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc)
+        assert [doc.get_sub_tag(x) for x in sub_matching_tokens] == [TAG1, TAG1, TAG1]
diff --git a/sciencebeam_gym/preprocess/annotation/target_annotation.py b/sciencebeam_gym/preprocess/annotation/target_annotation.py
index b1ff43d..82f2b2c 100644
--- a/sciencebeam_gym/preprocess/annotation/target_annotation.py
+++ b/sciencebeam_gym/preprocess/annotation/target_annotation.py
@@ -1,411 +1,448 @@
+from __future__ import absolute_import
+
 import logging
 import json
 import re
 from itertools import chain
 
+from six.moves.configparser import ConfigParser
 import six
-from six.moves.configparser import ConfigParser # pylint: disable=E0401
 
 from lxml import etree
 
 from sciencebeam_utils.utils.compat import (
-  python_2_unicode_compatible
+    python_2_unicode_compatible
 )
 
 from sciencebeam_utils.utils.string import (
-  LazyStr
+    LazyStr
 )
 
 from sciencebeam_utils.utils.xml import (
-  get_text_content,
-  get_immediate_text
+    get_text_content,
+    get_immediate_text
 )
 
 from sciencebeam_utils.utils.collection import (
-  filter_truthy,
-  strip_all
+    filter_truthy,
+    strip_all
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class XmlMappingSuffix(object):
-  REGEX = '.regex'
-  EXTRACT_REGEX = '.extract-regex'
-  MATCH_MULTIPLE = '.match-multiple'
-  BONDING = '.bonding'
-  REQUIRE_NEXT = '.require-next'
-  CHILDREN = '.children'
-  CHILDREN_CONCAT = '.children.concat'
-  CHILDREN_RANGE = '.children.range'
-  UNMATCHED_PARENT_TEXT = '.unmatched-parent-text'
-  PRIORITY = '.priority'
-  SUB = '.sub'
+    REGEX = '.regex'
+    EXTRACT_REGEX = '.extract-regex'
+    MATCH_MULTIPLE = '.match-multiple'
+    BONDING = '.bonding'
+    REQUIRE_NEXT = '.require-next'
+    CHILDREN = '.children'
+    CHILDREN_CONCAT = '.children.concat'
+    CHILDREN_RANGE = '.children.range'
+    UNMATCHED_PARENT_TEXT = '.unmatched-parent-text'
+    PRIORITY = '.priority'
+    SUB = '.sub'
+
 
 @python_2_unicode_compatible
 class TargetAnnotation(object):
-  def __init__(
-    self, value, name,
-    match_multiple=False, bonding=False, require_next=False,
-    sub_annotations=None):
-
-    self.value = value
-    self.name = name
-    self.match_multiple = match_multiple
-    self.bonding = bonding
-    self.require_next = require_next
-    self.sub_annotations = sub_annotations
-
-  def __repr__(self):
-    return u'TA {} (match_multiple={}, bonding={}, require_next={}): {}'.format(
-      self.name, self.match_multiple, self.bonding, self.require_next,
-      self.value
-    )
+    def __init__(
+            self, value, name,
+            match_multiple=False, bonding=False, require_next=False,
+            sub_annotations=None):
+
+        self.value = value
+        self.name = name
+        self.match_multiple = match_multiple
+        self.bonding = bonding
+        self.require_next = require_next
+        self.sub_annotations = sub_annotations
+
+    def __repr__(self):
+        return u'TA {} (match_multiple={}, bonding={}, require_next={}): {}'.format(
+            self.name, self.match_multiple, self.bonding, self.require_next,
+            self.value
+        )
+
 
 def parse_xml_mapping(xml_mapping_filename):
-  with open(xml_mapping_filename, 'r') as f:
-    config = ConfigParser()
-    if six.PY3:
-      config.read_file(f)
-    else:
-      config.readfp(f)
-    return {
-      k: dict(config.items(k))
-      for k in config.sections()
-    }
+    with open(xml_mapping_filename, 'r') as f:
+        config = ConfigParser()
+        if six.PY3:
+            config.read_file(f)  # pylint: disable=no-member
+        else:
+            config.readfp(f)
+        return {
+            k: dict(config.items(k))
+            for k in config.sections()
+        }
+
 
 def relace_recursive(s, old, new):
-  previous = None
-  while s != previous:
-    previous = s
-    s = s.replace(old, new)
-  return s
+    previous = None
+    while s != previous:
+        previous = s
+        s = s.replace(old, new)
+    return s
+
 
 def strip_whitespace(s):
-  replacements = [
-    ('\t', ' '),
-    ('  ', ' '),
-    ('\r', '\n'),
-    (' \n', '\n'),
-    ('\n ', '\n'),
-    ('\n\n', '\n')
-  ]
-  for old, new in replacements:
-    s = relace_recursive(s, old, new)
-  return s
+    replacements = [
+        ('\t', ' '),
+        ('  ', ' '),
+        ('\r', '\n'),
+        (' \n', '\n'),
+        ('\n ', '\n'),
+        ('\n\n', '\n')
+    ]
+    for old, new in replacements:
+        s = relace_recursive(s, old, new)
+    return s
+
 
 def get_stripped_text_content(node, **kwargs):
-  return strip_whitespace(get_text_content(node, **kwargs).strip())
+    return strip_whitespace(get_text_content(node, **kwargs).strip())
+
 
 def get_stripped_text_content_list(nodes, **kwargs):
-  return [get_stripped_text_content(node, **kwargs) for node in nodes]
+    return [get_stripped_text_content(node, **kwargs) for node in nodes]
+
 
 def iter_flatten_if_nested(a):
-  for x in a:
-    if isinstance(x, list):
-      for y in iter_flatten_if_nested(x):
-        yield y
-    else:
-      yield x
+    for x in a:
+        if isinstance(x, list):
+            for y in iter_flatten_if_nested(x):
+                yield y
+        else:
+            yield x
+
 
 def flatten_if_nested(a):
-  if not a:
-    return a
-  return list(iter_flatten_if_nested(a))
+    if not a:
+        return a
+    return list(iter_flatten_if_nested(a))
+
 
 def apply_pattern(s, compiled_pattern):
-  m = compiled_pattern.match(s)
-  if m:
-    get_logger().debug('regex match: %s -> %s', compiled_pattern, m.groups())
-    return m.group(1)
-  return s
+    m = compiled_pattern.match(s)
+    if m:
+        get_logger().debug('regex match: %s -> %s', compiled_pattern, m.groups())
+        return m.group(1)
+    return s
+
 
 def iter_parents(children):
-  for child in children:
-    p = child.getparent()
-    if p is not None:
-      yield p
+    for child in children:
+        p = child.getparent()
+        if p is not None:
+            yield p
+
 
 def exclude_parents(children):
-  if not isinstance(children, list):
-    children = list(children)
-  all_parents = set(iter_parents(children))
-  return [child for child in children if not child in all_parents]
+    if not isinstance(children, list):
+        children = list(children)
+    all_parents = set(iter_parents(children))
+    return [child for child in children if child not in all_parents]
+
 
 def extract_children_source_list(parent, children_source_list):
-  used_nodes = set()
-  values = []
-  for children_source in children_source_list:
-    xpath = children_source.get('xpath')
-    if xpath:
-      matching_nodes = exclude_parents(parent.xpath(xpath))
-      if not matching_nodes:
-        get_logger().debug(
-          'child xpath does not match any item, skipping: xpath=%s (xml=%s)',
-          xpath,
-          LazyStr(lambda: str(etree.tostring(parent)))
-        )
-        used_nodes = set()
-        values = []
-        break
-      used_nodes |= set(matching_nodes)
-      value = ' '.join(get_stripped_text_content_list(matching_nodes))
-    else:
-      value = children_source.get('value')
-    values.append(value or '')
-  return values, used_nodes
+    used_nodes = set()
+    values = []
+    for children_source in children_source_list:
+        xpath = children_source.get('xpath')
+        if xpath:
+            matching_nodes = exclude_parents(parent.xpath(xpath))
+            if not matching_nodes:
+                get_logger().debug(
+                    'child xpath does not match any item, skipping: xpath=%s (xml=%s)',
+                    xpath,
+                    LazyStr(lambda: str(etree.tostring(parent)))
+                )
+                used_nodes = set()
+                values = []
+                break
+            used_nodes |= set(matching_nodes)
+            value = ' '.join(get_stripped_text_content_list(matching_nodes))
+        else:
+            value = children_source.get('value')
+        values.append(value or '')
+    return values, used_nodes
+
 
 def extract_children_concat(parent, children_concat):
-  used_nodes = set()
-  values = []
-  get_logger().debug('children_concat: %s', children_concat)
-  for children_concat_item in children_concat:
-    temp_values, temp_used_nodes = extract_children_source_list(
-      parent, children_concat_item
-    )
-    used_nodes |= temp_used_nodes
-    if temp_values:
-      values.append(''.join(temp_values))
-  return values, used_nodes
+    used_nodes = set()
+    values = []
+    get_logger().debug('children_concat: %s', children_concat)
+    for children_concat_item in children_concat:
+        temp_values, temp_used_nodes = extract_children_source_list(
+            parent, children_concat_item
+        )
+        used_nodes |= temp_used_nodes
+        if temp_values:
+            values.append(''.join(temp_values))
+    return values, used_nodes
+
 
 def extract_children_range(parent, children_range):
-  used_nodes = set()
-  values = []
-  standalone_values = []
-  get_logger().debug('children_range: %s', children_range)
-  for range_item in children_range:
-    temp_values, temp_used_nodes = extract_children_source_list(
-      parent, [range_item.get('min'), range_item.get('max')]
-    )
-    if len(temp_values) == 2:
-      temp_values = strip_all(temp_values)
-      if all(s.isdigit() for s in temp_values):
-        num_values = [int(s) for s in temp_values]
-        range_values = [str(x) for x in range(num_values[0], num_values[1] + 1)]
-        if range_item.get('standalone'):
-          standalone_values.extend(range_values)
-        else:
-          values.extend(range_values)
-        used_nodes |= temp_used_nodes
-      else:
-        get_logger().info('values not integers: %s', temp_values)
-  return values, standalone_values, used_nodes
+    used_nodes = set()
+    values = []
+    standalone_values = []
+    get_logger().debug('children_range: %s', children_range)
+    for range_item in children_range:
+        temp_values, temp_used_nodes = extract_children_source_list(
+            parent, [range_item.get('min'), range_item.get('max')]
+        )
+        if len(temp_values) == 2:
+            temp_values = strip_all(temp_values)
+            if all(s.isdigit() for s in temp_values):
+                num_values = [int(s) for s in temp_values]
+                range_values = [str(x) for x in range(num_values[0], num_values[1] + 1)]
+                if range_item.get('standalone'):
+                    standalone_values.extend(range_values)
+                else:
+                    values.extend(range_values)
+                used_nodes |= temp_used_nodes
+            else:
+                get_logger().info('values not integers: %s', temp_values)
+    return values, standalone_values, used_nodes
+
 
 def parse_xpaths(s):
-  return strip_all(s.strip().split('\n')) if s else None
+    return strip_all(s.strip().split('\n')) if s else None
+
 
 def match_xpaths(parent, xpaths):
-  return chain(*[parent.xpath(s) for s in xpaths])
+    return chain(*[parent.xpath(s) for s in xpaths])
+
 
 def extract_children(
-  parent, children_xpaths, children_concat, children_range, unmatched_parent_text):
-
-  concat_values_list, concat_used_nodes = extract_children_concat(parent, children_concat)
-  range_values_list, standalone_values, range_used_nodes = (
-    extract_children_range(parent, children_range)
-  )
-  used_nodes = concat_used_nodes | range_used_nodes
-
-  other_child_nodes = [
-    node for node in match_xpaths(parent, children_xpaths)
-    if not node in used_nodes
-  ]
-  other_child_nodes_excl_parents = exclude_parents(other_child_nodes)
-  text_content_list = filter_truthy(strip_all(
-    get_stripped_text_content_list(other_child_nodes_excl_parents) +
-    concat_values_list + range_values_list
-  ))
-  if len(other_child_nodes_excl_parents) != len(other_child_nodes):
-    other_child_nodes_excl_parents_set = set(other_child_nodes_excl_parents)
-    for child in other_child_nodes:
-      if child not in other_child_nodes_excl_parents_set:
-        text_values = filter_truthy(strip_all(get_immediate_text(child)))
-        text_content_list.extend(text_values)
-  if unmatched_parent_text:
-    value = get_stripped_text_content(
-      parent,
-      exclude=set(other_child_nodes) | used_nodes
-    ).strip()
-    if value and not value in text_content_list:
-      text_content_list.append(value)
-  return text_content_list, standalone_values
+        parent, children_xpaths, children_concat, children_range, unmatched_parent_text):
+
+    concat_values_list, concat_used_nodes = extract_children_concat(parent, children_concat)
+    range_values_list, standalone_values, range_used_nodes = (
+        extract_children_range(parent, children_range)
+    )
+    used_nodes = concat_used_nodes | range_used_nodes
+
+    other_child_nodes = [
+        node for node in match_xpaths(parent, children_xpaths)
+        if node not in used_nodes
+    ]
+    other_child_nodes_excl_parents = exclude_parents(other_child_nodes)
+    text_content_list = filter_truthy(strip_all(
+        get_stripped_text_content_list(other_child_nodes_excl_parents) +
+        concat_values_list + range_values_list
+    ))
+    if len(other_child_nodes_excl_parents) != len(other_child_nodes):
+        other_child_nodes_excl_parents_set = set(other_child_nodes_excl_parents)
+        for child in other_child_nodes:
+            if child not in other_child_nodes_excl_parents_set:
+                text_values = filter_truthy(strip_all(get_immediate_text(child)))
+                text_content_list.extend(text_values)
+    if unmatched_parent_text:
+        value = get_stripped_text_content(
+            parent,
+            exclude=set(other_child_nodes) | used_nodes
+        ).strip()
+        if value and value not in text_content_list:
+            text_content_list.append(value)
+    return text_content_list, standalone_values
+
 
 def parse_json_with_default(s, default_value):
-  return json.loads(s) if s else default_value
+    return json.loads(s) if s else default_value
+
 
 def get_prefixed_dict_values(d, key_prefix):
-  return {
-    k[len(key_prefix):]: v
-    for k, v in d.items()
-    if k.startswith(key_prefix)
-  }
+    return {
+        k[len(key_prefix):]: v
+        for k, v in d.items()
+        if k.startswith(key_prefix)
+    }
+
 
 def get_sub_mapping(mapping, tag):
-  return {
-    k: v
-    for k, v in get_prefixed_dict_values(mapping, tag + XmlMappingSuffix.SUB + '.').items()
-    if '.' not in k
-  }
+    return {
+        k: v
+        for k, v in get_prefixed_dict_values(mapping, tag + XmlMappingSuffix.SUB + '.').items()
+        if '.' not in k
+    }
+
 
 def re_compile_or_none(pattern):
-  return re.compile(pattern) if pattern else None
+    return re.compile(pattern) if pattern else None
+
 
 def extract_using_regex(s, compiled_pattern):
-  result = None
-  start = 0
-  for m in compiled_pattern.finditer(s):
-    m_start = m.start(1)
-    m_end = m.end(1)
-    m_text = m.group(1)
-    get_logger().debug('extract match: %d:%d, %s', m_start, m_end, m_text)
+    result = None
+    start = 0
+    for m in compiled_pattern.finditer(s):
+        m_start = m.start(1)
+        m_end = m.end(1)
+        m_text = m.group(1)
+        get_logger().debug('extract match: %d:%d, %s', m_start, m_end, m_text)
+        if result is None:
+            result = []
+        if start < m_start:
+            result.append(s[start:m_start].strip())
+        result.append(m_text)
+        start = m_end + 1
     if result is None:
-      result = []
-    if start < m_start:
-      result.append(s[start:m_start].strip())
-    result.append(m_text)
-    start = m_end + 1
-  if result is None:
-    return s
-  if start < len(s):
-    result.append(s[start:].strip())
-  if len(result) == 1:
-    return result[0]
-  # also include the full string
-  result.append(s)
-  return result
+        return s
+    if start < len(s):
+        result.append(s[start:].strip())
+    if len(result) == 1:
+        return result[0]
+    # also include the full string
+    result.append(s)
+    return result
+
 
 def extract_sub_annotations(parent_node, sub_xpaths, mapping, parent_key):
-  if not sub_xpaths:
-    return
-  sub_annotations = []
-  for sub_tag, sub_xpath in sub_xpaths.items():
-    sub_key_prefix = parent_key + XmlMappingSuffix.SUB + '.' + sub_tag
-    extract_re_compiled_pattern = re_compile_or_none(
-      mapping.get(sub_key_prefix + XmlMappingSuffix.EXTRACT_REGEX)
-    )
-    get_logger().debug('sub_key_prefix: %s', sub_key_prefix)
-    get_logger().debug(
-      'extract_re_compiled_pattern (%s, %s): %s',
-      parent_key, sub_tag, extract_re_compiled_pattern
-    )
+    if not sub_xpaths:
+        return None
+    sub_annotations = []
+    for sub_tag, sub_xpath in sub_xpaths.items():
+        sub_key_prefix = parent_key + XmlMappingSuffix.SUB + '.' + sub_tag
+        extract_re_compiled_pattern = re_compile_or_none(
+            mapping.get(sub_key_prefix + XmlMappingSuffix.EXTRACT_REGEX)
+        )
+        get_logger().debug('sub_key_prefix: %s', sub_key_prefix)
+        get_logger().debug(
+            'extract_re_compiled_pattern (%s, %s): %s',
+            parent_key, sub_tag, extract_re_compiled_pattern
+        )
+
+        for e in match_xpaths(parent_node, [sub_xpath]):
+            value = get_stripped_text_content(e)
+            if value:
+                value = strip_whitespace(value).strip()
+            if extract_re_compiled_pattern is not None and value:
+                value = extract_using_regex(value, extract_re_compiled_pattern)
+            if value:
+                sub_annotations.append(TargetAnnotation(value, sub_tag))
+    return sub_annotations
 
-    for e in match_xpaths(parent_node, [sub_xpath]):
-      value = get_stripped_text_content(e)
-      if value:
-        value = strip_whitespace(value).strip()
-      if extract_re_compiled_pattern is not None and value:
-        value = extract_using_regex(value, extract_re_compiled_pattern)
-      if value:
-        sub_annotations.append(TargetAnnotation(value, sub_tag))
-  return sub_annotations
 
 def xml_root_to_target_annotations(xml_root, xml_mapping):
-  if not xml_root.tag in xml_mapping:
-    raise Exception("unrecognised tag: {} (available: {})".format(
-      xml_root.tag, xml_mapping.sections())
-    )
+    if xml_root.tag not in xml_mapping:
+        raise Exception("unrecognised tag: {} (available: {})".format(
+            xml_root.tag, xml_mapping.sections())
+        )
 
-  mapping = xml_mapping[xml_root.tag]
-
-  field_names = [k for k in mapping.keys() if '.' not in k]
-  get_mapping_flag = lambda k, suffix: mapping.get(k + suffix) == 'true'
-  get_match_multiple = lambda k: get_mapping_flag(k, XmlMappingSuffix.MATCH_MULTIPLE)
-  get_bonding_flag = lambda k: get_mapping_flag(k, XmlMappingSuffix.BONDING)
-  get_require_next_flag = lambda k: get_mapping_flag(k, XmlMappingSuffix.REQUIRE_NEXT)
-  get_unmatched_parent_text_flag = (
-    lambda k: get_mapping_flag(k, XmlMappingSuffix.UNMATCHED_PARENT_TEXT)
-  )
-
-  get_logger().debug('fields: %s', field_names)
-
-  target_annotations_with_pos = []
-  xml_pos_by_node = {node: i for i, node in enumerate(xml_root.iter())}
-  for k in field_names:
-    match_multiple = get_match_multiple(k)
-    bonding = get_bonding_flag(k)
-    require_next = get_require_next_flag(k)
-    unmatched_parent_text = get_unmatched_parent_text_flag(k)
-    children_xpaths = parse_xpaths(mapping.get(k + XmlMappingSuffix.CHILDREN))
-    children_concat = parse_json_with_default(
-      mapping.get(k + XmlMappingSuffix.CHILDREN_CONCAT), []
-    )
-    children_range = parse_json_with_default(
-      mapping.get(k + XmlMappingSuffix.CHILDREN_RANGE), []
-    )
-    re_compiled_pattern = re_compile_or_none(
-      mapping.get(k + XmlMappingSuffix.REGEX)
-    )
-    extract_re_compiled_pattern = re_compile_or_none(
-      mapping.get(k + XmlMappingSuffix.EXTRACT_REGEX)
-    )
-    get_logger().debug('extract_re_compiled_pattern (%s): %s', k, extract_re_compiled_pattern)
+    mapping = xml_mapping[xml_root.tag]
+
+    field_names = [k for k in mapping.keys() if '.' not in k]
+
+    def get_mapping_flag(k, suffix):
+        return mapping.get(k + suffix) == 'true'
+
+    def get_match_multiple(k):
+        return get_mapping_flag(k, XmlMappingSuffix.MATCH_MULTIPLE)
 
-    priority = int(mapping.get(k + XmlMappingSuffix.PRIORITY, '0'))
-    sub_xpaths = get_sub_mapping(mapping, k)
-    get_logger().debug('sub_xpaths (%s): %s', k, sub_xpaths)
+    def get_bonding_flag(k):
+        return get_mapping_flag(k, XmlMappingSuffix.BONDING)
 
-    xpaths = parse_xpaths(mapping[k])
-    get_logger().debug('xpaths(%s): %s', k, xpaths)
-    for e in match_xpaths(xml_root, xpaths):
-      e_pos = xml_pos_by_node.get(e)
+    def get_require_next_flag(k):
+        return get_mapping_flag(k, XmlMappingSuffix.REQUIRE_NEXT)
 
-      sub_annotations = extract_sub_annotations(e, sub_xpaths, mapping, k)
-      get_logger().debug('sub_annotations (%s): %s', k, sub_annotations)
+    get_unmatched_parent_text_flag = (
+        lambda k: get_mapping_flag(k, XmlMappingSuffix.UNMATCHED_PARENT_TEXT)
+    )
 
-      if children_xpaths:
-        text_content_list, standalone_values = extract_children(
-          e, children_xpaths, children_concat, children_range, unmatched_parent_text
+    get_logger().debug('fields: %s', field_names)
+
+    target_annotations_with_pos = []
+    xml_pos_by_node = {node: i for i, node in enumerate(xml_root.iter())}
+    for k in field_names:
+        match_multiple = get_match_multiple(k)
+        bonding = get_bonding_flag(k)
+        require_next = get_require_next_flag(k)
+        unmatched_parent_text = get_unmatched_parent_text_flag(k)
+        children_xpaths = parse_xpaths(mapping.get(k + XmlMappingSuffix.CHILDREN))
+        children_concat = parse_json_with_default(
+            mapping.get(k + XmlMappingSuffix.CHILDREN_CONCAT), []
+        )
+        children_range = parse_json_with_default(
+            mapping.get(k + XmlMappingSuffix.CHILDREN_RANGE), []
         )
-      else:
-        text_content_list = filter_truthy(strip_all([get_stripped_text_content(e)]))
-        standalone_values = []
-      if re_compiled_pattern:
-        text_content_list = filter_truthy([
-          apply_pattern(s, re_compiled_pattern) for s in text_content_list
-        ])
-      if extract_re_compiled_pattern:
-        text_content_list = filter_truthy([
-          extract_using_regex(s, extract_re_compiled_pattern) for s in text_content_list
-        ])
-      text_content_list = flatten_if_nested(text_content_list)
-      if text_content_list:
-        value = (
-          text_content_list[0]
-          if len(text_content_list) == 1
-          else sorted(text_content_list, key=lambda s: -len(s))
+        re_compiled_pattern = re_compile_or_none(
+            mapping.get(k + XmlMappingSuffix.REGEX)
         )
-        target_annotations_with_pos.append((
-          (-priority, e_pos),
-          TargetAnnotation(
-            value,
-            k,
-            match_multiple=match_multiple,
-            bonding=bonding,
-            require_next=require_next,
-            sub_annotations=sub_annotations
-          )
-        ))
-      if standalone_values:
-        for i, standalone_value in enumerate(standalone_values):
-          target_annotations_with_pos.append((
-            (-priority, e_pos, i),
-            TargetAnnotation(
-              standalone_value,
-              k,
-              match_multiple=match_multiple,
-              bonding=bonding,
-              sub_annotations=sub_annotations
-            )
-          ))
-  target_annotations_with_pos = sorted(
-    target_annotations_with_pos,
-    key=lambda x: x[0]
-  )
-  get_logger().debug('target_annotations_with_pos:\n%s', target_annotations_with_pos)
-  target_annotations = [
-    x[1] for x in target_annotations_with_pos
-  ]
-  get_logger().debug('target_annotations:\n%s', '\n'.join([
-    ' ' + str(a) for a in target_annotations
-  ]))
-  return target_annotations
+        extract_re_compiled_pattern = re_compile_or_none(
+            mapping.get(k + XmlMappingSuffix.EXTRACT_REGEX)
+        )
+        get_logger().debug('extract_re_compiled_pattern (%s): %s', k, extract_re_compiled_pattern)
+
+        priority = int(mapping.get(k + XmlMappingSuffix.PRIORITY, '0'))
+        sub_xpaths = get_sub_mapping(mapping, k)
+        get_logger().debug('sub_xpaths (%s): %s', k, sub_xpaths)
+
+        xpaths = parse_xpaths(mapping[k])
+        get_logger().debug('xpaths(%s): %s', k, xpaths)
+        for e in match_xpaths(xml_root, xpaths):
+            e_pos = xml_pos_by_node.get(e)
+
+            sub_annotations = extract_sub_annotations(e, sub_xpaths, mapping, k)
+            get_logger().debug('sub_annotations (%s): %s', k, sub_annotations)
+
+            if children_xpaths:
+                text_content_list, standalone_values = extract_children(
+                    e, children_xpaths, children_concat, children_range, unmatched_parent_text
+                )
+            else:
+                text_content_list = filter_truthy(strip_all([get_stripped_text_content(e)]))
+                standalone_values = []
+            if re_compiled_pattern:
+                text_content_list = filter_truthy([
+                    apply_pattern(s, re_compiled_pattern) for s in text_content_list
+                ])
+            if extract_re_compiled_pattern:
+                text_content_list = filter_truthy([
+                    extract_using_regex(s, extract_re_compiled_pattern) for s in text_content_list
+                ])
+            text_content_list = flatten_if_nested(text_content_list)
+            if text_content_list:
+                value = (
+                    text_content_list[0]
+                    if len(text_content_list) == 1
+                    else sorted(text_content_list, key=lambda s: -len(s))
+                )
+                target_annotations_with_pos.append((
+                    (-priority, e_pos),
+                    TargetAnnotation(
+                        value,
+                        k,
+                        match_multiple=match_multiple,
+                        bonding=bonding,
+                        require_next=require_next,
+                        sub_annotations=sub_annotations
+                    )
+                ))
+            if standalone_values:
+                for i, standalone_value in enumerate(standalone_values):
+                    target_annotations_with_pos.append((
+                        (-priority, e_pos, i),
+                        TargetAnnotation(
+                            standalone_value,
+                            k,
+                            match_multiple=match_multiple,
+                            bonding=bonding,
+                            sub_annotations=sub_annotations
+                        )
+                    ))
+    target_annotations_with_pos = sorted(
+        target_annotations_with_pos,
+        key=lambda x: x[0]
+    )
+    get_logger().debug('target_annotations_with_pos:\n%s', target_annotations_with_pos)
+    target_annotations = [
+        x[1] for x in target_annotations_with_pos
+    ]
+    get_logger().debug('target_annotations:\n%s', '\n'.join([
+        ' ' + str(a) for a in target_annotations
+    ]))
+    return target_annotations
diff --git a/sciencebeam_gym/preprocess/annotation/target_annotation_test.py b/sciencebeam_gym/preprocess/annotation/target_annotation_test.py
index a3fdadf..9d80098 100644
--- a/sciencebeam_gym/preprocess/annotation/target_annotation_test.py
+++ b/sciencebeam_gym/preprocess/annotation/target_annotation_test.py
@@ -5,9 +5,9 @@ import json
 from lxml.builder import E
 
 from sciencebeam_gym.preprocess.annotation.target_annotation import (
-  strip_whitespace,
-  xml_root_to_target_annotations,
-  XmlMappingSuffix
+    strip_whitespace,
+    xml_root_to_target_annotations,
+    XmlMappingSuffix
 )
 
 TAG1 = 'tag1'
@@ -18,603 +18,608 @@ SOME_VALUE_2 = 'some value2'
 SOME_LONGER_VALUE = 'some longer value1'
 SOME_SHORTER_VALUE = 'value1'
 
+
 class TestStripWhitespace(object):
-  def test_should_replace_tab_with_space(self):
-    assert strip_whitespace(SOME_VALUE + '\t' + SOME_VALUE_2) == SOME_VALUE + ' ' + SOME_VALUE_2
+    def test_should_replace_tab_with_space(self):
+        assert strip_whitespace(SOME_VALUE + '\t' + SOME_VALUE_2) == SOME_VALUE + ' ' + SOME_VALUE_2
+
+    def test_should_strip_double_space(self):
+        assert strip_whitespace(SOME_VALUE + '  ' + SOME_VALUE_2) == SOME_VALUE + ' ' + SOME_VALUE_2
 
-  def test_should_strip_double_space(self):
-    assert strip_whitespace(SOME_VALUE + '  ' + SOME_VALUE_2) == SOME_VALUE + ' ' + SOME_VALUE_2
+    def test_should_strip_double_line_feed(self):
+        assert strip_whitespace(SOME_VALUE + '\n\n' +
+                                SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2
 
-  def test_should_strip_double_line_feed(self):
-    assert strip_whitespace(SOME_VALUE + '\n\n' + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2
+    def test_should_replace_cr_with_line_feed(self):
+        assert strip_whitespace(
+            SOME_VALUE + '\r' + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2
 
-  def test_should_replace_cr_with_line_feed(self):
-    assert strip_whitespace(SOME_VALUE + '\r' + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2
+    def test_should_strip_spaces_around_line_feed(self):
+        assert strip_whitespace(SOME_VALUE + ' \n ' +
+                                SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2
 
-  def test_should_strip_spaces_around_line_feed(self):
-    assert strip_whitespace(SOME_VALUE + ' \n ' + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2
+    def test_should_strip_multiple_lines_with_blanks(self):
+        assert (
+            strip_whitespace(SOME_VALUE + ' \n \n \n ' + SOME_VALUE_2) ==
+            SOME_VALUE + '\n' + SOME_VALUE_2
+        )
 
-  def test_should_strip_multiple_lines_with_blanks(self):
-    assert (
-      strip_whitespace(SOME_VALUE + ' \n \n \n ' + SOME_VALUE_2) ==
-      SOME_VALUE + '\n' + SOME_VALUE_2
-    )
 
 class TestXmlRootToTargetAnnotations(object):
-  def test_should_return_empty_target_annotations_for_empty_xml(self):
-    xml_root = E.article(
-    )
-    xml_mapping = {
-      'article': {
-        'title': 'title'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert target_annotations == []
-
-  def test_should_return_empty_target_annotations_for_no_matching_annotations(self):
-    xml_root = E.article(
-      E.other(SOME_VALUE)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert target_annotations == []
-
-  def test_should_return_matching_target_annotations(self):
-    xml_root = E.article(
-      E.title(SOME_VALUE)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert len(target_annotations) == 1
-    assert target_annotations[0].name == TAG1
-    assert target_annotations[0].value == SOME_VALUE
-
-  def test_should_strip_extra_space(self):
-    xml_root = E.article(
-      E.abstract(SOME_VALUE + '  ' + SOME_VALUE_2)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'abstract'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert len(target_annotations) == 1
-    assert target_annotations[0].name == TAG1
-    assert target_annotations[0].value == SOME_VALUE + ' ' + SOME_VALUE_2
-
-  def test_should_apply_regex_to_result(self):
-    xml_root = E.article(
-      E.title('1.1. ' + SOME_VALUE)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title',
-        TAG1 + XmlMappingSuffix.REGEX: r'(?:\d+\.?)* ?(.*)'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert len(target_annotations) == 1
-    assert target_annotations[0].name == TAG1
-    assert target_annotations[0].value == SOME_VALUE
-
-  def test_should_apply_match_multiple_flag(self):
-    xml_root = E.article(
-      E.title(SOME_VALUE)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title',
-        TAG1 + XmlMappingSuffix.MATCH_MULTIPLE: 'true'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [t.match_multiple for t in target_annotations] == [True]
-
-  def test_should_not_apply_match_multiple_flag_if_not_set(self):
-    xml_root = E.article(
-      E.title(SOME_VALUE)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [t.match_multiple for t in target_annotations] == [False]
-
-  def test_should_apply_match_bonding_flag(self):
-    xml_root = E.article(
-      E.title(SOME_VALUE)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title',
-        TAG1 + XmlMappingSuffix.BONDING: 'true'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [t.bonding for t in target_annotations] == [True]
-
-  def test_should_not_apply_match_bonding_flag_if_not_set(self):
-    xml_root = E.article(
-      E.title(SOME_VALUE)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [t.bonding for t in target_annotations] == [False]
-
-  def test_should_apply_match_require_next_flag(self):
-    xml_root = E.article(
-      E.title(SOME_VALUE)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title',
-        TAG1 + XmlMappingSuffix.REQUIRE_NEXT: 'true'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [t.require_next for t in target_annotations] == [True]
-
-  def test_should_not_apply_match_require_next_flag_if_not_set(self):
-    xml_root = E.article(
-      E.title(SOME_VALUE)
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [t.require_next for t in target_annotations] == [False]
-
-  def test_should_use_multiple_xpaths(self):
-    xml_root = E.article(
-      E.entry(
-        E.child1(SOME_VALUE),
-        E.child2(SOME_VALUE_2)
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: '\n{}\n{}\n'.format(
-          'entry/child1',
-          'entry/child2'
+    def test_should_return_empty_target_annotations_for_empty_xml(self):
+        xml_root = E.article(
+        )
+        xml_mapping = {
+            'article': {
+                'title': 'title'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert target_annotations == []
+
+    def test_should_return_empty_target_annotations_for_no_matching_annotations(self):
+        xml_root = E.article(
+            E.other(SOME_VALUE)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'title'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert target_annotations == []
+
+    def test_should_return_matching_target_annotations(self):
+        xml_root = E.article(
+            E.title(SOME_VALUE)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'title'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert len(target_annotations) == 1
+        assert target_annotations[0].name == TAG1
+        assert target_annotations[0].value == SOME_VALUE
+
+    def test_should_strip_extra_space(self):
+        xml_root = E.article(
+            E.abstract(SOME_VALUE + '  ' + SOME_VALUE_2)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'abstract'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert len(target_annotations) == 1
+        assert target_annotations[0].name == TAG1
+        assert target_annotations[0].value == SOME_VALUE + ' ' + SOME_VALUE_2
+
+    def test_should_apply_regex_to_result(self):
+        xml_root = E.article(
+            E.title('1.1. ' + SOME_VALUE)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'title',
+                TAG1 + XmlMappingSuffix.REGEX: r'(?:\d+\.?)* ?(.*)'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert len(target_annotations) == 1
+        assert target_annotations[0].name == TAG1
+        assert target_annotations[0].value == SOME_VALUE
+
+    def test_should_apply_match_multiple_flag(self):
+        xml_root = E.article(
+            E.title(SOME_VALUE)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'title',
+                TAG1 + XmlMappingSuffix.MATCH_MULTIPLE: 'true'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [t.match_multiple for t in target_annotations] == [True]
+
+    def test_should_not_apply_match_multiple_flag_if_not_set(self):
+        xml_root = E.article(
+            E.title(SOME_VALUE)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'title'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [t.match_multiple for t in target_annotations] == [False]
+
+    def test_should_apply_match_bonding_flag(self):
+        xml_root = E.article(
+            E.title(SOME_VALUE)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'title',
+                TAG1 + XmlMappingSuffix.BONDING: 'true'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [t.bonding for t in target_annotations] == [True]
+
+    def test_should_not_apply_match_bonding_flag_if_not_set(self):
+        xml_root = E.article(
+            E.title(SOME_VALUE)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'title'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [t.bonding for t in target_annotations] == [False]
+
+    def test_should_apply_match_require_next_flag(self):
+        xml_root = E.article(
+            E.title(SOME_VALUE)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'title',
+                TAG1 + XmlMappingSuffix.REQUIRE_NEXT: 'true'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [t.require_next for t in target_annotations] == [True]
+
+    def test_should_not_apply_match_require_next_flag_if_not_set(self):
+        xml_root = E.article(
+            E.title(SOME_VALUE)
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'title'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [t.require_next for t in target_annotations] == [False]
+
+    def test_should_use_multiple_xpaths(self):
+        xml_root = E.article(
+            E.entry(
+                E.child1(SOME_VALUE),
+                E.child2(SOME_VALUE_2)
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: '\n{}\n{}\n'.format(
+                    'entry/child1',
+                    'entry/child2'
+                )
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, SOME_VALUE),
+            (TAG1, SOME_VALUE_2)
+        ]
+
+    def test_should_apply_children_xpaths_and_sort_by_value_descending(self):
+        xml_root = E.article(
+            E.entry(
+                E.child1(SOME_SHORTER_VALUE),
+                E.child2(SOME_LONGER_VALUE)
+            ),
+            E.entry(
+                E.child1(SOME_LONGER_VALUE)
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: './/*'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE]),
+            (TAG1, SOME_LONGER_VALUE)
+        ]
+
+    def test_should_apply_children_xpaths_and_exclude_parents(self):
+        xml_root = E.article(
+            E.entry(
+                E.parent(
+                    E.child2(SOME_LONGER_VALUE),
+                    E.child1(SOME_SHORTER_VALUE)
+                )
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: './/*'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE])
+        ]
+
+    def test_should_apply_children_xpaths_and_include_parent_text_between_matched_children(self):
+        xml_root = E.article(
+            E.entry(
+                E.parent(
+                    E.child2(SOME_LONGER_VALUE),
+                    SOME_VALUE,
+                    E.child1(SOME_SHORTER_VALUE)
+                )
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: './/*'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, [SOME_LONGER_VALUE, SOME_VALUE, SOME_SHORTER_VALUE])
+        ]
+
+    def test_should_apply_multiple_children_xpaths_and_include_parent_text_if_enabled(self):
+        xml_root = E.article(
+            E.entry(
+                E.child1(SOME_SHORTER_VALUE),
+                SOME_LONGER_VALUE
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: '\n{}\n{}\n'.format('.//*', '.'),
+                TAG1 + XmlMappingSuffix.UNMATCHED_PARENT_TEXT: 'true'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE])
+        ]
+
+    def test_should_apply_concat_children(self):
+        num_values = ['101', '202']
+        xml_root = E.article(
+            E.entry(
+                E.parent(
+                    E.child1(SOME_VALUE),
+                    E.fpage(num_values[0]),
+                    E.lpage(num_values[1])
+                )
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: './/*',
+                TAG1 + XmlMappingSuffix.CHILDREN_CONCAT: json.dumps([[{
+                    'xpath': './/fpage'
+                }, {
+                    'value': '-'
+                }, {
+                    'xpath': './/lpage'
+                }]])
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, [SOME_VALUE, '-'.join(num_values)])
+        ]
+
+    def test_should_not_apply_concat_children_if_one_node_was_not_found(self):
+        num_values = ['101', '202']
+        xml_root = E.article(
+            E.entry(
+                E.parent(
+                    E.child1(SOME_VALUE),
+                    E.fpage(num_values[0]),
+                    E.lpage(num_values[1])
+                )
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: './/*',
+                TAG1 + XmlMappingSuffix.CHILDREN_CONCAT: json.dumps([[{
+                    'xpath': './/fpage'
+                }, {
+                    'value': '-'
+                }, {
+                    'xpath': './/unknown'
+                }]])
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, [SOME_VALUE, num_values[0], num_values[1]])
+        ]
+
+    def test_should_apply_range_children(self):
+        num_values = [101, 102, 103, 104, 105, 106, 107]
+        xml_root = E.article(
+            E.entry(
+                E.child1(SOME_VALUE),
+                E.fpage(str(min(num_values))),
+                E.lpage(str(max(num_values)))
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage',
+                TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{
+                    'min': {
+                        'xpath': 'fpage'
+                    },
+                    'max': {
+                        'xpath': 'lpage'
+                    }
+                }])
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, [str(x) for x in num_values])
+        ]
+
+    def test_should_apply_range_children_as_separate_target_annotations(self):
+        num_values = [101, 102, 103, 104, 105, 106, 107]
+        xml_root = E.article(
+            E.entry(
+                E.child1(SOME_VALUE),
+                E.fpage(str(min(num_values))),
+                E.lpage(str(max(num_values)))
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage',
+                TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{
+                    'min': {
+                        'xpath': 'fpage'
+                    },
+                    'max': {
+                        'xpath': 'lpage'
+                    },
+                    'standalone': True
+                }])
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, str(x))
+            for x in num_values
+        ]
+
+    def test_should_not_apply_range_children_if_xpath_not_matching(self):
+        num_values = [101, 102, 103, 104, 105, 106, 107]
+        fpage = str(min(num_values))
+        lpage = str(max(num_values))
+        xml_root = E.article(
+            E.entry(
+                E.child1(SOME_VALUE),
+                E.fpage(fpage),
+                E.lpage(lpage)
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|unknown',
+                TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{
+                    'min': {
+                        'xpath': 'fpage'
+                    },
+                    'max': {
+                        'xpath': 'unknown'
+                    }
+                }])
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, fpage)
+        ]
+
+    def test_should_not_apply_range_children_if_value_is_not_integer(self):
+        fpage = 'abc'
+        lpage = 'xyz'
+        xml_root = E.article(
+            E.entry(
+                E.child1(SOME_VALUE),
+                E.fpage(fpage),
+                E.lpage(lpage)
+            )
+        )
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage',
+                TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{
+                    'min': {
+                        'xpath': 'fpage'
+                    },
+                    'max': {
+                        'xpath': 'lpage'
+                    }
+                }])
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, [fpage, lpage])
+        ]
+
+    def test_should_add_sub_annotations(self):
+        xml_root = E.article(
+            E.entry(
+                E.firstname(SOME_VALUE),
+                E.givennames(SOME_VALUE_2)
+            )
         )
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, SOME_VALUE),
-      (TAG1, SOME_VALUE_2)
-    ]
-
-  def test_should_apply_children_xpaths_and_sort_by_value_descending(self):
-    xml_root = E.article(
-      E.entry(
-        E.child1(SOME_SHORTER_VALUE),
-        E.child2(SOME_LONGER_VALUE)
-      ),
-      E.entry(
-        E.child1(SOME_LONGER_VALUE)
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: './/*'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE]),
-      (TAG1, SOME_LONGER_VALUE)
-    ]
-
-  def test_should_apply_children_xpaths_and_exclude_parents(self):
-    xml_root = E.article(
-      E.entry(
-        E.parent(
-          E.child2(SOME_LONGER_VALUE),
-          E.child1(SOME_SHORTER_VALUE)
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.SUB + '.firstname': './firstname',
+                TAG1 + XmlMappingSuffix.SUB + '.givennames': './givennames',
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations[0].sub_annotations] == [
+            ('firstname', SOME_VALUE),
+            ('givennames', SOME_VALUE_2)
+        ]
+
+    def test_should_add_sub_annotations_with_multiple_values(self):
+        xml_root = E.article(
+            E.entry(
+                E.value(SOME_VALUE),
+                E.value(SOME_VALUE_2)
+            )
         )
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: './/*'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE])
-    ]
-
-  def test_should_apply_children_xpaths_and_include_parent_text_between_matched_children(self):
-    xml_root = E.article(
-      E.entry(
-        E.parent(
-          E.child2(SOME_LONGER_VALUE),
-          SOME_VALUE,
-          E.child1(SOME_SHORTER_VALUE)
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.SUB + '.value': './value'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(t.name, t.value) for t in target_annotations[0].sub_annotations] == [
+            ('value', SOME_VALUE),
+            ('value', SOME_VALUE_2)
+        ]
+
+    def test_should_extract_numbers_from_value_after_text(self):
+        xml_root = E.article(E.entry(
+            E.value(SOME_VALUE + ' 12345')
+        ))
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert len(target_annotations) == 1
+        assert [(t.name, set(t.value)) for t in target_annotations] == [
+            (TAG1, {SOME_VALUE + ' 12345', SOME_VALUE, '12345'})
+        ]
+
+    def test_should_extract_single_value_if_its_the_only_value(self):
+        xml_root = E.article(E.entry(
+            E.value('12345')
+        ))
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert len(target_annotations) == 1
+        assert [(t.name, t.value) for t in target_annotations] == [
+            (TAG1, '12345')
+        ]
+
+    def test_should_unnest_extract_value_from_children(self):
+        xml_root = E.article(E.entry(
+            E.value(SOME_VALUE + ' 12345'),
+            E.value(SOME_VALUE_2 + ' 54321')
+        ))
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                TAG1 + XmlMappingSuffix.CHILDREN: r'.//*',
+                TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert len(target_annotations) == 1
+        assert [(t.name, set(t.value)) for t in target_annotations] == [
+            (TAG1, {
+                SOME_VALUE + ' 12345', SOME_VALUE, '12345',
+                SOME_VALUE_2 + ' 54321', SOME_VALUE_2, '54321'
+            })
+        ]
+
+    def test_should_extract_numbers_from_sub_value_after_text(self):
+        xml_root = E.article(E.entry(
+            E.value(SOME_VALUE + ' 12345')
+        ))
+        sub_key = TAG1 + XmlMappingSuffix.SUB + '.value'
+        xml_mapping = {
+            'article': {
+                TAG1: 'entry',
+                sub_key: './value',
+                sub_key + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert len(target_annotations) == 1
+        assert [(t.name, set(t.value)) for t in target_annotations[0].sub_annotations] == [
+            ('value', {SOME_VALUE + ' 12345', SOME_VALUE, '12345'})
+        ]
+
+    def test_should_return_full_text(self):
+        xml_root = E.article(
+            E.title(
+                'some ',
+                E.other('embedded'),
+                ' text'
+            )
         )
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: './/*'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, [SOME_LONGER_VALUE, SOME_VALUE, SOME_SHORTER_VALUE])
-    ]
-
-  def test_should_apply_multiple_children_xpaths_and_include_parent_text_if_enabled(self):
-    xml_root = E.article(
-      E.entry(
-        E.child1(SOME_SHORTER_VALUE),
-        SOME_LONGER_VALUE
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: '\n{}\n{}\n'.format('.//*', '.'),
-        TAG1 + XmlMappingSuffix.UNMATCHED_PARENT_TEXT: 'true'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE])
-    ]
-
-  def test_should_apply_concat_children(self):
-    num_values = ['101', '202']
-    xml_root = E.article(
-      E.entry(
-        E.parent(
-          E.child1(SOME_VALUE),
-          E.fpage(num_values[0]),
-          E.lpage(num_values[1])
+        xml_mapping = {
+            'article': {
+                TAG1: 'title'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert len(target_annotations) == 1
+        assert target_annotations[0].name == TAG1
+        assert target_annotations[0].value == 'some embedded text'
+
+    def test_should_return_target_annotations_in_order_of_xml(self):
+        xml_root = E.article(
+            E.tag1('tag1.1'), E.tag2('tag2.1'), E.tag1('tag1.2'), E.tag2('tag2.2'),
         )
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: './/*',
-        TAG1 + XmlMappingSuffix.CHILDREN_CONCAT: json.dumps([[{
-          'xpath': './/fpage'
-        }, {
-          'value': '-'
-        }, {
-          'xpath': './/lpage'
-        }]])
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, [SOME_VALUE, '-'.join(num_values)])
-    ]
-
-  def test_should_not_apply_concat_children_if_one_node_was_not_found(self):
-    num_values = ['101', '202']
-    xml_root = E.article(
-      E.entry(
-        E.parent(
-          E.child1(SOME_VALUE),
-          E.fpage(num_values[0]),
-          E.lpage(num_values[1])
+        xml_mapping = {
+            'article': {
+                TAG1: 'tag1',
+                TAG2: 'tag2'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(ta.name, ta.value) for ta in target_annotations] == [
+            (TAG1, 'tag1.1'), (TAG2, 'tag2.1'), (TAG1, 'tag1.2'), (TAG2, 'tag2.2')
+        ]
+
+    def test_should_return_target_annotations_in_order_of_priority_first(self):
+        xml_root = E.article(
+            E.tag1('tag1.1'), E.tag2('tag2.1'), E.tag1('tag1.2'), E.tag2('tag2.2'),
         )
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: './/*',
-        TAG1 + XmlMappingSuffix.CHILDREN_CONCAT: json.dumps([[{
-          'xpath': './/fpage'
-        }, {
-          'value': '-'
-        }, {
-          'xpath': './/unknown'
-        }]])
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, [SOME_VALUE, num_values[0], num_values[1]])
-    ]
-
-  def test_should_apply_range_children(self):
-    num_values = [101, 102, 103, 104, 105, 106, 107]
-    xml_root = E.article(
-      E.entry(
-        E.child1(SOME_VALUE),
-        E.fpage(str(min(num_values))),
-        E.lpage(str(max(num_values)))
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage',
-        TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{
-          'min': {
-            'xpath': 'fpage'
-          },
-          'max': {
-            'xpath': 'lpage'
-          }
-        }])
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, [str(x) for x in num_values])
-    ]
-
-  def test_should_apply_range_children_as_separate_target_annotations(self):
-    num_values = [101, 102, 103, 104, 105, 106, 107]
-    xml_root = E.article(
-      E.entry(
-        E.child1(SOME_VALUE),
-        E.fpage(str(min(num_values))),
-        E.lpage(str(max(num_values)))
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage',
-        TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{
-          'min': {
-            'xpath': 'fpage'
-          },
-          'max': {
-            'xpath': 'lpage'
-          },
-          'standalone': True
-        }])
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, str(x))
-      for x in num_values
-    ]
-
-  def test_should_not_apply_range_children_if_xpath_not_matching(self):
-    num_values = [101, 102, 103, 104, 105, 106, 107]
-    fpage = str(min(num_values))
-    lpage = str(max(num_values))
-    xml_root = E.article(
-      E.entry(
-        E.child1(SOME_VALUE),
-        E.fpage(fpage),
-        E.lpage(lpage)
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|unknown',
-        TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{
-          'min': {
-            'xpath': 'fpage'
-          },
-          'max': {
-            'xpath': 'unknown'
-          }
-        }])
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, fpage)
-    ]
-
-  def test_should_not_apply_range_children_if_value_is_not_integer(self):
-    fpage = 'abc'
-    lpage = 'xyz'
-    xml_root = E.article(
-      E.entry(
-        E.child1(SOME_VALUE),
-        E.fpage(fpage),
-        E.lpage(lpage)
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage',
-        TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{
-          'min': {
-            'xpath': 'fpage'
-          },
-          'max': {
-            'xpath': 'lpage'
-          }
-        }])
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, [fpage, lpage])
-    ]
-
-  def test_should_add_sub_annotations(self):
-    xml_root = E.article(
-      E.entry(
-        E.firstname(SOME_VALUE),
-        E.givennames(SOME_VALUE_2)
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.SUB + '.firstname': './firstname',
-        TAG1 + XmlMappingSuffix.SUB + '.givennames': './givennames',
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations[0].sub_annotations] == [
-      ('firstname', SOME_VALUE),
-      ('givennames', SOME_VALUE_2)
-    ]
-
-  def test_should_add_sub_annotations_with_multiple_values(self):
-    xml_root = E.article(
-      E.entry(
-        E.value(SOME_VALUE),
-        E.value(SOME_VALUE_2)
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.SUB + '.value': './value'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(t.name, t.value) for t in target_annotations[0].sub_annotations] == [
-      ('value', SOME_VALUE),
-      ('value', SOME_VALUE_2)
-    ]
-
-  def test_should_extract_numbers_from_value_after_text(self):
-    xml_root = E.article(E.entry(
-      E.value(SOME_VALUE + ' 12345')
-    ))
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert len(target_annotations) == 1
-    assert [(t.name, set(t.value)) for t in target_annotations] == [
-      (TAG1, {SOME_VALUE + ' 12345', SOME_VALUE, '12345'})
-    ]
-
-  def test_should_extract_single_value_if_its_the_only_value(self):
-    xml_root = E.article(E.entry(
-      E.value('12345')
-    ))
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert len(target_annotations) == 1
-    assert [(t.name, t.value) for t in target_annotations] == [
-      (TAG1, '12345')
-    ]
-
-  def test_should_unnest_extract_value_from_children(self):
-    xml_root = E.article(E.entry(
-      E.value(SOME_VALUE + ' 12345'),
-      E.value(SOME_VALUE_2 + ' 54321')
-    ))
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        TAG1 + XmlMappingSuffix.CHILDREN: r'.//*',
-        TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert len(target_annotations) == 1
-    assert [(t.name, set(t.value)) for t in target_annotations] == [
-      (TAG1, {
-        SOME_VALUE + ' 12345', SOME_VALUE, '12345',
-        SOME_VALUE_2 + ' 54321', SOME_VALUE_2, '54321'
-      })
-    ]
-
-  def test_should_extract_numbers_from_sub_value_after_text(self):
-    xml_root = E.article(E.entry(
-      E.value(SOME_VALUE + ' 12345')
-    ))
-    sub_key = TAG1 + XmlMappingSuffix.SUB + '.value'
-    xml_mapping = {
-      'article': {
-        TAG1: 'entry',
-        sub_key: './value',
-        sub_key + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert len(target_annotations) == 1
-    assert [(t.name, set(t.value)) for t in target_annotations[0].sub_annotations] == [
-      ('value', {SOME_VALUE + ' 12345', SOME_VALUE, '12345'})
-    ]
-
-  def test_should_return_full_text(self):
-    xml_root = E.article(
-      E.title(
-        'some ',
-        E.other('embedded'),
-        ' text'
-      )
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'title'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert len(target_annotations) == 1
-    assert target_annotations[0].name == TAG1
-    assert target_annotations[0].value == 'some embedded text'
-
-  def test_should_return_target_annotations_in_order_of_xml(self):
-    xml_root = E.article(
-      E.tag1('tag1.1'), E.tag2('tag2.1'), E.tag1('tag1.2'), E.tag2('tag2.2'),
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'tag1',
-        TAG2: 'tag2'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(ta.name, ta.value) for ta in target_annotations] == [
-      (TAG1, 'tag1.1'), (TAG2, 'tag2.1'), (TAG1, 'tag1.2'), (TAG2, 'tag2.2')
-    ]
-
-  def test_should_return_target_annotations_in_order_of_priority_first(self):
-    xml_root = E.article(
-      E.tag1('tag1.1'), E.tag2('tag2.1'), E.tag1('tag1.2'), E.tag2('tag2.2'),
-    )
-    xml_mapping = {
-      'article': {
-        TAG1: 'tag1',
-        TAG2: 'tag2',
-        TAG2 + XmlMappingSuffix.PRIORITY: '1'
-      }
-    }
-    target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
-    assert [(ta.name, ta.value) for ta in target_annotations] == [
-      (TAG2, 'tag2.1'), (TAG2, 'tag2.2'), (TAG1, 'tag1.1'), (TAG1, 'tag1.2')
-    ]
+        xml_mapping = {
+            'article': {
+                TAG1: 'tag1',
+                TAG2: 'tag2',
+                TAG2 + XmlMappingSuffix.PRIORITY: '1'
+            }
+        }
+        target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping)
+        assert [(ta.name, ta.value) for ta in target_annotations] == [
+            (TAG2, 'tag2.1'), (TAG2, 'tag2.2'), (TAG1, 'tag1.1'), (TAG1, 'tag1.2')
+        ]
diff --git a/sciencebeam_gym/preprocess/blockify_annotations.py b/sciencebeam_gym/preprocess/blockify_annotations.py
index 355604d..7f30755 100644
--- a/sciencebeam_gym/preprocess/blockify_annotations.py
+++ b/sciencebeam_gym/preprocess/blockify_annotations.py
@@ -10,267 +10,289 @@ from pyqtree import Index as PqtreeIndex
 from PIL import Image, ImageDraw, ImageColor
 
 from sciencebeam_gym.structured_document.svg import (
-  SVG_NSMAP,
-  SVG_DOC,
-  SVG_RECT,
+    SVG_NSMAP,
+    SVG_DOC,
+    SVG_RECT,
 )
 
 DEFAULT_NEARBY_TOLERANCE = 5
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class AnnotationBlock(object):
-  def __init__(self, tag, bounding_box):
-    self.tag = tag
-    self.bounding_box = bounding_box
-
-  def merge_with(self, other):
-    return AnnotationBlock(
-      self.tag,
-      self.bounding_box.include(other.bounding_box)
-    )
+    def __init__(self, tag, bounding_box):
+        self.tag = tag
+        self.bounding_box = bounding_box
+
+    def merge_with(self, other):
+        return AnnotationBlock(
+            self.tag,
+            self.bounding_box.include(other.bounding_box)
+        )
+
+    def __str__(self):
+        return 'AnnotationBlock({}, {})'.format(self.tag, self.bounding_box)
 
-  def __str__(self):
-    return 'AnnotationBlock({}, {})'.format(self.tag, self.bounding_box)
+    def __repr__(self):
+        return str(self)
 
-  def __repr__(self):
-    return str(self)
 
 class BlockPoint(object):
-  def __init__(self, block, x, y):
-    self.block = block
-    self.point = (x, y)
+    def __init__(self, block, x, y):
+        self.block = block
+        self.point = (x, y)
 
-  def __str__(self):
-    return 'BlockPoint({}, {})'.format(self.block, self.point)
+    def __str__(self):
+        return 'BlockPoint({}, {})'.format(self.block, self.point)
 
-  def __repr__(self):
-    return str(self)
+    def __repr__(self):
+        return str(self)
 
-  def __len__(self):
-    return len(self.point)
+    def __len__(self):
+        return len(self.point)
+
+    def __getitem__(self, index):
+        return self.point[index]
 
-  def __getitem__(self, index):
-    return self.point[index]
 
 def _to_bbox(bb):
-  return (bb.x, bb.y, bb.x + bb.width - 1, bb.y + bb.height - 1)
+    return (bb.x, bb.y, bb.x + bb.width - 1, bb.y + bb.height - 1)
+
 
 ProcessedWrapper = namedtuple('ProcessedWrapper', field_names=['data', 'deleted'])
 
-class DeletableWrapper(object):
-  def __init__(self, data):
-    self.data = data
-    self.deleted = False
 
-  def __hash__(self):
-    return hash(self.data)
+class DeletableWrapper(object):
+    def __init__(self, data):
+        self.data = data
+        self.deleted = False
 
-  def __eq__(self, other):
-    return self.data == other.data
+    def __hash__(self):
+        return hash(self.data)
 
-class BlockSearch(object):
-  def __init__(self, blocks):
-    bboxs = [block.bounding_box for block in blocks]
-    xmax = max([bb.x + bb.width for bb in bboxs])
-    ymax = max([bb.y + bb.height for bb in bboxs])
-    self.spindex = PqtreeIndex(bbox=(0, 0, xmax, ymax))
-    self.wrapper_map = {}
-    for block in blocks:
-      wrapper = DeletableWrapper(block)
-      self.wrapper_map[block] = wrapper
-      self.spindex.insert(wrapper, _to_bbox(block.bounding_box))
+    def __eq__(self, other):
+        return self.data == other.data
 
-  def find_intersection_with(self, search_bounding_box):
-    return [
-      wrapper.data
-      for wrapper in self.spindex.intersect(_to_bbox(search_bounding_box))
-      if not wrapper.deleted
-    ]
 
-  def remove(self, block):
-    wrapper = self.wrapper_map.get(block)
-    if wrapper is not None:
-      wrapper.deleted = True
+class BlockSearch(object):
+    def __init__(self, blocks):
+        bboxs = [block.bounding_box for block in blocks]
+        xmax = max([bb.x + bb.width for bb in bboxs])
+        ymax = max([bb.y + bb.height for bb in bboxs])
+        self.spindex = PqtreeIndex(bbox=(0, 0, xmax, ymax))
+        self.wrapper_map = {}
+        for block in blocks:
+            wrapper = DeletableWrapper(block)
+            self.wrapper_map[block] = wrapper
+            self.spindex.insert(wrapper, _to_bbox(block.bounding_box))
+
+    def find_intersection_with(self, search_bounding_box):
+        return [
+            wrapper.data
+            for wrapper in self.spindex.intersect(_to_bbox(search_bounding_box))
+            if not wrapper.deleted
+        ]
+
+    def remove(self, block):
+        wrapper = self.wrapper_map.get(block)
+        if wrapper is not None:
+            wrapper.deleted = True
 
 
 def merge_blocks(blocks, nearby_tolerance=0):
-  if len(blocks) <= 1:
-    return blocks
-  merged_blocks = deque()
-  logger = get_logger()
-  logger.debug('nearby_tolerance: %s', nearby_tolerance)
-  logger.debug('blocks: %s', blocks)
-  logger.debug('bboxs: %s', [_to_bbox(block.bounding_box) for block in blocks])
-  tags = sorted(set([b.tag for b in blocks]))
-  logger.debug('tags: %s', tags)
-  remaining_blocks = deque(blocks)
-  search_by_tag = {
-    tag: BlockSearch([b for b in remaining_blocks if b.tag == tag])
-    for tag in tags
-  }
-  while len(remaining_blocks) >= 2:
-    merged_block = remaining_blocks.popleft()
-    search = search_by_tag[merged_block.tag]
-    search.remove(merged_block)
-    search_bounding_box = merged_block.bounding_box.with_margin(1 + nearby_tolerance, 0)
-    logger.debug('search_bounding_box: %s (%s)', search_bounding_box, _to_bbox(search_bounding_box))
-    neighbours = search.find_intersection_with(search_bounding_box)
-    logger.debug('neighbours: %s', neighbours)
-    neighbours_blocks_count = 0
-    for neighbour in neighbours:
-      if neighbour.tag == merged_block.tag:
-        merged_block = merged_block.merge_with(neighbour)
-        search.remove(neighbour)
-        remaining_blocks.remove(neighbour)
-        neighbours_blocks_count += 1
-    if neighbours_blocks_count == 0 or len(remaining_blocks) == 0:
-      logger.debug(
-        'no or all remaining blocks merged, mark block as merged: %d',
-        neighbours_blocks_count
-      )
-      merged_blocks.append(merged_block)
-    else:
-      logger.debug(
-        'some but not all remaining blocks merged, continue search: %d',
-        neighbours_blocks_count
-      )
-      remaining_blocks.appendleft(merged_block)
-  result = list(merged_blocks) + list(remaining_blocks)
-  return result
+    if len(blocks) <= 1:
+        return blocks
+    merged_blocks = deque()
+    logger = get_logger()
+    logger.debug('nearby_tolerance: %s', nearby_tolerance)
+    logger.debug('blocks: %s', blocks)
+    logger.debug('bboxs: %s', [_to_bbox(block.bounding_box) for block in blocks])
+    tags = sorted(set([b.tag for b in blocks]))
+    logger.debug('tags: %s', tags)
+    remaining_blocks = deque(blocks)
+    search_by_tag = {
+        tag: BlockSearch([b for b in remaining_blocks if b.tag == tag])
+        for tag in tags
+    }
+    while len(remaining_blocks) >= 2:
+        merged_block = remaining_blocks.popleft()
+        search = search_by_tag[merged_block.tag]
+        search.remove(merged_block)
+        search_bounding_box = merged_block.bounding_box.with_margin(1 + nearby_tolerance, 0)
+        logger.debug('search_bounding_box: %s (%s)',
+                     search_bounding_box, _to_bbox(search_bounding_box))
+        neighbours = search.find_intersection_with(search_bounding_box)
+        logger.debug('neighbours: %s', neighbours)
+        neighbours_blocks_count = 0
+        for neighbour in neighbours:
+            if neighbour.tag == merged_block.tag:
+                merged_block = merged_block.merge_with(neighbour)
+                search.remove(neighbour)
+                remaining_blocks.remove(neighbour)
+                neighbours_blocks_count += 1
+        if neighbours_blocks_count == 0 or len(remaining_blocks) == 0:
+            logger.debug(
+                'no or all remaining blocks merged, mark block as merged: %d',
+                neighbours_blocks_count
+            )
+            merged_blocks.append(merged_block)
+        else:
+            logger.debug(
+                'some but not all remaining blocks merged, continue search: %d',
+                neighbours_blocks_count
+            )
+            remaining_blocks.appendleft(merged_block)
+    result = list(merged_blocks) + list(remaining_blocks)
+    return result
+
 
 def expand_bounding_box(bb):
-  return bb.with_margin(4, 2)
+    return bb.with_margin(4, 2)
+
 
 def expand_block(block):
-  return AnnotationBlock(block.tag, expand_bounding_box(block.bounding_box))
+    return AnnotationBlock(block.tag, expand_bounding_box(block.bounding_box))
+
 
 def expand_blocks(blocks):
-  return [expand_block(block) for block in blocks]
+    return [expand_block(block) for block in blocks]
+
 
 def annotation_document_page_to_annotation_blocks(structured_document, page):
-  tags_and_tokens = (
-    (structured_document.get_tag_value(token), token)
-    for line in structured_document.get_lines_of_page(page)
-    for token in structured_document.get_tokens_of_line(line)
-  )
-  tags_and_bounding_boxes = (
-    (tag, structured_document.get_bounding_box(token))
-    for tag, token in tags_and_tokens
-    if tag
-  )
-  return [
-    AnnotationBlock(tag, bounding_box)
-    for tag, bounding_box in tags_and_bounding_boxes
-    if bounding_box
-  ]
+    tags_and_tokens = (
+        (structured_document.get_tag_value(token), token)
+        for line in structured_document.get_lines_of_page(page)
+        for token in structured_document.get_tokens_of_line(line)
+    )
+    tags_and_bounding_boxes = (
+        (tag, structured_document.get_bounding_box(token))
+        for tag, token in tags_and_tokens
+        if tag
+    )
+    return [
+        AnnotationBlock(tag, bounding_box)
+        for tag, bounding_box in tags_and_bounding_boxes
+        if bounding_box
+    ]
+
 
 def annotation_document_page_to_merged_blocks(structured_document, page, **kwargs):
-  return merge_blocks(
-    annotation_document_page_to_annotation_blocks(structured_document, page),
-    **kwargs
-  )
+    return merge_blocks(
+        annotation_document_page_to_annotation_blocks(structured_document, page),
+        **kwargs
+    )
+
 
 def extend_color_map_for_tags(color_map, tags):
-  updated_color_map = dict(color_map)
-  for tag in tags:
-    if tag not in updated_color_map:
-      updated_color_map[tag] = (
-        max(updated_color_map.values()) + 1 if len(updated_color_map) > 0 else 1
-      )
-  return updated_color_map
+    updated_color_map = dict(color_map)
+    for tag in tags:
+        if tag not in updated_color_map:
+            updated_color_map[tag] = (
+                max(updated_color_map.values()) + 1 if len(updated_color_map) > 0 else 1
+            )
+    return updated_color_map
+
 
 def extend_color_map_for_blocks(color_map, blocks):
-  return extend_color_map_for_tags(
-    color_map,
-    sorted(set([b.tag for b in blocks]))
-  )
+    return extend_color_map_for_tags(
+        color_map,
+        sorted(set([b.tag for b in blocks]))
+    )
+
 
 class AbstractSurface(object, with_metaclass(ABCMeta)):
-  @abstractmethod
-  def rect(self, bounding_box, color, tag=None):
-    pass
-
-class SvgSurface(object):
-  def __init__(self, width, height, background):
-    if not (width and height):
-      raise AttributeError('width and height are required')
-
-    self.svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP, attrib={
-      'width': str(width),
-      'height': str(height)
-    })
-
-    if background:
-      self.svg_root.append(etree.Element(SVG_RECT, attrib={
-        'width': '100%',
-        'height': '100%',
-        'fill': background,
-        'class': 'background'
-      }))
-
-  def rect(self, bounding_box, color, tag=None):
-    attrib = {
-      'class': str(tag),
-      'shape-rendering': 'crispEdges',
-      'x': str(bounding_box.x),
-      'y': str(bounding_box.y),
-      'width': str(bounding_box.width),
-      'height': str(bounding_box.height)
-    }
-    if color:
-      attrib['fill'] = str(color)
-    rect = etree.Element(SVG_RECT, attrib=attrib)
-    self.svg_root.append(rect)
-    return rect
+    @abstractmethod
+    def rect(self, bounding_box, color, tag=None):
+        pass
+
+
+class SvgSurface(AbstractSurface):
+    def __init__(self, width, height, background):
+        if not (width and height):
+            raise AttributeError('width and height are required')
+
+        self.svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP, attrib={
+            'width': str(width),
+            'height': str(height)
+        })
+
+        if background:
+            self.svg_root.append(etree.Element(SVG_RECT, attrib={
+                'width': '100%',
+                'height': '100%',
+                'fill': background,
+                'class': 'background'
+            }))
+
+    def rect(self, bounding_box, color, tag=None):
+        attrib = {
+            'class': str(tag),
+            'shape-rendering': 'crispEdges',
+            'x': str(bounding_box.x),
+            'y': str(bounding_box.y),
+            'width': str(bounding_box.width),
+            'height': str(bounding_box.height)
+        }
+        if color:
+            attrib['fill'] = str(color)
+        rect = etree.Element(SVG_RECT, attrib=attrib)
+        self.svg_root.append(rect)
+        return rect
+
 
 def color_to_tuple(color):
-  if isinstance(color, tuple):
-    return color
-  return ImageColor.getrgb(color)
-
-class ImageSurface(object):
-  def __init__(self, width, height, background):
-    if not (width and height):
-      raise AttributeError('width and height are required')
-
-    width = int(math.ceil(width))
-    height = int(math.ceil(height))
-    if background:
-      self.image = Image.new('RGB', (width, height), color_to_tuple(background))
-    else:
-      self.image = Image.new('RGBA', (width, height), (255, 255, 255, 0))
-    self._draw = ImageDraw.Draw(self.image)
-
-  def rect(self, bounding_box, color, tag=None):
-    if color is None:
-      return
-    self._draw.rectangle(
-      (
-        (bounding_box.x, bounding_box.y),
-        (bounding_box.x + bounding_box.width, bounding_box.y + bounding_box.height)
-      ),
-      fill=color_to_tuple(color)
-    )
+    if isinstance(color, tuple):
+        return color
+    return ImageColor.getrgb(color)
+
+
+class ImageSurface(AbstractSurface):
+    def __init__(self, width, height, background):
+        if not (width and height):
+            raise AttributeError('width and height are required')
+
+        width = int(math.ceil(width))
+        height = int(math.ceil(height))
+        if background:
+            self.image = Image.new('RGB', (width, height), color_to_tuple(background))
+        else:
+            self.image = Image.new('RGBA', (width, height), (255, 255, 255, 0))
+        self._draw = ImageDraw.Draw(self.image)
+
+    def rect(self, bounding_box, color, tag=None):
+        if color is None:
+            return
+        self._draw.rectangle(
+            (
+                (bounding_box.x, bounding_box.y),
+                (bounding_box.x + bounding_box.width, bounding_box.y + bounding_box.height)
+            ),
+            fill=color_to_tuple(color)
+        )
+
 
 def annotated_blocks_to_surface(blocks, surface, color_map):
-  for block in blocks:
-    color = color_map.get(block.tag)
-    surface.rect(block.bounding_box, color, block.tag)
+    for block in blocks:
+        color = color_map.get(block.tag)
+        surface.rect(block.bounding_box, color, block.tag)
+
 
 def annotated_blocks_to_svg(blocks, color_map, width=None, height=None, background=None):
-  surface = SvgSurface(width, height, background)
-  annotated_blocks_to_surface(blocks, surface, color_map)
-  return surface.svg_root
+    surface = SvgSurface(width, height, background)
+    annotated_blocks_to_surface(blocks, surface, color_map)
+    return surface.svg_root
+
 
 def annotated_blocks_to_image(
-  blocks, color_map, width=None, height=None, background=None,
-  scale_to_size=None):
-
-  surface = ImageSurface(width, height, background)
-  annotated_blocks_to_surface(blocks, surface, color_map)
-  image = surface.image
-  if scale_to_size:
-    image = image.resize(scale_to_size, Image.NEAREST)
-  return image
+        blocks, color_map, width=None, height=None, background=None,
+        scale_to_size=None):
+
+    surface = ImageSurface(width, height, background)
+    annotated_blocks_to_surface(blocks, surface, color_map)
+    image = surface.image
+    if scale_to_size:
+        image = image.resize(scale_to_size, Image.NEAREST)
+    return image
diff --git a/sciencebeam_gym/preprocess/blockify_annotations_test.py b/sciencebeam_gym/preprocess/blockify_annotations_test.py
index c00e875..6d23c09 100644
--- a/sciencebeam_gym/preprocess/blockify_annotations_test.py
+++ b/sciencebeam_gym/preprocess/blockify_annotations_test.py
@@ -1,27 +1,27 @@
 import logging
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimpleLine,
-  SimpleToken,
-  B_TAG_PREFIX
+    SimpleStructuredDocument,
+    SimpleLine,
+    SimpleToken,
+    B_TAG_PREFIX
 )
 
 from sciencebeam_gym.structured_document.svg import (
-  SVG_NS
+    SVG_NS
 )
 
 from sciencebeam_gym.preprocess.blockify_annotations import (
-  annotation_document_page_to_annotation_blocks,
-  annotation_document_page_to_merged_blocks,
-  annotated_blocks_to_svg,
-  annotated_blocks_to_image,
-  merge_blocks,
-  AnnotationBlock
+    annotation_document_page_to_annotation_blocks,
+    annotation_document_page_to_merged_blocks,
+    annotated_blocks_to_svg,
+    annotated_blocks_to_image,
+    merge_blocks,
+    AnnotationBlock
 )
 
 TAG1 = 'tag1'
@@ -35,245 +35,251 @@ DEFAULT_BOUNDING_BOX = BoundingBox(0, 0, 16, 10)
 
 DEFAULT_NEARBY_TOLERANCE = 10
 
+
 def setup_module():
-  logging.basicConfig(level=logging.DEBUG)
+    logging.basicConfig(level=logging.DEBUG)
+
 
 class TestAnnotatedBlocksToSvg(object):
-  def test_should_create_rect_for_single_annotated_block(self):
-    blocks = [
-      AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
-    ]
-
-    result_svg = annotated_blocks_to_svg(blocks, color_map={
-      TAG1: DEFAULT_COLOR
-    }, width=100, height=100)
-    result_rect_elements = result_svg.xpath('svg:rect', namespaces={'svg': SVG_NS})
-    assert len(result_rect_elements) == 1
-    assert result_rect_elements[0].attrib['class'] == TAG1
-    assert result_rect_elements[0].attrib['fill'] == DEFAULT_COLOR
-
-  def test_should_add_background(self):
-    blocks = [
-      AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
-    ]
-
-    result_svg = annotated_blocks_to_svg(blocks, color_map={
-      TAG1: '#123'
-    }, width=100, height=100, background=DEFAULT_COLOR)
-    background_elements = result_svg.xpath(
-      'svg:rect[@class="background"]', namespaces={'svg': SVG_NS}
-    )
-    assert len(background_elements) == 1
-    assert background_elements[0].attrib['fill'] == DEFAULT_COLOR
+    def test_should_create_rect_for_single_annotated_block(self):
+        blocks = [
+            AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
+        ]
+
+        result_svg = annotated_blocks_to_svg(blocks, color_map={
+            TAG1: DEFAULT_COLOR
+        }, width=100, height=100)
+        result_rect_elements = result_svg.xpath('svg:rect', namespaces={'svg': SVG_NS})
+        assert len(result_rect_elements) == 1
+        assert result_rect_elements[0].attrib['class'] == TAG1
+        assert result_rect_elements[0].attrib['fill'] == DEFAULT_COLOR
+
+    def test_should_add_background(self):
+        blocks = [
+            AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
+        ]
+
+        result_svg = annotated_blocks_to_svg(blocks, color_map={
+            TAG1: '#123'
+        }, width=100, height=100, background=DEFAULT_COLOR)
+        background_elements = result_svg.xpath(
+            'svg:rect[@class="background"]', namespaces={'svg': SVG_NS}
+        )
+        assert len(background_elements) == 1
+        assert background_elements[0].attrib['fill'] == DEFAULT_COLOR
+
 
 class TestAnnotatedBlocksToImage(object):
-  def test_should_create_rect_for_single_annotated_block(self):
-    blocks = [
-      AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1))
-    ]
-
-    image = annotated_blocks_to_image(blocks, color_map={
-      TAG1: (0, 255, 0)
-    }, width=3, height=3)
-    assert image.getpixel((0, 0)) == (0, 255, 0, 255)
-
-  def test_should_accept_float_image_size(self):
-    blocks = [
-      AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1))
-    ]
-
-    image = annotated_blocks_to_image(blocks, color_map={
-      TAG1: (0, 255, 0)
-    }, width=3.1, height=3.9)
-    assert image.size == (4, 4)
-
-  def test_should_convert_rect_color_name(self):
-    blocks = [
-      AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1))
-    ]
-
-    image = annotated_blocks_to_image(blocks, color_map={
-      TAG1: 'green'
-    }, width=3, height=3)
-    assert image.getpixel((0, 0)) == (0, 128, 0, 255)
-
-  def test_should_ignore_unmapped_tag(self):
-    blocks = [
-      AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1))
-    ]
-
-    image = annotated_blocks_to_image(blocks, color_map={
-    }, width=3, height=3)
-    assert image.getpixel((0, 0)) == (255, 255, 255, 0)
-
-  def test_should_add_background(self):
-    width = 3
-    height = 2
-    image = annotated_blocks_to_image([], color_map={
-      TAG1: '#123'
-    }, width=width, height=height, background=(255, 0, 0))
-    data = list(image.getdata())
-    assert data == [(255, 0, 0)] * (width * height)
-
-  def test_should_convert_background_color_name(self):
-    width = 3
-    height = 2
-    image = annotated_blocks_to_image([], color_map={
-      TAG1: '#123'
-    }, width=width, height=height, background='red')
-    data = list(image.getdata())
-    assert data == [(255, 0, 0)] * (width * height)
+    def test_should_create_rect_for_single_annotated_block(self):
+        blocks = [
+            AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1))
+        ]
+
+        image = annotated_blocks_to_image(blocks, color_map={
+            TAG1: (0, 255, 0)
+        }, width=3, height=3)
+        assert image.getpixel((0, 0)) == (0, 255, 0, 255)
+
+    def test_should_accept_float_image_size(self):
+        blocks = [
+            AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1))
+        ]
+
+        image = annotated_blocks_to_image(blocks, color_map={
+            TAG1: (0, 255, 0)
+        }, width=3.1, height=3.9)
+        assert image.size == (4, 4)
+
+    def test_should_convert_rect_color_name(self):
+        blocks = [
+            AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1))
+        ]
+
+        image = annotated_blocks_to_image(blocks, color_map={
+            TAG1: 'green'
+        }, width=3, height=3)
+        assert image.getpixel((0, 0)) == (0, 128, 0, 255)
+
+    def test_should_ignore_unmapped_tag(self):
+        blocks = [
+            AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1))
+        ]
+
+        image = annotated_blocks_to_image(blocks, color_map={
+        }, width=3, height=3)
+        assert image.getpixel((0, 0)) == (255, 255, 255, 0)
+
+    def test_should_add_background(self):
+        width = 3
+        height = 2
+        image = annotated_blocks_to_image([], color_map={
+            TAG1: '#123'
+        }, width=width, height=height, background=(255, 0, 0))
+        data = list(image.getdata())
+        assert data == [(255, 0, 0)] * (width * height)
+
+    def test_should_convert_background_color_name(self):
+        width = 3
+        height = 2
+        image = annotated_blocks_to_image([], color_map={
+            TAG1: '#123'
+        }, width=width, height=height, background='red')
+        data = list(image.getdata())
+        assert data == [(255, 0, 0)] * (width * height)
+
 
 class TestAnnotationDocumentPageToAnnotationBlocks(object):
-  def test_should_convert_single_token_to_block_with_same_bounding_box(self):
-    token = SimpleToken('test', tag=TAG1, bounding_box=DEFAULT_BOUNDING_BOX)
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])])
-
-    blocks = annotation_document_page_to_annotation_blocks(
-      structured_document,
-      structured_document.get_pages()[0]
-    )
-    assert len(blocks) == 1
-
-    block = blocks[0]
-    assert block.tag == TAG1
-    assert block.bounding_box == DEFAULT_BOUNDING_BOX
-
-  def test_should_strip_tag_prefix(self):
-    token = SimpleToken(
-      'test', tag=TAG1, tag_prefix=B_TAG_PREFIX,
-      bounding_box=DEFAULT_BOUNDING_BOX
-    )
-    assert token.get_tag() == B_TAG_PREFIX + TAG1
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])])
-
-    blocks = annotation_document_page_to_annotation_blocks(
-      structured_document,
-      structured_document.get_pages()[0]
-    )
-    assert [b.tag for b in blocks] == [TAG1]
-
-  def test_should_ignore_block_without_bounding_box(self):
-    token = SimpleToken('test')
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])])
-    structured_document.set_tag(token, TAG1)
-
-    blocks = annotation_document_page_to_annotation_blocks(
-      structured_document,
-      structured_document.get_pages()[0]
-    )
-    assert len(blocks) == 0
+    def test_should_convert_single_token_to_block_with_same_bounding_box(self):
+        token = SimpleToken('test', tag=TAG1, bounding_box=DEFAULT_BOUNDING_BOX)
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])])
+
+        blocks = annotation_document_page_to_annotation_blocks(
+            structured_document,
+            structured_document.get_pages()[0]
+        )
+        assert len(blocks) == 1
+
+        block = blocks[0]
+        assert block.tag == TAG1
+        assert block.bounding_box == DEFAULT_BOUNDING_BOX
+
+    def test_should_strip_tag_prefix(self):
+        token = SimpleToken(
+            'test', tag=TAG1, tag_prefix=B_TAG_PREFIX,
+            bounding_box=DEFAULT_BOUNDING_BOX
+        )
+        assert token.get_tag() == B_TAG_PREFIX + TAG1
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])])
+
+        blocks = annotation_document_page_to_annotation_blocks(
+            structured_document,
+            structured_document.get_pages()[0]
+        )
+        assert [b.tag for b in blocks] == [TAG1]
+
+    def test_should_ignore_block_without_bounding_box(self):
+        token = SimpleToken('test')
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])])
+        structured_document.set_tag(token, TAG1)
+
+        blocks = annotation_document_page_to_annotation_blocks(
+            structured_document,
+            structured_document.get_pages()[0]
+        )
+        assert len(blocks) == 0
+
 
 class TestAnnotationDocumentPageToMergedBlocks(object):
-  def test_should_convert_single_token_to_block_with_same_bounding_box(self):
-    token = SimpleToken('test', tag=TAG1, bounding_box=DEFAULT_BOUNDING_BOX)
-    structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])])
+    def test_should_convert_single_token_to_block_with_same_bounding_box(self):
+        token = SimpleToken('test', tag=TAG1, bounding_box=DEFAULT_BOUNDING_BOX)
+        structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])])
+
+        blocks = annotation_document_page_to_merged_blocks(
+            structured_document,
+            structured_document.get_pages()[0]
+        )
+        assert len(blocks) == 1
 
-    blocks = annotation_document_page_to_merged_blocks(
-      structured_document,
-      structured_document.get_pages()[0]
-    )
-    assert len(blocks) == 1
+        block = blocks[0]
+        assert block.tag == TAG1
+        assert block.bounding_box == DEFAULT_BOUNDING_BOX
 
-    block = blocks[0]
-    assert block.tag == TAG1
-    assert block.bounding_box == DEFAULT_BOUNDING_BOX
 
 class TestMergeBlocks(object):
-  def test_should_return_same_single_blocks(self):
-    block = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
-
-    merged_blocks = merge_blocks([block])
-    assert merged_blocks == [block]
-
-  def test_should_merge_right_adjacent_block_with_same_tag(self):
-    block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
-    block2 = AnnotationBlock(
-      TAG1,
-      block1.bounding_box.move_by(block1.bounding_box.width, 0)
-    )
-
-    merged_blocks = merge_blocks([block1, block2])
-    assert [b.tag for b in merged_blocks] == [TAG1]
-    assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box)
-
-  def test_should_not_merge_right_adjacent_block_with_same_different_tag(self):
-    block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
-    block2 = AnnotationBlock(
-      TAG2,
-      block1.bounding_box.move_by(block1.bounding_box.width, 0)
-    )
-
-    merged_blocks = merge_blocks([block1, block2])
-    assert [b.tag for b in merged_blocks] == [TAG1, TAG2]
-
-  def test_should_merge_multiple_separate_right_adjacent_blocks(self):
-    block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
-    block2 = AnnotationBlock(
-      TAG1,
-      block1.bounding_box.move_by(block1.bounding_box.width, 0)
-    )
-
-    block3 = AnnotationBlock(
-      TAG2,
-      block1.bounding_box.move_by(block1.bounding_box.width * 2, 0)
-    )
-    block4 = AnnotationBlock(
-      TAG2,
-      block3.bounding_box.move_by(block3.bounding_box.width, 0)
-    )
-
-    merged_blocks = merge_blocks([block1, block2, block3, block4])
-    assert [b.tag for b in merged_blocks] == [TAG1, TAG2]
-    assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box)
-    assert merged_blocks[1].bounding_box == block3.bounding_box.include(block4.bounding_box)
-
-  def test_should_merge_multiple_sequential_right_adjacent_blocks(self):
-    block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
-    block2 = AnnotationBlock(
-      TAG1,
-      block1.bounding_box.move_by(block1.bounding_box.width, 0)
-    )
-    block3 = AnnotationBlock(
-      TAG1,
-      block2.bounding_box.move_by(block2.bounding_box.width, 0)
-    )
-
-    merged_blocks = merge_blocks([block1, block2, block3])
-    assert [b.tag for b in merged_blocks] == [TAG1]
-
-    assert merged_blocks[0].bounding_box == (
-      block1.bounding_box.include(block2.bounding_box).include(block3.bounding_box)
-    )
-
-  def test_should_merge_right_nearby_block_with_same_tag(self):
-    block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
-    block2 = AnnotationBlock(
-      TAG1,
-      block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE, 0)
-    )
-
-    merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE)
-    assert [b.tag for b in merged_blocks] == [TAG1]
-    assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box)
-
-  def test_should_not_merge_too_far_away_block_with_same_tag(self):
-    block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
-    block2 = AnnotationBlock(
-      TAG1,
-      block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE + 1, 0)
-    )
-
-    merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE)
-    assert [b.tag for b in merged_blocks] == [TAG1, TAG1]
-
-  def test_should_merge_right_nearby_block_with_same_tag_using_fractions(self):
-    block1 = AnnotationBlock(TAG1, BoundingBox(10.5, 10.5, 10.9, 26.3))
-    block2 = AnnotationBlock(
-      TAG1,
-      block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE, 0)
-    )
-
-    merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE)
-    assert [b.tag for b in merged_blocks] == [TAG1]
-    assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box)
+    def test_should_return_same_single_blocks(self):
+        block = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
+
+        merged_blocks = merge_blocks([block])
+        assert merged_blocks == [block]
+
+    def test_should_merge_right_adjacent_block_with_same_tag(self):
+        block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
+        block2 = AnnotationBlock(
+            TAG1,
+            block1.bounding_box.move_by(block1.bounding_box.width, 0)
+        )
+
+        merged_blocks = merge_blocks([block1, block2])
+        assert [b.tag for b in merged_blocks] == [TAG1]
+        assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box)
+
+    def test_should_not_merge_right_adjacent_block_with_same_different_tag(self):
+        block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
+        block2 = AnnotationBlock(
+            TAG2,
+            block1.bounding_box.move_by(block1.bounding_box.width, 0)
+        )
+
+        merged_blocks = merge_blocks([block1, block2])
+        assert [b.tag for b in merged_blocks] == [TAG1, TAG2]
+
+    def test_should_merge_multiple_separate_right_adjacent_blocks(self):
+        block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
+        block2 = AnnotationBlock(
+            TAG1,
+            block1.bounding_box.move_by(block1.bounding_box.width, 0)
+        )
+
+        block3 = AnnotationBlock(
+            TAG2,
+            block1.bounding_box.move_by(block1.bounding_box.width * 2, 0)
+        )
+        block4 = AnnotationBlock(
+            TAG2,
+            block3.bounding_box.move_by(block3.bounding_box.width, 0)
+        )
+
+        merged_blocks = merge_blocks([block1, block2, block3, block4])
+        assert [b.tag for b in merged_blocks] == [TAG1, TAG2]
+        assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box)
+        assert merged_blocks[1].bounding_box == block3.bounding_box.include(block4.bounding_box)
+
+    def test_should_merge_multiple_sequential_right_adjacent_blocks(self):
+        block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
+        block2 = AnnotationBlock(
+            TAG1,
+            block1.bounding_box.move_by(block1.bounding_box.width, 0)
+        )
+        block3 = AnnotationBlock(
+            TAG1,
+            block2.bounding_box.move_by(block2.bounding_box.width, 0)
+        )
+
+        merged_blocks = merge_blocks([block1, block2, block3])
+        assert [b.tag for b in merged_blocks] == [TAG1]
+
+        assert merged_blocks[0].bounding_box == (
+            block1.bounding_box.include(block2.bounding_box).include(block3.bounding_box)
+        )
+
+    def test_should_merge_right_nearby_block_with_same_tag(self):
+        block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
+        block2 = AnnotationBlock(
+            TAG1,
+            block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE, 0)
+        )
+
+        merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE)
+        assert [b.tag for b in merged_blocks] == [TAG1]
+        assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box)
+
+    def test_should_not_merge_too_far_away_block_with_same_tag(self):
+        block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX)
+        block2 = AnnotationBlock(
+            TAG1,
+            block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE + 1, 0)
+        )
+
+        merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE)
+        assert [b.tag for b in merged_blocks] == [TAG1, TAG1]
+
+    def test_should_merge_right_nearby_block_with_same_tag_using_fractions(self):
+        block1 = AnnotationBlock(TAG1, BoundingBox(10.5, 10.5, 10.9, 26.3))
+        block2 = AnnotationBlock(
+            TAG1,
+            block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE, 0)
+        )
+
+        merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE)
+        assert [b.tag for b in merged_blocks] == [TAG1]
+        assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box)
diff --git a/sciencebeam_gym/preprocess/color_map.py b/sciencebeam_gym/preprocess/color_map.py
index ce8ba2e..ecd7271 100644
--- a/sciencebeam_gym/preprocess/color_map.py
+++ b/sciencebeam_gym/preprocess/color_map.py
@@ -1,37 +1,40 @@
 from __future__ import absolute_import
 import re
 
-from six import string_types
 from six.moves.configparser import ConfigParser
+from six import string_types
+
 
 def parse_color_map_from_configparser(color_map_config):
-  num_pattern = re.compile(r'(\d+)')
-  rgb_pattern = re.compile(r'\((\d+),\s*(\d+),\s*(\d+)\)')
-
-  def parse_color(s):
-    m = num_pattern.match(s)
-    if m:
-      x = int(m.group(1))
-      return (x, x, x)
-    else:
-      m = rgb_pattern.match(s)
-      if m:
-        return (int(m.group(1)), int(m.group(2)), int(m.group(3)))
-    raise Exception('invalid color value: {}'.format(s))
+    num_pattern = re.compile(r'(\d+)')
+    rgb_pattern = re.compile(r'\((\d+),\s*(\d+),\s*(\d+)\)')
+
+    def parse_color(s):
+        m = num_pattern.match(s)
+        if m:
+            x = int(m.group(1))
+            return (x, x, x)
+        else:
+            m = rgb_pattern.match(s)
+            if m:
+                return (int(m.group(1)), int(m.group(2)), int(m.group(3)))
+        raise Exception('invalid color value: {}'.format(s))
+
+    color_map = dict()
+    for k, v in color_map_config.items('color_map'):
+        color_map[k] = parse_color(v)
+    return color_map
 
-  color_map = dict()
-  for k, v in color_map_config.items('color_map'):
-    color_map[k] = parse_color(v)
-  return color_map
 
 def parse_color_map_from_file(f):
-  color_map_config = ConfigParser()
-  if isinstance(f, string_types):
-    with open(f, 'r') as fp:
-      color_map_config.readfp(fp)
-  else:
-    color_map_config.readfp(f)
-  return parse_color_map_from_configparser(color_map_config)
+    color_map_config = ConfigParser()
+    if isinstance(f, string_types):
+        with open(f, 'r') as fp:
+            color_map_config.readfp(fp)
+    else:
+        color_map_config.readfp(f)
+    return parse_color_map_from_configparser(color_map_config)
+
 
 def parse_color_map(f):
-  return parse_color_map_from_file(f)
+    return parse_color_map_from_file(f)
diff --git a/sciencebeam_gym/preprocess/color_map_test.py b/sciencebeam_gym/preprocess/color_map_test.py
index 5cc3207..1bad14a 100644
--- a/sciencebeam_gym/preprocess/color_map_test.py
+++ b/sciencebeam_gym/preprocess/color_map_test.py
@@ -1,16 +1,17 @@
 from six import BytesIO
 
 from sciencebeam_gym.preprocess.color_map import (
-  parse_color_map_from_file
+    parse_color_map_from_file
 )
 
+
 class TestParseColorMapFromFile(object):
-  def test_should_parse_rgb_color_values(self):
-    data=b'\n'.join([
-      b'[color_map]',
-      b'tag1 = (255, 0, 0)'
-    ])
-    color_map = parse_color_map_from_file(BytesIO(data))
-    assert color_map == {
-      'tag1': (255, 0, 0)
-    }
+    def test_should_parse_rgb_color_values(self):
+        data = b'\n'.join([
+            b'[color_map]',
+            b'tag1 = (255, 0, 0)'
+        ])
+        color_map = parse_color_map_from_file(BytesIO(data))
+        assert color_map == {
+            'tag1': (255, 0, 0)
+        }
diff --git a/sciencebeam_gym/preprocess/lxml_to_svg.py b/sciencebeam_gym/preprocess/lxml_to_svg.py
index 8186739..887695b 100644
--- a/sciencebeam_gym/preprocess/lxml_to_svg.py
+++ b/sciencebeam_gym/preprocess/lxml_to_svg.py
@@ -5,232 +5,240 @@ import os
 from lxml import etree
 
 from sciencebeam_utils.utils.csv import (
-  open_csv_output,
-  write_dict_csv
+    open_csv_output,
+    write_dict_csv
 )
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.preprocess.annotation.annotator import (
-  Annotator,
-  DEFAULT_ANNOTATORS
+    Annotator,
+    DEFAULT_ANNOTATORS
 )
 
 from sciencebeam_gym.preprocess.annotation.matching_annotator import (
-  MatchingAnnotator,
-  CsvMatchDetailReporter
+    MatchingAnnotator,
+    CsvMatchDetailReporter
 )
 
 from sciencebeam_gym.preprocess.annotation.target_annotation import (
-  parse_xml_mapping,
-  xml_root_to_target_annotations
+    parse_xml_mapping,
+    xml_root_to_target_annotations
 )
 
 from sciencebeam_gym.preprocess.annotation.annotation_evaluation import (
-  evaluate_document_by_page,
-  DEFAULT_EVALUATION_COLUMNS,
-  to_csv_dict_rows as to_annotation_evaluation_csv_dict_rows
+    evaluate_document_by_page,
+    DEFAULT_EVALUATION_COLUMNS,
+    to_csv_dict_rows as to_annotation_evaluation_csv_dict_rows
 )
 
 from sciencebeam_gym.structured_document.svg import (
-  SVG_TEXT,
-  SVG_G,
-  SVG_RECT,
-  SVG_DOC,
-  SVG_NSMAP,
-  SVGE_BOUNDING_BOX,
-  SvgStructuredDocument,
-  SvgStyleClasses,
-  format_bounding_box as svg_format_bounding_box
+    SVG_TEXT,
+    SVG_G,
+    SVG_RECT,
+    SVG_DOC,
+    SVG_NSMAP,
+    SVGE_BOUNDING_BOX,
+    SvgStructuredDocument,
+    SvgStyleClasses,
+    format_bounding_box as svg_format_bounding_box
 )
 
 from sciencebeam_gym.preprocess.visualize_svg_annotation import (
-  visualize_svg_annotations
+    visualize_svg_annotations
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def ElementWithText(tag, text, **kwargs):
-  node = etree.Element(tag, **kwargs)
-  node.text = text
-  return node
+    node = etree.Element(tag, **kwargs)
+    node.text = text
+    return node
+
 
 def svg_pattern_for_lxml_path(lxml_path):
-  name, _ = os.path.splitext(lxml_path)
-  return name + '-page{}.svg'
+    name, _ = os.path.splitext(lxml_path)
+    return name + '-page{}.svg'
+
 
 def parse_args(argv=None):
-  parser = argparse.ArgumentParser(
-    description='Convert to LXML (pdftoxml) to SVG'
-  )
-  parser.add_argument(
-    '--lxml-path', type=str, required=True,
-    help='path to lxml file'
-  )
-  parser.add_argument(
-    '--svg-path', type=str, required=False,
-    help='path to svg file'
-  )
-  parser.add_argument(
-    '--xml-path', type=str, required=False,
-    help='path to xml file'
-  )
-  parser.add_argument(
-    '--xml-mapping-path', type=str, default='annot-xml-front.conf',
-    help='path to xml mapping file'
-  )
-  parser.add_argument(
-    '--annotate', action='store_true', required=False,
-    help='enable annotation'
-  )
-  parser.add_argument(
-    '--debug', action='store_true', required=False,
-    help='enable debug logging'
-  )
-  parser.add_argument(
-    '--debug-match', type=str, required=False,
-    help='debug matches and save as csv'
-  )
-  parser.add_argument(
-    '--annotation-evaluation-csv', type=str, required=False,
-    help='Annotation evaluation CSV output file'
-  )
-  args = parser.parse_args(argv)
-  return args
+    parser = argparse.ArgumentParser(
+        description='Convert to LXML (pdftoxml) to SVG'
+    )
+    parser.add_argument(
+        '--lxml-path', type=str, required=True,
+        help='path to lxml file'
+    )
+    parser.add_argument(
+        '--svg-path', type=str, required=False,
+        help='path to svg file'
+    )
+    parser.add_argument(
+        '--xml-path', type=str, required=False,
+        help='path to xml file'
+    )
+    parser.add_argument(
+        '--xml-mapping-path', type=str, default='annot-xml-front.conf',
+        help='path to xml mapping file'
+    )
+    parser.add_argument(
+        '--annotate', action='store_true', required=False,
+        help='enable annotation'
+    )
+    parser.add_argument(
+        '--debug', action='store_true', required=False,
+        help='enable debug logging'
+    )
+    parser.add_argument(
+        '--debug-match', type=str, required=False,
+        help='debug matches and save as csv'
+    )
+    parser.add_argument(
+        '--annotation-evaluation-csv', type=str, required=False,
+        help='Annotation evaluation CSV output file'
+    )
+    args = parser.parse_args(argv)
+    return args
+
 
 def iter_svg_pages_for_lxml(lxml_root, add_background=True):
-  previous_block = None
-  previous_svg_block = None
-  for page in lxml_root.xpath('//DOCUMENT/PAGE'):
-    svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
-    page_width = page.attrib.get('width')
-    page_height = page.attrib.get('height')
-    if page_width and page_height:
-      svg_root.attrib['viewBox'] = '0 0 %s %s' % (page_width, page_height)
-    if add_background:
-      svg_root.append(etree.Element(SVG_RECT, attrib={
-        'width': '100%',
-        'height': '100%',
-        'fill': 'white',
-        'class': 'background'
-      }))
-    for text in page.xpath('.//TEXT'):
-      svg_g = etree.Element(SVG_G, nsmap=SVG_NSMAP, attrib={
-        'class': SvgStyleClasses.LINE
-      })
-      for token in text.xpath('./TOKEN'):
-        x = float(token.attrib.get('x'))
-        y = float(token.attrib.get('y'))
-        height = float(token.attrib.get('height'))
-        width = float(token.attrib.get('width'))
-        base = float(token.attrib.get('base', y))
-        y_center = y + height / 2.0
-        attrib = {
-          'x': str(x),
-          'y': str(base),
-          'font-size': token.attrib.get('font-size'),
-          'font-family': token.attrib.get('font-name'),
-          'fill': token.attrib.get('font-color'),
-          SVGE_BOUNDING_BOX: svg_format_bounding_box(BoundingBox(
-            x, y, width, height
-          ))
-        }
-        angle = float(token.attrib.get('angle', '0'))
-        if token.attrib.get('rotation') == '1' and angle == 90.0:
-          attrib['x'] = '0'
-          attrib['y'] = '0'
-          attrib['transform'] = 'translate({x} {y}) rotate({angle})'.format(
-            x=str(x),
-            y=str(y_center),
-            angle=str(-angle)
-          )
-        svg_g.append(
-          ElementWithText(SVG_TEXT, token.text, attrib=attrib)
-        )
-      text_parent = text.getparent()
-      if text_parent.tag == 'BLOCK':
-        if text_parent != previous_block:
-          previous_svg_block = etree.Element(SVG_G, nsmap=SVG_NSMAP, attrib={
-            'class': SvgStyleClasses.BLOCK
-          })
-          svg_root.append(previous_svg_block)
-          previous_block = text_parent
-        previous_svg_block.append(svg_g)
-      else:
-        previous_block = None
-        previous_svg_block = None
-        svg_root.append(svg_g)
-    yield svg_root
+    previous_block = None
+    previous_svg_block = None
+    for page in lxml_root.xpath('//DOCUMENT/PAGE'):
+        svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
+        page_width = page.attrib.get('width')
+        page_height = page.attrib.get('height')
+        if page_width and page_height:
+            svg_root.attrib['viewBox'] = '0 0 %s %s' % (page_width, page_height)
+        if add_background:
+            svg_root.append(etree.Element(SVG_RECT, attrib={
+                'width': '100%',
+                'height': '100%',
+                'fill': 'white',
+                'class': 'background'
+            }))
+        for text in page.xpath('.//TEXT'):
+            svg_g = etree.Element(SVG_G, nsmap=SVG_NSMAP, attrib={
+                'class': SvgStyleClasses.LINE
+            })
+            for token in text.xpath('./TOKEN'):
+                x = float(token.attrib.get('x'))
+                y = float(token.attrib.get('y'))
+                height = float(token.attrib.get('height'))
+                width = float(token.attrib.get('width'))
+                base = float(token.attrib.get('base', y))
+                y_center = y + height / 2.0
+                attrib = {
+                    'x': str(x),
+                    'y': str(base),
+                    'font-size': token.attrib.get('font-size'),
+                    'font-family': token.attrib.get('font-name'),
+                    'fill': token.attrib.get('font-color'),
+                    SVGE_BOUNDING_BOX: svg_format_bounding_box(BoundingBox(
+                        x, y, width, height
+                    ))
+                }
+                angle = float(token.attrib.get('angle', '0'))
+                if token.attrib.get('rotation') == '1' and angle == 90.0:
+                    attrib['x'] = '0'
+                    attrib['y'] = '0'
+                    attrib['transform'] = 'translate({x} {y}) rotate({angle})'.format(
+                        x=str(x),
+                        y=str(y_center),
+                        angle=str(-angle)
+                    )
+                svg_g.append(
+                    ElementWithText(SVG_TEXT, token.text, attrib=attrib)
+                )
+            text_parent = text.getparent()
+            if text_parent.tag == 'BLOCK':
+                if text_parent != previous_block:
+                    previous_svg_block = etree.Element(SVG_G, nsmap=SVG_NSMAP, attrib={
+                        'class': SvgStyleClasses.BLOCK
+                    })
+                    svg_root.append(previous_svg_block)
+                    previous_block = text_parent
+                previous_svg_block.append(svg_g)
+            else:
+                previous_block = None
+                previous_svg_block = None
+                svg_root.append(svg_g)
+        yield svg_root
+
 
 def convert(args):
-  logger = get_logger()
-  svg_filename_pattern = args.svg_path
-  if not svg_filename_pattern:
-    svg_filename_pattern = svg_pattern_for_lxml_path(args.lxml_path)
-  logger.debug('svg_filename_pattern: %s', svg_filename_pattern)
-  lxml_root = etree.parse(args.lxml_path).getroot()
-
-  match_detail_reporter = None
-  if args.annotate:
-    annotators = DEFAULT_ANNOTATORS
-    if args.debug_match:
-      match_detail_reporter = CsvMatchDetailReporter(
-        open_csv_output(args.debug_match),
-        args.debug_match
-      )
-    if args.xml_path:
-      xml_mapping = parse_xml_mapping(args.xml_mapping_path)
-      target_annotations = xml_root_to_target_annotations(
-        etree.parse(args.xml_path).getroot(),
-        xml_mapping
-      )
-      annotators = annotators + [MatchingAnnotator(
-        target_annotations, match_detail_reporter=match_detail_reporter,
-        use_tag_begin_prefix=True
-      )]
-    annotator = Annotator(annotators)
-  else:
-    annotator = None
-
-  if annotator:
-    svg_roots = list(iter_svg_pages_for_lxml(lxml_root))
-    annotator.annotate(SvgStructuredDocument(svg_roots))
-  else:
-    svg_roots = iter_svg_pages_for_lxml(lxml_root)
-  for page_index, svg_root in enumerate(svg_roots):
+    logger = get_logger()
+    svg_filename_pattern = args.svg_path
+    if not svg_filename_pattern:
+        svg_filename_pattern = svg_pattern_for_lxml_path(args.lxml_path)
+    logger.debug('svg_filename_pattern: %s', svg_filename_pattern)
+    lxml_root = etree.parse(args.lxml_path).getroot()
+
+    match_detail_reporter = None
+    if args.annotate:
+        annotators = DEFAULT_ANNOTATORS
+        if args.debug_match:
+            match_detail_reporter = CsvMatchDetailReporter(
+                open_csv_output(args.debug_match),
+                args.debug_match
+            )
+        if args.xml_path:
+            xml_mapping = parse_xml_mapping(args.xml_mapping_path)
+            target_annotations = xml_root_to_target_annotations(
+                etree.parse(args.xml_path).getroot(),
+                xml_mapping
+            )
+            annotators = annotators + [MatchingAnnotator(
+                target_annotations, match_detail_reporter=match_detail_reporter,
+                use_tag_begin_prefix=True
+            )]
+        annotator = Annotator(annotators)
+    else:
+        annotator = None
+
+    if annotator:
+        svg_roots = list(iter_svg_pages_for_lxml(lxml_root))
+        annotator.annotate(SvgStructuredDocument(svg_roots))
+    else:
+        svg_roots = iter_svg_pages_for_lxml(lxml_root)
+    for page_index, svg_root in enumerate(svg_roots):
+        if annotator:
+            svg_root = visualize_svg_annotations(svg_root)
+        svg_filename = svg_filename_pattern.format(1 + page_index)
+        logger.info('writing to: %s', svg_filename)
+        with open(svg_filename, 'wb') as f:
+            etree.ElementTree(svg_root).write(f, pretty_print=True)
     if annotator:
-      svg_root = visualize_svg_annotations(svg_root)
-    svg_filename = svg_filename_pattern.format(1 + page_index)
-    logger.info('writing to: %s', svg_filename)
-    with open(svg_filename, 'wb') as f:
-      etree.ElementTree(svg_root).write(f, pretty_print=True)
-  if annotator:
-    tagging_evaluation_results = evaluate_document_by_page(SvgStructuredDocument(svg_roots))
-    logger.info('tagging evaluation:\n%s', '\n'.join([
-      'page{}: {}'.format(1 + i, r) for i, r in enumerate(tagging_evaluation_results)
-    ]))
-    if args.annotation_evaluation_csv:
-      write_dict_csv(
-        args.annotation_evaluation_csv,
-        DEFAULT_EVALUATION_COLUMNS,
-        to_annotation_evaluation_csv_dict_rows(
-          tagging_evaluation_results,
-          document=os.path.basename(args.lxml_path)
-        )
-      )
-  if match_detail_reporter:
-    match_detail_reporter.close()
+        tagging_evaluation_results = evaluate_document_by_page(SvgStructuredDocument(svg_roots))
+        logger.info('tagging evaluation:\n%s', '\n'.join([
+            'page{}: {}'.format(1 + i, r) for i, r in enumerate(tagging_evaluation_results)
+        ]))
+        if args.annotation_evaluation_csv:
+            write_dict_csv(
+                args.annotation_evaluation_csv,
+                DEFAULT_EVALUATION_COLUMNS,
+                to_annotation_evaluation_csv_dict_rows(
+                    tagging_evaluation_results,
+                    document=os.path.basename(args.lxml_path)
+                )
+            )
+    if match_detail_reporter:
+        match_detail_reporter.close()
+
 
 def main():
-  args = parse_args()
-  if args.debug:
-    logging.basicConfig(level=logging.DEBUG)
-  else:
-    logging.basicConfig(level=logging.INFO)
-  convert(args)
+    args = parse_args()
+    if args.debug:
+        logging.basicConfig(level=logging.DEBUG)
+    else:
+        logging.basicConfig(level=logging.INFO)
+    convert(args)
+
 
 if __name__ == "__main__":
-  main()
+    main()
diff --git a/sciencebeam_gym/preprocess/lxml_to_svg_test.py b/sciencebeam_gym/preprocess/lxml_to_svg_test.py
index c754cfa..98a6398 100644
--- a/sciencebeam_gym/preprocess/lxml_to_svg_test.py
+++ b/sciencebeam_gym/preprocess/lxml_to_svg_test.py
@@ -1,20 +1,20 @@
 from lxml.builder import E
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.structured_document.svg import (
-  SVG_TEXT,
-  SVG_G,
-  SVG_DOC,
-  SVG_NS,
-  SVGE_BOUNDING_BOX,
-  parse_bounding_box
+    SVG_TEXT,
+    SVG_G,
+    SVG_DOC,
+    SVG_NS,
+    SVGE_BOUNDING_BOX,
+    parse_bounding_box
 )
 
 from sciencebeam_gym.preprocess.lxml_to_svg import (
-  iter_svg_pages_for_lxml
+    iter_svg_pages_for_lxml
 )
 
 SOME_TEXT = "some text"
@@ -27,177 +27,182 @@ SOME_FONT_SIZE = "40"
 SOME_FONT_FAMILY = "Fontastic"
 SOME_FONT_COLOR = '#123'
 
+
 class LXML(object):
-  X = 'x'
-  Y = 'y'
-  BASE = 'base'
-  WIDTH = 'width'
-  HEIGHT = 'height'
-  FONT_SIZE = 'font-size'
-  FONT_NAME = 'font-name'
-  FONT_COLOR = 'font-color'
+    X = 'x'
+    Y = 'y'
+    BASE = 'base'
+    WIDTH = 'width'
+    HEIGHT = 'height'
+    FONT_SIZE = 'font-size'
+    FONT_NAME = 'font-name'
+    FONT_COLOR = 'font-color'
+
 
 class SVG(object):
-  X = 'x'
-  Y = 'y'
-  HEIGHT = 'height'
-  FONT_SIZE = 'font-size'
-  FONT_FAMILY = 'font-family'
-  FILL = 'fill'
-  BOUNDING_BOX = SVGE_BOUNDING_BOX
+    X = 'x'
+    Y = 'y'
+    HEIGHT = 'height'
+    FONT_SIZE = 'font-size'
+    FONT_FAMILY = 'font-family'
+    FILL = 'fill'
+    BOUNDING_BOX = SVGE_BOUNDING_BOX
+
 
 COMMON_LXML_TOKEN_ATTRIBS = {
-  LXML.X: SOME_X,
-  LXML.Y: SOME_Y,
-  LXML.WIDTH: SOME_WIDTH,
-  LXML.HEIGHT: SOME_HEIGHT,
-  LXML.FONT_SIZE: SOME_FONT_SIZE,
-  LXML.FONT_NAME: SOME_FONT_FAMILY,
-  LXML.FONT_COLOR: SOME_FONT_COLOR
+    LXML.X: SOME_X,
+    LXML.Y: SOME_Y,
+    LXML.WIDTH: SOME_WIDTH,
+    LXML.HEIGHT: SOME_HEIGHT,
+    LXML.FONT_SIZE: SOME_FONT_SIZE,
+    LXML.FONT_NAME: SOME_FONT_FAMILY,
+    LXML.FONT_COLOR: SOME_FONT_COLOR
 }
 
+
 def dict_extend(*dicts):
-  d = dict()
-  for x in dicts:
-    d.update(x)
-  return d
+    d = dict()
+    for x in dicts:
+        d.update(x)
+    return d
+
 
 class TestIterSvgPagesForLxml(object):
-  def test_should_return_one_page(self):
-    lxml_root = E.DOCUMENT(
-      E.PAGE(
-      )
-    )
-    svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
-    assert len(svg_pages) == 1
-
-  def test_should_return_multiple_pages(self):
-    lxml_root = E.DOCUMENT(
-      E.PAGE(
-      ),
-      E.PAGE(
-      ),
-      E.PAGE(
-      )
-    )
-    svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
-    assert len(svg_pages) == 3
-
-  def test_should_set_svg_dimensions(self):
-    lxml_root = E.DOCUMENT(
-      E.PAGE(
-        width='600',
-        height='800'
-      )
-    )
-    svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
-    assert len(svg_pages) == 1
-    assert svg_pages[0].attrib.get('viewBox') == '0 0 600 800'
-
-  def test_should_add_background_rect(self):
-    lxml_root = E.DOCUMENT(
-      E.PAGE(
-        width='600',
-        height='800'
-      )
-    )
-    svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
-    assert len(svg_pages) == 1
-    background_rect = svg_pages[0].xpath(
-      'svg:rect[@class="background"]',
-      namespaces={'svg': SVG_NS}
-    )
-    assert len(background_rect) == 1
-
-  def test_should_create_text_node_with_common_attributes(self):
-    lxml_root = E.DOCUMENT(
-      E.PAGE(
-        E.TEXT(
-          E.TOKEN(
-            SOME_TEXT,
-            COMMON_LXML_TOKEN_ATTRIBS
-          )
+    def test_should_return_one_page(self):
+        lxml_root = E.DOCUMENT(
+            E.PAGE(
+            )
         )
-      )
-    )
-    svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
-    assert len(svg_pages) == 1
-    first_page = svg_pages[0]
-    svg_text = first_page.find('.//' + SVG_TEXT)
-    assert svg_text is not None
-    assert svg_text.text == SOME_TEXT
-    assert float(svg_text.attrib[SVG.X]) == float(SOME_X)
-    assert float(svg_text.attrib[SVG.Y]) == float(SOME_Y)
-    assert float(svg_text.attrib[SVG.FONT_SIZE]) == float(SOME_FONT_SIZE)
-    assert svg_text.attrib[SVG.FONT_FAMILY] == SOME_FONT_FAMILY
-    assert svg_text.attrib[SVG.FILL] == SOME_FONT_COLOR
-    assert parse_bounding_box(svg_text.attrib.get(SVG.BOUNDING_BOX)) == BoundingBox(
-      float(COMMON_LXML_TOKEN_ATTRIBS[LXML.X]),
-      float(COMMON_LXML_TOKEN_ATTRIBS[LXML.Y]),
-      float(COMMON_LXML_TOKEN_ATTRIBS[LXML.WIDTH]),
-      float(COMMON_LXML_TOKEN_ATTRIBS[LXML.HEIGHT])
-    )
-
-  def test_should_use_base_as_y_in_svg_if_available(self):
-    lxml_root = E.DOCUMENT(
-      E.PAGE(
-        E.TEXT(
-          E.TOKEN(
-            SOME_TEXT,
-            dict_extend(COMMON_LXML_TOKEN_ATTRIBS, {
-              LXML.BASE: SOME_BASE
-            })
-          )
+        svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
+        assert len(svg_pages) == 1
+
+    def test_should_return_multiple_pages(self):
+        lxml_root = E.DOCUMENT(
+            E.PAGE(
+            ),
+            E.PAGE(
+            ),
+            E.PAGE(
+            )
         )
-      )
-    )
-    svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
-    assert len(svg_pages) == 1
-    first_page = svg_pages[0]
-    svg_text = first_page.find('.//' + SVG_TEXT)
-    assert float(svg_text.attrib[SVG.Y]) == float(SOME_BASE)
-
-  def test_should_keep_text_block_structure_without_block(self):
-    lxml_root = E.DOCUMENT(
-      E.PAGE(
-        E.TEXT(
-          E.TOKEN(
-            SOME_TEXT,
-            dict_extend(COMMON_LXML_TOKEN_ATTRIBS, {
-              LXML.BASE: SOME_BASE
-            })
-          )
+        svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
+        assert len(svg_pages) == 3
+
+    def test_should_set_svg_dimensions(self):
+        lxml_root = E.DOCUMENT(
+            E.PAGE(
+                width='600',
+                height='800'
+            )
+        )
+        svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
+        assert len(svg_pages) == 1
+        assert svg_pages[0].attrib.get('viewBox') == '0 0 600 800'
+
+    def test_should_add_background_rect(self):
+        lxml_root = E.DOCUMENT(
+            E.PAGE(
+                width='600',
+                height='800'
+            )
+        )
+        svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
+        assert len(svg_pages) == 1
+        background_rect = svg_pages[0].xpath(
+            'svg:rect[@class="background"]',
+            namespaces={'svg': SVG_NS}
+        )
+        assert len(background_rect) == 1
+
+    def test_should_create_text_node_with_common_attributes(self):
+        lxml_root = E.DOCUMENT(
+            E.PAGE(
+                E.TEXT(
+                    E.TOKEN(
+                        SOME_TEXT,
+                        COMMON_LXML_TOKEN_ATTRIBS
+                    )
+                )
+            )
+        )
+        svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
+        assert len(svg_pages) == 1
+        first_page = svg_pages[0]
+        svg_text = first_page.find('.//' + SVG_TEXT)
+        assert svg_text is not None
+        assert svg_text.text == SOME_TEXT
+        assert float(svg_text.attrib[SVG.X]) == float(SOME_X)
+        assert float(svg_text.attrib[SVG.Y]) == float(SOME_Y)
+        assert float(svg_text.attrib[SVG.FONT_SIZE]) == float(SOME_FONT_SIZE)
+        assert svg_text.attrib[SVG.FONT_FAMILY] == SOME_FONT_FAMILY
+        assert svg_text.attrib[SVG.FILL] == SOME_FONT_COLOR
+        assert parse_bounding_box(svg_text.attrib.get(SVG.BOUNDING_BOX)) == BoundingBox(
+            float(COMMON_LXML_TOKEN_ATTRIBS[LXML.X]),
+            float(COMMON_LXML_TOKEN_ATTRIBS[LXML.Y]),
+            float(COMMON_LXML_TOKEN_ATTRIBS[LXML.WIDTH]),
+            float(COMMON_LXML_TOKEN_ATTRIBS[LXML.HEIGHT])
+        )
+
+    def test_should_use_base_as_y_in_svg_if_available(self):
+        lxml_root = E.DOCUMENT(
+            E.PAGE(
+                E.TEXT(
+                    E.TOKEN(
+                        SOME_TEXT,
+                        dict_extend(COMMON_LXML_TOKEN_ATTRIBS, {
+                            LXML.BASE: SOME_BASE
+                        })
+                    )
+                )
+            )
+        )
+        svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
+        assert len(svg_pages) == 1
+        first_page = svg_pages[0]
+        svg_text = first_page.find('.//' + SVG_TEXT)
+        assert float(svg_text.attrib[SVG.Y]) == float(SOME_BASE)
+
+    def test_should_keep_text_block_structure_without_block(self):
+        lxml_root = E.DOCUMENT(
+            E.PAGE(
+                E.TEXT(
+                    E.TOKEN(
+                        SOME_TEXT,
+                        dict_extend(COMMON_LXML_TOKEN_ATTRIBS, {
+                            LXML.BASE: SOME_BASE
+                        })
+                    )
+                )
+            )
         )
-      )
-    )
-    svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
-    assert len(svg_pages) == 1
-    first_page = svg_pages[0]
-    svg_text = first_page.find('.//' + SVG_TEXT)
-    assert svg_text is not None
-    assert svg_text.getparent().tag == SVG_G
-    assert svg_text.getparent().getparent().tag == SVG_DOC
-
-  def test_should_keep_text_block_structure_with_block(self):
-    lxml_root = E.DOCUMENT(
-      E.PAGE(
-        E.BLOCK(
-          E.TEXT(
-            E.TOKEN(
-              SOME_TEXT,
-              dict_extend(COMMON_LXML_TOKEN_ATTRIBS, {
-                LXML.BASE: SOME_BASE
-              })
+        svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
+        assert len(svg_pages) == 1
+        first_page = svg_pages[0]
+        svg_text = first_page.find('.//' + SVG_TEXT)
+        assert svg_text is not None
+        assert svg_text.getparent().tag == SVG_G
+        assert svg_text.getparent().getparent().tag == SVG_DOC
+
+    def test_should_keep_text_block_structure_with_block(self):
+        lxml_root = E.DOCUMENT(
+            E.PAGE(
+                E.BLOCK(
+                    E.TEXT(
+                        E.TOKEN(
+                            SOME_TEXT,
+                            dict_extend(COMMON_LXML_TOKEN_ATTRIBS, {
+                                LXML.BASE: SOME_BASE
+                            })
+                        )
+                    )
+                )
             )
-          )
         )
-      )
-    )
-    svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
-    assert len(svg_pages) == 1
-    first_page = svg_pages[0]
-    svg_text = first_page.find('.//' + SVG_TEXT)
-    assert svg_text is not None
-    assert svg_text.getparent().tag == SVG_G
-    assert svg_text.getparent().getparent().tag == SVG_G
-    assert svg_text.getparent().getparent().getparent().tag == SVG_DOC
+        svg_pages = list(iter_svg_pages_for_lxml(lxml_root))
+        assert len(svg_pages) == 1
+        first_page = svg_pages[0]
+        svg_text = first_page.find('.//' + SVG_TEXT)
+        assert svg_text is not None
+        assert svg_text.getparent().tag == SVG_G
+        assert svg_text.getparent().getparent().tag == SVG_G
+        assert svg_text.getparent().getparent().getparent().tag == SVG_DOC
diff --git a/sciencebeam_gym/preprocess/preprocessing_pipeline.py b/sciencebeam_gym/preprocess/preprocessing_pipeline.py
index 9f36244..9f3749e 100644
--- a/sciencebeam_gym/preprocess/preprocessing_pipeline.py
+++ b/sciencebeam_gym/preprocess/preprocessing_pipeline.py
@@ -10,550 +10,563 @@ from apache_beam.io.filesystems import FileSystems
 from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
 
 from sciencebeam_utils.beam_utils.utils import (
-  TransformAndCount,
-  TransformAndLog,
-  MapOrLog,
-  PreventFusion
+    TransformAndCount,
+    TransformAndLog,
+    MapOrLog,
+    PreventFusion
 )
 
 from sciencebeam_utils.beam_utils.csv import (
-  WriteDictCsv,
-  ReadDictCsv
+    WriteDictCsv,
+    ReadDictCsv
 )
 
 from sciencebeam_utils.beam_utils.io import (
-  read_all_from_path,
-  basename,
-  save_file_content
+    read_all_from_path,
+    basename,
+    save_file_content
 )
 
 from sciencebeam_utils.beam_utils.main import (
-  add_cloud_args,
-  process_cloud_args
+    add_cloud_args,
+    process_cloud_args
 )
 
 from sciencebeam_utils.utils.collection import (
-  extend_dict,
-  remove_keys_from_dict
+    extend_dict,
+    remove_keys_from_dict
 )
 
 from sciencebeam_utils.utils.file_path import (
-  change_ext,
-  relative_path,
-  join_if_relative_path
+    change_ext,
+    relative_path,
+    join_if_relative_path
 )
 
 from sciencebeam_utils.utils.file_pairs import (
-  find_file_pairs_grouped_by_parent_directory_or_name,
+    find_file_pairs_grouped_by_parent_directory_or_name,
 )
 
 from sciencebeam_gym.structured_document.svg import (
-  SvgStructuredDocument
+    SvgStructuredDocument
 )
 
 from sciencebeam_gym.preprocess.annotation.target_annotation import (
-  parse_xml_mapping
+    parse_xml_mapping
 )
 
 from sciencebeam_gym.preprocess.color_map import (
-  parse_color_map_from_file
+    parse_color_map_from_file
 )
 
 from sciencebeam_gym.preprocess.annotation.annotation_evaluation import (
-  evaluate_document_by_page,
-  DEFAULT_EVALUATION_COLUMNS,
-  to_csv_dict_rows as to_annotation_evaluation_csv_dict_rows
+    evaluate_document_by_page,
+    DEFAULT_EVALUATION_COLUMNS,
+    to_csv_dict_rows as to_annotation_evaluation_csv_dict_rows
 )
 
 from sciencebeam_gym.preprocess.preprocessing_utils import (
-  convert_pdf_bytes_to_lxml,
-  convert_and_annotate_lxml_content,
-  pdf_bytes_to_png_pages,
-  svg_page_to_blockified_png_bytes,
-  save_pages,
-  save_svg_roots,
-  filter_list_props_by_indices,
-  get_page_indices_with_min_annotation_percentage,
-  parse_page_range
+    convert_pdf_bytes_to_lxml,
+    convert_and_annotate_lxml_content,
+    pdf_bytes_to_png_pages,
+    svg_page_to_blockified_png_bytes,
+    save_pages,
+    save_svg_roots,
+    filter_list_props_by_indices,
+    get_page_indices_with_min_annotation_percentage,
+    parse_page_range
 )
 
 from sciencebeam_gym.preprocess.preprocessing_transforms import (
-  WritePropsToTFRecord
+    WritePropsToTFRecord
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class MetricCounters(object):
-  FILE_PAIR = 'file_pair_count'
-  PAGE = 'page_count'
-  FILTERED_PAGE = 'filtered_page_count'
-  CONVERT_PDF_TO_LXML_ERROR = 'ConvertPdfToLxml_error_count'
-  CONVERT_PDF_TO_PNG_ERROR = 'ConvertPdfToPng_error_count'
-  CONVERT_LXML_TO_SVG_ANNOT_ERROR = 'ConvertPdfToSvgAnnot_error_count'
+    FILE_PAIR = 'file_pair_count'
+    PAGE = 'page_count'
+    FILTERED_PAGE = 'filtered_page_count'
+    CONVERT_PDF_TO_LXML_ERROR = 'ConvertPdfToLxml_error_count'
+    CONVERT_PDF_TO_PNG_ERROR = 'ConvertPdfToPng_error_count'
+    CONVERT_LXML_TO_SVG_ANNOT_ERROR = 'ConvertPdfToSvgAnnot_error_count'
+
 
 def configure_pipeline(p, opt):
-  image_size = (
-    (opt.image_width, opt.image_height)
-    if opt.image_width and opt.image_height
-    else None
-  )
-  page_range = opt.pages
-  first_page = page_range[0] if page_range else 1
-  xml_mapping = parse_xml_mapping(opt.xml_mapping_path)
-  if opt.lxml_path:
-    lxml_xml_file_pairs = (
-      p |
-      beam.Create([[
-        join_if_relative_path(opt.base_data_path, s)
-        for s in [opt.lxml_path, opt.xml_path]
-      ]]) |
-      "FindFilePairs" >> TransformAndLog(
-        beam.FlatMap(
-          lambda patterns: islice(
-            find_file_pairs_grouped_by_parent_directory_or_name(patterns),
-            opt.limit
-          )
-        ),
-        log_prefix='file pairs: ',
-        log_level='debug'
-      ) |
-      PreventFusion() |
-      "ReadFileContent" >> beam.Map(lambda filenames: {
-        'source_filename': filenames[0],
-        'xml_filename': filenames[1],
-        'lxml_content': read_all_from_path(filenames[0]),
-        'xml_content': read_all_from_path(filenames[1])
-      })
+    image_size = (
+        (opt.image_width, opt.image_height)
+        if opt.image_width and opt.image_height
+        else None
     )
-  elif opt.pdf_path or opt.pdf_xml_file_list:
-    if opt.pdf_xml_file_list:
-      pdf_xml_url_pairs = (
-        p |
-        "ReadFilePairUrls" >> ReadDictCsv(opt.pdf_xml_file_list, limit=opt.limit) |
-        "TranslateFilePairUrls" >> beam.Map(lambda row: (row['source_url'], row['xml_url']))
-      )
-    else:
-      pdf_xml_url_pairs = (
-        p |
-        beam.Create([[
-          join_if_relative_path(opt.base_data_path, s)
-          for s in [opt.pdf_path, opt.xml_path]
-        ]]) |
-        "FindFilePairs" >> TransformAndLog(
-          beam.FlatMap(
-            lambda patterns: islice(
-              find_file_pairs_grouped_by_parent_directory_or_name(patterns),
-              opt.limit
+    page_range = opt.pages
+    first_page = page_range[0] if page_range else 1
+    xml_mapping = parse_xml_mapping(opt.xml_mapping_path)
+    if opt.lxml_path:
+        lxml_xml_file_pairs = (
+            p |
+            beam.Create([[
+                join_if_relative_path(opt.base_data_path, s)
+                for s in [opt.lxml_path, opt.xml_path]
+            ]]) |
+            "FindFilePairs" >> TransformAndLog(
+                beam.FlatMap(
+                    lambda patterns: islice(
+                        find_file_pairs_grouped_by_parent_directory_or_name(patterns),
+                        opt.limit
+                    )
+                ),
+                log_prefix='file pairs: ',
+                log_level='debug'
+            ) |
+            PreventFusion() |
+            "ReadFileContent" >> beam.Map(lambda filenames: {
+                'source_filename': filenames[0],
+                'xml_filename': filenames[1],
+                'lxml_content': read_all_from_path(filenames[0]),
+                'xml_content': read_all_from_path(filenames[1])
+            })
+        )
+    elif opt.pdf_path or opt.pdf_xml_file_list:
+        if opt.pdf_xml_file_list:
+            pdf_xml_url_pairs = (
+                p |
+                "ReadFilePairUrls" >> ReadDictCsv(opt.pdf_xml_file_list, limit=opt.limit) |
+                "TranslateFilePairUrls" >> beam.Map(lambda row: (row['source_url'], row['xml_url']))
+            )
+        else:
+            pdf_xml_url_pairs = (
+                p |
+                beam.Create([[
+                    join_if_relative_path(opt.base_data_path, s)
+                    for s in [opt.pdf_path, opt.xml_path]
+                ]]) |
+                "FindFilePairs" >> TransformAndLog(
+                    beam.FlatMap(
+                        lambda patterns: islice(
+                            find_file_pairs_grouped_by_parent_directory_or_name(patterns),
+                            opt.limit
+                        )
+                    ),
+                    log_prefix='file pairs: ',
+                    log_level='debug'
+                )
+            )
+        pdf_xml_file_pairs = (
+            pdf_xml_url_pairs |
+            PreventFusion() |
+            "ReadFileContent" >> TransformAndCount(
+                beam.Map(lambda filenames: {
+                    'source_filename': filenames[0],
+                    'xml_filename': filenames[1],
+                    'pdf_content': read_all_from_path(filenames[0]),
+                    'xml_content': read_all_from_path(filenames[1])
+                }),
+                MetricCounters.FILE_PAIR
             )
-          ),
-          log_prefix='file pairs: ',
-          log_level='debug'
         )
-      )
-    pdf_xml_file_pairs = (
-      pdf_xml_url_pairs |
-      PreventFusion() |
-      "ReadFileContent" >> TransformAndCount(
-        beam.Map(lambda filenames: {
-          'source_filename': filenames[0],
-          'xml_filename': filenames[1],
-          'pdf_content': read_all_from_path(filenames[0]),
-          'xml_content': read_all_from_path(filenames[1])
-        }),
-        MetricCounters.FILE_PAIR
-      )
-    )
 
-    lxml_xml_file_pairs = (
-      pdf_xml_file_pairs |
-      "ConvertPdfToLxml" >> MapOrLog(lambda v: remove_keys_from_dict(
-        extend_dict(v, {
-          'lxml_content': convert_pdf_bytes_to_lxml(
-            v['pdf_content'], path=v['source_filename'],
-            page_range=page_range
-          )
-        }),
-        # we don't need the pdf_content unless we are writing tf_records
-        None if opt.save_tfrecords else {'pdf_content'}
-      ), log_fn=lambda e, v: (
-        get_logger().warning(
-          'caught exception (ignoring item): %s, pdf: %s, xml: %s',
-          e, v['source_filename'], v['xml_filename'], exc_info=e
+        lxml_xml_file_pairs = (
+            pdf_xml_file_pairs |
+            "ConvertPdfToLxml" >> MapOrLog(lambda v: remove_keys_from_dict(
+                extend_dict(v, {
+                    'lxml_content': convert_pdf_bytes_to_lxml(
+                        v['pdf_content'], path=v['source_filename'],
+                        page_range=page_range
+                    )
+                }),
+                # we don't need the pdf_content unless we are writing tf_records
+                None if opt.save_tfrecords else {'pdf_content'}
+            ), log_fn=lambda e, v: (
+                get_logger().warning(
+                    'caught exception (ignoring item): %s, pdf: %s, xml: %s',
+                    e, v['source_filename'], v['xml_filename'], exc_info=e
+                )
+            ), error_count=MetricCounters.CONVERT_PDF_TO_LXML_ERROR)
         )
-      ), error_count=MetricCounters.CONVERT_PDF_TO_LXML_ERROR)
-    )
-  else:
-    raise RuntimeError('either lxml-path or pdf-path required')
-
-  if opt.save_png or opt.save_tfrecords:
-    with_pdf_png_pages = (
-      (lxml_xml_file_pairs if opt.save_tfrecords else pdf_xml_file_pairs) |
-      "ConvertPdfToPng" >> MapOrLog(lambda v: remove_keys_from_dict(
-        extend_dict(v, {
-          'pdf_png_pages':  list(pdf_bytes_to_png_pages(
-            v['pdf_content'],
-            dpi=opt.png_dpi,
-            image_size=image_size,
-            page_range=page_range
-          ))
-        }),
-        {'pdf_content'} # we no longer need the pdf_content
-      ), error_count=MetricCounters.CONVERT_PDF_TO_PNG_ERROR)
-    )
-
-    if opt.save_png:
-      _ = (
-        with_pdf_png_pages |
-        "SavePdfToPng" >> TransformAndLog(
-          beam.Map(lambda v: save_pages(
-            FileSystems.join(
-              opt.output_path,
-              change_ext(
-                relative_path(opt.base_data_path, v['source_filename']),
-                None, '.png.zip'
-              )
-            ),
-            '.png',
-            v['pdf_png_pages']
-          )),
-          log_fn=lambda x: get_logger().info('saved result: %s', x)
+    else:
+        raise RuntimeError('either lxml-path or pdf-path required')
+
+    if opt.save_png or opt.save_tfrecords:
+        with_pdf_png_pages = (
+            (lxml_xml_file_pairs if opt.save_tfrecords else pdf_xml_file_pairs) |
+            "ConvertPdfToPng" >> MapOrLog(lambda v: remove_keys_from_dict(
+                extend_dict(v, {
+                    'pdf_png_pages': list(pdf_bytes_to_png_pages(
+                        v['pdf_content'],
+                        dpi=opt.png_dpi,
+                        image_size=image_size,
+                        page_range=page_range
+                    ))
+                }),
+                {'pdf_content'}  # we no longer need the pdf_content
+            ), error_count=MetricCounters.CONVERT_PDF_TO_PNG_ERROR)
         )
-      )
-
-  if opt.save_lxml:
-    _ = (
-      lxml_xml_file_pairs |
-      "SaveLxml" >> TransformAndLog(
-        beam.Map(lambda v: save_file_content(
-          FileSystems.join(
-            opt.output_path,
-            change_ext(
-              relative_path(opt.base_data_path, v['source_filename']),
-              None, '.lxml.gz'
+
+        if opt.save_png:
+            _ = (
+                with_pdf_png_pages |
+                "SavePdfToPng" >> TransformAndLog(
+                    beam.Map(lambda v: save_pages(
+                        FileSystems.join(
+                            opt.output_path,
+                            change_ext(
+                                relative_path(opt.base_data_path, v['source_filename']),
+                                None, '.png.zip'
+                            )
+                        ),
+                        '.png',
+                        v['pdf_png_pages']
+                    )),
+                    log_fn=lambda x: get_logger().info('saved result: %s', x)
+                )
             )
-          ),
-          v['lxml_content']
-        )),
-        log_fn=lambda x: get_logger().info('saved lxml: %s', x)
-      )
-    )
 
-  annotation_results = (
-    (with_pdf_png_pages if opt.save_tfrecords else lxml_xml_file_pairs) |
-    "ConvertLxmlToSvgAndAnnotate" >> TransformAndCount(
-      MapOrLog(lambda v: remove_keys_from_dict(
-        extend_dict(v, {
-          'svg_pages': list(convert_and_annotate_lxml_content(
-            v['lxml_content'], v['xml_content'], xml_mapping,
-            name=v['source_filename']
-          ))
-        }),
-        # Won't need the XML anymore
-        {'lxml_content', 'xml_content'}
-      ), log_fn=lambda e, v: (
-        get_logger().warning(
-          'caught exception (ignoring item): %s, source: %s, xml: %s',
-          e, v['source_filename'], v['xml_filename'], exc_info=e
-        )
-      ), error_count=MetricCounters.CONVERT_LXML_TO_SVG_ANNOT_ERROR),
-      MetricCounters.PAGE,
-      lambda v: len(v['svg_pages'])
-    )
-  )
-
-  if opt.save_svg:
-    _ = (
-      annotation_results |
-      "SaveSvgPages" >> TransformAndLog(
-        beam.Map(lambda v: save_svg_roots(
-          FileSystems.join(
-            opt.output_path,
-            change_ext(
-              relative_path(opt.base_data_path, v['source_filename']),
-              None, '.svg.zip'
+    if opt.save_lxml:
+        _ = (
+            lxml_xml_file_pairs |
+            "SaveLxml" >> TransformAndLog(
+                beam.Map(lambda v: save_file_content(
+                    FileSystems.join(
+                        opt.output_path,
+                        change_ext(
+                            relative_path(opt.base_data_path, v['source_filename']),
+                            None, '.lxml.gz'
+                        )
+                    ),
+                    v['lxml_content']
+                )),
+                log_fn=lambda x: get_logger().info('saved lxml: %s', x)
             )
-          ),
-          v['svg_pages']
-        )),
-        log_fn=lambda x: get_logger().info('saved result: %s', x)
-      )
+        )
+
+    annotation_results = (
+        (with_pdf_png_pages if opt.save_tfrecords else lxml_xml_file_pairs) |
+        "ConvertLxmlToSvgAndAnnotate" >> TransformAndCount(
+            MapOrLog(lambda v: remove_keys_from_dict(
+                extend_dict(v, {
+                    'svg_pages': list(convert_and_annotate_lxml_content(
+                        v['lxml_content'], v['xml_content'], xml_mapping,
+                        name=v['source_filename']
+                    ))
+                }),
+                # Won't need the XML anymore
+                {'lxml_content', 'xml_content'}
+            ), log_fn=lambda e, v: (
+                get_logger().warning(
+                    'caught exception (ignoring item): %s, source: %s, xml: %s',
+                    e, v['source_filename'], v['xml_filename'], exc_info=e
+                )
+            ), error_count=MetricCounters.CONVERT_LXML_TO_SVG_ANNOT_ERROR),
+            MetricCounters.PAGE,
+            lambda v: len(v['svg_pages'])
+        )
     )
 
-  if opt.annotation_evaluation_csv or opt.min_annotation_percentage:
-    annotation_evaluation_results = (
-      annotation_results |
-      "EvaluateAnnotations" >> TransformAndLog(
-        beam.Map(lambda v: remove_keys_from_dict(
-          extend_dict(v, {
-            'annotation_evaluation': evaluate_document_by_page(
-              SvgStructuredDocument(v['svg_pages'])
+    if opt.save_svg:
+        _ = (
+            annotation_results |
+            "SaveSvgPages" >> TransformAndLog(
+                beam.Map(lambda v: save_svg_roots(
+                    FileSystems.join(
+                        opt.output_path,
+                        change_ext(
+                            relative_path(opt.base_data_path, v['source_filename']),
+                            None, '.svg.zip'
+                        )
+                    ),
+                    v['svg_pages']
+                )),
+                log_fn=lambda x: get_logger().info('saved result: %s', x)
             )
-          }),
-          None if opt.min_annotation_percentage else {'svg_pages'}
-        )),
-        log_fn=lambda x: get_logger().info(
-          'annotation evaluation result: %s: %s',
-          x['source_filename'], x['annotation_evaluation']
         )
-      )
-    )
 
-  if opt.save_block_png or opt.save_tfrecords:
-    color_map = parse_color_map_from_file(opt.color_map)
-    with_block_png_pages = (
-      (annotation_evaluation_results if opt.min_annotation_percentage else annotation_results) |
-      "GenerateBlockPng" >> beam.Map(lambda v: remove_keys_from_dict(
-        extend_dict(v, {
-          'block_png_pages': [
-            svg_page_to_blockified_png_bytes(svg_page, color_map, image_size=image_size)
-            for svg_page in v['svg_pages']
-          ]
-        }),
-        {'svg_pages'}
-      ))
-    )
+    if opt.annotation_evaluation_csv or opt.min_annotation_percentage:
+        annotation_evaluation_results = (
+            annotation_results |
+            "EvaluateAnnotations" >> TransformAndLog(
+                beam.Map(lambda v: remove_keys_from_dict(
+                    extend_dict(v, {
+                        'annotation_evaluation': evaluate_document_by_page(
+                            SvgStructuredDocument(v['svg_pages'])
+                        )
+                    }),
+                    None if opt.min_annotation_percentage else {'svg_pages'}
+                )),
+                log_fn=lambda x: get_logger().info(
+                    'annotation evaluation result: %s: %s',
+                    x['source_filename'], x['annotation_evaluation']
+                )
+            )
+        )
 
-    if opt.save_block_png:
-      _ = (
-        with_block_png_pages |
-        "SaveBlockPng" >> TransformAndLog(
-          beam.Map(lambda v: save_pages(
-            FileSystems.join(
-              opt.output_path,
-              change_ext(
-                relative_path(opt.base_data_path, v['source_filename']),
-                None, '.block-png.zip'
-              )
-            ),
-            '.png',
-            v['block_png_pages']
-          )),
-          log_fn=lambda x: get_logger().info('saved result: %s', x)
+    if opt.save_block_png or opt.save_tfrecords:
+        color_map = parse_color_map_from_file(opt.color_map)
+        with_block_png_pages = (
+            (
+                annotation_evaluation_results
+                if opt.min_annotation_percentage
+                else annotation_results
+            ) |
+            "GenerateBlockPng" >> beam.Map(lambda v: remove_keys_from_dict(
+                extend_dict(v, {
+                    'block_png_pages': [
+                        svg_page_to_blockified_png_bytes(svg_page, color_map, image_size=image_size)
+                        for svg_page in v['svg_pages']
+                    ]
+                }),
+                {'svg_pages'}
+            ))
         )
-      )
-
-    if opt.save_tfrecords:
-      if opt.min_annotation_percentage:
-        filtered_pages = (
-          with_block_png_pages |
-          "FilterPages" >> TransformAndCount(
-            beam.Map(
-              lambda v: filter_list_props_by_indices(
-                v,
-                get_page_indices_with_min_annotation_percentage(
-                  v['annotation_evaluation'],
-                  opt.min_annotation_percentage
-                ),
-                {'pdf_png_pages', 'block_png_pages'}
-              )
-            ),
-            MetricCounters.FILTERED_PAGE,
-            lambda v: len(v['block_png_pages'])
-          )
+
+        if opt.save_block_png:
+            _ = (
+                with_block_png_pages |
+                "SaveBlockPng" >> TransformAndLog(
+                    beam.Map(lambda v: save_pages(
+                        FileSystems.join(
+                            opt.output_path,
+                            change_ext(
+                                relative_path(opt.base_data_path, v['source_filename']),
+                                None, '.block-png.zip'
+                            )
+                        ),
+                        '.png',
+                        v['block_png_pages']
+                    )),
+                    log_fn=lambda x: get_logger().info('saved result: %s', x)
+                )
+            )
+
+        if opt.save_tfrecords:
+            if opt.min_annotation_percentage:
+                filtered_pages = (
+                    with_block_png_pages |
+                    "FilterPages" >> TransformAndCount(
+                        beam.Map(
+                            lambda v: filter_list_props_by_indices(
+                                v,
+                                get_page_indices_with_min_annotation_percentage(
+                                    v['annotation_evaluation'],
+                                    opt.min_annotation_percentage
+                                ),
+                                {'pdf_png_pages', 'block_png_pages'}
+                            )
+                        ),
+                        MetricCounters.FILTERED_PAGE,
+                        lambda v: len(v['block_png_pages'])
+                    )
+                )
+            else:
+                filtered_pages = with_block_png_pages
+            _ = (
+                filtered_pages |
+                "WriteTFRecords" >> WritePropsToTFRecord(
+                    FileSystems.join(opt.output_path, 'data'),
+                    lambda v: (
+                        {
+                            'input_uri': v['source_filename'] + '#page%d' % (first_page + i),
+                            'input_image': pdf_png_page,
+                            'annotation_uri': (
+                                v['source_filename'] + '.annot' + '#page%d' % (first_page + i)
+                            ),
+                            'annotation_image': block_png_page,
+                            'page_no': first_page + i
+                        }
+                        for i, pdf_png_page, block_png_page in zip(
+                            range(len(v['pdf_png_pages'])), v['pdf_png_pages'], v['block_png_pages']
+                        )
+                    )
+                )
+            )
+
+    if opt.annotation_evaluation_csv:
+        annotation_evaluation_csv_name, annotation_evaluation_ext = (
+            os.path.splitext(opt.annotation_evaluation_csv)
         )
-      else:
-        filtered_pages = with_block_png_pages
-      _ = (
-        filtered_pages |
-        "WriteTFRecords" >> WritePropsToTFRecord(
-          FileSystems.join(opt.output_path, 'data'),
-          lambda v: (
-            {
-              'input_uri': v['source_filename'] + '#page%d' % (first_page + i),
-              'input_image': pdf_png_page,
-              'annotation_uri': v['source_filename'] + '.annot' + '#page%d' % (first_page + i),
-              'annotation_image': block_png_page,
-              'page_no': first_page + i
-            }
-            for i, pdf_png_page, block_png_page in zip(
-              range(len(v['pdf_png_pages'])), v['pdf_png_pages'], v['block_png_pages']
+        _ = (  # flake8: noqa
+            annotation_evaluation_results |
+            "FlattenAnotationEvaluationResults" >> beam.FlatMap(
+                lambda v: to_annotation_evaluation_csv_dict_rows(
+                    v['annotation_evaluation'],
+                    document=basename(v['source_filename'])
+                )
+            ) |
+            "WriteAnnotationEvaluationToCsv" >> WriteDictCsv(
+                join_if_relative_path(opt.output_path, annotation_evaluation_csv_name),
+                file_name_suffix=annotation_evaluation_ext,
+                columns=DEFAULT_EVALUATION_COLUMNS
             )
-          )
         )
-      )
 
-  if opt.annotation_evaluation_csv:
-    annotation_evaluation_csv_name, annotation_evaluation_ext = (
-      os.path.splitext(opt.annotation_evaluation_csv)
+
+def add_main_args(parser):
+    parser.add_argument(
+        '--data-path', type=str, required=True,
+        help='base data path'
     )
-    _ = (
-      annotation_evaluation_results |
-      "FlattenAnotationEvaluationResults" >> beam.FlatMap(
-        lambda v: to_annotation_evaluation_csv_dict_rows(
-          v['annotation_evaluation'],
-          document=basename(v['source_filename'])
-        )
-      ) |
-      "WriteAnnotationEvaluationToCsv" >> WriteDictCsv(
-        join_if_relative_path(opt.output_path, annotation_evaluation_csv_name),
-        file_name_suffix=annotation_evaluation_ext,
-        columns=DEFAULT_EVALUATION_COLUMNS
-      )
+
+    source_group = parser.add_mutually_exclusive_group(required=True)
+    source_group.add_argument(
+        '--lxml-path', type=str, required=False,
+        help='path to lxml file(s)'
+    )
+    source_group.add_argument(
+        '--pdf-path', type=str, required=False,
+        help='path to pdf file(s) (alternative to lxml)'
+    )
+    source_group.add_argument(
+        '--pdf-xml-file-list', type=str, required=False,
+        help='path to pdf-xml csv/tsv file list'
+    )
+    parser.add_argument(
+        '--limit', type=int, required=False,
+        help='limit the number of file pairs to process'
     )
 
-def add_main_args(parser):
-  parser.add_argument(
-    '--data-path', type=str, required=True,
-    help='base data path'
-  )
-
-  source_group = parser.add_mutually_exclusive_group(required=True)
-  source_group.add_argument(
-    '--lxml-path', type=str, required=False,
-    help='path to lxml file(s)'
-  )
-  source_group.add_argument(
-    '--pdf-path', type=str, required=False,
-    help='path to pdf file(s) (alternative to lxml)'
-  )
-  source_group.add_argument(
-    '--pdf-xml-file-list', type=str, required=False,
-    help='path to pdf-xml csv/tsv file list'
-  )
-  parser.add_argument(
-    '--limit', type=int, required=False,
-    help='limit the number of file pairs to process'
-  )
-
-  parser.add_argument(
-    '--save-lxml', default=False, action='store_true',
-    help='save generated lxml (if using pdf as an input)'
-  )
-
-  parser.add_argument(
-    '--save-svg', default=False, action='store_true',
-    help='save svg pages with annotation tags'
-  )
-
-  parser.add_argument(
-    '--save-png', default=False, action='store_true',
-    help='save png pages of the original pdf'
-  )
-  parser.add_argument(
-    '--png-dpi', type=int, default=90,
-    help='dpi of rendered pdf pages'
-  )
-
-  parser.add_argument(
-    '--image-width', type=int, required=False,
-    help='image width of resulting PNGs'
-  )
-  parser.add_argument(
-    '--image-height', type=int, required=False,
-    help='image height of resulting PNGs'
-  )
-
-  parser.add_argument(
-    '--save-block-png', default=False, action='store_true',
-    help='save blockified version of the svg as a png'
-  )
-  parser.add_argument(
-    '--color-map', default='color_map.conf',
-    help='color map to use (see save-block-png)'
-  )
-
-  parser.add_argument(
-    '--xml-path', type=str, required=False,
-    help='path to xml file(s)'
-  )
-  parser.add_argument(
-    '--xml-mapping-path', type=str, default='annot-xml-front.conf',
-    help='path to xml mapping file'
-  )
-
-  parser.add_argument(
-    '--pages', type=parse_page_range, default=None,
-    help='only processes the selected pages'
-  )
-
-  parser.add_argument(
-    '--save-tfrecords', default=False, action='store_true',
-    help='Save TFRecords with PDF PNG and Annotation PNG'
-    ' (--image-width and --image-height recommended)'
-  )
-
-  parser.add_argument(
-    '--min-annotation-percentage', type=float, required=False,
-    help='Minimum percentage of annotations per page'
-    ' (pages below that threshold will get dropped)'
-  )
-
-  parser.add_argument(
-    '--annotation-evaluation-csv', type=str, required=False,
-    help='Annotation evaluation CSV output file'
-  )
-  parser.add_argument(
-    '--output-path', required=False,
-    help='Output directory to write results to.'
-  )
+    parser.add_argument(
+        '--save-lxml', default=False, action='store_true',
+        help='save generated lxml (if using pdf as an input)'
+    )
 
-def process_main_args(parser, args):
-  args.base_data_path = args.data_path.replace('/*/', '/')
+    parser.add_argument(
+        '--save-svg', default=False, action='store_true',
+        help='save svg pages with annotation tags'
+    )
 
-  if not args.output_path:
-    args.output_path = os.path.join(
-      os.path.dirname(args.base_data_path),
-      os.path.basename(args.base_data_path + '-results')
+    parser.add_argument(
+        '--save-png', default=False, action='store_true',
+        help='save png pages of the original pdf'
+    )
+    parser.add_argument(
+        '--png-dpi', type=int, default=90,
+        help='dpi of rendered pdf pages'
     )
 
-  if not args.xml_path and not args.pdf_xml_file_list:
-    parser.error('--xml-path required unless --pdf-xml-file-list is specified')
+    parser.add_argument(
+        '--image-width', type=int, required=False,
+        help='image width of resulting PNGs'
+    )
+    parser.add_argument(
+        '--image-height', type=int, required=False,
+        help='image height of resulting PNGs'
+    )
 
-  pdf_path_or_pdf_xml_file_list = args.pdf_path or args.pdf_xml_file_list
+    parser.add_argument(
+        '--save-block-png', default=False, action='store_true',
+        help='save blockified version of the svg as a png'
+    )
+    parser.add_argument(
+        '--color-map', default='color_map.conf',
+        help='color map to use (see save-block-png)'
+    )
 
-  if args.save_lxml and not pdf_path_or_pdf_xml_file_list:
-    parser.error('--save-lxml only valid with --pdf-path or --pdf-xml-file-list')
+    parser.add_argument(
+        '--xml-path', type=str, required=False,
+        help='path to xml file(s)'
+    )
+    parser.add_argument(
+        '--xml-mapping-path', type=str, default='annot-xml-front.conf',
+        help='path to xml mapping file'
+    )
 
-  if args.save_png and not pdf_path_or_pdf_xml_file_list:
-    parser.error('--save-png only valid with --pdf-path or --pdf-xml-file-list')
+    parser.add_argument(
+        '--pages', type=parse_page_range, default=None,
+        help='only processes the selected pages'
+    )
 
-  if args.save_tfrecords and not pdf_path_or_pdf_xml_file_list:
-    parser.error('--save-tfrecords only valid with --pdf-path or --pdf-xml-file-list')
+    parser.add_argument(
+        '--save-tfrecords', default=False, action='store_true',
+        help='Save TFRecords with PDF PNG and Annotation PNG'
+        ' (--image-width and --image-height recommended)'
+    )
 
-  if sum(1 if x else 0 for x in (args.image_width, args.image_height)) == 1:
-    parser.error('--image-width and --image-height need to be specified together')
+    parser.add_argument(
+        '--min-annotation-percentage', type=float, required=False,
+        help='Minimum percentage of annotations per page'
+        ' (pages below that threshold will get dropped)'
+    )
 
-  if not (args.save_lxml or args.save_svg or args.save_png or args.save_tfrecords):
-    parser.error(
-      'at least one of the output options required:'
-      ' --save-lxml --save-svg --save-png or --save-tfrecords'
+    parser.add_argument(
+        '--annotation-evaluation-csv', type=str, required=False,
+        help='Annotation evaluation CSV output file'
     )
+    parser.add_argument(
+        '--output-path', required=False,
+        help='Output directory to write results to.'
+    )
+
+
+def process_main_args(parser, args):
+    args.base_data_path = args.data_path.replace('/*/', '/')
+
+    if not args.output_path:
+        args.output_path = os.path.join(
+            os.path.dirname(args.base_data_path),
+            os.path.basename(args.base_data_path + '-results')
+        )
+
+    if not args.xml_path and not args.pdf_xml_file_list:
+        parser.error('--xml-path required unless --pdf-xml-file-list is specified')
+
+    pdf_path_or_pdf_xml_file_list = args.pdf_path or args.pdf_xml_file_list
+
+    if args.save_lxml and not pdf_path_or_pdf_xml_file_list:
+        parser.error('--save-lxml only valid with --pdf-path or --pdf-xml-file-list')
+
+    if args.save_png and not pdf_path_or_pdf_xml_file_list:
+        parser.error('--save-png only valid with --pdf-path or --pdf-xml-file-list')
+
+    if args.save_tfrecords and not pdf_path_or_pdf_xml_file_list:
+        parser.error('--save-tfrecords only valid with --pdf-path or --pdf-xml-file-list')
+
+    if sum(1 if x else 0 for x in (args.image_width, args.image_height)) == 1:
+        parser.error('--image-width and --image-height need to be specified together')
+
+    if not (args.save_lxml or args.save_svg or args.save_png or args.save_tfrecords):
+        parser.error(
+            'at least one of the output options required:'
+            ' --save-lxml --save-svg --save-png or --save-tfrecords'
+        )
+
 
 def parse_args(argv=None):
-  parser = argparse.ArgumentParser()
-  add_main_args(parser)
-  add_cloud_args(parser)
+    parser = argparse.ArgumentParser()
+    add_main_args(parser)
+    add_cloud_args(parser)
 
-  # parsed_args, other_args = parser.parse_known_args(argv)
-  parsed_args = parser.parse_args(argv)
+    # parsed_args, other_args = parser.parse_known_args(argv)
+    parsed_args = parser.parse_args(argv)
+
+    process_main_args(parser, parsed_args)
+    process_cloud_args(
+        parsed_args, parsed_args.output_path,
+        name='sciencbeam-gym-preprocessing'
+    )
 
-  process_main_args(parser, parsed_args)
-  process_cloud_args(
-    parsed_args, parsed_args.output_path,
-    name='sciencbeam-gym-preprocessing'
-  )
+    get_logger().info('parsed_args: %s', parsed_args)
 
-  get_logger().info('parsed_args: %s', parsed_args)
+    return parsed_args
 
-  return parsed_args
 
 def run(argv=None):
-  """Main entry point; defines and runs the tfidf pipeline."""
-  known_args = parse_args(argv)
+    """Main entry point; defines and runs the tfidf pipeline."""
+    known_args = parse_args(argv)
 
-  # We use the save_main_session option because one or more DoFn's in this
-  # workflow rely on global context (e.g., a module imported at module level).
-  pipeline_options = PipelineOptions.from_dictionary(vars(known_args))
-  pipeline_options.view_as(SetupOptions).save_main_session = True
+    # We use the save_main_session option because one or more DoFn's in this
+    # workflow rely on global context (e.g., a module imported at module level).
+    pipeline_options = PipelineOptions.from_dictionary(vars(known_args))
+    pipeline_options.view_as(SetupOptions).save_main_session = True
 
-  with beam.Pipeline(known_args.runner, options=pipeline_options) as p:
-    configure_pipeline(p, known_args)
+    with beam.Pipeline(known_args.runner, options=pipeline_options) as p:
+        configure_pipeline(p, known_args)
 
-    # Execute the pipeline and wait until it is completed.
+        # Execute the pipeline and wait until it is completed.
 
 
 if __name__ == '__main__':
-  logging.basicConfig(level='INFO')
+    logging.basicConfig(level='INFO')
 
-  run()
+    run()
diff --git a/sciencebeam_gym/preprocess/preprocessing_pipeline_test.py b/sciencebeam_gym/preprocess/preprocessing_pipeline_test.py
index f444ce7..382c30a 100644
--- a/sciencebeam_gym/preprocess/preprocessing_pipeline_test.py
+++ b/sciencebeam_gym/preprocess/preprocessing_pipeline_test.py
@@ -7,24 +7,24 @@ import pytest
 import apache_beam as beam
 
 from sciencebeam_utils.beam_utils.utils import (
-  TransformAndLog
+    TransformAndLog
 )
 
 from sciencebeam_utils.beam_utils.testing import (
-  BeamTest,
-  TestPipeline,
-  get_current_test_context,
-  get_counter_value
+    BeamTest,
+    TestPipeline,
+    get_current_test_context,
+    get_counter_value
 )
 
 from sciencebeam_utils.utils.collection import (
-  extend_dict
+    extend_dict
 )
 
 from sciencebeam_gym.preprocess.preprocessing_pipeline import (
-  parse_args,
-  configure_pipeline,
-  MetricCounters
+    parse_args,
+    configure_pipeline,
+    MetricCounters
 )
 
 PREPROCESSING_PIPELINE = 'sciencebeam_gym.preprocess.preprocessing_pipeline'
@@ -39,386 +39,410 @@ PDF_FILE_2 = '2/file.pdf'
 XML_FILE_2 = '2/file.xml'
 PDF_XML_FILE_LIST_FILE_1 = 'pdf-xml-files.tsv'
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def fake_content(path):
-  return 'fake content: %s' % path
+    return 'fake content: %s' % path
+
 
 def fake_lxml_for_pdf(pdf, path, page_range=None):
-  return 'fake lxml for pdf: %s (%s) [%s]' % (pdf, path, page_range)
+    return 'fake lxml for pdf: %s (%s) [%s]' % (pdf, path, page_range)
+
+
+def fake_svg_page(i=0):
+    return 'fake svg page: %d' % i
+
+
+def fake_pdf_png_page(i=0):
+    return 'fake pdf png page: %d' % i
+
+
+def fake_block_png_page(i=0):
+    return 'fake block png page: %d' % i
 
-fake_svg_page = lambda i=0: 'fake svg page: %d' % i
-fake_pdf_png_page = lambda i=0: 'fake pdf png page: %d' % i
-fake_block_png_page = lambda i=0: 'fake block png page: %d' % i
 
 def get_global_tfrecords_mock():
-  # workaround for mock that would get serialized/deserialized before being invoked
-  return get_current_test_context().tfrecords_mock
+    # workaround for mock that would get serialized/deserialized before being invoked
+    return get_current_test_context().tfrecords_mock
+
 
 @contextmanager
 def patch_preprocessing_pipeline(**kwargs):
-  always_mock = {
-    'find_file_pairs_grouped_by_parent_directory_or_name',
-    'read_all_from_path',
-    'pdf_bytes_to_png_pages',
-    'convert_pdf_bytes_to_lxml',
-    'convert_and_annotate_lxml_content',
-    'svg_page_to_blockified_png_bytes',
-    'save_svg_roots',
-    'save_pages',
-    'evaluate_document_by_page',
-    'ReadDictCsv'
-  }
-  tfrecords_mock = Mock(name='tfrecords_mock')
-
-  def DummyWritePropsToTFRecord(file_path, extract_props):
-    return TransformAndLog(beam.Map(
-      lambda v: tfrecords_mock(file_path, list(extract_props(v)))
-    ), log_fn=lambda x: get_logger().info('tfrecords: %s', x))
-
-  with patch.multiple(
-    PREPROCESSING_PIPELINE,
-    WritePropsToTFRecord=DummyWritePropsToTFRecord,
-    **{
-      k: kwargs.get(k, DEFAULT)
-      for k in always_mock
+    always_mock = {
+        'find_file_pairs_grouped_by_parent_directory_or_name',
+        'read_all_from_path',
+        'pdf_bytes_to_png_pages',
+        'convert_pdf_bytes_to_lxml',
+        'convert_and_annotate_lxml_content',
+        'svg_page_to_blockified_png_bytes',
+        'save_svg_roots',
+        'save_pages',
+        'evaluate_document_by_page',
+        'ReadDictCsv'
     }
-  ) as mocks:
-    get_current_test_context().mocks = mocks
-    mocks['read_all_from_path'].side_effect = fake_content
-    mocks['convert_pdf_bytes_to_lxml'].side_effect = fake_lxml_for_pdf
-    yield extend_dict(
-      mocks,
-      {'tfrecords': tfrecords_mock}
-    )
+    tfrecords_mock = Mock(name='tfrecords_mock')
+
+    def DummyWritePropsToTFRecord(file_path, extract_props):
+        return TransformAndLog(beam.Map(
+            lambda v: tfrecords_mock(file_path, list(extract_props(v)))
+        ), log_fn=lambda x: get_logger().info('tfrecords: %s', x))
+
+    with patch.multiple(
+        PREPROCESSING_PIPELINE,
+        WritePropsToTFRecord=DummyWritePropsToTFRecord,
+        **{
+            k: kwargs.get(k, DEFAULT)
+            for k in always_mock
+        }
+    ) as mocks:
+        get_current_test_context().mocks = mocks
+        mocks['read_all_from_path'].side_effect = fake_content
+        mocks['convert_pdf_bytes_to_lxml'].side_effect = fake_lxml_for_pdf
+        yield extend_dict(
+            mocks,
+            {'tfrecords': tfrecords_mock}
+        )
 
-MIN_ARGV = [
-  '--data-path=' + BASE_DATA_PATH,
-  '--pdf-path=' + PDF_PATH,
-  '--xml-path=' + XML_PATH,
-  '--save-svg'
-]
 
-def get_default_args():
-  return parse_args([
+MIN_ARGV = [
     '--data-path=' + BASE_DATA_PATH,
     '--pdf-path=' + PDF_PATH,
     '--xml-path=' + XML_PATH,
     '--save-svg'
-  ])
+]
+
+
+def get_default_args():
+    return parse_args([
+        '--data-path=' + BASE_DATA_PATH,
+        '--pdf-path=' + PDF_PATH,
+        '--xml-path=' + XML_PATH,
+        '--save-svg'
+    ])
+
 
 def page_uri_suffix(page_no):
-  return '#page%d' % page_no
+    return '#page%d' % page_no
+
 
 def _expected_tfrecord_props(pdf_file, page_no=1):
-  return {
-    'input_uri': pdf_file + page_uri_suffix(page_no),
-    'annotation_uri': pdf_file + '.annot' + page_uri_suffix(page_no),
-    'input_image': fake_pdf_png_page(page_no),
-    'annotation_image': fake_block_png_page(page_no),
-    'page_no': page_no
-  }
+    return {
+        'input_uri': pdf_file + page_uri_suffix(page_no),
+        'annotation_uri': pdf_file + '.annot' + page_uri_suffix(page_no),
+        'input_image': fake_pdf_png_page(page_no),
+        'annotation_image': fake_block_png_page(page_no),
+        'page_no': page_no
+    }
+
 
 def _setup_mocks_for_pages(mocks, page_no_list, file_count=1):
-  mocks['convert_and_annotate_lxml_content'].return_value = [
-    fake_svg_page(i) for i in page_no_list
-  ]
-  mocks['pdf_bytes_to_png_pages'].return_value = [
-    fake_pdf_png_page(i) for i in page_no_list
-  ]
-  mocks['svg_page_to_blockified_png_bytes'].side_effect = [
-    fake_block_png_page(i)
-    for _ in range(file_count)
-    for i in page_no_list
-  ]
+    mocks['convert_and_annotate_lxml_content'].return_value = [
+        fake_svg_page(i) for i in page_no_list
+    ]
+    mocks['pdf_bytes_to_png_pages'].return_value = [
+        fake_pdf_png_page(i) for i in page_no_list
+    ]
+    mocks['svg_page_to_blockified_png_bytes'].side_effect = [
+        fake_block_png_page(i)
+        for _ in range(file_count)
+        for i in page_no_list
+    ]
+
 
 @pytest.mark.slow
 class TestConfigurePipeline(BeamTest):
-  def test_should_pass_pdf_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self):
-    with patch_preprocessing_pipeline() as mocks:
-      opt = get_default_args()
-      opt.base_data_path = 'base'
-      opt.pdf_path = 'pdf'
-      opt.xml_path = 'xml'
-      with TestPipeline() as p:
-        mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = []
-        configure_pipeline(p, opt)
-
-      mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with(
-        ['base/pdf', 'base/xml']
-      )
-
-  def test_should_pass_lxml_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self):
-    with patch_preprocessing_pipeline() as mocks:
-      opt = get_default_args()
-      opt.base_data_path = 'base'
-      opt.pdf_path = ''
-      opt.lxml_path = 'lxml'
-      opt.xml_path = 'xml'
-      with TestPipeline() as p:
-        mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = []
-        configure_pipeline(p, opt)
-
-      mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with(
-        ['base/lxml', 'base/xml']
-      )
-
-  def test_should_write_tfrecords_from_pdf_xml_file_list(self):
-    with patch_preprocessing_pipeline() as mocks:
-      opt = get_default_args()
-      opt.pdf_path = None
-      opt.xml_path = None
-      opt.pdf_xml_file_list = '.temp/file-list.tsv'
-      opt.save_tfrecords = True
-      with TestPipeline() as p:
-        mocks['ReadDictCsv'].return_value = beam.Create([{
-          'source_url': PDF_FILE_1,
-          'xml_url': XML_FILE_1
-        }])
-        _setup_mocks_for_pages(mocks, [1])
-        configure_pipeline(p, opt)
-
-      mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None)
-      mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
-        _expected_tfrecord_props(PDF_FILE_1)
-      ])
-
-  def test_should_write_multiple_tfrecords_from_pdf_xml_file_list(self):
-    with patch_preprocessing_pipeline() as mocks:
-      opt = get_default_args()
-      opt.pdf_path = None
-      opt.xml_path = None
-      opt.pdf_xml_file_list = '.temp/file-list.tsv'
-      opt.save_tfrecords = True
-      with TestPipeline() as p:
-        mocks['ReadDictCsv'].return_value = beam.Create([{
-          'source_url': PDF_FILE_1,
-          'xml_url': XML_FILE_1
-        }, {
-          'source_url': PDF_FILE_2,
-          'xml_url': XML_FILE_2
-        }])
-        _setup_mocks_for_pages(mocks, [1], file_count=2)
-        configure_pipeline(p, opt)
-
-      mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None)
-      for pdf_file in [PDF_FILE_1, PDF_FILE_2]:
-        mocks['tfrecords'].assert_any_call(opt.output_path + '/data', [
-          _expected_tfrecord_props(pdf_file)
-        ])
-      assert mocks['tfrecords'].call_count == 2
-
-  def test_should_pass_limit_to_read_dict_csv(self):
-    with patch_preprocessing_pipeline() as mocks:
-      opt = get_default_args()
-      opt.pdf_path = None
-      opt.xml_path = None
-      opt.pdf_xml_file_list = '.temp/file-list.tsv'
-      opt.limit = 1
-      opt.save_tfrecords = True
-      with TestPipeline() as p:
-        mocks['ReadDictCsv'].return_value = beam.Create([{
-          'source_url': PDF_FILE_1,
-          'xml_url': XML_FILE_1
-        }])
-        _setup_mocks_for_pages(mocks, [1])
-        configure_pipeline(p, opt)
-
-      mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=opt.limit)
-      assert mocks['tfrecords'].call_count == 1
-
-  def test_should_pass_limit_to_find_file_pairs_grouped_by_parent_directory_or_name(self):
-    with patch_preprocessing_pipeline() as mocks:
-      opt = get_default_args()
-      opt.base_data_path = 'base'
-      opt.pdf_path = ''
-      opt.lxml_path = 'lxml'
-      opt.xml_path = 'xml'
-      opt.save_tfrecords = True
-      opt.limit = 1
-      with TestPipeline() as p:
-        mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
-          (PDF_FILE_1, XML_FILE_1),
-          (PDF_FILE_2, XML_FILE_2)
-        ]
-        configure_pipeline(p, opt)
-
-      mocks['tfrecords'].call_count == 1
-
-  def test_should_write_tfrecords_from_pdf_xml_path(self):
-    with patch_preprocessing_pipeline() as mocks:
-      opt = get_default_args()
-      opt.save_tfrecords = True
-      with TestPipeline() as p:
-        mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
-          (PDF_FILE_1, XML_FILE_1)
-        ]
-        _setup_mocks_for_pages(mocks, [1])
-        configure_pipeline(p, opt)
-
-      mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
-        _expected_tfrecord_props(PDF_FILE_1)
-      ])
-
-  def test_should_write_multiple_tfrecords_and_count_pages(self):
-    with patch_preprocessing_pipeline() as mocks:
-      opt = get_default_args()
-      opt.save_tfrecords = True
-      with TestPipeline() as p:
-        mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
-          (PDF_FILE_1, XML_FILE_1)
-        ]
-        _setup_mocks_for_pages(mocks, [1, 2])
-        configure_pipeline(p, opt)
-
-        p_result = p.run()
-        assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1
-        assert get_counter_value(p_result, MetricCounters.PAGE) == 2
-        assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) is None
-
-      mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
-        _expected_tfrecord_props(PDF_FILE_1, page_no=i)
-        for i in [1, 2]
-      ])
-
-  def test_should_not_write_tfrecord_below_annotation_threshold_and_count_pages(self):
-    custom_mocks = dict(
-      evaluate_document_by_page=lambda _: [{
-        'percentage': {
-          # low percentage of None (no annotation, include)
-          None: 0.1
-        }
-      }, {
-        'percentage': {
-          # low percentage of None (no annotation, exclude)
-          None: 0.9
-        }
-      }]
-    )
-    with patch_preprocessing_pipeline(**custom_mocks) as mocks:
-      opt = get_default_args()
-      opt.save_tfrecords = True
-      opt.min_annotation_percentage = 0.5
-      with TestPipeline() as p:
-        mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
-          (PDF_FILE_1, XML_FILE_1)
-        ]
-        _setup_mocks_for_pages(mocks, [1, 2])
-        configure_pipeline(p, opt)
-
-        p_result = p.run()
-        assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1
-        assert get_counter_value(p_result, MetricCounters.PAGE) == 2
-        assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) == 1
-
-      mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
-        _expected_tfrecord_props(PDF_FILE_1, page_no=i)
-        for i in [1]
-      ])
-
-  def test_should_only_process_selected_pages(self):
-    with patch_preprocessing_pipeline() as mocks:
-      opt = get_default_args()
-      opt.save_tfrecords = True
-      opt.save_png = True
-      opt.pages = (1, 3)
-      with TestPipeline() as p:
-        mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
-          (PDF_FILE_1, XML_FILE_1)
-        ]
-        _setup_mocks_for_pages(mocks, [1, 2])
-        configure_pipeline(p, opt)
-
-      assert mocks['convert_pdf_bytes_to_lxml'].called
-      assert mocks['convert_pdf_bytes_to_lxml'].call_args[1].get('page_range') == opt.pages
-
-      assert mocks['pdf_bytes_to_png_pages'].called
-      assert mocks['pdf_bytes_to_png_pages'].call_args[1].get('page_range') == opt.pages
+    def test_should_pass_pdf_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self):
+        with patch_preprocessing_pipeline() as mocks:
+            opt = get_default_args()
+            opt.base_data_path = 'base'
+            opt.pdf_path = 'pdf'
+            opt.xml_path = 'xml'
+            with TestPipeline() as p:
+                mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = []
+                configure_pipeline(p, opt)
+
+            mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with(
+                ['base/pdf', 'base/xml']
+            )
+
+    def test_should_pass_lxml_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self):
+        with patch_preprocessing_pipeline() as mocks:
+            opt = get_default_args()
+            opt.base_data_path = 'base'
+            opt.pdf_path = ''
+            opt.lxml_path = 'lxml'
+            opt.xml_path = 'xml'
+            with TestPipeline() as p:
+                mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = []
+                configure_pipeline(p, opt)
+
+            mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with(
+                ['base/lxml', 'base/xml']
+            )
+
+    def test_should_write_tfrecords_from_pdf_xml_file_list(self):
+        with patch_preprocessing_pipeline() as mocks:
+            opt = get_default_args()
+            opt.pdf_path = None
+            opt.xml_path = None
+            opt.pdf_xml_file_list = '.temp/file-list.tsv'
+            opt.save_tfrecords = True
+            with TestPipeline() as p:
+                mocks['ReadDictCsv'].return_value = beam.Create([{
+                    'source_url': PDF_FILE_1,
+                    'xml_url': XML_FILE_1
+                }])
+                _setup_mocks_for_pages(mocks, [1])
+                configure_pipeline(p, opt)
+
+            mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None)
+            mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
+                _expected_tfrecord_props(PDF_FILE_1)
+            ])
+
+    def test_should_write_multiple_tfrecords_from_pdf_xml_file_list(self):
+        with patch_preprocessing_pipeline() as mocks:
+            opt = get_default_args()
+            opt.pdf_path = None
+            opt.xml_path = None
+            opt.pdf_xml_file_list = '.temp/file-list.tsv'
+            opt.save_tfrecords = True
+            with TestPipeline() as p:
+                mocks['ReadDictCsv'].return_value = beam.Create([{
+                    'source_url': PDF_FILE_1,
+                    'xml_url': XML_FILE_1
+                }, {
+                    'source_url': PDF_FILE_2,
+                    'xml_url': XML_FILE_2
+                }])
+                _setup_mocks_for_pages(mocks, [1], file_count=2)
+                configure_pipeline(p, opt)
+
+            mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None)
+            for pdf_file in [PDF_FILE_1, PDF_FILE_2]:
+                mocks['tfrecords'].assert_any_call(opt.output_path + '/data', [
+                    _expected_tfrecord_props(pdf_file)
+                ])
+            assert mocks['tfrecords'].call_count == 2
+
+    def test_should_pass_limit_to_read_dict_csv(self):
+        with patch_preprocessing_pipeline() as mocks:
+            opt = get_default_args()
+            opt.pdf_path = None
+            opt.xml_path = None
+            opt.pdf_xml_file_list = '.temp/file-list.tsv'
+            opt.limit = 1
+            opt.save_tfrecords = True
+            with TestPipeline() as p:
+                mocks['ReadDictCsv'].return_value = beam.Create([{
+                    'source_url': PDF_FILE_1,
+                    'xml_url': XML_FILE_1
+                }])
+                _setup_mocks_for_pages(mocks, [1])
+                configure_pipeline(p, opt)
+
+            mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=opt.limit)
+            assert mocks['tfrecords'].call_count == 1
+
+    def test_should_pass_limit_to_find_file_pairs_grouped_by_parent_directory_or_name(self):
+        with patch_preprocessing_pipeline() as mocks:
+            opt = get_default_args()
+            opt.base_data_path = 'base'
+            opt.pdf_path = 'pdf'
+            opt.lxml_path = ''
+            opt.xml_path = 'xml'
+            opt.save_tfrecords = True
+            opt.limit = 1
+            with TestPipeline() as p:
+                mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
+                    (PDF_FILE_1, XML_FILE_1),
+                    (PDF_FILE_2, XML_FILE_2)
+                ]
+                configure_pipeline(p, opt)
+
+            assert mocks['tfrecords'].call_count == 1
+
+    def test_should_write_tfrecords_from_pdf_xml_path(self):
+        with patch_preprocessing_pipeline() as mocks:
+            opt = get_default_args()
+            opt.save_tfrecords = True
+            with TestPipeline() as p:
+                mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
+                    (PDF_FILE_1, XML_FILE_1)
+                ]
+                _setup_mocks_for_pages(mocks, [1])
+                configure_pipeline(p, opt)
+
+            mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
+                _expected_tfrecord_props(PDF_FILE_1)
+            ])
+
+    def test_should_write_multiple_tfrecords_and_count_pages(self):
+        with patch_preprocessing_pipeline() as mocks:
+            opt = get_default_args()
+            opt.save_tfrecords = True
+            with TestPipeline() as p:
+                mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
+                    (PDF_FILE_1, XML_FILE_1)
+                ]
+                _setup_mocks_for_pages(mocks, [1, 2])
+                configure_pipeline(p, opt)
+
+                p_result = p.run()
+                assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1
+                assert get_counter_value(p_result, MetricCounters.PAGE) == 2
+                assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) is None
+
+            mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
+                _expected_tfrecord_props(PDF_FILE_1, page_no=i)
+                for i in [1, 2]
+            ])
+
+    def test_should_not_write_tfrecord_below_annotation_threshold_and_count_pages(self):
+        custom_mocks = dict(
+            evaluate_document_by_page=lambda _: [{
+                'percentage': {
+                    # low percentage of None (no annotation, include)
+                    None: 0.1
+                }
+            }, {
+                'percentage': {
+                    # low percentage of None (no annotation, exclude)
+                    None: 0.9
+                }
+            }]
+        )
+        with patch_preprocessing_pipeline(**custom_mocks) as mocks:
+            opt = get_default_args()
+            opt.save_tfrecords = True
+            opt.min_annotation_percentage = 0.5
+            with TestPipeline() as p:
+                mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
+                    (PDF_FILE_1, XML_FILE_1)
+                ]
+                _setup_mocks_for_pages(mocks, [1, 2])
+                configure_pipeline(p, opt)
+
+                p_result = p.run()
+                assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1
+                assert get_counter_value(p_result, MetricCounters.PAGE) == 2
+                assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) == 1
+
+            mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
+                _expected_tfrecord_props(PDF_FILE_1, page_no=i)
+                for i in [1]
+            ])
+
+    def test_should_only_process_selected_pages(self):
+        with patch_preprocessing_pipeline() as mocks:
+            opt = get_default_args()
+            opt.save_tfrecords = True
+            opt.save_png = True
+            opt.pages = (1, 3)
+            with TestPipeline() as p:
+                mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
+                    (PDF_FILE_1, XML_FILE_1)
+                ]
+                _setup_mocks_for_pages(mocks, [1, 2])
+                configure_pipeline(p, opt)
+
+            assert mocks['convert_pdf_bytes_to_lxml'].called
+            assert mocks['convert_pdf_bytes_to_lxml'].call_args[1].get('page_range') == opt.pages
+
+            assert mocks['pdf_bytes_to_png_pages'].called
+            assert mocks['pdf_bytes_to_png_pages'].call_args[1].get('page_range') == opt.pages
 
-class TestParseArgs(object):
-  def test_should_raise_error_without_arguments(self):
-    with pytest.raises(SystemExit):
-      parse_args([])
 
-  def test_should_not_raise_error_with_minimum_arguments(self):
-    parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-svg'])
+class TestParseArgs(object):
+    def test_should_raise_error_without_arguments(self):
+        with pytest.raises(SystemExit):
+            parse_args([])
 
-  def test_should_not_raise_error_with_lxml_path_instead_of_pdf_path(self):
-    parse_args(['--data-path=test', '--lxml-path=test', '--xml-path=test', '--save-svg'])
+    def test_should_not_raise_error_with_minimum_arguments(self):
+        parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-svg'])
 
-  def test_should_raise_error_if_no_output_option_specified(self):
-    with pytest.raises(SystemExit):
-      parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test'])
+    def test_should_not_raise_error_with_lxml_path_instead_of_pdf_path(self):
+        parse_args(['--data-path=test', '--lxml-path=test', '--xml-path=test', '--save-svg'])
 
-  def test_should_raise_error_if_pdf_and_lxml_path_are_specified(self):
-    with pytest.raises(SystemExit):
-      parse_args([
-        '--data-path=test', '--pdf-path=test', '--lxml-path=test', '--xml-path=test',
-        '--save-svg'
-      ])
+    def test_should_raise_error_if_no_output_option_specified(self):
+        with pytest.raises(SystemExit):
+            parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test'])
 
-  def test_should_raise_error_if_pdf_path_specified_without_xml_path(self):
-    with pytest.raises(SystemExit):
-      parse_args(['--data-path=test', '--pdf-path=test', '--save-svg'])
+    def test_should_raise_error_if_pdf_and_lxml_path_are_specified(self):
+        with pytest.raises(SystemExit):
+            parse_args([
+                '--data-path=test', '--pdf-path=test', '--lxml-path=test', '--xml-path=test',
+                '--save-svg'
+            ])
 
-  def test_should_not_raise_error_if_pdf_xml_file_list_specified_without_xml_path(self):
-    parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-svg'])
+    def test_should_raise_error_if_pdf_path_specified_without_xml_path(self):
+        with pytest.raises(SystemExit):
+            parse_args(['--data-path=test', '--pdf-path=test', '--save-svg'])
 
-  def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_path(self):
-    parse_args(['--data-path=test', '--pdf-path=test', '--save-lxml', '--xml-path=test'])
+    def test_should_not_raise_error_if_pdf_xml_file_list_specified_without_xml_path(self):
+        parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-svg'])
 
-  def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_xml_file_list(self):
-    parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-lxml', '--xml-path=test'])
+    def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_path(self):
+        parse_args(['--data-path=test', '--pdf-path=test', '--save-lxml', '--xml-path=test'])
 
-  def test_should_raise_error_if_save_lxml_specified_without_pdf_path(self):
-    with pytest.raises(SystemExit):
-      parse_args(['--data-path=test', '--lxml-path=test', '--save-lxml', '--xml-path=test'])
+    def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_xml_file_list(self):
+        parse_args(['--data-path=test', '--pdf-xml-file-list=test',
+                    '--save-lxml', '--xml-path=test'])
 
-  def test_should_raise_error_if_save_png_is_specified_without_pdf_path(self):
-    with pytest.raises(SystemExit):
-      parse_args(['--data-path=test', '--lxml-path=test', '--save-png', '--xml-path=test'])
+    def test_should_raise_error_if_save_lxml_specified_without_pdf_path(self):
+        with pytest.raises(SystemExit):
+            parse_args(['--data-path=test', '--lxml-path=test', '--save-lxml', '--xml-path=test'])
 
-  def test_should_not_raise_error_with_save_png_path_together_with_pdf_path(self):
-    parse_args(['--data-path=test', '--pdf-path=test', '--save-png', '--xml-path=test'])
+    def test_should_raise_error_if_save_png_is_specified_without_pdf_path(self):
+        with pytest.raises(SystemExit):
+            parse_args(['--data-path=test', '--lxml-path=test', '--save-png', '--xml-path=test'])
 
-  def test_should_not_raise_error_with_save_png_path_together_with_pdf_xml_file_list(self):
-    parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-png', '--xml-path=test'])
+    def test_should_not_raise_error_with_save_png_path_together_with_pdf_path(self):
+        parse_args(['--data-path=test', '--pdf-path=test', '--save-png', '--xml-path=test'])
 
-  def test_should_raise_error_if_image_width_was_specified_without_image_height(self):
-    with pytest.raises(SystemExit):
-      parse_args([
-        '--data-path=test', '--pdf-path=test', '--xml-path=test',
-        '--save-png', '--image-width=100'
-      ])
-
-  def test_should_raise_error_if_image_height_was_specified_without_image_width(self):
-    with pytest.raises(SystemExit):
-      parse_args([
-        '--data-path=test', '--pdf-path=test', '--xml-path=test',
-        '--save-png', '--image-height=100'
-      ])
+    def test_should_not_raise_error_with_save_png_path_together_with_pdf_xml_file_list(self):
+        parse_args([
+            '--data-path=test', '--pdf-xml-file-list=test', '--save-png', '--xml-path=test'
+        ])
 
-  def test_should_not_raise_error_if_both_image_width_and_height_are_specified(self):
-    parse_args([
-      '--data-path=test', '--pdf-path=test', '--xml-path=test',
-      '--save-png', '--image-width=100', '--image-height=100'
-    ])
+    def test_should_raise_error_if_image_width_was_specified_without_image_height(self):
+        with pytest.raises(SystemExit):
+            parse_args([
+                '--data-path=test', '--pdf-path=test', '--xml-path=test',
+                '--save-png', '--image-width=100'
+            ])
+
+    def test_should_raise_error_if_image_height_was_specified_without_image_width(self):
+        with pytest.raises(SystemExit):
+            parse_args([
+                '--data-path=test', '--pdf-path=test', '--xml-path=test',
+                '--save-png', '--image-height=100'
+            ])
+
+    def test_should_not_raise_error_if_both_image_width_and_height_are_specified(self):
+        parse_args([
+            '--data-path=test', '--pdf-path=test', '--xml-path=test',
+            '--save-png', '--image-width=100', '--image-height=100'
+        ])
 
-  def test_should_raise_error_if_save_tfrecords_specified_without_pdf_path(self):
-    with pytest.raises(SystemExit):
-      parse_args(['--data-path=test', '--lxml-path=test', '--xml-path=test', '--save-tfrecords'])
+    def test_should_raise_error_if_save_tfrecords_specified_without_pdf_path(self):
+        with pytest.raises(SystemExit):
+            parse_args(['--data-path=test', '--lxml-path=test',
+                        '--xml-path=test', '--save-tfrecords'])
 
-  def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_path(self):
-    parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-tfrecords'])
+    def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_path(self):
+        parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-tfrecords'])
 
-  def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_xml_file_list(self):
-    parse_args([
-      '--data-path=test', '--pdf-xml-file-list=test', '--xml-path=test', '--save-tfrecords'
-    ])
+    def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_xml_file_list(self):
+        parse_args([
+            '--data-path=test', '--pdf-xml-file-list=test', '--xml-path=test', '--save-tfrecords'
+        ])
 
-  def test_should_have_none_page_range_by_default(self):
-    assert parse_args(MIN_ARGV).pages is None
+    def test_should_have_none_page_range_by_default(self):
+        assert parse_args(MIN_ARGV).pages is None
 
-  def test_should_parse_pages_as_list(self):
-    assert parse_args(MIN_ARGV + ['--pages=1-3']).pages == (1, 3)
+    def test_should_parse_pages_as_list(self):
+        assert parse_args(MIN_ARGV + ['--pages=1-3']).pages == (1, 3)
diff --git a/sciencebeam_gym/preprocess/preprocessing_transforms.py b/sciencebeam_gym/preprocess/preprocessing_transforms.py
index 20bfe32..5cbf7bf 100644
--- a/sciencebeam_gym/preprocess/preprocessing_transforms.py
+++ b/sciencebeam_gym/preprocess/preprocessing_transforms.py
@@ -1,34 +1,34 @@
 import apache_beam as beam
 
 try:
-  import tensorflow as tf
+    import tensorflow as tf
 
-  from sciencebeam_gym.utils.tfrecord import (
-    dict_to_example
-  )
+    from sciencebeam_gym.utils.tfrecord import (
+        dict_to_example
+    )
 except ImportError:
-  # Make tensorflow optional
-  tf = None
+    # Make tensorflow optional
+    tf = None
 
 
 class WritePropsToTFRecord(beam.PTransform):
-  def __init__(self, file_path, extract_props):
-    super(WritePropsToTFRecord, self).__init__()
-    self.file_path = file_path
-    self.extract_props = extract_props
-    if tf is None:
-      raise RuntimeError('TensorFlow required for this transform')
+    def __init__(self, file_path, extract_props):
+        super(WritePropsToTFRecord, self).__init__()
+        self.file_path = file_path
+        self.extract_props = extract_props
+        if tf is None:
+            raise RuntimeError('TensorFlow required for this transform')
 
-  def expand(self, pcoll): # pylint: disable=W0221
-    return (
-      pcoll |
-      'ConvertToTfExamples' >> beam.FlatMap(lambda v: (
-        dict_to_example(props)
-        for props in self.extract_props(v)
-      )) |
-      'SerializeToString' >> beam.Map(lambda x: x.SerializeToString()) |
-      'SaveToTfRecords' >> beam.io.WriteToTFRecord(
-        self.file_path,
-        file_name_suffix='.tfrecord.gz'
-      )
-    )
+    def expand(self, pcoll):  # pylint: disable=W0221
+        return (
+            pcoll |
+            'ConvertToTfExamples' >> beam.FlatMap(lambda v: (
+                dict_to_example(props)
+                for props in self.extract_props(v)
+            )) |
+            'SerializeToString' >> beam.Map(lambda x: x.SerializeToString()) |
+            'SaveToTfRecords' >> beam.io.WriteToTFRecord(
+                self.file_path,
+                file_name_suffix='.tfrecord.gz'
+            )
+        )
diff --git a/sciencebeam_gym/preprocess/preprocessing_transforms_test.py b/sciencebeam_gym/preprocess/preprocessing_transforms_test.py
index 732700c..9b1a170 100644
--- a/sciencebeam_gym/preprocess/preprocessing_transforms_test.py
+++ b/sciencebeam_gym/preprocess/preprocessing_transforms_test.py
@@ -4,20 +4,20 @@ import apache_beam as beam
 from apache_beam.io.filesystems import FileSystems
 
 from sciencebeam_utils.beam_utils.io import (
-  find_matching_filenames
+    find_matching_filenames
 )
 
 from sciencebeam_utils.beam_utils.testing import (
-  BeamTest,
-  TestPipeline
+    BeamTest,
+    TestPipeline
 )
 
 from sciencebeam_gym.utils.tfrecord import (
-  iter_read_tfrecord_file_as_dict_list
+    iter_read_tfrecord_file_as_dict_list
 )
 
 from sciencebeam_gym.preprocess.preprocessing_transforms import (
-  WritePropsToTFRecord
+    WritePropsToTFRecord
 )
 
 TFRECORDS_PATH = '.temp/test-data'
@@ -28,35 +28,36 @@ KEY_2 = b'key2'
 VALUE_1 = b'value 1'
 VALUE_2 = b'value 2'
 
+
 @pytest.mark.slow
 class TestWritePropsToTFRecord(BeamTest):
-  def test_should_write_single_entry(self):
-    dict_list = [{KEY_1: VALUE_1}]
-    with TestPipeline() as p:
-      _ = (
-        p |
-        beam.Create(dict_list) |
-        WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x])
-      )
-    filenames = list(find_matching_filenames(TFRECORDS_PATH + '*'))
-    assert len(filenames) == 1
-    records = list(iter_read_tfrecord_file_as_dict_list(filenames[0]))
-    assert records == dict_list
-    FileSystems.delete(filenames)
-
-  def test_should_write_multiple_entries(self):
-    dict_list = [
-      {KEY_1: VALUE_1},
-      {KEY_2: VALUE_2}
-    ]
-    with TestPipeline() as p:
-      _ = (
-        p |
-        beam.Create(dict_list) |
-        WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x])
-      )
-    filenames = list(find_matching_filenames(TFRECORDS_PATH + '*'))
-    assert len(filenames) == 1
-    records = list(iter_read_tfrecord_file_as_dict_list(filenames[0]))
-    assert records == dict_list
-    FileSystems.delete(filenames)
+    def test_should_write_single_entry(self):
+        dict_list = [{KEY_1: VALUE_1}]
+        with TestPipeline() as p:
+            _ = (  # flake8: noqa
+                p |
+                beam.Create(dict_list) |
+                WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x])
+            )
+        filenames = list(find_matching_filenames(TFRECORDS_PATH + '*'))
+        assert len(filenames) == 1
+        records = list(iter_read_tfrecord_file_as_dict_list(filenames[0]))
+        assert records == dict_list
+        FileSystems.delete(filenames)
+
+    def test_should_write_multiple_entries(self):
+        dict_list = [
+            {KEY_1: VALUE_1},
+            {KEY_2: VALUE_2}
+        ]
+        with TestPipeline() as p:
+            _ = (  # flake8: noqa
+                p |
+                beam.Create(dict_list) |
+                WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x])
+            )
+        filenames = list(find_matching_filenames(TFRECORDS_PATH + '*'))
+        assert len(filenames) == 1
+        records = list(iter_read_tfrecord_file_as_dict_list(filenames[0]))
+        assert records == dict_list
+        FileSystems.delete(filenames)
diff --git a/sciencebeam_gym/preprocess/preprocessing_utils.py b/sciencebeam_gym/preprocess/preprocessing_utils.py
index 62d7be8..0800000 100644
--- a/sciencebeam_gym/preprocess/preprocessing_utils.py
+++ b/sciencebeam_gym/preprocess/preprocessing_utils.py
@@ -1,217 +1,208 @@
 from __future__ import absolute_import
 
-import os
 import logging
 from io import BytesIO
-from functools import reduce # pylint: disable=W0622
 
 from six import iteritems
 
 from lxml import etree
 
-from apache_beam.io.filesystems import FileSystems
-
 from sciencebeam_alignment.align import (
-  native_enabled as align_native_enabled
-)
-
-from sciencebeam_utils.beam_utils.io import (
-  find_matching_filenames
+    native_enabled as align_native_enabled
 )
 
 from sciencebeam_utils.utils.xml import (
-  xml_from_string_with_recover
+    xml_from_string_with_recover
 )
 
 from sciencebeam_utils.utils.stopwatch import (
-  StopWatchRecorder
-)
-
-from sciencebeam_utils.utils.collection import (
-  groupby_to_dict,
-  sort_and_groupby_to_dict
-)
-
-from sciencebeam_utils.utils.file_path import (
-  relative_path
+    StopWatchRecorder
 )
 
 from sciencebeam_gym.utils.pages_zip import (
-  save_pages
+    save_pages
 )
 
 from sciencebeam_gym.preprocess.lxml_to_svg import (
-  iter_svg_pages_for_lxml
+    iter_svg_pages_for_lxml
 )
 
 from sciencebeam_gym.structured_document.svg import (
-  SvgStructuredDocument
+    SvgStructuredDocument
 )
 
 from sciencebeam_gym.preprocess.annotation.annotator import (
-  Annotator,
-  DEFAULT_ANNOTATORS
+    Annotator,
+    DEFAULT_ANNOTATORS
 )
 
 from sciencebeam_gym.preprocess.annotation.matching_annotator import (
-  MatchingAnnotator
+    MatchingAnnotator
 )
 
 from sciencebeam_gym.preprocess.annotation.target_annotation import (
-  xml_root_to_target_annotations
+    xml_root_to_target_annotations
 )
 
 from sciencebeam_gym.preprocess.visualize_svg_annotation import (
-  visualize_svg_annotations
+    visualize_svg_annotations
 )
 
 from sciencebeam_gym.preprocess.blockify_annotations import (
-  annotation_document_page_to_annotation_blocks,
-  merge_blocks,
-  expand_blocks,
-  annotated_blocks_to_image
+    annotation_document_page_to_annotation_blocks,
+    merge_blocks,
+    expand_blocks,
+    annotated_blocks_to_image
 )
 
 from sciencebeam_gym.pdf import (
-  PdfToLxmlWrapper,
-  PdfToPng
+    PdfToLxmlWrapper,
+    PdfToPng
 )
 
 
 def get_logger():
-  return logging.getLogger(__name__)
-
-def convert_pdf_bytes_to_lxml(pdf_content, path=None, page_range=None):
-  stop_watch_recorder = StopWatchRecorder()
+    return logging.getLogger(__name__)
 
-  args = '-blocks -noImageInline -noImage -fullFontName'.split()
-  if page_range:
-    args += ['-f', str(page_range[0]), '-l', str(page_range[1])]
 
-  stop_watch_recorder.start('convert to lxml')
-  lxml_content = PdfToLxmlWrapper().process_input(
-    pdf_content,
-    args
-  )
-  stop_watch_recorder.stop()
-
-  get_logger().info(
-    'converted to lxml: path=%s, pdf size=%s, lxml size=%s, timings=[%s]',
-    path, format(len(pdf_content), ','), format(len(lxml_content), ','),
-    stop_watch_recorder
-  )
+def convert_pdf_bytes_to_lxml(pdf_content, path=None, page_range=None):
+    stop_watch_recorder = StopWatchRecorder()
 
-  return lxml_content
+    args = '-blocks -noImageInline -noImage -fullFontName'.split()
+    if page_range:
+        args += ['-f', str(page_range[0]), '-l', str(page_range[1])]
 
-def convert_and_annotate_lxml_content(lxml_content, xml_content, xml_mapping, name=None):
-  stop_watch_recorder = StopWatchRecorder()
+    stop_watch_recorder.start('convert to lxml')
+    lxml_content = PdfToLxmlWrapper().process_input(
+        pdf_content,
+        args
+    )
+    stop_watch_recorder.stop()
 
-  stop_watch_recorder.start('parse lxml')
-  lxml_root = etree.fromstring(lxml_content)
+    get_logger().info(
+        'converted to lxml: path=%s, pdf size=%s, lxml size=%s, timings=[%s]',
+        path, format(len(pdf_content), ','), format(len(lxml_content), ','),
+        stop_watch_recorder
+    )
 
-  # use a more lenient way to parse xml as xml errors are not uncomment
-  stop_watch_recorder.start('parse xml')
-  xml_root = xml_from_string_with_recover(xml_content)
+    return lxml_content
 
-  stop_watch_recorder.start('extract target annotations')
-  target_annotations = xml_root_to_target_annotations(
-    xml_root,
-    xml_mapping
-  )
-  stop_watch_recorder.stop()
 
-  annotators = DEFAULT_ANNOTATORS + [MatchingAnnotator(
-    target_annotations,
-    use_tag_begin_prefix=True
-  )]
-  annotator = Annotator(annotators)
+def convert_and_annotate_lxml_content(lxml_content, xml_content, xml_mapping, name=None):
+    stop_watch_recorder = StopWatchRecorder()
 
-  stop_watch_recorder.start('convert to svg')
-  svg_roots = list(iter_svg_pages_for_lxml(lxml_root))
+    stop_watch_recorder.start('parse lxml')
+    lxml_root = etree.fromstring(lxml_content)
 
-  stop_watch_recorder.start('annotate svg')
-  annotator.annotate(SvgStructuredDocument(svg_roots))
+    # use a more lenient way to parse xml as xml errors are not uncomment
+    stop_watch_recorder.start('parse xml')
+    xml_root = xml_from_string_with_recover(xml_content)
 
-  stop_watch_recorder.start('add visualisation')
-  svg_roots = [
-    visualize_svg_annotations(svg_root)
-    for svg_root in svg_roots
-  ]
-  stop_watch_recorder.stop()
+    stop_watch_recorder.start('extract target annotations')
+    target_annotations = xml_root_to_target_annotations(
+        xml_root,
+        xml_mapping
+    )
+    stop_watch_recorder.stop()
+
+    annotators = DEFAULT_ANNOTATORS + [MatchingAnnotator(
+        target_annotations,
+        use_tag_begin_prefix=True
+    )]
+    annotator = Annotator(annotators)
+
+    stop_watch_recorder.start('convert to svg')
+    svg_roots = list(iter_svg_pages_for_lxml(lxml_root))
+
+    stop_watch_recorder.start('annotate svg')
+    annotator.annotate(SvgStructuredDocument(svg_roots))
+
+    stop_watch_recorder.start('add visualisation')
+    svg_roots = [
+        visualize_svg_annotations(svg_root)
+        for svg_root in svg_roots
+    ]
+    stop_watch_recorder.stop()
+
+    get_logger().info(
+        'processed: name=%s, lxml size=%s, xml size=%s, timings=[%s] (native align impl=%s)',
+        name, format(len(lxml_content), ','), format(len(xml_content), ','),
+        stop_watch_recorder, align_native_enabled
+    )
 
-  get_logger().info(
-    'processed: name=%s, lxml size=%s, xml size=%s, timings=[%s] (native align impl=%s)',
-    name, format(len(lxml_content), ','), format(len(xml_content), ','),
-    stop_watch_recorder, align_native_enabled
-  )
+    return svg_roots
 
-  return svg_roots
 
 def save_svg_roots(output_filename, svg_pages):
-  return save_pages(output_filename, '.svg', (
-    etree.tostring(svg_page)
-    for svg_page in svg_pages
-  ))
+    return save_pages(output_filename, '.svg', (
+        etree.tostring(svg_page)
+        for svg_page in svg_pages
+    ))
+
 
 def pdf_bytes_to_png_pages(pdf_bytes, dpi, image_size, page_range=None):
-  pdf_to_png = PdfToPng(dpi=dpi, image_size=image_size, page_range=page_range)
-  return (
-    fp.read()
-    for fp in pdf_to_png.iter_pdf_bytes_to_png_fp(pdf_bytes)
-  )
+    pdf_to_png = PdfToPng(dpi=dpi, image_size=image_size, page_range=page_range)
+    return (
+        fp.read()
+        for fp in pdf_to_png.iter_pdf_bytes_to_png_fp(pdf_bytes)
+    )
+
 
 def svg_page_to_blockified_png_bytes(svg_page, color_map, image_size=None):
-  structured_document = SvgStructuredDocument(svg_page)
-  blocks = expand_blocks(
-    merge_blocks(
-      annotation_document_page_to_annotation_blocks(
-        structured_document,
-        structured_document.get_pages()[0]
-      )
+    structured_document = SvgStructuredDocument(svg_page)
+    blocks = expand_blocks(
+        merge_blocks(
+            annotation_document_page_to_annotation_blocks(
+                structured_document,
+                structured_document.get_pages()[0]
+            )
+        )
     )
-  )
-  viewbox = svg_page.attrib.get('viewBox')
-  if not viewbox:
-    raise RuntimeError(
-      'viewbox missing on svg, available attributes: %s' % svg_page.attrib.keys()
+    viewbox = svg_page.attrib.get('viewBox')
+    if not viewbox:
+        raise RuntimeError(
+            'viewbox missing on svg, available attributes: %s' % svg_page.attrib.keys()
+        )
+    _, _, width, height = viewbox.split()
+    image = annotated_blocks_to_image(
+        blocks, color_map,
+        width=float(width), height=float(height), background='white',
+        scale_to_size=image_size
     )
-  _, _, width, height = viewbox.split()
-  image = annotated_blocks_to_image(
-    blocks, color_map,
-    width=float(width), height=float(height), background='white',
-    scale_to_size=image_size
-  )
-  out = BytesIO()
-  image.save(out, 'png')
-  return out.getvalue()
+    out = BytesIO()
+    image.save(out, 'png')
+    return out.getvalue()
+
 
 def filter_list_props_by_indices(d, indices, list_props):
-  return {
-    k: (
-      [x for i, x in enumerate(v) if i in indices]
-      if k in list_props
-      else v
-    )
-    for k, v in iteritems(d)
-  }
+    return {
+        k: (
+            [x for i, x in enumerate(v) if i in indices]
+            if k in list_props
+            else v
+        )
+        for k, v in iteritems(d)
+    }
+
 
 def get_page_indices_with_min_annotation_percentage(
-  annotation_evaluation, min_annotation_percentage):
+        annotation_evaluation, min_annotation_percentage):
+
+    return [
+        i
+        for i, page_evaluation in enumerate(annotation_evaluation)
+        if page_evaluation['percentage'].get(None) <= (1 - min_annotation_percentage)
+    ]
 
-  return [
-    i
-    for i, page_evaluation in enumerate(annotation_evaluation)
-    if page_evaluation['percentage'].get(None) <= (1 - min_annotation_percentage)
-  ]
 
 def parse_page_range(s):
-  s = s.strip()
-  if not s:
-    return None
-  a = tuple([int(x) for x in s.split('-')])
-  if len(a) == 1:
-    return (a[0], a[0])
-  elif len(a) == 2:
-    return a
-  else:
-    raise TypeError('invalid page range: %s' % s)
+    s = s.strip()
+    if not s:
+        return None
+    a = tuple([int(x) for x in s.split('-')])
+    if len(a) == 1:
+        return (a[0], a[0])
+    elif len(a) == 2:
+        return a
+    else:
+        raise TypeError('invalid page range: %s' % s)
diff --git a/sciencebeam_gym/preprocess/preprocessing_utils_test.py b/sciencebeam_gym/preprocess/preprocessing_utils_test.py
index a287dd3..c6ac391 100644
--- a/sciencebeam_gym/preprocess/preprocessing_utils_test.py
+++ b/sciencebeam_gym/preprocess/preprocessing_utils_test.py
@@ -3,68 +3,72 @@ from mock import patch, MagicMock, DEFAULT
 from lxml import etree
 
 from sciencebeam_gym.structured_document.svg import (
-  SVG_DOC
+    SVG_DOC
 )
 
 from sciencebeam_gym.preprocess.preprocessing_utils import (
-  svg_page_to_blockified_png_bytes,
-  convert_pdf_bytes_to_lxml,
-  parse_page_range,
+    svg_page_to_blockified_png_bytes,
+    convert_pdf_bytes_to_lxml,
+    parse_page_range,
 )
 
 PROCESSING_UTILS = 'sciencebeam_gym.preprocess.preprocessing_utils'
 
 PDF_CONTENT_1 = b'pdf content 1'
 
+
 class TestSvgPageToBlockifiedPngBytes(object):
-  def test_should_parse_viewbox_and_pass_width_and_height_to_annotated_blocks_to_image(self):
-    with patch.multiple(PROCESSING_UTILS, annotated_blocks_to_image=DEFAULT) as mocks:
-      svg_page = etree.Element(SVG_DOC, attrib={
-        'viewBox': '0 0 100.1 200.9'
-      })
-      color_map = {}
-      image_size = (100, 200)
-      svg_page_to_blockified_png_bytes(svg_page, color_map, image_size)
-      call_args = mocks['annotated_blocks_to_image'].call_args
-      kwargs = call_args[1]
-      assert (kwargs.get('width'), kwargs.get('height')) == (100.1, 200.9)
+    def test_should_parse_viewbox_and_pass_width_and_height_to_annotated_blocks_to_image(self):
+        with patch.multiple(PROCESSING_UTILS, annotated_blocks_to_image=DEFAULT) as mocks:
+            svg_page = etree.Element(SVG_DOC, attrib={
+                'viewBox': '0 0 100.1 200.9'
+            })
+            color_map = {}
+            image_size = (100, 200)
+            svg_page_to_blockified_png_bytes(svg_page, color_map, image_size)
+            call_args = mocks['annotated_blocks_to_image'].call_args
+            kwargs = call_args[1]
+            assert (kwargs.get('width'), kwargs.get('height')) == (100.1, 200.9)
+
 
 DEFAULT_PDF_TO_LXML_ARGS = ['-blocks', '-noImageInline', '-noImage', '-fullFontName']
 
 LXML_CONTENT_1 = b'lxml content 1'
 
+
 class TestConvertPdfBytesToLxml(object):
-  def test_should_pass_pdf_content_and_default_args_to_process_input(self):
-    mock = MagicMock()
-    with patch.multiple(PROCESSING_UTILS, PdfToLxmlWrapper=mock):
-      mock.return_value.process_input.return_value = LXML_CONTENT_1
-      lxml_content = convert_pdf_bytes_to_lxml(PDF_CONTENT_1)
-      mock.return_value.process_input.assert_called_with(
-        PDF_CONTENT_1,
-        DEFAULT_PDF_TO_LXML_ARGS
-      )
-      assert lxml_content == LXML_CONTENT_1
-
-  def test_should_pass_include_page_range_in_args(self):
-    mock = MagicMock()
-    with patch.multiple(PROCESSING_UTILS, PdfToLxmlWrapper=mock):
-      mock.return_value.process_input.return_value = LXML_CONTENT_1
-      lxml_content = convert_pdf_bytes_to_lxml(PDF_CONTENT_1, page_range=(1, 3))
-      mock.return_value.process_input.assert_called_with(
-        PDF_CONTENT_1,
-        DEFAULT_PDF_TO_LXML_ARGS + ['-f', '1', '-l', '3']
-      )
-      assert lxml_content == LXML_CONTENT_1
+    def test_should_pass_pdf_content_and_default_args_to_process_input(self):
+        mock = MagicMock()
+        with patch.multiple(PROCESSING_UTILS, PdfToLxmlWrapper=mock):
+            mock.return_value.process_input.return_value = LXML_CONTENT_1
+            lxml_content = convert_pdf_bytes_to_lxml(PDF_CONTENT_1)
+            mock.return_value.process_input.assert_called_with(
+                PDF_CONTENT_1,
+                DEFAULT_PDF_TO_LXML_ARGS
+            )
+            assert lxml_content == LXML_CONTENT_1
+
+    def test_should_pass_include_page_range_in_args(self):
+        mock = MagicMock()
+        with patch.multiple(PROCESSING_UTILS, PdfToLxmlWrapper=mock):
+            mock.return_value.process_input.return_value = LXML_CONTENT_1
+            lxml_content = convert_pdf_bytes_to_lxml(PDF_CONTENT_1, page_range=(1, 3))
+            mock.return_value.process_input.assert_called_with(
+                PDF_CONTENT_1,
+                DEFAULT_PDF_TO_LXML_ARGS + ['-f', '1', '-l', '3']
+            )
+            assert lxml_content == LXML_CONTENT_1
+
 
 class TestPageRange(object):
-  def test_should_parse_single_page_number_as_range(self):
-    assert parse_page_range('1') == (1, 1)
+    def test_should_parse_single_page_number_as_range(self):
+        assert parse_page_range('1') == (1, 1)
 
-  def test_should_parse_range_with_hyphen(self):
-    assert parse_page_range('1-3') == (1, 3)
+    def test_should_parse_range_with_hyphen(self):
+        assert parse_page_range('1-3') == (1, 3)
 
-  def test_should_parse_range_with_spaces(self):
-    assert parse_page_range(' 1 - 3 ') == (1, 3)
+    def test_should_parse_range_with_spaces(self):
+        assert parse_page_range(' 1 - 3 ') == (1, 3)
 
-  def test_should_return_none_for_empty_range(self):
-    assert parse_page_range('') is None
+    def test_should_return_none_for_empty_range(self):
+        assert parse_page_range('') is None
diff --git a/sciencebeam_gym/preprocess/visualize_svg_annotation.py b/sciencebeam_gym/preprocess/visualize_svg_annotation.py
index fcd18ec..90439ee 100644
--- a/sciencebeam_gym/preprocess/visualize_svg_annotation.py
+++ b/sciencebeam_gym/preprocess/visualize_svg_annotation.py
@@ -2,87 +2,97 @@ import hashlib
 
 from lxml import etree
 
+from sciencebeam_utils.utils.collection import flatten
+
 from sciencebeam_gym.structured_document.svg import (
-  SVG_TAG_ATTRIB
+    SVG_TAG_ATTRIB
 )
 
 DEFAULT_COLORS = [
-  'maroon', 'red', 'purple', '#c0c', 'green', '#0c0',
-  'olive', '#cc0', 'navy', 'blue', 'teal', '#0cc'
+    'maroon', 'red', 'purple', '#c0c', 'green', '#0c0',
+    'olive', '#cc0', 'navy', 'blue', 'teal', '#0cc'
 ]
 # colors replaced due readability issues:
 # fuchsia, lime, yellow, aqua
 
-flatten = lambda l: [item for sublist in l for item in sublist]
-
 
 def color_for_tag(tag, colors=None):
-  if colors is None:
-    colors = DEFAULT_COLORS
-  h = int(hashlib.md5(tag.encode('utf8')).hexdigest(), 16)
-  return colors[h % len(colors)]
+    if colors is None:
+        colors = DEFAULT_COLORS
+    h = int(hashlib.md5(tag.encode('utf8')).hexdigest(), 16)
+    return colors[h % len(colors)]
+
 
 def style_props_for_tag(tag):
-  return {
-    'fill': color_for_tag(tag)
-  }
+    return {
+        'fill': color_for_tag(tag)
+    }
+
 
 def style_props_for_tags(tags):
-  return {
-    tag: style_props_for_tag(tag)
-    for tag in tags
-  }
+    return {
+        tag: style_props_for_tag(tag)
+        for tag in tags
+    }
+
 
 def render_style_props(style_props):
-  return '\n'.join([
-    '  {}: {}'.format(k, v)
-    for k, v in style_props.items()
-  ])
+    return '\n'.join([
+        '  {}: {}'.format(k, v)
+        for k, v in style_props.items()
+    ])
+
 
 def style_block_for_tag(tag, style_props):
-  return 'text[{tag_attrib}~="{tag}"] {{\n{content}\n}}'.format(
-    tag_attrib=SVG_TAG_ATTRIB,
-    tag=tag,
-    content=render_style_props(style_props)
-  )
+    return 'text[{tag_attrib}~="{tag}"] {{\n{content}\n}}'.format(
+        tag_attrib=SVG_TAG_ATTRIB,
+        tag=tag,
+        content=render_style_props(style_props)
+    )
+
 
 def style_block_for_tags(tags):
-  style_props_map = style_props_for_tags(tags)
-  return '\n\n'.join([
-    style_block_for_tag(tag, style_props_map[tag])
-    for tag in tags
-  ])
+    style_props_map = style_props_for_tags(tags)
+    return '\n\n'.join([
+        style_block_for_tag(tag, style_props_map[tag])
+        for tag in tags
+    ])
+
 
 def tags_for_node(node):
-  svga_tags = node.attrib.get(SVG_TAG_ATTRIB, '').strip()
-  if len(svga_tags) == 0:
-    return []
-  return svga_tags.split(' ')
+    svga_tags = node.attrib.get(SVG_TAG_ATTRIB, '').strip()
+    if len(svga_tags) == 0:
+        return []
+    return svga_tags.split(' ')
+
 
 def tags_for_nodes(nodes):
-  return sorted(set(flatten([
-    tags_for_node(node)
-    for node in nodes
-  ])))
+    return sorted(set(flatten([
+        tags_for_node(node)
+        for node in nodes
+    ])))
+
 
 def nodes_with_tags(svg_root):
-  for node in  svg_root.findall('*[@{}]'.format(SVG_TAG_ATTRIB)):
-    yield node
-    for nested_node in nodes_with_tags(node):
-      yield nested_node
+    for node in svg_root.findall('*[@{}]'.format(SVG_TAG_ATTRIB)):
+        yield node
+        for nested_node in nodes_with_tags(node):
+            yield nested_node
+
 
 def add_title_to_nodes(nodes):
-  for node in nodes:
-    tags = tags_for_node(node)
-    if len(tags) > 0:
-      title = etree.Element('title')
-      title.text = ' '.join(tags)
-      node.append(title)
+    for node in nodes:
+        tags = tags_for_node(node)
+        if len(tags) > 0:
+            title = etree.Element('title')
+            title.text = ' '.join(tags)
+            node.append(title)
+
 
 def visualize_svg_annotations(svg_root):
-  nodes = list(nodes_with_tags(svg_root))
-  style_block = etree.Element('style')
-  style_block.text = style_block_for_tags(tags_for_nodes(nodes))
-  svg_root.insert(0, style_block)
-  add_title_to_nodes(nodes)
-  return svg_root
+    nodes = list(nodes_with_tags(svg_root))
+    style_block = etree.Element('style')
+    style_block.text = style_block_for_tags(tags_for_nodes(nodes))
+    svg_root.insert(0, style_block)
+    add_title_to_nodes(nodes)
+    return svg_root
diff --git a/sciencebeam_gym/preprocess/visualize_svg_annotation_test.py b/sciencebeam_gym/preprocess/visualize_svg_annotation_test.py
index 67b573b..3dfbf1f 100644
--- a/sciencebeam_gym/preprocess/visualize_svg_annotation_test.py
+++ b/sciencebeam_gym/preprocess/visualize_svg_annotation_test.py
@@ -3,19 +3,19 @@ import logging
 from lxml import etree
 
 from sciencebeam_gym.structured_document.svg import (
-  SVG_NSMAP,
-  SVG_DOC,
-  SVG_TEXT,
-  SVG_TAG_ATTRIB
+    SVG_NSMAP,
+    SVG_DOC,
+    SVG_TEXT,
+    SVG_TAG_ATTRIB
 )
 
 from sciencebeam_gym.preprocess.visualize_svg_annotation import (
-  visualize_svg_annotations,
-  style_block_for_tags,
-  style_block_for_tag,
-  style_props_for_tags,
-  render_style_props,
-  color_for_tag
+    visualize_svg_annotations,
+    style_block_for_tags,
+    style_block_for_tag,
+    style_props_for_tags,
+    render_style_props,
+    color_for_tag
 )
 
 TAG1 = 'tag1'
@@ -28,92 +28,100 @@ logger = logging.getLogger('test')
 
 
 def _create_xml_node(tag, text=None, attrib=None):
-  node = etree.Element(tag)
-  if text is not None:
-    node.text = text
-  if attrib is not None:
-    for k, v in attrib.items():
-      node.attrib[k] = str(v)
-  return node
+    node = etree.Element(tag)
+    if text is not None:
+        node.text = text
+    if attrib is not None:
+        for k, v in attrib.items():
+            node.attrib[k] = str(v)
+    return node
+
 
 def _create_tagged_text(svga_tags):
-  return _create_xml_node(SVG_TEXT, attrib={
-    SVG_TAG_ATTRIB: svga_tags
-  })
+    return _create_xml_node(SVG_TEXT, attrib={
+        SVG_TAG_ATTRIB: svga_tags
+    })
 
 
 def test_add_style_block_for_single_tag_on_multiple_nodes():
-  svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
+    svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
+
+    svg_root.append(_create_tagged_text(TAG1))
+    svg_root.append(_create_tagged_text(TAG1))
 
-  svg_root.append(_create_tagged_text(TAG1))
-  svg_root.append(_create_tagged_text(TAG1))
+    result_svg = visualize_svg_annotations(svg_root)
+    style_block = result_svg.find('style')
 
-  result_svg = visualize_svg_annotations(svg_root)
-  style_block = result_svg.find('style')
+    assert style_block is not None
+    assert style_block.text == style_block_for_tags([TAG1])
 
-  assert style_block is not None
-  assert style_block.text == style_block_for_tags([TAG1])
 
 def test_add_style_block_for_multiple_tags_on_separate_nodes():
-  svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
+    svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
 
-  svg_root.append(_create_tagged_text(TAG1))
-  svg_root.append(_create_tagged_text(TAG2))
+    svg_root.append(_create_tagged_text(TAG1))
+    svg_root.append(_create_tagged_text(TAG2))
 
-  result_svg = visualize_svg_annotations(svg_root)
-  style_block = result_svg.find('style')
+    result_svg = visualize_svg_annotations(svg_root)
+    style_block = result_svg.find('style')
+
+    assert style_block is not None
+    assert style_block.text == style_block_for_tags([TAG1, TAG2])
 
-  assert style_block is not None
-  assert style_block.text == style_block_for_tags([TAG1, TAG2])
 
 def test_add_style_block_for_multiple_tags_on_same_node():
-  svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
+    svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
+
+    svg_root.append(_create_tagged_text(' '.join([TAG1, TAG2])))
 
-  svg_root.append(_create_tagged_text(' '.join([TAG1, TAG2])))
+    result_svg = visualize_svg_annotations(svg_root)
+    style_block = result_svg.find('style')
 
-  result_svg = visualize_svg_annotations(svg_root)
-  style_block = result_svg.find('style')
+    assert style_block is not None
+    assert style_block.text == style_block_for_tags([TAG1, TAG2])
 
-  assert style_block is not None
-  assert style_block.text == style_block_for_tags([TAG1, TAG2])
 
 def test_add_title_with_tags():
-  svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
+    svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP)
 
-  svg_root.append(_create_tagged_text(TAG1))
+    svg_root.append(_create_tagged_text(TAG1))
 
-  result_svg = visualize_svg_annotations(svg_root)
-  text_node = result_svg.find(SVG_TEXT + '/title')
+    result_svg = visualize_svg_annotations(svg_root)
+    text_node = result_svg.find(SVG_TEXT + '/title')
+
+    assert text_node is not None
+    assert text_node.text == TAG1
 
-  assert text_node is not None
-  assert text_node.text == TAG1
 
 def test_style_block_for_single_tag():
-  style_block_text = style_block_for_tags([TAG1])
+    style_block_text = style_block_for_tags([TAG1])
+
+    assert style_block_text == (
+        style_block_for_tag(TAG1, style_props_for_tags([TAG1])[TAG1])
+    )
 
-  assert style_block_text == (
-    style_block_for_tag(TAG1, style_props_for_tags([TAG1])[TAG1])
-  )
 
 def test_style_block_for_multiple_tags():
-  style_block_text = style_block_for_tags([TAG1, TAG2])
-  style_props_map = style_props_for_tags([TAG1, TAG2])
+    style_block_text = style_block_for_tags([TAG1, TAG2])
+    style_props_map = style_props_for_tags([TAG1, TAG2])
+
+    assert style_block_text == (
+        '\n\n'.join([
+            style_block_for_tag(TAG1, style_props_map[TAG1]),
+            style_block_for_tag(TAG2, style_props_map[TAG2])
+        ])
+    )
 
-  assert style_block_text == (
-    '\n\n'.join([
-      style_block_for_tag(TAG1, style_props_map[TAG1]),
-      style_block_for_tag(TAG2, style_props_map[TAG2])
-    ])
-  )
 
 def test_style_block_for_tag():
-  style_props = style_props_for_tags([TAG1])[TAG1]
-  style_block_text = style_block_for_tag(TAG1, style_props)
+    style_props = style_props_for_tags([TAG1])[TAG1]
+    style_block_text = style_block_for_tag(TAG1, style_props)
+
+    assert (
+        style_block_text ==
+        'text[class~="' + TAG1 + '"] {\n' + render_style_props(style_props) + '\n}'
+    )
 
-  assert (
-    style_block_text ==
-    'text[class~="' + TAG1 + '"] {\n' + render_style_props(style_props) + '\n}'
-  )
 
 def test_color_for_tag_should_be_different_for_different_tags():
-  assert color_for_tag(TAG1) != color_for_tag(TAG2)
+    assert color_for_tag(TAG1) != color_for_tag(TAG2)
diff --git a/sciencebeam_gym/structured_document/__init__.py b/sciencebeam_gym/structured_document/__init__.py
index 6f8ccf1..f7aba87 100644
--- a/sciencebeam_gym/structured_document/__init__.py
+++ b/sciencebeam_gym/structured_document/__init__.py
@@ -11,230 +11,244 @@ LEVEL_ATTRIB_SEP = '_'
 
 SIMPLE_TAG_ATTRIB_NAME = 'tag'
 
+
 def merge_token_tag(
-  merged_structured_document, merged_token,
-  other_structured_document, other_token,
-  source_scope=None, target_scope=None):
-
-  tag = other_structured_document.get_tag(other_token, scope=source_scope)
-  if tag:
-    merged_structured_document.set_tag(
-      merged_token,
-      tag,
-      scope=target_scope
-    )
+        merged_structured_document, merged_token,
+        other_structured_document, other_token,
+        source_scope=None, target_scope=None):
+
+    tag = other_structured_document.get_tag(other_token, scope=source_scope)
+    if tag:
+        merged_structured_document.set_tag(
+            merged_token,
+            tag,
+            scope=target_scope
+        )
+
 
 def get_scoped_attrib_name(name, scope=None, level=None):
-  if level:
-    name = 'level%s%s%s' % (level, LEVEL_ATTRIB_SEP, name)
-  return '%s%s%s' % (scope, SCOPE_ATTRIB_SEP, name) if scope else name
+    if level:
+        name = 'level%s%s%s' % (level, LEVEL_ATTRIB_SEP, name)
+    return '%s%s%s' % (scope, SCOPE_ATTRIB_SEP, name) if scope else name
+
 
 def get_attrib_by_scope(attrib, name):
-  suffix = '%s%s' % (SCOPE_ATTRIB_SEP, name)
-  return {
-    (None if k == name else k[:-len(suffix)]): v
-    for k, v in attrib.items()
-    if k.endswith(suffix) or k == name
-  }
+    suffix = '%s%s' % (SCOPE_ATTRIB_SEP, name)
+    return {
+        (None if k == name else k[:-len(suffix)]): v
+        for k, v in attrib.items()
+        if k.endswith(suffix) or k == name
+    }
+
 
 def get_simple_tag_attrib_name(scope, level=None):
-  return get_scoped_attrib_name(SIMPLE_TAG_ATTRIB_NAME, scope, level)
+    return get_scoped_attrib_name(SIMPLE_TAG_ATTRIB_NAME, scope, level)
+
 
 def split_tag_prefix(tag):
-  if tag:
-    if tag.startswith(B_TAG_PREFIX):
-      return B_TAG_PREFIX, tag[len(B_TAG_PREFIX):]
-    if tag.startswith(I_TAG_PREFIX):
-      return I_TAG_PREFIX, tag[len(I_TAG_PREFIX):]
-  return None, tag
+    if tag:
+        if tag.startswith(B_TAG_PREFIX):
+            return B_TAG_PREFIX, tag[len(B_TAG_PREFIX):]
+        if tag.startswith(I_TAG_PREFIX):
+            return I_TAG_PREFIX, tag[len(I_TAG_PREFIX):]
+    return None, tag
+
 
 def strip_tag_prefix(tag):
-  return split_tag_prefix(tag)[1]
+    return split_tag_prefix(tag)[1]
+
 
 def add_tag_prefix(tag, prefix):
-  return prefix + tag if prefix and tag else tag
+    return prefix + tag if prefix and tag else tag
+
 
 class AbstractStructuredDocument(object, with_metaclass(ABCMeta)):
-  def clone(self):
-    return deepcopy(self)
-
-  def iter_all_tokens(self):
-    for page in self.get_pages():
-      for line in self.get_lines_of_page(page):
-        for token in self.get_tokens_of_line(line):
-          yield token
-
-  def merge_with(
-    self,
-    other_structured_document,
-    merge_fn):
-    """
-    Merges this structured document with another structured document using the merge fn.
-
-    Note: this document will be changed (operate on a clone if that is undesired)
-    """
-
-    for merged_token, other_token in zip(
-      self.iter_all_tokens(),
-      other_structured_document.iter_all_tokens()
-      ):
-      assert (
-        self.get_text(merged_token) ==
-        other_structured_document.get_text(other_token)
-      )
-      merge_fn(
-        self, merged_token,
-        other_structured_document, other_token
-      )
-
-  def get_tag_prefix_and_value(self, parent, scope=None, level=None):
-    return split_tag_prefix(self.get_tag(parent, scope=scope, level=level))
-
-  def get_tag_value(self, parent, scope=None):
-    return self.get_tag_prefix_and_value(parent, scope=scope)[1]
-
-  def set_tag_with_prefix(self, parent, tag, scope=None, prefix=None):
-    self.set_tag(parent, add_tag_prefix(tag, prefix), scope=scope)
-
-  def get_sub_tag(self, parent, scope=None):
-    return self.get_tag(parent, scope=scope, level=2)
-
-  def set_sub_tag(self, parent, tag, scope=None):
-    self.set_tag(parent, tag, scope=scope, level=2)
-
-  def set_sub_tag_with_prefix(self, parent, tag, scope=None, prefix=None):
-    self.set_sub_tag(parent, add_tag_prefix(tag, prefix), scope=scope)
-
-  @abstractmethod
-  def get_pages(self):
-    pass
-
-  @abstractmethod
-  def get_lines_of_page(self, page):
-    pass
-
-  @abstractmethod
-  def get_tokens_of_line(self, line):
-    pass
-
-  @abstractmethod
-  def get_x(self, parent):
-    pass
-
-  @abstractmethod
-  def get_text(self, parent):
-    pass
-
-  @abstractmethod
-  def get_tag(self, parent, scope=None, level=None):
-    pass
-
-  @abstractmethod
-  def set_tag(self, parent, tag, scope=None, level=None):
-    pass
-
-  @abstractmethod
-  def get_tag_by_scope(self, parent):
-    pass
-
-  @abstractmethod
-  def get_bounding_box(self, parent):
-    pass
-
-  @abstractmethod
-  def set_bounding_box(self, parent, bounding_box):
-    pass
+    def clone(self):
+        return deepcopy(self)
+
+    def iter_all_tokens(self):
+        for page in self.get_pages():
+            for line in self.get_lines_of_page(page):
+                for token in self.get_tokens_of_line(line):
+                    yield token
+
+    def merge_with(
+            self,
+            other_structured_document,
+            merge_fn):
+        """
+        Merges this structured document with another structured document using the merge fn.
+
+        Note: this document will be changed (operate on a clone if that is undesired)
+        """
+
+        for merged_token, other_token in zip(
+            self.iter_all_tokens(),
+            other_structured_document.iter_all_tokens()
+        ):
+            assert (
+                self.get_text(merged_token) ==
+                other_structured_document.get_text(other_token)
+            )
+            merge_fn(
+                self, merged_token,
+                other_structured_document, other_token
+            )
+
+    def get_tag_prefix_and_value(self, parent, scope=None, level=None):
+        return split_tag_prefix(self.get_tag(parent, scope=scope, level=level))
+
+    def get_tag_value(self, parent, scope=None):
+        return self.get_tag_prefix_and_value(parent, scope=scope)[1]
+
+    def set_tag_with_prefix(self, parent, tag, scope=None, prefix=None):
+        self.set_tag(parent, add_tag_prefix(tag, prefix), scope=scope)
+
+    def get_sub_tag(self, parent, scope=None):
+        return self.get_tag(parent, scope=scope, level=2)
+
+    def set_sub_tag(self, parent, tag, scope=None):
+        self.set_tag(parent, tag, scope=scope, level=2)
+
+    def set_sub_tag_with_prefix(self, parent, tag, scope=None, prefix=None):
+        self.set_sub_tag(parent, add_tag_prefix(tag, prefix), scope=scope)
+
+    @abstractmethod
+    def get_pages(self):
+        pass
+
+    @abstractmethod
+    def get_lines_of_page(self, page):
+        pass
+
+    @abstractmethod
+    def get_tokens_of_line(self, line):
+        pass
+
+    @abstractmethod
+    def get_x(self, parent):
+        pass
+
+    @abstractmethod
+    def get_text(self, parent):
+        pass
+
+    @abstractmethod
+    def get_tag(self, parent, scope=None, level=None):
+        pass
+
+    @abstractmethod
+    def set_tag(self, parent, tag, scope=None, level=None):
+        pass
+
+    @abstractmethod
+    def get_tag_by_scope(self, parent):
+        pass
+
+    @abstractmethod
+    def get_bounding_box(self, parent):
+        pass
+
+    @abstractmethod
+    def set_bounding_box(self, parent, bounding_box):
+        pass
+
 
 class SimpleElement(object):
-  def __init__(self, bounding_box=None):
-    self._bounding_box = bounding_box
+    def __init__(self, bounding_box=None):
+        self._bounding_box = bounding_box
 
-  def get_bounding_box(self):
-    return self._bounding_box
+    def get_bounding_box(self):
+        return self._bounding_box
+
+    def set_bounding_box(self, bounding_box):
+        self._bounding_box = bounding_box
 
-  def set_bounding_box(self, bounding_box):
-    self._bounding_box = bounding_box
 
 class SimpleToken(SimpleElement):
-  def __init__(
-    self, text, attrib=None, tag=None, tag_scope=None, tag_prefix=None, **kwargs):
-    super(SimpleToken, self).__init__(**kwargs)
-    self.text = text
-    if attrib is None:
-      attrib = {}
-    self.attrib = attrib
-    if tag is not None:
-      self.set_tag(tag, scope=tag_scope, prefix=tag_prefix)
+    def __init__(
+            self, text, attrib=None, tag=None, tag_scope=None, tag_prefix=None, **kwargs):
+        super(SimpleToken, self).__init__(**kwargs)
+        self.text = text
+        if attrib is None:
+            attrib = {}
+        self.attrib = attrib
+        if tag is not None:
+            self.set_tag(tag, scope=tag_scope, prefix=tag_prefix)
+
+    def get_x(self):
+        return self.attrib.get('x')
 
-  def get_x(self):
-    return self.attrib.get('x')
+    def get_y(self):
+        return self.attrib.get('y')
 
-  def get_y(self):
-    return self.attrib.get('y')
+    def get_tag(self, scope=None, level=None):
+        return self.attrib.get(get_simple_tag_attrib_name(scope=scope, level=level))
 
-  def get_tag(self, scope=None, level=None):
-    return self.attrib.get(get_simple_tag_attrib_name(scope=scope, level=level))
+    def set_tag(self, tag, scope=None, level=None, prefix=None):
+        self.attrib[get_simple_tag_attrib_name(
+            scope=scope, level=level)] = add_tag_prefix(tag, prefix)
 
-  def set_tag(self, tag, scope=None, level=None, prefix=None):
-    self.attrib[get_simple_tag_attrib_name(scope=scope, level=level)] = add_tag_prefix(tag, prefix)
+    def get_tag_by_scope(self):
+        return get_attrib_by_scope(self.attrib, SIMPLE_TAG_ATTRIB_NAME)
 
-  def get_tag_by_scope(self):
-    return get_attrib_by_scope(self.attrib, SIMPLE_TAG_ATTRIB_NAME)
+    def get_text(self):
+        return self.text
 
-  def get_text(self):
-    return self.text
+    def __repr__(self):
+        return '%s(%s)' % (type(self).__name__, self.text)
 
-  def __repr__(self):
-    return '%s(%s)' % (type(self).__name__, self.text)
 
 class SimpleLine(SimpleElement):
-  def __init__(self, tokens):
-    super(SimpleLine, self).__init__()
-    self.tokens = tokens
+    def __init__(self, tokens):
+        super(SimpleLine, self).__init__()
+        self.tokens = tokens
+
 
 class SimplePage(SimpleElement):
-  def __init__(self, lines, **kwargs):
-    super(SimplePage, self).__init__(**kwargs)
-    self.lines = lines
+    def __init__(self, lines, **kwargs):
+        super(SimplePage, self).__init__(**kwargs)
+        self.lines = lines
+
 
 class SimpleStructuredDocument(AbstractStructuredDocument):
-  def __init__(self, page_or_pages=None, lines=None):
-    if lines is not None:
-      pages = [SimplePage(lines)]
-    elif page_or_pages is None:
-      pages = []
-    elif isinstance(page_or_pages, list):
-      pages = page_or_pages
-    else:
-      pages = [page_or_pages]
-    self._pages = pages
+    def __init__(self, page_or_pages=None, lines=None):
+        if lines is not None:
+            pages = [SimplePage(lines)]
+        elif page_or_pages is None:
+            pages = []
+        elif isinstance(page_or_pages, list):
+            pages = page_or_pages
+        else:
+            pages = [page_or_pages]
+        self._pages = pages
 
-  def get_pages(self):
-    return self._pages
+    def get_pages(self):
+        return self._pages
 
-  def get_lines_of_page(self, page):
-    return page.lines
+    def get_lines_of_page(self, page):
+        return page.lines
 
-  def get_tokens_of_line(self, line):
-    return line.tokens
+    def get_tokens_of_line(self, line):
+        return line.tokens
 
-  def get_x(self, parent):
-    return parent.get_x()
+    def get_x(self, parent):
+        return parent.get_x()
 
-  def get_text(self, parent):
-    return parent.get_text()
+    def get_text(self, parent):
+        return parent.get_text()
 
-  def get_tag(self, parent, scope=None, level=None):
-    return parent.get_tag(scope=scope, level=level)
+    def get_tag(self, parent, scope=None, level=None):
+        return parent.get_tag(scope=scope, level=level)
 
-  def set_tag(self, parent, tag, scope=None, level=None):
-    return parent.set_tag(tag, scope=scope, level=level)
+    def set_tag(self, parent, tag, scope=None, level=None):
+        return parent.set_tag(tag, scope=scope, level=level)
 
-  def get_tag_by_scope(self, parent):
-    return parent.get_tag_by_scope()
+    def get_tag_by_scope(self, parent):
+        return parent.get_tag_by_scope()
 
-  def get_bounding_box(self, parent):
-    return parent.get_bounding_box()
+    def get_bounding_box(self, parent):
+        return parent.get_bounding_box()
 
-  def set_bounding_box(self, parent, bounding_box):
-    parent.set_bounding_box(bounding_box)
+    def set_bounding_box(self, parent, bounding_box):
+        parent.set_bounding_box(bounding_box)
diff --git a/sciencebeam_gym/structured_document/__init___test.py b/sciencebeam_gym/structured_document/__init___test.py
index 4cb537c..3ecb681 100644
--- a/sciencebeam_gym/structured_document/__init___test.py
+++ b/sciencebeam_gym/structured_document/__init___test.py
@@ -3,10 +3,10 @@ from functools import partial
 import pytest
 
 from sciencebeam_gym.structured_document import (
-  SimpleStructuredDocument,
-  SimpleLine,
-  SimpleToken,
-  merge_token_tag
+    SimpleStructuredDocument,
+    SimpleLine,
+    SimpleToken,
+    merge_token_tag
 )
 
 TEXT_1 = 'text 1'
@@ -17,109 +17,111 @@ TAG_2 = 'tag2'
 
 SCOPE_1 = 'scope1'
 
-class TestAbstractStructuredDocumentMergeWith(object):
-  def test_should_merge_single_token_and_add_prefix(self):
-    merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag=TAG_1)
-    ])])
-    other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag=TAG_2)
-    ])])
-    merged_structured_document.merge_with(
-      other_structured_document,
-      partial(
-        merge_token_tag,
-        target_scope=SCOPE_1
-      )
-    )
-    merged_tokens = list(merged_structured_document.iter_all_tokens())
-    assert (
-      [merged_structured_document.get_text(t) for t in merged_tokens] ==
-      [TEXT_1]
-    )
-    assert (
-      [merged_structured_document.get_tag(t) for t in merged_tokens] ==
-      [TAG_1]
-    )
-    assert (
-      [merged_structured_document.get_tag(t, scope=SCOPE_1) for t in merged_tokens] ==
-      [TAG_2]
-    )
 
-  def test_should_not_fail_with_absent_tags(self):
-    merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag=TAG_1)
-    ])])
-    other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1)
-    ])])
-    merged_structured_document.merge_with(
-      other_structured_document,
-      partial(
-        merge_token_tag,
-        target_scope=SCOPE_1
-      )
-    )
-    merged_tokens = list(merged_structured_document.iter_all_tokens())
-    assert (
-      [merged_structured_document.get_tag(t, scope=SCOPE_1) for t in merged_tokens] ==
-      [None]
-    )
+class TestAbstractStructuredDocumentMergeWith(object):
+    def test_should_merge_single_token_and_add_prefix(self):
+        merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag=TAG_1)
+        ])])
+        other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag=TAG_2)
+        ])])
+        merged_structured_document.merge_with(
+            other_structured_document,
+            partial(
+                merge_token_tag,
+                target_scope=SCOPE_1
+            )
+        )
+        merged_tokens = list(merged_structured_document.iter_all_tokens())
+        assert (
+            [merged_structured_document.get_text(t) for t in merged_tokens] ==
+            [TEXT_1]
+        )
+        assert (
+            [merged_structured_document.get_tag(t) for t in merged_tokens] ==
+            [TAG_1]
+        )
+        assert (
+            [merged_structured_document.get_tag(t, scope=SCOPE_1) for t in merged_tokens] ==
+            [TAG_2]
+        )
 
-  def test_should_not_override_with_empty_tags(self):
-    merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag=TAG_1)
-    ])])
-    other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1)
-    ])])
-    merged_structured_document.merge_with(
-      other_structured_document,
-      partial(
-        merge_token_tag
-      )
-    )
-    merged_tokens = list(merged_structured_document.iter_all_tokens())
-    assert (
-      [merged_structured_document.get_tag(t) for t in merged_tokens] ==
-      [TAG_1]
-    )
+    def test_should_not_fail_with_absent_tags(self):
+        merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag=TAG_1)
+        ])])
+        other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1)
+        ])])
+        merged_structured_document.merge_with(
+            other_structured_document,
+            partial(
+                merge_token_tag,
+                target_scope=SCOPE_1
+            )
+        )
+        merged_tokens = list(merged_structured_document.iter_all_tokens())
+        assert (
+            [merged_structured_document.get_tag(t, scope=SCOPE_1) for t in merged_tokens] ==
+            [None]
+        )
 
-  def test_should_raise_assertion_error_if_tokens_mismatch(self):
-    merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_1, tag=TAG_1)
-    ])])
-    other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
-      SimpleToken(TEXT_2, tag=TAG_2)
-    ])])
-    with pytest.raises(AssertionError):
-      merged_structured_document.merge_with(
-        other_structured_document,
-        partial(
-          merge_token_tag,
-          target_scope=SCOPE_1
+    def test_should_not_override_with_empty_tags(self):
+        merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag=TAG_1)
+        ])])
+        other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1)
+        ])])
+        merged_structured_document.merge_with(
+            other_structured_document,
+            partial(
+                merge_token_tag
+            )
+        )
+        merged_tokens = list(merged_structured_document.iter_all_tokens())
+        assert (
+            [merged_structured_document.get_tag(t) for t in merged_tokens] ==
+            [TAG_1]
         )
-      )
+
+    def test_should_raise_assertion_error_if_tokens_mismatch(self):
+        merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_1, tag=TAG_1)
+        ])])
+        other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([
+            SimpleToken(TEXT_2, tag=TAG_2)
+        ])])
+        with pytest.raises(AssertionError):
+            merged_structured_document.merge_with(
+                other_structured_document,
+                partial(
+                    merge_token_tag,
+                    target_scope=SCOPE_1
+                )
+            )
+
 
 class TestSimpleStructuredDocument(object):
-  def test_should_set_tag_without_scope(self):
-    token = SimpleToken(TEXT_1)
-    doc = SimpleStructuredDocument(lines=[SimpleLine([token])])
-    doc.set_tag(token, TAG_1)
-    assert doc.get_tag(token) == TAG_1
+    def test_should_set_tag_without_scope(self):
+        token = SimpleToken(TEXT_1)
+        doc = SimpleStructuredDocument(lines=[SimpleLine([token])])
+        doc.set_tag(token, TAG_1)
+        assert doc.get_tag(token) == TAG_1
 
-  def test_should_set_tag_with_scope(self):
-    token = SimpleToken(TEXT_1)
-    doc = SimpleStructuredDocument(lines=[SimpleLine([token])])
-    doc.set_tag(token, TAG_1, scope=SCOPE_1)
-    assert doc.get_tag(token, scope=SCOPE_1) == TAG_1
-    assert doc.get_tag(token) is None
+    def test_should_set_tag_with_scope(self):
+        token = SimpleToken(TEXT_1)
+        doc = SimpleStructuredDocument(lines=[SimpleLine([token])])
+        doc.set_tag(token, TAG_1, scope=SCOPE_1)
+        assert doc.get_tag(token, scope=SCOPE_1) == TAG_1
+        assert doc.get_tag(token) is None
 
-  def test_should_return_all_tag_by_scope(self):
-    token = SimpleToken(TEXT_1)
-    doc = SimpleStructuredDocument(lines=[SimpleLine([token])])
-    doc.set_tag(token, TAG_1)
-    doc.set_tag(token, TAG_2, scope=SCOPE_1)
-    assert doc.get_tag(token) == TAG_1
-    assert doc.get_tag(token, scope=SCOPE_1) == TAG_2
-    assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2}
+    def test_should_return_all_tag_by_scope(self):
+        token = SimpleToken(TEXT_1)
+        doc = SimpleStructuredDocument(lines=[SimpleLine([token])])
+        doc.set_tag(token, TAG_1)
+        doc.set_tag(token, TAG_2, scope=SCOPE_1)
+        assert doc.get_tag(token) == TAG_1
+        assert doc.get_tag(token, scope=SCOPE_1) == TAG_2
+        assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2}
diff --git a/sciencebeam_gym/structured_document/lxml.py b/sciencebeam_gym/structured_document/lxml.py
index 5219529..1b04f65 100644
--- a/sciencebeam_gym/structured_document/lxml.py
+++ b/sciencebeam_gym/structured_document/lxml.py
@@ -1,63 +1,66 @@
 from sciencebeam_utils.utils.xml import (
-  set_or_remove_attrib
+    set_or_remove_attrib
 )
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.structured_document import (
-  AbstractStructuredDocument,
-  get_scoped_attrib_name,
-  get_attrib_by_scope
+    AbstractStructuredDocument,
+    get_scoped_attrib_name,
+    get_attrib_by_scope
 )
 
 TAG_ATTRIB_NAME = 'tag'
 
+
 def get_node_bounding_box(t):
-  return BoundingBox(
-    float(t.attrib.get('x', 0)),
-    float(t.attrib.get('y', 0)),
-    float(t.attrib['width']),
-    float(t.attrib['height'])
-  )
+    return BoundingBox(
+        float(t.attrib.get('x', 0)),
+        float(t.attrib.get('y', 0)),
+        float(t.attrib['width']),
+        float(t.attrib['height'])
+    )
+
 
 def _get_tag_attrib_name(scope, level):
-  return get_scoped_attrib_name(TAG_ATTRIB_NAME, scope=scope, level=level)
+    return get_scoped_attrib_name(TAG_ATTRIB_NAME, scope=scope, level=level)
+
 
 class LxmlStructuredDocument(AbstractStructuredDocument):
-  def __init__(self, root):
-    self.root = root
+    def __init__(self, root):
+        self.root = root
 
-  def get_pages(self):
-    return self.root.findall('.//PAGE')
+    def get_pages(self):
+        return self.root.findall('.//PAGE')
 
-  def get_lines_of_page(self, page):
-    return page.findall('.//TEXT')
+    def get_lines_of_page(self, page):
+        return page.findall('.//TEXT')
 
-  def get_tokens_of_line(self, line):
-    return line.findall('./TOKEN')
+    def get_tokens_of_line(self, line):
+        return line.findall('./TOKEN')
 
-  def get_x(self, parent):
-    return parent.attrib.get('x')
+    def get_x(self, parent):
+        return parent.attrib.get('x')
 
-  def get_text(self, parent):
-    return parent.text
+    def get_text(self, parent):
+        return parent.text
 
-  def get_tag(self, parent, scope=None, level=None):
-    return parent.attrib.get(_get_tag_attrib_name(scope, level))
+    def get_tag(self, parent, scope=None, level=None):
+        return parent.attrib.get(_get_tag_attrib_name(scope, level))
 
-  def set_tag(self, parent, tag, scope=None, level=None):
-    set_or_remove_attrib(parent.attrib, _get_tag_attrib_name(scope, level), tag)
+    def set_tag(self, parent, tag, scope=None, level=None):
+        set_or_remove_attrib(parent.attrib, _get_tag_attrib_name(scope, level), tag)
 
-  def get_tag_by_scope(self, parent):
-    return get_attrib_by_scope(parent.attrib, TAG_ATTRIB_NAME)
+    def get_tag_by_scope(self, parent):
+        return get_attrib_by_scope(parent.attrib, TAG_ATTRIB_NAME)
 
-  def get_bounding_box(self, parent):
-    return get_node_bounding_box(parent)
+    def get_bounding_box(self, parent):
+        return get_node_bounding_box(parent)
 
-  def set_bounding_box(self, parent, bounding_box):
-    parent.attrib['x'] = str(bounding_box.x)
-    parent.attrib['y'] = str(bounding_box.y)
-    parent.attrib['width'] = str(bounding_box.width)
-    parent.attrib['height'] = str(bounding_box.height)
+    def set_bounding_box(self, parent, bounding_box):
+        parent.attrib['x'] = str(bounding_box.x)
+        parent.attrib['y'] = str(bounding_box.y)
+        parent.attrib['width'] = str(bounding_box.width)
+        parent.attrib['height'] = str(bounding_box.height)
diff --git a/sciencebeam_gym/structured_document/lxml_test.py b/sciencebeam_gym/structured_document/lxml_test.py
index 774b991..ef361fc 100644
--- a/sciencebeam_gym/structured_document/lxml_test.py
+++ b/sciencebeam_gym/structured_document/lxml_test.py
@@ -3,11 +3,11 @@ from __future__ import absolute_import
 from lxml.builder import E
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.structured_document.lxml import (
-  LxmlStructuredDocument
+    LxmlStructuredDocument
 )
 
 TAG_1 = 'tag1'
@@ -15,142 +15,143 @@ TAG_2 = 'tag2'
 
 SCOPE_1 = 'scope1'
 
+
 class TestLxmlStructuredDocument(object):
-  def test_should_find_pages(self):
-    pages = [
-      E.PAGE(),
-      E.PAGE()
-    ]
-    doc = LxmlStructuredDocument(
-      E.DOCUMENT(
-        *pages
-      )
-    )
-    assert list(doc.get_pages()) == pages
-
-  def test_should_find_lines_of_page_without_blocks(self):
-    lines = [
-      E.TEXT(),
-      E.TEXT()
-    ]
-    page = E.PAGE(*lines)
-    doc = LxmlStructuredDocument(
-      E.DOCUMENT(
-        page,
-        # add another page just for effect
-        E.PAGE(
-          E.TEXT()
+    def test_should_find_pages(self):
+        pages = [
+            E.PAGE(),
+            E.PAGE()
+        ]
+        doc = LxmlStructuredDocument(
+            E.DOCUMENT(
+                *pages
+            )
+        )
+        assert list(doc.get_pages()) == pages
+
+    def test_should_find_lines_of_page_without_blocks(self):
+        lines = [
+            E.TEXT(),
+            E.TEXT()
+        ]
+        page = E.PAGE(*lines)
+        doc = LxmlStructuredDocument(
+            E.DOCUMENT(
+                page,
+                # add another page just for effect
+                E.PAGE(
+                    E.TEXT()
+                )
+            )
         )
-      )
-    )
-    assert list(doc.get_lines_of_page(page)) == lines
-
-  def test_should_find_lines_of_page_with_blocks(self):
-    lines = [
-      E.TEXT(),
-      E.TEXT()
-    ]
-    page = E.PAGE(E.BLOCK(*lines))
-    doc = LxmlStructuredDocument(
-      E.DOCUMENT(
-        page,
-        # add another page just for effect
-        E.PAGE(
-          E.BLOCK(E.TEXT())
+        assert list(doc.get_lines_of_page(page)) == lines
+
+    def test_should_find_lines_of_page_with_blocks(self):
+        lines = [
+            E.TEXT(),
+            E.TEXT()
+        ]
+        page = E.PAGE(E.BLOCK(*lines))
+        doc = LxmlStructuredDocument(
+            E.DOCUMENT(
+                page,
+                # add another page just for effect
+                E.PAGE(
+                    E.BLOCK(E.TEXT())
+                )
+            )
         )
-      )
-    )
-    assert list(doc.get_lines_of_page(page)) == lines
-
-  def test_should_find_tokens_of_line(self):
-    tokens = [
-      E.TOKEN(),
-      E.TOKEN()
-    ]
-    line = E.TEXT(*tokens)
-    doc = LxmlStructuredDocument(
-      E.DOCUMENT(
-        E.PAGE(
-          line,
-          E.TEXT(E.TOKEN)
+        assert list(doc.get_lines_of_page(page)) == lines
+
+    def test_should_find_tokens_of_line(self):
+        tokens = [
+            E.TOKEN(),
+            E.TOKEN()
+        ]
+        line = E.TEXT(*tokens)
+        doc = LxmlStructuredDocument(
+            E.DOCUMENT(
+                E.PAGE(
+                    line,
+                    E.TEXT(E.TOKEN)
+                )
+            )
         )
-      )
-    )
-    assert list(doc.get_tokens_of_line(line)) == tokens
-
-  def test_should_calculate_default_bounding_box(self):
-    token = E.TOKEN({
-      'x': '10',
-      'y': '11',
-      'width': '100',
-      'height': '101'
-    })
-    doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.TEXT(token))))
-    assert doc.get_bounding_box(token) == BoundingBox(10, 11, 100, 101)
-
-  def test_should_be_able_to_set_bounding_box(self):
-    bounding_box = BoundingBox(10, 11, 100, 101)
-    token = E.TOKEN({
-      'x': '20',
-      'y': '21',
-      'width': '200',
-      'height': '201'
-    })
-    doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.TEXT(token))))
-    doc.set_bounding_box(token, bounding_box)
-    assert doc.get_bounding_box(token) == bounding_box
-
-  def test_should_calculate_bounding_box_of_page_without_xy(self):
-    page = E.PAGE({
-      'width': '100',
-      'height': '101'
-    })
-    doc = LxmlStructuredDocument(E.DOCUMENT(page))
-    assert doc.get_bounding_box(page) == BoundingBox(0, 0, 100, 101)
-
-  def test_should_set_tag_without_scope(self):
-    token = E.TEXT()
-    doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
-    doc.set_tag(token, TAG_1)
-    assert doc.get_tag(token) == TAG_1
-
-  def test_should_set_tag_with_scope(self):
-    token = E.TEXT()
-    doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
-    doc.set_tag(token, TAG_1, scope=SCOPE_1)
-    assert doc.get_tag(token, scope=SCOPE_1) == TAG_1
-    assert doc.get_tag(token) is None
-
-  def test_should_set_tag_with_level(self):
-    token = E.TEXT()
-    doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
-    doc.set_tag(token, TAG_1, level=2)
-    assert doc.get_tag(token, level=2) == TAG_1
-    assert doc.get_tag(token) is None
-
-  def test_should_clear_tag_when_setting_tag_to_none(self):
-    token = E.TEXT()
-    doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
-    doc.set_tag(token, TAG_1)
-    doc.set_tag(token, TAG_1, scope=SCOPE_1)
-    doc.set_tag(token, None)
-    doc.set_tag(token, None, scope=SCOPE_1)
-    assert doc.get_tag(token) is None
-    assert doc.get_tag(token, scope=SCOPE_1) is None
-
-  def test_should_not_fail_setting_empty_tag_to_none(self):
-    token = E.TEXT()
-    doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
-    doc.set_tag(token, None)
-    doc.set_tag(token, None, scope=SCOPE_1)
-    assert doc.get_tag(token) is None
-    assert doc.get_tag(token, scope=SCOPE_1) is None
-
-  def test_should_return_all_tag_by_scope(self):
-    token = E.TEXT()
-    doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
-    doc.set_tag(token, TAG_1)
-    doc.set_tag(token, TAG_2, scope=SCOPE_1)
-    assert doc.get_tag(token) == TAG_1
-    assert doc.get_tag(token, scope=SCOPE_1) == TAG_2
-    assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2}
+        assert list(doc.get_tokens_of_line(line)) == tokens
+
+    def test_should_calculate_default_bounding_box(self):
+        token = E.TOKEN({
+            'x': '10',
+            'y': '11',
+            'width': '100',
+            'height': '101'
+        })
+        doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.TEXT(token))))
+        assert doc.get_bounding_box(token) == BoundingBox(10, 11, 100, 101)
+
+    def test_should_be_able_to_set_bounding_box(self):
+        bounding_box = BoundingBox(10, 11, 100, 101)
+        token = E.TOKEN({
+            'x': '20',
+            'y': '21',
+            'width': '200',
+            'height': '201'
+        })
+        doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.TEXT(token))))
+        doc.set_bounding_box(token, bounding_box)
+        assert doc.get_bounding_box(token) == bounding_box
+
+    def test_should_calculate_bounding_box_of_page_without_xy(self):
+        page = E.PAGE({
+            'width': '100',
+            'height': '101'
+        })
+        doc = LxmlStructuredDocument(E.DOCUMENT(page))
+        assert doc.get_bounding_box(page) == BoundingBox(0, 0, 100, 101)
+
+    def test_should_set_tag_without_scope(self):
+        token = E.TEXT()
+        doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
+        doc.set_tag(token, TAG_1)
+        assert doc.get_tag(token) == TAG_1
+
+    def test_should_set_tag_with_scope(self):
+        token = E.TEXT()
+        doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
+        doc.set_tag(token, TAG_1, scope=SCOPE_1)
+        assert doc.get_tag(token, scope=SCOPE_1) == TAG_1
+        assert doc.get_tag(token) is None
+
+    def test_should_set_tag_with_level(self):
+        token = E.TEXT()
+        doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
+        doc.set_tag(token, TAG_1, level=2)
+        assert doc.get_tag(token, level=2) == TAG_1
+        assert doc.get_tag(token) is None
+
+    def test_should_clear_tag_when_setting_tag_to_none(self):
+        token = E.TEXT()
+        doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
+        doc.set_tag(token, TAG_1)
+        doc.set_tag(token, TAG_1, scope=SCOPE_1)
+        doc.set_tag(token, None)
+        doc.set_tag(token, None, scope=SCOPE_1)
+        assert doc.get_tag(token) is None
+        assert doc.get_tag(token, scope=SCOPE_1) is None
+
+    def test_should_not_fail_setting_empty_tag_to_none(self):
+        token = E.TEXT()
+        doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
+        doc.set_tag(token, None)
+        doc.set_tag(token, None, scope=SCOPE_1)
+        assert doc.get_tag(token) is None
+        assert doc.get_tag(token, scope=SCOPE_1) is None
+
+    def test_should_return_all_tag_by_scope(self):
+        token = E.TEXT()
+        doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token))))
+        doc.set_tag(token, TAG_1)
+        doc.set_tag(token, TAG_2, scope=SCOPE_1)
+        assert doc.get_tag(token) == TAG_1
+        assert doc.get_tag(token, scope=SCOPE_1) == TAG_2
+        assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2}
diff --git a/sciencebeam_gym/structured_document/structured_document_loader.py b/sciencebeam_gym/structured_document/structured_document_loader.py
index 89fd82f..1798dd0 100644
--- a/sciencebeam_gym/structured_document/structured_document_loader.py
+++ b/sciencebeam_gym/structured_document/structured_document_loader.py
@@ -1,59 +1,66 @@
 from __future__ import absolute_import
 
-from zipfile import ZipFile
-
 from lxml import etree
 from lxml.builder import E
 
 from apache_beam.io.filesystems import FileSystems
 
 from sciencebeam_gym.utils.pages_zip import (
-  load_pages
+    load_pages
 )
 
 from sciencebeam_gym.structured_document.lxml import (
-  LxmlStructuredDocument
+    LxmlStructuredDocument
 )
 
 from sciencebeam_gym.structured_document.svg import (
-  SvgStructuredDocument
+    SvgStructuredDocument
 )
 
+
 class StructuredDocumentType(object):
-  LXML = 'lxml'
-  SVG_PAGES = 'svg-pages'
+    LXML = 'lxml'
+    SVG_PAGES = 'svg-pages'
+
 
 def get_structuctured_document_type(filename):
-  if filename.endswith('.zip'):
-    return StructuredDocumentType.SVG_PAGES
-  return StructuredDocumentType.LXML
+    if filename.endswith('.zip'):
+        return StructuredDocumentType.SVG_PAGES
+    return StructuredDocumentType.LXML
+
 
 def load_lxml_structured_document(filename, page_range=None):
-  with FileSystems.open(filename) as f:
-    structured_document = LxmlStructuredDocument(etree.parse(f).getroot())
-    if page_range:
-      structured_document = LxmlStructuredDocument(
-        E.DOCUMENT(
-          *structured_document.get_pages()[
-            max(0, page_range[0] - 1):
-            page_range[1]
-          ]
-        )
-      )
-    return structured_document
+    with FileSystems.open(filename) as f:
+        structured_document = LxmlStructuredDocument(etree.parse(f).getroot())
+        if page_range:
+            structured_document = LxmlStructuredDocument(
+                E.DOCUMENT(
+                    *structured_document.get_pages()[
+                        max(0, page_range[0] - 1):
+                        page_range[1]
+                    ]
+                )
+            )
+        return structured_document
+
 
 def load_svg_pages_structured_document(filename, page_range=None):
-  return SvgStructuredDocument([
-    etree.parse(svg_f).getroot()
-    for svg_f in load_pages(filename, page_range=page_range)
-  ])
+    return SvgStructuredDocument([
+        etree.parse(svg_f).getroot()
+        for svg_f in load_pages(filename, page_range=page_range)
+    ])
+
 
 def load_structured_document(filename, page_range=None):
-  structured_document_type = get_structuctured_document_type(filename)
-  if structured_document_type == StructuredDocumentType.LXML:
-    return load_lxml_structured_document(filename, page_range=page_range)
-  if structured_document_type == StructuredDocumentType.SVG_PAGES:
-    return load_svg_pages_structured_document(filename, page_range=page_range)
+    structured_document_type = get_structuctured_document_type(filename)
+    if structured_document_type == StructuredDocumentType.LXML:
+        return load_lxml_structured_document(filename, page_range=page_range)
+    if structured_document_type == StructuredDocumentType.SVG_PAGES:
+        return load_svg_pages_structured_document(filename, page_range=page_range)
+    raise RuntimeError('unsupported structured_document_type: %s (%s)' % (
+        structured_document_type, filename
+    ))
+
 
 def load_structured_documents_from_file_list(file_list, page_range=None):
-  return (load_structured_document(s, page_range=page_range) for s in file_list)
+    return (load_structured_document(s, page_range=page_range) for s in file_list)
diff --git a/sciencebeam_gym/structured_document/structured_document_loader_test.py b/sciencebeam_gym/structured_document/structured_document_loader_test.py
index 651802a..5842ccb 100644
--- a/sciencebeam_gym/structured_document/structured_document_loader_test.py
+++ b/sciencebeam_gym/structured_document/structured_document_loader_test.py
@@ -10,12 +10,12 @@ from lxml.builder import E
 import sciencebeam_gym.structured_document.structured_document_loader as structured_document_loader
 
 from sciencebeam_gym.structured_document.structured_document_loader import (
-  StructuredDocumentType,
-  get_structuctured_document_type,
-  load_structured_documents_from_file_list,
-  load_lxml_structured_document,
-  load_svg_pages_structured_document,
-  load_structured_document
+    StructuredDocumentType,
+    get_structuctured_document_type,
+    load_structured_documents_from_file_list,
+    load_lxml_structured_document,
+    load_svg_pages_structured_document,
+    load_structured_document
 )
 
 FILE_1 = 'file1.pdf'
@@ -23,93 +23,98 @@ FILE_2 = 'file2.pdf'
 
 PAGE_RANGE = (2, 3)
 
+
 class TestLoadLxmlStructuredDocument(object):
-  def test_should_load_file(self):
-    lxml_content = etree.tostring(E.test('test'))
-    with NamedTemporaryFile() as f:
-      f.write(lxml_content)
-      f.flush()
-      structured_document = load_lxml_structured_document(f.name)
-      assert etree.tostring(structured_document.root) == lxml_content
-      assert hasattr(structured_document.root, 'attrib')
-
-  def test_should_limit_page_range(self):
-    lxml_content = etree.tostring(E.DOCUMENT(
-      E.PAGE('page 1'),
-      E.PAGE('page 2'),
-      E.PAGE('page 3'),
-      E.PAGE('page 4')
-    ))
-    with NamedTemporaryFile() as f:
-      f.write(lxml_content)
-      f.flush()
-      structured_document = load_lxml_structured_document(f.name, page_range=(2, 3))
-      assert [x.text for x in structured_document.get_pages()] == ['page 2', 'page 3']
+    def test_should_load_file(self):
+        lxml_content = etree.tostring(E.test('test'))
+        with NamedTemporaryFile() as f:
+            f.write(lxml_content)
+            f.flush()
+            structured_document = load_lxml_structured_document(f.name)
+            assert etree.tostring(structured_document.root) == lxml_content
+            assert hasattr(structured_document.root, 'attrib')
+
+    def test_should_limit_page_range(self):
+        lxml_content = etree.tostring(E.DOCUMENT(
+            E.PAGE('page 1'),
+            E.PAGE('page 2'),
+            E.PAGE('page 3'),
+            E.PAGE('page 4')
+        ))
+        with NamedTemporaryFile() as f:
+            f.write(lxml_content)
+            f.flush()
+            structured_document = load_lxml_structured_document(f.name, page_range=(2, 3))
+            assert [x.text for x in structured_document.get_pages()] == ['page 2', 'page 3']
+
 
 class TestLoadSvgPagesStructuredDocument(object):
-  def test_should_load_file_with_multiple_pages(self):
-    svg_pages_content = [
-      etree.tostring(E.svg('page 1')),
-      etree.tostring(E.svg('page 2'))
-    ]
-    with NamedTemporaryFile() as f:
-      with ZipFile(f, 'w') as zf:
-        for i, svg_page_content in enumerate(svg_pages_content):
-          zf.writestr('page-%d.svg' % (1 + i), svg_page_content)
-      f.flush()
-      structured_document = load_svg_pages_structured_document(f.name)
-      assert (
-        [etree.tostring(x) for x in structured_document.page_roots] ==
-        svg_pages_content
-      )
-      assert hasattr(structured_document.page_roots[0], 'attrib')
-
-  def test_should_limit_page_range(self):
-    svg_pages_content = [
-      etree.tostring(E.svg('page 1')),
-      etree.tostring(E.svg('page 2')),
-      etree.tostring(E.svg('page 3')),
-      etree.tostring(E.svg('page 4'))
-    ]
-    with NamedTemporaryFile() as f:
-      with ZipFile(f, 'w') as zf:
-        for i, svg_page_content in enumerate(svg_pages_content):
-          zf.writestr('page-%d.svg' % (1 + i), svg_page_content)
-      f.flush()
-      structured_document = load_svg_pages_structured_document(f.name, page_range=(2, 3))
-      assert (
-        [x.text for x in structured_document.page_roots] ==
-        ['page 2', 'page 3']
-      )
+    def test_should_load_file_with_multiple_pages(self):
+        svg_pages_content = [
+            etree.tostring(E.svg('page 1')),
+            etree.tostring(E.svg('page 2'))
+        ]
+        with NamedTemporaryFile() as f:
+            with ZipFile(f, 'w') as zf:
+                for i, svg_page_content in enumerate(svg_pages_content):
+                    zf.writestr('page-%d.svg' % (1 + i), svg_page_content)
+            f.flush()
+            structured_document = load_svg_pages_structured_document(f.name)
+            assert (
+                [etree.tostring(x) for x in structured_document.page_roots] ==
+                svg_pages_content
+            )
+            assert hasattr(structured_document.page_roots[0], 'attrib')
+
+    def test_should_limit_page_range(self):
+        svg_pages_content = [
+            etree.tostring(E.svg('page 1')),
+            etree.tostring(E.svg('page 2')),
+            etree.tostring(E.svg('page 3')),
+            etree.tostring(E.svg('page 4'))
+        ]
+        with NamedTemporaryFile() as f:
+            with ZipFile(f, 'w') as zf:
+                for i, svg_page_content in enumerate(svg_pages_content):
+                    zf.writestr('page-%d.svg' % (1 + i), svg_page_content)
+            f.flush()
+            structured_document = load_svg_pages_structured_document(f.name, page_range=(2, 3))
+            assert (
+                [x.text for x in structured_document.page_roots] ==
+                ['page 2', 'page 3']
+            )
+
 
 class TestGetStructuredDocumentType(object):
-  def test_should_return_lxml_for_lxml_file(self):
-    assert get_structuctured_document_type('file.lxml') == StructuredDocumentType.LXML
+    def test_should_return_lxml_for_lxml_file(self):
+        assert get_structuctured_document_type('file.lxml') == StructuredDocumentType.LXML
 
-  def test_should_return_lxml_for_lxml_gz_file(self):
-    assert get_structuctured_document_type('file.lxml.gz') == StructuredDocumentType.LXML
+    def test_should_return_lxml_for_lxml_gz_file(self):
+        assert get_structuctured_document_type('file.lxml.gz') == StructuredDocumentType.LXML
+
+    def test_should_return_lxml_for_svg_zip_file(self):
+        assert get_structuctured_document_type('file.svg.zip') == StructuredDocumentType.SVG_PAGES
 
-  def test_should_return_lxml_for_svg_zip_file(self):
-    assert get_structuctured_document_type('file.svg.zip') == StructuredDocumentType.SVG_PAGES
 
 class TestLoadStructuredDocument(object):
-  def test_should_call_load_plain_file_list_if_file(self):
-    with patch.object(structured_document_loader, 'load_lxml_structured_document') as mock:
-      result = load_structured_document('file.lxml', page_range=PAGE_RANGE)
-      mock.assert_called_with('file.lxml', page_range=PAGE_RANGE)
-      assert result == mock.return_value
-
-  def test_should_call_load_csv_or_tsv_file_list_if_file(self):
-    with patch.object(structured_document_loader, 'load_svg_pages_structured_document') as mock:
-      result = load_structured_document('file.svg.zip', page_range=PAGE_RANGE)
-      mock.assert_called_with('file.svg.zip', page_range=PAGE_RANGE)
-      assert result == mock.return_value
+    def test_should_call_load_plain_file_list_if_file(self):
+        with patch.object(structured_document_loader, 'load_lxml_structured_document') as mock:
+            result = load_structured_document('file.lxml', page_range=PAGE_RANGE)
+            mock.assert_called_with('file.lxml', page_range=PAGE_RANGE)
+            assert result == mock.return_value
+
+    def test_should_call_load_csv_or_tsv_file_list_if_file(self):
+        with patch.object(structured_document_loader, 'load_svg_pages_structured_document') as mock:
+            result = load_structured_document('file.svg.zip', page_range=PAGE_RANGE)
+            mock.assert_called_with('file.svg.zip', page_range=PAGE_RANGE)
+            assert result == mock.return_value
+
 
 class TestLoadStructuredDocumentsFromFileList(object):
-  def test_should_single_file(self):
-    with patch.object(structured_document_loader, 'load_structured_document') as mock:
-      assert (
-        list(load_structured_documents_from_file_list([FILE_1], page_range=PAGE_RANGE)) ==
-        [mock.return_value]
-      )
-      mock.assert_called_with(FILE_1, page_range=PAGE_RANGE)
+    def test_should_single_file(self):
+        with patch.object(structured_document_loader, 'load_structured_document') as mock:
+            assert (
+                list(load_structured_documents_from_file_list([FILE_1], page_range=PAGE_RANGE)) ==
+                [mock.return_value]
+            )
+            mock.assert_called_with(FILE_1, page_range=PAGE_RANGE)
diff --git a/sciencebeam_gym/structured_document/structured_document_saver.py b/sciencebeam_gym/structured_document/structured_document_saver.py
index 5ca17bd..7b541df 100644
--- a/sciencebeam_gym/structured_document/structured_document_saver.py
+++ b/sciencebeam_gym/structured_document/structured_document_saver.py
@@ -3,33 +3,36 @@ from __future__ import absolute_import
 from lxml import etree
 
 from sciencebeam_utils.beam_utils.io import (
-  save_file_content
+    save_file_content
 )
 
 from sciencebeam_gym.utils.pages_zip import (
-  save_pages
+    save_pages
 )
 
 from sciencebeam_gym.structured_document.lxml import (
-  LxmlStructuredDocument
+    LxmlStructuredDocument
 )
 
 from sciencebeam_gym.structured_document.svg import (
-  SvgStructuredDocument
+    SvgStructuredDocument
 )
 
+
 def save_lxml_structured_document(filename, lxml_structured_document):
-  save_file_content(filename, etree.tostring(lxml_structured_document.root))
+    save_file_content(filename, etree.tostring(lxml_structured_document.root))
+
 
 def save_svg_structured_document(filename, svg_structured_document):
-  return save_pages(filename, '.svg', (
-    etree.tostring(svg_page)
-    for svg_page in svg_structured_document.get_pages()
-  ))
+    return save_pages(filename, '.svg', (
+        etree.tostring(svg_page)
+        for svg_page in svg_structured_document.get_pages()
+    ))
+
 
 def save_structured_document(filename, structured_document):
-  if isinstance(structured_document, LxmlStructuredDocument):
-    return save_lxml_structured_document(filename, structured_document)
-  if isinstance(structured_document, SvgStructuredDocument):
-    return save_svg_structured_document(filename, structured_document)
-  raise RuntimeError('unsupported type: %s' % type(structured_document))
+    if isinstance(structured_document, LxmlStructuredDocument):
+        return save_lxml_structured_document(filename, structured_document)
+    if isinstance(structured_document, SvgStructuredDocument):
+        return save_svg_structured_document(filename, structured_document)
+    raise RuntimeError('unsupported type: %s' % type(structured_document))
diff --git a/sciencebeam_gym/structured_document/structured_document_saver_test.py b/sciencebeam_gym/structured_document/structured_document_saver_test.py
index fdbece8..1cbc455 100644
--- a/sciencebeam_gym/structured_document/structured_document_saver_test.py
+++ b/sciencebeam_gym/structured_document/structured_document_saver_test.py
@@ -5,53 +5,56 @@ from mock import patch, ANY
 from lxml.builder import E
 
 from sciencebeam_gym.structured_document.lxml import (
-  LxmlStructuredDocument
+    LxmlStructuredDocument
 )
 
 from sciencebeam_gym.structured_document.svg import (
-  SvgStructuredDocument
+    SvgStructuredDocument
 )
 
 import sciencebeam_gym.structured_document.structured_document_saver as structured_document_saver
 from sciencebeam_gym.structured_document.structured_document_saver import (
-  save_lxml_structured_document,
-  save_svg_structured_document,
-  save_structured_document
+    save_lxml_structured_document,
+    save_svg_structured_document,
+    save_structured_document
 )
 
 FILE_1 = 'file1'
 
+
 class TestSaveLxmlStructuredDocument(object):
-  def test_should_call_save_file_content(self):
-    m = structured_document_saver
-    root = E.DOCUMENT()
-    with patch.object(m, 'save_file_content') as save_file_content:
-      with patch.object(m, 'etree') as etree:
-        save_lxml_structured_document(FILE_1, LxmlStructuredDocument(root))
-        save_file_content.assert_called_with(FILE_1, etree.tostring(root))
+    def test_should_call_save_file_content(self):
+        m = structured_document_saver
+        root = E.DOCUMENT()
+        with patch.object(m, 'save_file_content') as save_file_content:
+            with patch.object(m, 'etree') as etree:
+                save_lxml_structured_document(FILE_1, LxmlStructuredDocument(root))
+                save_file_content.assert_called_with(FILE_1, etree.tostring(root))
+
 
 class TestSaveSvgStructuredDocument(object):
-  def test_should_call_save_pages(self):
-    m = structured_document_saver
-    root = E.svg()
-    with patch.object(m, 'save_pages') as save_pages:
-      with patch.object(m, 'etree') as etree:
-        save_svg_structured_document(FILE_1, SvgStructuredDocument(root))
-        save_pages.assert_called_with(FILE_1, '.svg', ANY)
-        args, _ = save_pages.call_args
-        assert list(args[2]) == [etree.tostring(root)]
+    def test_should_call_save_pages(self):
+        m = structured_document_saver
+        root = E.svg()
+        with patch.object(m, 'save_pages') as save_pages:
+            with patch.object(m, 'etree') as etree:
+                save_svg_structured_document(FILE_1, SvgStructuredDocument(root))
+                save_pages.assert_called_with(FILE_1, '.svg', ANY)
+                args, _ = save_pages.call_args
+                assert list(args[2]) == [etree.tostring(root)]
+
 
 class TestSaveStructuredDocument(object):
-  def test_should_call_save_lxml_structured_document(self):
-    structured_document = LxmlStructuredDocument(E.DOCUMENT)
-    m = structured_document_saver
-    with patch.object(m, 'save_lxml_structured_document') as save_lxml_structured_document_mock:
-      save_structured_document(FILE_1, structured_document)
-      save_lxml_structured_document_mock.assert_called_with(FILE_1, structured_document)
-
-  def test_should_call_save_svg_structured_document(self):
-    structured_document = SvgStructuredDocument(E.svg)
-    m = structured_document_saver
-    with patch.object(m, 'save_svg_structured_document') as save_svg_structured_document_mock:
-      save_structured_document(FILE_1, structured_document)
-      save_svg_structured_document_mock.assert_called_with(FILE_1, structured_document)
+    def test_should_call_save_lxml_structured_document(self):
+        structured_document = LxmlStructuredDocument(E.DOCUMENT)
+        m = structured_document_saver
+        with patch.object(m, 'save_lxml_structured_document') as save_lxml_structured_document_mock:
+            save_structured_document(FILE_1, structured_document)
+            save_lxml_structured_document_mock.assert_called_with(FILE_1, structured_document)
+
+    def test_should_call_save_svg_structured_document(self):
+        structured_document = SvgStructuredDocument(E.svg)
+        m = structured_document_saver
+        with patch.object(m, 'save_svg_structured_document') as save_svg_structured_document_mock:
+            save_structured_document(FILE_1, structured_document)
+            save_svg_structured_document_mock.assert_called_with(FILE_1, structured_document)
diff --git a/sciencebeam_gym/structured_document/svg.py b/sciencebeam_gym/structured_document/svg.py
index 194ccf9..c9ecd43 100644
--- a/sciencebeam_gym/structured_document/svg.py
+++ b/sciencebeam_gym/structured_document/svg.py
@@ -1,15 +1,15 @@
 from sciencebeam_utils.utils.xml import (
-  set_or_remove_attrib
+    set_or_remove_attrib
 )
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.structured_document import (
-  AbstractStructuredDocument,
-  get_scoped_attrib_name,
-  get_attrib_by_scope
+    AbstractStructuredDocument,
+    get_scoped_attrib_name,
+    get_attrib_by_scope
 )
 
 SVG_NS = 'http://www.w3.org/2000/svg'
@@ -30,89 +30,95 @@ SVGE_BOUNDING_BOX = SVGE_NS_PREFIX + 'bounding-box'
 SCOPED_TAG_ATTRIB_SUFFIX = 'tag'
 
 SVG_NSMAP = {
-  None : SVG_NS,
-  'svge': SVGE_NS
+    None: SVG_NS,
+    'svge': SVGE_NS
 }
 
+
 class SvgStyleClasses(object):
-  LINE = 'line'
-  BLOCK = 'block'
-  LINE_NO = 'line_no'
+    LINE = 'line'
+    BLOCK = 'block'
+    LINE_NO = 'line_no'
+
 
 def format_bounding_box(bounding_box):
-  return '%s %s %s %s' % (bounding_box.x, bounding_box.y, bounding_box.width, bounding_box.height)
+    return '%s %s %s %s' % (bounding_box.x, bounding_box.y, bounding_box.width, bounding_box.height)
+
 
 def parse_bounding_box(bounding_box_str):
-  if not bounding_box_str:
-    return None
-  x, y, width, height = bounding_box_str.split()
-  return BoundingBox(float(x), float(y), float(width), float(height))
+    if not bounding_box_str:
+        return None
+    x, y, width, height = bounding_box_str.split()
+    return BoundingBox(float(x), float(y), float(width), float(height))
+
 
 def get_node_bounding_box(t):
-  attrib = t.attrib
-  if SVGE_BOUNDING_BOX in attrib:
-    return parse_bounding_box(attrib[SVGE_BOUNDING_BOX])
-  if SVG_VIEWBOX_ATTRIB in attrib:
-    return parse_bounding_box(attrib[SVG_VIEWBOX_ATTRIB])
-  if not ('font-size' in attrib and 'x' in attrib and 'y' in attrib):
-    return None
-  font_size = float(attrib['font-size'])
-  width = font_size * 0.8 * max(1, len(t.text))
-  return BoundingBox(
-    float(attrib['x']),
-    float(attrib['y']),
-    width,
-    font_size
-  )
+    attrib = t.attrib
+    if SVGE_BOUNDING_BOX in attrib:
+        return parse_bounding_box(attrib[SVGE_BOUNDING_BOX])
+    if SVG_VIEWBOX_ATTRIB in attrib:
+        return parse_bounding_box(attrib[SVG_VIEWBOX_ATTRIB])
+    if not ('font-size' in attrib and 'x' in attrib and 'y' in attrib):
+        return None
+    font_size = float(attrib['font-size'])
+    width = font_size * 0.8 * max(1, len(t.text))
+    return BoundingBox(
+        float(attrib['x']),
+        float(attrib['y']),
+        width,
+        font_size
+    )
+
 
 def _get_tag_attrib_name(scope, level):
-  return (
-    SVGE_NS_PREFIX + get_scoped_attrib_name(SCOPED_TAG_ATTRIB_SUFFIX, scope=scope, level=level)
-    if scope or level
-    else SVG_TAG_ATTRIB
-  )
+    return (
+        SVGE_NS_PREFIX + get_scoped_attrib_name(SCOPED_TAG_ATTRIB_SUFFIX, scope=scope, level=level)
+        if scope or level
+        else SVG_TAG_ATTRIB
+    )
+
 
 class SvgStructuredDocument(AbstractStructuredDocument):
-  def __init__(self, root_or_roots):
-    if isinstance(root_or_roots, list):
-      self.page_roots = root_or_roots
-    else:
-      self.page_roots = [root_or_roots]
+    def __init__(self, root_or_roots):
+        if isinstance(root_or_roots, list):
+            self.page_roots = root_or_roots
+        else:
+            self.page_roots = [root_or_roots]
 
-  def get_pages(self):
-    return self.page_roots
+    def get_pages(self):
+        return self.page_roots
 
-  def get_lines_of_page(self, page):
-    return page.findall('.//{}[@class="{}"]'.format(SVG_G, SvgStyleClasses.LINE))
+    def get_lines_of_page(self, page):
+        return page.findall('.//{}[@class="{}"]'.format(SVG_G, SvgStyleClasses.LINE))
 
-  def get_tokens_of_line(self, line):
-    return line.findall('./{}'.format(SVG_TEXT))
+    def get_tokens_of_line(self, line):
+        return line.findall('./{}'.format(SVG_TEXT))
 
-  def get_x(self, parent):
-    return parent.attrib.get('x')
+    def get_x(self, parent):
+        return parent.attrib.get('x')
 
-  def get_text(self, parent):
-    return parent.text
+    def get_text(self, parent):
+        return parent.text
 
-  def get_tag(self, parent, scope=None, level=None):
-    return parent.attrib.get(_get_tag_attrib_name(scope, level))
+    def get_tag(self, parent, scope=None, level=None):
+        return parent.attrib.get(_get_tag_attrib_name(scope, level))
 
-  def set_tag(self, parent, tag, scope=None, level=None):
-    set_or_remove_attrib(parent.attrib, _get_tag_attrib_name(scope, level), tag)
+    def set_tag(self, parent, tag, scope=None, level=None):
+        set_or_remove_attrib(parent.attrib, _get_tag_attrib_name(scope, level), tag)
 
-  def get_tag_by_scope(self, parent):
-    d = {
-      k[len(SVGE_NS_PREFIX):]: v
-      for k, v in get_attrib_by_scope(parent.attrib, SCOPED_TAG_ATTRIB_SUFFIX).items()
-      if k.startswith(SVGE_NS_PREFIX)
-    }
-    tag = self.get_tag(parent)
-    if tag:
-      d[None] = tag
-    return d
+    def get_tag_by_scope(self, parent):
+        d = {
+            k[len(SVGE_NS_PREFIX):]: v
+            for k, v in get_attrib_by_scope(parent.attrib, SCOPED_TAG_ATTRIB_SUFFIX).items()
+            if k.startswith(SVGE_NS_PREFIX)
+        }
+        tag = self.get_tag(parent)
+        if tag:
+            d[None] = tag
+        return d
 
-  def get_bounding_box(self, parent):
-    return get_node_bounding_box(parent)
+    def get_bounding_box(self, parent):
+        return get_node_bounding_box(parent)
 
-  def set_bounding_box(self, parent, bounding_box):
-    parent.attrib[SVGE_BOUNDING_BOX] = format_bounding_box(bounding_box)
+    def set_bounding_box(self, parent, bounding_box):
+        parent.attrib[SVGE_BOUNDING_BOX] = format_bounding_box(bounding_box)
diff --git a/sciencebeam_gym/structured_document/svg_test.py b/sciencebeam_gym/structured_document/svg_test.py
index 879597c..db8ef2c 100644
--- a/sciencebeam_gym/structured_document/svg_test.py
+++ b/sciencebeam_gym/structured_document/svg_test.py
@@ -3,16 +3,16 @@ from __future__ import absolute_import
 from lxml.builder import ElementMaker
 
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
 from sciencebeam_gym.structured_document.svg import (
-  SvgStructuredDocument,
-  SvgStyleClasses,
-  SVG_NS,
-  SVG_VIEWBOX_ATTRIB,
-  SVGE_BOUNDING_BOX,
-  format_bounding_box
+    SvgStructuredDocument,
+    SvgStyleClasses,
+    SVG_NS,
+    SVG_VIEWBOX_ATTRIB,
+    SVGE_BOUNDING_BOX,
+    format_bounding_box
 )
 
 E = ElementMaker(namespace=SVG_NS)
@@ -25,166 +25,167 @@ TAG_2 = 'tag2'
 
 SCOPE_1 = 'scope1'
 
+
 class TestSvgStructuredDocument(object):
-  def test_should_return_root_as_pages(self):
-    root = E.svg()
-    doc = SvgStructuredDocument(root)
-    assert list(doc.get_pages()) == [root]
-
-  def test_should_find_lines_of_page_without_blocks(self):
-    lines = [
-      SVG_TEXT_LINE(),
-      SVG_TEXT_LINE()
-    ]
-    doc = SvgStructuredDocument(
-      E.svg(
-        *lines
-      )
-    )
-    page = doc.get_pages()[0]
-    assert list(doc.get_lines_of_page(page)) == lines
-
-  def test_should_find_lines_of_page_with_blocks(self):
-    lines = [
-      SVG_TEXT_LINE(),
-      SVG_TEXT_LINE()
-    ]
-    doc = SvgStructuredDocument(
-      E.svg(
-        SVG_TEXT_BLOCK(
-          *lines
+    def test_should_return_root_as_pages(self):
+        root = E.svg()
+        doc = SvgStructuredDocument(root)
+        assert list(doc.get_pages()) == [root]
+
+    def test_should_find_lines_of_page_without_blocks(self):
+        lines = [
+            SVG_TEXT_LINE(),
+            SVG_TEXT_LINE()
+        ]
+        doc = SvgStructuredDocument(
+            E.svg(
+                *lines
+            )
+        )
+        page = doc.get_pages()[0]
+        assert list(doc.get_lines_of_page(page)) == lines
+
+    def test_should_find_lines_of_page_with_blocks(self):
+        lines = [
+            SVG_TEXT_LINE(),
+            SVG_TEXT_LINE()
+        ]
+        doc = SvgStructuredDocument(
+            E.svg(
+                SVG_TEXT_BLOCK(
+                    *lines
+                )
+            )
+        )
+        page = doc.get_pages()[0]
+        assert list(doc.get_lines_of_page(page)) == lines
+
+    def test_should_find_tokens_of_line(self):
+        tokens = [
+            SVG_TEXT(),
+            SVG_TEXT()
+        ]
+        line = SVG_TEXT_LINE(*tokens)
+        doc = SvgStructuredDocument(
+            E.svg(
+                line,
+                SVG_TEXT_LINE(SVG_TEXT())
+            )
         )
-      )
-    )
-    page = doc.get_pages()[0]
-    assert list(doc.get_lines_of_page(page)) == lines
-
-  def test_should_find_tokens_of_line(self):
-    tokens = [
-      SVG_TEXT(),
-      SVG_TEXT()
-    ]
-    line = SVG_TEXT_LINE(*tokens)
-    doc = SvgStructuredDocument(
-      E.svg(
-        line,
-        SVG_TEXT_LINE(SVG_TEXT())
-      )
-    )
-    assert list(doc.get_tokens_of_line(line)) == tokens
-
-  def test_should_tag_text_as_line_no(self):
-    text = SVG_TEXT()
-    doc = SvgStructuredDocument(
-      E.svg(
-        SVG_TEXT_LINE(text)
-      )
-    )
-    doc.set_tag(text, SvgStyleClasses.LINE_NO)
-    assert text.attrib['class'] == SvgStyleClasses.LINE_NO
-
-  def test_should_calculate_default_bounding_box(self):
-    text = SVG_TEXT('a', {
-      'x': '10',
-      'y': '11',
-      'font-size': '100'
-    })
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
-    assert doc.get_bounding_box(text) == BoundingBox(10, 11, 100 * 0.8, 100)
-
-  def test_should_estimate_width_based_on_number_of_characters(self):
-    s = 'abc'
-    text = SVG_TEXT(s, {
-      'x': '10',
-      'y': '11',
-      'font-size': '100'
-    })
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
-    assert doc.get_bounding_box(text) == BoundingBox(
-      10, 11, 100 * 0.8 * len(s), 100
-    )
-
-  def test_should_not_return_bounding_box_if_font_size_is_missing(self):
-    text = SVG_TEXT({
-      'x': '10',
-      'y': '11'
-    })
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
-    assert doc.get_bounding_box(text) is None
-
-  def test_should_use_bounding_box_if_available(self):
-    bounding_box = BoundingBox(11, 12, 101, 102)
-    text = SVG_TEXT('a', {
-      'x': '10',
-      'y': '11',
-      'font-size': '100',
-      SVGE_BOUNDING_BOX: format_bounding_box(bounding_box)
-    })
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
-    assert doc.get_bounding_box(text) == bounding_box
-
-  def test_should_be_able_to_set_bounding_box(self):
-    bounding_box = BoundingBox(11, 12, 101, 102)
-    text = SVG_TEXT('a', {
-      'x': '10',
-      'y': '11',
-      'font-size': '100'
-    })
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
-    doc.set_bounding_box(text, bounding_box)
-    assert text.attrib[SVGE_BOUNDING_BOX] == format_bounding_box(bounding_box)
-
-  def test_should_use_viewbox_if_available(self):
-    bounding_box = BoundingBox(11, 12, 101, 102)
-    page = E.svg({
-      SVG_VIEWBOX_ATTRIB: format_bounding_box(bounding_box)
-    })
-    doc = SvgStructuredDocument(page)
-    assert doc.get_bounding_box(page) == bounding_box
-
-  def test_should_set_tag_without_scope(self):
-    token = SVG_TEXT('test')
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
-    doc.set_tag(token, TAG_1)
-    assert doc.get_tag(token) == TAG_1
-
-  def test_should_set_tag_with_scope(self):
-    token = SVG_TEXT('test')
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
-    doc.set_tag(token, TAG_1, scope=SCOPE_1)
-    assert doc.get_tag(token, scope=SCOPE_1) == TAG_1
-    assert doc.get_tag(token) is None
-
-  def test_should_set_tag_with_level(self):
-    token = SVG_TEXT('test')
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
-    doc.set_tag(token, TAG_1, level=2)
-    assert doc.get_tag(token, level=2) == TAG_1
-    assert doc.get_tag(token) is None
-
-  def test_should_return_all_tag_by_scope(self):
-    token = SVG_TEXT('test')
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
-    doc.set_tag(token, TAG_1)
-    doc.set_tag(token, TAG_2, scope=SCOPE_1)
-    assert doc.get_tag(token) == TAG_1
-    assert doc.get_tag(token, scope=SCOPE_1) == TAG_2
-    assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2}
-
-  def test_should_clear_tag_when_setting_tag_to_none(self):
-    token = SVG_TEXT('test')
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
-    doc.set_tag(token, TAG_1)
-    doc.set_tag(token, TAG_1, scope=SCOPE_1)
-    doc.set_tag(token, None)
-    doc.set_tag(token, None, scope=SCOPE_1)
-    assert doc.get_tag(token) is None
-    assert doc.get_tag(token, scope=SCOPE_1) is None
-
-  def test_should_not_fail_setting_empty_tag_to_none(self):
-    token = SVG_TEXT('test')
-    doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
-    doc.set_tag(token, None)
-    doc.set_tag(token, None, scope=SCOPE_1)
-    assert doc.get_tag(token) is None
-    assert doc.get_tag(token, scope=SCOPE_1) is None
+        assert list(doc.get_tokens_of_line(line)) == tokens
+
+    def test_should_tag_text_as_line_no(self):
+        text = SVG_TEXT()
+        doc = SvgStructuredDocument(
+            E.svg(
+                SVG_TEXT_LINE(text)
+            )
+        )
+        doc.set_tag(text, SvgStyleClasses.LINE_NO)
+        assert text.attrib['class'] == SvgStyleClasses.LINE_NO
+
+    def test_should_calculate_default_bounding_box(self):
+        text = SVG_TEXT('a', {
+            'x': '10',
+            'y': '11',
+            'font-size': '100'
+        })
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
+        assert doc.get_bounding_box(text) == BoundingBox(10, 11, 100 * 0.8, 100)
+
+    def test_should_estimate_width_based_on_number_of_characters(self):
+        s = 'abc'
+        text = SVG_TEXT(s, {
+            'x': '10',
+            'y': '11',
+            'font-size': '100'
+        })
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
+        assert doc.get_bounding_box(text) == BoundingBox(
+            10, 11, 100 * 0.8 * len(s), 100
+        )
+
+    def test_should_not_return_bounding_box_if_font_size_is_missing(self):
+        text = SVG_TEXT({
+            'x': '10',
+            'y': '11'
+        })
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
+        assert doc.get_bounding_box(text) is None
+
+    def test_should_use_bounding_box_if_available(self):
+        bounding_box = BoundingBox(11, 12, 101, 102)
+        text = SVG_TEXT('a', {
+            'x': '10',
+            'y': '11',
+            'font-size': '100',
+            SVGE_BOUNDING_BOX: format_bounding_box(bounding_box)
+        })
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
+        assert doc.get_bounding_box(text) == bounding_box
+
+    def test_should_be_able_to_set_bounding_box(self):
+        bounding_box = BoundingBox(11, 12, 101, 102)
+        text = SVG_TEXT('a', {
+            'x': '10',
+            'y': '11',
+            'font-size': '100'
+        })
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text)))
+        doc.set_bounding_box(text, bounding_box)
+        assert text.attrib[SVGE_BOUNDING_BOX] == format_bounding_box(bounding_box)
+
+    def test_should_use_viewbox_if_available(self):
+        bounding_box = BoundingBox(11, 12, 101, 102)
+        page = E.svg({
+            SVG_VIEWBOX_ATTRIB: format_bounding_box(bounding_box)
+        })
+        doc = SvgStructuredDocument(page)
+        assert doc.get_bounding_box(page) == bounding_box
+
+    def test_should_set_tag_without_scope(self):
+        token = SVG_TEXT('test')
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
+        doc.set_tag(token, TAG_1)
+        assert doc.get_tag(token) == TAG_1
+
+    def test_should_set_tag_with_scope(self):
+        token = SVG_TEXT('test')
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
+        doc.set_tag(token, TAG_1, scope=SCOPE_1)
+        assert doc.get_tag(token, scope=SCOPE_1) == TAG_1
+        assert doc.get_tag(token) is None
+
+    def test_should_set_tag_with_level(self):
+        token = SVG_TEXT('test')
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
+        doc.set_tag(token, TAG_1, level=2)
+        assert doc.get_tag(token, level=2) == TAG_1
+        assert doc.get_tag(token) is None
+
+    def test_should_return_all_tag_by_scope(self):
+        token = SVG_TEXT('test')
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
+        doc.set_tag(token, TAG_1)
+        doc.set_tag(token, TAG_2, scope=SCOPE_1)
+        assert doc.get_tag(token) == TAG_1
+        assert doc.get_tag(token, scope=SCOPE_1) == TAG_2
+        assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2}
+
+    def test_should_clear_tag_when_setting_tag_to_none(self):
+        token = SVG_TEXT('test')
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
+        doc.set_tag(token, TAG_1)
+        doc.set_tag(token, TAG_1, scope=SCOPE_1)
+        doc.set_tag(token, None)
+        doc.set_tag(token, None, scope=SCOPE_1)
+        assert doc.get_tag(token) is None
+        assert doc.get_tag(token, scope=SCOPE_1) is None
+
+    def test_should_not_fail_setting_empty_tag_to_none(self):
+        token = SVG_TEXT('test')
+        doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token)))
+        doc.set_tag(token, None)
+        doc.set_tag(token, None, scope=SCOPE_1)
+        assert doc.get_tag(token) is None
+        assert doc.get_tag(token, scope=SCOPE_1) is None
diff --git a/sciencebeam_gym/tools/calculate_class_weights.py b/sciencebeam_gym/tools/calculate_class_weights.py
index b8446ca..fed0464 100644
--- a/sciencebeam_gym/tools/calculate_class_weights.py
+++ b/sciencebeam_gym/tools/calculate_class_weights.py
@@ -7,254 +7,276 @@ import json
 
 import numpy as np
 import tensorflow as tf
-from tensorflow.python.lib.io import file_io # pylint: disable=E0611
+from tensorflow.python.lib.io import file_io  # pylint: disable=E0611
 from tqdm import tqdm
 
 from sciencebeam_gym.preprocess.color_map import (
-  parse_color_map_from_file
+    parse_color_map_from_file
 )
 
 from sciencebeam_gym.utils.tfrecord import (
-  iter_read_tfrecord_file_as_dict_list
+    iter_read_tfrecord_file_as_dict_list
 )
 
 from sciencebeam_gym.model_utils.channels import (
-  color_equals_mask_as_float,
-  calculate_color_masks
+    color_equals_mask_as_float,
+    calculate_color_masks
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def color_frequency(image, color):
-  return tf.reduce_sum(color_equals_mask_as_float(image, color))
+    return tf.reduce_sum(color_equals_mask_as_float(image, color))
+
 
 def get_shape(x):
-  try:
-    return x.shape
-  except AttributeError:
-    return tf.constant(x).shape
+    try:
+        return x.shape
+    except AttributeError:
+        return tf.constant(x).shape
+
 
 def calculate_sample_frequencies(image, colors, use_unknown_class=False):
-  color_masks = calculate_color_masks(image, colors, use_unknown_class)
-  return [
-    tf.reduce_sum(color_mask)
-    for color_mask in color_masks
-  ]
+    color_masks = calculate_color_masks(image, colors, use_unknown_class)
+    return [
+        tf.reduce_sum(color_mask)
+        for color_mask in color_masks
+    ]
+
 
 def iter_calculate_sample_frequencies(
-  images, colors, image_shape=None, image_format=None, use_unknown_class=False):
+        images, colors, image_shape=None, image_format=None, use_unknown_class=False):
+
+    with tf.Graph().as_default():
+        if image_format == 'png':
+            image_tensor = tf.placeholder(tf.string, shape=[], name='image')
+            decoded_image_tensor = tf.image.decode_png(image_tensor, channels=3)
+        else:
+            if image_shape is None:
+                image_shape = (None, None, 3)
+            image_tensor = tf.placeholder(tf.uint8, shape=image_shape, name='image')
+            decoded_image_tensor = image_tensor
+        get_logger().debug('decoded_image_tensor: %s', decoded_image_tensor)
+        frequency_tensors = calculate_sample_frequencies(
+            decoded_image_tensor, colors, use_unknown_class=use_unknown_class
+        )
+        with tf.Session() as session:
+            for image in images:
+                frequencies = session.run(frequency_tensors, {
+                    image_tensor: image
+                })
+                get_logger().debug('frequencies: %s', frequencies)
+                yield frequencies
 
-  with tf.Graph().as_default():
-    if image_format == 'png':
-      image_tensor = tf.placeholder(tf.string, shape=[], name='image')
-      decoded_image_tensor = tf.image.decode_png(image_tensor, channels=3)
-    else:
-      if image_shape is None:
-        image_shape = (None, None, 3)
-      image_tensor = tf.placeholder(tf.uint8, shape=image_shape, name='image')
-      decoded_image_tensor = image_tensor
-    get_logger().debug('decoded_image_tensor: %s', decoded_image_tensor)
-    frequency_tensors = calculate_sample_frequencies(
-      decoded_image_tensor, colors, use_unknown_class=use_unknown_class
-    )
-    with tf.Session() as session:
-      for image in images:
-        frequencies = session.run(frequency_tensors, {
-          image_tensor: image
-        })
-        get_logger().debug('frequencies: %s', frequencies)
-        yield frequencies
 
 def tf_calculate_efnet_weights_for_frequency_by_label(
-  frequency_by_label, return_zero_for_zero_frequency=True):
+        frequency_by_label, return_zero_for_zero_frequency=True):
+
+    total_frequency = tf.reduce_sum(frequency_by_label)
+    class_weights = 1.0 / tf.log(1.02 + frequency_by_label / total_frequency)
+    return (
+        class_weights * tf.cast(tf.minimum(frequency_by_label, 1), class_weights.dtype)
+        if return_zero_for_zero_frequency
+        else class_weights
+    )
 
-  total_frequency = tf.reduce_sum(frequency_by_label)
-  class_weights = 1.0 / tf.log(1.02 + frequency_by_label / total_frequency)
-  return (
-    class_weights * tf.cast(tf.minimum(frequency_by_label, 1), class_weights.dtype)
-    if return_zero_for_zero_frequency
-    else class_weights
-  )
 
 def calculate_efnet_weights_for_frequency_by_label(frequency_by_label):
-  total_frequency = sum(frequency_by_label)
-  return [
-    1 / np.log(1.02 + (frequency / total_frequency))
-    if frequency > 0 else 0.0
-    for frequency in frequency_by_label
-  ]
+    total_frequency = sum(frequency_by_label)
+    return [
+        1 / np.log(1.02 + (frequency / total_frequency))
+        if frequency > 0 else 0.0
+        for frequency in frequency_by_label
+    ]
+
 
 def sum_frequencies_by_label(frequencies_by_label):
-  return [sum(x) for x in frequencies_by_label]
+    return [sum(x) for x in frequencies_by_label]
+
 
 def calculate_efnet_weights_for_frequencies_by_label(frequencies_by_label):
-  return calculate_efnet_weights_for_frequency_by_label(
-    sum_frequencies_by_label(frequencies_by_label)
-  )
+    return calculate_efnet_weights_for_frequency_by_label(
+        sum_frequencies_by_label(frequencies_by_label)
+    )
+
 
 def calculate_median_class_weight(class_frequencies):
-  """
-  Perform median frequency balancing on the image files, given by the formula:
-    f = Median_freq_c / total_freq_c
-    where median_freq_c is the median frequency of the class
-    for all pixels of C that appeared in images
-    and total_freq_c is the total number of pixels of c in the total pixels
-    of the images where c appeared.
-  """
-
-  non_zero_frequencies = [f for f in class_frequencies if f != 0.0]
-  if not non_zero_frequencies:
-    return 0.0
-  get_logger().debug('non_zero_frequencies: %s', non_zero_frequencies)
-  total_freq_c = sum(non_zero_frequencies)
-  get_logger().debug('total_freq_c: %s', total_freq_c)
-  median_freq_c = np.median(non_zero_frequencies)
-  get_logger().debug('median_freq_c: %s', median_freq_c)
-  return median_freq_c / total_freq_c
+    """
+    Perform median frequency balancing on the image files, given by the formula:
+      f = Median_freq_c / total_freq_c
+      where median_freq_c is the median frequency of the class
+      for all pixels of C that appeared in images
+      and total_freq_c is the total number of pixels of c in the total pixels
+      of the images where c appeared.
+    """
+
+    non_zero_frequencies = [f for f in class_frequencies if f != 0.0]
+    if not non_zero_frequencies:
+        return 0.0
+    get_logger().debug('non_zero_frequencies: %s', non_zero_frequencies)
+    total_freq_c = sum(non_zero_frequencies)
+    get_logger().debug('total_freq_c: %s', total_freq_c)
+    median_freq_c = np.median(non_zero_frequencies)
+    get_logger().debug('median_freq_c: %s', median_freq_c)
+    return median_freq_c / total_freq_c
+
 
 def calculate_median_weights_for_frequencies(frequencies):
-  median_frequencies_balanced = [
-    calculate_median_class_weight(f)
-    for f in frequencies
-  ]
-  total = sum(median_frequencies_balanced)
-  return [
-    f / total
-    for f in median_frequencies_balanced
-  ]
+    median_frequencies_balanced = [
+        calculate_median_class_weight(f)
+        for f in frequencies
+    ]
+    total = sum(median_frequencies_balanced)
+    return [
+        f / total
+        for f in median_frequencies_balanced
+    ]
+
 
 def parse_color_map(color_map_filename):
-  with file_io.FileIO(color_map_filename, 'r') as config_f:
-    return parse_color_map_from_file(
-      config_f
-    )
+    with file_io.FileIO(color_map_filename, 'r') as config_f:
+        return parse_color_map_from_file(
+            config_f
+        )
+
 
 def transpose(m):
-  return zip(*m)
+    return zip(*m)
+
 
 def iter_images_for_tfrecord_paths(tfrecord_paths, image_key, progress=False):
-  for tfrecord_path in tfrecord_paths:
-    get_logger().info('tfrecord_path: %s', tfrecord_path)
-    filenames = file_io.get_matching_files(tfrecord_path)
-    with tqdm(list(filenames), leave=False, disable=not progress) as pbar:
-      for tfrecord_filename in pbar:
-        pbar.set_description('%-40s' % tfrecord_filename)
-        get_logger().debug('tfrecord_filename: %s', tfrecord_filename)
-        for d in iter_read_tfrecord_file_as_dict_list(tfrecord_filename, keys={image_key}):
-          yield d[image_key]
+    for tfrecord_path in tfrecord_paths:
+        get_logger().info('tfrecord_path: %s', tfrecord_path)
+        filenames = file_io.get_matching_files(tfrecord_path)
+        with tqdm(list(filenames), leave=False, disable=not progress) as pbar:
+            for tfrecord_filename in pbar:
+                pbar.set_description('%-40s' % tfrecord_filename)
+                get_logger().debug('tfrecord_filename: %s', tfrecord_filename)
+                for d in iter_read_tfrecord_file_as_dict_list(tfrecord_filename, keys={image_key}):
+                    yield d[image_key]
+
 
 def calculate_median_class_weights_for_tfrecord_paths_and_colors(
-  tfrecord_paths, image_key, colors, use_unknown_class=False, progress=False):
-
-  get_logger().debug('colors: %s', colors)
-  get_logger().info('loading tfrecords: %s', tfrecord_paths)
-  images = iter_images_for_tfrecord_paths(tfrecord_paths, image_key, progress=progress)
-  if progress:
-    images = list(images)
-    images = tqdm(images, 'analysing images', leave=False)
-  frequency_list = list(iter_calculate_sample_frequencies(
-    images, colors, image_format='png', use_unknown_class=use_unknown_class
-  ))
-  get_logger().debug('frequency_list: %s', frequency_list)
-  frequencies = transpose(frequency_list)
-  get_logger().debug('frequencies: %s', frequencies)
-  class_weights = calculate_median_weights_for_frequencies(frequencies)
-  return class_weights
+        tfrecord_paths, image_key, colors, use_unknown_class=False, progress=False):
+
+    get_logger().debug('colors: %s', colors)
+    get_logger().info('loading tfrecords: %s', tfrecord_paths)
+    images = iter_images_for_tfrecord_paths(tfrecord_paths, image_key, progress=progress)
+    if progress:
+        images = list(images)
+        images = tqdm(images, 'analysing images', leave=False)
+    frequency_list = list(iter_calculate_sample_frequencies(
+        images, colors, image_format='png', use_unknown_class=use_unknown_class
+    ))
+    get_logger().debug('frequency_list: %s', frequency_list)
+    frequencies = transpose(frequency_list)
+    get_logger().debug('frequencies: %s', frequencies)
+    class_weights = calculate_median_weights_for_frequencies(frequencies)
+    return class_weights
+
 
 def calculate_median_class_weights_for_tfrecord_paths_and_color_map(
-  tfrecord_paths, image_key, color_map, channels=None,
-  use_unknown_class=False, unknown_class_label='unknown',
-  progress=False):
-  if not channels:
-    channels = sorted(color_map.keys())
-  colors = [color_map[k] for k in channels]
-  class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
-    tfrecord_paths,
-    image_key,
-    colors,
-    progress=progress,
-    use_unknown_class=use_unknown_class
-  )
-  if use_unknown_class:
-    channels += [unknown_class_label]
-  return {
-    k: class_weight for k, class_weight in zip(channels, class_weights)
-  }
+        tfrecord_paths, image_key, color_map, channels=None,
+        use_unknown_class=False, unknown_class_label='unknown',
+        progress=False):
+    if not channels:
+        channels = sorted(color_map.keys())
+    colors = [color_map[k] for k in channels]
+    class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
+        tfrecord_paths,
+        image_key,
+        colors,
+        progress=progress,
+        use_unknown_class=use_unknown_class
+    )
+    if use_unknown_class:
+        channels += [unknown_class_label]
+    return {
+        k: class_weight for k, class_weight in zip(channels, class_weights)
+    }
+
 
 def str_to_bool(s):
-  return s.lower() in ('yes', 'true', '1')
+    return s.lower() in ('yes', 'true', '1')
+
 
 def str_to_list(s):
-  s = s.strip()
-  if not s:
-    return []
-  return [x.strip() for x in s.split(',')]
+    s = s.strip()
+    if not s:
+        return []
+    return [x.strip() for x in s.split(',')]
+
 
 def get_args_parser():
-  parser = argparse.ArgumentParser()
-  parser.add_argument(
-    '--tfrecord-paths',
-    required=True,
-    type=str,
-    action='append',
-    help='The paths to the tf-records files to analyse.'
-  )
-  parser.add_argument(
-    '--image-key',
-    required=False,
-    type=str,
-    help='The name of the image key to do the class weights on.'
-  )
-  parser.add_argument(
-    '--color-map',
-    required=True,
-    type=str,
-    help='The color-map filename.'
-  )
-  parser.add_argument(
-    '--channels',
-    type=str_to_list,
-    help='The channels to use (subset of color map), otherwise all of the labels will be used'
-  )
-  parser.add_argument(
-    '--use-unknown-class',
-    type=str_to_bool,
-    default=True,
-    help='Use unknown class channel'
-  )
-  parser.add_argument(
-    '--out',
-    required=False,
-    type=str,
-    help='The filename the output file (json), otherwise the output will be written to stdout.'
-  )
-  return parser
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--tfrecord-paths',
+        required=True,
+        type=str,
+        action='append',
+        help='The paths to the tf-records files to analyse.'
+    )
+    parser.add_argument(
+        '--image-key',
+        required=False,
+        type=str,
+        help='The name of the image key to do the class weights on.'
+    )
+    parser.add_argument(
+        '--color-map',
+        required=True,
+        type=str,
+        help='The color-map filename.'
+    )
+    parser.add_argument(
+        '--channels',
+        type=str_to_list,
+        help='The channels to use (subset of color map), otherwise all of the labels will be used'
+    )
+    parser.add_argument(
+        '--use-unknown-class',
+        type=str_to_bool,
+        default=True,
+        help='Use unknown class channel'
+    )
+    parser.add_argument(
+        '--out',
+        required=False,
+        type=str,
+        help='The filename the output file (json), otherwise the output will be written to stdout.'
+    )
+    return parser
+
 
 def parse_args(argv=None):
-  parser = get_args_parser()
-  parsed_args = parser.parse_args(argv)
-  return parsed_args
+    parser = get_args_parser()
+    parsed_args = parser.parse_args(argv)
+    return parsed_args
+
 
 def main(argv=None):
-  args = parse_args(argv)
-  color_map = parse_color_map(args.color_map)
-  class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
-    args.tfrecord_paths,
-    args.image_key,
-    color_map,
-    channels=args.channels,
-    use_unknown_class=args.use_unknown_class,
-    progress=True
-  )
-  get_logger().info('class_weights: %s', class_weights_map)
-  json_str = json.dumps(class_weights_map, indent=2)
-  if args.out:
-    with file_io.FileIO(args.out, 'wb') as out_f:
-      out_f.write(json_str)
-  else:
-    print(json_str)
+    args = parse_args(argv)
+    color_map = parse_color_map(args.color_map)
+    class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
+        args.tfrecord_paths,
+        args.image_key,
+        color_map,
+        channels=args.channels,
+        use_unknown_class=args.use_unknown_class,
+        progress=True
+    )
+    get_logger().info('class_weights: %s', class_weights_map)
+    json_str = json.dumps(class_weights_map, indent=2)
+    if args.out:
+        with file_io.FileIO(args.out, 'wb') as out_f:
+            out_f.write(json_str)
+    else:
+        print(json_str)
+
 
 if __name__ == '__main__':
-  logging.basicConfig(level='INFO')
-  main()
+    logging.basicConfig(level='INFO')
+    main()
diff --git a/sciencebeam_gym/tools/calculate_class_weights_test.py b/sciencebeam_gym/tools/calculate_class_weights_test.py
index 5dd9935..6fcd14b 100644
--- a/sciencebeam_gym/tools/calculate_class_weights_test.py
+++ b/sciencebeam_gym/tools/calculate_class_weights_test.py
@@ -6,346 +6,359 @@ from io import BytesIO
 
 from backports.tempfile import TemporaryDirectory
 
+import tensorflow as tf
+import numpy as np
+from PIL import Image
+
 from sciencebeam_utils.utils.num import (
-  assert_close,
-  assert_all_close
+    assert_close,
+    assert_all_close
 )
 
 from sciencebeam_gym.utils.tfrecord import (
-  dict_to_example,
-  write_examples_to_tfrecord
+    dict_to_example,
+    write_examples_to_tfrecord
 )
 
 from sciencebeam_gym.tools.calculate_class_weights import (
-  calculate_sample_frequencies,
-  iter_calculate_sample_frequencies,
-  calculate_median_class_weight,
-  calculate_median_weights_for_frequencies,
-  calculate_median_class_weights_for_tfrecord_paths_and_colors,
-  calculate_median_class_weights_for_tfrecord_paths_and_color_map,
-  calculate_efnet_weights_for_frequencies_by_label,
-  tf_calculate_efnet_weights_for_frequency_by_label
+    calculate_sample_frequencies,
+    iter_calculate_sample_frequencies,
+    calculate_median_class_weight,
+    calculate_median_weights_for_frequencies,
+    calculate_median_class_weights_for_tfrecord_paths_and_colors,
+    calculate_median_class_weights_for_tfrecord_paths_and_color_map,
+    calculate_efnet_weights_for_frequencies_by_label,
+    tf_calculate_efnet_weights_for_frequency_by_label
 )
 
-import tensorflow as tf
-import numpy as np
-from PIL import Image
 
 def color(i):
-  return (i, i, i)
+    return (i, i, i)
+
 
 COLOR_0 = color(0)
 COLOR_1 = color(1)
 COLOR_2 = color(2)
 COLOR_3 = color(3)
 
+
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class TestCalculateSampleFrequencies(object):
-  def test_should_return_zero_for_single_not_matching_color(self):
-    with tf.Session() as session:
-      assert session.run(calculate_sample_frequencies([[
-        COLOR_0
-      ]], [COLOR_1])) == [0.0]
-
-  def test_should_return_one_for_single_matching_color(self):
-    with tf.Session() as session:
-      assert session.run(calculate_sample_frequencies([[
-        COLOR_1
-      ]], [COLOR_1])) == [1.0]
-
-  def test_should_return_total_count_for_multiple_all_matching_color(self):
-    with tf.Session() as session:
-      assert session.run(calculate_sample_frequencies([[
-        COLOR_1, COLOR_1, COLOR_1
-      ]], [COLOR_1])) == [3.0]
-
-  def test_should_return_total_count_for_multiple_mixed_color(self):
-    with tf.Session() as session:
-      assert session.run(calculate_sample_frequencies([[
-        COLOR_1, COLOR_1, COLOR_2
-      ]], [COLOR_1, COLOR_2])) == [2.0, 1.0]
-
-  def test_should_include_unknown_class_count_if_enabled(self):
-    with tf.Session() as session:
-      assert session.run(calculate_sample_frequencies([[
-        COLOR_1, COLOR_2, COLOR_3
-      ]], [COLOR_1], use_unknown_class=True)) == [1.0, 2.0]
+    def test_should_return_zero_for_single_not_matching_color(self):
+        with tf.Session() as session:
+            assert session.run(calculate_sample_frequencies([[
+                COLOR_0
+            ]], [COLOR_1])) == [0.0]
+
+    def test_should_return_one_for_single_matching_color(self):
+        with tf.Session() as session:
+            assert session.run(calculate_sample_frequencies([[
+                COLOR_1
+            ]], [COLOR_1])) == [1.0]
+
+    def test_should_return_total_count_for_multiple_all_matching_color(self):
+        with tf.Session() as session:
+            assert session.run(calculate_sample_frequencies([[
+                COLOR_1, COLOR_1, COLOR_1
+            ]], [COLOR_1])) == [3.0]
+
+    def test_should_return_total_count_for_multiple_mixed_color(self):
+        with tf.Session() as session:
+            assert session.run(calculate_sample_frequencies([[
+                COLOR_1, COLOR_1, COLOR_2
+            ]], [COLOR_1, COLOR_2])) == [2.0, 1.0]
+
+    def test_should_include_unknown_class_count_if_enabled(self):
+        with tf.Session() as session:
+            assert session.run(calculate_sample_frequencies([[
+                COLOR_1, COLOR_2, COLOR_3
+            ]], [COLOR_1], use_unknown_class=True)) == [1.0, 2.0]
+
 
 def encode_png(data):
-  out = BytesIO()
-  data = np.array(data, dtype=np.uint8)
-  image_size = data.shape[:-1]
-  get_logger().debug('data type: %s', data.dtype)
-  get_logger().debug('image_size: %s', image_size)
-  mode = 'RGB'
-  image = Image.fromarray(data, mode)
-  image.save(out, 'png')
-  image_bytes = out.getvalue()
-  return image_bytes
+    out = BytesIO()
+    data = np.array(data, dtype=np.uint8)
+    image_size = data.shape[:-1]
+    get_logger().debug('data type: %s', data.dtype)
+    get_logger().debug('image_size: %s', image_size)
+    mode = 'RGB'
+    image = Image.fromarray(data, mode)
+    image.save(out, 'png')
+    image_bytes = out.getvalue()
+    return image_bytes
+
 
 class TestIterCalculateSampleFrequencies(object):
-  def test_should_return_zero_for_single_not_matching_color(self):
-    assert list(iter_calculate_sample_frequencies([
-      [[
-        COLOR_0
-      ]]
-    ], [COLOR_1], image_shape=(1, 1, 3))) == [[0.0]]
-
-  def test_should_infer_image_shape(self):
-    assert list(iter_calculate_sample_frequencies([
-      [[
-        COLOR_0
-      ]]
-    ], [COLOR_1])) == [[0.0]]
-
-  def test_should_include_unknown_class_if_enabled(self):
-    assert list(iter_calculate_sample_frequencies([
-      [[
-        COLOR_0
-      ]]
-    ], [COLOR_1], image_shape=(1, 1, 3), use_unknown_class=True)) == [[0.0, 1.0]]
-
-  def test_should_include_unknown_class_if_enabled_and_infer_shape(self):
-    assert list(iter_calculate_sample_frequencies([
-      [[
-        COLOR_0
-      ]]
-    ], [COLOR_1], use_unknown_class=True)) == [[0.0, 1.0]]
-
-  def test_should_return_total_count_for_multiple_mixed_color(self):
-    assert list(iter_calculate_sample_frequencies([
-      [[
-        COLOR_0, COLOR_0, COLOR_0
-      ]], [[
-        COLOR_0, COLOR_1, COLOR_2
-      ]], [[
-        COLOR_1, COLOR_1, COLOR_2
-      ]]
-    ], [COLOR_1, COLOR_2])) == [
-      [0.0, 0.0],
-      [1.0, 1.0],
-      [2.0, 1.0]
-    ]
-
-  def test_should_decode_png(self):
-    assert list(iter_calculate_sample_frequencies([
-      encode_png([[
-        COLOR_1
-      ]])
-    ], [COLOR_1], image_shape=(1, 1, 3), image_format='png')) == [[1.0]]
-
-  def test_should_infer_shape_when_decoding_png(self):
-    assert list(iter_calculate_sample_frequencies([
-      encode_png([[
-        COLOR_1
-      ]])
-    ], [COLOR_1], image_format='png')) == [[1.0]]
-
-  def test_should_infer_shape_when_decoding_png_and_include_unknown_class(self):
-    assert list(iter_calculate_sample_frequencies([
-      encode_png([[
-        COLOR_1, COLOR_2, COLOR_3
-      ]])
-    ], [COLOR_1], image_format='png', use_unknown_class=True)) == [[1.0, 2.0]]
+    def test_should_return_zero_for_single_not_matching_color(self):
+        assert list(iter_calculate_sample_frequencies([
+            [[
+                COLOR_0
+            ]]
+        ], [COLOR_1], image_shape=(1, 1, 3))) == [[0.0]]
+
+    def test_should_infer_image_shape(self):
+        assert list(iter_calculate_sample_frequencies([
+            [[
+                COLOR_0
+            ]]
+        ], [COLOR_1])) == [[0.0]]
+
+    def test_should_include_unknown_class_if_enabled(self):
+        assert list(iter_calculate_sample_frequencies([
+            [[
+                COLOR_0
+            ]]
+        ], [COLOR_1], image_shape=(1, 1, 3), use_unknown_class=True)) == [[0.0, 1.0]]
+
+    def test_should_include_unknown_class_if_enabled_and_infer_shape(self):
+        assert list(iter_calculate_sample_frequencies([
+            [[
+                COLOR_0
+            ]]
+        ], [COLOR_1], use_unknown_class=True)) == [[0.0, 1.0]]
+
+    def test_should_return_total_count_for_multiple_mixed_color(self):
+        assert list(iter_calculate_sample_frequencies([
+            [[
+                COLOR_0, COLOR_0, COLOR_0
+            ]], [[
+                COLOR_0, COLOR_1, COLOR_2
+            ]], [[
+                COLOR_1, COLOR_1, COLOR_2
+            ]]
+        ], [COLOR_1, COLOR_2])) == [
+            [0.0, 0.0],
+            [1.0, 1.0],
+            [2.0, 1.0]
+        ]
+
+    def test_should_decode_png(self):
+        assert list(iter_calculate_sample_frequencies([
+            encode_png([[
+                COLOR_1
+            ]])
+        ], [COLOR_1], image_shape=(1, 1, 3), image_format='png')) == [[1.0]]
+
+    def test_should_infer_shape_when_decoding_png(self):
+        assert list(iter_calculate_sample_frequencies([
+            encode_png([[
+                COLOR_1
+            ]])
+        ], [COLOR_1], image_format='png')) == [[1.0]]
+
+    def test_should_infer_shape_when_decoding_png_and_include_unknown_class(self):
+        assert list(iter_calculate_sample_frequencies([
+            encode_png([[
+                COLOR_1, COLOR_2, COLOR_3
+            ]])
+        ], [COLOR_1], image_format='png', use_unknown_class=True)) == [[1.0, 2.0]]
+
 
 class TestTfCalculateEfnetForFrequencyByLabel(object):
-  def test_should_return_same_value_for_classes_with_same_frequencies(self):
-    with tf.Graph().as_default():
-      with tf.Session():
-        frequencies = [1, 1]
-        result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval()
+    def test_should_return_same_value_for_classes_with_same_frequencies(self):
+        with tf.Graph().as_default():
+            with tf.Session():
+                frequencies = [1, 1]
+                result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval()
+                assert result[0] == result[1]
+
+    def test_should_return_higher_value_for_less_frequent_occuring_class(self):
+        with tf.Graph().as_default():
+            with tf.Session():
+                frequencies = [2, 1]
+                result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval()
+                assert result[0] < result[1]
+
+    def test_should_return_zero_value_for_not_occuring_class(self):
+        with tf.Graph().as_default():
+            with tf.Session():
+                frequencies = [1, 0]
+                result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval()
+                assert result[-1] == 0.0
+
+
+class TestCalculateEfnetForFrequenciesByLabel(object):
+    def test_should_return_same_value_for_classes_with_same_frequencies(self):
+        frequencies = [
+            [0, 1],
+            [0, 1]
+        ]
+        result = calculate_efnet_weights_for_frequencies_by_label(frequencies)
         assert result[0] == result[1]
 
-  def test_should_return_higher_value_for_less_frequent_occuring_class(self):
-    with tf.Graph().as_default():
-      with tf.Session():
-        frequencies = [2, 1]
-        result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval()
+    def test_should_return_higher_value_for_less_frequent_occuring_class(self):
+        frequencies = [
+            [1, 1],
+            [0, 1]
+        ]
+        result = calculate_efnet_weights_for_frequencies_by_label(frequencies)
         assert result[0] < result[1]
 
-  def test_should_return_zero_value_for_not_occuring_class(self):
-    with tf.Graph().as_default():
-      with tf.Session():
-        frequencies = [1, 0]
-        result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval()
+    def test_should_return_zero_value_for_not_occuring_class(self):
+        frequencies = [
+            [1, 1],
+            [0, 0]
+        ]
+        result = calculate_efnet_weights_for_frequencies_by_label(frequencies)
         assert result[-1] == 0.0
 
-class TestCalculateEfnetForFrequenciesByLabel(object):
-  def test_should_return_same_value_for_classes_with_same_frequencies(self):
-    frequencies = [
-      [0, 1],
-      [0, 1]
-    ]
-    result = calculate_efnet_weights_for_frequencies_by_label(frequencies)
-    assert result[0] == result[1]
-
-  def test_should_return_higher_value_for_less_frequent_occuring_class(self):
-    frequencies = [
-      [1, 1],
-      [0, 1]
-    ]
-    result = calculate_efnet_weights_for_frequencies_by_label(frequencies)
-    assert result[0] < result[1]
-
-  def test_should_return_zero_value_for_not_occuring_class(self):
-    frequencies = [
-      [1, 1],
-      [0, 0]
-    ]
-    result = calculate_efnet_weights_for_frequencies_by_label(frequencies)
-    assert result[-1] == 0.0
 
 class TestCalculateMedianClassWeight(object):
-  def test_should_return_median_frequency_balanced_for_same_frequencies(self):
-    assert calculate_median_class_weight([3, 3, 3]) == 1 / 3
+    def test_should_return_median_frequency_balanced_for_same_frequencies(self):
+        assert calculate_median_class_weight([3, 3, 3]) == 1 / 3
+
+    def test_should_return_median_frequence_balanced_for_different_frequencies(self):
+        assert calculate_median_class_weight([1, 3, 5]) == 1 / 3
 
-  def test_should_return_median_frequence_balanced_for_different_frequencies(self):
-    assert calculate_median_class_weight([1, 3, 5]) == 1 / 3
+    def test_should_return_zero_for_all_zero_frequencies(self):
+        assert calculate_median_class_weight([0, 0, 0]) == 0.0
 
-  def test_should_return_zero_for_all_zero_frequencies(self):
-    assert calculate_median_class_weight([0, 0, 0]) == 0.0
 
 class TestCalculateWeightsForFrequencies(object):
-  def test_should_return_one_for_single_class(self):
-    assert calculate_median_weights_for_frequencies([
-      [3, 3, 3]
-    ]) == [1.0]
-
-  def test_should_return_50p_for_classes_with_same_frequencies(self):
-    assert calculate_median_weights_for_frequencies([
-      [3, 3, 3],
-      [3, 3, 3]
-    ]) == [0.5, 0.5]
-
-  def test_should_return_higher_value_for_less_frequent_occuring_class(self):
-    frequencies = [
-      [1, 1],
-      [1, 1],
-      [0, 1]
-    ]
-    result = calculate_median_weights_for_frequencies(frequencies)
-    get_logger().debug('result: %s', result)
-    assert_close(sum(result), 1.0)
-    assert_all_close(result, [0.25, 0.25, 0.5], atol=0.001)
-
-  def test_should_return_zero_value_for_not_occuring_class(self):
-    frequencies = [
-      [1, 1],
-      [1, 1],
-      [0, 0]
-    ]
-    result = calculate_median_weights_for_frequencies(frequencies)
-    get_logger().debug('result: %s', result)
-    assert_close(sum(result), 1.0)
-    assert_all_close(result, [0.5, 0.5, 0.0], atol=0.001)
+    def test_should_return_one_for_single_class(self):
+        assert calculate_median_weights_for_frequencies([
+            [3, 3, 3]
+        ]) == [1.0]
+
+    def test_should_return_50p_for_classes_with_same_frequencies(self):
+        assert calculate_median_weights_for_frequencies([
+            [3, 3, 3],
+            [3, 3, 3]
+        ]) == [0.5, 0.5]
+
+    def test_should_return_higher_value_for_less_frequent_occuring_class(self):
+        frequencies = [
+            [1, 1],
+            [1, 1],
+            [0, 1]
+        ]
+        result = calculate_median_weights_for_frequencies(frequencies)
+        get_logger().debug('result: %s', result)
+        assert_close(sum(result), 1.0)
+        assert_all_close(result, [0.25, 0.25, 0.5], atol=0.001)
+
+    def test_should_return_zero_value_for_not_occuring_class(self):
+        frequencies = [
+            [1, 1],
+            [1, 1],
+            [0, 0]
+        ]
+        result = calculate_median_weights_for_frequencies(frequencies)
+        get_logger().debug('result: %s', result)
+        assert_close(sum(result), 1.0)
+        assert_all_close(result, [0.5, 0.5, 0.0], atol=0.001)
+
 
 class TestCalculateMedianClassWeightsForFfrecordPathsAndColors(object):
-  def test_should_calculate_median_class_weights_for_single_image_and_single_color(self):
-    with TemporaryDirectory() as path:
-      tfrecord_filename = os.path.join(path, 'data.tfrecord')
-      get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
-      write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
-        'image': encode_png([[
-          COLOR_1
-        ]])
-      })])
-      class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
-        [tfrecord_filename], 'image', [COLOR_1]
-      )
-      assert class_weights == [1.0]
-
-  def test_should_calculate_median_class_weights_for_multiple_image_and_multiple_images(self):
-    with TemporaryDirectory() as path:
-      tfrecord_filename = os.path.join(path, 'data.tfrecord')
-      get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
-      write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
-        'image': encode_png([[
-          COLOR_0, COLOR_1, COLOR_2
-        ]])
-      }), dict_to_example({
-        'image': encode_png([[
-          COLOR_1, COLOR_2, COLOR_3
-        ]])
-      })])
-      class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
-        [tfrecord_filename], 'image', [COLOR_1, COLOR_2, COLOR_3]
-      )
-      assert class_weights == [0.25, 0.25, 0.5]
-
-  def test_should_return_zero_for_non_occuring_class(self):
-    with TemporaryDirectory() as path:
-      tfrecord_filename = os.path.join(path, 'data.tfrecord')
-      get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
-      write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
-        'image': encode_png([[
-          COLOR_1
-        ]])
-      })])
-      class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
-        [tfrecord_filename], 'image', [COLOR_1, COLOR_2]
-      )
-      assert class_weights == [1.0, 0.0]
+    def test_should_calculate_median_class_weights_for_single_image_and_single_color(self):
+        with TemporaryDirectory() as path:
+            tfrecord_filename = os.path.join(path, 'data.tfrecord')
+            get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
+            write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
+                'image': encode_png([[
+                    COLOR_1
+                ]])
+            })])
+            class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
+                [tfrecord_filename], 'image', [COLOR_1]
+            )
+            assert class_weights == [1.0]
+
+    def test_should_calculate_median_class_weights_for_multiple_image_and_multiple_images(self):
+        with TemporaryDirectory() as path:
+            tfrecord_filename = os.path.join(path, 'data.tfrecord')
+            get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
+            write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
+                'image': encode_png([[
+                    COLOR_0, COLOR_1, COLOR_2
+                ]])
+            }), dict_to_example({
+                'image': encode_png([[
+                    COLOR_1, COLOR_2, COLOR_3
+                ]])
+            })])
+            class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
+                [tfrecord_filename], 'image', [COLOR_1, COLOR_2, COLOR_3]
+            )
+            assert class_weights == [0.25, 0.25, 0.5]
+
+    def test_should_return_zero_for_non_occuring_class(self):
+        with TemporaryDirectory() as path:
+            tfrecord_filename = os.path.join(path, 'data.tfrecord')
+            get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
+            write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
+                'image': encode_png([[
+                    COLOR_1
+                ]])
+            })])
+            class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors(
+                [tfrecord_filename], 'image', [COLOR_1, COLOR_2]
+            )
+            assert class_weights == [1.0, 0.0]
+
 
 class TestCalculateMedianClassWeightsForFfrecordPathsAndColorMap(object):
-  def test_should_calculate_median_class_weights_for_single_image_and_single_color(self):
-    with TemporaryDirectory() as path:
-      tfrecord_filename = os.path.join(path, 'data.tfrecord')
-      get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
-      write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
-        'image': encode_png([[
-          COLOR_1, COLOR_2
-        ]])
-      })])
-      class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
-        [tfrecord_filename], 'image', {
-          'color1': COLOR_1,
-          'color2': COLOR_2,
-          'color3': COLOR_3
-        },
-        channels=['color1', 'color2']
-      )
-      assert class_weights_map == {
-        'color1': 0.5,
-        'color2': 0.5
-      }
-
-  def test_should_use_color_map_keys_as_channels_by_default(self):
-    with TemporaryDirectory() as path:
-      tfrecord_filename = os.path.join(path, 'data.tfrecord')
-      get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
-      write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
-        'image': encode_png([[
-          COLOR_1, COLOR_2
-        ]])
-      })])
-      class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
-        [tfrecord_filename], 'image', {
-          'color1': COLOR_1,
-          'color2': COLOR_2
-        }
-      )
-      assert set(class_weights_map.keys()) == {'color1', 'color2'}
-
-  def test_should_include_unknown_class_if_enabled(self):
-    with TemporaryDirectory() as path:
-      tfrecord_filename = os.path.join(path, 'data.tfrecord')
-      get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
-      write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
-        'image': encode_png([[
-          COLOR_0, COLOR_1, COLOR_2, COLOR_3
-        ]])
-      })])
-      class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
-        [tfrecord_filename], 'image', {
-          'color1': COLOR_1,
-          'color2': COLOR_2
-        },
-        use_unknown_class=True,
-        unknown_class_label='unknown'
-      )
-      assert set(class_weights_map.keys()) == {'color1', 'color2', 'unknown'}
+    def test_should_calculate_median_class_weights_for_single_image_and_single_color(self):
+        with TemporaryDirectory() as path:
+            tfrecord_filename = os.path.join(path, 'data.tfrecord')
+            get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
+            write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
+                'image': encode_png([[
+                    COLOR_1, COLOR_2
+                ]])
+            })])
+            class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
+                [tfrecord_filename], 'image', {
+                    'color1': COLOR_1,
+                    'color2': COLOR_2,
+                    'color3': COLOR_3
+                },
+                channels=['color1', 'color2']
+            )
+            assert class_weights_map == {
+                'color1': 0.5,
+                'color2': 0.5
+            }
+
+    def test_should_use_color_map_keys_as_channels_by_default(self):
+        with TemporaryDirectory() as path:
+            tfrecord_filename = os.path.join(path, 'data.tfrecord')
+            get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
+            write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
+                'image': encode_png([[
+                    COLOR_1, COLOR_2
+                ]])
+            })])
+            class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
+                [tfrecord_filename], 'image', {
+                    'color1': COLOR_1,
+                    'color2': COLOR_2
+                }
+            )
+            assert set(class_weights_map.keys()) == {'color1', 'color2'}
+
+    def test_should_include_unknown_class_if_enabled(self):
+        with TemporaryDirectory() as path:
+            tfrecord_filename = os.path.join(path, 'data.tfrecord')
+            get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename)
+            write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({
+                'image': encode_png([[
+                    COLOR_0, COLOR_1, COLOR_2, COLOR_3
+                ]])
+            })])
+            class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map(
+                [tfrecord_filename], 'image', {
+                    'color1': COLOR_1,
+                    'color2': COLOR_2
+                },
+                use_unknown_class=True,
+                unknown_class_label='unknown'
+            )
+            assert set(class_weights_map.keys()) == {'color1', 'color2', 'unknown'}
diff --git a/sciencebeam_gym/tools/colorize_image.py b/sciencebeam_gym/tools/colorize_image.py
index 438d7fa..8306996 100644
--- a/sciencebeam_gym/tools/colorize_image.py
+++ b/sciencebeam_gym/tools/colorize_image.py
@@ -9,93 +9,101 @@ from six.moves.configparser import ConfigParser
 
 from sciencebeam_gym.utils.tf import FileIO
 
+
 def get_args_parser():
-  parser = argparse.ArgumentParser()
-  parser.add_argument(
-    '--color_map',
-    default='color_map.conf',
-    type=str,
-    help='The path to the color map configuration.'
-  )
-  parser.add_argument(
-    '--input_image',
-    required=True,
-    type=str,
-    help='The path to the input image.'
-  )
-  parser.add_argument(
-    '--output_image',
-    required=False,
-    type=str,
-    help='The path to the output image.'
-  )
-  return parser
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--color_map',
+        default='color_map.conf',
+        type=str,
+        help='The path to the color map configuration.'
+    )
+    parser.add_argument(
+        '--input_image',
+        required=True,
+        type=str,
+        help='The path to the input image.'
+    )
+    parser.add_argument(
+        '--output_image',
+        required=False,
+        type=str,
+        help='The path to the output image.'
+    )
+    return parser
+
 
 def parse_args(argv=None):
-  parser = get_args_parser()
-  parsed_args, _ = parser.parse_known_args(argv)
-  return parsed_args
+    parser = get_args_parser()
+    parsed_args, _ = parser.parse_known_args(argv)
+    return parsed_args
+
 
 def parse_color_map_from_configparser(color_map_config):
-  num_pattern = re.compile(r'(\d+)')
-  rgb_pattern = re.compile(r'\((\d+),(\d+),(\d+)\)')
-
-  def parse_color(s):
-    m = num_pattern.match(s)
-    if m:
-      x = int(m.group(1))
-      return (x, x, x)
-    else:
-      m = rgb_pattern.match(s)
-      if m:
-        return (int(m.group(1)), int(m.group(2)), int(m.group(3)))
-    raise Exception('invalid color value: {}'.format(s))
-
-  color_map = dict()
-  for k, v in color_map_config.items('color_map'):
-    color_map[parse_color(k)] = parse_color(v)
-  return color_map
+    num_pattern = re.compile(r'(\d+)')
+    rgb_pattern = re.compile(r'\((\d+),(\d+),(\d+)\)')
+
+    def parse_color(s):
+        m = num_pattern.match(s)
+        if m:
+            x = int(m.group(1))
+            return (x, x, x)
+        else:
+            m = rgb_pattern.match(s)
+            if m:
+                return (int(m.group(1)), int(m.group(2)), int(m.group(3)))
+        raise Exception('invalid color value: {}'.format(s))
+
+    color_map = dict()
+    for k, v in color_map_config.items('color_map'):
+        color_map[parse_color(k)] = parse_color(v)
+    return color_map
+
 
 def parse_color_map_from_file(f):
-  color_map_config = ConfigParser()
-  color_map_config.readfp(f)
-  return parse_color_map_from_configparser(color_map_config)
+    color_map_config = ConfigParser()
+    color_map_config.readfp(f)
+    return parse_color_map_from_configparser(color_map_config)
+
 
 def parse_color_map(f):
-  return parse_color_map_from_file(f)
+    return parse_color_map_from_file(f)
+
 
 def map_colors(img, color_map):
-  if color_map is None or len(color_map) == 0:
+    if color_map is None or len(color_map) == 0:
+        return img
+    original_data = img.getdata()
+    mapped_data = [
+        color_map.get(color, color)
+        for color in original_data
+    ]
+    img.putdata(mapped_data)
     return img
-  original_data = img.getdata()
-  mapped_data = [
-    color_map.get(color, color)
-    for color in original_data
-  ]
-  img.putdata(mapped_data)
-  return img
+
 
 def main():
-  from PIL import Image
+    from PIL import Image
+
+    logger = logging.getLogger(__name__)
+    args = parse_args()
 
-  logger = logging.getLogger(__name__)
-  args = parse_args()
+    with FileIO(args.color_map, 'r') as config_f:
+        color_map = parse_color_map(config_f)
 
-  with FileIO(args.color_map, 'r') as config_f:
-    color_map = parse_color_map(config_f)
+    logger.info('read %s color mappings', len(color_map))
 
-  logger.info('read {} color mappings'.format(len(color_map)))
+    with FileIO(args.input_image, 'rb') as input_f:
+        image_bytes = input_f.read()
+        img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
 
-  with FileIO(args.input_image, 'rb') as input_f:
-    image_bytes = input_f.read()
-    img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
+        img = map_colors(img, color_map)
 
-    img = map_colors(img, color_map)
+        with FileIO(args.output_image, 'wb') as output_f:
+            img.save(output_f, 'png')
 
-    with FileIO(args.output_image, 'wb') as output_f:
-      img.save(output_f, 'png')
 
 if __name__ == "__main__":
-  logging.basicConfig(level=logging.INFO)
+    logging.basicConfig(level=logging.INFO)
 
-  main()
+    main()
diff --git a/sciencebeam_gym/tools/inspect_tfrecords.py b/sciencebeam_gym/tools/inspect_tfrecords.py
index c3373b1..19416e8 100644
--- a/sciencebeam_gym/tools/inspect_tfrecords.py
+++ b/sciencebeam_gym/tools/inspect_tfrecords.py
@@ -7,94 +7,100 @@ import tensorflow as tf
 from tensorflow.python.lib.io import file_io
 
 from sciencebeam_gym.utils.tf import (
-  FileIO
+    FileIO
 )
 
+
 def get_args_parser():
-  parser = argparse.ArgumentParser()
-  parser.add_argument(
-    '--records_paths',
-    required=True,
-    type=str,
-    action='append',
-    help='The paths to the tf-records files to inspect.'
-  )
-  parser.add_argument(
-    '--inspect_key',
-    required=False,
-    type=str,
-    help='The name of the key to further inspect.'
-  )
-  parser.add_argument(
-    '--extract_dir',
-    required=False,
-    default=".",
-    type=str,
-    help='The directory to extract to.'
-  )
-  parser.add_argument(
-    '--extract_image',
-    required=False,
-    action='append',
-    type=str,
-    help='The name of the key to extract as an image.'
-  )
-  return parser
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--records_paths',
+        required=True,
+        type=str,
+        action='append',
+        help='The paths to the tf-records files to inspect.'
+    )
+    parser.add_argument(
+        '--inspect_key',
+        required=False,
+        type=str,
+        help='The name of the key to further inspect.'
+    )
+    parser.add_argument(
+        '--extract_dir',
+        required=False,
+        default=".",
+        type=str,
+        help='The directory to extract to.'
+    )
+    parser.add_argument(
+        '--extract_image',
+        required=False,
+        action='append',
+        type=str,
+        help='The name of the key to extract as an image.'
+    )
+    return parser
+
 
 def parse_args(argv=None):
-  parser = get_args_parser()
-  parsed_args, _ = parser.parse_known_args(argv)
-  return parsed_args
+    parser = get_args_parser()
+    parsed_args, _ = parser.parse_known_args(argv)
+    return parsed_args
+
 
 def get_matching_files_for_paths(paths):
-  files = []
-  for path in paths:
-    files.extend(file_io.get_matching_files(path))
-  # logging.info('files: %s (%s)', files, paths)
-  return files
+    files = []
+    for path in paths:
+        files.extend(file_io.get_matching_files(path))
+    # logging.info('files: %s (%s)', files, paths)
+    return files
+
 
 def main():
-  args = parse_args()
-
-  files = get_matching_files_for_paths(args.records_paths)
-
-  if args.extract_image:
-    file_io.recursive_create_dir(args.extract_dir)
-
-  total_count = 0
-  for f in files:
-    options = None
-    if f.endswith('.gz'):
-      options = tf.python_io.TFRecordOptions(
-        compression_type=tf.python_io.TFRecordCompressionType.GZIP)
-    print("file:", f)
-    count = 0
-    for i, example in enumerate(tf.python_io.tf_record_iterator(f, options=options)):
-      result = tf.train.Example.FromString(example)
-      if i == 0:
-        print("  features:", result.features.feature.keys())
-        if args.inspect_key:
-          print("  first value of {}:\n    {}".format(
-            args.inspect_key,
-            result.features.feature.get(args.inspect_key).bytes_list.value[0]
-          ))
-        if args.extract_image:
-          for extract_image_key in args.extract_image:
-            image_bytes = result.features.feature.get(extract_image_key).bytes_list.value[0]
-            print("  image size %d bytes (%s)" % (len(image_bytes), type(image_bytes)))
-            image_filename = os.path.join(
-              args.extract_dir,
-              '{}-{}-{}.png'.format(os.path.basename(f), extract_image_key, i)
-            )
-            print("  extracting image to {}".format(image_filename))
-            with FileIO(image_filename, 'wb') as image_f:
-              image_f.write(image_bytes)
-      count += 1
-    print("  found {} records".format(count))
-    total_count += count
-  print("found total of {} records in {} files".format(total_count, len(files)))
+    args = parse_args()
+
+    files = get_matching_files_for_paths(args.records_paths)
+
+    if args.extract_image:
+        file_io.recursive_create_dir(args.extract_dir)
+
+    total_count = 0
+    for f in files:
+        options = None
+        if f.endswith('.gz'):
+            options = tf.python_io.TFRecordOptions(
+                compression_type=tf.python_io.TFRecordCompressionType.GZIP)
+        print("file:", f)
+        count = 0
+        for i, example in enumerate(tf.python_io.tf_record_iterator(f, options=options)):
+            result = tf.train.Example.FromString(example)  # pylint: disable=no-member
+            if i == 0:
+                print("  features:", result.features.feature.keys())
+                if args.inspect_key:
+                    print("  first value of {}:\n    {}".format(
+                        args.inspect_key,
+                        result.features.feature.get(args.inspect_key).bytes_list.value[0]
+                    ))
+                if args.extract_image:
+                    for extract_image_key in args.extract_image:
+                        image_bytes = result.features.feature.get(
+                            extract_image_key).bytes_list.value[0]
+                        print("  image size %d bytes (%s)" % (len(image_bytes), type(image_bytes)))
+                        image_filename = os.path.join(
+                            args.extract_dir,
+                            '{}-{}-{}.png'.format(os.path.basename(f), extract_image_key, i)
+                        )
+                        print("  extracting image to {}".format(image_filename))
+                        with FileIO(image_filename, 'wb') as image_f:
+                            image_f.write(image_bytes)
+            count += 1
+        print("  found {} records".format(count))
+        total_count += count
+    print("found total of {} records in {} files".format(total_count, len(files)))
+
 
 if __name__ == "__main__":
-  logging.basicConfig(level=logging.INFO)
+    logging.basicConfig(level=logging.INFO)
 
-  main()
+    main()
diff --git a/sciencebeam_gym/tools/resize_image.py b/sciencebeam_gym/tools/resize_image.py
index d2f250f..d407a28 100644
--- a/sciencebeam_gym/tools/resize_image.py
+++ b/sciencebeam_gym/tools/resize_image.py
@@ -3,59 +3,62 @@ from __future__ import absolute_import
 import argparse
 import io
 
-from six.moves.configparser import ConfigParser
-
 from PIL import Image
 
 from sciencebeam_gym.utils.tf import FileIO
 
+
 def get_args_parser():
-  parser = argparse.ArgumentParser()
-  parser.add_argument(
-      '--image_width',
-      type=int,
-      required=True,
-      help='Resize images to the specified width')
-  parser.add_argument(
-      '--image_height',
-      type=int,
-      required=True,
-      help='Resize images to the specified height')
-  parser.add_argument(
-    '--input_image',
-    required=True,
-    type=str,
-    help='The path to the input image.'
-  )
-  parser.add_argument(
-    '--output_image',
-    required=False,
-    type=str,
-    help='The path to the output image.'
-  )
-  return parser
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--image_width',
+        type=int,
+        required=True,
+        help='Resize images to the specified width')
+    parser.add_argument(
+        '--image_height',
+        type=int,
+        required=True,
+        help='Resize images to the specified height')
+    parser.add_argument(
+        '--input_image',
+        required=True,
+        type=str,
+        help='The path to the input image.'
+    )
+    parser.add_argument(
+        '--output_image',
+        required=False,
+        type=str,
+        help='The path to the output image.'
+    )
+    return parser
+
 
 def parse_args(argv=None):
-  parser = get_args_parser()
-  parsed_args, _ = parser.parse_known_args(argv)
-  return parsed_args
+    parser = get_args_parser()
+    parsed_args, _ = parser.parse_known_args(argv)
+    return parsed_args
+
 
 def image_resize_bicubic(image, size):
-  return image.resize(size, Image.BICUBIC)
+    return image.resize(size, Image.BICUBIC)
+
 
 def main():
-  args = parse_args()
+    args = parse_args()
+
+    image_size = (args.image_width, args.image_height)
 
-  image_size = (args.image_width, args.image_height)
+    with FileIO(args.input_image, 'rb') as input_f:
+        image_bytes = input_f.read()
+        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
 
-  with FileIO(args.input_image, 'rb') as input_f:
-    image_bytes = input_f.read()
-    image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
+        image = image_resize_bicubic(image, image_size)
 
-    image = image_resize_bicubic(image, image_size)
+        with FileIO(args.output_image, 'wb') as output_f:
+            image.save(output_f, 'png')
 
-    with FileIO(args.output_image, 'wb') as output_f:
-      image.save(output_f, 'png')
 
 if __name__ == "__main__":
-  main()
+    main()
diff --git a/sciencebeam_gym/trainer/checkpoint.py b/sciencebeam_gym/trainer/checkpoint.py
index 1c4fd80..a849858 100644
--- a/sciencebeam_gym/trainer/checkpoint.py
+++ b/sciencebeam_gym/trainer/checkpoint.py
@@ -3,38 +3,40 @@ import logging
 import tensorflow as tf
 
 from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import (
-  batch_dimensions_to_most_likely_colors_list
+    batch_dimensions_to_most_likely_colors_list
 )
 
 from sciencebeam_gym.inference_model import (
-  InferenceModel
+    InferenceModel
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def load_last_checkpoint_as_inference_model(model, checkpoint_path, session=None):
-  if session is None:
-    session = tf.get_default_session()
-  assert session is not None
-  last_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
-  get_logger().info(
-    'last_checkpoint: %s (%s)', last_checkpoint, checkpoint_path
-  )
-  tensors = model.build_predict_graph()
-  inputs_tensor = tensors.inputs['image']
-  outputs_tensor = tensors.pred
-  if model.use_separate_channels:
-    outputs_tensor = batch_dimensions_to_most_likely_colors_list(
-      outputs_tensor,
-      model.dimension_colors_with_unknown
+    if session is None:
+        session = tf.get_default_session()
+    assert session is not None
+    last_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
+    get_logger().info(
+        'last_checkpoint: %s (%s)', last_checkpoint, checkpoint_path
+    )
+    tensors = model.build_predict_graph()
+    inputs_tensor = tensors.inputs['image']
+    outputs_tensor = tensors.pred
+    if model.use_separate_channels:
+        outputs_tensor = batch_dimensions_to_most_likely_colors_list(
+            outputs_tensor,
+            model.dimension_colors_with_unknown
+        )
+    saver = tf.train.Saver()
+    saver.restore(session, last_checkpoint)
+    labels = tf.constant(model.dimension_labels_with_unknown)
+    colors = tf.constant(model.dimension_colors_with_unknown)
+    inference_model = InferenceModel(
+        inputs_tensor, outputs_tensor,
+        labels, colors
     )
-  saver = tf.train.Saver()
-  saver.restore(session, last_checkpoint)
-  labels = tf.constant(model.dimension_labels_with_unknown)
-  colors = tf.constant(model.dimension_colors_with_unknown)
-  inference_model = InferenceModel(
-    inputs_tensor, outputs_tensor,
-    labels, colors
-  )
-  return inference_model
+    return inference_model
diff --git a/sciencebeam_gym/trainer/data/examples.py b/sciencebeam_gym/trainer/data/examples.py
index 0c13b04..d71c650 100644
--- a/sciencebeam_gym/trainer/data/examples.py
+++ b/sciencebeam_gym/trainer/data/examples.py
@@ -2,125 +2,133 @@ import logging
 from functools import partial
 
 import tensorflow as tf
-from tensorflow.python.lib.io import file_io # pylint: disable=E0611
+from tensorflow.python.lib.io import file_io  # pylint: disable=E0611
 
 from sciencebeam_utils.utils.collection import (
-  extend_dict
+    extend_dict
 )
 
 from sciencebeam_gym.model_utils.channels import (
-  calculate_color_masks
+    calculate_color_masks
 )
 
 try:
-  # TensorFlow 1.4+
-  tf_data = tf.data
+    # TensorFlow 1.4+
+    tf_data = tf.data
 except AttributeError:
-  tf_data = tf.contrib.data
+    tf_data = tf.contrib.data
 
 Dataset = tf_data.Dataset
 TFRecordDataset = tf_data.TFRecordDataset
 
 DEFAULT_FEATURE_MAP = {
-  'input_uri': tf.FixedLenFeature(
-    shape=[], dtype=tf.string, default_value=['']
-  ),
-  'annotation_uri': tf.FixedLenFeature(
-    shape=[], dtype=tf.string, default_value=['']
-  ),
-  'input_image': tf.FixedLenFeature(
-    shape=[], dtype=tf.string
-  ),
-  'annotation_image': tf.FixedLenFeature(
-    shape=[], dtype=tf.string
-  )
+    'input_uri': tf.FixedLenFeature(
+        shape=[], dtype=tf.string, default_value=['']
+    ),
+    'annotation_uri': tf.FixedLenFeature(
+        shape=[], dtype=tf.string, default_value=['']
+    ),
+    'input_image': tf.FixedLenFeature(
+        shape=[], dtype=tf.string
+    ),
+    'annotation_image': tf.FixedLenFeature(
+        shape=[], dtype=tf.string
+    )
 }
 
 PAGE_NO_FEATURE = {
-  'page_no': tf.FixedLenFeature(
-    shape=[], dtype=tf.int64
-  )
+    'page_no': tf.FixedLenFeature(
+        shape=[], dtype=tf.int64
+    )
 }
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def get_matching_files(paths):
-  files = []
-  for e in paths:
-    for path in e.split(','):
-      files.extend(file_io.get_matching_files(path))
-  return files
+    files = []
+    for e in paths:
+        for path in e.split(','):
+            files.extend(file_io.get_matching_files(path))
+    return files
+
 
 def parse_example(example, feature_map=None):
-  if feature_map is None:
-    feature_map = DEFAULT_FEATURE_MAP
-  get_logger().info('example: %s', example)
-  return tf.parse_single_example(example, features=feature_map)
+    if feature_map is None:
+        feature_map = DEFAULT_FEATURE_MAP
+    get_logger().info('example: %s', example)
+    return tf.parse_single_example(example, features=feature_map)
 
 # Workaround for Tensorflow 1.2 not supporting dicts
+
+
 class MapKeysTracker(object):
-  def __init__(self):
-    self.keys = None
+    def __init__(self):
+        self.keys = None
 
-  def wrap(self, fn):
-    def wrapper(x):
-      x = fn(x)
-      if self.keys is not None:
-        get_logger().warn('keys already set: %s', self.keys)
-      self.keys = sorted(x.keys())
-      return [x[k] for k in self.keys]
-    return wrapper
+    def wrap(self, fn):
+        def wrapper(x):
+            x = fn(x)
+            if self.keys is not None:
+                get_logger().warn('keys already set: %s', self.keys)
+            self.keys = sorted(x.keys())
+            return [x[k] for k in self.keys]
+        return wrapper
+
+    def unwrap(self, result):
+        return {k: v for k, v in zip(self.keys, result)}
 
-  def unwrap(self, result):
-    return {k: v for k, v in zip(self.keys, result)}
 
 def page_no_is_within(page_no, page_range):
-  get_logger().debug('page_no: %s, page_range: %s', page_no, page_range)
-  return tf.logical_and(page_no >= page_range[0], page_no <= page_range[1])
+    get_logger().debug('page_no: %s, page_range: %s', page_no, page_range)
+    return tf.logical_and(page_no >= page_range[0], page_no <= page_range[1])
+
 
 def image_contains_any_of_the_colors(image, colors):
-  decoded_image = tf.image.decode_png(image, channels=3)
-  color_masks = calculate_color_masks(decoded_image, colors)
-  return tf.reduce_any([
-    tf.reduce_any(color_mask >= 0.5)
-    for color_mask in color_masks
-  ])
+    decoded_image = tf.image.decode_png(image, channels=3)
+    color_masks = calculate_color_masks(decoded_image, colors)
+    return tf.reduce_any([
+        tf.reduce_any(color_mask >= 0.5)
+        for color_mask in color_masks
+    ])
+
 
 def read_examples(
-  filenames,
-  shuffle,
-  num_epochs=None,
-  page_range=None,
-  channel_colors=None):
-
-  # Convert num_epochs == 0 -> num_epochs is None, if necessary
-  num_epochs = num_epochs or None
-
-  feature_map = DEFAULT_FEATURE_MAP
-  if page_range is not None:
-    feature_map = extend_dict(feature_map, PAGE_NO_FEATURE)
-
-  map_keys_tracker = MapKeysTracker()
-
-  dataset = TFRecordDataset(filenames, compression_type='GZIP')
-  dataset = dataset.map(map_keys_tracker.wrap(
-    partial(parse_example, feature_map=feature_map)
-  ))
-  if page_range is not None:
-    dataset = dataset.filter(lambda *x: page_no_is_within(
-      map_keys_tracker.unwrap(x)['page_no'],
-      page_range
-    ))
-  if channel_colors is not None:
-    dataset = dataset.filter(lambda *x: image_contains_any_of_the_colors(
-      map_keys_tracker.unwrap(x)['annotation_image'],
-      channel_colors
-    ))
-  if shuffle:
-    dataset = dataset.shuffle(buffer_size=10000)
-  dataset = dataset.repeat(num_epochs)
+        filenames,
+        shuffle,
+        num_epochs=None,
+        page_range=None,
+        channel_colors=None):
+
+    # Convert num_epochs == 0 -> num_epochs is None, if necessary
+    num_epochs = num_epochs or None
 
-  return map_keys_tracker.unwrap(
-    dataset.make_one_shot_iterator().get_next()
-  )
+    feature_map = DEFAULT_FEATURE_MAP
+    if page_range is not None:
+        feature_map = extend_dict(feature_map, PAGE_NO_FEATURE)
+
+    map_keys_tracker = MapKeysTracker()
+
+    dataset = TFRecordDataset(filenames, compression_type='GZIP')
+    dataset = dataset.map(map_keys_tracker.wrap(
+        partial(parse_example, feature_map=feature_map)
+    ))
+    if page_range is not None:
+        dataset = dataset.filter(lambda *x: page_no_is_within(
+            map_keys_tracker.unwrap(x)['page_no'],
+            page_range
+        ))
+    if channel_colors is not None:
+        dataset = dataset.filter(lambda *x: image_contains_any_of_the_colors(
+            map_keys_tracker.unwrap(x)['annotation_image'],
+            channel_colors
+        ))
+    if shuffle:
+        dataset = dataset.shuffle(buffer_size=10000)
+    dataset = dataset.repeat(num_epochs)
+
+    return map_keys_tracker.unwrap(
+        dataset.make_one_shot_iterator().get_next()
+    )
diff --git a/sciencebeam_gym/trainer/data/examples_test.py b/sciencebeam_gym/trainer/data/examples_test.py
index 7a2886a..cb99ef6 100644
--- a/sciencebeam_gym/trainer/data/examples_test.py
+++ b/sciencebeam_gym/trainer/data/examples_test.py
@@ -1,26 +1,24 @@
 import logging
 from mock import patch
 
-import pytest
-
 import tensorflow as tf
 
 from sciencebeam_utils.utils.collection import (
-  extend_dict
+    extend_dict
 )
 
 from sciencebeam_gym.utils.tfrecord import (
-  dict_to_example
+    dict_to_example
 )
 
 from sciencebeam_gym.tools.calculate_class_weights_test import (
-  encode_png
+    encode_png
 )
 
 import sciencebeam_gym.trainer.data.examples as examples_module
 from sciencebeam_gym.trainer.data.examples import (
-  read_examples,
-  tf_data
+    read_examples,
+    tf_data
 )
 
 DATA_PATH = '.temp/data/*.tfrecord'
@@ -28,78 +26,87 @@ DATA_PATH = '.temp/data/*.tfrecord'
 IMAGE_SHAPE = (5, 5)
 
 EXAMPLE_PROPS_1 = {
-  'input_uri': 'input.png',
-  'input_image': b'input image',
-  'annotation_uri': 'annotation.png',
-  'annotation_image': b'annotation image'
+    'input_uri': 'input.png',
+    'input_image': b'input image',
+    'annotation_uri': 'annotation.png',
+    'annotation_image': b'annotation image'
 }
 
 RECORD_1 = dict_to_example(EXAMPLE_PROPS_1).SerializeToString()
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 def list_dataset(data, dtype):
-  data = tf.constant(data, dtype=dtype)
-  return tf_data.Dataset.from_tensor_slices(data)
+    data = tf.constant(data, dtype=dtype)
+    return tf_data.Dataset.from_tensor_slices(data)
+
 
 def fetch_examples(session, examples_tensor, max_examples=1000):
-  try:
-    for _ in range(max_examples):
-      yield session.run([examples_tensor])[0]
-  except tf.errors.OutOfRangeError:
-    get_logger().debug('end of dataset')
+    try:
+        for _ in range(max_examples):
+            yield session.run([examples_tensor])[0]
+    except tf.errors.OutOfRangeError:
+        get_logger().debug('end of dataset')
+
 
 def some_color(i):
-  return (i, i, i)
+    return (i, i, i)
+
 
 def image_with_color(color, shape=IMAGE_SHAPE):
-  return encode_png([[color] * shape[1]] * shape[0])
+    return encode_png([[color] * shape[1]] * shape[0])
 
 # @pytest.mark.slow
+
+
 class TestReadExamples(object):
-  def test_should_read_single_example(self):
-    with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset:
-      with tf.Graph().as_default():
-        TFRecordDataset.return_value = list_dataset([RECORD_1], tf.string)
-        examples = read_examples(DATA_PATH, shuffle=False)
-        TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP')
-        with tf.Session() as session:
-          next_example = session.run([examples])[0]
-          get_logger().info('next_example: %s', next_example)
-          assert next_example == EXAMPLE_PROPS_1
-
-  def test_should_filter_by_page_no(self):
-    with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset:
-      with tf.Graph().as_default():
-        TFRecordDataset.return_value = list_dataset([
-          dict_to_example(extend_dict(EXAMPLE_PROPS_1, page_no=page_no)).SerializeToString()
-          for page_no in [1, 2, 3, 4]
-        ], tf.string)
-        examples = read_examples(DATA_PATH, shuffle=False, num_epochs=1, page_range=(2, 3))
-        TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP')
-        with tf.Session() as session:
-          assert [x['page_no'] for x in fetch_examples(session, examples)] == [2, 3]
-
-  def test_should_filter_by_channel_colors(self):
-    with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset:
-      with tf.Graph().as_default():
-        TFRecordDataset.return_value = list_dataset([
-          dict_to_example(extend_dict(
-            EXAMPLE_PROPS_1,
-            page_no=page_no,
-            annotation_image=image_with_color(some_color(page_no))
-          )).SerializeToString()
-          for page_no in [1, 2, 3, 4]
-        ], tf.string)
-        examples = read_examples(
-          DATA_PATH, shuffle=False, num_epochs=1,
-          page_range=(0, 100),
-          channel_colors=[some_color(i) for i in [2, 3]]
-        )
-        TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP')
-        with tf.Session() as session:
-          assert [x['page_no'] for x in fetch_examples(session, examples)] == [2, 3]
+    def test_should_read_single_example(self):
+        with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset:
+            with tf.Graph().as_default():
+                TFRecordDataset.return_value = list_dataset([RECORD_1], tf.string)
+                examples = read_examples(DATA_PATH, shuffle=False)
+                TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP')
+                with tf.Session() as session:
+                    next_example = session.run([examples])[0]
+                    get_logger().info('next_example: %s', next_example)
+                    assert next_example == EXAMPLE_PROPS_1
+
+    def test_should_filter_by_page_no(self):
+        with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset:
+            with tf.Graph().as_default():
+                TFRecordDataset.return_value = list_dataset([
+                    dict_to_example(extend_dict(EXAMPLE_PROPS_1, page_no=page_no)
+                                    ).SerializeToString()
+                    for page_no in [1, 2, 3, 4]
+                ], tf.string)
+                examples = read_examples(DATA_PATH, shuffle=False, num_epochs=1, page_range=(2, 3))
+                TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP')
+                with tf.Session() as session:
+                    assert [x['page_no'] for x in fetch_examples(session, examples)] == [2, 3]
+
+    def test_should_filter_by_channel_colors(self):
+        with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset:
+            with tf.Graph().as_default():
+                TFRecordDataset.return_value = list_dataset([
+                    dict_to_example(extend_dict(
+                        EXAMPLE_PROPS_1,
+                        page_no=page_no,
+                        annotation_image=image_with_color(some_color(page_no))
+                    )).SerializeToString()
+                    for page_no in [1, 2, 3, 4]
+                ], tf.string)
+                examples = read_examples(
+                    DATA_PATH, shuffle=False, num_epochs=1,
+                    page_range=(0, 100),
+                    channel_colors=[some_color(i) for i in [2, 3]]
+                )
+                TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP')
+                with tf.Session() as session:
+                    assert [x['page_no'] for x in fetch_examples(session, examples)] == [2, 3]
diff --git a/sciencebeam_gym/trainer/evaluator.py b/sciencebeam_gym/trainer/evaluator.py
index 67d62cc..1fa78bf 100644
--- a/sciencebeam_gym/trainer/evaluator.py
+++ b/sciencebeam_gym/trainer/evaluator.py
@@ -3,478 +3,489 @@ import logging
 import json
 from io import BytesIO
 
-import matplotlib as mpl
-# this is important to run on the cloud - we won't have python-tk installed
-mpl.use("Agg")
-
-# pylint: disable=C0413
-from matplotlib import pyplot as plt
 import numpy as np
 import six
 
 import tensorflow as tf
-from tensorflow.python.lib.io import file_io # pylint: disable=E0611
+from tensorflow.python.lib.io import file_io
 
 from PIL import Image
 
+from sciencebeam_gym.utils.pyplot import pyplot as plt
+
 from sciencebeam_gym.utils.tf import (
-  FileIO
+    FileIO
 )
 
 from sciencebeam_gym.trainer.util import (
-  CustomSupervisor,
-  get_graph_size
+    CustomSupervisor,
+    get_graph_size
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def plot_image(ax, image, label):
-  if len(image.shape) == 3:
-    get_logger().info('image shape: %s (%s)', image.shape, image.shape[-1])
-    if image.shape[-1] == 1:
-      ax.imshow(image.squeeze(), aspect='auto', vmin=0, vmax=255, cmap=plt.get_cmap('gray'))
+    if len(image.shape) == 3:
+        get_logger().info('image shape: %s (%s)', image.shape, image.shape[-1])
+        if image.shape[-1] == 1:
+            ax.imshow(image.squeeze(), aspect='auto', vmin=0, vmax=255, cmap=plt.get_cmap('gray'))
+        else:
+            ax.imshow(image, aspect='auto')
     else:
-      ax.imshow(image, aspect='auto')
-  else:
-    ax.imshow(np.dstack((image.astype(np.uint8),)*3)*100, aspect='auto')
-  ax.set_title(label, color=(0.5, 0.5, 0.5), y=0.995)
-  ax.set_axis_off()
-  ax.set(xlim=[0, 255], ylim=[255, 0], aspect=1)
-  ax.axes.get_xaxis().set_visible(False)
-  ax.axes.get_yaxis().set_visible(False)
+        ax.imshow(np.dstack((image.astype(np.uint8),) * 3) * 100, aspect='auto')
+    ax.set_title(label, color=(0.5, 0.5, 0.5), y=0.995)
+    ax.set_axis_off()
+    ax.set(xlim=[0, 255], ylim=[255, 0], aspect=1)
+    ax.axes.get_xaxis().set_visible(False)
+    ax.axes.get_yaxis().set_visible(False)
+
 
 def show_result_images3(input_image, annot, output_image):
-  figsize = plt.figaspect(1.1 / 3.0)
-  fig, (ax_img, ax_annot, ax_out) = plt.subplots(
-    1,
-    3,
-    sharey=True,
-    figsize=figsize,
-    frameon=False,
-    facecolor=None,
-    dpi=60
-  )
-
-  plot_image(ax_img, input_image, 'input')
-  plot_image(ax_annot, annot, 'target')
-  plot_image(ax_out, output_image, 'prediction')
-
-  margin = 0.01
-  fig.subplots_adjust(
-    left=margin,
-    right=1.0 - margin,
-    top=0.95 - margin,
-    bottom=margin,
-    wspace=0.05
-  )
-
-  return fig
+    figsize = plt.figaspect(1.1 / 3.0)
+    fig, (ax_img, ax_annot, ax_out) = plt.subplots(
+        1,
+        3,
+        sharey=True,
+        figsize=figsize,
+        frameon=False,
+        facecolor=None,
+        dpi=60
+    )
+
+    plot_image(ax_img, input_image, 'input')
+    plot_image(ax_annot, annot, 'target')
+    plot_image(ax_out, output_image, 'prediction')
+
+    margin = 0.01
+    fig.subplots_adjust(
+        left=margin,
+        right=1.0 - margin,
+        top=0.95 - margin,
+        bottom=margin,
+        wspace=0.05
+    )
+
+    return fig
+
 
 def save_image_data(image_filename, image_data):
-  image = Image.fromarray(np.array(image_data), "RGB")
-  with FileIO(image_filename, 'wb') as image_f:
-    image.save(image_f, 'png')
+    image = Image.fromarray(np.array(image_data), "RGB")
+    with FileIO(image_filename, 'wb') as image_f:
+        image.save(image_f, 'png')
+
 
 def save_file(filename, data):
-  with FileIO(filename, 'wb') as f:
-    f.write(data)
+    with FileIO(filename, 'wb') as f:
+        f.write(data)
+
 
 def precision_from_tp_fp(tp, fp):
-  return tp / (tp + fp)
+    return tp / (tp + fp)
+
 
 def recall_from_tp_fn(tp, fn):
-  return tp / (tp + fn)
+    return tp / (tp + fn)
+
 
 def f1_from_precision_recall(precision, recall):
-  return 2 * precision * recall / (precision + recall)
+    return 2 * precision * recall / (precision + recall)
+
 
 def f1_from_tp_fp_fn(tp, fp, fn):
-  return f1_from_precision_recall(
-    precision_from_tp_fp(tp, fp),
-    recall_from_tp_fn(tp, fn)
-  )
+    return f1_from_precision_recall(
+        precision_from_tp_fp(tp, fp),
+        recall_from_tp_fn(tp, fn)
+    )
+
 
 def to_list_if_not_none(x):
-  return x.tolist() if x is not None else x
+    return x.tolist() if x is not None else x
+
 
 IMAGE_PREFIX = 'image_'
 
+
 class Evaluator(object):
-  """Loads variables from latest checkpoint and performs model evaluation."""
-
-  def __init__(
-    self, args, model,
-    checkpoint_path,
-    data_paths,
-    dataset='eval',
-    eval_batch_size=None,
-    eval_set_size=None,
-    qualitative_set_size=None,
-    run_async=None):
-
-    self.eval_batch_size = eval_batch_size or args.eval_batch_size
-    self.num_eval_batches = (eval_set_size or args.eval_set_size) // self.eval_batch_size
-    self.num_detail_eval_batches = (
-      min((qualitative_set_size or 10), args.eval_set_size) // self.eval_batch_size
-    )
-    self.checkpoint_path = checkpoint_path
-    self.output_path = os.path.join(args.output_path, dataset)
-    self.eval_data_paths = data_paths
-    self.batch_size = args.batch_size
-    self.stream = args.streaming_eval
-    self.model = model
-    self.results_dir = os.path.join(self.output_path, 'results')
-    self.graph_size = None
-    self.run_async = run_async
-    if not run_async:
-      self.run_async = lambda f, args: f(*args)
-
-  def init(self):
-    file_io.recursive_create_dir(self.results_dir)
-
-  def _check_fetches(self, fetches):
-    for k, v in six.iteritems(fetches):
-      if v is None:
-        raise Exception('fetches tensor is None: {}'.format(k))
-
-  def _get_default_fetches(self, tensors):
-    return {
-      'global_step': tensors.global_step,
-      'input_uri': tensors.input_uri,
-      'input_image': tensors.image_tensor,
-      'annotation_image': tensors.annotation_tensor,
-      'output_image': tensors.summaries.get('output_image'),
-      'metric_values': tensors.metric_values
-    }
-
-  def _add_image_fetches(self, fetches, tensors):
-    for k, v in six.iteritems(tensors.image_tensors):
-      fetches[IMAGE_PREFIX + k] = v
-    return fetches
-
-  def _add_evaluation_result_fetches(self, fetches, tensors):
-    if tensors.evaluation_result:
-      if tensors.output_layer_labels is not None:
-        fetches['output_layer_labels'] = tensors.output_layer_labels
-      fetches['confusion_matrix'] = tensors.evaluation_result.confusion_matrix
-      fetches['tp'] = tensors.evaluation_result.tp
-      fetches['fp'] = tensors.evaluation_result.fp
-      fetches['fn'] = tensors.evaluation_result.fn
-      fetches['tn'] = tensors.evaluation_result.tn
-      fetches['accuracy'] = tensors.evaluation_result.accuracy
-      fetches['micro_f1'] = tensors.evaluation_result.micro_f1
-    return fetches
-
-  def _accumulate_evaluation_results(self, results, accumulated_results=None):
-    if results.get('confusion_matrix') is None:
-      return accumulated_results
-    if accumulated_results is None:
-      accumulated_results = []
-    accumulated_results.append({
-      'output_layer_labels': results.get('output_layer_labels'),
-      'confusion_matrix': results['confusion_matrix'],
-      'tp': results['tp'],
-      'fp': results['fp'],
-      'fn': results['fn'],
-      'tn': results['tn'],
-      'accuracy': results['accuracy'],
-      'micro_f1': results['micro_f1'],
-      'count': self.batch_size,
-      'global_step': results['global_step']
-    })
-    return accumulated_results
-
-  def _save_accumulate_evaluation_results(self, accumulated_results):
-    if accumulated_results:
-      first_result = accumulated_results[0]
-      global_step = first_result['global_step']
-      graph_size = get_graph_size()
-      output_layer_labels = to_list_if_not_none(first_result.get('output_layer_labels'))
-      scores_file = os.path.join(
-        self.results_dir, 'result_{}_scores.json'.format(
-          global_step
-        )
-      )
-      tp = np.sum([r['tp'] for r in accumulated_results], axis=0)
-      fp = np.sum([r['fp'] for r in accumulated_results], axis=0)
-      fn = np.sum([r['fn'] for r in accumulated_results], axis=0)
-      tn = np.sum([r['tn'] for r in accumulated_results], axis=0)
-      f1 = f1_from_tp_fp_fn(tp.astype(float), fp, fn)
-      meta = {
-        'global_step': global_step,
-        'batch_size': self.batch_size
-      }
-      if self.graph_size:
-        meta['graph_size'] = self.graph_size
-      scores = {
-        'accuracy': float(np.mean([r['accuracy'] for r in accumulated_results])),
-        'output_layer_labels': output_layer_labels,
-        'confusion_matrix': sum([r['confusion_matrix'] for r in accumulated_results]).tolist(),
-        'tp': to_list_if_not_none(tp),
-        'fp': to_list_if_not_none(fp),
-        'fn': to_list_if_not_none(fn),
-        'tn': to_list_if_not_none(tn),
-        'f1': to_list_if_not_none(f1),
-        'micro_f1': float(np.mean([r['micro_f1'] for r in accumulated_results])),
-        'macro_f1': float(np.mean(f1)),
-        'count': sum([r['count'] for r in accumulated_results]),
-        'meta': meta
-      }
-      scores_str = json.dumps(scores, indent=2)
-      with FileIO(scores_file, 'w') as f:
-        f.write(scores_str)
-
-  def _save_prediction_summary_image_for(
-    self, eval_index, global_step, inputs, targets, outputs, name):
-
-    for batch_index, input_image, target_image, output_image in zip(
-      range(len(inputs)), inputs, targets, outputs
-      ):
-
-      fig = show_result_images3(
-        input_image,
-        target_image,
-        output_image
-      )
-      result_file = os.path.join(
-        self.results_dir, 'result_{}_{}_{}_{}.png'.format(
-          global_step, eval_index, batch_index, name
+    """Loads variables from latest checkpoint and performs model evaluation."""
+
+    def __init__(
+            self, args, model,
+            checkpoint_path,
+            data_paths,
+            dataset='eval',
+            eval_batch_size=None,
+            eval_set_size=None,
+            qualitative_set_size=None,
+            run_async=None):
+
+        self.eval_batch_size = eval_batch_size or args.eval_batch_size
+        self.num_eval_batches = (eval_set_size or args.eval_set_size) // self.eval_batch_size
+        self.num_detail_eval_batches = (
+            min((qualitative_set_size or 10), args.eval_set_size) // self.eval_batch_size
         )
-      )
-      logging.info('result_file: %s', result_file)
-      bio = BytesIO()
-      plt.savefig(bio, format='png', transparent=False, frameon=True, dpi='figure')
-      plt.close(fig)
-      self.run_async(save_file, (result_file, bio.getvalue()))
-
-  def _save_prediction_summary_image(self, eval_index, results):
-    global_step = results['global_step']
-    self._save_prediction_summary_image_for(
-      eval_index,
-      global_step,
-      results['input_image'],
-      results['annotation_image'],
-      results['output_image'],
-      'summary_output'
-    )
-
-    if results.get(IMAGE_PREFIX + 'outputs_most_likely') is not None:
-      self._save_prediction_summary_image_for(
-        eval_index,
-        global_step,
-        results['input_image'],
-        results['annotation_image'],
-        results[IMAGE_PREFIX + 'outputs_most_likely'],
-        'summary_most_likely'
-      )
-    outputs_key_needle = '_outputs_'
-    for k in six.iterkeys(results):
-      outputs_key_needle_index = k.find(outputs_key_needle)
-      if k.startswith(IMAGE_PREFIX) and outputs_key_needle_index >= 0:
-        targets_key = k.replace(outputs_key_needle, '_targets_')
-        if not targets_key in results:
-          continue
+        self.checkpoint_path = checkpoint_path
+        self.output_path = os.path.join(args.output_path, dataset)
+        self.eval_data_paths = data_paths
+        self.batch_size = args.batch_size
+        self.stream = args.streaming_eval
+        self.model = model
+        self.results_dir = os.path.join(self.output_path, 'results')
+        self.graph_size = None
+        self.run_async = run_async
+        if not run_async:
+            self.run_async = lambda f, args: f(*args)
+
+    def init(self):
+        file_io.recursive_create_dir(self.results_dir)
+
+    def _check_fetches(self, fetches):
+        for k, v in six.iteritems(fetches):
+            if v is None:
+                raise Exception('fetches tensor is None: {}'.format(k))
+
+    def _get_default_fetches(self, tensors):
+        return {
+            'global_step': tensors.global_step,
+            'input_uri': tensors.input_uri,
+            'input_image': tensors.image_tensor,
+            'annotation_image': tensors.annotation_tensor,
+            'output_image': tensors.summaries.get('output_image'),
+            'metric_values': tensors.metric_values
+        }
+
+    def _add_image_fetches(self, fetches, tensors):
+        for k, v in six.iteritems(tensors.image_tensors):
+            fetches[IMAGE_PREFIX + k] = v
+        return fetches
+
+    def _add_evaluation_result_fetches(self, fetches, tensors):
+        if tensors.evaluation_result:
+            if tensors.output_layer_labels is not None:
+                fetches['output_layer_labels'] = tensors.output_layer_labels
+            fetches['confusion_matrix'] = tensors.evaluation_result.confusion_matrix
+            fetches['tp'] = tensors.evaluation_result.tp
+            fetches['fp'] = tensors.evaluation_result.fp
+            fetches['fn'] = tensors.evaluation_result.fn
+            fetches['tn'] = tensors.evaluation_result.tn
+            fetches['accuracy'] = tensors.evaluation_result.accuracy
+            fetches['micro_f1'] = tensors.evaluation_result.micro_f1
+        return fetches
+
+    def _accumulate_evaluation_results(self, results, accumulated_results=None):
+        if results.get('confusion_matrix') is None:
+            return accumulated_results
+        if accumulated_results is None:
+            accumulated_results = []
+        accumulated_results.append({
+            'output_layer_labels': results.get('output_layer_labels'),
+            'confusion_matrix': results['confusion_matrix'],
+            'tp': results['tp'],
+            'fp': results['fp'],
+            'fn': results['fn'],
+            'tn': results['tn'],
+            'accuracy': results['accuracy'],
+            'micro_f1': results['micro_f1'],
+            'count': self.batch_size,
+            'global_step': results['global_step']
+        })
+        return accumulated_results
+
+    def _save_accumulate_evaluation_results(self, accumulated_results):
+        if accumulated_results:
+            first_result = accumulated_results[0]
+            global_step = first_result['global_step']
+            output_layer_labels = to_list_if_not_none(first_result.get('output_layer_labels'))
+            scores_file = os.path.join(
+                self.results_dir, 'result_{}_scores.json'.format(
+                    global_step
+                )
+            )
+            tp = np.sum([r['tp'] for r in accumulated_results], axis=0)
+            fp = np.sum([r['fp'] for r in accumulated_results], axis=0)
+            fn = np.sum([r['fn'] for r in accumulated_results], axis=0)
+            tn = np.sum([r['tn'] for r in accumulated_results], axis=0)
+            f1 = f1_from_tp_fp_fn(tp.astype(float), fp, fn)
+            meta = {
+                'global_step': global_step,
+                'batch_size': self.batch_size
+            }
+            if self.graph_size:
+                meta['graph_size'] = self.graph_size
+            scores = {
+                'accuracy': float(np.mean([r['accuracy'] for r in accumulated_results])),
+                'output_layer_labels': output_layer_labels,
+                'confusion_matrix': sum(
+                    [r['confusion_matrix'] for r in accumulated_results]
+                ).tolist(),
+                'tp': to_list_if_not_none(tp),
+                'fp': to_list_if_not_none(fp),
+                'fn': to_list_if_not_none(fn),
+                'tn': to_list_if_not_none(tn),
+                'f1': to_list_if_not_none(f1),
+                'micro_f1': float(np.mean([r['micro_f1'] for r in accumulated_results])),
+                'macro_f1': float(np.mean(f1)),
+                'count': sum([r['count'] for r in accumulated_results]),
+                'meta': meta
+            }
+            scores_str = json.dumps(scores, indent=2)
+            with FileIO(scores_file, 'w') as f:
+                f.write(scores_str)
+
+    def _save_prediction_summary_image_for(
+            self, eval_index, global_step, inputs, targets, outputs, name):
+
+        for batch_index, input_image, target_image, output_image in zip(
+            range(len(inputs)), inputs, targets, outputs
+        ):
+
+            fig = show_result_images3(
+                input_image,
+                target_image,
+                output_image
+            )
+            result_file = os.path.join(
+                self.results_dir, 'result_{}_{}_{}_{}.png'.format(
+                    global_step, eval_index, batch_index, name
+                )
+            )
+            logging.info('result_file: %s', result_file)
+            bio = BytesIO()
+            plt.savefig(bio, format='png', transparent=False, frameon=True, dpi='figure')
+            plt.close(fig)
+            self.run_async(save_file, (result_file, bio.getvalue()))
+
+    def _save_prediction_summary_image(self, eval_index, results):
+        global_step = results['global_step']
         self._save_prediction_summary_image_for(
-          eval_index,
-          global_step,
-          results['input_image'],
-          results[targets_key],
-          results[k],
-          'summary_' + k[(outputs_key_needle_index + len(outputs_key_needle)):]
+            eval_index,
+            global_step,
+            results['input_image'],
+            results['annotation_image'],
+            results['output_image'],
+            'summary_output'
         )
 
-  def _save_result_images(self, eval_index, results):
-    global_step = results['global_step']
-    for k in six.iterkeys(results):
-      if k.startswith(IMAGE_PREFIX):
-        batch_image_data = results[k]
-        name = k[len(IMAGE_PREFIX):]
-        for batch_index, image_data in enumerate(batch_image_data):
-          image_filename = os.path.join(
-            self.results_dir, 'result_{}_{}_{}_{}.png'.format(
-              global_step, eval_index, batch_index, name
+        if results.get(IMAGE_PREFIX + 'outputs_most_likely') is not None:
+            self._save_prediction_summary_image_for(
+                eval_index,
+                global_step,
+                results['input_image'],
+                results['annotation_image'],
+                results[IMAGE_PREFIX + 'outputs_most_likely'],
+                'summary_most_likely'
             )
-          )
-          self.run_async(save_image_data, (image_filename, image_data))
-
-  def _save_meta(self, eval_index, results):
-    global_step = results['global_step']
-    metric_values = results['metric_values']
-    for batch_index, input_uri in enumerate(results['input_uri']):
-      meta_file = os.path.join(
-        self.results_dir, 'result_{}_{}_{}_meta.json'.format(
-          global_step, eval_index, batch_index
-        )
-      )
-      meta_str = json.dumps({
-        'global_step': int(global_step),
-        'eval_index': eval_index,
-        'batch_index': batch_index,
-        'metric_values': [float(x) for x in metric_values],
-        'input_uri': input_uri
-      }, indent=2)
-      with FileIO(meta_file, 'w') as meta_f:
-        meta_f.write(meta_str)
-
-  def evaluate_in_session(self, session, tensors, num_eval_batches=None):
-    summary_writer = tf.summary.FileWriter(self.output_path)
-    num_eval_batches = num_eval_batches or self.num_eval_batches
-    num_detailed_eval_batches = min(self.num_detail_eval_batches, num_eval_batches)
-    if self.stream:
-      for _ in range(num_eval_batches):
-        session.run(tensors.metric_updates, feed_dict={
-          tensors.is_training: False
-        })
-    else:
-      get_logger().info('tensors.examples: %s', tensors.examples)
-
-      metric_values = None
-      accumulated_results = None
-
-      for eval_index in range(num_eval_batches):
-        detailed_evaluation = eval_index < num_detailed_eval_batches
-
-        fetches = self._get_default_fetches(tensors)
-        self._add_evaluation_result_fetches(fetches, tensors)
-        if detailed_evaluation:
-          self._add_image_fetches(fetches, tensors)
-        fetches['summary_value'] = tensors.summary
-        self._check_fetches(fetches)
-        results = session.run(fetches, feed_dict={
-          tensors.is_training: False
-        })
-
-        accumulated_results = self._accumulate_evaluation_results(results, accumulated_results)
-        if detailed_evaluation:
-          self._save_prediction_summary_image(eval_index, results)
-          self._save_result_images(eval_index, results)
-          self._save_meta(eval_index, results)
-
-          global_step = results['global_step']
-          summary_value = results['summary_value']
-          summary_writer.add_summary(summary_value, global_step)
-          summary_writer.flush()
-
+        outputs_key_needle = '_outputs_'
+        for k in six.iterkeys(results):
+            outputs_key_needle_index = k.find(outputs_key_needle)
+            if k.startswith(IMAGE_PREFIX) and outputs_key_needle_index >= 0:
+                targets_key = k.replace(outputs_key_needle, '_targets_')
+                if targets_key not in results:
+                    continue
+                self._save_prediction_summary_image_for(
+                    eval_index,
+                    global_step,
+                    results['input_image'],
+                    results[targets_key],
+                    results[k],
+                    'summary_' + k[(outputs_key_needle_index + len(outputs_key_needle)):]
+                )
+
+    def _save_result_images(self, eval_index, results):
+        global_step = results['global_step']
+        for k in six.iterkeys(results):
+            if k.startswith(IMAGE_PREFIX):
+                batch_image_data = results[k]
+                name = k[len(IMAGE_PREFIX):]
+                for batch_index, image_data in enumerate(batch_image_data):
+                    image_filename = os.path.join(
+                        self.results_dir, 'result_{}_{}_{}_{}.png'.format(
+                            global_step, eval_index, batch_index, name
+                        )
+                    )
+                    self.run_async(save_image_data, (image_filename, image_data))
+
+    def _save_meta(self, eval_index, results):
+        global_step = results['global_step']
         metric_values = results['metric_values']
+        for batch_index, input_uri in enumerate(results['input_uri']):
+            meta_file = os.path.join(
+                self.results_dir, 'result_{}_{}_{}_meta.json'.format(
+                    global_step, eval_index, batch_index
+                )
+            )
+            meta_str = json.dumps({
+                'global_step': int(global_step),
+                'eval_index': eval_index,
+                'batch_index': batch_index,
+                'metric_values': [float(x) for x in metric_values],
+                'input_uri': input_uri
+            }, indent=2)
+            with FileIO(meta_file, 'w') as meta_f:
+                meta_f.write(meta_str)
+
+    def evaluate_in_session(self, session, tensors, num_eval_batches=None):
+        summary_writer = tf.summary.FileWriter(self.output_path)
+        num_eval_batches = num_eval_batches or self.num_eval_batches
+        num_detailed_eval_batches = min(self.num_detail_eval_batches, num_eval_batches)
+        if self.stream:
+            for _ in range(num_eval_batches):
+                session.run(tensors.metric_updates, feed_dict={
+                    tensors.is_training: False
+                })
+        else:
+            get_logger().info('tensors.examples: %s', tensors.examples)
 
-      self._save_accumulate_evaluation_results(accumulated_results)
-
-      logging.info('eval done')
-      return metric_values
+            metric_values = None
+            accumulated_results = None
 
-    metric_values = session.run(tensors.metric_values, feed_dict={
-      tensors.is_training: False
-    })
-    return metric_values
+            for eval_index in range(num_eval_batches):
+                detailed_evaluation = eval_index < num_detailed_eval_batches
 
-  def evaluate(self, num_eval_batches=None, session=None):
-    """Run one round of evaluation, return loss and accuracy."""
+                fetches = self._get_default_fetches(tensors)
+                self._add_evaluation_result_fetches(fetches, tensors)
+                if detailed_evaluation:
+                    self._add_image_fetches(fetches, tensors)
+                fetches['summary_value'] = tensors.summary
+                self._check_fetches(fetches)
+                results = session.run(fetches, feed_dict={
+                    tensors.is_training: False
+                })
 
-    num_eval_batches = num_eval_batches or self.num_eval_batches
-    with tf.Graph().as_default() as graph:
-      tensors = self.model.build_eval_graph(
-        self.eval_data_paths,
-        self.eval_batch_size
-      )
-      self.graph_size = get_graph_size()
+                accumulated_results = self._accumulate_evaluation_results(
+                    results, accumulated_results)
+                if detailed_evaluation:
+                    self._save_prediction_summary_image(eval_index, results)
+                    self._save_result_images(eval_index, results)
+                    self._save_meta(eval_index, results)
 
-      saver = tf.train.Saver()
+                    global_step = results['global_step']
+                    summary_value = results['summary_value']
+                    summary_writer.add_summary(summary_value, global_step)
+                    summary_writer.flush()
 
-    sv = CustomSupervisor(
-      model=self.model,
-      graph=graph,
-      logdir=self.output_path,
-      summary_op=None,
-      global_step=None,
-      saver=saver
-    )
-    try:
-      last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
-      logging.info('last_checkpoint: %s (%s)', last_checkpoint, self.checkpoint_path)
+                metric_values = results['metric_values']
 
-      file_io.recursive_create_dir(self.results_dir)
+            self._save_accumulate_evaluation_results(accumulated_results)
 
-      with sv.managed_session(
-          master='', start_standard_services=False) as session:
-        sv.saver.restore(session, last_checkpoint)
+            logging.info('eval done')
+            return metric_values
 
-        logging.info('session restored')
+        metric_values = session.run(tensors.metric_values, feed_dict={
+            tensors.is_training: False
+        })
+        return metric_values
 
-        if self.stream:
-          logging.info('start queue runners (stream)')
-          sv.start_queue_runners(session)
-          for _ in range(num_eval_batches):
-            session.run(self.tensors.metric_updates, feed_dict={
-              tensors.is_training: False
-            })
-        else:
-          logging.info('start queue runners (batch)')
-          sv.start_queue_runners(session)
-
-        logging.info('evaluate_in_session')
-        return self.evaluate_in_session(session, tensors)
-    finally:
-      sv.stop()
-
-  def write_predictions(self):
-    """Run one round of predictions and write predictions to csv file."""
-    num_eval_batches = self.num_eval_batches
-    num_detailed_eval_batches = self.num_detail_eval_batches
-    with tf.Graph().as_default() as graph:
-      tensors = self.model.build_eval_graph(
-        self.eval_data_paths,
-        self.batch_size
-      )
-      self.graph_size = get_graph_size()
-      saver = tf.train.Saver()
-
-    sv = CustomSupervisor(
-      model=self.model,
-      graph=graph,
-      logdir=self.output_path,
-      summary_op=None,
-      global_step=None,
-      saver=saver
-    )
+    def evaluate(self, num_eval_batches=None):
+        """Run one round of evaluation, return loss and accuracy."""
 
-    file_io.recursive_create_dir(self.results_dir)
-
-    accumulated_results = None
-
-    last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
-    with sv.managed_session(
-        master='', start_standard_services=False) as session:
-      sv.saver.restore(session, last_checkpoint)
-      predictions_filename = os.path.join(self.output_path, 'predictions.csv')
-      with FileIO(predictions_filename, 'w') as csv_f:
-        sv.start_queue_runners(session)
-        last_log_progress = 0
-        for eval_index in range(num_eval_batches):
-          progress = eval_index * 100 // num_eval_batches
-          if progress > last_log_progress:
-            logging.info('%3d%% predictions processed', progress)
-            last_log_progress = progress
-
-          detailed_evaluation = eval_index < num_detailed_eval_batches
-
-          fetches = self._get_default_fetches(tensors)
-          self._add_evaluation_result_fetches(fetches, tensors)
-          if detailed_evaluation:
-            self._add_image_fetches(fetches, tensors)
-          self._check_fetches(fetches)
-          results = session.run(fetches, feed_dict={
-            tensors.is_training: False
-          })
+        num_eval_batches = num_eval_batches or self.num_eval_batches
+        with tf.Graph().as_default() as graph:
+            tensors = self.model.build_eval_graph(
+                self.eval_data_paths,
+                self.eval_batch_size
+            )
+            self.graph_size = get_graph_size()
 
-          accumulated_results = self._accumulate_evaluation_results(results, accumulated_results)
-          if detailed_evaluation:
-            self._save_prediction_summary_image(eval_index, results)
-            self._save_result_images(eval_index, results)
-            self._save_meta(eval_index, results)
+            saver = tf.train.Saver()
 
-          input_uri = results['input_uri']
-          metric_values = results['metric_values']
-          csv_f.write('{},{}\n'.format(input_uri, metric_values[0]))
+        sv = CustomSupervisor(
+            model=self.model,
+            graph=graph,
+            logdir=self.output_path,
+            summary_op=None,
+            global_step=None,
+            saver=saver
+        )
+        try:
+            last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
+            logging.info('last_checkpoint: %s (%s)', last_checkpoint, self.checkpoint_path)
+
+            file_io.recursive_create_dir(self.results_dir)
+
+            with sv.managed_session(
+                    master='', start_standard_services=False) as session:
+                sv.saver.restore(session, last_checkpoint)
+
+                logging.info('session restored')
+
+                if self.stream:
+                    logging.info('start queue runners (stream)')
+                    sv.start_queue_runners(session)
+                    for _ in range(num_eval_batches):
+                        session.run(tensors.metric_updates, feed_dict={
+                            tensors.is_training: False
+                        })
+                else:
+                    logging.info('start queue runners (batch)')
+                    sv.start_queue_runners(session)
+
+                logging.info('evaluate_in_session')
+                return self.evaluate_in_session(session, tensors)
+        finally:
+            sv.stop()
+
+    def write_predictions(self):
+        """Run one round of predictions and write predictions to csv file."""
+        num_eval_batches = self.num_eval_batches
+        num_detailed_eval_batches = self.num_detail_eval_batches
+        with tf.Graph().as_default() as graph:
+            tensors = self.model.build_eval_graph(
+                self.eval_data_paths,
+                self.batch_size
+            )
+            self.graph_size = get_graph_size()
+            saver = tf.train.Saver()
+
+        sv = CustomSupervisor(
+            model=self.model,
+            graph=graph,
+            logdir=self.output_path,
+            summary_op=None,
+            global_step=None,
+            saver=saver
+        )
 
-        self._save_accumulate_evaluation_results(accumulated_results)
+        file_io.recursive_create_dir(self.results_dir)
+
+        accumulated_results = None
+
+        last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
+        with sv.managed_session(
+                master='', start_standard_services=False) as session:
+            sv.saver.restore(session, last_checkpoint)
+            predictions_filename = os.path.join(self.output_path, 'predictions.csv')
+            with FileIO(predictions_filename, 'w') as csv_f:
+                sv.start_queue_runners(session)
+                last_log_progress = 0
+                for eval_index in range(num_eval_batches):
+                    progress = eval_index * 100 // num_eval_batches
+                    if progress > last_log_progress:
+                        logging.info('%3d%% predictions processed', progress)
+                        last_log_progress = progress
+
+                    detailed_evaluation = eval_index < num_detailed_eval_batches
+
+                    fetches = self._get_default_fetches(tensors)
+                    self._add_evaluation_result_fetches(fetches, tensors)
+                    if detailed_evaluation:
+                        self._add_image_fetches(fetches, tensors)
+                    self._check_fetches(fetches)
+                    results = session.run(fetches, feed_dict={
+                        tensors.is_training: False
+                    })
+
+                    accumulated_results = self._accumulate_evaluation_results(
+                        results, accumulated_results)
+                    if detailed_evaluation:
+                        self._save_prediction_summary_image(eval_index, results)
+                        self._save_result_images(eval_index, results)
+                        self._save_meta(eval_index, results)
+
+                    input_uri = results['input_uri']
+                    metric_values = results['metric_values']
+                    csv_f.write('{},{}\n'.format(input_uri, metric_values[0]))
+
+                self._save_accumulate_evaluation_results(accumulated_results)
diff --git a/sciencebeam_gym/trainer/evaluator_test.py b/sciencebeam_gym/trainer/evaluator_test.py
index 3f257f0..1a891fe 100644
--- a/sciencebeam_gym/trainer/evaluator_test.py
+++ b/sciencebeam_gym/trainer/evaluator_test.py
@@ -4,25 +4,25 @@ import pytest
 import tensorflow as tf
 
 from sciencebeam_utils.utils.collection import (
-  to_namedtuple
+    to_namedtuple
 )
 
 from sciencebeam_gym.utils.tfrecord import (
-  dict_to_example
+    dict_to_example
 )
 
 from sciencebeam_gym.trainer.data.examples import (
-  parse_example,
-  MapKeysTracker
+    parse_example,
+    MapKeysTracker
 )
 
 from sciencebeam_gym.trainer.data.examples_test import (
-  EXAMPLE_PROPS_1,
-  list_dataset
+    EXAMPLE_PROPS_1,
+    list_dataset
 )
 
 from sciencebeam_gym.trainer.evaluator import (
-  Evaluator
+    Evaluator
 )
 
 TEST_PATH = '.temp/test/evaluator'
@@ -34,102 +34,108 @@ BATCH_SIZE = 10
 EVAL_SET_SIZE = 10
 
 DEFAULT_ARGS = dict(
-  batch_size=BATCH_SIZE,
-  eval_set_size=EVAL_SET_SIZE,
-  streaming_eval=False,
-  output_path=OUTPUT_PATH
+    batch_size=BATCH_SIZE,
+    eval_set_size=EVAL_SET_SIZE,
+    streaming_eval=False,
+    output_path=OUTPUT_PATH
 )
 
 DEFAULT_KWARGS = dict(
-  checkpoint_path=CHECKPOINT_PATH,
-  data_paths=DATA_PATHS,
-  eval_batch_size=BATCH_SIZE
+    checkpoint_path=CHECKPOINT_PATH,
+    data_paths=DATA_PATHS,
+    eval_batch_size=BATCH_SIZE
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 class GraphMode(object):
-  TRAIN = 'train'
-  EVALUATE = 'eval'
+    TRAIN = 'train'
+    EVALUATE = 'eval'
+
 
 def example_dataset(map_keys_tracker, examples):
-  dataset = list_dataset([
-    dict_to_example(example).SerializeToString()
-    for example in examples
-  ], tf.string)
-  dataset = dataset.map(map_keys_tracker.wrap(parse_example))
-  return dataset
+    dataset = list_dataset([
+        dict_to_example(example).SerializeToString()
+        for example in examples
+    ], tf.string)
+    dataset = dataset.map(map_keys_tracker.wrap(parse_example))
+    return dataset
+
 
 class ExampleModel(object):
-  def __init__(self, examples):
-    self.examples = examples
-
-  def build_graph(self, data_paths, batch_size, graph_mode):
-    tensors = dict()
-    tensors['is_training'] = tf.placeholder(tf.bool)
-    map_keys_tracker = MapKeysTracker()
-    dataset = example_dataset(map_keys_tracker, self.examples)
-    iterator = dataset.make_one_shot_iterator()
-    parsed = map_keys_tracker.unwrap(iterator.get_next())
-    get_logger().debug('parsed: %s', parsed)
-    tensors['examples'] = parsed
-    tensors['metric_values'] = []
-    tensors['metric_updates'] = []
-    tensors['global_step'] = tf.constant(100, tf.int32)
-    tensors['summaries'] = dict()
-    tensors['image_tensors'] = dict()
-    tensors['evaluation_result'] = None
-    image_shape = (10, 10, 3)
-    pre_batch_tensors = {
-      'input_uri': tf.squeeze(parsed['input_uri']),
-      'annotation_uri': tf.squeeze(parsed['annotation_uri']),
-      'image_tensor': tf.zeros(image_shape),
-      'annotation_tensor': tf.zeros(image_shape),
-      'output_image': tf.zeros(image_shape)
-    }
-    post_batch_tensors = tf.train.batch(pre_batch_tensors, batch_size=batch_size)
-    tensors.update(post_batch_tensors)
-    for name in ['image_tensor', 'annotation_tensor', 'output_image']:
-      image_tensor = tensors[name]
-      get_logger().debug('name=%s, image_tensor=%s', name, image_tensor)
-      tf.summary.image(name, image_tensor)
-      tensors['image_tensors'][name] = image_tensor
-      tensors['summaries'][name] = image_tensor
-    tensors['summary'] = tf.summary.merge_all()
-    tensors['initializer'] = [tf.global_variables_initializer()]
-    return to_namedtuple(tensors, name='Tensors')
-
-  def build_train_graph(self, data_paths, batch_size):
-    return self.build_graph(data_paths, batch_size, GraphMode.TRAIN)
-
-  def build_eval_graph(self, data_paths, batch_size):
-    return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE)
+    def __init__(self, examples):
+        self.examples = examples
+
+    def build_graph(self, data_paths, batch_size, graph_mode):  # pylint: disable=unused-argument
+        tensors = dict()
+        tensors['is_training'] = tf.placeholder(tf.bool)
+        map_keys_tracker = MapKeysTracker()
+        dataset = example_dataset(map_keys_tracker, self.examples)
+        iterator = dataset.make_one_shot_iterator()
+        parsed = map_keys_tracker.unwrap(iterator.get_next())
+        get_logger().debug('parsed: %s', parsed)
+        tensors['examples'] = parsed
+        tensors['metric_values'] = []
+        tensors['metric_updates'] = []
+        tensors['global_step'] = tf.constant(100, tf.int32)
+        tensors['summaries'] = dict()
+        tensors['image_tensors'] = dict()
+        tensors['evaluation_result'] = None
+        image_shape = (10, 10, 3)
+        pre_batch_tensors = {
+            'input_uri': tf.squeeze(parsed['input_uri']),
+            'annotation_uri': tf.squeeze(parsed['annotation_uri']),
+            'image_tensor': tf.zeros(image_shape),
+            'annotation_tensor': tf.zeros(image_shape),
+            'output_image': tf.zeros(image_shape)
+        }
+        post_batch_tensors = tf.train.batch(pre_batch_tensors, batch_size=batch_size)
+        tensors.update(post_batch_tensors)
+        for name in ['image_tensor', 'annotation_tensor', 'output_image']:
+            image_tensor = tensors[name]
+            get_logger().debug('name=%s, image_tensor=%s', name, image_tensor)
+            tf.summary.image(name, image_tensor)
+            tensors['image_tensors'][name] = image_tensor
+            tensors['summaries'][name] = image_tensor
+        tensors['summary'] = tf.summary.merge_all()
+        tensors['initializer'] = [tf.global_variables_initializer()]
+        return to_namedtuple(tensors, name='Tensors')
+
+    def build_train_graph(self, data_paths, batch_size):
+        return self.build_graph(data_paths, batch_size, GraphMode.TRAIN)
+
+    def build_eval_graph(self, data_paths, batch_size):
+        return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE)
+
 
 @pytest.mark.slow
 class TestEvaluator(object):
-  def test_should_not_fail_eval_in_session(self):
-    with tf.Graph().as_default():
-      model = ExampleModel([EXAMPLE_PROPS_1] * BATCH_SIZE)
-      tensors = model.build_train_graph(
-        DATA_PATHS, BATCH_SIZE
-      )
-
-      evaluator = Evaluator(
-        args=to_namedtuple(DEFAULT_ARGS, name='args'),
-        model=model,
-        **DEFAULT_KWARGS
-      )
-      evaluator.init()
-
-      get_logger().info('starting session')
-      with tf.Session() as session:
-        coord = tf.train.Coordinator()
-        tf.train.start_queue_runners(sess=session, coord=coord)
-        get_logger().info('evaluating')
-        session.run(tensors.initializer)
-        evaluator.evaluate_in_session(session, tensors)
-      get_logger().info('done')
+    def test_should_not_fail_eval_in_session(self):
+        with tf.Graph().as_default():
+            model = ExampleModel([EXAMPLE_PROPS_1] * BATCH_SIZE)
+            tensors = model.build_train_graph(
+                DATA_PATHS, BATCH_SIZE
+            )
+
+            evaluator = Evaluator(
+                args=to_namedtuple(DEFAULT_ARGS, name='args'),
+                model=model,
+                **DEFAULT_KWARGS
+            )
+            evaluator.init()
+
+            get_logger().info('starting session')
+            with tf.Session() as session:
+                coord = tf.train.Coordinator()
+                tf.train.start_queue_runners(sess=session, coord=coord)
+                get_logger().info('evaluating')
+                session.run(tensors.initializer)
+                evaluator.evaluate_in_session(session, tensors)
+            get_logger().info('done')
diff --git a/sciencebeam_gym/trainer/models/pix2pix/evaluate.py b/sciencebeam_gym/trainer/models/pix2pix/evaluate.py
index c8a7846..5c79b05 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/evaluate.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/evaluate.py
@@ -6,114 +6,122 @@ import collections
 import tensorflow as tf
 
 EvaluationTensors = collections.namedtuple(
-  "EvaluationTensors", [
-    "confusion_matrix",
-    "tp",
-    "fp",
-    "fn",
-    "tn",
-    "precision",
-    "recall",
-    "f1",
-    "accuracy",
-    "micro_precision",
-    "micro_recall",
-    "micro_f1",
-    "macro_precision",
-    "macro_recall",
-    "macro_f1"
-  ]
+    "EvaluationTensors", [
+        "confusion_matrix",
+        "tp",
+        "fp",
+        "fn",
+        "tn",
+        "precision",
+        "recall",
+        "f1",
+        "accuracy",
+        "micro_precision",
+        "micro_recall",
+        "micro_f1",
+        "macro_precision",
+        "macro_recall",
+        "macro_f1"
+    ]
 )
 
+
 def output_probabilities_to_class(outputs):
-  return tf.argmax(outputs, 3)
+    return tf.argmax(outputs, 3)
+
 
 def to_1d_vector(tensor):
-  return tf.reshape(tensor, [-1])
+    return tf.reshape(tensor, [-1])
+
 
 def precision_from_tp_fp(tp, fp):
-  return tp / (tp + fp)
+    return tp / (tp + fp)
+
 
 def recall_from_tp_fn(tp, fn):
-  return tp / (tp + fn)
+    return tp / (tp + fn)
+
 
 def f1_from_precision_recall(precision, recall):
-  return 2 * precision * recall / (precision + recall)
+    return 2 * precision * recall / (precision + recall)
+
 
 def _evaluate_from_confusion_matrix(confusion, accuracy=None):
-  total = tf.reduce_sum(confusion)
-  actual_p = tf.reduce_sum(confusion, axis=0)
-  pred_p = tf.reduce_sum(confusion, axis=1)
-  tp = tf.diag_part(confusion)
-  fp = actual_p - tp
-  fn = pred_p - tp
-  tn = total - tp - fp - fn
-  precision = precision_from_tp_fp(tp, fp)
-  recall = recall_from_tp_fn(tp, fn)
-  f1 = f1_from_precision_recall(precision, recall)
-  total_tp = tf.reduce_sum(tp)
-  total_fp = tf.reduce_sum(fp)
-  total_fn = tf.reduce_sum(fn)
-  # Note: micro averages (with equal weights) will lead to the same precision, recall, f1
-  micro_precision = precision_from_tp_fp(total_tp, total_fp)
-  micro_recall = recall_from_tp_fn(total_tp, total_fn)
-  micro_f1 = f1_from_precision_recall(micro_precision, micro_recall)
-  macro_precision = tf.reduce_sum(precision)
-  macro_recall = tf.reduce_sum(recall)
-  macro_f1 = tf.reduce_sum(f1)
-  return EvaluationTensors(
-    confusion_matrix=confusion,
-    tp=tp,
-    fp=fp,
-    fn=fn,
-    tn=tn,
-    precision=precision,
-    recall=recall,
-    f1=f1,
-    accuracy=accuracy,
-    micro_precision=micro_precision,
-    micro_recall=micro_recall,
-    micro_f1=micro_f1,
-    macro_precision=macro_precision,
-    macro_recall=macro_recall,
-    macro_f1=macro_f1
-  )
+    total = tf.reduce_sum(confusion)
+    actual_p = tf.reduce_sum(confusion, axis=0)
+    pred_p = tf.reduce_sum(confusion, axis=1)
+    tp = tf.diag_part(confusion)
+    fp = actual_p - tp
+    fn = pred_p - tp
+    tn = total - tp - fp - fn
+    precision = precision_from_tp_fp(tp, fp)
+    recall = recall_from_tp_fn(tp, fn)
+    f1 = f1_from_precision_recall(precision, recall)
+    total_tp = tf.reduce_sum(tp)
+    total_fp = tf.reduce_sum(fp)
+    total_fn = tf.reduce_sum(fn)
+    # Note: micro averages (with equal weights) will lead to the same precision, recall, f1
+    micro_precision = precision_from_tp_fp(total_tp, total_fp)
+    micro_recall = recall_from_tp_fn(total_tp, total_fn)
+    micro_f1 = f1_from_precision_recall(micro_precision, micro_recall)
+    macro_precision = tf.reduce_sum(precision)
+    macro_recall = tf.reduce_sum(recall)
+    macro_f1 = tf.reduce_sum(f1)
+    return EvaluationTensors(
+        confusion_matrix=confusion,
+        tp=tp,
+        fp=fp,
+        fn=fn,
+        tn=tn,
+        precision=precision,
+        recall=recall,
+        f1=f1,
+        accuracy=accuracy,
+        micro_precision=micro_precision,
+        micro_recall=micro_recall,
+        micro_f1=micro_f1,
+        macro_precision=macro_precision,
+        macro_recall=macro_recall,
+        macro_f1=macro_f1
+    )
+
 
 def evaluate_predictions(labels, predictions, n_classes):
-  labels = to_1d_vector(labels)
-  predictions = to_1d_vector(predictions)
+    labels = to_1d_vector(labels)
+    predictions = to_1d_vector(predictions)
+
+    correct_prediction = tf.equal(labels, predictions)
+    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
-  correct_prediction = tf.equal(labels, predictions)
-  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+    confusion = tf.contrib.metrics.confusion_matrix(
+        labels=labels,
+        predictions=predictions,
+        num_classes=n_classes
+    )
 
-  confusion = tf.contrib.metrics.confusion_matrix(
-    labels=labels,
-    predictions=predictions,
-    num_classes=n_classes
-  )
+    return _evaluate_from_confusion_matrix(
+        confusion=confusion,
+        accuracy=accuracy
+    )
 
-  return _evaluate_from_confusion_matrix(
-    confusion=confusion,
-    accuracy=accuracy
-  )
 
 def evaluate_separate_channels(targets, outputs):
-  n_classes = targets.shape[-1]
+    n_classes = targets.shape[-1]
 
-  labels = output_probabilities_to_class(targets)
-  predictions = output_probabilities_to_class(outputs)
-  return evaluate_predictions(
-    labels=labels,
-    predictions=predictions,
-    n_classes=n_classes
-  )
+    labels = output_probabilities_to_class(targets)
+    predictions = output_probabilities_to_class(outputs)
+    return evaluate_predictions(
+        labels=labels,
+        predictions=predictions,
+        n_classes=n_classes
+    )
 
 
 def evaluation_summary(evaluation_tensors, layer_labels):
-  tf.summary.scalar("micro_precision", evaluation_tensors.micro_precision)
-  tf.summary.scalar("micro_recall", evaluation_tensors.micro_recall)
-  tf.summary.scalar("micro_f1", evaluation_tensors.micro_f1)
-  tf.summary.scalar("macro_f1", evaluation_tensors.macro_f1)
-  tf.summary.scalar("accuracy", evaluation_tensors.accuracy)
-  for i, layer_label in enumerate(layer_labels):
-    tf.summary.scalar("f1_{}_{}".format(i, layer_label), evaluation_tensors.f1[i])
+    tf.summary.scalar("micro_precision", evaluation_tensors.micro_precision)
+    tf.summary.scalar("micro_recall", evaluation_tensors.micro_recall)
+    tf.summary.scalar("micro_f1", evaluation_tensors.micro_f1)
+    tf.summary.scalar("macro_f1", evaluation_tensors.macro_f1)
+    tf.summary.scalar("accuracy", evaluation_tensors.accuracy)
+    for i, layer_label in enumerate(layer_labels):
+        tf.summary.scalar("f1_{}_{}".format(i, layer_label), evaluation_tensors.f1[i])
diff --git a/sciencebeam_gym/trainer/models/pix2pix/evaluate_test.py b/sciencebeam_gym/trainer/models/pix2pix/evaluate_test.py
index 99e9c6f..f6a4877 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/evaluate_test.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/evaluate_test.py
@@ -7,49 +7,50 @@ import tensorflow as tf
 import numpy as np
 
 from sciencebeam_utils.utils.num import (
-  assert_close
+    assert_close
 )
 
 from sciencebeam_gym.trainer.models.pix2pix.evaluate import (
-  evaluate_predictions
+    evaluate_predictions
 )
 
+
 @pytest.mark.slow
 def test_evaluate_predictions():
-  n_classes = 4
-  predictions = tf.constant(np.array([0, 1, 1, 2, 3, 3]))
-  labels = tf.constant(np.array([0, 1, 2, 3, 3, 3]))
+    n_classes = 4
+    predictions = tf.constant(np.array([0, 1, 1, 2, 3, 3]))
+    labels = tf.constant(np.array([0, 1, 2, 3, 3, 3]))
 
-  evaluation_tensors = evaluate_predictions(
-    labels=labels,
-    predictions=predictions,
-    n_classes=n_classes
-  )
-  with tf.Session() as session:
-    assert np.array_equal(session.run(evaluation_tensors.confusion_matrix), np.array([
-      [1, 0, 0, 0],
-      [0, 1, 0, 0],
-      [0, 1, 0, 0],
-      [0, 0, 1, 2]
-    ]))
-    assert np.array_equal(session.run(evaluation_tensors.tp), np.array([1, 1, 0, 2]))
-    assert np.array_equal(session.run(evaluation_tensors.fp), np.array([0, 1, 1, 0]))
-    assert np.array_equal(session.run(evaluation_tensors.fn), np.array([0, 0, 1, 1]))
-    expected_micro_precision = 4.0 / (4 + 2)
-    expected_micro_recall = 4.0 / (4 + 2)
-    expected_micro_f1 = (
-      2 * expected_micro_precision * expected_micro_recall /
-      (expected_micro_precision + expected_micro_recall)
-    )
-    assert_close(
-      session.run(evaluation_tensors.micro_precision),
-      expected_micro_precision
-    )
-    assert_close(
-      session.run(evaluation_tensors.micro_recall),
-      expected_micro_recall
-    )
-    assert_close(
-      session.run(evaluation_tensors.micro_f1),
-      expected_micro_f1
+    evaluation_tensors = evaluate_predictions(
+        labels=labels,
+        predictions=predictions,
+        n_classes=n_classes
     )
+    with tf.Session() as session:
+        assert np.array_equal(session.run(evaluation_tensors.confusion_matrix), np.array([
+            [1, 0, 0, 0],
+            [0, 1, 0, 0],
+            [0, 1, 0, 0],
+            [0, 0, 1, 2]
+        ]))
+        assert np.array_equal(session.run(evaluation_tensors.tp), np.array([1, 1, 0, 2]))
+        assert np.array_equal(session.run(evaluation_tensors.fp), np.array([0, 1, 1, 0]))
+        assert np.array_equal(session.run(evaluation_tensors.fn), np.array([0, 0, 1, 1]))
+        expected_micro_precision = 4.0 / (4 + 2)
+        expected_micro_recall = 4.0 / (4 + 2)
+        expected_micro_f1 = (
+            2 * expected_micro_precision * expected_micro_recall /
+            (expected_micro_precision + expected_micro_recall)
+        )
+        assert_close(
+            session.run(evaluation_tensors.micro_precision),
+            expected_micro_precision
+        )
+        assert_close(
+            session.run(evaluation_tensors.micro_recall),
+            expected_micro_recall
+        )
+        assert_close(
+            session.run(evaluation_tensors.micro_f1),
+            expected_micro_f1
+        )
diff --git a/sciencebeam_gym/trainer/models/pix2pix/loss.py b/sciencebeam_gym/trainer/models/pix2pix/loss.py
index 2a60587..a685fdf 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/loss.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/loss.py
@@ -1,36 +1,39 @@
 import tensorflow as tf
 
+
 def l1_loss(labels, outputs):
-  with tf.name_scope("l1_loss"):
-    # abs(labels - outputs) => 0
-    return tf.reduce_mean(tf.abs(labels - outputs))
+    with tf.name_scope("l1_loss"):
+        # abs(labels - outputs) => 0
+        return tf.reduce_mean(tf.abs(labels - outputs))
+
 
 def cross_entropy_loss(labels, logits):
-  with tf.name_scope("cross_entropy"):
-    return tf.reduce_mean(
-      tf.nn.softmax_cross_entropy_with_logits(
-        logits=logits,
-        labels=labels
-      )
-    )
+    with tf.name_scope("cross_entropy"):
+        return tf.reduce_mean(
+            tf.nn.softmax_cross_entropy_with_logits(
+                logits=logits,
+                labels=labels
+            )
+        )
+
 
 def weighted_cross_entropy_loss(labels, logits, pos_weight, scalar=True):
-  with tf.name_scope("weighted_cross_entropy"):
-    softmax_loss = tf.nn.softmax_cross_entropy_with_logits(
-      logits=logits,
-      labels=labels
-    )
-    # calculate weight per sample using the labels and weight per class
-    weight_per_sample = tf.reduce_sum(
-      tf.multiply(
-        labels,
-        pos_weight
-      ),
-      axis=-1
-    )
-    # weight each loss per sample
-    value = tf.multiply(
-      softmax_loss,
-      weight_per_sample
-    )
-    return tf.reduce_mean(value) if scalar else value
+    with tf.name_scope("weighted_cross_entropy"):
+        softmax_loss = tf.nn.softmax_cross_entropy_with_logits(
+            logits=logits,
+            labels=labels
+        )
+        # calculate weight per sample using the labels and weight per class
+        weight_per_sample = tf.reduce_sum(
+            tf.multiply(
+                labels,
+                pos_weight
+            ),
+            axis=-1
+        )
+        # weight each loss per sample
+        value = tf.multiply(
+            softmax_loss,
+            weight_per_sample
+        )
+        return tf.reduce_mean(value) if scalar else value
diff --git a/sciencebeam_gym/trainer/models/pix2pix/loss_test.py b/sciencebeam_gym/trainer/models/pix2pix/loss_test.py
index 25916b3..e5cc2cf 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/loss_test.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/loss_test.py
@@ -1,137 +1,138 @@
 import logging
 
-from six import raise_from
-
 import tensorflow as tf
-import numpy as np
 
 from sciencebeam_utils.utils.num import (
-  assert_close
+    assert_close
 )
 
 from sciencebeam_gym.trainer.models.pix2pix.loss import (
-  l1_loss,
-  cross_entropy_loss,
-  weighted_cross_entropy_loss
+    l1_loss,
+    cross_entropy_loss,
+    weighted_cross_entropy_loss
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class TestL1Loss(object):
-  def test_should_return_abs_diff_for_single_value(self):
-    with tf.Graph().as_default():
-      labels = tf.constant([0.9])
-      outputs = tf.constant([0.1])
-      loss = l1_loss(labels, outputs)
-      with tf.Session() as session:
-        assert_close(session.run([loss])[0], 0.8)
+    def test_should_return_abs_diff_for_single_value(self):
+        with tf.Graph().as_default():
+            labels = tf.constant([0.9])
+            outputs = tf.constant([0.1])
+            loss = l1_loss(labels, outputs)
+            with tf.Session() as session:
+                assert_close(session.run([loss])[0], 0.8)
+
 
 class TestCrossEntropyLoss(object):
-  def test_should_return_zero_if_logits_are_matching_labels_with_neg_pos_value(self):
-    with tf.Graph().as_default():
-      labels = tf.constant([
-        [[0.0, 1.0]]
-      ])
-      logits = tf.constant([
-        [[-10.0, 10.0]]
-      ])
-      loss = cross_entropy_loss(labels, logits)
-      with tf.Session() as session:
-        assert_close(session.run([loss])[0], 0.0)
+    def test_should_return_zero_if_logits_are_matching_labels_with_neg_pos_value(self):
+        with tf.Graph().as_default():
+            labels = tf.constant([
+                [[0.0, 1.0]]
+            ])
+            logits = tf.constant([
+                [[-10.0, 10.0]]
+            ])
+            loss = cross_entropy_loss(labels, logits)
+            with tf.Session() as session:
+                assert_close(session.run([loss])[0], 0.0)
+
+    def test_should_return_not_zero_if_logits_are_not_matching_labels(self):
+        with tf.Graph().as_default():
+            labels = tf.constant([
+                [[0.0, 1.0]]
+            ])
+            logits = tf.constant([
+                [[10.0, 10.0]]
+            ])
+            loss = cross_entropy_loss(labels, logits)
+            with tf.Session() as session:
+                assert session.run([loss])[0] > 0.5
 
-  def test_should_return_not_zero_if_logits_are_not_matching_labels(self):
-    with tf.Graph().as_default():
-      labels = tf.constant([
-        [[0.0, 1.0]]
-      ])
-      logits = tf.constant([
-        [[10.0, 10.0]]
-      ])
-      loss = cross_entropy_loss(labels, logits)
-      with tf.Session() as session:
-        assert session.run([loss])[0] > 0.5
 
 class TestWeightedCrossEntropyLoss(object):
-  def test_should_return_zero_if_logits_are_matching_labels_with_neg_pos_value(self):
-    with tf.Graph().as_default():
-      labels = tf.constant([
-        [[0.0, 1.0]]
-      ])
-      logits = tf.constant([
-        [[-10.0, 10.0]]
-      ])
-      pos_weight = tf.constant([
-        1.0
-      ])
-      loss = weighted_cross_entropy_loss(labels, logits, pos_weight)
-      with tf.Session() as session:
-        assert_close(session.run([loss])[0], 0.0, atol=0.0001)
+    def test_should_return_zero_if_logits_are_matching_labels_with_neg_pos_value(self):
+        with tf.Graph().as_default():
+            labels = tf.constant([
+                [[0.0, 1.0]]
+            ])
+            logits = tf.constant([
+                [[-10.0, 10.0]]
+            ])
+            pos_weight = tf.constant([
+                1.0
+            ])
+            loss = weighted_cross_entropy_loss(labels, logits, pos_weight)
+            with tf.Session() as session:
+                assert_close(session.run([loss])[0], 0.0, atol=0.0001)
 
-  def test_should_return_not_zero_if_logits_are_not_matching_labels(self):
-    with tf.Graph().as_default():
-      labels = tf.constant([
-        [[0.0, 1.0]]
-      ])
-      logits = tf.constant([
-        [[10.0, 10.0]]
-      ])
-      pos_weight = tf.constant([
-        1.0
-      ])
-      loss = weighted_cross_entropy_loss(labels, logits, pos_weight)
-      with tf.Session() as session:
-        assert session.run([loss])[0] > 0.5
+    def test_should_return_not_zero_if_logits_are_not_matching_labels(self):
+        with tf.Graph().as_default():
+            labels = tf.constant([
+                [[0.0, 1.0]]
+            ])
+            logits = tf.constant([
+                [[10.0, 10.0]]
+            ])
+            pos_weight = tf.constant([
+                1.0
+            ])
+            loss = weighted_cross_entropy_loss(labels, logits, pos_weight)
+            with tf.Session() as session:
+                assert session.run([loss])[0] > 0.5
 
-  def test_should_return_higher_loss_for_value_with_greater_weight(self):
-    with tf.Graph().as_default():
-      labels = tf.constant([
-        [[0.0, 1.0]]
-      ])
-      logits = tf.constant([
-        [[10.0, 10.0]]
-      ])
-      pos_weight_1 = tf.constant([
-        0.5
-      ])
-      pos_weight_2 = tf.constant([
-        1.0
-      ])
-      loss_1 = weighted_cross_entropy_loss(labels, logits, pos_weight_1)
-      loss_2 = weighted_cross_entropy_loss(labels, logits, pos_weight_2)
-      with tf.Session() as session:
-        loss_1_value, loss_2_value = session.run([loss_1, loss_2])
-        assert loss_1_value < loss_2_value
+    def test_should_return_higher_loss_for_value_with_greater_weight(self):
+        with tf.Graph().as_default():
+            labels = tf.constant([
+                [[0.0, 1.0]]
+            ])
+            logits = tf.constant([
+                [[10.0, 10.0]]
+            ])
+            pos_weight_1 = tf.constant([
+                0.5
+            ])
+            pos_weight_2 = tf.constant([
+                1.0
+            ])
+            loss_1 = weighted_cross_entropy_loss(labels, logits, pos_weight_1)
+            loss_2 = weighted_cross_entropy_loss(labels, logits, pos_weight_2)
+            with tf.Session() as session:
+                loss_1_value, loss_2_value = session.run([loss_1, loss_2])
+                assert loss_1_value < loss_2_value
 
-  def test_should_support_batch_example_pos_weights(self):
-    batch_size = 3
-    with tf.Graph().as_default():
-      labels = tf.constant([[0.0, 1.0]] * batch_size)
-      logits = tf.constant([[10.0, 10.0]] * batch_size)
-      pos_weight_1 = tf.constant([
-        [0.5, 0.5],
-        [1.0, 1.0],
-        [1.0, 1.0]
-      ])
-      pos_weight_2 = tf.constant([
-        [1.0, 1.0],
-        [0.5, 0.5],
-        [1.0, 1.0]
-      ])
-      loss_1 = weighted_cross_entropy_loss(
-        labels, logits, pos_weight_1, scalar=False
-      )
-      loss_2 = weighted_cross_entropy_loss(
-        labels, logits, pos_weight_2, scalar=False
-      )
-      with tf.Session() as session:
-        get_logger().debug('labels=\n%s', labels.eval())
-        get_logger().debug('logits=\n%s', logits.eval())
-        loss_1_value, loss_2_value = session.run([loss_1, loss_2])
-        get_logger().debug(
-          '\nloss_1_value=\n%s\nloss_2_value=\n%s',
-          loss_1_value, loss_2_value
-        )
-        assert loss_1_value[0] < loss_2_value[0]
-        assert loss_1_value[1] > loss_2_value[1]
-        assert loss_1_value[2] == loss_2_value[2]
+    def test_should_support_batch_example_pos_weights(self):
+        batch_size = 3
+        with tf.Graph().as_default():
+            labels = tf.constant([[0.0, 1.0]] * batch_size)
+            logits = tf.constant([[10.0, 10.0]] * batch_size)
+            pos_weight_1 = tf.constant([
+                [0.5, 0.5],
+                [1.0, 1.0],
+                [1.0, 1.0]
+            ])
+            pos_weight_2 = tf.constant([
+                [1.0, 1.0],
+                [0.5, 0.5],
+                [1.0, 1.0]
+            ])
+            loss_1 = weighted_cross_entropy_loss(
+                labels, logits, pos_weight_1, scalar=False
+            )
+            loss_2 = weighted_cross_entropy_loss(
+                labels, logits, pos_weight_2, scalar=False
+            )
+            with tf.Session() as session:
+                get_logger().debug('labels=\n%s', labels.eval())
+                get_logger().debug('logits=\n%s', logits.eval())
+                loss_1_value, loss_2_value = session.run([loss_1, loss_2])
+                get_logger().debug(
+                    '\nloss_1_value=\n%s\nloss_2_value=\n%s',
+                    loss_1_value, loss_2_value
+                )
+                assert loss_1_value[0] < loss_2_value[0]
+                assert loss_1_value[1] > loss_2_value[1]
+                assert loss_1_value[2] == loss_2_value[2]
diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py
index effd89c..912b65f 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py
@@ -10,560 +10,598 @@ import collections
 import tensorflow as tf
 
 from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
-  blank_other_channels,
-  get_channel_slice
+    blank_other_channels,
+    get_channel_slice
 )
 
 from sciencebeam_gym.trainer.models.pix2pix.loss import (
-  l1_loss,
-  cross_entropy_loss,
-  weighted_cross_entropy_loss
+    l1_loss,
+    cross_entropy_loss,
+    weighted_cross_entropy_loss
 )
 
 EPS = 1e-12
 
+
 class BaseLoss(object):
-  L1 = "L1"
-  CROSS_ENTROPY = "CE"
-  # Loss with pre-calculated weights (and cross entropy)
-  # (weights calculated across entire dataset)
-  WEIGHTED_CROSS_ENTROPY = "WCE"
-  # Loss with per sample weights (and cross entropy)
-  # (weights calculated as part of the graph)
-  SAMPLE_WEIGHTED_CROSS_ENTROPY = "SWCE"
-  # Combination of the above; i.e. sample weights multiplied by pre-calculated weights
-  WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY = "WSWCE"
+    L1 = "L1"
+    CROSS_ENTROPY = "CE"
+    # Loss with pre-calculated weights (and cross entropy)
+    # (weights calculated across entire dataset)
+    WEIGHTED_CROSS_ENTROPY = "WCE"
+    # Loss with per sample weights (and cross entropy)
+    # (weights calculated as part of the graph)
+    SAMPLE_WEIGHTED_CROSS_ENTROPY = "SWCE"
+    # Combination of the above; i.e. sample weights multiplied by pre-calculated weights
+    WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY = "WSWCE"
+
 
 ALL_BASE_LOSS = [
-  BaseLoss.L1, BaseLoss.CROSS_ENTROPY, BaseLoss.WEIGHTED_CROSS_ENTROPY,
-  BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY
+    BaseLoss.L1, BaseLoss.CROSS_ENTROPY, BaseLoss.WEIGHTED_CROSS_ENTROPY,
+    BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY
 ]
 
 ALL_CE_BASE_LOSS = {
-  BaseLoss.CROSS_ENTROPY,
-  BaseLoss.WEIGHTED_CROSS_ENTROPY,
-  BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY,
-  BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY
+    BaseLoss.CROSS_ENTROPY,
+    BaseLoss.WEIGHTED_CROSS_ENTROPY,
+    BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY,
+    BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY
 }
 
 ALL_WCE_BASE_LOSS = ALL_CE_BASE_LOSS - {BaseLoss.CROSS_ENTROPY}
 
 Pix2PixModel = collections.namedtuple(
-  "Pix2PixModel", [
-    "inputs",
-    "targets",
-    "outputs",
-    "predict_real",
-    "predict_fake",
-    "discrim_loss",
-    "discrim_grads_and_vars",
-    "gen_loss_GAN",
-    "gen_loss_L1",
-    "gen_grads_and_vars",
-    "global_step",
-    "train"
-  ]
+    "Pix2PixModel", [
+        "inputs",
+        "targets",
+        "outputs",
+        "predict_real",
+        "predict_fake",
+        "discrim_loss",
+        "discrim_grads_and_vars",
+        "gen_loss_GAN",
+        "gen_loss_L1",
+        "gen_grads_and_vars",
+        "global_step",
+        "train"
+    ]
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def lrelu(x, a):
-  with tf.name_scope("lrelu"):
-    # adding these together creates the leak part and linear part
-    # then cancels them out by subtracting/adding an absolute value term
-    # leak: a*x/2 - a*abs(x)/2
-    # linear: x/2 + abs(x)/2
+    with tf.name_scope("lrelu"):
+        # adding these together creates the leak part and linear part
+        # then cancels them out by subtracting/adding an absolute value term
+        # leak: a*x/2 - a*abs(x)/2
+        # linear: x/2 + abs(x)/2
+
+        # this block looks like it has 2 inputs on the graph unless we do this
+        x = tf.identity(x)
+        return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
 
-    # this block looks like it has 2 inputs on the graph unless we do this
-    x = tf.identity(x)
-    return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
 
 def batchnorm(batch_input):
-  with tf.variable_scope("batchnorm"):
-    # this block looks like it has 3 inputs on the graph unless we do this
-    batch_input = tf.identity(batch_input)
-
-    input_shape = batch_input.get_shape()
-    get_logger().debug('batchnorm, input_shape: %s', input_shape)
-    channels = input_shape[-1]
-    offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
-    scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02))
-    mean, variance = tf.nn.moments(batch_input, axes=[0, 1, 2], keep_dims=False)
-    variance_epsilon = 1e-5
-    normalized = tf.nn.batch_normalization(batch_input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
-    return normalized
+    with tf.variable_scope("batchnorm"):
+        # this block looks like it has 3 inputs on the graph unless we do this
+        batch_input = tf.identity(batch_input)
+
+        input_shape = batch_input.get_shape()
+        get_logger().debug('batchnorm, input_shape: %s', input_shape)
+        channels = input_shape[-1]
+        offset = tf.get_variable(
+            "offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()
+        )
+        scale = tf.get_variable(
+            "scale", [channels],
+            dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02)
+        )
+        mean, variance = tf.nn.moments(batch_input, axes=[0, 1, 2], keep_dims=False)
+        variance_epsilon = 1e-5
+        normalized = tf.nn.batch_normalization(
+            batch_input, mean, variance, offset, scale, variance_epsilon=variance_epsilon
+        )
+        return normalized
+
 
 def conv(batch_input, out_channels, stride):
-  with tf.variable_scope("conv"):
-    input_shape = batch_input.get_shape()
-    get_logger().debug('conv, input_shape: %s', input_shape)
-    in_channels = input_shape[-1]
-    conv_filter = tf.get_variable(
-      "filter",
-      [4, 4, in_channels, out_channels],
-      dtype=tf.float32,
-      initializer=tf.random_normal_initializer(0, 0.02)
-    )
-    # [batch, in_height, in_width, in_channels],
-    # [filter_width, filter_height, in_channels, out_channels]
-    #     => [batch, out_height, out_width, out_channels]
-    padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
-    return tf.nn.conv2d(padded_input, conv_filter, [1, stride, stride, 1], padding="VALID")
+    with tf.variable_scope("conv"):
+        input_shape = batch_input.get_shape()
+        get_logger().debug('conv, input_shape: %s', input_shape)
+        in_channels = input_shape[-1]
+        conv_filter = tf.get_variable(
+            "filter",
+            [4, 4, in_channels, out_channels],
+            dtype=tf.float32,
+            initializer=tf.random_normal_initializer(0, 0.02)
+        )
+        # [batch, in_height, in_width, in_channels],
+        # [filter_width, filter_height, in_channels, out_channels]
+        #     => [batch, out_height, out_width, out_channels]
+        padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
+        return tf.nn.conv2d(padded_input, conv_filter, [1, stride, stride, 1], padding="VALID")
+
 
 def deconv(batch_input, out_channels):
-  with tf.variable_scope("deconv"):
-    input_shape = batch_input.get_shape()
-    get_logger().debug('deconv, input_shape: %s', input_shape)
-
-    batch, in_height, in_width, in_channels = input_shape.as_list()
-    conv_filter = tf.get_variable(
-      "filter",
-      [4, 4, out_channels, in_channels],
-      dtype=tf.float32,
-      initializer=tf.random_normal_initializer(0, 0.02)
-    )
-    # [batch, in_height, in_width, in_channels],
-    # [filter_width, filter_height, out_channels, in_channels]
-    #     => [batch, out_height, out_width, out_channels]
-    # tf.shape creates a tensor to allow calculation of shape at runtime
-    dynamic_batch_size = tf.shape(batch_input)[0]
-    output_shape = [batch, in_height * 2, in_width * 2, out_channels]
-    return tf.reshape(
-      tf.nn.conv2d_transpose(
-        batch_input,
-        conv_filter,
-        tf.stack([dynamic_batch_size] + output_shape[1:]),
-        [1, 2, 2, 1],
-        padding="SAME"
-      ),
-      [batch or -1] + output_shape[1:]
-    )
+    with tf.variable_scope("deconv"):
+        input_shape = batch_input.get_shape()
+        get_logger().debug('deconv, input_shape: %s', input_shape)
+
+        batch, in_height, in_width, in_channels = input_shape.as_list()
+        conv_filter = tf.get_variable(
+            "filter",
+            [4, 4, out_channels, in_channels],
+            dtype=tf.float32,
+            initializer=tf.random_normal_initializer(0, 0.02)
+        )
+        # [batch, in_height, in_width, in_channels],
+        # [filter_width, filter_height, out_channels, in_channels]
+        #     => [batch, out_height, out_width, out_channels]
+        # tf.shape creates a tensor to allow calculation of shape at runtime
+        dynamic_batch_size = tf.shape(batch_input)[0]
+        output_shape = [batch, in_height * 2, in_width * 2, out_channels]
+        return tf.reshape(
+            tf.nn.conv2d_transpose(
+                batch_input,
+                conv_filter,
+                tf.stack([dynamic_batch_size] + output_shape[1:]),
+                [1, 2, 2, 1],
+                padding="SAME"
+            ),
+            [batch or -1] + output_shape[1:]
+        )
+
 
 def create_encoder_layers(
-  generator_inputs,
-  layer_specs):
-  """
-  Creates encoders with every layer halfing width and height,
-  using an output depth specified by layer_specs.
-
-  Args:
-    generator_inputs: A `Tensor`. The input to the generator.
-    layer_specs: A list of number of output channels for each layer.
-  Returns:
-    A list of `Tensor`. The output tensor of each layer.
-  """
-
-  layers = []
-  for out_channels in layer_specs:
-    with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
-      if not layers:
-        # very first layer is connected to inputs and doesn't use batch norm
-        output = conv(generator_inputs, out_channels, stride=2)
-        layers.append(output)
-      else:
-        rectified = lrelu(layers[-1], 0.2)
-        # [batch, in_height, in_width, in_channels] =>
-        # [batch, in_height/2, in_width/2, out_channels]
-        convolved = conv(rectified, out_channels, stride=2)
-        output = batchnorm(convolved)
-        layers.append(output)
-  return layers
+        generator_inputs,
+        layer_specs):
+    """
+    Creates encoders with every layer halfing width and height,
+    using an output depth specified by layer_specs.
+
+    Args:
+      generator_inputs: A `Tensor`. The input to the generator.
+      layer_specs: A list of number of output channels for each layer.
+    Returns:
+      A list of `Tensor`. The output tensor of each layer.
+    """
+
+    layers = []
+    for out_channels in layer_specs:
+        with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
+            if not layers:
+                # very first layer is connected to inputs and doesn't use batch norm
+                output = conv(generator_inputs, out_channels, stride=2)
+                layers.append(output)
+            else:
+                rectified = lrelu(layers[-1], 0.2)
+                # [batch, in_height, in_width, in_channels] =>
+                # [batch, in_height/2, in_width/2, out_channels]
+                convolved = conv(rectified, out_channels, stride=2)
+                output = batchnorm(convolved)
+                layers.append(output)
+    return layers
+
 
 def conditional_dropout(cond, x, **kwargs):
-  return tf.cond(cond, lambda: tf.nn.dropout(x, **kwargs), lambda: x)
+    return tf.cond(cond, lambda: tf.nn.dropout(x, **kwargs), lambda: x)
+
 
 def create_decoder_layers(
-  decoder_inputs,
-  layer_specs,
-  skip_connection_layers,
-  is_training):
-  """
-  Creates decoders with every layer doubling width and height,
-  using an output depth and dropout specified by layer_specs.
-
-  Args:
-    generator_inputs: A `Tensor`. The input to the generator.
-    layer_specs: A list of tuples for each layer. Each tuple containing
-      number of output channels and dropout.
-  Returns:
-    A list of `Tensor`. The output tensor of each layer.
-  """
-
-  layers = []
-  num_encoder_layers = len(skip_connection_layers)
-  for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
-    skip_layer = num_encoder_layers - decoder_layer - 1
-    is_last_layer = decoder_layer == len(layer_specs) - 1
-    with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
-      if decoder_layer == 0:
-        # first decoder layer doesn't have skip connections
-        # since it is directly connected to the skip_layer
-        layer_input = decoder_inputs
-      else:
-        layer_input = tf.concat([layers[-1], skip_connection_layers[skip_layer]], axis=3)
-
-      rectified = tf.nn.relu(layer_input)
-      # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
-      output = deconv(rectified, out_channels)
-      if not is_last_layer:
-        # very last layer does not use batch norm
-        output = batchnorm(output)
-
-      if dropout > 0.0:
-        # only apply dropout in training mode
-        output = conditional_dropout(
-          tf.convert_to_tensor(is_training),
-          output,
-          keep_prob=1 - dropout
-        )
+        decoder_inputs,
+        layer_specs,
+        skip_connection_layers,
+        is_training):
+    """
+    Creates decoders with every layer doubling width and height,
+    using an output depth and dropout specified by layer_specs.
+
+    Args:
+      generator_inputs: A `Tensor`. The input to the generator.
+      layer_specs: A list of tuples for each layer. Each tuple containing
+        number of output channels and dropout.
+    Returns:
+      A list of `Tensor`. The output tensor of each layer.
+    """
+
+    layers = []
+    num_encoder_layers = len(skip_connection_layers)
+    for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
+        skip_layer = num_encoder_layers - decoder_layer - 1
+        is_last_layer = decoder_layer == len(layer_specs) - 1
+        with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
+            if decoder_layer == 0:
+                # first decoder layer doesn't have skip connections
+                # since it is directly connected to the skip_layer
+                layer_input = decoder_inputs
+            else:
+                layer_input = tf.concat([layers[-1], skip_connection_layers[skip_layer]], axis=3)
+
+            rectified = tf.nn.relu(layer_input)
+            # [batch, in_height, in_width, in_channels]
+            # => [batch, in_height*2, in_width*2, out_channels]
+            output = deconv(rectified, out_channels)
+            if not is_last_layer:
+                # very last layer does not use batch norm
+                output = batchnorm(output)
+
+            if dropout > 0.0:
+                # only apply dropout in training mode
+                output = conditional_dropout(
+                    tf.convert_to_tensor(is_training),
+                    output,
+                    keep_prob=1 - dropout
+                )
+
+            layers.append(output)
+    return layers
 
-      layers.append(output)
-  return layers
 
 def create_encoder_decoder(
-  generator_inputs,
-  encoder_layer_specs,
-  decoder_layer_specs,
-  is_training):
-
-  encoder_layers = create_encoder_layers(
-    generator_inputs,
-    encoder_layer_specs
-  )
-  decoder_layers = create_decoder_layers(
-    encoder_layers[-1],
-    decoder_layer_specs,
-    encoder_layers,
-    is_training
-  )
-  return decoder_layers[-1]
+        generator_inputs,
+        encoder_layer_specs,
+        decoder_layer_specs,
+        is_training):
+
+    encoder_layers = create_encoder_layers(
+        generator_inputs,
+        encoder_layer_specs
+    )
+    decoder_layers = create_decoder_layers(
+        encoder_layers[-1],
+        decoder_layer_specs,
+        encoder_layers,
+        is_training
+    )
+    return decoder_layers[-1]
+
 
 def create_generator(generator_inputs, generator_outputs_channels, a, is_training):
-  encoder_layer_specs = [
-    a.ngf,     # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
-    a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
-    a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
-    a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
-    a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
-    a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
-    a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
-    a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
-  ]
-  decoder_layer_specs = [
-    (a.ngf * 8, 0.5),   # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
-    (a.ngf * 8, 0.5),   # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
-    (a.ngf * 8, 0.5),   # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
-    (a.ngf * 8, 0.0),   # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
-    (a.ngf * 4, 0.0),   # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
-    (a.ngf * 2, 0.0),   # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
-    (a.ngf, 0.0),       # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
-    # decoder_1 [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
-    (generator_outputs_channels, 0.0)
-  ]
-  return create_encoder_decoder(
-    generator_inputs,
-    encoder_layer_specs,
-    decoder_layer_specs,
-    is_training
-  )
+    encoder_layer_specs = [
+        a.ngf,     # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
+        a.ngf * 2,  # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
+        a.ngf * 4,  # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
+        a.ngf * 8,  # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
+        a.ngf * 8,  # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
+        a.ngf * 8,  # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
+        a.ngf * 8,  # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
+        a.ngf * 8,  # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
+    ]
+    decoder_layer_specs = [
+        (a.ngf * 8, 0.5),   # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
+        (a.ngf * 8, 0.5),   # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
+        (a.ngf * 8, 0.5),   # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
+        (a.ngf * 8, 0.0),   # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
+        # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
+        (a.ngf * 4, 0.0),
+        # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
+        (a.ngf * 2, 0.0),
+        (a.ngf, 0.0),       # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
+        # decoder_1 [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
+        (generator_outputs_channels, 0.0)
+    ]
+    return create_encoder_decoder(
+        generator_inputs,
+        encoder_layer_specs,
+        decoder_layer_specs,
+        is_training
+    )
+
 
 def create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1):
-  n_layers = 3
-  layers = []
-
-  # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
-  layer_input = tf.concat([discrim_inputs, discrim_targets], axis=3)
-
-  # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
-  with tf.variable_scope("layer_1"):
-    convolved = conv(layer_input, a.ndf, stride=2)
-    rectified = lrelu(convolved, 0.2)
-    layers.append(rectified)
-
-  # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
-  # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
-  # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
-  for i in range(n_layers):
+    n_layers = 3
+    layers = []
+
+    # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
+    layer_input = tf.concat([discrim_inputs, discrim_targets], axis=3)
+
+    # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
+    with tf.variable_scope("layer_1"):
+        convolved = conv(layer_input, a.ndf, stride=2)
+        rectified = lrelu(convolved, 0.2)
+        layers.append(rectified)
+
+    # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
+    # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
+    # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
+    for i in range(n_layers):
+        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
+            layer_out_channels = a.ndf * min(2 ** (i + 1), 8)
+            stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
+            convolved = conv(layers[-1], layer_out_channels, stride=stride)
+            normalized = batchnorm(convolved)
+            rectified = lrelu(normalized, 0.2)
+            layers.append(rectified)
+
+    # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
     with tf.variable_scope("layer_%d" % (len(layers) + 1)):
-      layer_out_channels = a.ndf * min(2**(i+1), 8)
-      stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
-      convolved = conv(layers[-1], layer_out_channels, stride=stride)
-      normalized = batchnorm(convolved)
-      rectified = lrelu(normalized, 0.2)
-      layers.append(rectified)
+        convolved = conv(rectified, out_channels=out_channels, stride=1)
+        output = tf.sigmoid(convolved)
+        layers.append(output)
 
-  # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
-  with tf.variable_scope("layer_%d" % (len(layers) + 1)):
-    convolved = conv(rectified, out_channels=out_channels, stride=1)
-    output = tf.sigmoid(convolved)
-    layers.append(output)
+    return layers[-1]
 
-  return layers[-1]
 
 def with_variable_scope(name, fn, reuse=False, **kwargs):
-  with tf.variable_scope(name, reuse=reuse):
-    return fn(**kwargs)
+    with tf.variable_scope(name, reuse=reuse):
+        return fn(**kwargs)
+
 
 def create_separate_discriminators(discrim_inputs, discrim_targets, a, out_channels=1):
-  if out_channels == 1:
-    return create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1)
-  n_targets_channels = discrim_targets.shape[-1]
-  reuse = tf.get_variable_scope().reuse
-  return tf.concat([
-    with_variable_scope(
-      'layer_{}'.format(channel_index),
-      lambda i: create_discriminator(
-        discrim_inputs,
-        get_channel_slice(discrim_targets, i),
-        a,
-        out_channels=1
-      ),
-      reuse=reuse,
-      i=channel_index
-    )
-    for channel_index in range(n_targets_channels)
-  ], axis=0)
+    if out_channels == 1:
+        return create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1)
+    n_targets_channels = discrim_targets.shape[-1]
+    reuse = tf.get_variable_scope().reuse
+    return tf.concat([
+        with_variable_scope(
+            'layer_{}'.format(channel_index),
+            lambda i: create_discriminator(
+                discrim_inputs,
+                get_channel_slice(discrim_targets, i),
+                a,
+                out_channels=1
+            ),
+            reuse=reuse,
+            i=channel_index
+        )
+        for channel_index in range(n_targets_channels)
+    ], axis=0)
+
 
 def create_separate_channel_discriminator_by_blanking_out_channels(inputs, targets, a):
-  # We need to teach the discriminator to detect the real channels,
-  # by just looking at the real channel.
-  # For each channel:
-  # - let the discriminator only see the current channel, blank out all other channels
-  # - expect output to not be fake for the not blanked out channel
-  n_targets_channels = int(targets.shape[-1])
-  predict_real_channels = []
-  predict_real_blanked_list = []
-  for i in range(n_targets_channels):
-    masked_targets = blank_other_channels(
-      targets,
-      i
+    # We need to teach the discriminator to detect the real channels,
+    # by just looking at the real channel.
+    # For each channel:
+    # - let the discriminator only see the current channel, blank out all other channels
+    # - expect output to not be fake for the not blanked out channel
+    n_targets_channels = int(targets.shape[-1])
+    predict_real_channels = []
+    predict_real_blanked_list = []
+    for i in range(n_targets_channels):
+        masked_targets = blank_other_channels(
+            targets,
+            i
+        )
+        with tf.variable_scope("discriminator", reuse=(i > 0)):
+            # 2x [batch, height, width, channels] => [batch, 30, 30, n_targets_channels]
+            predict_real_i = create_discriminator(
+                inputs, masked_targets, a,
+                out_channels=n_targets_channels
+            )
+            predict_real_channels.append(predict_real_i[:, :, :, i])
+            for j in range(n_targets_channels):
+                if j != i:
+                    predict_real_blanked_list.append(predict_real_i[:, :, :, j])
+    predict_real = tf.stack(
+        predict_real_channels,
+        axis=-1,
+        name='predict_real'
     )
-    with tf.variable_scope("discriminator", reuse=(i > 0)):
-      # 2x [batch, height, width, channels] => [batch, 30, 30, n_targets_channels]
-      predict_real_i = create_discriminator(
-        inputs, masked_targets, a,
-        out_channels=n_targets_channels
-      )
-      predict_real_channels.append(predict_real_i[:, :, :, i])
-      for j in range(n_targets_channels):
-        if j != i:
-          predict_real_blanked_list.append(predict_real_i[:, :, :, j])
-  predict_real = tf.stack(
-    predict_real_channels,
-    axis=-1,
-    name='predict_real'
-  )
-  predict_real_blanked = tf.stack(
-    predict_real_blanked_list,
-    axis=-1,
-    name='predict_real_blanked'
-  )
-  return predict_real, predict_real_blanked
+    predict_real_blanked = tf.stack(
+        predict_real_blanked_list,
+        axis=-1,
+        name='predict_real_blanked'
+    )
+    return predict_real, predict_real_blanked
 
 
 def create_pix2pix_model(inputs, targets, a, is_training, pos_weight=None, n_output_channels=None):
-  get_logger().info('gan_weight: %s, l1_weight: %s', a.gan_weight, a.l1_weight)
-  gan_enabled = abs(a.gan_weight) > 0.000001
-
-  is_predict = targets is None
-  if n_output_channels is None:
-    n_output_channels = int(targets.shape[-1])
-  with tf.variable_scope("generator"):
-    outputs = create_generator(inputs, n_output_channels, a, is_training)
-    if a.base_loss in ALL_CE_BASE_LOSS:
-      output_logits = outputs
-      outputs = tf.nn.softmax(output_logits)
-    else:
-      outputs = tf.tanh(outputs)
-
-  if is_predict:
-    return Pix2PixModel(
-      inputs=inputs,
-      targets=None,
-      predict_real=None,
-      predict_fake=None,
-      discrim_loss=None,
-      discrim_grads_and_vars=None,
-      gen_loss_GAN=None,
-      gen_loss_L1=None,
-      gen_grads_and_vars=None,
-      outputs=outputs,
-      global_step=None,
-      train=None,
-    )
+    get_logger().info('gan_weight: %s, l1_weight: %s', a.gan_weight, a.l1_weight)
+    gan_enabled = abs(a.gan_weight) > 0.000001
+
+    is_predict = targets is None
+    if n_output_channels is None:
+        n_output_channels = int(targets.shape[-1])
+    with tf.variable_scope("generator"):
+        outputs = create_generator(inputs, n_output_channels, a, is_training)
+        if a.base_loss in ALL_CE_BASE_LOSS:
+            output_logits = outputs
+            outputs = tf.nn.softmax(output_logits)
+        else:
+            outputs = tf.tanh(outputs)
+
+    if is_predict:
+        return Pix2PixModel(
+            inputs=inputs,
+            targets=None,
+            predict_real=None,
+            predict_fake=None,
+            discrim_loss=None,
+            discrim_grads_and_vars=None,
+            gen_loss_GAN=None,
+            gen_loss_L1=None,
+            gen_grads_and_vars=None,
+            outputs=outputs,
+            global_step=None,
+            train=None,
+        )
 
-  n_targets_channels = n_output_channels
+    n_targets_channels = n_output_channels
 
-  if gan_enabled:
-    discrim_out_channels = (
-      n_targets_channels
-      if a.use_separate_discriminator_channels
-      else 1
-    )
-    get_logger().info('discrim_out_channels: %s', discrim_out_channels)
-
-    # create two copies of discriminator, one for real pairs and one for fake pairs
-    # they share the same underlying variables
-    with tf.name_scope("real_discriminator"):
-      if discrim_out_channels > 1:
-        predict_real, predict_real_blanked = (
-          create_separate_channel_discriminator_by_blanking_out_channels(
-            inputs, targets, a
-          )
+    if gan_enabled:
+        discrim_out_channels = (
+            n_targets_channels
+            if a.use_separate_discriminator_channels
+            else 1
         )
-      else:
-        with tf.variable_scope("discriminator"):
-          if a.use_separate_discriminators:
-            get_logger().info('using separate discriminators: %s', n_targets_channels)
-            predict_real = create_separate_discriminators(
-              inputs, targets, a, out_channels=n_targets_channels
+        get_logger().info('discrim_out_channels: %s', discrim_out_channels)
+
+        # create two copies of discriminator, one for real pairs and one for fake pairs
+        # they share the same underlying variables
+        with tf.name_scope("real_discriminator"):
+            if discrim_out_channels > 1:
+                predict_real, predict_real_blanked = (
+                    create_separate_channel_discriminator_by_blanking_out_channels(
+                        inputs, targets, a
+                    )
+                )
+            else:
+                with tf.variable_scope("discriminator"):
+                    if a.use_separate_discriminators:
+                        get_logger().info('using separate discriminators: %s', n_targets_channels)
+                        predict_real = create_separate_discriminators(
+                            inputs, targets, a, out_channels=n_targets_channels
+                        )
+                    else:
+                        # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
+                        predict_real = create_discriminator(inputs, targets, a)
+
+        with tf.name_scope("fake_discriminator"):
+            with tf.variable_scope("discriminator", reuse=True):
+                # 2x [batch, height, width, channels] => [batch, 30, 30, discrim_out_channels]
+                # We don't need to split the channels,
+                # the discriminator should detect them all as fake
+                if a.use_separate_discriminators:
+                    predict_fake = create_separate_discriminators(
+                        inputs, outputs, a, out_channels=n_targets_channels
+                    )
+                else:
+                    predict_fake = create_discriminator(
+                        inputs, outputs, a,
+                        out_channels=discrim_out_channels
+                    )
+
+        with tf.name_scope("discriminator_loss"):
+            # minimizing -tf.log will try to get inputs to 1
+            # predict_real => 1
+            # predict_fake => 0
+            discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) +
+                                            tf.log(1 - predict_fake + EPS)))
+            if discrim_out_channels > 1:
+                discrim_loss += tf.reduce_mean(
+                    -tf.log(1 - tf.reshape(predict_real_blanked, [-1]) + EPS)
+                )
+
+        with tf.name_scope("discriminator_train"):
+            discrim_tvars = [
+                var for var in tf.trainable_variables() if var.name.startswith("discriminator")
+            ]
+            discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
+            discrim_grads_and_vars = discrim_optim.compute_gradients(
+                discrim_loss, var_list=discrim_tvars)
+            discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)
+    else:
+        with tf.name_scope("gan_disabled"):
+            discrim_loss = tf.constant(0.0)
+            predict_real = None
+            predict_fake = None
+            discrim_grads_and_vars = []
+
+    with tf.name_scope("generator_loss"):
+        get_logger().info('using loss: %s', a.base_loss)
+        if a.base_loss == BaseLoss.L1:
+            gen_base_loss = l1_loss(labels=targets, outputs=outputs)
+        elif a.base_loss == BaseLoss.CROSS_ENTROPY:
+            gen_base_loss = cross_entropy_loss(
+                logits=output_logits,
+                labels=targets
+            )
+        elif a.base_loss in ALL_WCE_BASE_LOSS:
+            if pos_weight is None:
+                raise ValueError('pos_weight missing')
+            pos_weight = tf.convert_to_tensor(pos_weight)
+            gen_base_loss = weighted_cross_entropy_loss(
+                logits=output_logits,
+                labels=targets,
+                pos_weight=pos_weight
             )
-          else:
-            # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
-            predict_real = create_discriminator(inputs, targets, a)
-
-    with tf.name_scope("fake_discriminator"):
-      with tf.variable_scope("discriminator", reuse=True):
-        # 2x [batch, height, width, channels] => [batch, 30, 30, discrim_out_channels]
-        # We don't need to split the channels, the discriminator should detect them all as fake
-        if a.use_separate_discriminators:
-          predict_fake = create_separate_discriminators(
-            inputs, outputs, a, out_channels=n_targets_channels
-          )
         else:
-          predict_fake = create_discriminator(
-            inputs, outputs, a,
-            out_channels=discrim_out_channels
-          )
-
-    with tf.name_scope("discriminator_loss"):
-      # minimizing -tf.log will try to get inputs to 1
-      # predict_real => 1
-      # predict_fake => 0
-      discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
-      if discrim_out_channels > 1:
-        discrim_loss += tf.reduce_mean(-tf.log(1 - tf.reshape(predict_real_blanked, [-1]) + EPS))
-
-    with tf.name_scope("discriminator_train"):
-      discrim_tvars = [
-        var for var in tf.trainable_variables() if var.name.startswith("discriminator")
-      ]
-      discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
-      discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
-      discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)
-  else:
-    with tf.name_scope("gan_disabled"):
-      discrim_loss = tf.constant(0.0)
-      predict_real = None
-      predict_fake = None
-      discrim_grads_and_vars = []
-
-  with tf.name_scope("generator_loss"):
-    get_logger().info('using loss: %s', a.base_loss)
-    if a.base_loss == BaseLoss.L1:
-      gen_base_loss = l1_loss(labels=targets, outputs=outputs)
-    elif a.base_loss == BaseLoss.CROSS_ENTROPY:
-      gen_base_loss = cross_entropy_loss(
-        logits=output_logits,
-        labels=targets
-      )
-    elif a.base_loss in ALL_WCE_BASE_LOSS:
-      if pos_weight is None:
-        raise ValueError('pos_weight missing')
-      pos_weight = tf.convert_to_tensor(pos_weight)
-      gen_base_loss = weighted_cross_entropy_loss(
-        logits=output_logits,
-        labels=targets,
-        pos_weight=pos_weight
-      )
-    else:
-      raise ValueError('unrecognised base loss: %s' % a.base_loss)
-    gen_loss = gen_base_loss * a.l1_weight
+            raise ValueError('unrecognised base loss: %s' % a.base_loss)
+        gen_loss = gen_base_loss * a.l1_weight
 
-    if gan_enabled:
-      # predict_fake => 1
-      gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
-      gen_loss += gen_loss_GAN * a.gan_weight
-    else:
-      gen_loss_GAN = tf.constant(0.0)
+        if gan_enabled:
+            # predict_fake => 1
+            gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
+            gen_loss += gen_loss_GAN * a.gan_weight
+        else:
+            gen_loss_GAN = tf.constant(0.0)
+
+    with tf.name_scope("generator_train"):
+        generator_train_dependencies = (
+            [discrim_train] if gan_enabled
+            else []
+        )
+        with tf.control_dependencies(generator_train_dependencies):
+            gen_tvars = [var for var in tf.trainable_variables(
+            ) if var.name.startswith("generator")]
+            gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
+            gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
+            gen_train = gen_optim.apply_gradients(gen_grads_and_vars)
+
+    ema = tf.train.ExponentialMovingAverage(decay=0.99)
+    update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_base_loss])
+
+    global_step = tf.contrib.framework.get_or_create_global_step()
+    incr_global_step = tf.assign(global_step, global_step + 1)
 
-  with tf.name_scope("generator_train"):
-    generator_train_dependencies = (
-      [discrim_train] if gan_enabled
-      else []
+    # TODO change gen_loss_L1 name
+    return Pix2PixModel(
+        inputs=inputs,
+        targets=targets,
+        predict_real=predict_real,
+        predict_fake=predict_fake,
+        discrim_loss=ema.average(discrim_loss),
+        discrim_grads_and_vars=discrim_grads_and_vars,
+        gen_loss_GAN=ema.average(gen_loss_GAN),
+        gen_loss_L1=ema.average(gen_base_loss),
+        gen_grads_and_vars=gen_grads_and_vars,
+        outputs=outputs,
+        global_step=global_step,
+        train=tf.group(update_losses, incr_global_step, gen_train),
     )
-    with tf.control_dependencies(generator_train_dependencies):
-      gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
-      gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
-      gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
-      gen_train = gen_optim.apply_gradients(gen_grads_and_vars)
-
-  ema = tf.train.ExponentialMovingAverage(decay=0.99)
-  update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_base_loss])
-
-  global_step = tf.contrib.framework.get_or_create_global_step()
-  incr_global_step = tf.assign(global_step, global_step+1)
-
-  # TODO change gen_loss_L1 name
-  return Pix2PixModel(
-    inputs=inputs,
-    targets=targets,
-    predict_real=predict_real,
-    predict_fake=predict_fake,
-    discrim_loss=ema.average(discrim_loss),
-    discrim_grads_and_vars=discrim_grads_and_vars,
-    gen_loss_GAN=ema.average(gen_loss_GAN),
-    gen_loss_L1=ema.average(gen_base_loss),
-    gen_grads_and_vars=gen_grads_and_vars,
-    outputs=outputs,
-    global_step=global_step,
-    train=tf.group(update_losses, incr_global_step, gen_train),
-  )
+
 
 def create_image_summaries(model):
-  def convert(image):
-    return tf.image.convert_image_dtype(
-      image,
-      dtype=tf.uint8,
-      saturate=True
-    )
-  summaries = {}
+    def convert(image):
+        return tf.image.convert_image_dtype(
+            image,
+            dtype=tf.uint8,
+            saturate=True
+        )
+    summaries = {}
 
-  # reverse any processing on images so they can be written to disk or displayed to user
-  with tf.name_scope("convert_inputs"):
-    converted_inputs = convert(model.inputs)
+    # reverse any processing on images so they can be written to disk or displayed to user
+    with tf.name_scope("convert_inputs"):
+        converted_inputs = convert(model.inputs)
 
-  with tf.name_scope("convert_targets"):
-    converted_targets = convert(model.targets)
+    with tf.name_scope("convert_targets"):
+        converted_targets = convert(model.targets)
 
-  with tf.name_scope("convert_outputs"):
-    converted_outputs = convert(model.outputs)
+    with tf.name_scope("convert_outputs"):
+        converted_outputs = convert(model.outputs)
 
-  with tf.name_scope("inputs_summary"):
-    tf.summary.image("inputs", converted_inputs)
+    with tf.name_scope("inputs_summary"):
+        tf.summary.image("inputs", converted_inputs)
 
-  with tf.name_scope("targets_summary"):
-    tf.summary.image("targets", converted_targets)
+    with tf.name_scope("targets_summary"):
+        tf.summary.image("targets", converted_targets)
 
-  with tf.name_scope("outputs_summary"):
-    tf.summary.image("outputs", converted_outputs)
-    summaries['output_image'] = converted_outputs
+    with tf.name_scope("outputs_summary"):
+        tf.summary.image("outputs", converted_outputs)
+        summaries['output_image'] = converted_outputs
 
-  with tf.name_scope("predict_real_summary"):
-    tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))
+    with tf.name_scope("predict_real_summary"):
+        tf.summary.image(
+            "predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)
+        )
+
+    with tf.name_scope("predict_fake_summary"):
+        tf.summary.image(
+            "predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)
+        )
+    return summaries
 
-  with tf.name_scope("predict_fake_summary"):
-    tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))
-  return summaries
 
 def create_other_summaries(model):
-  tf.summary.scalar("discriminator_loss", model.discrim_loss)
-  tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
-  tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
+    tf.summary.scalar("discriminator_loss", model.discrim_loss)
+    tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
+    tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
 
-  for var in tf.trainable_variables():
-    tf.summary.histogram(var.op.name + "/values", var)
+    for var in tf.trainable_variables():
+        tf.summary.histogram(var.op.name + "/values", var)
 
-  for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
-    tf.summary.histogram(var.op.name + "/gradients", grad)
+    for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
+        tf.summary.histogram(var.op.name + "/gradients", grad)
diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core_test.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core_test.py
index b514198..b71197a 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core_test.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core_test.py
@@ -7,32 +7,32 @@ import numpy as np
 import pytest
 
 from sciencebeam_utils.utils.num import (
-  assert_all_close,
-  assert_all_not_close
+    assert_all_close,
+    assert_all_not_close
 )
 
 from sciencebeam_utils.utils.collection import (
-  extend_dict
+    extend_dict
 )
 
 import sciencebeam_gym.trainer.models.pix2pix.pix2pix_core as pix2pix_core
 
 from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
-  create_encoder_decoder,
-  create_pix2pix_model,
-  BaseLoss
+    create_encoder_decoder,
+    create_pix2pix_model,
+    BaseLoss
 )
 
 DEFAULT_ARGS = dict(
-  ngf=2, # use small ngf and ndf to keep the graph smaller
-  ndf=2,
-  lr=0.0002,
-  beta1=0.5,
-  l1_weight=1.0,
-  gan_weight=0.0,
-  base_loss=BaseLoss.L1,
-  use_separate_discriminator_channels=False,
-  use_separate_discriminators=False
+    ngf=2,  # use small ngf and ndf to keep the graph smaller
+    ndf=2,
+    lr=0.0002,
+    beta1=0.5,
+    l1_weight=1.0,
+    gan_weight=0.0,
+    base_loss=BaseLoss.L1,
+    use_separate_discriminator_channels=False,
+    use_separate_discriminators=False
 )
 
 BATCH_SIZE = 10
@@ -40,164 +40,180 @@ WIDTH = 256
 HEIGHT = 256
 CHANNELS = 3
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def setup_module():
-  logging.basicConfig(level='DEBUG')
+    logging.basicConfig(level='DEBUG')
+
 
 def create_args(*args, **kwargs):
-  d = extend_dict(*list(args) + [kwargs])
-  return namedtuple('args', d.keys())(**d)
+    d = extend_dict(*list(args) + [kwargs])
+    return namedtuple('args', d.keys())(**d)
+
 
 def patch_spy_object(o, name):
-  return patch.object(o, name, wraps=getattr(o, name))
+    return patch.object(o, name, wraps=getattr(o, name))
+
 
 class TestCreateEncoderDecoder(object):
-  def test_should_add_dropout_in_training_mode_using_constant(self):
-    with tf.Graph().as_default():
-      encoder_inputs = tf.ones((1, 8, 8, 3))
-      encoder_layer_specs = [5, 10]
-      decoder_layer_specs = [(5, 0.5), (3, 0.0)]
-      outputs = create_encoder_decoder(
-        encoder_inputs,
-        encoder_layer_specs,
-        decoder_layer_specs,
-        is_training=True
-      )
-      with tf.Session() as session:
-        session.run(tf.global_variables_initializer())
-        # with dropout, the outputs are expected to be different for every run
-        assert_all_not_close(session.run(outputs), session.run(outputs))
-
-  def test_should_not_add_dropout_not_in_training_mode_using_constant(self):
-    with tf.Graph().as_default():
-      encoder_inputs = tf.ones((1, 8, 8, 3))
-      encoder_layer_specs = [5, 10]
-      decoder_layer_specs = [(5, 0.5), (3, 0.0)]
-      outputs = create_encoder_decoder(
-        encoder_inputs,
-        encoder_layer_specs,
-        decoder_layer_specs,
-        is_training=False
-      )
-      with tf.Session() as session:
-        session.run(tf.global_variables_initializer())
-        # without dropout, the outputs are expected to the same for every run
-        assert_all_close(session.run(outputs), session.run(outputs))
-
-  def test_should_add_dropout_in_training_mode_using_placeholder(self):
-    with tf.Graph().as_default():
-      is_training = tf.placeholder(tf.bool)
-      encoder_inputs = tf.ones((1, 8, 8, 3))
-      encoder_layer_specs = [5, 10]
-      decoder_layer_specs = [(5, 0.5), (3, 0.0)]
-      outputs = create_encoder_decoder(
-        encoder_inputs,
-        encoder_layer_specs,
-        decoder_layer_specs,
-        is_training=is_training
-      )
-      with tf.Session() as session:
-        session.run(tf.global_variables_initializer())
-        feed_dict = {is_training: True}
-        # with dropout, the outputs are expected to be different for every run
-        assert_all_not_close(
-          session.run(outputs, feed_dict=feed_dict),
-          session.run(outputs, feed_dict=feed_dict)
-        )
-
-  def test_should_not_add_dropout_not_in_training_mode_using_placeholder(self):
-    with tf.Graph().as_default():
-      is_training = tf.placeholder(tf.bool)
-      encoder_inputs = tf.ones((1, 8, 8, 3))
-      encoder_layer_specs = [5, 10]
-      decoder_layer_specs = [(5, 0.5), (3, 0.0)]
-      outputs = create_encoder_decoder(
-        encoder_inputs,
-        encoder_layer_specs,
-        decoder_layer_specs,
-        is_training=is_training
-      )
-      with tf.Session() as session:
-        session.run(tf.global_variables_initializer())
-        feed_dict = {is_training: False}
-        # without dropout, the outputs are expected to the same for every run
-        assert_all_close(
-          session.run(outputs, feed_dict=feed_dict),
-          session.run(outputs, feed_dict=feed_dict)
-        )
-
-  def test_should_allow_undefined_batch_size(self):
-    with tf.Graph().as_default():
-      input_shape = [None, 8, 8, 3]
-      encoder_inputs = tf.placeholder(tf.float32, input_shape)
-      encoder_layer_specs = [5, 10]
-      decoder_layer_specs = [(5, 0.5), (3, 0.0)]
-      outputs = create_encoder_decoder(
-        encoder_inputs,
-        encoder_layer_specs,
-        decoder_layer_specs,
-        is_training=False
-      )
-      assert outputs.get_shape().as_list() == input_shape
-      with tf.Session() as session:
-        session.run(tf.global_variables_initializer())
-        feed_dict = {encoder_inputs: np.ones([1] + input_shape[1:])}
-        assert_all_close(
-          session.run(outputs, feed_dict=feed_dict),
-          session.run(outputs, feed_dict=feed_dict)
-        )
+    def test_should_add_dropout_in_training_mode_using_constant(self):
+        with tf.Graph().as_default():
+            encoder_inputs = tf.ones((1, 8, 8, 3))
+            encoder_layer_specs = [5, 10]
+            decoder_layer_specs = [(5, 0.5), (3, 0.0)]
+            outputs = create_encoder_decoder(
+                encoder_inputs,
+                encoder_layer_specs,
+                decoder_layer_specs,
+                is_training=True
+            )
+            with tf.Session() as session:
+                session.run(tf.global_variables_initializer())
+                # with dropout, the outputs are expected to be different for every run
+                assert_all_not_close(session.run(outputs), session.run(outputs))
+
+    def test_should_not_add_dropout_not_in_training_mode_using_constant(self):
+        with tf.Graph().as_default():
+            encoder_inputs = tf.ones((1, 8, 8, 3))
+            encoder_layer_specs = [5, 10]
+            decoder_layer_specs = [(5, 0.5), (3, 0.0)]
+            outputs = create_encoder_decoder(
+                encoder_inputs,
+                encoder_layer_specs,
+                decoder_layer_specs,
+                is_training=False
+            )
+            with tf.Session() as session:
+                session.run(tf.global_variables_initializer())
+                # without dropout, the outputs are expected to the same for every run
+                assert_all_close(session.run(outputs), session.run(outputs))
+
+    def test_should_add_dropout_in_training_mode_using_placeholder(self):
+        with tf.Graph().as_default():
+            is_training = tf.placeholder(tf.bool)
+            encoder_inputs = tf.ones((1, 8, 8, 3))
+            encoder_layer_specs = [5, 10]
+            decoder_layer_specs = [(5, 0.5), (3, 0.0)]
+            outputs = create_encoder_decoder(
+                encoder_inputs,
+                encoder_layer_specs,
+                decoder_layer_specs,
+                is_training=is_training
+            )
+            with tf.Session() as session:
+                session.run(tf.global_variables_initializer())
+                feed_dict = {is_training: True}
+                # with dropout, the outputs are expected to be different for every run
+                assert_all_not_close(
+                    session.run(outputs, feed_dict=feed_dict),
+                    session.run(outputs, feed_dict=feed_dict)
+                )
+
+    def test_should_not_add_dropout_not_in_training_mode_using_placeholder(self):
+        with tf.Graph().as_default():
+            is_training = tf.placeholder(tf.bool)
+            encoder_inputs = tf.ones((1, 8, 8, 3))
+            encoder_layer_specs = [5, 10]
+            decoder_layer_specs = [(5, 0.5), (3, 0.0)]
+            outputs = create_encoder_decoder(
+                encoder_inputs,
+                encoder_layer_specs,
+                decoder_layer_specs,
+                is_training=is_training
+            )
+            with tf.Session() as session:
+                session.run(tf.global_variables_initializer())
+                feed_dict = {is_training: False}
+                # without dropout, the outputs are expected to the same for every run
+                assert_all_close(
+                    session.run(outputs, feed_dict=feed_dict),
+                    session.run(outputs, feed_dict=feed_dict)
+                )
+
+    def test_should_allow_undefined_batch_size(self):
+        with tf.Graph().as_default():
+            input_shape = [None, 8, 8, 3]
+            encoder_inputs = tf.placeholder(tf.float32, input_shape)
+            encoder_layer_specs = [5, 10]
+            decoder_layer_specs = [(5, 0.5), (3, 0.0)]
+            outputs = create_encoder_decoder(
+                encoder_inputs,
+                encoder_layer_specs,
+                decoder_layer_specs,
+                is_training=False
+            )
+            assert outputs.get_shape().as_list() == input_shape
+            with tf.Session() as session:
+                session.run(tf.global_variables_initializer())
+                feed_dict = {encoder_inputs: np.ones([1] + input_shape[1:])}
+                assert_all_close(
+                    session.run(outputs, feed_dict=feed_dict),
+                    session.run(outputs, feed_dict=feed_dict)
+                )
+
 
 @pytest.mark.slow
 @pytest.mark.very_slow
 class TestCreatePix2pixModel(object):
-  def test_should_be_able_to_construct_graph_with_defaults_without_gan(self):
-    with tf.Graph().as_default():
-      inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-      targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-      a = create_args(DEFAULT_ARGS, gan_weight=0.0)
-      create_pix2pix_model(inputs, targets, a, is_training=True)
-
-  def test_should_be_able_to_construct_graph_with_defaults_and_gan(self):
-    with tf.Graph().as_default():
-      inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-      targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-      a = create_args(DEFAULT_ARGS, gan_weight=1.0)
-      create_pix2pix_model(inputs, targets, a, is_training=True)
-
-  def test_should_be_able_to_construct_graph_with_gan_and_sep_discrim_channels(self):
-    with patch_spy_object(pix2pix_core, 'l1_loss') as l1_loss:
-      with tf.Graph().as_default():
-        inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-        targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-        a = create_args(DEFAULT_ARGS, gan_weight=1.0, use_separate_discriminator_channels=True)
-        create_pix2pix_model(inputs, targets, a, is_training=True)
-        assert l1_loss.called
-
-  def test_should_be_able_to_construct_graph_with_sep_discrim_channels_and_cross_entropy_loss(self):
-    with patch_spy_object(pix2pix_core, 'cross_entropy_loss') as cross_entropy_loss:
-      with tf.Graph().as_default():
-        inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-        targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-        a = create_args(
-          DEFAULT_ARGS,
-          gan_weight=1.0, use_separate_discriminator_channels=True, base_loss=BaseLoss.CROSS_ENTROPY
-        )
-        create_pix2pix_model(inputs, targets, a, is_training=True)
-        assert cross_entropy_loss.called
-
-  def test_should_be_able_to_construct_graph_with_weighted_cross_entropy_loss(self):
-    with patch_spy_object(pix2pix_core, 'weighted_cross_entropy_loss') \
-      as weighted_cross_entropy_loss:
-
-      with tf.Graph().as_default():
-        inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-        targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
-        a = create_args(
-          DEFAULT_ARGS,
-          gan_weight=1.0, use_separate_discriminator_channels=True,
-          base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY
-        )
-        create_pix2pix_model(inputs, targets, a, is_training=True, pos_weight=[1.0] * CHANNELS)
-        assert weighted_cross_entropy_loss.called
+    def test_should_be_able_to_construct_graph_with_defaults_without_gan(self):
+        with tf.Graph().as_default():
+            inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+            targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+            a = create_args(DEFAULT_ARGS, gan_weight=0.0)
+            create_pix2pix_model(inputs, targets, a, is_training=True)
+
+    def test_should_be_able_to_construct_graph_with_defaults_and_gan(self):
+        with tf.Graph().as_default():
+            inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+            targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+            a = create_args(DEFAULT_ARGS, gan_weight=1.0)
+            create_pix2pix_model(inputs, targets, a, is_training=True)
+
+    def test_should_be_able_to_construct_graph_with_gan_and_sep_discrim_channels(self):
+        with patch_spy_object(pix2pix_core, 'l1_loss') as l1_loss:
+            with tf.Graph().as_default():
+                inputs = tf.constant(
+                    np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+                targets = tf.constant(
+                    np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+                a = create_args(DEFAULT_ARGS, gan_weight=1.0,
+                                use_separate_discriminator_channels=True)
+                create_pix2pix_model(inputs, targets, a, is_training=True)
+                assert l1_loss.called
+
+    def test_should_be_able_to_construct_graph_with_sep_discrim_channels_and_cross_entropy_loss(
+            self):
+        with patch_spy_object(pix2pix_core, 'cross_entropy_loss') as cross_entropy_loss:
+            with tf.Graph().as_default():
+                inputs = tf.constant(
+                    np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+                targets = tf.constant(
+                    np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+                a = create_args(
+                    DEFAULT_ARGS,
+                    gan_weight=1.0, use_separate_discriminator_channels=True,
+                    base_loss=BaseLoss.CROSS_ENTROPY
+                )
+                create_pix2pix_model(inputs, targets, a, is_training=True)
+                assert cross_entropy_loss.called
+
+    def test_should_be_able_to_construct_graph_with_weighted_cross_entropy_loss(self):
+        with patch_spy_object(pix2pix_core, 'weighted_cross_entropy_loss') \
+                as weighted_cross_entropy_loss:
+
+            with tf.Graph().as_default():
+                inputs = tf.constant(
+                    np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+                targets = tf.constant(
+                    np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
+                a = create_args(
+                    DEFAULT_ARGS,
+                    gan_weight=1.0, use_separate_discriminator_channels=True,
+                    base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY
+                )
+                create_pix2pix_model(inputs, targets, a, is_training=True,
+                                     pos_weight=[1.0] * CHANNELS)
+                assert weighted_cross_entropy_loss.called
diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py
index e4e2e5f..805be05 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py
@@ -12,44 +12,44 @@ from six.moves import reduce
 import tensorflow as tf
 
 
-from tensorflow.python.lib.io.file_io import FileIO # pylint: disable=E0611
+from tensorflow.python.lib.io.file_io import FileIO  # pylint: disable=E0611
 
 from sciencebeam_gym.trainer.data.examples import (
-  get_matching_files,
-  read_examples
+    get_matching_files,
+    read_examples
 )
 
 from sciencebeam_gym.preprocess.color_map import (
-  parse_color_map_from_file
+    parse_color_map_from_file
 )
 
 from sciencebeam_gym.tools.calculate_class_weights import (
-  tf_calculate_efnet_weights_for_frequency_by_label
+    tf_calculate_efnet_weights_for_frequency_by_label
 )
 
 from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
-  find_nearest_centroid_indices
+    find_nearest_centroid_indices
 )
 
 from sciencebeam_gym.preprocess.preprocessing_utils import (
-  parse_page_range
+    parse_page_range
 )
 
 from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
-  BaseLoss,
-  ALL_BASE_LOSS,
-  create_pix2pix_model,
-  create_other_summaries
+    BaseLoss,
+    ALL_BASE_LOSS,
+    create_pix2pix_model,
+    create_other_summaries
 )
 
 from sciencebeam_gym.trainer.models.pix2pix.evaluate import (
-  evaluate_separate_channels,
-  evaluate_predictions,
-  evaluation_summary
+    evaluate_separate_channels,
+    evaluate_predictions,
+    evaluation_summary
 )
 
 from sciencebeam_gym.model_utils.channels import (
-  calculate_color_masks
+    calculate_color_masks
 )
 
 
@@ -58,628 +58,669 @@ UNKNOWN_LABEL = 'unknown'
 
 DEFAULT_UNKNOWN_CLASS_WEIGHT = 0.1
 
+
 class GraphMode(object):
-  TRAIN = 1
-  EVALUATE = 2
-  PREDICT = 3
+    TRAIN = 1
+    EVALUATE = 2
+    PREDICT = 3
+
 
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
 
 
 class GraphReferences(object):
-  """Holder of base tensors used for training model using common task."""
-
-  def __init__(self):
-    self.is_training = None
-    self.inputs = dict()
-    self.examples = None
-    self.train = None
-    self.global_step = None
-    self.metric_updates = []
-    self.metric_values = []
-    self.predictions = []
-    self.input_jpeg = None
-    self.input_uri = None
-    self.image_tensor = None
-    self.annotation_uri = None
-    self.annotation_tensor = None
-    self.separate_channel_annotation_tensor = None
-    self.class_labels_tensor = None
-    self.pred = None
-    self.probabilities = None
-    self.summary = None
-    self.summaries = None
-    self.image_tensors = None
-    self.targets_class_indices = None
-    self.outputs_class_indices = None
-    self.output_layer_labels = None
-    self.evaluation_result = None
-    self.pos_weight = None
+    """Holder of base tensors used for training model using common task."""
+
+    def __init__(self):
+        self.is_training = None
+        self.inputs = dict()
+        self.examples = None
+        self.train = None
+        self.global_step = None
+        self.metric_updates = []
+        self.metric_values = []
+        self.predictions = []
+        self.input_jpeg = None
+        self.input_uri = None
+        self.image_tensor = None
+        self.annotation_uri = None
+        self.annotation_tensor = None
+        self.separate_channel_annotation_tensor = None
+        self.class_labels_tensor = None
+        self.pred = None
+        self.probabilities = None
+        self.summary = None
+        self.summaries = None
+        self.image_tensors = None
+        self.targets_class_indices = None
+        self.outputs_class_indices = None
+        self.output_layer_labels = None
+        self.evaluation_result = None
+        self.pos_weight = None
+
 
 def batch_dimensions_to_colors_list(image_tensor, colors):
-  batch_images = []
-  for i, single_label_color in enumerate(colors):
-    batch_images.append(
-      tf.expand_dims(
-        image_tensor[:, :, :, i],
-        axis=-1
-      ) * ([x / 255.0 for x in single_label_color])
-    )
-  return batch_images
+    batch_images = []
+    for i, single_label_color in enumerate(colors):
+        batch_images.append(
+            tf.expand_dims(
+                image_tensor[:, :, :, i],
+                axis=-1
+            ) * ([x / 255.0 for x in single_label_color])
+        )
+    return batch_images
+
 
 def batch_dimensions_to_most_likely_colors_list(image_tensor, colors):
-  with tf.variable_scope("batch_dimensions_to_most_likely_colors_list"):
-    colors_tensor = tf.constant(colors, dtype=tf.uint8, name='colors')
-    most_likely_class_index = tf.argmax(image_tensor, 3)
-    return tf.gather(params=colors_tensor, indices=most_likely_class_index)
+    with tf.variable_scope("batch_dimensions_to_most_likely_colors_list"):
+        colors_tensor = tf.constant(colors, dtype=tf.uint8, name='colors')
+        most_likely_class_index = tf.argmax(image_tensor, 3)
+        return tf.gather(params=colors_tensor, indices=most_likely_class_index)
+
 
 def add_summary_image(tensors, name, image):
-  tensors.image_tensors[name] = image
-  tf.summary.image(name, image)
+    tensors.image_tensors[name] = image
+    tf.summary.image(name, image)
+
 
 def convert_image(image_tensor):
-  return tf.image.convert_image_dtype(
-    image_tensor,
-    dtype=tf.uint8,
-    saturate=True
-  )
+    return tf.image.convert_image_dtype(
+        image_tensor,
+        dtype=tf.uint8,
+        saturate=True
+    )
+
 
 def add_simple_summary_image(tensors, name, image_tensor):
-  with tf.name_scope(name):
-    add_summary_image(
-      tensors,
-      name,
-      convert_image(image_tensor)
-    )
+    with tf.name_scope(name):
+        add_summary_image(
+            tensors,
+            name,
+            convert_image(image_tensor)
+        )
+
 
 def replace_black_with_white_color(image_tensor):
-  is_black = tf.reduce_all(
-  tf.equal(image_tensor, (0, 0, 0)),
-    axis=-1
-  )
-  is_black = tf.stack([is_black] * 3, axis=-1)
-  return tf.where(
-    is_black,
-    255 * tf.ones_like(image_tensor),
-    image_tensor
-  )
+    is_black = tf.reduce_all(
+        tf.equal(image_tensor, (0, 0, 0)),
+        axis=-1
+    )
+    is_black = tf.stack([is_black] * 3, axis=-1)
+    return tf.where(
+        is_black,
+        255 * tf.ones_like(image_tensor),
+        image_tensor
+    )
+
 
 def combine_image(batch_images, replace_black_with_white=False):
-  clipped_batch_images = [
-    tf.clip_by_value(batch_image, 0.0, 1.0)
-    for batch_image in batch_images
-  ]
-  combined_image = convert_image(
-    reduce(
-      lambda a, b: a + b,
-      clipped_batch_images
+    clipped_batch_images = [
+        tf.clip_by_value(batch_image, 0.0, 1.0)
+        for batch_image in batch_images
+    ]
+    combined_image = convert_image(
+        reduce(
+            lambda a, b: a + b,
+            clipped_batch_images
+        )
     )
-  )
-  if replace_black_with_white:
-    combined_image = replace_black_with_white_color(combined_image)
-  return combined_image
+    if replace_black_with_white:
+        combined_image = replace_black_with_white_color(combined_image)
+    return combined_image
+
 
 def remove_last(a):
-  return a[:-1]
+    return a[:-1]
+
 
 def add_model_summary_images(
-  tensors, dimension_colors, dimension_labels,
-  use_separate_channels=False,
-  has_unknown_class=False):
-
-  tensors.summaries = {}
-  add_simple_summary_image(
-    tensors, 'input', tensors.image_tensor
-  )
-  add_simple_summary_image(
-    tensors, 'target', tensors.annotation_tensor
-  )
-  if (has_unknown_class or not use_separate_channels) and dimension_labels is not None:
-    dimension_labels_with_unknown = dimension_labels + [UNKNOWN_LABEL]
-    dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)]
-  else:
-    dimension_labels_with_unknown = dimension_labels
-    dimension_colors_with_unknown = dimension_colors
-  if use_separate_channels:
-    for name, outputs in [
-        ('targets', tensors.separate_channel_annotation_tensor),
-        ('outputs', tensors.pred)
-      ]:
-
-      batch_images = batch_dimensions_to_colors_list(
-        outputs,
-        dimension_colors_with_unknown
-      )
-      batch_images_excluding_unknown = (
-        remove_last(batch_images)
-        if has_unknown_class
-        else batch_images
-      )
-      for i, (batch_image, dimension_label) in enumerate(zip(
-        batch_images, dimension_labels_with_unknown)):
-
-        suffix = "_{}_{}".format(
-          i, dimension_label if dimension_label else 'unknown_label'
-        )
-        add_simple_summary_image(
-          tensors, name + suffix, batch_image
-        )
-      with tf.name_scope(name + "_combined"):
-        combined_image = combine_image(batch_images_excluding_unknown)
-        if name == 'outputs':
-          tensors.summaries['output_image'] = combined_image
-        add_summary_image(
-          tensors,
-          name + "_combined",
-          combined_image
-        )
+        tensors, dimension_colors, dimension_labels,
+        use_separate_channels=False,
+        has_unknown_class=False):
 
-      if name == 'outputs':
-        with tf.name_scope(name + "_most_likely"):
-          add_summary_image(
-            tensors,
-            name + "_most_likely",
-            batch_dimensions_to_most_likely_colors_list(
-              outputs,
-              dimension_colors_with_unknown)
-          )
-  else:
+    tensors.summaries = {}
     add_simple_summary_image(
-      tensors,
-      "output",
-      tensors.pred
+        tensors, 'input', tensors.image_tensor
     )
-    if tensors.outputs_class_indices is not None:
-      outputs = tensors.pred
-      with tf.name_scope("outputs_most_likely"):
-        colors_tensor = tf.constant(
-          dimension_colors_with_unknown,
-          dtype=tf.uint8, name='colors'
-        )
-        add_summary_image(
-          tensors,
-          "outputs_most_likely",
-          tf.gather(
-            params=colors_tensor,
-            indices=tensors.outputs_class_indices
-          )
+    add_simple_summary_image(
+        tensors, 'target', tensors.annotation_tensor
+    )
+    if (has_unknown_class or not use_separate_channels) and dimension_labels is not None:
+        dimension_labels_with_unknown = dimension_labels + [UNKNOWN_LABEL]
+        dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)]
+    else:
+        dimension_labels_with_unknown = dimension_labels
+        dimension_colors_with_unknown = dimension_colors
+    if use_separate_channels:
+        for name, outputs in [
+            ('targets', tensors.separate_channel_annotation_tensor),
+            ('outputs', tensors.pred)
+        ]:
+
+            batch_images = batch_dimensions_to_colors_list(
+                outputs,
+                dimension_colors_with_unknown
+            )
+            batch_images_excluding_unknown = (
+                remove_last(batch_images)
+                if has_unknown_class
+                else batch_images
+            )
+            for i, (batch_image, dimension_label) in enumerate(zip(
+                    batch_images, dimension_labels_with_unknown)):
+
+                suffix = "_{}_{}".format(
+                    i, dimension_label if dimension_label else 'unknown_label'
+                )
+                add_simple_summary_image(
+                    tensors, name + suffix, batch_image
+                )
+            with tf.name_scope(name + "_combined"):
+                combined_image = combine_image(batch_images_excluding_unknown)
+                if name == 'outputs':
+                    tensors.summaries['output_image'] = combined_image
+                add_summary_image(
+                    tensors,
+                    name + "_combined",
+                    combined_image
+                )
+
+            if name == 'outputs':
+                with tf.name_scope(name + "_most_likely"):
+                    add_summary_image(
+                        tensors,
+                        name + "_most_likely",
+                        batch_dimensions_to_most_likely_colors_list(
+                            outputs,
+                            dimension_colors_with_unknown)
+                    )
+    else:
+        add_simple_summary_image(
+            tensors,
+            "output",
+            tensors.pred
         )
-    tensors.summaries['output_image'] = tensors.image_tensors['output']
+        if tensors.outputs_class_indices is not None:
+            outputs = tensors.pred
+            with tf.name_scope("outputs_most_likely"):
+                colors_tensor = tf.constant(
+                    dimension_colors_with_unknown,
+                    dtype=tf.uint8, name='colors'
+                )
+                add_summary_image(
+                    tensors,
+                    "outputs_most_likely",
+                    tf.gather(
+                        params=colors_tensor,
+                        indices=tensors.outputs_class_indices
+                    )
+                )
+        tensors.summaries['output_image'] = tensors.image_tensors['output']
+
 
 def parse_json_file(filename):
-  with FileIO(filename, 'r') as f:
-    return json.load(f)
+    with FileIO(filename, 'r') as f:
+        return json.load(f)
+
 
 def class_weights_to_pos_weight(
-  class_weights, labels,
-  use_unknown_class, unknown_class_weight=DEFAULT_UNKNOWN_CLASS_WEIGHT):
+        class_weights, labels,
+        use_unknown_class, unknown_class_weight=DEFAULT_UNKNOWN_CLASS_WEIGHT):
+
+    pos_weight = [class_weights[k] for k in labels]
+    return pos_weight + [unknown_class_weight] if use_unknown_class else pos_weight
 
-  pos_weight = [class_weights[k] for k in labels]
-  return pos_weight + [unknown_class_weight] if use_unknown_class else pos_weight
 
 def parse_color_map(color_map_filename):
-  with FileIO(color_map_filename, 'r') as config_f:
-    return parse_color_map_from_file(
-      config_f
-    )
+    with FileIO(color_map_filename, 'r') as config_f:
+        return parse_color_map_from_file(
+            config_f
+        )
+
 
 def color_map_to_labels(color_map, labels=None):
-  if labels:
-    if not all(color_map.has_key(k) for k in labels):
-      raise ValueError(
-        'not all lables found in color map, labels=%s, available keys=%s' %
-        (labels, color_map.keys())
-      )
-    return labels
-  return sorted(color_map.keys())
+    if labels:
+        if not all(k in color_map for k in labels):
+            raise ValueError(
+                'not all lables found in color map, labels=%s, available keys=%s' %
+                (labels, color_map.keys())
+            )
+        return labels
+    return sorted(color_map.keys())
+
 
 def color_map_to_colors(color_map, labels):
-  return [color_map[k] for k in labels]
+    return [color_map[k] for k in labels]
+
 
 def colors_and_labels_with_unknown_class(colors, labels, use_unknown_class):
-  if use_unknown_class or not colors:
-    return (
-      colors + [UNKNOWN_COLOR],
-      labels + [UNKNOWN_LABEL]
-    )
-  else:
-    return colors, labels
+    if use_unknown_class or not colors:
+        return (
+            colors + [UNKNOWN_COLOR],
+            labels + [UNKNOWN_LABEL]
+        )
+    else:
+        return colors, labels
 
-def remove_none_from_dict(d):
-  return {k: v for k, v in iteritems(d) if v is not None}
 
-class Model(object):
-  def __init__(self, args):
-    self.args = args
-    self.image_width = 256
-    self.image_height = 256
-    self.color_map = None
-    self.pos_weight = None
-    self.dimension_colors = None
-    self.dimension_labels = None
-    self.use_unknown_class = args.use_unknown_class
-    self.use_separate_channels = args.use_separate_channels and self.args.color_map is not None
-    logger = get_logger()
-    logger.info('use_separate_channels: %s', self.use_separate_channels)
-    if self.args.color_map:
-      color_map = parse_color_map(args.color_map)
-      class_weights = (
-        parse_json_file(self.args.class_weights)
-        if (
-          self.args.class_weights and
-          self.args.base_loss in {
-            BaseLoss.WEIGHTED_CROSS_ENTROPY,
-            BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY
-          }
-        )
-        else None
-      )
-      available_labels = color_map_to_labels(color_map)
-      if class_weights:
-        # remove labels with zero class weights
-        available_labels = [k for k in available_labels if class_weights.get(k, 0.0) != 0.0]
-      self.dimension_labels = args.channels if args.channels else available_labels
-      self.dimension_colors = color_map_to_colors(color_map, self.dimension_labels)
-      self.dimension_colors_with_unknown, self.dimension_labels_with_unknown = (
-        colors_and_labels_with_unknown_class(
-          self.dimension_colors,
-          self.dimension_labels,
-          self.use_unknown_class
-        )
-      )
-      logger.debug("dimension_colors: %s", self.dimension_colors)
-      logger.debug("dimension_labels: %s", self.dimension_labels)
-      if class_weights:
-        self.pos_weight = class_weights_to_pos_weight(
-          class_weights,
-          self.dimension_labels,
-          self.use_separate_channels,
-          class_weights.get(UNKNOWN_LABEL, DEFAULT_UNKNOWN_CLASS_WEIGHT)
-        )
-        logger.info("pos_weight: %s", self.pos_weight)
+def remove_none_from_dict(d):
+    return {k: v for k, v in iteritems(d) if v is not None}
 
-  def _build_predict_graph(self):
-    tensors = GraphReferences()
-    input_image_tensor = tf.placeholder(
-      tf.uint8, (None, None, None, 3),
-      name='inputs_image'
-    )
-    tensors.inputs = dict(
-      image=input_image_tensor
-    )
 
-    tensors.image_tensor = tf.image.resize_images(
-      tf.image.convert_image_dtype(input_image_tensor, tf.float32),
-      (self.image_height, self.image_width),
-      method=tf.image.ResizeMethod.BILINEAR
-    )
+def _create_pos_weights_tensor(
+        base_loss,
+        separate_channel_annotation_tensor,
+        pos_weight_values,
+        input_uri,
+        debug):
 
-    if self.use_separate_channels:
-      n_output_channels = len(self.dimension_labels_with_unknown)
-    else:
-      n_output_channels = 3
-    pix2pix_model = create_pix2pix_model(
-      tensors.image_tensor,
-      None,
-      self.args,
-      is_training=False,
-      pos_weight=tensors.pos_weight,
-      n_output_channels=n_output_channels
+    frequency_by_label = tf.reduce_sum(
+        separate_channel_annotation_tensor,
+        axis=[0, 1],
+        keep_dims=True,
+        name='frequency_by_channel'
     )
-    tensors.pred = pix2pix_model.outputs
-    return tensors
-
-  def build_graph(self, data_paths, batch_size, graph_mode):
-    if graph_mode == GraphMode.PREDICT:
-      return self._build_predict_graph()
-
-    logger = get_logger()
-    logger.debug('batch_size: %s', batch_size)
-    tensors = GraphReferences()
-    tensors.is_training = tf.constant(graph_mode == GraphMode.TRAIN)
-    is_training = (
-      graph_mode == GraphMode.TRAIN or
-      graph_mode == GraphMode.EVALUATE
+    pos_weight_sample = tf_calculate_efnet_weights_for_frequency_by_label(
+        frequency_by_label
     )
-
-    if not data_paths:
-      raise ValueError('data_paths required')
-    get_logger().info('reading examples from %s', data_paths)
-    tensors.examples = read_examples(
-      get_matching_files(data_paths),
-      shuffle=(graph_mode == GraphMode.TRAIN),
-      num_epochs=None if is_training else 2,
-      page_range=self.args.pages,
-      channel_colors=(
-        self.dimension_colors if self.args.filter_annotated
-        else None
-      )
+    pos_weight = (
+        pos_weight_sample * pos_weight_values
+        if base_loss == BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY
+        else pos_weight_sample
     )
-    parsed = tensors.examples
-
-    tensors.image_tensors = {}
-
-    tensors.input_uri = tf.squeeze(parsed['input_uri'])
-    tensors.annotation_uri = tf.squeeze(parsed['annotation_uri'])
-    raw_input_image = tf.squeeze(parsed['input_image'])
-    logging.info('raw_input_image: %s', raw_input_image)
-    raw_annotation_image = tf.squeeze(parsed['annotation_image'])
-    tensors.image_tensor = tf.image.decode_png(raw_input_image, channels=3)
-    tensors.annotation_tensor = tf.image.decode_png(raw_annotation_image, channels=3)
-
-    # TODO resize_images and tf.cast did not work on input image
-    #   but did work on annotation image
-    tensors.image_tensor = tf.image.resize_image_with_crop_or_pad(
-      tensors.image_tensor, self.image_height, self.image_width
+    if debug:
+        pos_weight = tf.Print(
+            pos_weight, [
+                pos_weight,
+                pos_weight_sample,
+                frequency_by_label,
+                input_uri
+            ],
+            'pos weights, sample, frequency, uri: ',
+            summarize=1000
+        )
+    get_logger().debug(
+        'pos_weight before batch: %s (frequency_by_label: %s)',
+        pos_weight, frequency_by_label
     )
+    return pos_weight
 
-    tensors.image_tensor = tf.image.convert_image_dtype(tensors.image_tensor, tf.float32)
-
-    tensors.annotation_tensor = tf.image.resize_image_with_crop_or_pad(
-      tensors.annotation_tensor, self.image_height, self.image_width
-    )
 
-    if self.use_separate_channels:
-      with tf.variable_scope('channels'):
-        color_masks = calculate_color_masks(
-          tensors.annotation_tensor,
-          self.dimension_colors,
-          use_unknown_class=self.use_unknown_class
-        )
-        tensors.separate_channel_annotation_tensor = tf.stack(color_masks, axis=-1)
-        if self.args.base_loss == BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY:
-          with tf.variable_scope('class_weights'):
-            frequency_by_label = tf.reduce_sum(
-              tensors.separate_channel_annotation_tensor,
-              axis=[0, 1],
-              keep_dims=True,
-              name='frequency_by_channel'
+class Model(object):
+    def __init__(self, args):
+        self.args = args
+        self.image_width = 256
+        self.image_height = 256
+        self.color_map = None
+        self.pos_weight = None
+        self.dimension_colors = None
+        self.dimension_labels = None
+        self.use_unknown_class = args.use_unknown_class
+        self.use_separate_channels = args.use_separate_channels and self.args.color_map is not None
+        logger = get_logger()
+        logger.info('use_separate_channels: %s', self.use_separate_channels)
+        if self.args.color_map:
+            color_map = parse_color_map(args.color_map)
+            class_weights = (
+                parse_json_file(self.args.class_weights)
+                if (
+                    self.args.class_weights and
+                    self.args.base_loss in {
+                        BaseLoss.WEIGHTED_CROSS_ENTROPY,
+                        BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY
+                    }
+                )
+                else None
             )
-            pos_weight_sample = tf_calculate_efnet_weights_for_frequency_by_label(
-              frequency_by_label
+            available_labels = color_map_to_labels(color_map)
+            if class_weights:
+                # remove labels with zero class weights
+                available_labels = [k for k in available_labels if class_weights.get(k, 0.0) != 0.0]
+            self.dimension_labels = args.channels if args.channels else available_labels
+            self.dimension_colors = color_map_to_colors(color_map, self.dimension_labels)
+            self.dimension_colors_with_unknown, self.dimension_labels_with_unknown = (
+                colors_and_labels_with_unknown_class(
+                    self.dimension_colors,
+                    self.dimension_labels,
+                    self.use_unknown_class
+                )
             )
-            tensors.pos_weight = (
-              pos_weight_sample * self.pos_weight
-              if self.args.base_loss == BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY
-              else pos_weight_sample
+            logger.debug("dimension_colors: %s", self.dimension_colors)
+            logger.debug("dimension_labels: %s", self.dimension_labels)
+            if class_weights:
+                self.pos_weight = class_weights_to_pos_weight(
+                    class_weights,
+                    self.dimension_labels,
+                    self.use_separate_channels,
+                    class_weights.get(UNKNOWN_LABEL, DEFAULT_UNKNOWN_CLASS_WEIGHT)
+                )
+                logger.info("pos_weight: %s", self.pos_weight)
+
+    def _build_predict_graph(self):
+        tensors = GraphReferences()
+        input_image_tensor = tf.placeholder(
+            tf.uint8, (None, None, None, 3),
+            name='inputs_image'
+        )
+        tensors.inputs = dict(
+            image=input_image_tensor
+        )
+
+        tensors.image_tensor = tf.image.resize_images(
+            tf.image.convert_image_dtype(input_image_tensor, tf.float32),
+            (self.image_height, self.image_width),
+            method=tf.image.ResizeMethod.BILINEAR
+        )
+
+        if self.use_separate_channels:
+            n_output_channels = len(self.dimension_labels_with_unknown)
+        else:
+            n_output_channels = 3
+        pix2pix_model = create_pix2pix_model(
+            tensors.image_tensor,
+            None,
+            self.args,
+            is_training=False,
+            pos_weight=tensors.pos_weight,
+            n_output_channels=n_output_channels
+        )
+        tensors.pred = pix2pix_model.outputs
+        return tensors
+
+    def build_graph(self, data_paths, batch_size, graph_mode):
+        if graph_mode == GraphMode.PREDICT:
+            return self._build_predict_graph()
+
+        logger = get_logger()
+        logger.debug('batch_size: %s', batch_size)
+        tensors = GraphReferences()
+        tensors.is_training = tf.constant(graph_mode == GraphMode.TRAIN)
+        is_training = (
+            graph_mode == GraphMode.TRAIN or
+            graph_mode == GraphMode.EVALUATE
+        )
+
+        if not data_paths:
+            raise ValueError('data_paths required')
+        get_logger().info('reading examples from %s', data_paths)
+        tensors.examples = read_examples(
+            get_matching_files(data_paths),
+            shuffle=(graph_mode == GraphMode.TRAIN),
+            num_epochs=None if is_training else 2,
+            page_range=self.args.pages,
+            channel_colors=(
+                self.dimension_colors if self.args.filter_annotated
+                else None
             )
-            if self.args.debug:
-              tensors.pos_weight = tf.Print(
-                tensors.pos_weight, [
-                  tensors.pos_weight,
-                  pos_weight_sample,
-                  frequency_by_label,
-                  tensors.input_uri
-                ],
-                'pos weights, sample, frequency, uri: ',
-                summarize=1000
-              )
-            get_logger().debug(
-              'pos_weight before batch: %s (frequency_by_label: %s)',
-              tensors.pos_weight, frequency_by_label
+        )
+        parsed = tensors.examples
+
+        tensors.image_tensors = {}
+
+        tensors.input_uri = tf.squeeze(parsed['input_uri'])
+        tensors.annotation_uri = tf.squeeze(parsed['annotation_uri'])
+        raw_input_image = tf.squeeze(parsed['input_image'])
+        logging.info('raw_input_image: %s', raw_input_image)
+        raw_annotation_image = tf.squeeze(parsed['annotation_image'])
+        tensors.image_tensor = tf.image.decode_png(raw_input_image, channels=3)
+        tensors.annotation_tensor = tf.image.decode_png(raw_annotation_image, channels=3)
+
+        # TODO resize_images and tf.cast did not work on input image
+        #   but did work on annotation image
+        tensors.image_tensor = tf.image.resize_image_with_crop_or_pad(
+            tensors.image_tensor, self.image_height, self.image_width
+        )
+
+        tensors.image_tensor = tf.image.convert_image_dtype(tensors.image_tensor, tf.float32)
+
+        tensors.annotation_tensor = tf.image.resize_image_with_crop_or_pad(
+            tensors.annotation_tensor, self.image_height, self.image_width
+        )
+
+        if self.use_separate_channels:
+            with tf.variable_scope('channels'):
+                color_masks = calculate_color_masks(
+                    tensors.annotation_tensor,
+                    self.dimension_colors,
+                    use_unknown_class=self.use_unknown_class
+                )
+                tensors.separate_channel_annotation_tensor = tf.stack(color_masks, axis=-1)
+                if self.args.base_loss == BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY:
+                    with tf.variable_scope('class_weights'):
+                        tensors.pos_weight = _create_pos_weights_tensor(
+                            base_loss=self.args.base_loss,
+                            separate_channel_annotation_tensor=(
+                                tensors.separate_channel_annotation_tensor
+                            ),
+                            pos_weight_values=self.pos_weight,
+                            input_uri=tensors.input_uri,
+                            debug=self.args.debug
+                        )
+        else:
+            tensors.annotation_tensor = tf.image.convert_image_dtype(
+                tensors.annotation_tensor, tf.float32
             )
-    else:
-      tensors.annotation_tensor = tf.image.convert_image_dtype(
-        tensors.annotation_tensor, tf.float32
-      )
-      tensors.separate_channel_annotation_tensor = tensors.annotation_tensor
-
-    batched_tensors = tf.train.batch(
-      remove_none_from_dict({
-        k: getattr(tensors, k)
-        for k in {
-          'input_uri',
-          'annotation_uri',
-          'image_tensor',
-          'annotation_tensor',
-          'separate_channel_annotation_tensor',
-          'pos_weight'
-        }
-      }),
-      batch_size=batch_size
-    )
-    for k, v in iteritems(batched_tensors):
-      setattr(tensors, k, v)
-
-    if tensors.pos_weight is None:
-      tensors.pos_weight = self.pos_weight
-
-    pix2pix_model = create_pix2pix_model(
-      tensors.image_tensor,
-      tensors.separate_channel_annotation_tensor,
-      self.args,
-      is_training=tensors.is_training,
-      pos_weight=tensors.pos_weight
-    )
+            tensors.separate_channel_annotation_tensor = tensors.annotation_tensor
+
+        batched_tensors = tf.train.batch(
+            remove_none_from_dict({
+                k: getattr(tensors, k)
+                for k in {
+                    'input_uri',
+                    'annotation_uri',
+                    'image_tensor',
+                    'annotation_tensor',
+                    'separate_channel_annotation_tensor',
+                    'pos_weight'
+                }
+            }),
+            batch_size=batch_size
+        )
+        for k, v in iteritems(batched_tensors):
+            setattr(tensors, k, v)
+
+        if tensors.pos_weight is None:
+            tensors.pos_weight = self.pos_weight
+
+        pix2pix_model = create_pix2pix_model(
+            tensors.image_tensor,
+            tensors.separate_channel_annotation_tensor,
+            self.args,
+            is_training=tensors.is_training,
+            pos_weight=tensors.pos_weight
+        )
 
-    if self.use_separate_channels:
-      with tf.name_scope("evaluation"):
-        tensors.output_layer_labels = tf.constant(self.dimension_labels_with_unknown)
-        evaluation_result = evaluate_separate_channels(
-          targets=pix2pix_model.targets,
-          outputs=pix2pix_model.outputs
+        if self.use_separate_channels:
+            with tf.name_scope("evaluation"):
+                tensors.output_layer_labels = tf.constant(self.dimension_labels_with_unknown)
+                evaluation_result = evaluate_separate_channels(
+                    targets=pix2pix_model.targets,
+                    outputs=pix2pix_model.outputs
+                )
+                tensors.evaluation_result = evaluation_result
+                evaluation_summary(evaluation_result, self.dimension_labels_with_unknown)
+        else:
+            with tf.name_scope('evaluation'):
+                if self.dimension_colors:
+                    tensors.output_layer_labels = tf.constant(self.dimension_labels)
+                    colors_tensor = tf.constant(
+                        self.dimension_colors_with_unknown,
+                        dtype=tf.float32
+                    ) / 255.0
+                    tensors.outputs_class_indices = find_nearest_centroid_indices(
+                        predictions=pix2pix_model.outputs,
+                        centroids=colors_tensor
+                    )
+                    tensors.targets_class_indices = find_nearest_centroid_indices(
+                        predictions=pix2pix_model.targets,
+                        centroids=colors_tensor
+                    )
+                    evaluation_result = evaluate_predictions(
+                        labels=tensors.targets_class_indices,
+                        predictions=tensors.outputs_class_indices,
+                        n_classes=len(self.dimension_colors_with_unknown)
+                    )
+                    tensors.evaluation_result = evaluation_result
+                    evaluation_summary(evaluation_result, self.dimension_labels)
+
+        tensors.global_step = pix2pix_model.global_step
+        tensors.train = pix2pix_model.train
+        tensors.class_labels_tensor = tensors.annotation_tensor
+        tensors.pred = pix2pix_model.outputs
+        tensors.probabilities = pix2pix_model.outputs
+        tensors.metric_values = [pix2pix_model.discrim_loss]
+
+        add_model_summary_images(
+            tensors,
+            self.dimension_colors,
+            self.dimension_labels,
+            use_separate_channels=self.use_separate_channels,
+            has_unknown_class=self.use_unknown_class
         )
-        tensors.evaluation_result = evaluation_result
-        evaluation_summary(evaluation_result, self.dimension_labels_with_unknown)
-    else:
-      with tf.name_scope('evaluation'):
-        if self.dimension_colors:
-          tensors.output_layer_labels = tf.constant(self.dimension_labels)
-          colors_tensor = tf.constant(
-            self.dimension_colors_with_unknown,
-            dtype=tf.float32
-          ) / 255.0
-          tensors.outputs_class_indices = find_nearest_centroid_indices(
-            predictions=pix2pix_model.outputs,
-            centroids=colors_tensor
-          )
-          tensors.targets_class_indices = find_nearest_centroid_indices(
-            predictions=pix2pix_model.targets,
-            centroids=colors_tensor
-          )
-          evaluation_result = evaluate_predictions(
-            labels=tensors.targets_class_indices,
-            predictions=tensors.outputs_class_indices,
-            n_classes=len(self.dimension_colors_with_unknown)
-          )
-          tensors.evaluation_result = evaluation_result
-          evaluation_summary(evaluation_result, self.dimension_labels)
-
-    tensors.global_step = pix2pix_model.global_step
-    tensors.train = pix2pix_model.train
-    tensors.class_labels_tensor = tensors.annotation_tensor
-    tensors.pred = pix2pix_model.outputs
-    tensors.probabilities = pix2pix_model.outputs
-    tensors.metric_values = [pix2pix_model.discrim_loss]
-
-    add_model_summary_images(
-      tensors,
-      self.dimension_colors,
-      self.dimension_labels,
-      use_separate_channels=self.use_separate_channels,
-      has_unknown_class=self.use_unknown_class
-    )
 
-    # tensors.summaries = create_summaries(pix2pix_model)
-    create_other_summaries(pix2pix_model)
+        # tensors.summaries = create_summaries(pix2pix_model)
+        create_other_summaries(pix2pix_model)
+
+        if (
+            self.args.base_loss == BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY and
+            tensors.pos_weight is not None
+        ):
+            with tf.variable_scope('pos_weight_summary'):
+                tf.summary.text('pos_weight', tf.as_string(tf.reshape(
+                    tensors.pos_weight, [-1, int(tensors.pos_weight.shape[-1])]
+                )))
 
-    if (
-      self.args.base_loss == BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY and
-      tensors.pos_weight is not None
-    ):
-      with tf.variable_scope('pos_weight_summary'):
-        tf.summary.text('pos_weight', tf.as_string(tf.reshape(
-          tensors.pos_weight, [-1, int(tensors.pos_weight.shape[-1])]
-        )))
+        tensors.summary = tf.summary.merge_all()
+        return tensors
 
-    tensors.summary = tf.summary.merge_all()
-    return tensors
+    def build_train_graph(self, data_paths, batch_size):
+        return self.build_graph(data_paths, batch_size, GraphMode.TRAIN)
 
-  def build_train_graph(self, data_paths, batch_size):
-    return self.build_graph(data_paths, batch_size, GraphMode.TRAIN)
+    def build_eval_graph(self, data_paths, batch_size):
+        return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE)
 
-  def build_eval_graph(self, data_paths, batch_size):
-    return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE)
+    def build_predict_graph(self):
+        return self.build_graph(None, None, GraphMode.PREDICT)
 
-  def build_predict_graph(self):
-    return self.build_graph(None, None, GraphMode.PREDICT)
+    def initialize(self, session):
+        pass
 
-  def initialize(self, session):
-    pass
+    def format_metric_values(self, metric_values):
+        """Formats metric values - used for logging purpose."""
 
-  def format_metric_values(self, metric_values):
-    """Formats metric values - used for logging purpose."""
+        # Early in training, metric_values may actually be None.
+        loss_str = 'N/A'
+        accuracy_str = 'N/A'
+        try:
+            loss_str = '%.3f' % metric_values[0]
+            accuracy_str = '%.3f' % metric_values[1]
+        except (TypeError, IndexError):
+            pass
 
-    # Early in training, metric_values may actually be None.
-    loss_str = 'N/A'
-    accuracy_str = 'N/A'
-    try:
-      loss_str = '%.3f' % metric_values[0]
-      accuracy_str = '%.3f' % metric_values[1]
-    except (TypeError, IndexError):
-      pass
+        return '%s, %s' % (loss_str, accuracy_str)
 
-    return '%s, %s' % (loss_str, accuracy_str)
 
 def str_to_bool(s):
-  return s.lower() in ('yes', 'true', '1')
+    return s.lower() in ('yes', 'true', '1')
+
 
 def str_to_list(s):
-  s = s.strip()
-  if not s:
-    return []
-  return [x.strip() for x in s.split(',')]
+    s = s.strip()
+    if not s:
+        return []
+    return [x.strip() for x in s.split(',')]
+
 
 def model_args_parser():
-  parser = argparse.ArgumentParser()
-  parser.add_argument(
-    "--ngf", type=int, default=64, help="number of generator filters in first conv layer"
-  )
-  parser.add_argument(
-    "--ndf", type=int, default=64, help="number of discriminator filters in first conv layer"
-  )
-  parser.add_argument(
-    "--lr", type=float, default=0.0002, help="initial learning rate for adam"
-  )
-  parser.add_argument(
-    "--beta1", type=float, default=0.5, help="momentum term of adam"
-  )
-  parser.add_argument(
-    "--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient"
-  )
-  parser.add_argument(
-    "--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient"
-  )
-
-  parser.add_argument(
-    '--pages', type=parse_page_range, default=None,
-    help='only processes the selected pages'
-  )
-  parser.add_argument(
-    '--color_map',
-    type=str,
-    help='The path to the color map configuration.'
-  )
-  parser.add_argument(
-    '--class_weights',
-    type=str,
-    help='The path to the class weights configuration.'
-  )
-  parser.add_argument(
-    '--channels',
-    type=str_to_list,
-    help='The channels to use (subset of color map), otherwise all of the labels will be used'
-  )
-  parser.add_argument(
-    '--filter_annotated',
-    type=str_to_bool,
-    default=False,
-    help='Only include pages that have annotations for the selected channels'
-    ' (if color map is provided)'
-  )
-  parser.add_argument(
-    '--use_unknown_class',
-    type=str_to_bool,
-    default=True,
-    help='Use unknown class channel (if color map is provided)'
-  )
-  parser.add_argument(
-    '--use_separate_channels',
-    type=str_to_bool,
-    default=False,
-    help='The separate output channels per annotation (if color map is provided)'
-  )
-  parser.add_argument(
-    '--use_separate_discriminator_channels',
-    type=str_to_bool,
-    default=False,
-    help='The separate discriminator channels per annotation (if color map is provided)'
-  )
-  parser.add_argument(
-    '--use_separate_discriminators',
-    type=str_to_bool,
-    default=False,
-    help='The separate discriminators per annotation (if color map is provided)'
-  )
-  parser.add_argument(
-    '--base_loss',
-    type=str,
-    default=BaseLoss.L1,
-    choices=ALL_BASE_LOSS,
-    help='The base loss function to use'
-  )
-  parser.add_argument(
-    '--debug',
-    type=str_to_bool,
-    default=True,
-    help='Enable debug mode'
-  )
-  return parser
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--ngf", type=int, default=64, help="number of generator filters in first conv layer"
+    )
+    parser.add_argument(
+        "--ndf", type=int, default=64, help="number of discriminator filters in first conv layer"
+    )
+    parser.add_argument(
+        "--lr", type=float, default=0.0002, help="initial learning rate for adam"
+    )
+    parser.add_argument(
+        "--beta1", type=float, default=0.5, help="momentum term of adam"
+    )
+    parser.add_argument(
+        "--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient"
+    )
+    parser.add_argument(
+        "--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient"
+    )
+
+    parser.add_argument(
+        '--pages', type=parse_page_range, default=None,
+        help='only processes the selected pages'
+    )
+    parser.add_argument(
+        '--color_map',
+        type=str,
+        help='The path to the color map configuration.'
+    )
+    parser.add_argument(
+        '--class_weights',
+        type=str,
+        help='The path to the class weights configuration.'
+    )
+    parser.add_argument(
+        '--channels',
+        type=str_to_list,
+        help='The channels to use (subset of color map), otherwise all of the labels will be used'
+    )
+    parser.add_argument(
+        '--filter_annotated',
+        type=str_to_bool,
+        default=False,
+        help='Only include pages that have annotations for the selected channels'
+        ' (if color map is provided)'
+    )
+    parser.add_argument(
+        '--use_unknown_class',
+        type=str_to_bool,
+        default=True,
+        help='Use unknown class channel (if color map is provided)'
+    )
+    parser.add_argument(
+        '--use_separate_channels',
+        type=str_to_bool,
+        default=False,
+        help='The separate output channels per annotation (if color map is provided)'
+    )
+    parser.add_argument(
+        '--use_separate_discriminator_channels',
+        type=str_to_bool,
+        default=False,
+        help='The separate discriminator channels per annotation (if color map is provided)'
+    )
+    parser.add_argument(
+        '--use_separate_discriminators',
+        type=str_to_bool,
+        default=False,
+        help='The separate discriminators per annotation (if color map is provided)'
+    )
+    parser.add_argument(
+        '--base_loss',
+        type=str,
+        default=BaseLoss.L1,
+        choices=ALL_BASE_LOSS,
+        help='The base loss function to use'
+    )
+    parser.add_argument(
+        '--debug',
+        type=str_to_bool,
+        default=True,
+        help='Enable debug mode'
+    )
+    return parser
 
 
 def create_model(argv=None):
-  """Factory method that creates model to be used by generic task.py."""
-  parser = model_args_parser()
-  args, task_args = parser.parse_known_args(argv)
-  return Model(args), task_args
+    """Factory method that creates model to be used by generic task.py."""
+    parser = model_args_parser()
+    args, task_args = parser.parse_known_args(argv)
+    return Model(args), task_args
diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py
index 7377812..3d59547 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py
@@ -7,36 +7,36 @@ from pytest import raises
 import tensorflow as tf
 
 from sciencebeam_utils.utils.collection import (
-  extend_dict
+    extend_dict
 )
 
 from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
-  BaseLoss,
-  ALL_BASE_LOSS
+    BaseLoss,
+    ALL_BASE_LOSS
 )
 
 from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core_test import (
-  DEFAULT_ARGS as CORE_DEFAULT_ARGS
+    DEFAULT_ARGS as CORE_DEFAULT_ARGS
 )
 
 from sciencebeam_gym.trainer.data.examples_test import (
-  EXAMPLE_PROPS_1
+    EXAMPLE_PROPS_1
 )
 
 import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as pix2pix_model
 
 from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import (
-  parse_color_map,
-  color_map_to_labels,
-  color_map_to_colors,
-  colors_and_labels_with_unknown_class,
-  UNKNOWN_COLOR,
-  UNKNOWN_LABEL,
-  DEFAULT_UNKNOWN_CLASS_WEIGHT,
-  Model,
-  str_to_list,
-  model_args_parser,
-  class_weights_to_pos_weight
+    parse_color_map,
+    color_map_to_labels,
+    color_map_to_colors,
+    colors_and_labels_with_unknown_class,
+    UNKNOWN_COLOR,
+    UNKNOWN_LABEL,
+    DEFAULT_UNKNOWN_CLASS_WEIGHT,
+    Model,
+    str_to_list,
+    model_args_parser,
+    class_weights_to_pos_weight
 )
 
 COLOR_MAP_FILENAME = 'color_map.conf'
@@ -44,384 +44,452 @@ CLASS_WEIGHTS_FILENAME = 'class-weights.json'
 DATA_PATH = 'some/where/*.tfrecord'
 BATCH_SIZE = 2
 
+
 def some_color(i):
-  return (i, i, i)
+    return (i, i, i)
+
 
 SOME_COLORS = [some_color(1), some_color(2), some_color(3)]
 SOME_LABELS = ['a', 'b', 'c']
 SOME_COLOR_MAP = {
-  k: v for k, v in zip(SOME_LABELS, SOME_COLORS)
+    k: v for k, v in zip(SOME_LABELS, SOME_COLORS)
 }
 SOME_CLASS_WEIGHTS = {
-  k: float(1 + i) for i, k in enumerate(SOME_LABELS)
+    k: float(1 + i) for i, k in enumerate(SOME_LABELS)
 }
 
-class TestParseColorMap(object):
-  def test_should_use_fileio_to_load_file_and_pass_to_parser(self):
+
+@pytest.fixture(name='FileIO_mock')
+def _FileIO_mock():
     with patch.object(pix2pix_model, 'FileIO') as FileIO:
-      with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
+        yield FileIO
+
+
+@pytest.fixture(name='parse_color_map_from_file_mock')
+def _parse_color_map_from_file_mock():
+    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
+        yield parse_color_map_from_file
+
+
+@pytest.fixture(name='parse_json_file_mock')
+def _parse_json_file_mock():
+    with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
+        yield parse_json_file
+
+
+@pytest.fixture(name='get_matching_files_mock')
+def _get_matching_files_mock():
+    with patch.object(pix2pix_model, 'get_matching_files') as get_matching_files:
+        yield get_matching_files
+
+
+@pytest.fixture(name='read_examples_mock')
+def _read_examples_mock():
+    with patch.object(pix2pix_model, 'read_examples') as read_examples:
+        yield read_examples
+
+
+@pytest.fixture(name='default_graph')
+def _graph():
+    with tf.Graph().as_default() as graph:
+        yield graph
+
+
+class TestParseColorMap(object):
+    def test_should_use_fileio_to_load_file_and_pass_to_parser(
+            self, FileIO_mock, parse_color_map_from_file_mock):
+
         parse_color_map(COLOR_MAP_FILENAME)
-        FileIO.assert_called_with(COLOR_MAP_FILENAME, 'r')
-        parse_color_map_from_file.assert_called_with(FileIO.return_value.__enter__.return_value)
+        FileIO_mock.assert_called_with(COLOR_MAP_FILENAME, 'r')
+        parse_color_map_from_file_mock.assert_called_with(
+            FileIO_mock.return_value.__enter__.return_value)
+
 
 class TestColorMapToLabels(object):
-  def test_should_use_color_maps_keys_by_default(self):
-    color_map = {
-      'a': some_color(1),
-      'b': some_color(2),
-      'c': some_color(3)
-    }
-    assert color_map_to_labels(color_map) == ['a', 'b', 'c']
-
-  def test_should_return_specified_labels(self):
-    color_map = {
-      'a': some_color(1),
-      'b': some_color(2),
-      'c': some_color(3)
-    }
-    assert color_map_to_labels(color_map, ['b', 'a']) == ['b', 'a']
-
-  def test_should_raise_error_if_specified_label_not_in_color_map(self):
-    color_map = {
-      'a': some_color(1),
-      'c': some_color(3)
-    }
-    with raises(ValueError):
-      color_map_to_labels(color_map, ['a', 'b'])
+    def test_should_use_color_maps_keys_by_default(self):
+        color_map = {
+            'a': some_color(1),
+            'b': some_color(2),
+            'c': some_color(3)
+        }
+        assert color_map_to_labels(color_map) == ['a', 'b', 'c']
+
+    def test_should_return_specified_labels(self):
+        color_map = {
+            'a': some_color(1),
+            'b': some_color(2),
+            'c': some_color(3)
+        }
+        assert color_map_to_labels(color_map, ['b', 'a']) == ['b', 'a']
+
+    def test_should_raise_error_if_specified_label_not_in_color_map(self):
+        color_map = {
+            'a': some_color(1),
+            'c': some_color(3)
+        }
+        with raises(ValueError):
+            color_map_to_labels(color_map, ['a', 'b'])
+
 
 class TestColorMapToColors(object):
-  def test_should_return_colors_for_labels(self):
-    color_map = {
-      'a': some_color(1),
-      'b': some_color(2),
-      'c': some_color(3)
-    }
-    assert color_map_to_colors(color_map, ['a', 'b']) == [
-      some_color(1),
-      some_color(2)
-    ]
+    def test_should_return_colors_for_labels(self):
+        color_map = {
+            'a': some_color(1),
+            'b': some_color(2),
+            'c': some_color(3)
+        }
+        assert color_map_to_colors(color_map, ['a', 'b']) == [
+            some_color(1),
+            some_color(2)
+        ]
+
 
 class TestColorsAndLabelsWithUnknownClass(object):
-  def test_should_not_add_unknown_class_if_not_enabled(self):
-    colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
-      SOME_COLORS,
-      SOME_LABELS,
-      use_unknown_class=False
-    )
-    assert colors_with_unknown == SOME_COLORS
-    assert labels_with_unknown == SOME_LABELS
-
-  def test_should_add_unknown_class_if_enabled(self):
-    colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
-      SOME_COLORS,
-      SOME_LABELS,
-      use_unknown_class=True
-    )
-    assert colors_with_unknown == SOME_COLORS + [UNKNOWN_COLOR]
-    assert labels_with_unknown == SOME_LABELS + [UNKNOWN_LABEL]
-
-  def test_should_add_unknown_class_if_colors_are_empty(self):
-    colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
-      [],
-      [],
-      use_unknown_class=False
-    )
-    assert colors_with_unknown == [UNKNOWN_COLOR]
-    assert labels_with_unknown == [UNKNOWN_LABEL]
+    def test_should_not_add_unknown_class_if_not_enabled(self):
+        colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
+            SOME_COLORS,
+            SOME_LABELS,
+            use_unknown_class=False
+        )
+        assert colors_with_unknown == SOME_COLORS
+        assert labels_with_unknown == SOME_LABELS
+
+    def test_should_add_unknown_class_if_enabled(self):
+        colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
+            SOME_COLORS,
+            SOME_LABELS,
+            use_unknown_class=True
+        )
+        assert colors_with_unknown == SOME_COLORS + [UNKNOWN_COLOR]
+        assert labels_with_unknown == SOME_LABELS + [UNKNOWN_LABEL]
+
+    def test_should_add_unknown_class_if_colors_are_empty(self):
+        colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class(
+            [],
+            [],
+            use_unknown_class=False
+        )
+        assert colors_with_unknown == [UNKNOWN_COLOR]
+        assert labels_with_unknown == [UNKNOWN_LABEL]
+
 
 class TestClassWeightsToPosWeight(object):
-  def test_should_extract_selected_weights(self):
-    assert class_weights_to_pos_weight({
-      'a': 0.1,
-      'b': 0.2,
-      'c': 0.3
-    }, ['a', 'b'], False) == [0.1, 0.2]
-
-  def test_should_add_zero_if_unknown_class_is_true(self):
-    assert class_weights_to_pos_weight({
-      'a': 0.1,
-      'b': 0.2,
-      'c': 0.3
-    }, ['a', 'b'], True, DEFAULT_UNKNOWN_CLASS_WEIGHT) == (
-      [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT]
-    )
+    def test_should_extract_selected_weights(self):
+        assert class_weights_to_pos_weight({
+            'a': 0.1,
+            'b': 0.2,
+            'c': 0.3
+        }, ['a', 'b'], False) == [0.1, 0.2]
+
+    def test_should_add_zero_if_unknown_class_is_true(self):
+        assert class_weights_to_pos_weight({
+            'a': 0.1,
+            'b': 0.2,
+            'c': 0.3
+        }, ['a', 'b'], True, DEFAULT_UNKNOWN_CLASS_WEIGHT) == (
+            [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT]
+        )
+
 
 DEFAULT_ARGS = extend_dict(
-  CORE_DEFAULT_ARGS,
-  dict(
-    pages=None,
-    color_map=None,
-    class_weights=None,
-    channels=None,
-    filter_annotated=False,
-    use_separate_channels=False,
-    use_unknown_class=False,
-    debug=False
-  )
+    CORE_DEFAULT_ARGS,
+    dict(
+        pages=None,
+        color_map=None,
+        class_weights=None,
+        channels=None,
+        filter_annotated=False,
+        use_separate_channels=False,
+        use_unknown_class=False,
+        debug=False
+    )
 )
 
+
 def create_args(*args, **kwargs):
-  d = extend_dict(*list(args) + [kwargs])
-  return namedtuple('args', d.keys())(**d)
+    d = extend_dict(*list(args) + [kwargs])
+    return namedtuple('args', d.keys())(**d)
+
 
+@pytest.mark.usefixtures(
+    'parse_json_file_mock'
+)
 class TestModel(object):
-  def test_parse_separate_channels_with_color_map_without_class_weights(self):
-    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-      parse_color_map_from_file.return_value = {
-        'a': some_color(1),
-        'b': some_color(2),
-        'c': some_color(3)
-      }
-      args = create_args(
-        DEFAULT_ARGS,
-        color_map=COLOR_MAP_FILENAME,
-        class_weights=None,
-        channels=['a', 'b'],
-        use_separate_channels=True,
-        use_unknown_class=True
-      )
-      model = Model(args)
-      assert model.dimension_colors == [some_color(1), some_color(2)]
-      assert model.dimension_labels == ['a', 'b']
-      assert model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
-      assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
-      assert model.pos_weight is None
-
-  def test_parse_separate_channels_with_color_map_and_class_weights(self):
-    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-      with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
-        parse_color_map_from_file.return_value = {
-          'a': some_color(1),
-          'b': some_color(2),
-          'c': some_color(3)
+    def test_parse_separate_channels_with_color_map_without_class_weights(
+            self,
+            parse_color_map_from_file_mock):
+
+        parse_color_map_from_file_mock.return_value = {
+            'a': some_color(1),
+            'b': some_color(2),
+            'c': some_color(3)
         }
-        parse_json_file.return_value = {
-          'a': 0.1,
-          'b': 0.2,
-          'c': 0.3
+        args = create_args(
+            DEFAULT_ARGS,
+            color_map=COLOR_MAP_FILENAME,
+            class_weights=None,
+            channels=['a', 'b'],
+            use_separate_channels=True,
+            use_unknown_class=True
+        )
+        model = Model(args)
+        assert model.dimension_colors == [some_color(1), some_color(2)]
+        assert model.dimension_labels == ['a', 'b']
+        assert model.dimension_colors_with_unknown == [
+            some_color(1), some_color(2), UNKNOWN_COLOR]
+        assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
+        assert model.pos_weight is None
+
+    def test_parse_separate_channels_with_color_map_and_class_weights(
+            self,
+            parse_color_map_from_file_mock,
+            parse_json_file_mock):
+
+        parse_color_map_from_file_mock.return_value = {
+            'a': some_color(1),
+            'b': some_color(2),
+            'c': some_color(3)
+        }
+        parse_json_file_mock.return_value = {
+            'a': 0.1,
+            'b': 0.2,
+            'c': 0.3
         }
         args = create_args(
-          DEFAULT_ARGS,
-          base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
-          color_map=COLOR_MAP_FILENAME,
-          class_weights=CLASS_WEIGHTS_FILENAME,
-          channels=['a', 'b'],
-          use_separate_channels=True,
-          use_unknown_class=True
+            DEFAULT_ARGS,
+            base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
+            color_map=COLOR_MAP_FILENAME,
+            class_weights=CLASS_WEIGHTS_FILENAME,
+            channels=['a', 'b'],
+            use_separate_channels=True,
+            use_unknown_class=True
         )
         model = Model(args)
         assert model.dimension_colors == [some_color(1), some_color(2)]
         assert model.dimension_labels == ['a', 'b']
         assert (
-          model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
+            model.dimension_colors_with_unknown == [
+                some_color(1), some_color(2), UNKNOWN_COLOR]
         )
         assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
         assert model.pos_weight == [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT]
 
-  def test_should_only_include_labels_with_non_zero_class_labels_by_default(self):
-    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-      with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
-        parse_color_map_from_file.return_value = {
-          'a': some_color(1),
-          'b': some_color(2),
-          'c': some_color(3)
+    def test_should_only_include_labels_with_non_zero_class_labels_by_default(
+            self,
+            parse_color_map_from_file_mock,
+            parse_json_file_mock):
+
+        parse_color_map_from_file_mock.return_value = {
+            'a': some_color(1),
+            'b': some_color(2),
+            'c': some_color(3)
         }
-        parse_json_file.return_value = {
-          'a': 0.1,
-          'b': 0.0,
-          'c': 0.3
+        parse_json_file_mock.return_value = {
+            'a': 0.1,
+            'b': 0.0,
+            'c': 0.3
         }
         args = create_args(
-          DEFAULT_ARGS,
-          base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
-          color_map=COLOR_MAP_FILENAME,
-          class_weights=CLASS_WEIGHTS_FILENAME,
-          use_separate_channels=True,
-          use_unknown_class=True
+            DEFAULT_ARGS,
+            base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
+            color_map=COLOR_MAP_FILENAME,
+            class_weights=CLASS_WEIGHTS_FILENAME,
+            use_separate_channels=True,
+            use_unknown_class=True
         )
         model = Model(args)
         assert model.dimension_labels == ['a', 'c']
         assert model.dimension_colors == [some_color(1), some_color(3)]
         assert model.pos_weight == [0.1, 0.3, DEFAULT_UNKNOWN_CLASS_WEIGHT]
 
-  def test_should_use_unknown_class_weight_from_configuration(self):
-    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-      with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
-        parse_color_map_from_file.return_value = SOME_COLOR_MAP
-        parse_json_file.return_value = extend_dict(SOME_CLASS_WEIGHTS, {
-          'unknown': 0.99
+    def test_should_use_unknown_class_weight_from_configuration(
+            self,
+            parse_color_map_from_file_mock,
+            parse_json_file_mock):
+
+        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
+        parse_json_file_mock.return_value = extend_dict(SOME_CLASS_WEIGHTS, {
+            'unknown': 0.99
         })
         args = create_args(
-          DEFAULT_ARGS,
-          base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
-          color_map=COLOR_MAP_FILENAME,
-          class_weights=CLASS_WEIGHTS_FILENAME,
-          use_separate_channels=True,
-          use_unknown_class=True
+            DEFAULT_ARGS,
+            base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
+            color_map=COLOR_MAP_FILENAME,
+            class_weights=CLASS_WEIGHTS_FILENAME,
+            use_separate_channels=True,
+            use_unknown_class=True
         )
         model = Model(args)
         assert model.pos_weight[-1] == 0.99
 
-  def test_should_not_load_class_weights_for_cross_entropy(self):
-    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-      with patch.object(pix2pix_model, 'parse_json_file'):
-        parse_color_map_from_file.return_value = SOME_COLOR_MAP
+    def test_should_not_load_class_weights_for_cross_entropy(
+            self,
+            parse_color_map_from_file_mock):
+
+        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
         args = create_args(
-          DEFAULT_ARGS,
-          base_loss=BaseLoss.CROSS_ENTROPY,
-          color_map=COLOR_MAP_FILENAME,
-          class_weights=CLASS_WEIGHTS_FILENAME,
-          use_separate_channels=True,
-          use_unknown_class=True
+            DEFAULT_ARGS,
+            base_loss=BaseLoss.CROSS_ENTROPY,
+            color_map=COLOR_MAP_FILENAME,
+            class_weights=CLASS_WEIGHTS_FILENAME,
+            use_separate_channels=True,
+            use_unknown_class=True
         )
         model = Model(args)
         assert model.pos_weight is None
 
-  def test_should_not_load_class_weights_for_sample_weighted_cross_entropy(self):
-    with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-      with patch.object(pix2pix_model, 'parse_json_file'):
-        parse_color_map_from_file.return_value = SOME_COLOR_MAP
+    def test_should_not_load_class_weights_for_sample_weighted_cross_entropy(
+            self,
+            parse_color_map_from_file_mock):
+
+        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
         args = create_args(
-          DEFAULT_ARGS,
-          base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY,
-          color_map=COLOR_MAP_FILENAME,
-          class_weights=CLASS_WEIGHTS_FILENAME,
-          use_separate_channels=True,
-          use_unknown_class=True
+            DEFAULT_ARGS,
+            base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY,
+            color_map=COLOR_MAP_FILENAME,
+            class_weights=CLASS_WEIGHTS_FILENAME,
+            use_separate_channels=True,
+            use_unknown_class=True
         )
         model = Model(args)
         assert model.pos_weight is None
 
+
 @pytest.mark.slow
 @pytest.mark.very_slow
+@pytest.mark.usefixtures(
+    'get_matching_files_mock',
+    'default_graph'
+)
 class TestModelBuildGraph(object):
-  def test_should_build_train_graph_with_defaults(self):
-    with tf.Graph().as_default():
-      with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-        with patch.object(pix2pix_model, 'get_matching_files'):
-          with patch.object(pix2pix_model, 'read_examples') as read_examples:
-            parse_color_map_from_file.return_value = SOME_COLOR_MAP
-            read_examples.return_value = EXAMPLE_PROPS_1
-            args = create_args(
-              DEFAULT_ARGS
-            )
-            model = Model(args)
-            model.build_train_graph(DATA_PATH, BATCH_SIZE)
-
-  def test_should_build_train_graph_with_class_weights(self):
-    with tf.Graph().as_default():
-      with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-        with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
-          with patch.object(pix2pix_model, 'get_matching_files'):
-            with patch.object(pix2pix_model, 'read_examples') as read_examples:
-              parse_color_map_from_file.return_value = SOME_COLOR_MAP
-              parse_json_file.return_value = SOME_CLASS_WEIGHTS
-              read_examples.return_value = EXAMPLE_PROPS_1
-              args = create_args(
-                DEFAULT_ARGS,
-                base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
-                color_map=COLOR_MAP_FILENAME,
-                class_weights=CLASS_WEIGHTS_FILENAME,
-                channels=['a', 'b'],
-                use_separate_channels=True,
-                use_unknown_class=True
-              )
-              model = Model(args)
-              tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE)
-              assert tensors.pos_weight is not None
-
-  def test_should_build_train_graph_with_sample_class_weights(self):
-    with tf.Graph().as_default():
-      with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-        with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
-          with patch.object(pix2pix_model, 'get_matching_files'):
-            with patch.object(pix2pix_model, 'read_examples') as read_examples:
-              parse_color_map_from_file.return_value = SOME_COLOR_MAP
-              parse_json_file.return_value = SOME_CLASS_WEIGHTS
-              read_examples.return_value = EXAMPLE_PROPS_1
-              args = create_args(
-                DEFAULT_ARGS,
-                base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY,
-                color_map=COLOR_MAP_FILENAME,
-                channels=SOME_LABELS,
-                use_separate_channels=True,
-                use_unknown_class=True
-              )
-              model = Model(args)
-              tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE)
-              n_output_channels = len(SOME_LABELS) + 1
-              assert (
-                tensors.separate_channel_annotation_tensor.shape.as_list() ==
-                [BATCH_SIZE, model.image_height, model.image_width, n_output_channels]
-              )
-              assert tensors.pos_weight.shape.as_list() == [BATCH_SIZE, 1, 1, n_output_channels]
-
-  def test_should_build_predict_graph_with_defaults(self):
-    with tf.Graph().as_default():
-      with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-        with patch.object(pix2pix_model, 'get_matching_files'):
-          with patch.object(pix2pix_model, 'read_examples') as read_examples:
-            parse_color_map_from_file.return_value = SOME_COLOR_MAP
-            read_examples.return_value = EXAMPLE_PROPS_1
-            args = create_args(
-              DEFAULT_ARGS
-            )
-            model = Model(args)
-            tensors = model.build_predict_graph()
-            n_output_channels = 3
-            assert (
-              tensors.pred.shape.as_list() ==
-              [None, model.image_height, model.image_width, n_output_channels]
-            )
-
-  def test_should_build_predict_graph_with_sample_class_weights(self):
-    with tf.Graph().as_default():
-      with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
-        with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
-          with patch.object(pix2pix_model, 'get_matching_files'):
-            with patch.object(pix2pix_model, 'read_examples') as read_examples:
-              parse_color_map_from_file.return_value = SOME_COLOR_MAP
-              parse_json_file.return_value = SOME_CLASS_WEIGHTS
-              read_examples.return_value = EXAMPLE_PROPS_1
-              args = create_args(
-                DEFAULT_ARGS,
-                base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY,
-                color_map=COLOR_MAP_FILENAME,
-                channels=SOME_LABELS,
-                use_separate_channels=True,
-                use_unknown_class=True
-              )
-              model = Model(args)
-              tensors = model.build_predict_graph()
-              n_output_channels = len(SOME_LABELS) + 1
-              assert (
-                tensors.pred.shape.as_list() ==
-                [None, model.image_height, model.image_width, n_output_channels]
-              )
+    def test_should_build_train_graph_with_defaults(
+            self,
+            parse_color_map_from_file_mock,
+            read_examples_mock):
+
+        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
+        read_examples_mock.return_value = EXAMPLE_PROPS_1
+        args = create_args(
+            DEFAULT_ARGS
+        )
+        model = Model(args)
+        model.build_train_graph(DATA_PATH, BATCH_SIZE)
+
+    def test_should_build_train_graph_with_class_weights(
+            self,
+            parse_color_map_from_file_mock,
+            parse_json_file_mock,
+            read_examples_mock):
+
+        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
+        parse_json_file_mock.return_value = SOME_CLASS_WEIGHTS
+        read_examples_mock.return_value = EXAMPLE_PROPS_1
+        args = create_args(
+            DEFAULT_ARGS,
+            base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
+            color_map=COLOR_MAP_FILENAME,
+            class_weights=CLASS_WEIGHTS_FILENAME,
+            channels=['a', 'b'],
+            use_separate_channels=True,
+            use_unknown_class=True
+        )
+        model = Model(args)
+        tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE)
+        assert tensors.pos_weight is not None
+
+    def test_should_build_train_graph_with_sample_class_weights(
+            self,
+            parse_color_map_from_file_mock,
+            parse_json_file_mock,
+            read_examples_mock):
+
+        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
+        parse_json_file_mock.return_value = SOME_CLASS_WEIGHTS
+        read_examples_mock.return_value = EXAMPLE_PROPS_1
+        args = create_args(
+            DEFAULT_ARGS,
+            base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY,
+            color_map=COLOR_MAP_FILENAME,
+            channels=SOME_LABELS,
+            use_separate_channels=True,
+            use_unknown_class=True
+        )
+        model = Model(args)
+        tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE)
+        n_output_channels = len(SOME_LABELS) + 1
+        assert (
+            tensors.separate_channel_annotation_tensor.shape.as_list() ==
+            [BATCH_SIZE, model.image_height, model.image_width, n_output_channels]
+        )
+        assert tensors.pos_weight.shape.as_list(
+        ) == [BATCH_SIZE, 1, 1, n_output_channels]
+
+    def test_should_build_predict_graph_with_defaults(
+            self,
+            parse_color_map_from_file_mock,
+            read_examples_mock):
+        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
+        read_examples_mock.return_value = EXAMPLE_PROPS_1
+        args = create_args(
+            DEFAULT_ARGS
+        )
+        model = Model(args)
+        tensors = model.build_predict_graph()
+        n_output_channels = 3
+        assert (
+            tensors.pred.shape.as_list() ==
+            [None, model.image_height, model.image_width, n_output_channels]
+        )
+
+    def test_should_build_predict_graph_with_sample_class_weights(
+            self,
+            parse_color_map_from_file_mock,
+            parse_json_file_mock,
+            read_examples_mock):
+        parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP
+        parse_json_file_mock.return_value = SOME_CLASS_WEIGHTS
+        read_examples_mock.return_value = EXAMPLE_PROPS_1
+        args = create_args(
+            DEFAULT_ARGS,
+            base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY,
+            color_map=COLOR_MAP_FILENAME,
+            channels=SOME_LABELS,
+            use_separate_channels=True,
+            use_unknown_class=True
+        )
+        model = Model(args)
+        tensors = model.build_predict_graph()
+        n_output_channels = len(SOME_LABELS) + 1
+        assert (
+            tensors.pred.shape.as_list() ==
+            [None, model.image_height, model.image_width, n_output_channels]
+        )
+
 
 class TestStrToList(object):
-  def test_should_parse_empty_string_as_empty_list(self):
-    assert str_to_list('') == []
+    def test_should_parse_empty_string_as_empty_list(self):
+        assert str_to_list('') == []
+
+    def test_should_parse_blank_string_as_empty_list(self):
+        assert str_to_list(' ') == []
 
-  def test_should_parse_blank_string_as_empty_list(self):
-    assert str_to_list(' ') == []
+    def test_should_parse_comma_separated_list(self):
+        assert str_to_list('a,b,c') == ['a', 'b', 'c']
 
-  def test_should_parse_comma_separated_list(self):
-    assert str_to_list('a,b,c') == ['a', 'b', 'c']
+    def test_should_ignore_white_space_around_values(self):
+        assert str_to_list(' a , b , c ') == ['a', 'b', 'c']
 
-  def test_should_ignore_white_space_around_values(self):
-    assert str_to_list(' a , b , c ') == ['a', 'b', 'c']
 
 class TestModelArgsParser(object):
-  def test_should_parse_channels(self):
-    args = model_args_parser().parse_args(['--channels', 'a,b,c'])
-    assert args.channels == ['a', 'b', 'c']
-
-  def test_should_set_channels_to_none_by_default(self):
-    args = model_args_parser().parse_args([])
-    assert args.channels is None
-
-  def test_should_allow_all_base_loss_options(self):
-    for base_loss in ALL_BASE_LOSS:
-      args = model_args_parser().parse_args(['--base_loss', base_loss])
-      assert args.base_loss == base_loss
+    def test_should_parse_channels(self):
+        args = model_args_parser().parse_args(['--channels', 'a,b,c'])
+        assert args.channels == ['a', 'b', 'c']
+
+    def test_should_set_channels_to_none_by_default(self):
+        args = model_args_parser().parse_args([])
+        assert args.channels is None
+
+    def test_should_allow_all_base_loss_options(self):
+        for base_loss in ALL_BASE_LOSS:
+            args = model_args_parser().parse_args(['--base_loss', base_loss])
+            assert args.base_loss == base_loss
diff --git a/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py b/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py
index 440c100..ba14e6c 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py
@@ -3,49 +3,54 @@ from __future__ import division
 
 import tensorflow as tf
 
+
 def pairwise_squared_distance_with(predictions, centroids):
-  centroid_count = centroids.shape[0]
-  return tf.stack([
-    tf.reduce_sum(
-      tf.square(predictions - centroids[i]),
-      axis=-1
-    )
-    for i in range(centroid_count)
-  ], axis=-1)
+    centroid_count = centroids.shape[0]
+    return tf.stack([
+        tf.reduce_sum(
+            tf.square(predictions - centroids[i]),
+            axis=-1
+        )
+        for i in range(centroid_count)
+    ], axis=-1)
+
 
 def find_nearest_centroid_indices(predictions, centroids):
-  return tf.argmin(
-    pairwise_squared_distance_with(predictions, centroids),
-    axis=-1
-  )
+    return tf.argmin(
+        pairwise_squared_distance_with(predictions, centroids),
+        axis=-1
+    )
+
 
 def find_nearest_centroid(predictions, centroids):
-  return tf.gather(
-    params=centroids,
-    indices=find_nearest_centroid_indices(predictions, centroids)
-  )
+    return tf.gather(
+        params=centroids,
+        indices=find_nearest_centroid_indices(predictions, centroids)
+    )
+
 
 def get_channel_slice(tensor, channel_index):
-  rank = len(tensor.shape)
-  return tf.slice(
-    tensor,
-    begin=[0] * (rank - 1) + [channel_index],
-    size=[-1] * (rank - 1) + [1]
-  )
+    rank = len(tensor.shape)
+    return tf.slice(
+        tensor,
+        begin=[0] * (rank - 1) + [channel_index],
+        size=[-1] * (rank - 1) + [1]
+    )
+
 
 def blank_other_channels(tensor, keep_index):
-  tensor_shape = tensor.shape
-  n_channels = int(tensor_shape[-1])
-  rank = len(tensor_shape)
-  tensor_slice = get_channel_slice(
-    tensor,
-    keep_index
-  )
-  paddings = tf.constant(
-    [[0, 0]] * (rank - 1) +
-    [[keep_index, n_channels - keep_index - 1]]
-  )
-  padded = tf.pad(
-    tensor_slice, paddings, "CONSTANT"
-  )
-  return padded
+    tensor_shape = tensor.shape
+    n_channels = int(tensor_shape[-1])
+    rank = len(tensor_shape)
+    tensor_slice = get_channel_slice(
+        tensor,
+        keep_index
+    )
+    paddings = tf.constant(
+        [[0, 0]] * (rank - 1) +
+        [[keep_index, n_channels - keep_index - 1]]
+    )
+    padded = tf.pad(
+        tensor_slice, paddings, "CONSTANT"
+    )
+    return padded
diff --git a/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py b/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py
index 22cfc0d..d5e0887 100644
--- a/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py
+++ b/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py
@@ -2,90 +2,92 @@ from __future__ import absolute_import
 from __future__ import division
 
 import tensorflow as tf
-import numpy as np
 
 from sciencebeam_utils.utils.num import (
-  assert_all_close
+    assert_all_close
 )
 
 from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
-  find_nearest_centroid,
-  blank_other_channels
+    find_nearest_centroid,
+    blank_other_channels
 )
 
+
 def test_find_nearest_centroid():
-  colors = tf.constant([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
-  outputs = tf.constant([[[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]]])
-  nearest_color = find_nearest_centroid(outputs, colors)
+    colors = tf.constant([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
+    outputs = tf.constant([[[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]]])
+    nearest_color = find_nearest_centroid(outputs, colors)
 
-  with tf.Session() as session:
-    assert_all_close(
-      session.run(nearest_color),
-      [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]]]
-    )
+    with tf.Session() as session:
+        assert_all_close(
+            session.run(nearest_color),
+            [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]]]
+        )
 
-def test_find_nearest_centroid1():
-  colors = tf.constant([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
-  outputs = tf.constant([
-    [
-      [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]],
-      [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]]
-    ],
-    [
-      [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]],
-      [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]]
-    ]
-  ])
-  nearest_color = find_nearest_centroid(outputs, colors)
 
-  with tf.Session() as session:
-    assert_all_close(
-      session.run(nearest_color),
-      [
+def test_find_nearest_centroid1():
+    colors = tf.constant([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
+    outputs = tf.constant([
         [
-          [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]],
-          [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]]
+            [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]],
+            [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]]
         ],
         [
-          [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]],
-          [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]]
+            [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]],
+            [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]]
         ]
-      ]
-    )
+    ])
+    nearest_color = find_nearest_centroid(outputs, colors)
+
+    with tf.Session() as session:
+        assert_all_close(
+            session.run(nearest_color),
+            [
+                [
+                    [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]],
+                    [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]]
+                ],
+                [
+                    [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]],
+                    [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]]
+                ]
+            ]
+        )
+
 
 def test_blank_other_channels():
-  tensor = tf.constant([
-    [
-      [5, 5, 5, 5],
-      [6, 6, 6, 6],
-      [7, 7, 7, 7],
-      [8, 8, 8, 8]
-    ],
-    [
-      [5, 5, 5, 5],
-      [6, 6, 6, 6],
-      [7, 7, 7, 7],
-      [8, 8, 8, 8]
-    ]
-  ])
-  padded = blank_other_channels(
-    tensor, 1
-  )
-  with tf.Session() as session:
-    assert_all_close(
-      session.run(padded),
-      [
+    tensor = tf.constant([
         [
-          [0, 5, 0, 0],
-          [0, 6, 0, 0],
-          [0, 7, 0, 0],
-          [0, 8, 0, 0]
+            [5, 5, 5, 5],
+            [6, 6, 6, 6],
+            [7, 7, 7, 7],
+            [8, 8, 8, 8]
         ],
         [
-          [0, 5, 0, 0],
-          [0, 6, 0, 0],
-          [0, 7, 0, 0],
-          [0, 8, 0, 0]
+            [5, 5, 5, 5],
+            [6, 6, 6, 6],
+            [7, 7, 7, 7],
+            [8, 8, 8, 8]
         ]
-      ]
+    ])
+    padded = blank_other_channels(
+        tensor, 1
     )
+    with tf.Session() as session:
+        assert_all_close(
+            session.run(padded),
+            [
+                [
+                    [0, 5, 0, 0],
+                    [0, 6, 0, 0],
+                    [0, 7, 0, 0],
+                    [0, 8, 0, 0]
+                ],
+                [
+                    [0, 5, 0, 0],
+                    [0, 6, 0, 0],
+                    [0, 7, 0, 0],
+                    [0, 8, 0, 0]
+                ]
+            ]
+        )
diff --git a/sciencebeam_gym/trainer/predict.py b/sciencebeam_gym/trainer/predict.py
index 441a528..a81de41 100644
--- a/sciencebeam_gym/trainer/predict.py
+++ b/sciencebeam_gym/trainer/predict.py
@@ -8,48 +8,52 @@ import tensorflow as tf
 from sciencebeam_gym.utils.tf import FileIO
 
 from sciencebeam_gym.inference_model import (
-  load_inference_model
+    load_inference_model
 )
 
 from sciencebeam_gym.trainer.checkpoint import (
-  load_last_checkpoint_as_inference_model
+    load_last_checkpoint_as_inference_model
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def predict_using_inference_model(
-  inference_model, predict_filename, output_image_filename):
+        inference_model, predict_filename, output_image_filename):
 
-  with FileIO(predict_filename, 'rb') as input_f:
-    image_bytes = input_f.read()
-    img = Image.open(BytesIO(image_bytes)).convert('RGB')
-    img_data = np.asarray(img, dtype=np.uint8)
-    img_data_batch = np.reshape(img_data, tuple([1] + list(img_data.shape)))
+    with FileIO(predict_filename, 'rb') as input_f:
+        image_bytes = input_f.read()
+        img = Image.open(BytesIO(image_bytes)).convert('RGB')
+        img_data = np.asarray(img, dtype=np.uint8)
+        img_data_batch = np.reshape(img_data, tuple([1] + list(img_data.shape)))
 
-  output_img_data_batch = inference_model(img_data_batch)
-  output_img_data = output_img_data_batch[0]
-  output_image = Image.fromarray(output_img_data, 'RGB')
-  out = BytesIO()
-  output_image.save(out, 'png')
-  output_image_bytes = out.getvalue()
+    output_img_data_batch = inference_model(img_data_batch)
+    output_img_data = output_img_data_batch[0]
+    output_image = Image.fromarray(output_img_data, 'RGB')
+    out = BytesIO()
+    output_image.save(out, 'png')
+    output_image_bytes = out.getvalue()
+
+    get_logger().info('writing to %s', output_image_filename)
+    with FileIO(output_image_filename, 'wb') as output_f:
+        output_f.write(output_image_bytes)
 
-  get_logger().info('writing to %s', output_image_filename)
-  with FileIO(output_image_filename, 'wb') as output_f:
-    output_f.write(output_image_bytes)
 
 def load_saved_model_and_predict(export_dir, predict_filename, output_image_filename):
-  with tf.Session(graph=tf.Graph()):
-    predict_using_inference_model(
-      load_inference_model(export_dir),
-      predict_filename,
-      output_image_filename
-    )
+    with tf.Session(graph=tf.Graph()):
+        predict_using_inference_model(
+            load_inference_model(export_dir),
+            predict_filename,
+            output_image_filename
+        )
+
 
 def load_checkpoint_and_predict(model, checkpoint_path, predict_filename, output_image_filename):
-  with tf.Session(graph=tf.Graph()):
-    predict_using_inference_model(
-      load_last_checkpoint_as_inference_model(model, checkpoint_path),
-      predict_filename,
-      output_image_filename
-    )
+    with tf.Session(graph=tf.Graph()):
+        predict_using_inference_model(
+            load_last_checkpoint_as_inference_model(model, checkpoint_path),
+            predict_filename,
+            output_image_filename
+        )
diff --git a/sciencebeam_gym/trainer/saver.py b/sciencebeam_gym/trainer/saver.py
index ef9797e..1ad9b04 100644
--- a/sciencebeam_gym/trainer/saver.py
+++ b/sciencebeam_gym/trainer/saver.py
@@ -3,17 +3,19 @@ import logging
 import tensorflow as tf
 
 from sciencebeam_gym.inference_model import (
-  save_inference_model
+    save_inference_model
 )
 
 from sciencebeam_gym.trainer.checkpoint import (
-  load_last_checkpoint_as_inference_model
+    load_last_checkpoint_as_inference_model
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def load_checkpoint_and_save_model(model, checkpoint_path, export_dir):
-  with tf.Session(graph=tf.Graph()):
-    inference_model = load_last_checkpoint_as_inference_model(model, checkpoint_path)
-    save_inference_model(export_dir, inference_model)
+    with tf.Session(graph=tf.Graph()):
+        inference_model = load_last_checkpoint_as_inference_model(model, checkpoint_path)
+        save_inference_model(export_dir, inference_model)
diff --git a/sciencebeam_gym/trainer/task.py b/sciencebeam_gym/trainer/task.py
index 192f483..3adf822 100644
--- a/sciencebeam_gym/trainer/task.py
+++ b/sciencebeam_gym/trainer/task.py
@@ -21,748 +21,771 @@ from tensorflow.python.client.device_lib import list_local_devices
 
 from sciencebeam_gym.trainer.evaluator import Evaluator
 from sciencebeam_gym.trainer.util import (
-  CustomSupervisor,
-  SimpleStepScheduler,
-  get_graph_size
+    CustomSupervisor,
+    SimpleStepScheduler,
+    get_graph_size
 )
 
 from sciencebeam_gym.trainer.predict import (
-  load_checkpoint_and_predict,
-  load_saved_model_and_predict
+    load_checkpoint_and_predict,
+    load_saved_model_and_predict
 )
 
 from sciencebeam_gym.trainer.saver import (
-  load_checkpoint_and_save_model
+    load_checkpoint_and_save_model
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class TrainingProgressLogger(object):
-  def __init__(self, start_time, start_step, task):
-    self.start_time = start_time
-    self.start_step = start_step
-    self.last_log_time = start_time
-    self.last_global_step = start_step
-    self.last_local_step = 0
-    self.task = task
-
-  def get_last_log_time(self):
-    return self.last_log_time
-
-  def log(self, now, global_step, local_step):
-    """Logs training progress."""
-    logging.info(
-      'Train [%s/%d], step %d (%.3f sec) %.1f '
-      'global steps/s, %.1f local steps/s',
-      self.task.type,
-      self.task.index,
-      global_step,
-      (now - self.start_time),
-      (global_step - self.last_global_step) /
-      (now - self.last_log_time),
-      (local_step - self.last_local_step) /
-      (now - self.last_log_time)
-    )
-    self.last_log_time = now
-    self.last_global_step = global_step
-    self.last_local_step = local_step
+    def __init__(self, start_time, start_step, task):
+        self.start_time = start_time
+        self.start_step = start_step
+        self.last_log_time = start_time
+        self.last_global_step = start_step
+        self.last_local_step = 0
+        self.task = task
+
+    def get_last_log_time(self):
+        return self.last_log_time
+
+    def log(self, now, global_step, local_step):
+        """Logs training progress."""
+        logging.info(
+            'Train [%s/%d], step %d (%.3f sec) %.1f '
+            'global steps/s, %.1f local steps/s',
+            self.task.type,
+            self.task.index,
+            global_step,
+            (now - self.start_time),
+            (global_step - self.last_global_step) /
+            (now - self.last_log_time),
+            (local_step - self.last_local_step) /
+            (now - self.last_log_time)
+        )
+        self.last_log_time = now
+        self.last_global_step = global_step
+        self.last_local_step = local_step
+
 
 def get_qualitative_evaluator(args, model, run_async):
-  if args.qualitative_data_paths:
-    return Evaluator(
-      args,
-      model,
-      train_dir(args.output_path),
-      args.qualitative_data_paths,
-      dataset='qualitative_set',
-      eval_set_size=args.qualitative_set_size or args.eval_set_size,
-      qualitative_set_size=args.qualitative_set_size,
-      run_async=run_async
-    )
-  else:
-    return None
+    if args.qualitative_data_paths:
+        return Evaluator(
+            args,
+            model,
+            train_dir(args.output_path),
+            args.qualitative_data_paths,
+            dataset='qualitative_set',
+            eval_set_size=args.qualitative_set_size or args.eval_set_size,
+            qualitative_set_size=args.qualitative_set_size,
+            run_async=run_async
+        )
+    else:
+        return None
+
 
 def set_random_seed(seed):
-  get_logger().info('setting random seed to: %s', seed)
-  random.seed(seed)
-  tf.set_random_seed(seed)
+    get_logger().info('setting random seed to: %s', seed)
+    random.seed(seed)
+    tf.set_random_seed(seed)
 
-class Trainer(object):
-  """Performs model training and optionally evaluation."""
-
-  def __init__(self, args, model, cluster, task):
-    self.args = args
-    self.model = model
-    self.cluster = cluster
-    self.task = task
-    self.run_async = None
-    default_run_async = lambda f, args: f(*args)
-    run_async = lambda f, args: (self.run_async or default_run_async)(f, args)
-    self.evaluator = Evaluator(
-      self.args, self.model,
-      train_dir(self.args.output_path),
-      self.args.eval_data_paths,
-      'eval_set',
-      run_async=run_async
-    )
-    self.train_evaluator = Evaluator(
-      self.args, self.model,
-      train_dir(self.args.output_path),
-      self.args.train_data_paths,
-      'train_set',
-      run_async=run_async
-    )
-    self.qualitative_evaluator = get_qualitative_evaluator(
-      self.args,
-      self.model,
-      run_async=run_async
-    )
-    self.min_train_eval_rate = args.min_train_eval_rate
-    self.global_step = None
-    self.last_save = 0
-
-  def run_training(self):
-    get_logger().info('creating async pool, pool size: %d', self.args.pool_size)
-    pool = Pool(processes=self.args.pool_size)
-    self.run_async = lambda f, args: pool.apply_async(f, args)
-    self._do_run_training()
-    get_logger().info('Waiting for tasks to complete')
-    pool.close()
-    pool.join()
-    self.run_async = None
 
-  def _do_run_training(self):
-    """Runs a Master."""
-    logger = get_logger()
+class Trainer(object):
+    """Performs model training and optionally evaluation."""
+
+    def __init__(self, args, model, cluster, task):
+        self.args = args
+        self.model = model
+        self.cluster = cluster
+        self.task = task
+        self.run_async = None
+
+        def default_run_async(f, args):
+            return f(*args)
+
+        def run_async(f, args):
+            return (self.run_async or default_run_async)(f, args)
+
+        self.evaluator = Evaluator(
+            self.args, self.model,
+            train_dir(self.args.output_path),
+            self.args.eval_data_paths,
+            'eval_set',
+            run_async=run_async
+        )
+        self.train_evaluator = Evaluator(
+            self.args, self.model,
+            train_dir(self.args.output_path),
+            self.args.train_data_paths,
+            'train_set',
+            run_async=run_async
+        )
+        self.qualitative_evaluator = get_qualitative_evaluator(
+            self.args,
+            self.model,
+            run_async=run_async
+        )
+        self.min_train_eval_rate = args.min_train_eval_rate
+        self.global_step = None
+        self.last_save = 0
+
+    def run_training(self):
+        get_logger().info('creating async pool, pool size: %d', self.args.pool_size)
+        pool = Pool(processes=self.args.pool_size)
+        self.run_async = lambda f, args: pool.apply_async(f, args)
+        self._do_run_training()
+        get_logger().info('Waiting for tasks to complete')
+        pool.close()
+        pool.join()
+        self.run_async = None
+
+    def _do_run_training(self):
+        """Runs a Master."""
+        logger = get_logger()
+
+        logger.info('tensorflow version: %s', tf.__version__)
+
+        if self.args.seed is not None:
+            set_random_seed(self.args.seed)
+        self.train_evaluator.init()
+        self.evaluator.init()
+        if self.qualitative_evaluator:
+            self.qualitative_evaluator.init()
+        ensure_output_path(self.args.output_path)
+        train_path = train_dir(self.args.output_path)
+        # model_path = model_dir(self.args.output_path)
+        is_master = self.task.type != 'worker'
+        log_interval = self.args.log_interval_secs
+        save_interval = self.args.save_interval_secs
+        eval_interval = self.args.eval_interval_secs
+        summary_interval = log_interval
+        summary_freq = self.args.log_freq
+        if is_master and self.task.index > 0:
+            raise StandardError('Only one replica of master expected')
+
+        if self.cluster:
+            logging.info('Starting %s/%d', self.task.type, self.task.index)
+            server = start_server(self.cluster, self.task)
+            target = server.target
+            device_fn = tf.train.replica_device_setter(
+                ps_device='/job:ps',
+                worker_device='/job:%s/task:%d' % (self.task.type, self.task.index),
+                cluster=self.cluster
+            )
+            # We use a device_filter to limit the communication between this job
+            # and the parameter servers, i.e., there is no need to directly
+            # communicate with the other workers; attempting to do so can result
+            # in reliability problems.
+            device_filters = [
+                '/job:ps', '/job:%s/task:%d' % (self.task.type, self.task.index)
+            ]
+            config = tf.ConfigProto(device_filters=device_filters)
+        else:
+            target = ''
+            device_fn = ''
+            config = None
+
+        logger.info('batch_size: %s', self.args.batch_size)
+
+        logger.info(
+            'available devices: %s',
+            ', '.join([
+                '%s (%s)' % (x.name, x.device_type)
+                for x in list_local_devices()
+            ])
+        )
 
-    logger.info('tensorflow version: %s', tf.__version__)
-
-    if self.args.seed is not None:
-      set_random_seed(self.args.seed)
-    self.train_evaluator.init()
-    self.evaluator.init()
-    if self.qualitative_evaluator:
-      self.qualitative_evaluator.init()
-    ensure_output_path(self.args.output_path)
-    train_path = train_dir(self.args.output_path)
-    # model_path = model_dir(self.args.output_path)
-    is_master = self.task.type != 'worker'
-    log_interval = self.args.log_interval_secs
-    save_interval = self.args.save_interval_secs
-    eval_interval = self.args.eval_interval_secs
-    summary_interval = log_interval
-    summary_freq = self.args.log_freq
-    if is_master and self.task.index > 0:
-      raise StandardError('Only one replica of master expected')
-
-    if self.cluster:
-      logging.info('Starting %s/%d', self.task.type, self.task.index)
-      server = start_server(self.cluster, self.task)
-      target = server.target
-      device_fn = tf.train.replica_device_setter(
-        ps_device='/job:ps',
-        worker_device='/job:%s/task:%d' % (self.task.type, self.task.index),
-        cluster=self.cluster
-      )
-      # We use a device_filter to limit the communication between this job
-      # and the parameter servers, i.e., there is no need to directly
-      # communicate with the other workers; attempting to do so can result
-      # in reliability problems.
-      device_filters = [
-        '/job:ps', '/job:%s/task:%d' % (self.task.type, self.task.index)
-      ]
-      config = tf.ConfigProto(device_filters=device_filters)
-    else:
-      target = ''
-      device_fn = ''
-      config = None
-
-    logger.info('batch_size: %s', self.args.batch_size)
-
-    logger.info(
-      'available devices: %s',
-      ', '.join([
-        '%s (%s)' % (x.name, x.device_type)
-        for x in list_local_devices()
-      ])
-    )
-
-    with tf.Graph().as_default() as graph:
-      with tf.device(device_fn):
-        # Build the training graph.
-        logger.info('building graph...')
-        tensors = self.model.build_train_graph(
-          self.args.train_data_paths,
-          self.args.batch_size)
-
-        logger.info('done building graph, calculating graph size...')
-        logger.info('graph_size: %s bytes', '{:,}'.format(get_graph_size()))
-
-        # Create a saver for writing training checkpoints.
-        saver = tf.train.Saver(
-          max_to_keep=self.args.save_max_to_keep
+        with tf.Graph().as_default() as graph:
+            with tf.device(device_fn):
+                # Build the training graph.
+                logger.info('building graph...')
+                tensors = self.model.build_train_graph(
+                    self.args.train_data_paths,
+                    self.args.batch_size)
+
+                logger.info('done building graph, calculating graph size...')
+                logger.info('graph_size: %s bytes', '{:,}'.format(get_graph_size()))
+
+                # Create a saver for writing training checkpoints.
+                saver = tf.train.Saver(
+                    max_to_keep=self.args.save_max_to_keep
+                )
+
+        # Create a "supervisor", which oversees the training process.
+        sv = CustomSupervisor(
+            model=self.model,
+            graph=graph,
+            is_chief=is_master,
+            logdir=train_path,
+            saver=saver,
+            # Write summary_ops by hand.
+            summary_op=None,
+            global_step=tensors.global_step
         )
 
-    # Create a "supervisor", which oversees the training process.
-    sv = CustomSupervisor(
-      model=self.model,
-      graph=graph,
-      is_chief=is_master,
-      logdir=train_path,
-      saver=saver,
-      # Write summary_ops by hand.
-      summary_op=None,
-      global_step=tensors.global_step
-    )
-
-    save_path = sv.save_path
-
-    should_retry = True
-    local_step = 0
-
-    while should_retry:
-      try:
-        should_retry = False
-        with sv.managed_session(target, config=config) as session:
-          start_time = time.time()
-          now = start_time
-
-          global_step = session.run(tensors.global_step)
-          training_progress_logger = TrainingProgressLogger(
-            start_time,
-            global_step,
-            self.task
-          )
-
-          log_scheduler = SimpleStepScheduler(
-            lambda: training_progress_logger.log(now, global_step, local_step),
-            min_interval=log_interval,
-            min_freq=self.args.log_freq,
-            step=global_step,
-            last_run=start_time
-          )
-
-          def do_save():
-            logger.info('saving model to %s (%s)', save_path, global_step)
-            saver.save(session, save_path, tensors.global_step)
-
-          save_scheduler = SimpleStepScheduler(
-            do_save,
-            min_interval=save_interval,
-            min_freq=self.args.save_freq,
-            step=global_step,
-            last_run=start_time
-          )
-
-          eval_train_scheduler = SimpleStepScheduler(
-            lambda: self.eval_train(session, tensors, global_step),
-            min_interval=eval_interval,
-            min_freq=self.args.eval_freq,
-            step=global_step,
-            last_run=start_time
-          )
-
-          schedulers = [
-            log_scheduler,
-            save_scheduler,
-            eval_train_scheduler
-          ]
-
-          if is_master:
-            eval_scheduler = SimpleStepScheduler(
-              lambda: self.eval(global_step=global_step),
-              min_interval=save_interval,
-              min_freq=self.args.save_freq,
-              step=global_step,
-              last_run=start_time
-            )
-            schedulers = schedulers + [eval_scheduler]
-
-          summary_op = sv.summary_op if tensors.summary is None else tensors.summary
-          if summary_op is not None:
-            schedulers.append(SimpleStepScheduler(
-              lambda: sv.summary_writer.add_summary(
-                *session.run([summary_op, tensors.global_step])
-              ),
-              min_interval=summary_interval,
-              min_freq=summary_freq,
-              step=global_step,
-              last_run=start_time
-            ))
-
-          # Loop until the supervisor shuts down or args.max_steps have
-          # completed.
-          max_steps = self.args.max_steps
-          while not sv.should_stop() and global_step < max_steps:
-            logging.info("global_step: %s", global_step)
-            try:
-              # Run one step of the model.
-              global_step = session.run([tensors.global_step, tensors.train])[0]
-              logging.info("global_step: %s", global_step)
-              local_step += 1
+        save_path = sv.save_path
 
-              now = time.time()
-              for scheduler in schedulers:
-                scheduler.step(now)
+        should_retry = True
+        local_step = 0
+
+        while should_retry:
+            try:
+                should_retry = False
+                with sv.managed_session(target, config=config) as session:
+                    start_time = time.time()
+                    now = start_time
+
+                    global_step = session.run(tensors.global_step)
+                    training_progress_logger = TrainingProgressLogger(
+                        start_time,
+                        global_step,
+                        self.task
+                    )
+
+                    log_scheduler = SimpleStepScheduler(
+                        lambda: training_progress_logger.log(now, global_step, local_step),
+                        min_interval=log_interval,
+                        min_freq=self.args.log_freq,
+                        step=global_step,
+                        last_run=start_time
+                    )
+
+                    def do_save():
+                        logger.info('saving model to %s (%s)', save_path, global_step)
+                        saver.save(session, save_path, tensors.global_step)
+
+                    save_scheduler = SimpleStepScheduler(
+                        do_save,
+                        min_interval=save_interval,
+                        min_freq=self.args.save_freq,
+                        step=global_step,
+                        last_run=start_time
+                    )
+
+                    eval_train_scheduler = SimpleStepScheduler(
+                        lambda: self.eval_train(session, tensors, global_step),
+                        min_interval=eval_interval,
+                        min_freq=self.args.eval_freq,
+                        step=global_step,
+                        last_run=start_time
+                    )
+
+                    schedulers = [
+                        log_scheduler,
+                        save_scheduler,
+                        eval_train_scheduler
+                    ]
+
+                    if is_master:
+                        eval_scheduler = SimpleStepScheduler(
+                            lambda: self.eval(global_step=global_step),
+                            min_interval=save_interval,
+                            min_freq=self.args.save_freq,
+                            step=global_step,
+                            last_run=start_time
+                        )
+                        schedulers = schedulers + [eval_scheduler]
+
+                    summary_op = sv.summary_op if tensors.summary is None else tensors.summary
+                    if summary_op is not None:
+                        schedulers.append(SimpleStepScheduler(
+                            lambda: sv.summary_writer.add_summary(
+                                *session.run([summary_op, tensors.global_step])
+                            ),
+                            min_interval=summary_interval,
+                            min_freq=summary_freq,
+                            step=global_step,
+                            last_run=start_time
+                        ))
+
+                    # Loop until the supervisor shuts down or args.max_steps have
+                    # completed.
+                    max_steps = self.args.max_steps
+                    while not sv.should_stop() and global_step < max_steps:
+                        logging.info("global_step: %s", global_step)
+                        try:
+                            # Run one step of the model.
+                            global_step = session.run([tensors.global_step, tensors.train])[0]
+                            logging.info("global_step: %s", global_step)
+                            local_step += 1
+
+                            now = time.time()
+                            for scheduler in schedulers:
+                                scheduler.step(now)
+
+                        except tf.errors.AbortedError as e:
+                            should_retry = True
+                            logging.info('AbortedError (%s)', e)
+                        except (KeyboardInterrupt, tf.errors.CancelledError):
+                            logging.info('cancelled')
+                            should_retry = False
+
+                    logging.info('finished (is_master: %s)', is_master)
+
+                    if is_master:
+                        # Take the final checkpoint and compute the final accuracy.
+                        now = time.time()
+                        for scheduler in schedulers:
+                            scheduler.flush(now)
 
             except tf.errors.AbortedError as e:
-              should_retry = True
-              logging.info('AbortedError (%s)', e)
-            except (KeyboardInterrupt, tf.errors.CancelledError):
-              logging.info('cancelled')
-              should_retry = False
+                should_retry = True
+                logging.info('AbortedError (%s)', e)
 
-          logging.info('finished (is_master: %s)', is_master)
+        # Ask for all the services to stop.
+        sv.stop()
 
-          if is_master:
-            # Take the final checkpoint and compute the final accuracy.
-            now = time.time()
-            for scheduler in schedulers:
-              scheduler.flush(now)
+    def eval_train(self, session, tensors, global_step):
+        """Runs evaluation loop."""
+        logging.info(
+            'Eval, step %d:\n- on train set %s',
+            global_step,
+            self.model.format_metric_values(
+                self.train_evaluator.evaluate_in_session(
+                    session=session,
+                    tensors=tensors
+                )
+            )
+        )
 
-      except tf.errors.AbortedError as e:
-        should_retry = True
-        logging.info('AbortedError (%s)', e)
-
-    # Ask for all the services to stop.
-    sv.stop()
-
-  def eval_train(self, session, tensors, global_step):
-    """Runs evaluation loop."""
-    logging.info(
-      'Eval, step %d:\n- on train set %s',
-      global_step,
-      self.model.format_metric_values(
-        self.train_evaluator.evaluate_in_session(
-          session=session,
-          tensors=tensors
+    def eval(self, global_step=None):
+        """Runs evaluation loop."""
+        if self.qualitative_evaluator:
+            logging.info(
+                'Quantitive Eval, step %s:\n- on eval set %s',
+                global_step,
+                self.model.format_metric_values(self.qualitative_evaluator.evaluate())
+            )
+        logging.info(
+            'Eval, step %s:\n- on eval set %s',
+            global_step,
+            self.model.format_metric_values(self.evaluator.evaluate())
         )
-      )
-    )
 
-  def eval(self, global_step=None):
-    """Runs evaluation loop."""
-    if self.qualitative_evaluator:
-      logging.info(
-        'Quantitive Eval, step %s:\n- on eval set %s',
-        global_step,
-        self.model.format_metric_values(self.qualitative_evaluator.evaluate())
-      )
-    logging.info(
-      'Eval, step %s:\n- on eval set %s',
-      global_step,
-      self.model.format_metric_values(self.evaluator.evaluate())
-    )
 
 def copy_data_to_tmp(input_files):
-  """Copies data to /tmp/ and returns glob matching the files."""
-  files = []
-  for e in input_files:
-    for path in e.split(','):
-      files.extend(file_io.get_matching_files(path))
+    """Copies data to /tmp/ and returns glob matching the files."""
+    files = []
+    for e in input_files:
+        for path in e.split(','):
+            files.extend(file_io.get_matching_files(path))
 
-  for path in files:
-    if not path.startswith('gs://'):
-      return input_files
+    for path in files:
+        if not path.startswith('gs://'):
+            return input_files
 
-  tmp_path = os.path.join('/tmp/', str(uuid.uuid4()))
-  os.makedirs(tmp_path)
-  subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [tmp_path])
-  return [os.path.join(tmp_path, '*')]
+    tmp_path = os.path.join('/tmp/', str(uuid.uuid4()))
+    os.makedirs(tmp_path)
+    subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [tmp_path])
+    return [os.path.join(tmp_path, '*')]
 
 
 def write_predictions(args, model, cluster, task):
-  if not cluster or not task or task.type == 'master':
-    pass  # Run locally.
-  else:
-    raise ValueError('invalid task_type %s' % (task.type,))
-
-  if args.seed is not None:
-    set_random_seed(args.seed)
-
-  logger = get_logger()
-  logger.info('Starting to write predictions on %s/%d', task.type, task.index)
-  pool = Pool(processes=args.pool_size)
-  run_async = lambda f, args: pool.apply_async(f, args)
-
-  qualitative_evaluator = get_qualitative_evaluator(
-    args,
-    model,
-    run_async=run_async
-  )
-  if qualitative_evaluator:
-    qualitative_evaluator.init()
-    qualitative_evaluator.write_predictions()
-
-  evaluator = Evaluator(
-    args, model, train_dir(args.output_path), args.eval_data_paths,
-    run_async=run_async
-  )
-  evaluator.init()
-  evaluator.write_predictions()
-
-  logger.info('Waiting for background tasks to finish')
-  pool.close()
-  pool.join()
-  logger.info('Done writing predictions on %s/%d', task.type, task.index)
+    if not cluster or not task or task.type == 'master':
+        pass  # Run locally.
+    else:
+        raise ValueError('invalid task_type %s' % (task.type,))
+
+    if args.seed is not None:
+        set_random_seed(args.seed)
+
+    logger = get_logger()
+    logger.info('Starting to write predictions on %s/%d', task.type, task.index)
+    pool = Pool(processes=args.pool_size)
+
+    def run_async(f, args):
+        return pool.apply_async(f, args)
+
+    qualitative_evaluator = get_qualitative_evaluator(
+        args,
+        model,
+        run_async=run_async
+    )
+    if qualitative_evaluator:
+        qualitative_evaluator.init()
+        qualitative_evaluator.write_predictions()
+
+    evaluator = Evaluator(
+        args, model, train_dir(args.output_path), args.eval_data_paths,
+        run_async=run_async
+    )
+    evaluator.init()
+    evaluator.write_predictions()
+
+    logger.info('Waiting for background tasks to finish')
+    pool.close()
+    pool.join()
+    logger.info('Done writing predictions on %s/%d', task.type, task.index)
+
 
 def predict(args, model, cluster, task):
-  if not cluster or not task or task.type == 'master':
-    pass  # Run locally.
-  else:
-    raise ValueError('invalid task_type %s' % (task.type,))
-
-  if args.seed is not None:
-    set_random_seed(args.seed)
-
-  predict_filename = args.predict
-  output_image_filename = args.predict_output
-  assert output_image_filename
-  if args.model_export_path:
-    load_saved_model_and_predict(args.model_export_path, predict_filename, output_image_filename)
-  else:
-    checkpoint_path = train_dir(args.output_path)
-    load_checkpoint_and_predict(model, checkpoint_path, predict_filename, output_image_filename)
+    if not cluster or not task or task.type == 'master':
+        pass  # Run locally.
+    else:
+        raise ValueError('invalid task_type %s' % (task.type,))
+
+    if args.seed is not None:
+        set_random_seed(args.seed)
+
+    predict_filename = args.predict
+    output_image_filename = args.predict_output
+    assert output_image_filename
+    if args.model_export_path:
+        load_saved_model_and_predict(args.model_export_path,
+                                     predict_filename, output_image_filename)
+    else:
+        checkpoint_path = train_dir(args.output_path)
+        load_checkpoint_and_predict(model, checkpoint_path, predict_filename, output_image_filename)
+
 
 def save_model(args, model, cluster, task):
-  if not cluster or not task or task.type == 'master':
-    pass  # Run locally.
-  else:
-    raise ValueError('invalid task_type %s' % (task.type,))
+    if not cluster or not task or task.type == 'master':
+        pass  # Run locally.
+    else:
+        raise ValueError('invalid task_type %s' % (task.type,))
 
-  if args.seed is not None:
-    set_random_seed(args.seed)
+    if args.seed is not None:
+        set_random_seed(args.seed)
+
+    export_dir = args.save_model
+    checkpoint_path = train_dir(args.output_path)
+    load_checkpoint_and_save_model(model, checkpoint_path, export_dir)
 
-  export_dir = args.save_model
-  checkpoint_path = train_dir(args.output_path)
-  load_checkpoint_and_save_model(model, checkpoint_path, export_dir)
 
 def dispatch(args, model, cluster, task):
-  if not cluster or not task or task.type == 'master':
-    # Run locally.
-    Trainer(args, model, cluster, task).run_training()
-  elif task.type == 'ps':
-    run_parameter_server(cluster, task)
-  elif task.type == 'worker':
-    Trainer(args, model, cluster, task).run_training()
-  else:
-    raise ValueError('invalid task_type %s' % (task.type,))
+    if not cluster or not task or task.type == 'master':
+        # Run locally.
+        Trainer(args, model, cluster, task).run_training()
+    elif task.type == 'ps':
+        run_parameter_server(cluster, task)
+    elif task.type == 'worker':
+        Trainer(args, model, cluster, task).run_training()
+    else:
+        raise ValueError('invalid task_type %s' % (task.type,))
 
 
 def run_parameter_server(cluster, task):
-  logging.info('Starting parameter server %d', task.index)
-  server = start_server(cluster, task)
-  server.join()
+    logging.info('Starting parameter server %d', task.index)
+    server = start_server(cluster, task)
+    server.join()
 
 
 def start_server(cluster, task):
-  if not task.type:
-    raise ValueError('--task_type must be specified.')
-  if task.index is None:
-    raise ValueError('--task_index must be specified.')
-
-  # Create and start a server.
-  return tf.train.Server(
-    tf.train.ClusterSpec(cluster),
-    protocol='grpc',
-    job_name=task.type,
-    task_index=task.index
-  )
+    if not task.type:
+        raise ValueError('--task_type must be specified.')
+    if task.index is None:
+        raise ValueError('--task_index must be specified.')
+
+    # Create and start a server.
+    return tf.train.Server(
+        tf.train.ClusterSpec(cluster),
+        protocol='grpc',
+        job_name=task.type,
+        task_index=task.index
+    )
 
 
 def ensure_output_path(output_path):
-  if not output_path:
-    raise ValueError('output_path must be specified')
+    if not output_path:
+        raise ValueError('output_path must be specified')
 
-  # GCS doesn't have real directories.
-  if output_path.startswith('gs://'):
-    return
+    # GCS doesn't have real directories.
+    if output_path.startswith('gs://'):
+        return
 
-  ensure_dir(output_path)
+    ensure_dir(output_path)
 
 
 def ensure_dir(path):
-  try:
-    os.makedirs(path)
-  except OSError as e:
-    # If the directory already existed, ignore the error.
-    if e.args[0] == 17:
-      pass
-    else:
-      raise
+    try:
+        os.makedirs(path)
+    except OSError as e:
+        # If the directory already existed, ignore the error.
+        if e.args[0] == 17:
+            pass
+        else:
+            raise
 
 
 def train_dir(output_path):
-  return os.path.join(output_path, 'train')
+    return os.path.join(output_path, 'train')
 
 
 def eval_dir(output_path):
-  return os.path.join(output_path, 'eval')
+    return os.path.join(output_path, 'eval')
 
 
 def model_dir(output_path):
-  return os.path.join(output_path, 'model')
+    return os.path.join(output_path, 'model')
+
 
 def run(model, argv):
-  """Runs the training loop."""
-  parser = argparse.ArgumentParser()
-  parser.add_argument(
-    '--train_data_paths',
-    nargs='+',
-    type=str,
-    help='The paths to the training data files. '
-    'Can be comma separated list of files or glob pattern.'
-  )
-  parser.add_argument(
-    '--eval_data_paths',
-    nargs='+',
-    type=str,
-    help='The path to the files used for evaluation. '
-    'Can be comma separated list of files or glob pattern.'
-  )
-  parser.add_argument(
-    '--qualitative_data_paths',
-    nargs='+',
-    type=str,
-    help='The path to the files used for qualitative evaluation. '
-    'You may choose a different set for the qualitative analysis to keep the results consistent.'
-  )
-  parser.add_argument(
-    '--output_path',
-    type=str,
-    help='The path to which checkpoints and other outputs '
-    'should be saved. This can be either a local or GCS '
-    'path.'
-  )
-  parser.add_argument(
-    '--max_steps',
-    type=int,
-    default=1000
-  )
-  parser.add_argument(
-    '--batch_size',
-    type=int,
-    default=100,
-    help='Number of examples to be processed per mini-batch.'
-  )
-  parser.add_argument(
-    '--eval_set_size', type=int, default=370,
-    help='Number of examples in the eval set.'
-  )
-  parser.add_argument(
-    '--qualitative_set_size',
-    type=int,
-    help='Number of examples in the qualitative eval set.'
-  )
-  parser.add_argument(
-    '--eval_batch_size', type=int, help='Number of examples per eval batch.'
-  )
-  parser.add_argument(
-    '--eval_interval_secs',
-    type=float,
-    default=5,
-    help='Minimal interval between calculating evaluation metrics and saving'
-    ' evaluation summaries.'
-  )
-  parser.add_argument(
-    '--eval_freq',
-    type=int,
-    default=100,
-    help='Frequancy in steps between calculating evaluation metrics and saving'
-    ' evaluation summaries.'
-  )
-  parser.add_argument(
-    '--save_interval_secs',
-    type=float,
-    default=300,
-    help='Minimal interval between saving the model checkpoint'
-  )
-  parser.add_argument(
-    '--save_freq',
-    type=int,
-    default=1000,
-    help='Frequancy in steps between saving the model checkpoint'
-  )
-  parser.add_argument(
-    '--save_max_to_keep',
-    type=int,
-    default=2,
-    help='Maximum number of recent checkpoint files to keep'
-  )
-  parser.add_argument(
-    '--log_interval_secs',
-    type=float,
-    default=5,
-    help='Minimal interval between logging training metrics and saving '
-    'training summaries.'
-  )
-  parser.add_argument(
-    '--log_freq',
-    type=int,
-    default=500,
-    help='Frequancy in steps between logging training metrics and saving '
-    'training summaries.'
-  )
-  parser.add_argument(
-    '--write_predictions',
-    action='store_true',
-    default=False,
-    help='If set, model is restored from latest checkpoint '
-    'and predictions are written to a csv file and no training is performed.'
-  )
-  parser.add_argument(
-    '--predict',
-    type=str,
-    default=False,
-    help='If set, predicts output for a given image using the latest checkpoint.'
-  )
-  parser.add_argument(
-    '--predict-output',
-    type=str,
-    default=False,
-    help='The output file for the prediction.'
-  )
-  parser.add_argument(
-    '--model_export_path',
-    type=str,
-    default=False,
-    help='If specified, predict using the "saved model".'
-  )
-  parser.add_argument(
-    '--save_model',
-    type=str,
-    default=False,
-    help='If specified, export directory for the latest checkpoint'
-    ' to be saved as a more portable "saved model".'
-  )
-  parser.add_argument(
-    '--min_train_eval_rate',
-    type=int,
-    default=20,
-    help='Minimal train / eval time ratio on master. '
-    'Default value 20 means that 20x more time is used for training than '
-    'for evaluation. If evaluation takes more time the eval_interval_secs '
-    'is increased.'
-  )
-  parser.add_argument(
-    '--write_to_tmp',
-    action='store_true',
-    default=False,
-    help='If set, all checkpoints and summaries are written to '
-    'local filesystem (/tmp/) and copied to gcs once training is done. '
-    'This can speed up training but if training job fails all the summaries '
-    'and checkpoints are lost.'
-  )
-  parser.add_argument(
-    '--copy_train_data_to_tmp',
-    action='store_true',
-    default=False,
-    help='If set, training data is copied to local filesystem '
-    '(/tmp/). This can speed up training but requires extra space on the '
-    'local filesystem.'
-  )
-  parser.add_argument(
-    '--copy_eval_data_to_tmp',
-    action='store_true',
-    default=False,
-    help='If set, evaluation data is copied to local filesystem '
-    '(/tmp/). This can speed up training but requires extra space on the '
-    'local filesystem.'
-  )
-  parser.add_argument(
-    '--streaming_eval',
-    action='store_true',
-    default=False,
-    help='If set to True the evaluation is performed in streaming mode. '
-    'During each eval cycle the evaluation data is read and parsed from '
-    'files. This allows for having very large evaluation set. '
-    'If set to False (default) evaluation data is read once and cached in '
-    'memory. This results in faster evaluation cycle but can potentially '
-    'use more memory (in streaming mode large per-file read-ahead buffer is '
-    'used - which may exceed eval data size).'
-  )
-  parser.add_argument(
-    '--pool_size',
-    type=int,
-    default=50,
-    help='Number of examples in the eval set.'
-  )
-  parser.add_argument(
-    '--seed',
-    type=int,
-    help='The random seed to use'
-  )
-
-  args = parser.parse_args(argv)
-
-  env = json.loads(os.environ.get('TF_CONFIG', '{}'))
-
-  # Print the job data as provided by the service.
-  logging.info('Original job data: %s', env.get('job', {}))
-
-  # First find out if there's a task value on the environment variable.
-  # If there is none or it is empty define a default one.
-  task_data = env.get('task', None) or {'type': 'master', 'index': 0}
-  task = type('TaskSpec', (object,), task_data)
-  trial = task_data.get('trial')
-  if trial is not None:
-    args.output_path = os.path.join(args.output_path, trial)
-  if args.write_to_tmp and args.output_path.startswith('gs://'):
-    output_path = args.output_path
-    args.output_path = os.path.join('/tmp/', str(uuid.uuid4()))
-    os.makedirs(args.output_path)
-  else:
-    output_path = None
-
-  if args.copy_train_data_to_tmp:
-    args.train_data_paths = copy_data_to_tmp(args.train_data_paths)
-  if args.copy_eval_data_to_tmp:
-    args.eval_data_paths = copy_data_to_tmp(args.eval_data_paths)
-
-  if not args.eval_batch_size:
-    # If eval_batch_size not set, use min of batch_size and eval_set_size
-    args.eval_batch_size = min(args.batch_size, args.eval_set_size)
-    logging.info("setting eval batch size to %s", args.eval_batch_size)
-
-  cluster_data = env.get('cluster', None)
-  cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
-  if args.write_predictions:
-    write_predictions(args, model, cluster, task)
-  elif args.save_model:
-    save_model(args, model, cluster, task)
-  elif args.predict:
-    predict(args, model, cluster, task)
-  else:
-    dispatch(args, model, cluster, task)
-
-  if output_path and (not cluster or not task or task.type == 'master'):
-    subprocess.check_call([
-        'gsutil', '-m', '-q', 'cp', '-r', args.output_path + '/*', output_path
-    ])
-    shutil.rmtree(args.output_path, ignore_errors=True)
+    """Runs the training loop."""
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--train_data_paths',
+        nargs='+',
+        type=str,
+        help='The paths to the training data files. '
+        'Can be comma separated list of files or glob pattern.'
+    )
+    parser.add_argument(
+        '--eval_data_paths',
+        nargs='+',
+        type=str,
+        help='The path to the files used for evaluation. '
+        'Can be comma separated list of files or glob pattern.'
+    )
+    parser.add_argument(
+        '--qualitative_data_paths',
+        nargs='+',
+        type=str,
+        help='The path to the files used for qualitative evaluation.'
+        ' You may choose a different set for the qualitative analysis'
+        ' to keep the results consistent.'
+    )
+    parser.add_argument(
+        '--output_path',
+        type=str,
+        help='The path to which checkpoints and other outputs '
+        'should be saved. This can be either a local or GCS '
+        'path.'
+    )
+    parser.add_argument(
+        '--max_steps',
+        type=int,
+        default=1000
+    )
+    parser.add_argument(
+        '--batch_size',
+        type=int,
+        default=100,
+        help='Number of examples to be processed per mini-batch.'
+    )
+    parser.add_argument(
+        '--eval_set_size', type=int, default=370,
+        help='Number of examples in the eval set.'
+    )
+    parser.add_argument(
+        '--qualitative_set_size',
+        type=int,
+        help='Number of examples in the qualitative eval set.'
+    )
+    parser.add_argument(
+        '--eval_batch_size', type=int, help='Number of examples per eval batch.'
+    )
+    parser.add_argument(
+        '--eval_interval_secs',
+        type=float,
+        default=5,
+        help='Minimal interval between calculating evaluation metrics and saving'
+        ' evaluation summaries.'
+    )
+    parser.add_argument(
+        '--eval_freq',
+        type=int,
+        default=100,
+        help='Frequancy in steps between calculating evaluation metrics and saving'
+        ' evaluation summaries.'
+    )
+    parser.add_argument(
+        '--save_interval_secs',
+        type=float,
+        default=300,
+        help='Minimal interval between saving the model checkpoint'
+    )
+    parser.add_argument(
+        '--save_freq',
+        type=int,
+        default=1000,
+        help='Frequancy in steps between saving the model checkpoint'
+    )
+    parser.add_argument(
+        '--save_max_to_keep',
+        type=int,
+        default=2,
+        help='Maximum number of recent checkpoint files to keep'
+    )
+    parser.add_argument(
+        '--log_interval_secs',
+        type=float,
+        default=5,
+        help='Minimal interval between logging training metrics and saving '
+        'training summaries.'
+    )
+    parser.add_argument(
+        '--log_freq',
+        type=int,
+        default=500,
+        help='Frequancy in steps between logging training metrics and saving '
+        'training summaries.'
+    )
+    parser.add_argument(
+        '--write_predictions',
+        action='store_true',
+        default=False,
+        help='If set, model is restored from latest checkpoint '
+        'and predictions are written to a csv file and no training is performed.'
+    )
+    parser.add_argument(
+        '--predict',
+        type=str,
+        default=False,
+        help='If set, predicts output for a given image using the latest checkpoint.'
+    )
+    parser.add_argument(
+        '--predict-output',
+        type=str,
+        default=False,
+        help='The output file for the prediction.'
+    )
+    parser.add_argument(
+        '--model_export_path',
+        type=str,
+        default=False,
+        help='If specified, predict using the "saved model".'
+    )
+    parser.add_argument(
+        '--save_model',
+        type=str,
+        default=False,
+        help='If specified, export directory for the latest checkpoint'
+        ' to be saved as a more portable "saved model".'
+    )
+    parser.add_argument(
+        '--min_train_eval_rate',
+        type=int,
+        default=20,
+        help='Minimal train / eval time ratio on master. '
+        'Default value 20 means that 20x more time is used for training than '
+        'for evaluation. If evaluation takes more time the eval_interval_secs '
+        'is increased.'
+    )
+    parser.add_argument(
+        '--write_to_tmp',
+        action='store_true',
+        default=False,
+        help='If set, all checkpoints and summaries are written to '
+        'local filesystem (/tmp/) and copied to gcs once training is done. '
+        'This can speed up training but if training job fails all the summaries '
+        'and checkpoints are lost.'
+    )
+    parser.add_argument(
+        '--copy_train_data_to_tmp',
+        action='store_true',
+        default=False,
+        help='If set, training data is copied to local filesystem '
+        '(/tmp/). This can speed up training but requires extra space on the '
+        'local filesystem.'
+    )
+    parser.add_argument(
+        '--copy_eval_data_to_tmp',
+        action='store_true',
+        default=False,
+        help='If set, evaluation data is copied to local filesystem '
+        '(/tmp/). This can speed up training but requires extra space on the '
+        'local filesystem.'
+    )
+    parser.add_argument(
+        '--streaming_eval',
+        action='store_true',
+        default=False,
+        help='If set to True the evaluation is performed in streaming mode. '
+        'During each eval cycle the evaluation data is read and parsed from '
+        'files. This allows for having very large evaluation set. '
+        'If set to False (default) evaluation data is read once and cached in '
+        'memory. This results in faster evaluation cycle but can potentially '
+        'use more memory (in streaming mode large per-file read-ahead buffer is '
+        'used - which may exceed eval data size).'
+    )
+    parser.add_argument(
+        '--pool_size',
+        type=int,
+        default=50,
+        help='Number of examples in the eval set.'
+    )
+    parser.add_argument(
+        '--seed',
+        type=int,
+        help='The random seed to use'
+    )
+
+    args = parser.parse_args(argv)
+
+    env = json.loads(os.environ.get('TF_CONFIG', '{}'))
+
+    # Print the job data as provided by the service.
+    logging.info('Original job data: %s', env.get('job', {}))
+
+    # First find out if there's a task value on the environment variable.
+    # If there is none or it is empty define a default one.
+    task_data = env.get('task', None) or {'type': 'master', 'index': 0}
+    task = type('TaskSpec', (object,), task_data)
+    trial = task_data.get('trial')
+    if trial is not None:
+        args.output_path = os.path.join(args.output_path, trial)
+    if args.write_to_tmp and args.output_path.startswith('gs://'):
+        output_path = args.output_path
+        args.output_path = os.path.join('/tmp/', str(uuid.uuid4()))
+        os.makedirs(args.output_path)
+    else:
+        output_path = None
+
+    if args.copy_train_data_to_tmp:
+        args.train_data_paths = copy_data_to_tmp(args.train_data_paths)
+    if args.copy_eval_data_to_tmp:
+        args.eval_data_paths = copy_data_to_tmp(args.eval_data_paths)
+
+    if not args.eval_batch_size:
+        # If eval_batch_size not set, use min of batch_size and eval_set_size
+        args.eval_batch_size = min(args.batch_size, args.eval_set_size)
+        logging.info("setting eval batch size to %s", args.eval_batch_size)
+
+    cluster_data = env.get('cluster', None)
+    cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
+    if args.write_predictions:
+        write_predictions(args, model, cluster, task)
+    elif args.save_model:
+        save_model(args, model, cluster, task)
+    elif args.predict:
+        predict(args, model, cluster, task)
+    else:
+        dispatch(args, model, cluster, task)
+
+    is_master = not task or task.type == 'master'  # pylint: disable=no-member
+    if output_path and (not cluster or is_master):
+        subprocess.check_call([
+            'gsutil', '-m', '-q', 'cp', '-r', args.output_path + '/*', output_path
+        ])
+        shutil.rmtree(args.output_path, ignore_errors=True)
+
 
 def get_model_factory(model_name):
-  if model_name == 'pix2pix':
-    import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as model_factory
-    return model_factory
-  raise Exception('unsupported model: {}'.format(model_name))
+    if model_name == 'pix2pix':
+        import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as model_factory
+        return model_factory
+    raise Exception('unsupported model: {}'.format(model_name))
+
 
 def main(_):
-  parser = argparse.ArgumentParser()
-  parser.add_argument(
-    '--model',
-    type=str,
-    required=True,
-    help='The name of the model'
-  )
-  args, other_args = parser.parse_known_args()
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--model',
+        type=str,
+        required=True,
+        help='The name of the model'
+    )
+    args, other_args = parser.parse_known_args()
+
+    model_factory = get_model_factory(args.model)
 
-  model_factory = get_model_factory(args.model)
+    model, task_args = model_factory.create_model(other_args)
+    run(model, task_args)
 
-  model, task_args = model_factory.create_model(other_args)
-  run(model, task_args)
 
 if __name__ == '__main__':
-  logging.basicConfig(level=logging.INFO)
-  tf.app.run()
+    logging.basicConfig(level=logging.INFO)
+    tf.app.run()
diff --git a/sciencebeam_gym/trainer/util.py b/sciencebeam_gym/trainer/util.py
index 4d2b75c..557b914 100644
--- a/sciencebeam_gym/trainer/util.py
+++ b/sciencebeam_gym/trainer/util.py
@@ -3,119 +3,128 @@ import logging
 
 import numpy as np
 import tensorflow as tf
-from tensorflow.python.framework import ops # pylint: disable=E0611
-from tensorflow.python.client.session import Session # pylint: disable=E0611
-from tensorflow.python.training.saver import get_checkpoint_state # pylint: disable=E0611
+from tensorflow.python.framework import ops  # pylint: disable=E0611
+from tensorflow.python.client.session import Session  # pylint: disable=E0611
+from tensorflow.python.training.saver import get_checkpoint_state  # pylint: disable=E0611
+
 
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 class CustomSessionManager(object):
-  def __init__(self, session_init_fn, graph=None):
-    self._session_init_fn = session_init_fn
-    if graph is None:
-      graph = ops.get_default_graph()
-    self._graph = graph
-
-  def prepare_session(self, master, checkpoint_dir=None, saver=None, config=None, **_):
-    logger = get_logger()
-    logger.info('prepare_session')
-    session = Session(master, graph=self._graph, config=config)
-    self._session_init_fn(session)
-    if saver and checkpoint_dir:
-      ckpt = get_checkpoint_state(checkpoint_dir)
-      if ckpt and ckpt.model_checkpoint_path:
-        logger.info('restoring from %s', ckpt.model_checkpoint_path)
-        saver.restore(session, ckpt.model_checkpoint_path)
-        saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
-      else:
-        logger.info('no valid checkpoint in %s', checkpoint_dir)
-    return session
+    def __init__(self, session_init_fn, graph=None):
+        self._session_init_fn = session_init_fn
+        if graph is None:
+            graph = ops.get_default_graph()
+        self._graph = graph
+
+    def prepare_session(self, master, checkpoint_dir=None, saver=None, config=None, **_):
+        logger = get_logger()
+        logger.info('prepare_session')
+        session = Session(master, graph=self._graph, config=config)
+        self._session_init_fn(session)
+        if saver and checkpoint_dir:
+            ckpt = get_checkpoint_state(checkpoint_dir)
+            if ckpt and ckpt.model_checkpoint_path:  # pylint: disable=no-member
+                logger.info('restoring from %s',
+                            ckpt.model_checkpoint_path)  # pylint: disable=no-member
+                saver.restore(session, ckpt.model_checkpoint_path)  # pylint: disable=no-member
+                saver.recover_last_checkpoints(
+                    ckpt.all_model_checkpoint_paths)  # pylint: disable=no-member
+            else:
+                logger.info('no valid checkpoint in %s', checkpoint_dir)
+        return session
+
 
 class CustomSupervisor(tf.train.Supervisor):
-  def __init__(self, model, graph, init_op=None, ready_op=None, save_model_secs=0, **kwargs):
-    with graph.as_default():
-      init_op = tf.global_variables_initializer()
-
-    def custom_init(session):
-      logging.info('initializing, session: %s', session)
-      session.run(init_op)
-      model.initialize(session)
-      return True
-
-    session_manager = CustomSessionManager(
-      session_init_fn=custom_init,
-      graph=graph
-    )
-    super(CustomSupervisor, self).__init__(
-      session_manager=session_manager,
-      graph=graph,
-      init_op=init_op,
-      ready_op=ready_op,
-      save_model_secs=save_model_secs,
-      **kwargs
-    )
+    def __init__(self, model, graph, init_op=None, ready_op=None, save_model_secs=0, **kwargs):
+        with graph.as_default():
+            init_op = tf.global_variables_initializer()
+
+        def custom_init(session):
+            logging.info('initializing, session: %s', session)
+            session.run(init_op)
+            model.initialize(session)
+            return True
+
+        session_manager = CustomSessionManager(
+            session_init_fn=custom_init,
+            graph=graph
+        )
+        super(CustomSupervisor, self).__init__(
+            session_manager=session_manager,
+            graph=graph,
+            init_op=init_op,
+            ready_op=ready_op,
+            save_model_secs=save_model_secs,
+            **kwargs
+        )
+
 
 class SimpleStepScheduler(object):
-  """
-  Rather than using threads, with this scheduler the client has full control.
-  For example it can be triggered any time intentionally or at the end.
-  """
-  def __init__(self, do_fn, min_interval, min_freq=0, step=0, last_run=None):
-    self.do_fn = do_fn
-    self.min_interval = min_interval
-    self.min_freq = min_freq
-    self.current_step = step
-    self.last_run = last_run
-    self.dirty = False
-
-  def run_now(self, now):
-    self.do_fn()
-    self.last_run = now
-    self.dirty = False
-
-  def should_trigger(self, now):
-    result = (
-      (
-        (self.min_freq > 0) and
-        (self.current_step % self.min_freq == 0)
-      ) or
-      (
-        (self.min_interval > 0) and
-        (self.last_run is None or (now - self.last_run) >= self.min_interval)
-      )
-    )
-    if result:
-      get_logger().info(
-        'should_trigger: current_step:%s, min_freq=%s, now=%s, '
-        'last_run=%s, min_interval=%s, result=%s',
-        self.current_step, self.min_freq, now,
-        self.last_run, self.min_interval, result
-      )
-    return result
-
-  def step(self, now):
-    self.current_step += 1
-    if self.should_trigger(now=now):
-      self.run_now(now=now)
-    else:
-      self.dirty = True
-
-  def flush(self, now):
-    if self.dirty:
-      self.run_now(now)
+    """
+    Rather than using threads, with this scheduler the client has full control.
+    For example it can be triggered any time intentionally or at the end.
+    """
+
+    def __init__(self, do_fn, min_interval, min_freq=0, step=0, last_run=None):
+        self.do_fn = do_fn
+        self.min_interval = min_interval
+        self.min_freq = min_freq
+        self.current_step = step
+        self.last_run = last_run
+        self.dirty = False
+
+    def run_now(self, now):
+        self.do_fn()
+        self.last_run = now
+        self.dirty = False
+
+    def should_trigger(self, now):
+        result = (
+            (
+                (self.min_freq > 0) and
+                (self.current_step % self.min_freq == 0)
+            ) or
+            (
+                (self.min_interval > 0) and
+                (self.last_run is None or (now - self.last_run) >= self.min_interval)
+            )
+        )
+        if result:
+            get_logger().info(
+                'should_trigger: current_step:%s, min_freq=%s, now=%s, '
+                'last_run=%s, min_interval=%s, result=%s',
+                self.current_step, self.min_freq, now,
+                self.last_run, self.min_interval, result
+            )
+        return result
+
+    def step(self, now):
+        self.current_step += 1
+        if self.should_trigger(now=now):
+            self.run_now(now=now)
+        else:
+            self.dirty = True
+
+    def flush(self, now):
+        if self.dirty:
+            self.run_now(now)
+
 
 def loss(loss_value):
-  """Calculates aggregated mean loss."""
-  total_loss = tf.Variable(0.0, False)
-  loss_count = tf.Variable(0, False)
-  total_loss_update = tf.assign_add(total_loss, loss_value)
-  loss_count_update = tf.assign_add(loss_count, 1)
-  loss_op = total_loss / tf.cast(loss_count, tf.float32)
-  return [total_loss_update, loss_count_update], loss_op
+    """Calculates aggregated mean loss."""
+    total_loss = tf.Variable(0.0, False)
+    loss_count = tf.Variable(0, False)
+    total_loss_update = tf.assign_add(total_loss, loss_value)
+    loss_count_update = tf.assign_add(loss_count, 1)
+    loss_op = total_loss / tf.cast(loss_count, tf.float32)
+    return [total_loss_update, loss_count_update], loss_op
+
 
 def get_graph_size():
-  return sum([
-    int(np.product(v.get_shape().as_list()) * v.dtype.size)
-    for v in tf.global_variables()
-  ])
+    return sum([
+        int(np.product(v.get_shape().as_list()) * v.dtype.size)
+        for v in tf.global_variables()
+    ])
diff --git a/sciencebeam_gym/utils/bounding_box.py b/sciencebeam_gym/utils/bounding_box.py
index 5531c4b..4a4c719 100644
--- a/sciencebeam_gym/utils/bounding_box.py
+++ b/sciencebeam_gym/utils/bounding_box.py
@@ -1,103 +1,111 @@
 class BoundingRange(object):
-  def __init__(self, start, length):
-    self.start = start
-    self.length = length
-    if length < 0:
-      raise ValueError('length must not be less than zero, was: ' + str(length))
+    def __init__(self, start, length):
+        self.start = start
+        self.length = length
+        if length < 0:
+            raise ValueError('length must not be less than zero, was: ' + str(length))
 
-  def __str__(self):
-    return '({}, {})'.format(self.start, self.length)
+    def __str__(self):
+        return '({}, {})'.format(self.start, self.length)
 
-  def __len__(self):
-    return self.length
+    def __len__(self):
+        return self.length
 
-  def empty(self):
-    return self.length == 0
+    def empty(self):
+        return self.length == 0
 
-  def intersects(self, other):
-    return (self.start < other.start + other.length) and (other.start < self.start + self.length)
+    def intersects(self, other):
+        return (
+            (self.start < other.start + other.length) and
+            (other.start < self.start + self.length)
+        )
 
-  def include(self, other):
-    if other.empty():
-      return self
-    if self.empty():
-      return other
-    start = min(self.start, other.start)
-    length = max(self.start + self.length, other.start + other.length) - start
-    return BoundingRange(start, length)
+    def include(self, other):
+        if other.empty():
+            return self
+        if self.empty():
+            return other
+        start = min(self.start, other.start)
+        length = max(self.start + self.length, other.start + other.length) - start
+        return BoundingRange(start, length)
+
+    def __add__(self, other):
+        return self.include(other)
 
-  def __add__(self, other):
-    return self.include(other)
 
 class BoundingBox(object):
-  def __init__(self, x, y, width, height):
-    self.x = x
-    self.y = y
-    self.width = width
-    self.height = height
-    if width < 0:
-      raise ValueError('width must not be less than zero, was: ' + str(width))
-    if height < 0:
-      raise ValueError('height must not be less than zero, was: ' + str(height))
-
-  def __str__(self):
-    return '({}, {}, {}, {})'.format(self.x, self.y, self.width, self.height)
-
-  def __repr__(self):
-    return 'BB({}, {}, {}, {})'.format(self.x, self.y, self.width, self.height)
-
-  def empty(self):
-    return self.width == 0 or self.height == 0
-
-  def move_by(self, rx, ry):
-    return BoundingBox(self.x + rx, self.y + ry, self.width, self.height)
-
-  def scale_by(self, rx, ry):
-    return BoundingBox(self.x * rx, self.y * ry, self.width * rx, self.height * ry)
-
-  def include(self, other):
-    if other.empty():
-      return self
-    if self.empty():
-      return other
-    x = min(self.x, other.x)
-    y = min(self.y, other.y)
-    w = max(self.x + self.width, other.x + other.width) - x
-    h = max(self.y + self.height, other.y + other.height) - y
-    return BoundingBox(x, y, w, h)
-
-  def with_margin(self, dx, dy=None):
-    if dy is None:
-      dy = dx
-    return BoundingBox(
-      self.x - dx,
-      self.y - dy,
-      self.width + 2 * dx,
-      self.height + 2 * dy
-    )
-
-  def intersects(self, other):
-    return self.x_range().intersects(other.x_range()) and self.y_range().intersects(other.y_range())
-
-  def __add__(self, bb):
-    return self.include(bb)
-
-  def x_range(self):
-    return BoundingRange(self.x, self.width)
-
-  def y_range(self):
-    return BoundingRange(self.y, self.height)
-
-  def __eq__(self, other):
-    return (
-      other is not None and
-      self.x == other.x and
-      self.y == other.y and
-      self.width == other.width and
-      self.height == other.height
-    )
-
-  def __hash__(self):
-    return hash((self.x, self.y, self.width, self.height))
+    def __init__(self, x, y, width, height):
+        self.x = x
+        self.y = y
+        self.width = width
+        self.height = height
+        if width < 0:
+            raise ValueError('width must not be less than zero, was: ' + str(width))
+        if height < 0:
+            raise ValueError('height must not be less than zero, was: ' + str(height))
+
+    def __str__(self):
+        return '({}, {}, {}, {})'.format(self.x, self.y, self.width, self.height)
+
+    def __repr__(self):
+        return 'BB({}, {}, {}, {})'.format(self.x, self.y, self.width, self.height)
+
+    def empty(self):
+        return self.width == 0 or self.height == 0
+
+    def move_by(self, rx, ry):
+        return BoundingBox(self.x + rx, self.y + ry, self.width, self.height)
+
+    def scale_by(self, rx, ry):
+        return BoundingBox(self.x * rx, self.y * ry, self.width * rx, self.height * ry)
+
+    def include(self, other):
+        if other.empty():
+            return self
+        if self.empty():
+            return other
+        x = min(self.x, other.x)
+        y = min(self.y, other.y)
+        w = max(self.x + self.width, other.x + other.width) - x
+        h = max(self.y + self.height, other.y + other.height) - y
+        return BoundingBox(x, y, w, h)
+
+    def with_margin(self, dx, dy=None):
+        if dy is None:
+            dy = dx
+        return BoundingBox(
+            self.x - dx,
+            self.y - dy,
+            self.width + 2 * dx,
+            self.height + 2 * dy
+        )
+
+    def intersects(self, other):
+        return (
+            self.x_range().intersects(other.x_range()) and
+            self.y_range().intersects(other.y_range())
+        )
+
+    def __add__(self, bb):
+        return self.include(bb)
+
+    def x_range(self):
+        return BoundingRange(self.x, self.width)
+
+    def y_range(self):
+        return BoundingRange(self.y, self.height)
+
+    def __eq__(self, other):
+        return (
+            other is not None and
+            self.x == other.x and
+            self.y == other.y and
+            self.width == other.width and
+            self.height == other.height
+        )
+
+    def __hash__(self):
+        return hash((self.x, self.y, self.width, self.height))
+
 
 BoundingBox.EMPTY = BoundingBox(0, 0, 0, 0)
diff --git a/sciencebeam_gym/utils/bounding_box_test.py b/sciencebeam_gym/utils/bounding_box_test.py
index d7767ad..457ce53 100644
--- a/sciencebeam_gym/utils/bounding_box_test.py
+++ b/sciencebeam_gym/utils/bounding_box_test.py
@@ -1,46 +1,44 @@
 from sciencebeam_gym.utils.bounding_box import (
-  BoundingBox
+    BoundingBox
 )
 
-class TestBoundingBox(object):
-  def test_should_indicate_empty_with_zero_width(self):
-    assert BoundingBox(0, 0, 0, 100).empty()
 
-  def test_should_indicate_empty_with_zero_height(self):
-    assert BoundingBox(0, 0, 100, 0).empty()
+class TestBoundingBox(object):
+    def test_should_indicate_empty_with_zero_width(self):
+        assert BoundingBox(0, 0, 0, 100).empty()
 
-  def test_should_indicate_not_be_empty_with_non_zero_width_and_height(self):
-    assert not BoundingBox(0, 0, 100, 100).empty()
+    def test_should_indicate_empty_with_zero_height(self):
+        assert BoundingBox(0, 0, 100, 0).empty()
 
-  def test_should_equal_same_bounding_boxes(self):
-    assert BoundingBox(11, 12, 101, 102) == BoundingBox(11, 12, 101, 102)
+    def test_should_indicate_not_be_empty_with_non_zero_width_and_height(self):
+        assert not BoundingBox(0, 0, 100, 100).empty()
 
-  def test_should_not_equal_bounding_boxes_with_different_x(self):
-    assert BoundingBox(11, 12, 101, 102) != BoundingBox(99, 12, 101, 102)
+    def test_should_equal_same_bounding_boxes(self):
+        assert BoundingBox(11, 12, 101, 102) == BoundingBox(11, 12, 101, 102)
 
-  def test_should_not_equal_bounding_boxes_with_different_y(self):
-    assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 99, 101, 102)
+    def test_should_not_equal_bounding_boxes_with_different_x(self):
+        assert BoundingBox(11, 12, 101, 102) != BoundingBox(99, 12, 101, 102)
 
-  def test_should_not_equal_bounding_boxes_with_different_width(self):
-    assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 12, 999, 102)
+    def test_should_not_equal_bounding_boxes_with_different_y(self):
+        assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 99, 101, 102)
 
-  def test_should_not_equal_bounding_boxes_with_different_height(self):
-    assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 12, 101, 999)
+    def test_should_not_equal_bounding_boxes_with_different_width(self):
+        assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 12, 999, 102)
 
-  def test_should_not_equal_none(self):
-    assert BoundingBox(11, 12, 101, 102) != None
+    def test_should_not_equal_bounding_boxes_with_different_height(self):
+        assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 12, 101, 999)
 
-  def test_should_not_equal_to_none(self):
-    assert None != BoundingBox(11, 12, 101, 102)
+    def test_should_not_equal_none(self):
+        assert not BoundingBox(11, 12, 101, 102).__eq__(None)
 
-  def test_should_include_another_bounding_box_to_the_bottom_right(self):
-    assert (
-      BoundingBox(10, 20, 50, 100).include(BoundingBox(100, 100, 200, 200)) ==
-      BoundingBox(10, 20, 100 + 200 - 10, 100 + 200 - 20)
-    )
+    def test_should_include_another_bounding_box_to_the_bottom_right(self):
+        assert (
+            BoundingBox(10, 20, 50, 100).include(BoundingBox(100, 100, 200, 200)) ==
+            BoundingBox(10, 20, 100 + 200 - 10, 100 + 200 - 20)
+        )
 
-  def test_should_include_another_bounding_box_to_the_top_left(self):
-    assert (
-      BoundingBox(100, 100, 200, 200).include(BoundingBox(10, 20, 50, 100)) ==
-      BoundingBox(10, 20, 100 + 200 - 10, 100 + 200 - 20)
-    )
+    def test_should_include_another_bounding_box_to_the_top_left(self):
+        assert (
+            BoundingBox(100, 100, 200, 200).include(BoundingBox(10, 20, 50, 100)) ==
+            BoundingBox(10, 20, 100 + 200 - 10, 100 + 200 - 20)
+        )
diff --git a/sciencebeam_gym/utils/pages_zip.py b/sciencebeam_gym/utils/pages_zip.py
index f992400..f2f979b 100644
--- a/sciencebeam_gym/utils/pages_zip.py
+++ b/sciencebeam_gym/utils/pages_zip.py
@@ -4,32 +4,35 @@ from zipfile import ZipFile, ZIP_DEFLATED
 from apache_beam.io.filesystems import FileSystems
 
 from sciencebeam_utils.beam_utils.io import (
-  dirname,
-  mkdirs_if_not_exists
+    dirname,
+    mkdirs_if_not_exists
 )
 
+
 def get_logger():
-  return logging.getLogger(__name__)
+    return logging.getLogger(__name__)
+
 
 def load_pages(filename, page_range=None):
-  with FileSystems.open(filename) as f:
-    with ZipFile(f) as zf:
-      filenames = zf.namelist()
-      if page_range:
-        filenames = filenames[
-          max(0, page_range[0] - 1):
-          page_range[1]
-        ]
-      for filename in filenames:
-        with zf.open(filename) as f:
-          yield f
+    with FileSystems.open(filename) as f:
+        with ZipFile(f) as zf:
+            filenames = zf.namelist()
+            if page_range:
+                filenames = filenames[
+                    max(0, page_range[0] - 1):
+                    page_range[1]
+                ]
+            for member_filename in filenames:
+                with zf.open(member_filename) as f:
+                    yield f
+
 
 def save_pages(output_filename, ext, bytes_by_page):
-  mkdirs_if_not_exists(dirname(output_filename))
-  with FileSystems.create(output_filename) as f:
-    with ZipFile(f, 'w', compression=ZIP_DEFLATED) as zf:
-      for i, data in enumerate(bytes_by_page):
-        page_filename = 'page-%s%s' % (1 + i, ext)
-        get_logger().debug('page_filename: %s', page_filename)
-        zf.writestr(page_filename, data)
-    return output_filename
+    mkdirs_if_not_exists(dirname(output_filename))
+    with FileSystems.create(output_filename) as f:
+        with ZipFile(f, 'w', compression=ZIP_DEFLATED) as zf:
+            for i, data in enumerate(bytes_by_page):
+                page_filename = 'page-%s%s' % (1 + i, ext)
+                get_logger().debug('page_filename: %s', page_filename)
+                zf.writestr(page_filename, data)
+        return output_filename
diff --git a/sciencebeam_gym/utils/pyplot.py b/sciencebeam_gym/utils/pyplot.py
new file mode 100644
index 0000000..622d0a9
--- /dev/null
+++ b/sciencebeam_gym/utils/pyplot.py
@@ -0,0 +1,6 @@
+import matplotlib as mpl
+# this is important to run on the cloud - we won't have python-tk installed
+mpl.use("Agg")
+
+# pylint: disable=unused-import, wrong-import-position
+from matplotlib import pyplot  # flake8: noqa
diff --git a/sciencebeam_gym/utils/tf.py b/sciencebeam_gym/utils/tf.py
index dab93df..1f78b4b 100644
--- a/sciencebeam_gym/utils/tf.py
+++ b/sciencebeam_gym/utils/tf.py
@@ -5,19 +5,22 @@ from tensorflow.python.lib.io import file_io
 from tensorflow.python.framework import errors as tf_errors
 # pylint: enable=E0611
 
-def variable_scoped(name, fn):
-  with tf.variable_scope(name):
-    return fn()
+
+def variable_scoped(name, fn, *args, **kwargs):
+    with tf.variable_scope(name):
+        return fn(*args, **kwargs)
+
 
 def tf_print(x, message=None, **kwargs):
-  return tf.Print(x, [x], message=message, **kwargs)
+    return tf.Print(x, [x], message=message, **kwargs)
+
 
 def FileIO(filename, mode):
-  try:
-    return file_io.FileIO(filename, mode)
-  except tf_errors.InvalidArgumentError:
-    if 'b' in mode:
-      # older versions of TF don't support the 'b' flag as such
-      return file_io.FileIO(filename, mode.replace('b', ''))
-    else:
-      raise
+    try:
+        return file_io.FileIO(filename, mode)
+    except tf_errors.InvalidArgumentError:
+        if 'b' in mode:
+            # older versions of TF don't support the 'b' flag as such
+            return file_io.FileIO(filename, mode.replace('b', ''))
+        else:
+            raise
diff --git a/sciencebeam_gym/utils/tfrecord.py b/sciencebeam_gym/utils/tfrecord.py
index 8cd7606..5a827c3 100644
--- a/sciencebeam_gym/utils/tfrecord.py
+++ b/sciencebeam_gym/utils/tfrecord.py
@@ -4,50 +4,56 @@ from six import iteritems, raise_from, text_type, binary_type
 
 import tensorflow as tf
 
-def encode_value_as_feature(value, name):
-  try:
-    if isinstance(value, text_type):
-      value = value.encode('utf-8')
-    if isinstance(value, binary_type):
-      return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
-    if isinstance(value, int):
-      return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
-    raise TypeError('unsupported type: %s' % type(value))
-  except TypeError as e:
-    raise_from(TypeError('failed to convert %s due to %s' % (name, e)), e)
+
+def encode_value_as_feature(value, name):  # pylint: disable=inconsistent-return-statements
+    try:
+        if isinstance(value, text_type):
+            value = value.encode('utf-8')
+        if isinstance(value, binary_type):
+            return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+        if isinstance(value, int):
+            return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+        raise TypeError('unsupported type: %s' % type(value))
+    except TypeError as e:
+        raise_from(TypeError('failed to convert %s due to %s' % (name, e)), e)
+
 
 def decode_feature_value(feature):
-  if feature.bytes_list.value:
-    return feature.bytes_list.value[0]
-  if feature.int64_list.value:
-    return feature.int64_list.value[0]
-  raise TypeError('unsupported feature: %s' % feature)
+    if feature.bytes_list.value:
+        return feature.bytes_list.value[0]
+    if feature.int64_list.value:
+        return feature.int64_list.value[0]
+    raise TypeError('unsupported feature: %s' % feature)
+
 
 def iter_examples_to_dict_list(examples, keys=None):
-  for example in examples:
-    result = tf.train.Example.FromString(example)
-    yield {
-     key: decode_feature_value(result.features.feature.get(key))
-     for key in result.features.feature.keys()
-     if keys is None or key in keys
-    }
+    for example in examples:
+        result = tf.train.Example.FromString(example)  # pylint: disable=no-member
+        yield {
+            key: decode_feature_value(result.features.feature.get(key))
+            for key in result.features.feature.keys()
+            if keys is None or key in keys
+        }
+
 
 def iter_read_tfrecord_file_as_dict_list(filename, keys=None):
-  options = None
-  if filename.endswith('.gz'):
-    options = tf.python_io.TFRecordOptions(
-      compression_type=tf.python_io.TFRecordCompressionType.GZIP
-    )
-  examples = tf.python_io.tf_record_iterator(filename, options=options)
-  return iter_examples_to_dict_list(examples, keys=keys)
+    options = None
+    if filename.endswith('.gz'):
+        options = tf.python_io.TFRecordOptions(
+            compression_type=tf.python_io.TFRecordCompressionType.GZIP
+        )
+    examples = tf.python_io.tf_record_iterator(filename, options=options)
+    return iter_examples_to_dict_list(examples, keys=keys)
+
 
 def dict_to_example(props):
-  return tf.train.Example(features=tf.train.Features(feature={
-    k: encode_value_as_feature(v, name=k)
-    for k, v in iteritems(props)
-  }))
+    return tf.train.Example(features=tf.train.Features(feature={
+        k: encode_value_as_feature(v, name=k)
+        for k, v in iteritems(props)
+    }))
+
 
 def write_examples_to_tfrecord(tfrecord_filename, examples):
-  with tf.python_io.TFRecordWriter(tfrecord_filename) as writer:
-    for example in examples:
-      writer.write(example.SerializeToString())
+    with tf.python_io.TFRecordWriter(tfrecord_filename) as writer:
+        for example in examples:
+            writer.write(example.SerializeToString())
diff --git a/sciencebeam_gym/utils/tfrecord_test.py b/sciencebeam_gym/utils/tfrecord_test.py
index 3ed63fc..c1689ad 100644
--- a/sciencebeam_gym/utils/tfrecord_test.py
+++ b/sciencebeam_gym/utils/tfrecord_test.py
@@ -1,30 +1,33 @@
 from six import text_type
 
 from sciencebeam_gym.utils.tfrecord import (
-  iter_examples_to_dict_list,
-  dict_to_example
+    iter_examples_to_dict_list,
+    dict_to_example
 )
 
+
 def dict_to_example_and_reverse(props):
-  return list(iter_examples_to_dict_list([
-    dict_to_example(props).SerializeToString()
-  ]))[0]
+    return list(iter_examples_to_dict_list([
+        dict_to_example(props).SerializeToString()
+    ]))[0]
+
 
 def assert_dict_to_example_and_reverse(props):
-  assert dict_to_example_and_reverse(props) == props
+    assert dict_to_example_and_reverse(props) == props
+
 
 class TestDictToExampleAndIterExamplesToDictList(object):
-  def test_should_handle_bytes(self):
-    assert_dict_to_example_and_reverse({
-      b'a': b'data'
-    })
-
-  def test_should_handle_unicode(self):
-    assert_dict_to_example_and_reverse({
-      b'a': text_type('data')
-    })
-
-  def test_should_handle_int(self):
-    assert_dict_to_example_and_reverse({
-      b'a': 1
-    })
+    def test_should_handle_bytes(self):
+        assert_dict_to_example_and_reverse({
+            b'a': b'data'
+        })
+
+    def test_should_handle_unicode(self):
+        assert_dict_to_example_and_reverse({
+            b'a': text_type('data')
+        })
+
+    def test_should_handle_int(self):
+        assert_dict_to_example_and_reverse({
+            b'a': 1
+        })
diff --git a/setup.py b/setup.py
index 13a8d24..b32f1b2 100644
--- a/setup.py
+++ b/setup.py
@@ -4,91 +4,95 @@ import os
 import subprocess
 import shlex
 
+from distutils.command.build import build  # pylint: disable=import-error, no-name-in-module
+
 from setuptools import (
-  find_packages,
-  setup,
-  Command,
-  Extension
+    find_packages,
+    setup,
+    Command,
+    Extension
 )
 
-from distutils.command.build import build
-
 import numpy as np
 
 
 CUSTOM_COMMANDS = [
-  shlex.split(command_line) for command_line in [
-    'apt-get update',
-    'apt-get --assume-yes install libxml2',
-    'apt-get --assume-yes install poppler-utils'
-  ]
+    shlex.split(command_line) for command_line in [
+        'apt-get update',
+        'apt-get --assume-yes install libxml2',
+        'apt-get --assume-yes install poppler-utils'
+    ]
 ]
 
 with open(os.path.join('requirements.txt'), 'r') as f:
-  REQUIRED_PACKAGES = f.readlines()
+    REQUIRED_PACKAGES = f.readlines()
 
 packages = find_packages()
 
 # This class handles the pip install mechanism.
+
+
 class CustomBuild(build):
-  """A build command class that will be invoked during package install.
-  The package built using the current setup.py will be staged and later
-  installed in the worker using `pip install package'. This class will be
-  instantiated during install for this specific scenario and will trigger
-  running the custom commands specified.
-  """
-  sub_commands = build.sub_commands + [('CustomCommands', None)]
+    """A build command class that will be invoked during package install.
+    The package built using the current setup.py will be staged and later
+    installed in the worker using `pip install package'. This class will be
+    instantiated during install for this specific scenario and will trigger
+    running the custom commands specified.
+    """
+    sub_commands = build.sub_commands + [('CustomCommands', None)]
+
 
 class CustomCommands(Command):
-  """A setuptools Command class able to run arbitrary commands."""
-
-  def initialize_options(self):
-    pass
-
-  def finalize_options(self):
-    pass
-
-  def _run_custom_command(self, command_list):
-    print('Running command: %s' % command_list)
-    p = subprocess.Popen(
-      command_list,
-      stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
-    )
-    # Can use communicate(input='y\n'.encode()) if the command run requires
-    # some confirmation.
-    stdout_data, _ = p.communicate()
-    print('Command output: %s' % stdout_data)
-    if p.returncode != 0:
-      raise RuntimeError(
-        'Command %s failed: exit code: %s (output: %s)' %
-        (command_list, p.returncode, stdout_data)
-      )
-
-  def run(self):
-    for command in CUSTOM_COMMANDS:
-      self._run_custom_command(command)
+    """A setuptools Command class able to run arbitrary commands."""
+
+    def initialize_options(self):
+        pass
+
+    def finalize_options(self):
+        pass
+
+    def _run_custom_command(self, command_list):
+        print('Running command: %s' % command_list)
+        p = subprocess.Popen(
+            command_list,
+            stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
+        )
+        # Can use communicate(input='y\n'.encode()) if the command run requires
+        # some confirmation.
+        stdout_data, _ = p.communicate()
+        print('Command output: %s' % stdout_data)
+        if p.returncode != 0:
+            raise RuntimeError(
+                'Command %s failed: exit code: %s (output: %s)' %
+                (command_list, p.returncode, stdout_data)
+            )
+
+    def run(self):
+        for command in CUSTOM_COMMANDS:
+            self._run_custom_command(command)
+
 
 setup(
-  name='sciencebeam_gym',
-  version='0.0.1',
-  install_requires=REQUIRED_PACKAGES,
-  packages=packages,
-  include_package_data=True,
-  description='ScienceBeam Gym',
-  setup_requires=[
-    # Setuptools 18.0 properly handles Cython extensions.
-    'setuptools>=18.0',
-    'cython',
-  ],
-  ext_modules=[
-    Extension(
-      'sciencebeam_gym.alignment.align_fast_utils',
-      sources=['sciencebeam_gym/alignment/align_fast_utils.pyx'],
-    ),
-  ],
-  include_dirs=[np.get_include()],
-  cmdclass={
-    'build': CustomBuild,
-    'CustomCommands': CustomCommands
-  }
+    name='sciencebeam_gym',
+    version='0.0.1',
+    install_requires=REQUIRED_PACKAGES,
+    packages=packages,
+    include_package_data=True,
+    description='ScienceBeam Gym',
+    setup_requires=[
+        # Setuptools 18.0 properly handles Cython extensions.
+        'setuptools>=18.0',
+        'cython',
+    ],
+    ext_modules=[
+        Extension(
+            'sciencebeam_gym.alignment.align_fast_utils',
+            sources=['sciencebeam_gym/alignment/align_fast_utils.pyx'],
+        ),
+    ],
+    include_dirs=[np.get_include()],
+    cmdclass={
+        'build': CustomBuild,
+        'CustomCommands': CustomCommands
+    }
 )
-- 
GitLab