From 91e1c0d0b16ff2dfc6f9dd8e6cee2e975639b53a Mon Sep 17 00:00:00 2001 From: Daniel Ecer <de-code@users.noreply.github.com> Date: Fri, 2 Nov 2018 16:45:33 +0000 Subject: [PATCH] pylint and flake8 checking (#39) * added pylint check * added pylintrc to docker image * reduced accessive apache beam debug logging * configured pylint, addressed linting * enabled flake8 checks * downgrade pycodestyle to 2.3.1 due to error * switch to 4 spaces indent * autopep8 * more flake8 * added new line to .flake8 --- .flake8 | 4 + .pylintrc | 34 +- Dockerfile | 3 + project_tests.sh | 8 + requirements.dev.txt | 2 + sciencebeam_gym/conftest.py | 6 +- .../convert/conversion_pipeline.py | 901 ++++----- .../convert/conversion_pipeline_test.py | 856 ++++---- .../convert/cv_conversion_utils.py | 67 +- .../convert/cv_conversion_utils_test.py | 74 +- .../convert/grobid/grobid_service.py | 134 +- .../convert/grobid/grobid_service_test.py | 90 +- .../convert/grobid/grobid_service_wrapper.py | 208 +- .../convert/grobid/grobid_xml_enhancer.py | 157 +- .../grobid/grobid_xml_enhancer_test.py | 319 +-- sciencebeam_gym/inference_model/__init__.py | 161 +- .../inference_model/__init___test.py | 106 +- .../annotate_using_predictions.py | 289 +-- .../annotate_using_predictions_test.py | 321 +-- .../extract_from_annotated_document.py | 161 +- .../extract_from_annotated_document_test.py | 423 ++-- .../inference_model/extract_to_xml.py | 308 +-- .../inference_model/extract_to_xml_test.py | 285 +-- sciencebeam_gym/model_utils/channels.py | 53 +- .../text/crf/annotate_using_predictions.py | 202 +- .../crf/annotate_using_predictions_test.py | 162 +- .../models/text/crf/crfsuite_model.py | 23 +- .../models/text/crf/crfsuite_model_test.py | 178 +- .../text/crf/crfsuite_training_pipeline.py | 326 ++-- .../crf/crfsuite_training_pipeline_test.py | 476 +++-- .../models/text/feature_extractor.py | 192 +- .../models/text/feature_extractor_test.py | 573 +++--- sciencebeam_gym/pdf/__init__.py | 4 +- sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py | 243 +-- sciencebeam_gym/pdf/pdf_to_png.py | 108 +- sciencebeam_gym/pdf/pdf_to_png_test.py | 85 +- .../annotation/annotation_evaluation.py | 81 +- .../annotation/annotation_evaluation_test.py | 89 +- .../preprocess/annotation/annotator.py | 46 +- .../preprocess/annotation/annotator_test.py | 47 +- .../preprocess/annotation/find_line_number.py | 91 +- .../annotation/find_line_numbers_test.py | 187 +- .../preprocess/annotation/fuzzy_match.py | 502 ++--- .../preprocess/annotation/fuzzy_match_test.py | 331 ++-- .../annotation/matching_annotator.py | 1166 +++++------ .../annotation/matching_annotator_test.py | 1735 +++++++++-------- .../annotation/target_annotation.py | 695 +++---- .../annotation/target_annotation_test.py | 1185 +++++------ .../preprocess/blockify_annotations.py | 446 +++-- .../preprocess/blockify_annotations_test.py | 488 ++--- sciencebeam_gym/preprocess/color_map.py | 55 +- sciencebeam_gym/preprocess/color_map_test.py | 21 +- sciencebeam_gym/preprocess/lxml_to_svg.py | 388 ++-- .../preprocess/lxml_to_svg_test.py | 339 ++-- .../preprocess/preprocessing_pipeline.py | 923 ++++----- .../preprocess/preprocessing_pipeline_test.py | 714 +++---- .../preprocess/preprocessing_transforms.py | 50 +- .../preprocessing_transforms_test.py | 71 +- .../preprocess/preprocessing_utils.py | 275 ++- .../preprocess/preprocessing_utils_test.py | 92 +- .../preprocess/visualize_svg_annotation.py | 118 +- .../visualize_svg_annotation_test.py | 138 +- .../structured_document/__init__.py | 380 ++-- .../structured_document/__init___test.py | 206 +- sciencebeam_gym/structured_document/lxml.py | 77 +- .../structured_document/lxml_test.py | 275 +-- .../structured_document_loader.py | 71 +- .../structured_document_loader_test.py | 175 +- .../structured_document_saver.py | 31 +- .../structured_document_saver_test.py | 71 +- sciencebeam_gym/structured_document/svg.py | 142 +- .../structured_document/svg_test.py | 337 ++-- .../tools/calculate_class_weights.py | 412 ++-- .../tools/calculate_class_weights_test.py | 611 +++--- sciencebeam_gym/tools/colorize_image.py | 142 +- sciencebeam_gym/tools/inspect_tfrecords.py | 164 +- sciencebeam_gym/tools/resize_image.py | 81 +- sciencebeam_gym/trainer/checkpoint.py | 54 +- sciencebeam_gym/trainer/data/examples.py | 180 +- sciencebeam_gym/trainer/data/examples_test.py | 135 +- sciencebeam_gym/trainer/evaluator.py | 847 ++++---- sciencebeam_gym/trainer/evaluator_test.py | 184 +- .../trainer/models/pix2pix/evaluate.py | 182 +- .../trainer/models/pix2pix/evaluate_test.py | 75 +- .../trainer/models/pix2pix/loss.py | 61 +- .../trainer/models/pix2pix/loss_test.py | 233 +-- .../trainer/models/pix2pix/pix2pix_core.py | 996 +++++----- .../models/pix2pix/pix2pix_core_test.py | 342 ++-- .../trainer/models/pix2pix/pix2pix_model.py | 1165 +++++------ .../models/pix2pix/pix2pix_model_test.py | 716 ++++--- .../trainer/models/pix2pix/tf_utils.py | 79 +- .../trainer/models/pix2pix/tf_utils_test.py | 132 +- sciencebeam_gym/trainer/predict.py | 64 +- sciencebeam_gym/trainer/saver.py | 14 +- sciencebeam_gym/trainer/task.py | 1361 ++++++------- sciencebeam_gym/trainer/util.py | 217 ++- sciencebeam_gym/utils/bounding_box.py | 192 +- sciencebeam_gym/utils/bounding_box_test.py | 62 +- sciencebeam_gym/utils/pages_zip.py | 47 +- sciencebeam_gym/utils/pyplot.py | 6 + sciencebeam_gym/utils/tf.py | 27 +- sciencebeam_gym/utils/tfrecord.py | 80 +- sciencebeam_gym/utils/tfrecord_test.py | 43 +- setup.py | 140 +- 104 files changed, 14774 insertions(+), 13850 deletions(-) create mode 100644 .flake8 create mode 100644 sciencebeam_gym/utils/pyplot.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..c184d2f --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 100 +ignore = + E731 # do not assign a lambda expression, use a def diff --git a/.pylintrc b/.pylintrc index 4242187..1d99acd 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,11 +1,29 @@ -[FORMAT] -indent-string=' ' -indent-after-paren=2 - [MESSAGES CONTROL] -# Disable C0111 (docstrings) - use useful method names instead -# Disable C0103 (invalid variable name) - in very short blocks it is okay to use generic names -disable=C0111,C0103,W0108,E1101,C0330,E1129,C1801 +disable= + missing-docstring, + too-few-public-methods, + no-self-use, + invalid-name, + too-many-arguments, + duplicate-code, + too-many-locals, + too-many-instance-attributes, + too-many-statements, + too-many-branches, + too-many-public-methods, + no-else-return, + unnecessary-lambda, + len-as-condition, + fixme, + bad-continuation [TYPECHECK] -ignored-modules=sciencelab.alignment.AlignmentMatrix +ignored-classes= + lxml.builder.ElementMaker, + lxml.builder.E +extension-pkg-whitelist= + lxml +ignored-modules= + sciencelab.alignment.AlignmentMatrix, + tensorflow, + apache_beam diff --git a/Dockerfile b/Dockerfile index 181da08..4528f26 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,3 +16,6 @@ RUN venv/bin/pip install -r requirements.txt COPY sciencebeam_gym ${PROJECT_HOME}/sciencebeam_gym COPY *.conf *.sh *.in *.txt *.py ${PROJECT_HOME}/ + +# tests +COPY .pylintrc .flake8 ${PROJECT_HOME}/ diff --git a/project_tests.sh b/project_tests.sh index 9584518..27f2114 100755 --- a/project_tests.sh +++ b/project_tests.sh @@ -4,3 +4,11 @@ set -e pip install -r requirements.dev.txt pytest sciencebeam_gym + +echo "running pylint" +pylint sciencebeam_gym setup.py + +echo "running flake8" +flake8 sciencebeam_gym setup.py + +echo "done" diff --git a/requirements.dev.txt b/requirements.dev.txt index 220de02..88ffb2a 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,5 +1,7 @@ +flake8==3.5.0 Cython==0.28.1 nose==1.3.7 +pycodestyle==2.3.1 pylint==1.8.3 pytest==3.4.2 pytest-watch==4.1.0 diff --git a/sciencebeam_gym/conftest.py b/sciencebeam_gym/conftest.py index b12470c..8102c5b 100644 --- a/sciencebeam_gym/conftest.py +++ b/sciencebeam_gym/conftest.py @@ -2,7 +2,9 @@ import logging import pytest + @pytest.fixture(scope='session', autouse=True) def setup_logging(): - logging.root.handlers = [] - logging.basicConfig(level='DEBUG') + logging.root.handlers = [] + logging.basicConfig(level='WARNING') + logging.getLogger('sciencebeam_gym').setLevel('DEBUG') diff --git a/sciencebeam_gym/convert/conversion_pipeline.py b/sciencebeam_gym/convert/conversion_pipeline.py index 3a47ef3..fc5ea0a 100644 --- a/sciencebeam_gym/convert/conversion_pipeline.py +++ b/sciencebeam_gym/convert/conversion_pipeline.py @@ -13,557 +13,584 @@ from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions from lxml import etree from sciencebeam_utils.beam_utils.utils import ( - TransformAndCount, - TransformAndLog, - MapOrLog, - PreventFusion + TransformAndCount, + TransformAndLog, + MapOrLog, + PreventFusion ) from sciencebeam_utils.beam_utils.files import ( - ReadFileList, - FindFiles + ReadFileList, + FindFiles ) from sciencebeam_utils.beam_utils.io import ( - read_all_from_path, - save_file_content + read_all_from_path, + save_file_content ) from sciencebeam_utils.beam_utils.main import ( - add_cloud_args, - process_cloud_args, - process_sciencebeam_gym_dep_args + add_cloud_args, + process_cloud_args, + process_sciencebeam_gym_dep_args ) from sciencebeam_utils.utils.collection import ( - extend_dict, - remove_keys_from_dict + extend_dict, + remove_keys_from_dict ) from sciencebeam_utils.utils.file_path import ( - join_if_relative_path, - get_output_file + join_if_relative_path, + get_output_file ) from sciencebeam_gym.structured_document.structured_document_loader import ( - load_structured_document + load_structured_document ) from sciencebeam_gym.structured_document.lxml import ( - LxmlStructuredDocument + LxmlStructuredDocument ) from sciencebeam_gym.preprocess.preprocessing_utils import ( - convert_pdf_bytes_to_lxml, - parse_page_range, - save_pages, - pdf_bytes_to_png_pages + convert_pdf_bytes_to_lxml, + parse_page_range, + save_pages, + pdf_bytes_to_png_pages ) from sciencebeam_gym.inference_model.extract_to_xml import ( - extract_structured_document_to_xml + extract_structured_document_to_xml ) from sciencebeam_gym.models.text.crf.annotate_using_predictions import ( - predict_and_annotate_structured_document, - CRF_TAG_SCOPE + predict_and_annotate_structured_document, + CRF_TAG_SCOPE ) from sciencebeam_gym.inference_model.annotate_using_predictions import ( - annotate_structured_document_using_predicted_images, - AnnotatedImage, - CV_TAG_SCOPE + annotate_structured_document_using_predicted_images, + AnnotatedImage, + CV_TAG_SCOPE ) from .grobid.grobid_xml_enhancer import ( - GrobidXmlEnhancer + GrobidXmlEnhancer ) from .cv_conversion_utils import ( - InferenceModelWrapper, - image_data_to_png + InferenceModelWrapper, + image_data_to_png ) from .grobid.grobid_service import ( - grobid_service, - GrobidApiPaths + grobid_service, + GrobidApiPaths ) def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class MetricCounters(object): - FILES = 'files' - READ_LXML_ERROR = 'read_lxml_error_count' - CONVERT_PDF_TO_LXML_ERROR = 'ConvertPdfToLxml_error_count' - CONVERT_PDF_TO_PNG_ERROR = 'ConvertPdfToPng_error_count' - CONVERT_LXML_TO_SVG_ANNOT_ERROR = 'ConvertPdfToSvgAnnot_error_count' - CV_PREDICTION_ERROR = 'ComputerVisionPrediction_error_count' - ANNOTATE_USING_PREDICTION_ERROR = 'AnnotateLxmlUsingPrediction_error_count' - EXTRACT_TO_XML_ERROR = 'ExtractToXml_error_count' - GROBID_ERROR = 'Grobid_error_count' + FILES = 'files' + READ_LXML_ERROR = 'read_lxml_error_count' + CONVERT_PDF_TO_LXML_ERROR = 'ConvertPdfToLxml_error_count' + CONVERT_PDF_TO_PNG_ERROR = 'ConvertPdfToPng_error_count' + CONVERT_LXML_TO_SVG_ANNOT_ERROR = 'ConvertPdfToSvgAnnot_error_count' + CV_PREDICTION_ERROR = 'ComputerVisionPrediction_error_count' + ANNOTATE_USING_PREDICTION_ERROR = 'AnnotateLxmlUsingPrediction_error_count' + EXTRACT_TO_XML_ERROR = 'ExtractToXml_error_count' + GROBID_ERROR = 'Grobid_error_count' + class OutputExt(object): - CRF_ANNOT_LXML = '.crf.lxml.gz' - CRF_CV_ANNOT_LXML = '.crf-cv.lxml.gz' - CV_ANNOT_LXML = '.cv.lxml.gz' - CV_PNG = '.cv-png.zip' + CRF_ANNOT_LXML = '.crf.lxml.gz' + CRF_CV_ANNOT_LXML = '.crf-cv.lxml.gz' + CV_ANNOT_LXML = '.cv.lxml.gz' + CV_PNG = '.cv-png.zip' + class DataProps(object): - SOURCE_FILENAME = 'source_filename' - PDF_CONTENT = 'pdf_content' - STRUCTURED_DOCUMENT = 'structured_document' - PDF_PNG_PAGES = 'pdf_png_pages' - CV_PREDICTION_PNG_PAGES = 'cv_prediction_png_pages' - COLOR_MAP = 'color_map' - EXTRACTED_XML = 'extracted_xml' + SOURCE_FILENAME = 'source_filename' + PDF_CONTENT = 'pdf_content' + STRUCTURED_DOCUMENT = 'structured_document' + PDF_PNG_PAGES = 'pdf_png_pages' + CV_PREDICTION_PNG_PAGES = 'cv_prediction_png_pages' + COLOR_MAP = 'color_map' + EXTRACTED_XML = 'extracted_xml' + def convert_pdf_bytes_to_structured_document(pdf_content, path=None, page_range=None): - return LxmlStructuredDocument(etree.parse(BytesIO( - convert_pdf_bytes_to_lxml(pdf_content, path=path, page_range=page_range) - ))) + return LxmlStructuredDocument(etree.parse(BytesIO( + convert_pdf_bytes_to_lxml(pdf_content, path=path, page_range=page_range) + ))) + def annotate_structured_document_using_predicted_image_data( - structured_document, prediction_images, color_map, tag_scope=None): + structured_document, prediction_images, color_map, tag_scope=None): + + return annotate_structured_document_using_predicted_images( + structured_document, ( + AnnotatedImage(prediction_image, color_map) + for prediction_image in prediction_images + ), tag_scope=tag_scope + ) - return annotate_structured_document_using_predicted_images( - structured_document, ( - AnnotatedImage(prediction_image, color_map) - for prediction_image in prediction_images - ), tag_scope=tag_scope - ) def extract_annotated_structured_document_to_xml(structured_document, tag_scope=None): - xml_root = extract_structured_document_to_xml(structured_document, tag_scope=tag_scope) - return etree.tostring(xml_root, pretty_print=True) + xml_root = extract_structured_document_to_xml(structured_document, tag_scope=tag_scope) + return etree.tostring(xml_root, pretty_print=True) + def load_crf_model(path): - with FileSystems.open(path) as crf_model_f: - return pickle.load(crf_model_f) + with FileSystems.open(path) as crf_model_f: + return pickle.load(crf_model_f) + def save_structured_document(filename, structured_document): - # only support saving lxml for now - assert isinstance(structured_document, LxmlStructuredDocument) - save_file_content(filename, etree.tostring(structured_document.root, pretty_print=True)) - return filename + # only support saving lxml for now + assert isinstance(structured_document, LxmlStructuredDocument) + save_file_content(filename, etree.tostring(structured_document.root, pretty_print=True)) + return filename + def get_annot_lxml_ext(crf_enabled, cv_enabled): - if crf_enabled and cv_enabled: - return OutputExt.CRF_CV_ANNOT_LXML - if crf_enabled: - return OutputExt.CRF_ANNOT_LXML - if cv_enabled: - return OutputExt.CV_ANNOT_LXML - raise AssertionError('at least one of crf or cv need to be enabled') + if crf_enabled and cv_enabled: + return OutputExt.CRF_CV_ANNOT_LXML + if crf_enabled: + return OutputExt.CRF_ANNOT_LXML + if cv_enabled: + return OutputExt.CV_ANNOT_LXML + raise AssertionError('at least one of crf or cv need to be enabled') + def PdfUrlSource(opt): - if opt.pdf_file_list: - return ReadFileList(opt.pdf_file_list, column=opt.pdf_file_column, limit=opt.limit) - else: - return FindFiles(join_if_relative_path(opt.base_data_path, opt.pdf_path)) + if opt.pdf_file_list: + return ReadFileList(opt.pdf_file_list, column=opt.pdf_file_column, limit=opt.limit) + else: + return FindFiles(join_if_relative_path(opt.base_data_path, opt.pdf_path)) + def ReadPdfContent(): - return "ReadPdfContent" >> TransformAndCount( - beam.Map(lambda pdf_url: { - DataProps.SOURCE_FILENAME: pdf_url, - DataProps.PDF_CONTENT: read_all_from_path(pdf_url) - }), - MetricCounters.FILES - ) + return "ReadPdfContent" >> TransformAndCount( + beam.Map(lambda pdf_url: { + DataProps.SOURCE_FILENAME: pdf_url, + DataProps.PDF_CONTENT: read_all_from_path(pdf_url) + }), + MetricCounters.FILES + ) + def add_read_pdfs_to_annotated_lxml_pipeline_steps(p, opt, get_pipeline_output_file): - page_range = opt.pages - - cv_enabled = opt.cv_model_export_dir - - extract_tag_scope = None - - pdf_urls = p | PdfUrlSource(opt) - - lxml_content = ( - pdf_urls | - PreventFusion() | - - ReadPdfContent() | - - "ConvertPdfToLxml" >> MapOrLog(lambda v: extend_dict(v, { - DataProps.STRUCTURED_DOCUMENT: convert_pdf_bytes_to_structured_document( - v[DataProps.PDF_CONTENT], path=v[DataProps.SOURCE_FILENAME], - page_range=page_range - ) - }), log_fn=lambda e, v: ( - get_logger().warning( - 'caught exception (ignoring item): %s, pdf: %s', - e, v[DataProps.SOURCE_FILENAME], exc_info=e - ) - ), error_count=MetricCounters.CONVERT_PDF_TO_LXML_ERROR) - ) - - if cv_enabled: - image_size = ( - (opt.image_width, opt.image_height) - if opt.image_width and opt.image_height - else None + page_range = opt.pages + + cv_enabled = opt.cv_model_export_dir + + extract_tag_scope = None + + pdf_urls = p | PdfUrlSource(opt) + + lxml_content = ( + pdf_urls | + PreventFusion() | + + ReadPdfContent() | + + "ConvertPdfToLxml" >> MapOrLog(lambda v: extend_dict(v, { + DataProps.STRUCTURED_DOCUMENT: convert_pdf_bytes_to_structured_document( + v[DataProps.PDF_CONTENT], path=v[DataProps.SOURCE_FILENAME], + page_range=page_range + ) + }), log_fn=lambda e, v: ( + get_logger().warning( + 'caught exception (ignoring item): %s, pdf: %s', + e, v[DataProps.SOURCE_FILENAME], exc_info=e + ) + ), error_count=MetricCounters.CONVERT_PDF_TO_LXML_ERROR) ) - inference_model_wrapper = InferenceModelWrapper(opt.cv_model_export_dir) - - cv_predictions = ( - lxml_content | - - "ConvertPdfToPng" >> MapOrLog(lambda v: remove_keys_from_dict( - extend_dict(v, { - DataProps.PDF_PNG_PAGES: list(pdf_bytes_to_png_pages( - v[DataProps.PDF_CONTENT], - dpi=90, # not used if the image is scaled - image_size=image_size, - page_range=page_range - )) - }), - keys_to_remove={DataProps.PDF_CONTENT} - ), error_count=MetricCounters.CONVERT_PDF_TO_PNG_ERROR) | - "ComputerVisionPrediction" >> MapOrLog(lambda v: remove_keys_from_dict( - extend_dict(v, { - DataProps.CV_PREDICTION_PNG_PAGES: inference_model_wrapper(v[DataProps.PDF_PNG_PAGES]), - DataProps.COLOR_MAP: inference_model_wrapper.get_color_map() - }), - keys_to_remove={DataProps.PDF_PNG_PAGES} - ), error_count=MetricCounters.CV_PREDICTION_ERROR) + if cv_enabled: + image_size = ( + (opt.image_width, opt.image_height) + if opt.image_width and opt.image_height + else None + ) + inference_model_wrapper = InferenceModelWrapper(opt.cv_model_export_dir) + + cv_predictions = ( + lxml_content | + + "ConvertPdfToPng" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + DataProps.PDF_PNG_PAGES: list(pdf_bytes_to_png_pages( + v[DataProps.PDF_CONTENT], + dpi=90, # not used if the image is scaled + image_size=image_size, + page_range=page_range + )) + }), + keys_to_remove={DataProps.PDF_CONTENT} + ), error_count=MetricCounters.CONVERT_PDF_TO_PNG_ERROR) | + + "ComputerVisionPrediction" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + DataProps.CV_PREDICTION_PNG_PAGES: inference_model_wrapper( + v[DataProps.PDF_PNG_PAGES] + ), + DataProps.COLOR_MAP: inference_model_wrapper.get_color_map() + }), + keys_to_remove={DataProps.PDF_PNG_PAGES} + ), error_count=MetricCounters.CV_PREDICTION_ERROR) + ) + + if opt.save_cv_output: + _ = ( + cv_predictions | + "SaveComputerVisionOutput" >> TransformAndLog( + beam.Map(lambda v: save_pages( + get_pipeline_output_file( + v[DataProps.SOURCE_FILENAME], + OutputExt.CV_PNG + ), + '.png', + [image_data_to_png(image_data) + for image_data in v[DataProps.CV_PREDICTION_PNG_PAGES]] + )), + log_fn=lambda x: get_logger().info('saved cv output: %s', x) + ) + ) + + cv_annotated_lxml = ( + cv_predictions | + "AnnotateLxmlUsingCvPrediction" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + DataProps.STRUCTURED_DOCUMENT: ( + annotate_structured_document_using_predicted_image_data( + v[DataProps.STRUCTURED_DOCUMENT], + v[DataProps.CV_PREDICTION_PNG_PAGES], + v[DataProps.COLOR_MAP], + tag_scope=CV_TAG_SCOPE + ) + ) + }), + keys_to_remove={DataProps.PDF_PNG_PAGES} + ), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR) + ) + + lxml_content = cv_annotated_lxml + extract_tag_scope = CV_TAG_SCOPE + + if opt.crf_model: + model = load_crf_model(opt.crf_model) + crf_annotated_lxml = ( + lxml_content | + "AnnotateLxmlUsingCrfPrediction" >> MapOrLog(lambda v: extend_dict(v, { + DataProps.STRUCTURED_DOCUMENT: predict_and_annotate_structured_document( + v[DataProps.STRUCTURED_DOCUMENT], model + ) + }), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR) + ) + + lxml_content = crf_annotated_lxml + extract_tag_scope = CRF_TAG_SCOPE + + if opt.save_annot_lxml: + _ = ( # flake8: noqa + lxml_content | + "SaveAnnotLxml" >> TransformAndLog( + beam.Map(lambda v: save_structured_document( + get_pipeline_output_file( + v[DataProps.SOURCE_FILENAME], + get_annot_lxml_ext( + crf_enabled=opt.crf_model, + cv_enabled=cv_enabled + ) + ), + v[DataProps.STRUCTURED_DOCUMENT] + )), + log_fn=lambda x: get_logger().info('saved annoted lxml to: %s', x) + ) + ) + return lxml_content, extract_tag_scope + + +def add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt): + grobid_transformer = grobid_service( + opt.grobid_url, opt.grobid_action, start_service=opt.start_grobid_service + ) + + return ( + p | + PdfUrlSource(opt) | + PreventFusion() | + ReadPdfContent() | + "Grobid" >> MapOrLog(lambda v: extend_dict(v, { + DataProps.EXTRACTED_XML: grobid_transformer( + (v[DataProps.SOURCE_FILENAME], v[DataProps.PDF_CONTENT]) + )[1] + }), error_count=MetricCounters.GROBID_ERROR) ) - if opt.save_cv_output: - _ = ( - cv_predictions | - "SaveComputerVisionOutput" >> TransformAndLog( - beam.Map(lambda v: save_pages( - get_pipeline_output_file( - v[DataProps.SOURCE_FILENAME], - OutputExt.CV_PNG - ), - '.png', - [image_data_to_png(image_data) for image_data in v[DataProps.CV_PREDICTION_PNG_PAGES]] - )), - log_fn=lambda x: get_logger().info('saved cv output: %s', x) + +def add_read_source_to_extracted_xml_pipeline_steps(p, opt, get_pipeline_output_file): + if opt.lxml_file_list: + lxml_urls = p | ReadFileList( + opt.lxml_file_list, column=opt.lxml_file_column, limit=opt.limit) + + annotated_lxml = ( + lxml_urls | + PreventFusion() | + + "ReadLxmlContent" >> TransformAndCount( + MapOrLog(lambda url: { + DataProps.SOURCE_FILENAME: url, + DataProps.STRUCTURED_DOCUMENT: load_structured_document(url) + }, error_count=MetricCounters.READ_LXML_ERROR), + MetricCounters.FILES + ) ) - ) - - cv_annotated_lxml = ( - cv_predictions | - "AnnotateLxmlUsingCvPrediction" >> MapOrLog(lambda v: remove_keys_from_dict( - extend_dict(v, { - DataProps.STRUCTURED_DOCUMENT: annotate_structured_document_using_predicted_image_data( - v[DataProps.STRUCTURED_DOCUMENT], - v[DataProps.CV_PREDICTION_PNG_PAGES], - v[DataProps.COLOR_MAP], - tag_scope=CV_TAG_SCOPE - ) - }), - keys_to_remove={DataProps.PDF_PNG_PAGES} - ), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR) + + extract_tag_scope = None + else: + annotated_lxml, extract_tag_scope = add_read_pdfs_to_annotated_lxml_pipeline_steps( + p, opt, get_pipeline_output_file + ) + + extracted_xml = ( + annotated_lxml | + "ExtractToXml" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + DataProps.EXTRACTED_XML: extract_annotated_structured_document_to_xml( + v[DataProps.STRUCTURED_DOCUMENT], + tag_scope=extract_tag_scope + ) + }), + keys_to_remove={DataProps.STRUCTURED_DOCUMENT} + ), error_count=MetricCounters.EXTRACT_TO_XML_ERROR) ) - lxml_content = cv_annotated_lxml - extract_tag_scope = CV_TAG_SCOPE + if opt.use_grobid: + enhancer = GrobidXmlEnhancer( + opt.grobid_url, start_service=opt.start_grobid_service + ) + extracted_xml = ( + extracted_xml | + "GrobidEnhanceXml" >> MapOrLog(lambda v: extend_dict(v, { + DataProps.EXTRACTED_XML: enhancer( + v[DataProps.EXTRACTED_XML] + ) + }), error_count=MetricCounters.GROBID_ERROR) + ) + return extracted_xml + - if opt.crf_model: - model = load_crf_model(opt.crf_model) - crf_annotated_lxml = ( - lxml_content | - "AnnotateLxmlUsingCrfPrediction" >> MapOrLog(lambda v: extend_dict(v, { - DataProps.STRUCTURED_DOCUMENT: predict_and_annotate_structured_document( - v[DataProps.STRUCTURED_DOCUMENT], model +def configure_pipeline(p, opt): + def get_pipeline_output_file(source_url, ext): + return get_output_file( + source_url, + opt.base_data_path, + opt.output_path, + ext + ) + + if ( + opt.use_grobid and not opt.crf_model and + not opt.cv_model_export_dir and not opt.lxml_file_list + ): + extracted_xml = add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt) + else: + extracted_xml = add_read_source_to_extracted_xml_pipeline_steps( + p, opt, get_pipeline_output_file + ) + + _ = ( # flake8: noqa + extracted_xml | + "WriteXml" >> TransformAndLog( + beam.Map(lambda v: save_file_content( + get_pipeline_output_file( + v[DataProps.SOURCE_FILENAME], + opt.output_suffix + ), + v[DataProps.EXTRACTED_XML] + )), + log_fn=lambda x: get_logger().info('saved xml to: %s', x) ) - }), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR) ) - lxml_content = crf_annotated_lxml - extract_tag_scope = CRF_TAG_SCOPE - - if opt.save_annot_lxml: - _ = ( - lxml_content | - "SaveAnnotLxml" >> TransformAndLog( - beam.Map(lambda v: save_structured_document( - get_pipeline_output_file( - v[DataProps.SOURCE_FILENAME], - get_annot_lxml_ext( - crf_enabled=opt.crf_model, - cv_enabled=cv_enabled - ) - ), - v[DataProps.STRUCTURED_DOCUMENT] - )), - log_fn=lambda x: get_logger().info('saved annoted lxml to: %s', x) - ) + +def add_main_args(parser): + parser.add_argument( + '--data-path', type=str, required=True, + help='base data path' ) - return lxml_content, extract_tag_scope -def add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt): - grobid_transformer = grobid_service( - opt.grobid_url, opt.grobid_action, start_service=opt.start_grobid_service - ) - - return ( - p | - PdfUrlSource(opt) | - PreventFusion() | - ReadPdfContent() | - "Grobid" >> MapOrLog(lambda v: extend_dict(v, { - DataProps.EXTRACTED_XML: grobid_transformer( - (v[DataProps.SOURCE_FILENAME], v[DataProps.PDF_CONTENT]) - )[1] - }), error_count=MetricCounters.GROBID_ERROR) - ) + source_group = parser.add_argument_group('source') + source_one_of_group = source_group.add_mutually_exclusive_group(required=True) + source_one_of_group.add_argument( + '--pdf-path', type=str, required=False, + help='path to pdf file(s), relative to data-path' + ) + source_one_of_group.add_argument( + '--pdf-file-list', type=str, required=False, + help='path to pdf csv/tsv file list' + ) + source_group.add_argument( + '--pdf-file-column', type=str, required=False, default='pdf_url', + help='the column of the pdf file list to use' + ) + source_one_of_group.add_argument( + '--lxml-file-list', type=str, required=False, + help='path to annotated lxml or svg pages zip file list' + '; (CRF and CV models are not supported in this mode)' + ) + source_group.add_argument( + '--lxml-file-column', type=str, required=False, default='url', + help='the column of the lxml file list to use' + ) -def add_read_source_to_extracted_xml_pipeline_steps(p, opt, get_pipeline_output_file): - if opt.lxml_file_list: - lxml_urls = p | ReadFileList(opt.lxml_file_list, column=opt.lxml_file_column, limit=opt.limit) - - annotated_lxml = ( - lxml_urls | - PreventFusion() | - - "ReadLxmlContent" >> TransformAndCount( - MapOrLog(lambda url: { - DataProps.SOURCE_FILENAME: url, - DataProps.STRUCTURED_DOCUMENT: load_structured_document(url) - }, error_count=MetricCounters.READ_LXML_ERROR), - MetricCounters.FILES - ) + parser.add_argument( + '--limit', type=int, required=False, + help='limit the number of file pairs to process' ) - extract_tag_scope = None - else: - annotated_lxml, extract_tag_scope = add_read_pdfs_to_annotated_lxml_pipeline_steps( - p, opt, get_pipeline_output_file + output_group = parser.add_argument_group('output') + output_group.add_argument( + '--output-path', required=False, + help='Output directory to write results to.' + ) + output_group.add_argument( + '--output-suffix', required=False, default='.crf.xml', + help='Output file suffix to add to the filename (excluding the file extension).' ) - extracted_xml = ( - annotated_lxml | - "ExtractToXml" >> MapOrLog(lambda v: remove_keys_from_dict( - extend_dict(v, { - DataProps.EXTRACTED_XML: extract_annotated_structured_document_to_xml( - v[DataProps.STRUCTURED_DOCUMENT], - tag_scope=extract_tag_scope - ) - }), - keys_to_remove={DataProps.STRUCTURED_DOCUMENT} - ), error_count=MetricCounters.EXTRACT_TO_XML_ERROR) - ) - - if opt.use_grobid: - enhancer = GrobidXmlEnhancer( - opt.grobid_url, start_service=opt.start_grobid_service + parser.add_argument( + '--save-annot-lxml', action='store_true', default=False, + help='enable saving of annotated lxml' ) - extracted_xml = ( - extracted_xml | - "GrobidEnhanceXml" >> MapOrLog(lambda v: extend_dict(v, { - DataProps.EXTRACTED_XML: enhancer( - v[DataProps.EXTRACTED_XML] - ) - }), error_count=MetricCounters.GROBID_ERROR) + + grobid_group = parser.add_argument_group('Grobid') + grobid_group.add_argument( + '--use-grobid', action='store_true', default=False, + help='enable the use of grobid' + ) + grobid_group.add_argument( + '--grobid-url', required=False, default=None, + help='Base URL to the Grobid service' + ) + parser.add_argument( + '--grobid-action', required=False, + default=GrobidApiPaths.PROCESS_HEADER_DOCUMENT, + help='Name of the Grobid action (if Grobid is used without CRF or CV model)' ) - return extracted_xml -def configure_pipeline(p, opt): - get_pipeline_output_file = lambda source_url, ext: get_output_file( - source_url, - opt.base_data_path, - opt.output_path, - ext - ) - - if ( - opt.use_grobid and not opt.crf_model and - not opt.cv_model_export_dir and not opt.lxml_file_list - ): - extracted_xml = add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt) - else: - extracted_xml = add_read_source_to_extracted_xml_pipeline_steps( - p, opt, get_pipeline_output_file + parser.add_argument( + '--debug', action='store_true', default=False, + help='enable debug output' ) - _ = ( - extracted_xml | - "WriteXml" >> TransformAndLog( - beam.Map(lambda v: save_file_content( - get_pipeline_output_file( - v[DataProps.SOURCE_FILENAME], - opt.output_suffix - ), - v[DataProps.EXTRACTED_XML] - )), - log_fn=lambda x: get_logger().info('saved xml to: %s', x) + parser.add_argument( + '--pages', type=parse_page_range, default=None, + help='only processes the selected pages' ) - ) + crf_group = parser.add_argument_group('CRF') + crf_group.add_argument( + '--crf-model', type=str, required=False, + help='path to saved crf model' + ) + + cv_group = parser.add_argument_group('CV') + cv_group.add_argument( + '--cv-model-export-dir', type=str, required=False, + help='path to cv model export dir' + ) + cv_group.add_argument( + '--image-width', type=int, required=False, + default=256, + help='image width of resulting PNGs' + ) + cv_group.add_argument( + '--image-height', type=int, required=False, + default=256, + help='image height of resulting PNGs' + ) + cv_group.add_argument( + '--save-cv-output', action='store_true', default=False, + help='enable saving of computer vision output (png pages)' + ) -def add_main_args(parser): - parser.add_argument( - '--data-path', type=str, required=True, - help='base data path' - ) - - source_group = parser.add_argument_group('source') - source_one_of_group = source_group.add_mutually_exclusive_group(required=True) - source_one_of_group.add_argument( - '--pdf-path', type=str, required=False, - help='path to pdf file(s), relative to data-path' - ) - source_one_of_group.add_argument( - '--pdf-file-list', type=str, required=False, - help='path to pdf csv/tsv file list' - ) - source_group.add_argument( - '--pdf-file-column', type=str, required=False, default='pdf_url', - help='the column of the pdf file list to use' - ) - source_one_of_group.add_argument( - '--lxml-file-list', type=str, required=False, - help='path to annotated lxml or svg pages zip file list' - '; (CRF and CV models are not supported in this mode)' - ) - source_group.add_argument( - '--lxml-file-column', type=str, required=False, default='url', - help='the column of the lxml file list to use' - ) - - parser.add_argument( - '--limit', type=int, required=False, - help='limit the number of file pairs to process' - ) - - output_group = parser.add_argument_group('output') - output_group.add_argument( - '--output-path', required=False, - help='Output directory to write results to.' - ) - output_group.add_argument( - '--output-suffix', required=False, default='.crf.xml', - help='Output file suffix to add to the filename (excluding the file extension).' - ) - - parser.add_argument( - '--save-annot-lxml', action='store_true', default=False, - help='enable saving of annotated lxml' - ) - - grobid_group = parser.add_argument_group('Grobid') - grobid_group.add_argument( - '--use-grobid', action='store_true', default=False, - help='enable the use of grobid' - ) - grobid_group.add_argument( - '--grobid-url', required=False, default=None, - help='Base URL to the Grobid service' - ) - parser.add_argument( - '--grobid-action', required=False, - default=GrobidApiPaths.PROCESS_HEADER_DOCUMENT, - help='Name of the Grobid action (if Grobid is used without CRF or CV model)' - ) - - parser.add_argument( - '--debug', action='store_true', default=False, - help='enable debug output' - ) - - parser.add_argument( - '--pages', type=parse_page_range, default=None, - help='only processes the selected pages' - ) - - crf_group = parser.add_argument_group('CRF') - crf_group.add_argument( - '--crf-model', type=str, required=False, - help='path to saved crf model' - ) - - cv_group = parser.add_argument_group('CV') - cv_group.add_argument( - '--cv-model-export-dir', type=str, required=False, - help='path to cv model export dir' - ) - cv_group.add_argument( - '--image-width', type=int, required=False, - default=256, - help='image width of resulting PNGs' - ) - cv_group.add_argument( - '--image-height', type=int, required=False, - default=256, - help='image height of resulting PNGs' - ) - cv_group.add_argument( - '--save-cv-output', action='store_true', default=False, - help='enable saving of computer vision output (png pages)' - ) def process_main_args(args, parser): - args.base_data_path = args.data_path.replace('/*/', '/') + args.base_data_path = args.data_path.replace('/*/', '/') - if not args.output_path: - args.output_path = os.path.join( - os.path.dirname(args.base_data_path), - os.path.basename(args.base_data_path + '-results') - ) + if not args.output_path: + args.output_path = os.path.join( + os.path.dirname(args.base_data_path), + os.path.basename(args.base_data_path + '-results') + ) + + if args.lxml_file_list: + if args.crf_model: + parser.error('--crf-model cannot be used in conjunction with --lxml-file-list') + + if args.cv_model_export_dir: + parser.error( + '--crf-model-export-dir cannot be used in conjunction with --lxml-file-list' + ) + else: + if not args.crf_model and not args.cv_model_export_dir and not args.use_grobid: + parser.error( + '--crf-model, --cv-model-export-dir or --use-grobid required in conjunction' + ' with --pdf-file-list or --pdf-path' + ) + + if args.use_grobid and not args.grobid_url: + args.grobid_url = 'http://localhost:8080/api' + args.start_grobid_service = True + else: + args.start_grobid_service = False - if args.lxml_file_list: - if args.crf_model: - parser.error('--crf-model cannot be used in conjunction with --lxml-file-list') - - if args.cv_model_export_dir: - parser.error('--crf-model-export-dir cannot be used in conjunction with --lxml-file-list') - else: - if not args.crf_model and not args.cv_model_export_dir and not args.use_grobid: - parser.error( - '--crf-model, --cv-model-export-dir or --use-grobid required in conjunction' - ' with --pdf-file-list or --pdf-path' - ) - - if args.use_grobid and not args.grobid_url: - args.grobid_url = 'http://localhost:8080/api' - args.start_grobid_service = True - else: - args.start_grobid_service = False def parse_args(argv=None): - parser = argparse.ArgumentParser() - add_main_args(parser) - add_cloud_args(parser) + parser = argparse.ArgumentParser() + add_main_args(parser) + add_cloud_args(parser) - args = parser.parse_args(argv) + args = parser.parse_args(argv) - if args.debug: - logging.getLogger().setLevel('DEBUG') + if args.debug: + logging.getLogger().setLevel('DEBUG') + + process_main_args(args, parser) + process_cloud_args( + args, args.output_path, + name='sciencebeam-convert' + ) + process_sciencebeam_gym_dep_args(args) - process_main_args(args, parser) - process_cloud_args( - args, args.output_path, - name='sciencebeam-convert' - ) - process_sciencebeam_gym_dep_args(args) + get_logger().info('args: %s', args) - get_logger().info('args: %s', args) + return args - return args def run(argv=None): - args = parse_args(argv) + args = parse_args(argv) - # We use the save_main_session option because one or more DoFn's in this - # workflow rely on global context (e.g., a module imported at module level). - pipeline_options = PipelineOptions.from_dictionary(vars(args)) - pipeline_options.view_as(SetupOptions).save_main_session = True + # We use the save_main_session option because one or more DoFn's in this + # workflow rely on global context (e.g., a module imported at module level). + pipeline_options = PipelineOptions.from_dictionary(vars(args)) + pipeline_options.view_as(SetupOptions).save_main_session = True - with beam.Pipeline(args.runner, options=pipeline_options) as p: - configure_pipeline(p, args) + with beam.Pipeline(args.runner, options=pipeline_options) as p: + configure_pipeline(p, args) - # Execute the pipeline and wait until it is completed. + # Execute the pipeline and wait until it is completed. if __name__ == '__main__': - logging.basicConfig(level='INFO') + logging.basicConfig(level='INFO') - run() + run() diff --git a/sciencebeam_gym/convert/conversion_pipeline_test.py b/sciencebeam_gym/convert/conversion_pipeline_test.py index 4bf728d..9247cfd 100644 --- a/sciencebeam_gym/convert/conversion_pipeline_test.py +++ b/sciencebeam_gym/convert/conversion_pipeline_test.py @@ -6,18 +6,18 @@ import pytest import apache_beam as beam from sciencebeam_utils.beam_utils.testing import ( - BeamTest, - TestPipeline + BeamTest, + TestPipeline ) from . import conversion_pipeline as conversion_pipeline from .conversion_pipeline import ( - get_annot_lxml_ext, - configure_pipeline, - parse_args, - OutputExt, - CV_TAG_SCOPE, - CRF_TAG_SCOPE + get_annot_lxml_ext, + configure_pipeline, + parse_args, + OutputExt, + CV_TAG_SCOPE, + CRF_TAG_SCOPE ) @@ -41,439 +41,449 @@ PDF_CONTENT_1 = b'pdf content' LXML_CONTENT_1 = b'<LXML>lxml content</LXML>' TEI_XML_CONTENT_1 = b'<TEI>tei content</TEI>' -fake_pdf_png_page = lambda i=0: 'fake pdf png page: %d' % i + +def fake_pdf_png_page(i=0): + return 'fake pdf png page: %d' % i + MIN_ARGV = [ - '--data-path=' + BASE_DATA_PATH, - '--pdf-path=' + PDF_PATH, - '--crf-model=' + MODEL_EXPORT_DIR + '--data-path=' + BASE_DATA_PATH, + '--pdf-path=' + PDF_PATH, + '--crf-model=' + MODEL_EXPORT_DIR ] + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + def get_default_args(): - return parse_args(MIN_ARGV) + return parse_args(MIN_ARGV) + def patch_conversion_pipeline(**kwargs): - always_mock = { - 'read_all_from_path', - 'load_structured_document', - 'convert_pdf_bytes_to_lxml', - 'convert_pdf_bytes_to_structured_document', - 'load_crf_model', - 'ReadFileList', - 'FindFiles', - 'predict_and_annotate_structured_document', - 'extract_annotated_structured_document_to_xml', - 'save_structured_document', - 'save_file_content', - 'GrobidXmlEnhancer', - 'grobid_service', - 'pdf_bytes_to_png_pages', - 'InferenceModelWrapper', - 'annotate_structured_document_using_predicted_image_data' - } - - return patch.multiple( - conversion_pipeline, - **{ - k: kwargs.get(k, DEFAULT) - for k in always_mock + always_mock = { + 'read_all_from_path', + 'load_structured_document', + 'convert_pdf_bytes_to_lxml', + 'convert_pdf_bytes_to_structured_document', + 'load_crf_model', + 'ReadFileList', + 'FindFiles', + 'predict_and_annotate_structured_document', + 'extract_annotated_structured_document_to_xml', + 'save_structured_document', + 'save_file_content', + 'GrobidXmlEnhancer', + 'grobid_service', + 'pdf_bytes_to_png_pages', + 'InferenceModelWrapper', + 'annotate_structured_document_using_predicted_image_data' } - ) -def _setup_mocks_for_pages(mocks, page_no_list, file_count=1): - mocks['pdf_bytes_to_png_pages'].return_value = [ - fake_pdf_png_page(i) for i in page_no_list - ] + return patch.multiple( + conversion_pipeline, + **{ + k: kwargs.get(k, DEFAULT) + for k in always_mock + } + ) + + +def _setup_mocks_for_pages(mocks, page_no_list): + mocks['pdf_bytes_to_png_pages'].return_value = [ + fake_pdf_png_page(i) for i in page_no_list + ] + class TestGetAnnotLxmlExt(object): - def test_should_return_crf_cv_annot_lxml(self): - assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=True) == OutputExt.CRF_CV_ANNOT_LXML + def test_should_return_crf_cv_annot_lxml(self): + assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=True) == OutputExt.CRF_CV_ANNOT_LXML - def test_should_return_crf_annot_lxml(self): - assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=False) == OutputExt.CRF_ANNOT_LXML + def test_should_return_crf_annot_lxml(self): + assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=False) == OutputExt.CRF_ANNOT_LXML - def test_should_return_cv_annot_lxml(self): - assert get_annot_lxml_ext(crf_enabled=False, cv_enabled=True) == OutputExt.CV_ANNOT_LXML + def test_should_return_cv_annot_lxml(self): + assert get_annot_lxml_ext(crf_enabled=False, cv_enabled=True) == OutputExt.CV_ANNOT_LXML + + def test_should_raise_error_if_neither_crf_or_cv_are_enabled(self): + with pytest.raises(AssertionError): + get_annot_lxml_ext(crf_enabled=False, cv_enabled=False) - def test_should_raise_error_if_neither_crf_or_cv_are_enabled(self): - with pytest.raises(AssertionError): - get_annot_lxml_ext(crf_enabled=False, cv_enabled=False) @pytest.mark.slow class TestConfigurePipeline(BeamTest): - def test_should_pass_pdf_pattern_to_find_files_and_read_pdf_file(self): - with patch_conversion_pipeline() as mocks: - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = PDF_PATH - opt.pdf_file_list = None - with TestPipeline() as p: - mocks['FindFiles'].return_value = beam.Create([PDF_FILE_1]) - configure_pipeline(p, opt) - - mocks['FindFiles'].assert_called_with( - BASE_DATA_PATH + '/' + PDF_PATH - ) - mocks['read_all_from_path'].assert_called_with( - PDF_FILE_1 - ) - - def test_should_pass_pdf_file_list_and_limit_to_read_file_list_and_read_pdf_file(self): - with patch_conversion_pipeline() as mocks: - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = None - opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' - opt.limit = 100 - with TestPipeline() as p: - mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) - configure_pipeline(p, opt) - - mocks['ReadFileList'].assert_called_with( - opt.pdf_file_list, column='pdf_url', limit=opt.limit - ) - mocks['read_all_from_path'].assert_called_with( - PDF_FILE_1 - ) - - def test_should_pass_around_values_with_default_pipeline(self): - with patch_conversion_pipeline() as mocks: - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = None - opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' - opt.output_path = OUTPUT_PATH - opt.output_suffix = OUTPUT_SUFFIX - with TestPipeline() as p: - mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) - mocks['read_all_from_path'].return_value = PDF_CONTENT_1 - configure_pipeline(p, opt) - - mocks['convert_pdf_bytes_to_structured_document'].assert_called_with( - PDF_CONTENT_1, page_range=None, path=PDF_FILE_1 - ) - mocks['predict_and_annotate_structured_document'].assert_called_with( - mocks['convert_pdf_bytes_to_structured_document'].return_value, - mocks['load_crf_model'].return_value - ) - mocks['extract_annotated_structured_document_to_xml'].assert_called_with( - mocks['predict_and_annotate_structured_document'].return_value, - tag_scope=CRF_TAG_SCOPE - ) - mocks['save_file_content'].assert_called_with( - OUTPUT_XML_FILE_1, - mocks['extract_annotated_structured_document_to_xml'].return_value - ) - - def test_should_save_annotated_lxml_if_enabled(self): - with patch_conversion_pipeline() as mocks: - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = None - opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' - opt.output_path = OUTPUT_PATH - opt.output_suffix = OUTPUT_SUFFIX - opt.save_annot_lxml = True - with TestPipeline() as p: - mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) - configure_pipeline(p, opt) - - mocks['save_structured_document'].assert_called_with( - OUTPUT_PATH + '/' + REL_PDF_FILE_WITHOUT_EXT_1 + OutputExt.CRF_ANNOT_LXML, - mocks['predict_and_annotate_structured_document'].return_value - ) - - def test_should_use_lxml_file_list_if_provided_and_load_structured_documents(self): - with patch_conversion_pipeline() as mocks: - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = None - opt.pdf_file_list = None - opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv' - opt.output_path = OUTPUT_PATH - opt.output_suffix = OUTPUT_SUFFIX - with TestPipeline() as p: - mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1]) - configure_pipeline(p, opt) - - mocks['extract_annotated_structured_document_to_xml'].assert_called_with( - mocks['load_structured_document'].return_value, - tag_scope=None - ) - mocks['save_file_content'].assert_called_with( - OUTPUT_XML_FILE_1, - mocks['extract_annotated_structured_document_to_xml'].return_value - ) - - def test_should_use_crf_model_with_cv_model_if_enabled(self): - with patch_conversion_pipeline() as mocks: - inference_model_wrapper = mocks['InferenceModelWrapper'].return_value - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = None - opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' - opt.output_path = OUTPUT_PATH - opt.output_suffix = OUTPUT_SUFFIX - opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR - with TestPipeline() as p: - mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) - mocks['read_all_from_path'].return_value = PDF_CONTENT_1 - _setup_mocks_for_pages(mocks, [1, 2]) - configure_pipeline(p, opt) - - mocks['convert_pdf_bytes_to_structured_document'].assert_called_with( - PDF_CONTENT_1, page_range=None, path=PDF_FILE_1 - ) - - # cv model - inference_model_wrapper.assert_called_with( - [fake_pdf_png_page(i) for i in [1, 2]] - ) - mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with( - mocks['convert_pdf_bytes_to_structured_document'].return_value, - inference_model_wrapper.return_value, - inference_model_wrapper.get_color_map.return_value, - tag_scope=CV_TAG_SCOPE - ) - - # crf model should receive output from cv model - mocks['predict_and_annotate_structured_document'].assert_called_with( - mocks['annotate_structured_document_using_predicted_image_data'].return_value, - mocks['load_crf_model'].return_value - ) - mocks['extract_annotated_structured_document_to_xml'].assert_called_with( - mocks['predict_and_annotate_structured_document'].return_value, - tag_scope=CRF_TAG_SCOPE - ) - - def test_should_use_cv_model_only_if_enabled(self): - with patch_conversion_pipeline() as mocks: - inference_model_wrapper = mocks['InferenceModelWrapper'].return_value - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = None - opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' - opt.output_path = OUTPUT_PATH - opt.output_suffix = OUTPUT_SUFFIX - opt.crf_model = None - opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR - with TestPipeline() as p: - mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) - mocks['read_all_from_path'].return_value = PDF_CONTENT_1 - _setup_mocks_for_pages(mocks, [1, 2]) - configure_pipeline(p, opt) - - mocks['convert_pdf_bytes_to_structured_document'].assert_called_with( - PDF_CONTENT_1, page_range=None, path=PDF_FILE_1 - ) - - # cv model - inference_model_wrapper.assert_called_with( - [fake_pdf_png_page(i) for i in [1, 2]] - ) - mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with( - mocks['convert_pdf_bytes_to_structured_document'].return_value, - inference_model_wrapper.return_value, - inference_model_wrapper.get_color_map.return_value, - tag_scope=CV_TAG_SCOPE - ) - mocks['extract_annotated_structured_document_to_xml'].assert_called_with( - mocks['annotate_structured_document_using_predicted_image_data'].return_value, - tag_scope=CV_TAG_SCOPE - ) - - # crf model not be called - mocks['predict_and_annotate_structured_document'].assert_not_called() - - def test_should_use_grobid_if_enabled(self): - with patch_conversion_pipeline() as mocks: - grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = None - opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' - opt.output_path = OUTPUT_PATH - opt.output_suffix = OUTPUT_SUFFIX - opt.use_grobid = True - opt.grobid_url = 'http://test/api' - with TestPipeline() as p: - mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) - mocks['read_all_from_path'].return_value = PDF_CONTENT_1 - mocks['convert_pdf_bytes_to_lxml'].return_value = LXML_CONTENT_1 - configure_pipeline(p, opt) - - mocks['GrobidXmlEnhancer'].assert_called_with( - opt.grobid_url, - start_service=opt.start_grobid_service - ) - grobid_xml_enhancer.assert_called_with( - mocks['extract_annotated_structured_document_to_xml'].return_value - ) - mocks['save_file_content'].assert_called_with( - OUTPUT_XML_FILE_1, - grobid_xml_enhancer.return_value - ) - - def test_should_use_grobid_with_lxml_file_list_if_enabled(self): - with patch_conversion_pipeline() as mocks: - grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = None - opt.pdf_file_list = None - opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv' - opt.output_path = OUTPUT_PATH - opt.output_suffix = OUTPUT_SUFFIX - opt.crf_model = None - opt.cv_model_export_dir = None - opt.use_grobid = True - opt.grobid_url = 'http://test/api' - with TestPipeline() as p: - mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1]) - configure_pipeline(p, opt) - - mocks['extract_annotated_structured_document_to_xml'].assert_called_with( - mocks['load_structured_document'].return_value, - tag_scope=None - ) - mocks['save_file_content'].assert_called_with( - OUTPUT_XML_FILE_1, - grobid_xml_enhancer.return_value - ) - - def test_should_use_grobid_only_if_crf_or_cv_model_are_not_enabled(self): - with patch_conversion_pipeline() as mocks: - opt = get_default_args() - opt.base_data_path = BASE_DATA_PATH - opt.pdf_path = None - opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' - opt.output_path = OUTPUT_PATH - opt.output_suffix = OUTPUT_SUFFIX - opt.crf_model = None - opt.cv_model_export_dir = None - opt.use_grobid = True - opt.grobid_url = 'http://test/api' - with TestPipeline() as p: - mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) - mocks['read_all_from_path'].return_value = PDF_CONTENT_1 - mocks['grobid_service'].return_value = lambda x: ( - PDF_FILE_1, TEI_XML_CONTENT_1 - ) - configure_pipeline(p, opt) - - mocks['grobid_service'].assert_called_with( - opt.grobid_url, - opt.grobid_action, - start_service=opt.start_grobid_service - ) - mocks['save_file_content'].assert_called_with( - OUTPUT_XML_FILE_1, - TEI_XML_CONTENT_1 - ) + def test_should_pass_pdf_pattern_to_find_files_and_read_pdf_file(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = PDF_PATH + opt.pdf_file_list = None + with TestPipeline() as p: + mocks['FindFiles'].return_value = beam.Create([PDF_FILE_1]) + configure_pipeline(p, opt) + + mocks['FindFiles'].assert_called_with( + BASE_DATA_PATH + '/' + PDF_PATH + ) + mocks['read_all_from_path'].assert_called_with( + PDF_FILE_1 + ) + + def test_should_pass_pdf_file_list_and_limit_to_read_file_list_and_read_pdf_file(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.limit = 100 + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + configure_pipeline(p, opt) + + mocks['ReadFileList'].assert_called_with( + opt.pdf_file_list, column='pdf_url', limit=opt.limit + ) + mocks['read_all_from_path'].assert_called_with( + PDF_FILE_1 + ) + + def test_should_pass_around_values_with_default_pipeline(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + configure_pipeline(p, opt) + + mocks['convert_pdf_bytes_to_structured_document'].assert_called_with( + PDF_CONTENT_1, page_range=None, path=PDF_FILE_1 + ) + mocks['predict_and_annotate_structured_document'].assert_called_with( + mocks['convert_pdf_bytes_to_structured_document'].return_value, + mocks['load_crf_model'].return_value + ) + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['predict_and_annotate_structured_document'].return_value, + tag_scope=CRF_TAG_SCOPE + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + mocks['extract_annotated_structured_document_to_xml'].return_value + ) + + def test_should_save_annotated_lxml_if_enabled(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.save_annot_lxml = True + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + configure_pipeline(p, opt) + + mocks['save_structured_document'].assert_called_with( + OUTPUT_PATH + '/' + REL_PDF_FILE_WITHOUT_EXT_1 + OutputExt.CRF_ANNOT_LXML, + mocks['predict_and_annotate_structured_document'].return_value + ) + + def test_should_use_lxml_file_list_if_provided_and_load_structured_documents(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = None + opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1]) + configure_pipeline(p, opt) + + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['load_structured_document'].return_value, + tag_scope=None + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + mocks['extract_annotated_structured_document_to_xml'].return_value + ) + + def test_should_use_crf_model_with_cv_model_if_enabled(self): + with patch_conversion_pipeline() as mocks: + inference_model_wrapper = mocks['InferenceModelWrapper'].return_value + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + _setup_mocks_for_pages(mocks, [1, 2]) + configure_pipeline(p, opt) + + mocks['convert_pdf_bytes_to_structured_document'].assert_called_with( + PDF_CONTENT_1, page_range=None, path=PDF_FILE_1 + ) + + # cv model + inference_model_wrapper.assert_called_with( + [fake_pdf_png_page(i) for i in [1, 2]] + ) + mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with( + mocks['convert_pdf_bytes_to_structured_document'].return_value, + inference_model_wrapper.return_value, + inference_model_wrapper.get_color_map.return_value, + tag_scope=CV_TAG_SCOPE + ) + + # crf model should receive output from cv model + mocks['predict_and_annotate_structured_document'].assert_called_with( + mocks['annotate_structured_document_using_predicted_image_data'].return_value, + mocks['load_crf_model'].return_value + ) + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['predict_and_annotate_structured_document'].return_value, + tag_scope=CRF_TAG_SCOPE + ) + + def test_should_use_cv_model_only_if_enabled(self): + with patch_conversion_pipeline() as mocks: + inference_model_wrapper = mocks['InferenceModelWrapper'].return_value + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.crf_model = None + opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + _setup_mocks_for_pages(mocks, [1, 2]) + configure_pipeline(p, opt) + + mocks['convert_pdf_bytes_to_structured_document'].assert_called_with( + PDF_CONTENT_1, page_range=None, path=PDF_FILE_1 + ) + + # cv model + inference_model_wrapper.assert_called_with( + [fake_pdf_png_page(i) for i in [1, 2]] + ) + mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with( + mocks['convert_pdf_bytes_to_structured_document'].return_value, + inference_model_wrapper.return_value, + inference_model_wrapper.get_color_map.return_value, + tag_scope=CV_TAG_SCOPE + ) + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['annotate_structured_document_using_predicted_image_data'].return_value, + tag_scope=CV_TAG_SCOPE + ) + + # crf model not be called + mocks['predict_and_annotate_structured_document'].assert_not_called() + + def test_should_use_grobid_if_enabled(self): + with patch_conversion_pipeline() as mocks: + grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.use_grobid = True + opt.grobid_url = 'http://test/api' + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + mocks['convert_pdf_bytes_to_lxml'].return_value = LXML_CONTENT_1 + configure_pipeline(p, opt) + + mocks['GrobidXmlEnhancer'].assert_called_with( + opt.grobid_url, + start_service=opt.start_grobid_service + ) + grobid_xml_enhancer.assert_called_with( + mocks['extract_annotated_structured_document_to_xml'].return_value + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + grobid_xml_enhancer.return_value + ) + + def test_should_use_grobid_with_lxml_file_list_if_enabled(self): + with patch_conversion_pipeline() as mocks: + grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = None + opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.crf_model = None + opt.cv_model_export_dir = None + opt.use_grobid = True + opt.grobid_url = 'http://test/api' + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1]) + configure_pipeline(p, opt) + + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['load_structured_document'].return_value, + tag_scope=None + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + grobid_xml_enhancer.return_value + ) + + def test_should_use_grobid_only_if_crf_or_cv_model_are_not_enabled(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.crf_model = None + opt.cv_model_export_dir = None + opt.use_grobid = True + opt.grobid_url = 'http://test/api' + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + mocks['grobid_service'].return_value = lambda x: ( + PDF_FILE_1, TEI_XML_CONTENT_1 + ) + configure_pipeline(p, opt) + + mocks['grobid_service'].assert_called_with( + opt.grobid_url, + opt.grobid_action, + start_service=opt.start_grobid_service + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + TEI_XML_CONTENT_1 + ) + class TestParseArgs(object): - def test_should_parse_minimum_number_of_arguments(self): - parse_args(MIN_ARGV) - - def test_should_raise_error_if_no_source_argument_was_provided(self): - with pytest.raises(SystemExit): - parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--crf-model=' + MODEL_EXPORT_DIR - ]) - - def test_should_allow_pdf_path_to_be_specified(self): - args = parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--pdf-path=' + PDF_PATH, - '--crf-model=' + MODEL_EXPORT_DIR - ]) - assert args.pdf_path == PDF_PATH - - def test_should_allow_pdf_file_list_and_column_to_be_specified(self): - args = parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--pdf-file-list=' + FILE_LIST_PATH, - '--pdf-file-column=' + FILE_COLUMN, - '--crf-model=' + MODEL_EXPORT_DIR - ]) - assert args.pdf_file_list == FILE_LIST_PATH - assert args.pdf_file_column == FILE_COLUMN - - def test_should_allow_lxml_file_list_and_column_to_be_specified(self): - args = parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--lxml-file-list=' + FILE_LIST_PATH, - '--lxml-file-column=' + FILE_COLUMN - ]) - assert args.lxml_file_list == FILE_LIST_PATH - assert args.lxml_file_column == FILE_COLUMN - - def test_should_not_allow_crf_model_with_lxml_file_list(self): - with pytest.raises(SystemExit): - parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--lxml-file-list=' + FILE_LIST_PATH, - '--lxml-file-column=' + FILE_COLUMN, - '--crf-model=' + MODEL_EXPORT_DIR - ]) - - def test_should_not_allow_cv_model_with_lxml_file_list(self): - with pytest.raises(SystemExit): - parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--lxml-file-list=' + FILE_LIST_PATH, - '--lxml-file-column=' + FILE_COLUMN, - '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR - ]) - - def test_should_require_crf_or_cv_model_with_pdf_file_list(self): - with pytest.raises(SystemExit): - parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--pdf-file-list=' + FILE_LIST_PATH, - '--pdf-file-column=' + FILE_COLUMN - ]) - - def test_should_require_crf_or_cv_model_with_pdf_path(self): - with pytest.raises(SystemExit): - parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--pdf-path=' + PDF_PATH - ]) - - def test_should_allow_crf_model_only_with_pdf_file_list(self): - parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--pdf-file-list=' + FILE_LIST_PATH, - '--pdf-file-column=' + FILE_COLUMN, - '--crf-model=' + MODEL_EXPORT_DIR - ]) - - def test_should_allow_cv_model_only_with_pdf_file_list(self): - parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--pdf-file-list=' + FILE_LIST_PATH, - '--pdf-file-column=' + FILE_COLUMN, - '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR - ]) - - def test_should_allow_crf_and_cv_model_only_with_pdf_file_list(self): - parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--pdf-file-list=' + FILE_LIST_PATH, - '--pdf-file-column=' + FILE_COLUMN, - '--crf-model=' + MODEL_EXPORT_DIR, - '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR - ]) - - def test_should_allow_grobid_only_with_pdf_file_list(self): - parse_args([ - '--data-path=' + BASE_DATA_PATH, - '--pdf-file-list=' + FILE_LIST_PATH, - '--pdf-file-column=' + FILE_COLUMN, - '--use-grobid' - ]) + def test_should_parse_minimum_number_of_arguments(self): + parse_args(MIN_ARGV) + + def test_should_raise_error_if_no_source_argument_was_provided(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + + def test_should_allow_pdf_path_to_be_specified(self): + args = parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-path=' + PDF_PATH, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + assert args.pdf_path == PDF_PATH + + def test_should_allow_pdf_file_list_and_column_to_be_specified(self): + args = parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + assert args.pdf_file_list == FILE_LIST_PATH + assert args.pdf_file_column == FILE_COLUMN + + def test_should_allow_lxml_file_list_and_column_to_be_specified(self): + args = parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--lxml-file-list=' + FILE_LIST_PATH, + '--lxml-file-column=' + FILE_COLUMN + ]) + assert args.lxml_file_list == FILE_LIST_PATH + assert args.lxml_file_column == FILE_COLUMN + + def test_should_not_allow_crf_model_with_lxml_file_list(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--lxml-file-list=' + FILE_LIST_PATH, + '--lxml-file-column=' + FILE_COLUMN, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + + def test_should_not_allow_cv_model_with_lxml_file_list(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--lxml-file-list=' + FILE_LIST_PATH, + '--lxml-file-column=' + FILE_COLUMN, + '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR + ]) + + def test_should_require_crf_or_cv_model_with_pdf_file_list(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN + ]) + + def test_should_require_crf_or_cv_model_with_pdf_path(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-path=' + PDF_PATH + ]) + + def test_should_allow_crf_model_only_with_pdf_file_list(self): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + + def test_should_allow_cv_model_only_with_pdf_file_list(self): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR + ]) + + def test_should_allow_crf_and_cv_model_only_with_pdf_file_list(self): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--crf-model=' + MODEL_EXPORT_DIR, + '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR + ]) + + def test_should_allow_grobid_only_with_pdf_file_list(self): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--use-grobid' + ]) diff --git a/sciencebeam_gym/convert/cv_conversion_utils.py b/sciencebeam_gym/convert/cv_conversion_utils.py index d3cbc56..cd36db2 100644 --- a/sciencebeam_gym/convert/cv_conversion_utils.py +++ b/sciencebeam_gym/convert/cv_conversion_utils.py @@ -8,42 +8,49 @@ import numpy as np from PIL import Image from sciencebeam_gym.inference_model import ( - load_inference_model + load_inference_model ) + def lazy_cached_value(value_fn): - cache = {} - def wrapper(): - value = cache.get('value') - if value is None: - value = value_fn() - cache['value'] = value - return value - return wrapper + cache = {} + + def wrapper(): + value = cache.get('value') + if value is None: + value = value_fn() + cache['value'] = value + return value + return wrapper + def image_data_to_png(image_data): - image = Image.fromarray(image_data, 'RGB') - out = BytesIO() - image.save(out, 'png') - return out.getvalue() + image = Image.fromarray(image_data, 'RGB') + out = BytesIO() + image.save(out, 'png') + return out.getvalue() + def png_bytes_to_image_data(png_bytes): - return np.asarray(Image.open(BytesIO(png_bytes)).convert('RGB'), dtype=np.uint8) + return np.asarray(Image.open(BytesIO(png_bytes)).convert('RGB'), dtype=np.uint8) + class InferenceModelWrapper(object): - def __init__(self, export_dir): - self.session_cache = lazy_cached_value(lambda: tf.InteractiveSession()) - self.inference_model_cache = lazy_cached_value( - lambda: load_inference_model(export_dir, session=self.session_cache()) - ) - - def get_color_map(self): - return self.inference_model_cache().get_color_map(session=self.session_cache()) - - def __call__(self, png_pages): - input_data = [ - png_bytes_to_image_data(png_page) - for png_page in png_pages - ] - output_img_data_batch = self.inference_model_cache()(input_data, session=self.session_cache()) - return output_img_data_batch + def __init__(self, export_dir): + self.session_cache = lazy_cached_value(lambda: tf.InteractiveSession()) + self.inference_model_cache = lazy_cached_value( + lambda: load_inference_model(export_dir, session=self.session_cache()) + ) + + def get_color_map(self): + return self.inference_model_cache().get_color_map(session=self.session_cache()) + + def __call__(self, png_pages): + input_data = [ + png_bytes_to_image_data(png_page) + for png_page in png_pages + ] + output_img_data_batch = self.inference_model_cache()( + input_data, session=self.session_cache() + ) + return output_img_data_batch diff --git a/sciencebeam_gym/convert/cv_conversion_utils_test.py b/sciencebeam_gym/convert/cv_conversion_utils_test.py index ae82821..addbaca 100644 --- a/sciencebeam_gym/convert/cv_conversion_utils_test.py +++ b/sciencebeam_gym/convert/cv_conversion_utils_test.py @@ -1,46 +1,64 @@ -import logging from mock import patch +import pytest + from . import cv_conversion_utils as cv_conversion_utils from .cv_conversion_utils import ( - InferenceModelWrapper + InferenceModelWrapper ) CV_MODEL_EXPORT_DIR = './model-export' PNG_BYTES = b'dummy png bytes' -def setup_module(): - logging.basicConfig(level='DEBUG') -class TestInferenceModelWrapper(object): - def test_should_lazy_load_model(self): +@pytest.fixture(name='load_inference_model_mock') +def _load_inference_model_mock(): with patch.object(cv_conversion_utils, 'load_inference_model') as load_inference_model: - with patch.object(cv_conversion_utils, 'png_bytes_to_image_data') as png_bytes_to_image_data: - with patch.object(cv_conversion_utils, 'tf') as tf: - inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR) - load_inference_model.assert_not_called() + yield load_inference_model - output_image_data = inference_model_wrapper([PNG_BYTES]) - tf.InteractiveSession.assert_called_with() - session = tf.InteractiveSession.return_value +@pytest.fixture(name='png_bytes_to_image_data_mock') +def _png_bytes_to_image_data_mock(): + with patch.object(cv_conversion_utils, 'png_bytes_to_image_data') as png_bytes_to_image_data: + yield png_bytes_to_image_data - png_bytes_to_image_data.assert_called_with(PNG_BYTES) - load_inference_model.assert_called_with(CV_MODEL_EXPORT_DIR, session=session) - inference_model = load_inference_model.return_value +@pytest.fixture(name='tf_mock') +def _tf_mock(): + with patch.object(cv_conversion_utils, 'tf') as tf_mock: + yield tf_mock - inference_model.assert_called_with([ - png_bytes_to_image_data.return_value - ], session=session) - assert output_image_data == inference_model.return_value +class TestInferenceModelWrapper(object): + def test_should_lazy_load_model( + self, + load_inference_model_mock, + png_bytes_to_image_data_mock, + tf_mock): + inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR) + load_inference_model_mock.assert_not_called() - def test_should_load_model_only_once(self): - with patch.object(cv_conversion_utils, 'load_inference_model') as load_inference_model: - with patch.object(cv_conversion_utils, 'png_bytes_to_image_data') as _: - with patch.object(cv_conversion_utils, 'tf') as _: - inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR) - inference_model_wrapper([PNG_BYTES]) - inference_model_wrapper([PNG_BYTES]) - load_inference_model.assert_called_once() + output_image_data = inference_model_wrapper([PNG_BYTES]) + + tf_mock.InteractiveSession.assert_called_with() + session = tf_mock.InteractiveSession.return_value + + png_bytes_to_image_data_mock.assert_called_with(PNG_BYTES) + + load_inference_model_mock.assert_called_with(CV_MODEL_EXPORT_DIR, session=session) + inference_model = load_inference_model_mock.return_value + + inference_model.assert_called_with([ + png_bytes_to_image_data_mock.return_value + ], session=session) + + assert output_image_data == inference_model.return_value + + @pytest.mark.usefixtures('png_bytes_to_image_data_mock', 'tf_mock') + def test_should_load_model_only_once( + self, + load_inference_model_mock): + inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR) + inference_model_wrapper([PNG_BYTES]) + inference_model_wrapper([PNG_BYTES]) + load_inference_model_mock.assert_called_once() diff --git a/sciencebeam_gym/convert/grobid/grobid_service.py b/sciencebeam_gym/convert/grobid/grobid_service.py index e80fa08..73f1ca6 100644 --- a/sciencebeam_gym/convert/grobid/grobid_service.py +++ b/sciencebeam_gym/convert/grobid/grobid_service.py @@ -5,80 +5,86 @@ from functools import partial import requests from .grobid_service_wrapper import ( - GrobidServiceWrapper + GrobidServiceWrapper ) + class GrobidApiPaths(object): - PROCESS_HEADER_DOCUMENT = '/processHeaderDocument' - PROCESS_HEADER_NAMES = '/processHeaderNames' - PROCESS_CITATION_NAMES = '/processCitationNames' - PROCESS_AFFILIATIONS = '/processAffiliations' - PROCESS_CITATION = '/processCitation' - PROCESS_FULL_TEXT_DOCUMENT = '/processFulltextDocument' + PROCESS_HEADER_DOCUMENT = '/processHeaderDocument' + PROCESS_HEADER_NAMES = '/processHeaderNames' + PROCESS_CITATION_NAMES = '/processCitationNames' + PROCESS_AFFILIATIONS = '/processAffiliations' + PROCESS_CITATION = '/processCitation' + PROCESS_FULL_TEXT_DOCUMENT = '/processFulltextDocument' + service_wrapper = GrobidServiceWrapper() + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def start_service_if_not_running(): - service_wrapper.start_service_if_not_running() + service_wrapper.start_service_if_not_running() + def run_grobid_service(item, base_url, path, start_service=True, field_name=None): - """ - Translates PDF content via the GROBID service. - - Args: - item: one of: - * tuple (filename, pdf content) - * pdf content - * field content (requires field name) - base_url: base url to the GROBID service - path: path of the GROBID endpoint - start_service: if true, a GROBID service will be started automatically and - kept running until the application ends - field_name: the field name the field content relates to - - Returns: - If item is tuple: - returns tuple (filename, xml result) - Otherwise: - returns xml result - """ - - url = base_url + path - - if start_service: - start_service_if_not_running() - - if field_name: - content = item - response = requests.post(url, - data={field_name: content} - ) - else: - filename = item[0] if isinstance(item, tuple) else 'unknown.pdf' - content = item[1] if isinstance(item, tuple) else item - get_logger().info('processing: %s (%d) - %s', filename, len(content), url) - response = requests.post(url, - files={'input': (filename, BytesIO(content))}, - data={ - 'consolidateHeader': '0', - 'consolidateCitations': '0' - } - ) - response.raise_for_status() - result_content = response.content - if isinstance(item, tuple): - return filename, result_content - else: - return result_content + """ + Translates PDF content via the GROBID service. + + Args: + item: one of: + * tuple (filename, pdf content) + * pdf content + * field content (requires field name) + base_url: base url to the GROBID service + path: path of the GROBID endpoint + start_service: if true, a GROBID service will be started automatically and + kept running until the application ends + field_name: the field name the field content relates to + + Returns: + If item is tuple: + returns tuple (filename, xml result) + Otherwise: + returns xml result + """ + + url = base_url + path + + if start_service: + start_service_if_not_running() + + if field_name: + content = item + response = requests.post(url, + data={field_name: content} + ) + else: + filename = item[0] if isinstance(item, tuple) else 'unknown.pdf' + content = item[1] if isinstance(item, tuple) else item + get_logger().info('processing: %s (%d) - %s', filename, len(content), url) + response = requests.post(url, + files={'input': (filename, BytesIO(content))}, + data={ + 'consolidateHeader': '0', + 'consolidateCitations': '0' + } + ) + response.raise_for_status() + result_content = response.content + if isinstance(item, tuple): + return filename, result_content + else: + return result_content + def grobid_service(base_url, path, start_service=True, field_name=None): - return partial( - run_grobid_service, - base_url=base_url, - path=path, - start_service=start_service, - field_name=field_name - ) + return partial( + run_grobid_service, + base_url=base_url, + path=path, + start_service=start_service, + field_name=field_name + ) diff --git a/sciencebeam_gym/convert/grobid/grobid_service_test.py b/sciencebeam_gym/convert/grobid/grobid_service_test.py index fd02bfe..77cc021 100644 --- a/sciencebeam_gym/convert/grobid/grobid_service_test.py +++ b/sciencebeam_gym/convert/grobid/grobid_service_test.py @@ -14,54 +14,56 @@ PDF_CONTENT_1 = b'pdf content1' FIELD_NAME_1 = 'field1' FIELD_VALUE_1 = 'value1' + @pytest.fixture(name='requests_post', autouse=True) def _mock_requests_post(): - with patch('requests.post') as requests_post: - yield requests_post + with patch('requests.post') as requests_post: + yield requests_post + class TestCreateGrobidService(object): - def test_should_pass_url_and_data_as_file(self, requests_post): - create_grobid_service(BASE_URL, PATH_1, start_service=False)( - (FILENAME_1, PDF_CONTENT_1) - ) - requests_post.assert_called_with( - BASE_URL + PATH_1, - files=ANY, - data=ANY - ) - kwargs = requests_post.call_args[1] - assert kwargs['files']['input'][0] == FILENAME_1 - assert kwargs['files']['input'][1].read() == PDF_CONTENT_1 + def test_should_pass_url_and_data_as_file(self, requests_post): + create_grobid_service(BASE_URL, PATH_1, start_service=False)( + (FILENAME_1, PDF_CONTENT_1) + ) + requests_post.assert_called_with( + BASE_URL + PATH_1, + files=ANY, + data=ANY + ) + kwargs = requests_post.call_args[1] + assert kwargs['files']['input'][0] == FILENAME_1 + assert kwargs['files']['input'][1].read() == PDF_CONTENT_1 - def test_should_be_able_to_override_path_on_call(self, requests_post): - create_grobid_service(BASE_URL, PATH_1, start_service=False)( - (FILENAME_1, PDF_CONTENT_1), - path=PATH_2 - ) - requests_post.assert_called_with( - BASE_URL + PATH_2, - files=ANY, - data=ANY - ) + def test_should_be_able_to_override_path_on_call(self, requests_post): + create_grobid_service(BASE_URL, PATH_1, start_service=False)( + (FILENAME_1, PDF_CONTENT_1), + path=PATH_2 + ) + requests_post.assert_called_with( + BASE_URL + PATH_2, + files=ANY, + data=ANY + ) - def test_should_pass_data_as_field(self, requests_post): - create_grobid_service(BASE_URL, PATH_1, start_service=False, field_name=FIELD_NAME_1)( - FIELD_VALUE_1 - ) - requests_post.assert_called_with( - BASE_URL + PATH_1, - data={FIELD_NAME_1: FIELD_VALUE_1} - ) + def test_should_pass_data_as_field(self, requests_post): + create_grobid_service(BASE_URL, PATH_1, start_service=False, field_name=FIELD_NAME_1)( + FIELD_VALUE_1 + ) + requests_post.assert_called_with( + BASE_URL + PATH_1, + data={FIELD_NAME_1: FIELD_VALUE_1} + ) - def test_should_pass_consolidate_flags(self, requests_post): - create_grobid_service(BASE_URL, PATH_1, start_service=False)( - (FILENAME_1, PDF_CONTENT_1) - ) - requests_post.assert_called_with( - BASE_URL + PATH_1, - files=ANY, - data={ - 'consolidateHeader': '0', - 'consolidateCitations': '0' - } - ) + def test_should_pass_consolidate_flags(self, requests_post): + create_grobid_service(BASE_URL, PATH_1, start_service=False)( + (FILENAME_1, PDF_CONTENT_1) + ) + requests_post.assert_called_with( + BASE_URL + PATH_1, + files=ANY, + data={ + 'consolidateHeader': '0', + 'consolidateCitations': '0' + } + ) diff --git a/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py b/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py index 3f47f37..d69b531 100644 --- a/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py +++ b/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py @@ -13,114 +13,120 @@ from urllib import URLopener from sciencebeam_utils.utils.io import makedirs from sciencebeam_utils.utils.zip import extract_all_with_executable_permission + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def iter_read_lines(reader): - while True: - line = reader.readline() - if not line: - break - yield line + while True: + line = reader.readline() + if not line: + break + yield line + def stream_lines_to_logger(lines, logger, prefix=''): - for line in lines: - line = line.strip() - if line: - logger.info('%s%s', prefix, line) + for line in lines: + line = line.strip() + if line: + logger.info('%s%s', prefix, line) + class GrobidServiceWrapper(object): - def __init__(self): - self.grobid_service_instance = None - temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.temp')) - self.grobid_service_target_directory = os.path.join(temp_dir, 'grobid-service') - self.grobid_service_zip_filename = os.path.join(temp_dir, 'grobid-service.zip') - self.grobid_service_zip_url = ( - 'https://storage.googleapis.com/elife-ml/artefacts/grobid-service.zip' - ) - - def stop_service_if_running(self): - if self.grobid_service_instance is not None: - get_logger().info('stopping instance: %s', self.grobid_service_instance) - self.grobid_service_instance.kill() - - def download__grobid_service_zip_if_not_exist(self): - if not os.path.isfile(self.grobid_service_zip_filename): - get_logger().info( - 'downloading %s to %s', - self.grobid_service_zip_url, - self.grobid_service_zip_filename - ) - - makedirs(os.path.dirname(self.grobid_service_zip_filename), exists_ok=True) - - temp_zip_filename = self.grobid_service_zip_filename + '.part' - if os.path.isfile(temp_zip_filename): - os.remove(temp_zip_filename) - URLopener().retrieve(self.grobid_service_zip_url, temp_zip_filename) - os.rename(temp_zip_filename, self.grobid_service_zip_filename) - - def unzip_grobid_service_zip_if_target_directory_does_not_exist(self): - if not os.path.isdir(self.grobid_service_target_directory): - self.download__grobid_service_zip_if_not_exist() - get_logger().info( - 'unzipping %s to %s', - self.grobid_service_zip_filename, - self.grobid_service_target_directory - ) - temp_target_directory = self.grobid_service_target_directory + '.part' - if os.path.isdir(temp_target_directory): - rmtree(temp_target_directory) - - with ZipFile(self.grobid_service_zip_filename, 'r') as zf: - extract_all_with_executable_permission(zf, temp_target_directory) - sub_dir = os.path.join(temp_target_directory, 'grobid-service') - if os.path.isdir(sub_dir): - os.rename(sub_dir, self.grobid_service_target_directory) - rmtree(temp_target_directory) - else: - os.rename(temp_target_directory, self.grobid_service_target_directory) - - def start_service_if_not_running(self): - get_logger().info('grobid_service_instance: %s', self.grobid_service_instance) - if self.grobid_service_instance is None: - self.unzip_grobid_service_zip_if_target_directory_does_not_exist() - grobid_service_home = os.path.abspath(self.grobid_service_target_directory) - cwd = grobid_service_home + '/bin' - grobid_service_home_jar_dir = grobid_service_home + '/lib' - command_line = 'java -cp "{}/*" org.grobid.service.main.GrobidServiceApplication'.format( - grobid_service_home_jar_dir - ) - args = shlex.split(command_line) - get_logger().info('command_line: %s', command_line) - get_logger().info('args: %s', args) - self.grobid_service_instance = subprocess.Popen( - args, cwd=cwd, stdout=PIPE, stderr=subprocess.STDOUT - ) - if self.grobid_service_instance is None: - raise RuntimeError('failed to start grobid service') - atexit.register(self.stop_service_if_running) - pstdout = self.grobid_service_instance.stdout - out_prefix = 'stdout: ' - while True: - line = pstdout.readline().strip() - if line: - get_logger().info('%s%s', out_prefix, line) - if 'jetty.server.Server: Started' in line: - get_logger().info('grobid service started successfully') - break - if 'ERROR' in line or 'Error' in line: - raise RuntimeError('failed to start grobid service due to {}'.format(line)) - t = Thread(target=partial( - stream_lines_to_logger, - lines=iter_read_lines(pstdout), - logger=get_logger(), - prefix=out_prefix - )) - t.daemon = True - t.start() + def __init__(self): + self.grobid_service_instance = None + temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.temp')) + self.grobid_service_target_directory = os.path.join(temp_dir, 'grobid-service') + self.grobid_service_zip_filename = os.path.join(temp_dir, 'grobid-service.zip') + self.grobid_service_zip_url = ( + 'https://storage.googleapis.com/elife-ml/artefacts/grobid-service.zip' + ) + + def stop_service_if_running(self): + if self.grobid_service_instance is not None: + get_logger().info('stopping instance: %s', self.grobid_service_instance) + self.grobid_service_instance.kill() + + def download__grobid_service_zip_if_not_exist(self): + if not os.path.isfile(self.grobid_service_zip_filename): + get_logger().info( + 'downloading %s to %s', + self.grobid_service_zip_url, + self.grobid_service_zip_filename + ) + + makedirs(os.path.dirname(self.grobid_service_zip_filename), exists_ok=True) + + temp_zip_filename = self.grobid_service_zip_filename + '.part' + if os.path.isfile(temp_zip_filename): + os.remove(temp_zip_filename) + URLopener().retrieve(self.grobid_service_zip_url, temp_zip_filename) + os.rename(temp_zip_filename, self.grobid_service_zip_filename) + + def unzip_grobid_service_zip_if_target_directory_does_not_exist(self): + if not os.path.isdir(self.grobid_service_target_directory): + self.download__grobid_service_zip_if_not_exist() + get_logger().info( + 'unzipping %s to %s', + self.grobid_service_zip_filename, + self.grobid_service_target_directory + ) + temp_target_directory = self.grobid_service_target_directory + '.part' + if os.path.isdir(temp_target_directory): + rmtree(temp_target_directory) + + with ZipFile(self.grobid_service_zip_filename, 'r') as zf: + extract_all_with_executable_permission(zf, temp_target_directory) + sub_dir = os.path.join(temp_target_directory, 'grobid-service') + if os.path.isdir(sub_dir): + os.rename(sub_dir, self.grobid_service_target_directory) + rmtree(temp_target_directory) + else: + os.rename(temp_target_directory, self.grobid_service_target_directory) + + def start_service_if_not_running(self): + get_logger().info('grobid_service_instance: %s', self.grobid_service_instance) + if self.grobid_service_instance is None: + self.unzip_grobid_service_zip_if_target_directory_does_not_exist() + grobid_service_home = os.path.abspath(self.grobid_service_target_directory) + cwd = grobid_service_home + '/bin' + grobid_service_home_jar_dir = grobid_service_home + '/lib' + main_class = "org.grobid.service.main.GrobidServiceApplication" + command_line = 'java -cp "%s/*" %s' % ( + grobid_service_home_jar_dir, main_class + ) + args = shlex.split(command_line) + get_logger().info('command_line: %s', command_line) + get_logger().info('args: %s', args) + self.grobid_service_instance = subprocess.Popen( + args, cwd=cwd, stdout=PIPE, stderr=subprocess.STDOUT + ) + if self.grobid_service_instance is None: + raise RuntimeError('failed to start grobid service') + atexit.register(self.stop_service_if_running) + pstdout = self.grobid_service_instance.stdout + out_prefix = 'stdout: ' + while True: + line = pstdout.readline().strip() + if line: + get_logger().info('%s%s', out_prefix, line) + if 'jetty.server.Server: Started' in line: + get_logger().info('grobid service started successfully') + break + if 'ERROR' in line or 'Error' in line: + raise RuntimeError('failed to start grobid service due to {}'.format(line)) + t = Thread(target=partial( + stream_lines_to_logger, + lines=iter_read_lines(pstdout), + logger=get_logger(), + prefix=out_prefix + )) + t.daemon = True + t.start() + if __name__ == '__main__': - logging.basicConfig(level='INFO') + logging.basicConfig(level='INFO') - GrobidServiceWrapper().start_service_if_not_running() + GrobidServiceWrapper().start_service_if_not_running() diff --git a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py index 26646cb..74104ab 100644 --- a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py +++ b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py @@ -5,14 +5,14 @@ from lxml import etree from lxml.builder import E from sciencebeam_gym.inference_model.extract_to_xml import ( - XmlPaths, - create_node_recursive, - rsplit_xml_path + XmlPaths, + create_node_recursive, + rsplit_xml_path ) from .grobid_service import ( - grobid_service, - GrobidApiPaths + grobid_service, + GrobidApiPaths ) TEI_NS = 'http://www.tei-c.org/ns/1.0' @@ -27,83 +27,86 @@ JATS_ADDR_LINE = 'addr-line' JATS_NAMED_CONTENT = 'named-content' JATS_INSTITUTION = 'institution' + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def create_or_append(xml_root, path): - parent_path, tag_name = rsplit_xml_path(path) - parent_node = create_node_recursive(xml_root, parent_path, exists_ok=True) - node = E(tag_name) - parent_node.append(node) - return node + parent_path, tag_name = rsplit_xml_path(path) + parent_node = create_node_recursive(xml_root, parent_path, exists_ok=True) + node = E(tag_name) + parent_node.append(node) + return node + class GrobidXmlEnhancer(object): - def __init__(self, grobid_url, start_service): - self.process_header_names = grobid_service( - grobid_url, - GrobidApiPaths.PROCESS_HEADER_NAMES, - start_service=start_service, - field_name='names' - ) - self.process_affiliations = grobid_service( - grobid_url, - GrobidApiPaths.PROCESS_AFFILIATIONS, - start_service=start_service, - field_name='affiliations' - ) + def __init__(self, grobid_url, start_service): + self.process_header_names = grobid_service( + grobid_url, + GrobidApiPaths.PROCESS_HEADER_NAMES, + start_service=start_service, + field_name='names' + ) + self.process_affiliations = grobid_service( + grobid_url, + GrobidApiPaths.PROCESS_AFFILIATIONS, + start_service=start_service, + field_name='affiliations' + ) - def process_and_replace_authors(self, xml_root): - author_nodes = list(xml_root.findall(XmlPaths.AUTHOR)) - if author_nodes: - authors = '\n'.join(x.text for x in author_nodes) - get_logger().debug('authors: %s', authors) - grobid_response = self.process_header_names(authors) - get_logger().debug('grobid_response: %s', grobid_response) - response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response)) - for author in author_nodes: - author.getparent().remove(author) - for pers_name in response_xml_root.findall(TEI_PERS_NAME): - get_logger().debug('pers_name: %s', pers_name) - node = create_or_append(xml_root, XmlPaths.AUTHOR) - for surname in pers_name.findall(TEI_SURNAME): - node.append(E(JATS_SURNAME, surname.text)) - forenames = [x.text for x in pers_name.findall(TEI_FORNAME)] - if forenames: - node.append(E(JATS_GIVEN_NAMES, ' '.join(forenames))) - return xml_root + def process_and_replace_authors(self, xml_root): + author_nodes = list(xml_root.findall(XmlPaths.AUTHOR)) + if author_nodes: + authors = '\n'.join(x.text for x in author_nodes) + get_logger().debug('authors: %s', authors) + grobid_response = self.process_header_names(authors) + get_logger().debug('grobid_response: %s', grobid_response) + response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response)) + for author in author_nodes: + author.getparent().remove(author) + for pers_name in response_xml_root.findall(TEI_PERS_NAME): + get_logger().debug('pers_name: %s', pers_name) + node = create_or_append(xml_root, XmlPaths.AUTHOR) + for surname in pers_name.findall(TEI_SURNAME): + node.append(E(JATS_SURNAME, surname.text)) + forenames = [x.text for x in pers_name.findall(TEI_FORNAME)] + if forenames: + node.append(E(JATS_GIVEN_NAMES, ' '.join(forenames))) + return xml_root - def process_and_replace_affiliations(self, xml_root): - aff_nodes = list(xml_root.findall(XmlPaths.AUTHOR_AFF)) - if aff_nodes: - affiliations = '\n'.join(x.text for x in aff_nodes) - get_logger().debug('affiliations: %s', affiliations) - grobid_response = self.process_affiliations(affiliations) - get_logger().debug('grobid_response: %s', grobid_response) - response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response)) - for aff in aff_nodes: - aff.getparent().remove(aff) - for affiliation in response_xml_root.findall('affiliation'): - get_logger().debug('affiliation: %s', affiliation) - node = create_or_append(xml_root, XmlPaths.AUTHOR_AFF) - for department in affiliation.xpath('./orgName[@type="department"]'): - node.append(E( - JATS_ADDR_LINE, - E( - JATS_NAMED_CONTENT, - department.text, - { - 'content-type': 'department' - } - ) - )) - for institution in affiliation.xpath('./orgName[@type="institution"]'): - node.append(E( - JATS_INSTITUTION, - institution.text - )) + def process_and_replace_affiliations(self, xml_root): + aff_nodes = list(xml_root.findall(XmlPaths.AUTHOR_AFF)) + if aff_nodes: + affiliations = '\n'.join(x.text for x in aff_nodes) + get_logger().debug('affiliations: %s', affiliations) + grobid_response = self.process_affiliations(affiliations) + get_logger().debug('grobid_response: %s', grobid_response) + response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response)) + for aff in aff_nodes: + aff.getparent().remove(aff) + for affiliation in response_xml_root.findall('affiliation'): + get_logger().debug('affiliation: %s', affiliation) + node = create_or_append(xml_root, XmlPaths.AUTHOR_AFF) + for department in affiliation.xpath('./orgName[@type="department"]'): + node.append(E( + JATS_ADDR_LINE, + E( + JATS_NAMED_CONTENT, + department.text, + { + 'content-type': 'department' + } + ) + )) + for institution in affiliation.xpath('./orgName[@type="institution"]'): + node.append(E( + JATS_INSTITUTION, + institution.text + )) - def __call__(self, extracted_xml): - xml_root = etree.parse(BytesIO(extracted_xml)) - self.process_and_replace_authors(xml_root) - self.process_and_replace_affiliations(xml_root) - return etree.tostring(xml_root, pretty_print=True) + def __call__(self, extracted_xml): + xml_root = etree.parse(BytesIO(extracted_xml)) + self.process_and_replace_authors(xml_root) + self.process_and_replace_affiliations(xml_root) + return etree.tostring(xml_root, pretty_print=True) diff --git a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py index e2b7052..7726fb9 100644 --- a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py +++ b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py @@ -7,17 +7,17 @@ from lxml import etree from lxml.builder import E from sciencebeam_gym.inference_model.extract_to_xml import ( - XmlPaths, - create_xml_text + XmlPaths, + create_xml_text ) from .grobid_service import ( - GrobidApiPaths + GrobidApiPaths ) from . import grobid_xml_enhancer as grobid_xml_enhancer from .grobid_xml_enhancer import ( - GrobidXmlEnhancer + GrobidXmlEnhancer ) GROBID_URL = 'http://localhost:8080/api' @@ -38,170 +38,179 @@ INSTITUTION_2 = 'institution 2' DEPARTMENT_XPATH = './addr-line/named-content[@content-type="department"]' INSTITUTION_XPATH = './institution' + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + def pers_name(*names): - forenames = names[:-1] - surname = names[-1] - return ( - '<persName xmlns="http://www.tei-c.org/ns/1.0">' - ' %s' - ' <surname>%s</surname>' - '</persName>' - ) % ( - ' '.join( - '<forename type="%s">%s</forename>' % ( - 'first' if i == 0 else 'middle', - forename - ) - for i, forename in enumerate(forenames) - ), - surname - ) + forenames = names[:-1] + surname = names[-1] + return ( + '<persName xmlns="http://www.tei-c.org/ns/1.0">' + ' %s' + ' <surname>%s</surname>' + '</persName>' + ) % ( + ' '.join( + '<forename type="%s">%s</forename>' % ( + 'first' if i == 0 else 'middle', + forename + ) + for i, forename in enumerate(forenames) + ), + surname + ) + def tei_affiliation(department=None, institution=None): - affiliation = E.affiliation() - if department: - affiliation.append(E.orgName(department, type='department')) - if institution: - affiliation.append(E.orgName(institution, type='institution')) - return etree.tostring(affiliation) + affiliation = E.affiliation() + if department: + affiliation.append(E.orgName(department, type='department')) + if institution: + affiliation.append(E.orgName(institution, type='institution')) + return etree.tostring(affiliation) + def get_text(node): - if isinstance(node, list): - return ''.join(get_text(x) for x in node) - return node.text if node is not None else None + if isinstance(node, list): + return ''.join(get_text(x) for x in node) + return node.text if node is not None else None + def get_child_text(node, name): - return get_text(node.find(name)) + return get_text(node.find(name)) + @contextmanager def patch_grobid_service(): - with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service: - process_header_names = Mock() - process_affiliations = Mock() - grobid_service.side_effect = [process_header_names, process_affiliations] - yield process_header_names, process_affiliations + with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service: + process_header_names = Mock() + process_affiliations = Mock() + grobid_service.side_effect = [process_header_names, process_affiliations] + yield process_header_names, process_affiliations + class TestGrobidXmlEnhancer(object): - def test_should_initialise_grobid_service(self): - with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service: - GrobidXmlEnhancer(GROBID_URL, start_service=False) - grobid_service.assert_any_call( - GROBID_URL, GrobidApiPaths.PROCESS_HEADER_NAMES, start_service=False, field_name='names' - ) - grobid_service.assert_any_call( - GROBID_URL, GrobidApiPaths.PROCESS_AFFILIATIONS, start_service=False, - field_name='affiliations' - ) - - def test_should_convert_single_author(self): - logging.basicConfig(level='DEBUG') - with patch_grobid_service() as (process_header_names, process_affiliations): - _ = process_affiliations - process_header_names.return_value = pers_name(FORENAME_1, SURNAME_1) - enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) - xml_root = E.article() - create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1) - enhanced_xml = enhancer(etree.tostring(xml_root)) - get_logger().info('enhanced_xml: %s', enhanced_xml) - enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) - authors = enhanced_xml_root.findall(XmlPaths.AUTHOR) - assert [ - (get_child_text(author, 'given-names'), get_child_text(author, 'surname')) - for author in authors - ] == [(FORENAME_1, SURNAME_1)] - - def test_should_convert_single_affiliation(self): - logging.basicConfig(level='DEBUG') - with patch_grobid_service() as (process_header_names, process_affiliations): - _ = process_header_names - process_affiliations.return_value = tei_affiliation( - department=DEPARTMENT_1, - institution=INSTITUTION_1 - ) - enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) - xml_root = E.article() - create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1) - enhanced_xml = enhancer(etree.tostring(xml_root)) - get_logger().info('enhanced_xml: %s', enhanced_xml) - enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) - affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF) - assert [ - get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations - ] == [DEPARTMENT_1] - assert [ - get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations - ] == [INSTITUTION_1] - - def test_should_convert_multiple_author(self): - logging.basicConfig(level='DEBUG') - with patch_grobid_service() as (process_header_names, process_affiliations): - _ = process_affiliations - process_header_names.return_value = ( - pers_name(FORENAME_1, SURNAME_1) + - pers_name(FORENAME_2, SURNAME_2) - ) - enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) - xml_root = E.article() - create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1) - create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_2) - enhanced_xml = enhancer(etree.tostring(xml_root)) - get_logger().info('enhanced_xml: %s', enhanced_xml) - enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) - authors = enhanced_xml_root.findall(XmlPaths.AUTHOR) - assert [ - (get_child_text(author, 'given-names'), get_child_text(author, 'surname')) - for author in authors - ] == [(FORENAME_1, SURNAME_1), (FORENAME_2, SURNAME_2)] - - def test_should_convert_multiple_affiliations(self): - logging.basicConfig(level='DEBUG') - with patch_grobid_service() as (process_header_names, process_affiliations): - _ = process_header_names - process_affiliations.return_value = ( - tei_affiliation( - department=DEPARTMENT_1, - institution=INSTITUTION_1 - ) + - tei_affiliation( - department=DEPARTMENT_2, - institution=INSTITUTION_2 - ) - ) - enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) - xml_root = E.article() - create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1) - enhanced_xml = enhancer(etree.tostring(xml_root)) - get_logger().info('enhanced_xml: %s', enhanced_xml) - enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) - affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF) - assert [ - get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations - ] == [DEPARTMENT_1, DEPARTMENT_2] - assert [ - get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations - ] == [INSTITUTION_1, INSTITUTION_2] - - def test_should_combine_multiple_forenames(self): - logging.basicConfig(level='DEBUG') - with patch_grobid_service() as (process_header_names, process_affiliations): - _ = process_affiliations - process_header_names.return_value = pers_name( - FORENAME_1, FORENAME_2, SURNAME_1 - ) - enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) - xml_root = E.article() - create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1) - enhanced_xml = enhancer(etree.tostring(xml_root)) - get_logger().info('enhanced_xml: %s', enhanced_xml) - enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) - authors = enhanced_xml_root.findall(XmlPaths.AUTHOR) - assert [ - (get_child_text(author, 'given-names'), get_child_text(author, 'surname')) - for author in authors - ] == [(' '.join([FORENAME_1, FORENAME_2]), SURNAME_1)] + def test_should_initialise_grobid_service(self): + with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service: + GrobidXmlEnhancer(GROBID_URL, start_service=False) + grobid_service.assert_any_call( + GROBID_URL, GrobidApiPaths.PROCESS_HEADER_NAMES, + start_service=False, field_name='names' + ) + grobid_service.assert_any_call( + GROBID_URL, GrobidApiPaths.PROCESS_AFFILIATIONS, start_service=False, + field_name='affiliations' + ) + + def test_should_convert_single_author(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_affiliations # flake8: noqa + process_header_names.return_value = pers_name(FORENAME_1, SURNAME_1) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + authors = enhanced_xml_root.findall(XmlPaths.AUTHOR) + assert [ + (get_child_text(author, 'given-names'), get_child_text(author, 'surname')) + for author in authors + ] == [(FORENAME_1, SURNAME_1)] + + def test_should_convert_single_affiliation(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_header_names # flake8: noqa + process_affiliations.return_value = tei_affiliation( + department=DEPARTMENT_1, + institution=INSTITUTION_1 + ) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF) + assert [ + get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations + ] == [DEPARTMENT_1] + assert [ + get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations + ] == [INSTITUTION_1] + + def test_should_convert_multiple_author(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_affiliations # flake8: noqa + process_header_names.return_value = ( + pers_name(FORENAME_1, SURNAME_1) + + pers_name(FORENAME_2, SURNAME_2) + ) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1) + create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_2) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + authors = enhanced_xml_root.findall(XmlPaths.AUTHOR) + assert [ + (get_child_text(author, 'given-names'), get_child_text(author, 'surname')) + for author in authors + ] == [(FORENAME_1, SURNAME_1), (FORENAME_2, SURNAME_2)] + + def test_should_convert_multiple_affiliations(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_header_names # flake8: noqa + process_affiliations.return_value = ( + tei_affiliation( + department=DEPARTMENT_1, + institution=INSTITUTION_1 + ) + + tei_affiliation( + department=DEPARTMENT_2, + institution=INSTITUTION_2 + ) + ) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF) + assert [ + get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations + ] == [DEPARTMENT_1, DEPARTMENT_2] + assert [ + get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations + ] == [INSTITUTION_1, INSTITUTION_2] + + def test_should_combine_multiple_forenames(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_affiliations # flake8: noqa + process_header_names.return_value = pers_name( + FORENAME_1, FORENAME_2, SURNAME_1 + ) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + authors = enhanced_xml_root.findall(XmlPaths.AUTHOR) + assert [ + (get_child_text(author, 'given-names'), get_child_text(author, 'surname')) + for author in authors + ] == [(' '.join([FORENAME_1, FORENAME_2]), SURNAME_1)] diff --git a/sciencebeam_gym/inference_model/__init__.py b/sciencebeam_gym/inference_model/__init__.py index c26ac14..a8a17ea 100644 --- a/sciencebeam_gym/inference_model/__init__.py +++ b/sciencebeam_gym/inference_model/__init__.py @@ -15,93 +15,98 @@ OUTPUTS_KEY = 'annotation' LABELS_KEY = 'labels' COLORS_KEY = 'colors' + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class InferenceModel(object): - def __init__(self, inputs, outputs, labels_tensor=None, colors_tensor=None): - self.inputs_tensor = inputs - self.outputs_tensor = outputs - self.labels_tensor = labels_tensor - self.colors_tensor = colors_tensor - self._color_map = None + def __init__(self, inputs, outputs, labels_tensor=None, colors_tensor=None): + self.inputs_tensor = inputs + self.outputs_tensor = outputs + self.labels_tensor = labels_tensor + self.colors_tensor = colors_tensor + self._color_map = None - def get_color_map(self, session=None): - if self._color_map is None: - assert self.labels_tensor is not None - assert self.colors_tensor is not None - session = session or tf.get_default_session() - assert session is not None - labels, colors = session.run( - [self.labels_tensor, self.colors_tensor] - ) - self._color_map = { - k: tuple(v) - for k, v in zip(labels, colors) - } - return self._color_map + def get_color_map(self, session=None): + if self._color_map is None: + assert self.labels_tensor is not None + assert self.colors_tensor is not None + session = session or tf.get_default_session() + assert session is not None + labels, colors = session.run( + [self.labels_tensor, self.colors_tensor] + ) + self._color_map = { + k: tuple(v) + for k, v in zip(labels, colors) + } + return self._color_map + + def __call__(self, inputs, session=None): + session = session or tf.get_default_session() + assert session is not None + return session.run(self.outputs_tensor, feed_dict={ + self.inputs_tensor: inputs + }) - def __call__(self, inputs, session=None): - session = session or tf.get_default_session() - assert session is not None - return session.run(self.outputs_tensor, feed_dict={ - self.inputs_tensor: inputs - }) def save_inference_model(export_dir, inference_model, session=None, replace=True): - if session is None: - session = tf.get_default_session() - assert session is not None - if replace and is_directory(export_dir): - get_logger().info('replacing %s', export_dir) - delete_recursively(export_dir) - prediction_signature = predict_signature_def( - inputs={INPUTS_KEY: inference_model.inputs_tensor}, - outputs={k: v for k, v in { - OUTPUTS_KEY: inference_model.outputs_tensor, - LABELS_KEY: inference_model.labels_tensor, - COLORS_KEY: inference_model.colors_tensor - }.items() if v is not None + if session is None: + session = tf.get_default_session() + assert session is not None + if replace and is_directory(export_dir): + get_logger().info('replacing %s', export_dir) + delete_recursively(export_dir) + prediction_signature = predict_signature_def( + inputs={INPUTS_KEY: inference_model.inputs_tensor}, + outputs={k: v for k, v in { + OUTPUTS_KEY: inference_model.outputs_tensor, + LABELS_KEY: inference_model.labels_tensor, + COLORS_KEY: inference_model.colors_tensor + }.items() if v is not None + } + ) + signature_def_map = { + DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature } - ) - signature_def_map = { - DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature - } - legacy_init_op = tf.group( - tf.tables_initializer(), - name='legacy_init_op' - ) - builder = SavedModelBuilder(export_dir) - builder.add_meta_graph_and_variables( - session, - [SERVING], - signature_def_map=signature_def_map, - legacy_init_op=legacy_init_op - ) - builder.save() + legacy_init_op = tf.group( + tf.tables_initializer(), + name='legacy_init_op' + ) + builder = SavedModelBuilder(export_dir) + builder.add_meta_graph_and_variables( + session, + [SERVING], + signature_def_map=signature_def_map, + legacy_init_op=legacy_init_op + ) + builder.save() + def get_output_tensor_or_none(graph, signature, name): - tensor_name = signature.outputs[name].name - return graph.get_tensor_by_name(tensor_name) if tensor_name else None + tensor_name = signature.outputs[name].name + return graph.get_tensor_by_name(tensor_name) if tensor_name else None + def load_inference_model(export_dir, session=None): - if session is None: - session = tf.get_default_session() - assert session is not None - meta_graph_def = tf.saved_model.loader.load(session, [SERVING], export_dir) - signature = meta_graph_def.signature_def[DEFAULT_SERVING_SIGNATURE_DEF_KEY] - inputs_name = signature.inputs[INPUTS_KEY].name - outputs_name = signature.outputs[OUTPUTS_KEY].name - get_logger().info('inputs_name: %s', inputs_name) - get_logger().info('outputs_name: %s', outputs_name) - graph = tf.get_default_graph() - inputs = graph.get_tensor_by_name(inputs_name) - outputs = graph.get_tensor_by_name(outputs_name) - get_logger().info('inputs: %s', inputs) - get_logger().info('output: %s', outputs) - return InferenceModel( - inputs, - outputs, - get_output_tensor_or_none(graph, signature, LABELS_KEY), - get_output_tensor_or_none(graph, signature, COLORS_KEY) - ) + if session is None: + session = tf.get_default_session() + assert session is not None + meta_graph_def = tf.saved_model.loader.load(session, [SERVING], export_dir) + signature = meta_graph_def.signature_def[DEFAULT_SERVING_SIGNATURE_DEF_KEY] + inputs_name = signature.inputs[INPUTS_KEY].name + outputs_name = signature.outputs[OUTPUTS_KEY].name + get_logger().info('inputs_name: %s', inputs_name) + get_logger().info('outputs_name: %s', outputs_name) + graph = tf.get_default_graph() + inputs = graph.get_tensor_by_name(inputs_name) + outputs = graph.get_tensor_by_name(outputs_name) + get_logger().info('inputs: %s', inputs) + get_logger().info('output: %s', outputs) + return InferenceModel( + inputs, + outputs, + get_output_tensor_or_none(graph, signature, LABELS_KEY), + get_output_tensor_or_none(graph, signature, COLORS_KEY) + ) diff --git a/sciencebeam_gym/inference_model/__init___test.py b/sciencebeam_gym/inference_model/__init___test.py index 9eb117f..01b60a9 100644 --- a/sciencebeam_gym/inference_model/__init___test.py +++ b/sciencebeam_gym/inference_model/__init___test.py @@ -6,13 +6,13 @@ import tensorflow as tf import numpy as np from sciencebeam_utils.utils.num import ( - assert_all_close + assert_all_close ) from sciencebeam_gym.inference_model import ( - InferenceModel, - save_inference_model, - load_inference_model + InferenceModel, + save_inference_model, + load_inference_model ) TEMP_DIR = '.temp/tests/%s' % __name__ @@ -20,59 +20,63 @@ TEMP_DIR = '.temp/tests/%s' % __name__ LABELS = ['label 1', 'label 2', 'label 3'] COLORS = [(1, 1, 1), (2, 2, 2), (3, 3, 3)] + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + class TestInferenceModelSaverLoader(object): - def test_should_export_import_and_run_simple_model(self): - export_dir = os.path.join(TEMP_DIR, 'export') - if os.path.isdir(export_dir): - rmtree(export_dir) + def test_should_export_import_and_run_simple_model(self): + export_dir = os.path.join(TEMP_DIR, 'export') + if os.path.isdir(export_dir): + rmtree(export_dir) - # sample fn that works with tf and np - sample_fn = lambda x: x * 2.0 + 10.0 + # sample fn that works with tf and np + def sample_fn(x): + return x * 2.0 + 10.0 - with tf.Graph().as_default(): - with tf.variable_scope('scope1'): - inputs = tf.placeholder(tf.float32, (None, 16, 16, 3)) - outputs = sample_fn(inputs) - get_logger().info('outputs: %s', outputs) - with tf.Session() as session: - save_inference_model( - export_dir, - InferenceModel(inputs, outputs) - ) - with tf.Graph().as_default(): - with tf.Session() as session: - inference_model = load_inference_model(export_dir) - inputs_value = np.ones((5, 16, 16, 3)) - assert_all_close( - inference_model(inputs_value, session=session), - sample_fn(inputs_value) - ) + with tf.Graph().as_default(): + with tf.variable_scope('scope1'): + inputs = tf.placeholder(tf.float32, (None, 16, 16, 3)) + outputs = sample_fn(inputs) + get_logger().info('outputs: %s', outputs) + with tf.Session() as session: + save_inference_model( + export_dir, + InferenceModel(inputs, outputs) + ) + with tf.Graph().as_default(): + with tf.Session() as session: + inference_model = load_inference_model(export_dir) + inputs_value = np.ones((5, 16, 16, 3)) + assert_all_close( + inference_model(inputs_value, session=session), + sample_fn(inputs_value) + ) - def test_should_export_import_color_map(self): - export_dir = os.path.join(TEMP_DIR, 'export') - if os.path.isdir(export_dir): - rmtree(export_dir) + def test_should_export_import_color_map(self): + export_dir = os.path.join(TEMP_DIR, 'export') + if os.path.isdir(export_dir): + rmtree(export_dir) - with tf.Graph().as_default(): - with tf.variable_scope('scope1'): - inputs = tf.placeholder(tf.float32, (None, 16, 16, 3)) - outputs = inputs * 2.0 - labels = tf.constant(LABELS) - colors = tf.constant(COLORS) - with tf.Session(): - save_inference_model( - export_dir, - InferenceModel(inputs, outputs, labels, colors) - ) - with tf.Graph().as_default(): - with tf.Session(): - inference_model = load_inference_model(export_dir) - color_map = inference_model.get_color_map() - get_logger().debug('color_map: %s', color_map) - assert set(color_map.items()) == set(zip(LABELS, COLORS)) + with tf.Graph().as_default(): + with tf.variable_scope('scope1'): + inputs = tf.placeholder(tf.float32, (None, 16, 16, 3)) + outputs = inputs * 2.0 + labels = tf.constant(LABELS) + colors = tf.constant(COLORS) + with tf.Session(): + save_inference_model( + export_dir, + InferenceModel(inputs, outputs, labels, colors) + ) + with tf.Graph().as_default(): + with tf.Session(): + inference_model = load_inference_model(export_dir) + color_map = inference_model.get_color_map() + get_logger().debug('color_map: %s', color_map) + assert set(color_map.items()) == set(zip(LABELS, COLORS)) diff --git a/sciencebeam_gym/inference_model/annotate_using_predictions.py b/sciencebeam_gym/inference_model/annotate_using_predictions.py index d3af139..6cef2a6 100644 --- a/sciencebeam_gym/inference_model/annotate_using_predictions.py +++ b/sciencebeam_gym/inference_model/annotate_using_predictions.py @@ -8,183 +8,198 @@ import numpy as np from PIL import Image from sciencebeam_utils.beam_utils.io import ( - read_all_from_path + read_all_from_path ) from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.preprocess.color_map import ( - parse_color_map_from_file + parse_color_map_from_file ) from sciencebeam_gym.structured_document.structured_document_loader import ( - load_structured_document + load_structured_document ) from sciencebeam_gym.structured_document.structured_document_saver import ( - save_structured_document + save_structured_document ) CV_TAG_SCOPE = 'cv' + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class AnnotatedImage(object): - def __init__(self, data, color_map): - self.data = data - self.color_map = color_map - self.size = (data.shape[1], data.shape[0]) - - def get_tag_probabilities_within(self, bounding_box): - image_area = self.data[ - int(bounding_box.y):int(bounding_box.y + bounding_box.height), - int(bounding_box.x):int(bounding_box.x + bounding_box.width) - ] - counts = { - k: np.sum(np.all(image_area == v, axis=-1)) - for k, v in self.color_map.items() - } - total = image_area.size / image_area.shape[-1] - return { - k: v / total if total > 0.0 else 0.0 - for k, v in counts.items() - } + def __init__(self, data, color_map): + self.data = data + self.color_map = color_map + self.size = (data.shape[1], data.shape[0]) + + def get_tag_probabilities_within(self, bounding_box): + image_area = self.data[ + int(bounding_box.y):int(bounding_box.y + bounding_box.height), + int(bounding_box.x):int(bounding_box.x + bounding_box.width) + ] + counts = { + k: np.sum(np.all(image_area == v, axis=-1)) + for k, v in self.color_map.items() + } + total = image_area.size / image_area.shape[-1] + return { + k: v / total if total > 0.0 else 0.0 + for k, v in counts.items() + } + def calculate_rescale_factors(structured_document, page, annotated_image): - page_bounding_box = structured_document.get_bounding_box(page) - get_logger().debug('page_bounding_box: %s', page_bounding_box) - assert page_bounding_box is not None - page_width = page_bounding_box.width - page_height = page_bounding_box.height - annotated_image_width, annotated_image_height = annotated_image.size - get_logger().debug( - 'annotated_image width, height: %f, %f', - annotated_image_width, annotated_image_height - ) - rx = annotated_image_width / page_width - ry = annotated_image_height / page_height - return rx, ry + page_bounding_box = structured_document.get_bounding_box(page) + get_logger().debug('page_bounding_box: %s', page_bounding_box) + assert page_bounding_box is not None + page_width = page_bounding_box.width + page_height = page_bounding_box.height + annotated_image_width, annotated_image_height = annotated_image.size + get_logger().debug( + 'annotated_image width, height: %f, %f', + annotated_image_width, annotated_image_height + ) + rx = annotated_image_width / page_width + ry = annotated_image_height / page_height + return rx, ry + def scale_bounding_box(bounding_box, rx, ry): - return BoundingBox( - bounding_box.x * rx, - bounding_box.y * ry, - bounding_box.width * rx, - bounding_box.height * ry - ) + return BoundingBox( + bounding_box.x * rx, + bounding_box.y * ry, + bounding_box.width * rx, + bounding_box.height * ry + ) + def annotate_page_using_predicted_image( - structured_document, page, annotated_image, tag_scope=CV_TAG_SCOPE): - - rx, ry = calculate_rescale_factors(structured_document, page, annotated_image) - get_logger().debug('rx, ry: %f, %f', rx, ry) - for line in structured_document.get_lines_of_page(page): - for token in structured_document.get_tokens_of_line(line): - bounding_box = structured_document.get_bounding_box(token) - if bounding_box: - get_logger().debug('original bounding_box: %s', bounding_box) - bounding_box = scale_bounding_box(bounding_box, rx, ry) - get_logger().debug('scaled bounding_box: %s', bounding_box) - tag_probabilites = sorted( - ((k, v) for k, v in annotated_image.get_tag_probabilities_within(bounding_box).items()), - key=lambda x: x[1], - reverse=True - ) - top_probability_tag, top_probability_value = tag_probabilites[0] - get_logger().debug('tag_probabilites: %s', tag_probabilites) - if top_probability_value > 0.5: - get_logger().debug( - 'tagging token: %s: %s', - top_probability_tag, - structured_document.get_text(token) - ) - structured_document.set_tag(token, top_probability_tag, scope=tag_scope) + structured_document, page, annotated_image, tag_scope=CV_TAG_SCOPE): + + rx, ry = calculate_rescale_factors(structured_document, page, annotated_image) + get_logger().debug('rx, ry: %f, %f', rx, ry) + for line in structured_document.get_lines_of_page(page): + for token in structured_document.get_tokens_of_line(line): + bounding_box = structured_document.get_bounding_box(token) + if bounding_box: + get_logger().debug('original bounding_box: %s', bounding_box) + bounding_box = scale_bounding_box(bounding_box, rx, ry) + get_logger().debug('scaled bounding_box: %s', bounding_box) + tag_probabilites = sorted( + ( + (k, v) + for k, v in annotated_image.get_tag_probabilities_within( + bounding_box + ).items() + ), + key=lambda x: x[1], + reverse=True + ) + top_probability_tag, top_probability_value = tag_probabilites[0] + get_logger().debug('tag_probabilites: %s', tag_probabilites) + if top_probability_value > 0.5: + get_logger().debug( + 'tagging token: %s: %s', + top_probability_tag, + structured_document.get_text(token) + ) + structured_document.set_tag(token, top_probability_tag, scope=tag_scope) + def annotate_structured_document_using_predicted_images( - structured_document, annotated_images, tag_scope=CV_TAG_SCOPE): + structured_document, annotated_images, tag_scope=CV_TAG_SCOPE): + + for page, annotated_image in zip(structured_document.get_pages(), annotated_images): + annotate_page_using_predicted_image( + structured_document, page, annotated_image, tag_scope=tag_scope + ) + return structured_document - for page, annotated_image in zip(structured_document.get_pages(), annotated_images): - annotate_page_using_predicted_image( - structured_document, page, annotated_image, tag_scope=tag_scope - ) - return structured_document def parse_args(argv=None): - parser = argparse.ArgumentParser('Annotated LXML using prediction images') - source = parser.add_mutually_exclusive_group(required=True) - source.add_argument( - '--lxml-path', type=str, required=False, - help='path to lxml or svg pages document' - ) - - images = parser.add_mutually_exclusive_group(required=True) - images.add_argument( - '--images-path', type=str, nargs='+', - help='path to lxml document' - ) - - parser.add_argument( - '--output-path', type=str, required=True, - help='output path to annotated document' - ) - - parser.add_argument( - '--tag-scope', type=str, required=False, - default=CV_TAG_SCOPE, - help='target tag scope for the predicted tags' - ) - - parser.add_argument( - '--color-map', default='color_map.conf', - help='color map to use' - ) - - parser.add_argument( - '--debug', action='store_true', default=False, - help='enable debug output' - ) - - return parser.parse_args(argv) + parser = argparse.ArgumentParser('Annotated LXML using prediction images') + source = parser.add_mutually_exclusive_group(required=True) + source.add_argument( + '--lxml-path', type=str, required=False, + help='path to lxml or svg pages document' + ) + + images = parser.add_mutually_exclusive_group(required=True) + images.add_argument( + '--images-path', type=str, nargs='+', + help='path to lxml document' + ) + + parser.add_argument( + '--output-path', type=str, required=True, + help='output path to annotated document' + ) + + parser.add_argument( + '--tag-scope', type=str, required=False, + default=CV_TAG_SCOPE, + help='target tag scope for the predicted tags' + ) + + parser.add_argument( + '--color-map', default='color_map.conf', + help='color map to use' + ) + + parser.add_argument( + '--debug', action='store_true', default=False, + help='enable debug output' + ) + + return parser.parse_args(argv) + def load_annotation_image(path, color_map): - get_logger().debug('loading annotation image: %s', path) - return AnnotatedImage( - np.asarray( - Image.open(BytesIO(read_all_from_path(path, 'rb'))).convert('RGB'), - dtype=np.uint8 - ), - color_map - ) + get_logger().debug('loading annotation image: %s', path) + return AnnotatedImage( + np.asarray( + Image.open(BytesIO(read_all_from_path(path, 'rb'))).convert('RGB'), + dtype=np.uint8 + ), + color_map + ) + def main(argv=None): - args = parse_args(argv) + args = parse_args(argv) - if args.debug: - logging.getLogger().setLevel('DEBUG') + if args.debug: + logging.getLogger().setLevel('DEBUG') - color_map = parse_color_map_from_file(args.color_map) - get_logger().debug('color_map: %s', color_map) + color_map = parse_color_map_from_file(args.color_map) + get_logger().debug('color_map: %s', color_map) - structured_document = load_structured_document(args.lxml_path, 'rb') + structured_document = load_structured_document(args.lxml_path, 'rb') - annotated_images = ( - load_annotation_image(path, color_map) - for path in args.images_path - ) + annotated_images = ( + load_annotation_image(path, color_map) + for path in args.images_path + ) + + structured_document = annotate_structured_document_using_predicted_images( + structured_document, + annotated_images, + tag_scope=args.tag_scope + ) - structured_document = annotate_structured_document_using_predicted_images( - structured_document, - annotated_images, - tag_scope=args.tag_scope - ) + get_logger().info('writing result to: %s', args.output_path) + save_structured_document(args.output_path, structured_document.root) - get_logger().info('writing result to: %s', args.output_path) - save_structured_document(args.output_path, structured_document.root) if __name__ == '__main__': - logging.basicConfig(level='INFO') + logging.basicConfig(level='INFO') - main() + main() diff --git a/sciencebeam_gym/inference_model/annotate_using_predictions_test.py b/sciencebeam_gym/inference_model/annotate_using_predictions_test.py index f483b39..312fd0f 100644 --- a/sciencebeam_gym/inference_model/annotate_using_predictions_test.py +++ b/sciencebeam_gym/inference_model/annotate_using_predictions_test.py @@ -3,20 +3,20 @@ import pytest import numpy as np from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimpleLine, - SimpleToken + SimpleStructuredDocument, + SimpleLine, + SimpleToken ) from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.inference_model.annotate_using_predictions import ( - AnnotatedImage, - annotate_structured_document_using_predicted_images, - parse_args, - CV_TAG_SCOPE + AnnotatedImage, + annotate_structured_document_using_predicted_images, + parse_args, + CV_TAG_SCOPE ) TAG_1 = 'tag1' @@ -30,163 +30,168 @@ DEFAULT_HEIGHT = 3 DEFAULT_WIDTH = 3 DEFAULT_BOUNDING_BOX = BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT) + def filled_image(color, color_map, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT): - return AnnotatedImage( - np.full((height, width, 3), color), - color_map - ) + return AnnotatedImage( + np.full((height, width, 3), color), + color_map + ) + def fill_rect(annoted_image, bounding_box, color): - for y in range(bounding_box.y, bounding_box.y + bounding_box.height): - for x in range(bounding_box.x, bounding_box.x + bounding_box.width): - annoted_image.data[y, x] = color + for y in range(bounding_box.y, bounding_box.y + bounding_box.height): + for x in range(bounding_box.x, bounding_box.x + bounding_box.width): + annoted_image.data[y, x] = color + class TestAnnotatedImage(object): - def test_should_return_zero_tag_probality_if_color_not_in_output(self): - annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1}) - assert annotated_image.get_tag_probabilities_within( - DEFAULT_BOUNDING_BOX - ).get(TAG_1) == 0.0 - - def test_should_return_one_tag_probality_if_color_is_only_color_in_output(self): - annotated_image = filled_image(COLOR_1, {TAG_1: COLOR_1}) - assert annotated_image.get_tag_probabilities_within( - DEFAULT_BOUNDING_BOX - ).get(TAG_1) == 1.0 - - def test_should_return_zero_tag_probality_if_bounding_box_is_empty(self): - annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1}) - assert annotated_image.get_tag_probabilities_within( - BoundingBox(0, 0, 0, 0) - ).get(TAG_1) == 0.0 - - def test_should_return_zero_tag_probality_if_bounding_box_is_outside_image(self): - annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1}) - assert annotated_image.get_tag_probabilities_within( - DEFAULT_BOUNDING_BOX.move_by(DEFAULT_WIDTH, 0) - ).get(TAG_1) == 0.0 + def test_should_return_zero_tag_probality_if_color_not_in_output(self): + annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1}) + assert annotated_image.get_tag_probabilities_within( + DEFAULT_BOUNDING_BOX + ).get(TAG_1) == 0.0 + + def test_should_return_one_tag_probality_if_color_is_only_color_in_output(self): + annotated_image = filled_image(COLOR_1, {TAG_1: COLOR_1}) + assert annotated_image.get_tag_probabilities_within( + DEFAULT_BOUNDING_BOX + ).get(TAG_1) == 1.0 + + def test_should_return_zero_tag_probality_if_bounding_box_is_empty(self): + annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1}) + assert annotated_image.get_tag_probabilities_within( + BoundingBox(0, 0, 0, 0) + ).get(TAG_1) == 0.0 + + def test_should_return_zero_tag_probality_if_bounding_box_is_outside_image(self): + annotated_image = filled_image(BG_COLOR, {TAG_1: COLOR_1}) + assert annotated_image.get_tag_probabilities_within( + DEFAULT_BOUNDING_BOX.move_by(DEFAULT_WIDTH, 0) + ).get(TAG_1) == 0.0 + class TestAnnotateStructuredDocumentUsingPredictedImages(object): - def test_should_not_fail_with_empty_document(self): - structured_document = SimpleStructuredDocument() - annotate_structured_document_using_predicted_images( - structured_document, - [] - ) + def test_should_not_fail_with_empty_document(self): + structured_document = SimpleStructuredDocument() + annotate_structured_document_using_predicted_images( + structured_document, + [] + ) + + def test_should_not_tag_single_token_not_within_prediction(self): + token_1 = SimpleToken(TOKEN_TEXT_1) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) + structured_document.set_bounding_box( + structured_document.get_pages()[0], + DEFAULT_BOUNDING_BOX + ) + structured_document.set_bounding_box(token_1, DEFAULT_BOUNDING_BOX) + annotate_structured_document_using_predicted_images( + structured_document, + [filled_image(BG_COLOR, {TAG_1: COLOR_1})] + ) + assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) is None + + def test_should_tag_single_token_within_prediction(self): + token_1 = SimpleToken(TOKEN_TEXT_1) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) + structured_document.set_bounding_box( + structured_document.get_pages()[0], + DEFAULT_BOUNDING_BOX + ) + structured_document.set_bounding_box(token_1, DEFAULT_BOUNDING_BOX) + annotate_structured_document_using_predicted_images( + structured_document, + [filled_image(COLOR_1, {TAG_1: COLOR_1})] + ) + assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1 + + def test_should_tag_single_token_within_full_prediction_at_smaller_scale(self): + token_1 = SimpleToken(TOKEN_TEXT_1) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) + structured_document.set_bounding_box( + structured_document.get_pages()[0], + BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10) + ) + structured_document.set_bounding_box( + token_1, + BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10) + ) + annotate_structured_document_using_predicted_images( + structured_document, + [filled_image(COLOR_1, {TAG_1: COLOR_1}, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT)] + ) + assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1 + + def test_should_tag_single_token_within_partial_prediction_at_same_scale(self): + token_1 = SimpleToken(TOKEN_TEXT_1) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) + structured_document.set_bounding_box( + structured_document.get_pages()[0], + DEFAULT_BOUNDING_BOX + ) + structured_document.set_bounding_box( + structured_document.get_pages()[0], + BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10) + ) + structured_document.set_bounding_box( + token_1, + BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT) + ) + annotated_image = filled_image( + BG_COLOR, {TAG_1: COLOR_1}, + width=DEFAULT_WIDTH * 10, + height=DEFAULT_HEIGHT * 10 + ) + fill_rect( + annotated_image, + BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT), + COLOR_1 + ) + annotate_structured_document_using_predicted_images( + structured_document, + [annotated_image] + ) + assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1 + + def test_should_tag_single_token_within_partial_prediction_at_smaller_scale(self): + token_1 = SimpleToken(TOKEN_TEXT_1) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) + structured_document.set_bounding_box( + structured_document.get_pages()[0], + BoundingBox(0, 0, DEFAULT_WIDTH * 100, DEFAULT_HEIGHT * 100) + ) + structured_document.set_bounding_box( + token_1, + BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10) + ) + annotated_image = filled_image( + BG_COLOR, {TAG_1: COLOR_1}, + width=DEFAULT_WIDTH * 10, + height=DEFAULT_HEIGHT * 10 + ) + fill_rect( + annotated_image, + BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT), + COLOR_1 + ) + annotate_structured_document_using_predicted_images( + structured_document, + [annotated_image] + ) + assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1 - def test_should_not_tag_single_token_not_within_prediction(self): - token_1 = SimpleToken(TOKEN_TEXT_1) - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) - structured_document.set_bounding_box( - structured_document.get_pages()[0], - DEFAULT_BOUNDING_BOX - ) - structured_document.set_bounding_box(token_1, DEFAULT_BOUNDING_BOX) - annotate_structured_document_using_predicted_images( - structured_document, - [filled_image(BG_COLOR, {TAG_1: COLOR_1})] - ) - assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) is None - - def test_should_tag_single_token_within_prediction(self): - token_1 = SimpleToken(TOKEN_TEXT_1) - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) - structured_document.set_bounding_box( - structured_document.get_pages()[0], - DEFAULT_BOUNDING_BOX - ) - structured_document.set_bounding_box(token_1, DEFAULT_BOUNDING_BOX) - annotate_structured_document_using_predicted_images( - structured_document, - [filled_image(COLOR_1, {TAG_1: COLOR_1})] - ) - assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1 - - def test_should_tag_single_token_within_full_prediction_at_smaller_scale(self): - token_1 = SimpleToken(TOKEN_TEXT_1) - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) - structured_document.set_bounding_box( - structured_document.get_pages()[0], - BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10) - ) - structured_document.set_bounding_box( - token_1, - BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10) - ) - annotate_structured_document_using_predicted_images( - structured_document, - [filled_image(COLOR_1, {TAG_1: COLOR_1}, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT)] - ) - assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1 - - def test_should_tag_single_token_within_partial_prediction_at_same_scale(self): - token_1 = SimpleToken(TOKEN_TEXT_1) - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) - structured_document.set_bounding_box( - structured_document.get_pages()[0], - DEFAULT_BOUNDING_BOX - ) - structured_document.set_bounding_box( - structured_document.get_pages()[0], - BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10) - ) - structured_document.set_bounding_box( - token_1, - BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT) - ) - annotated_image = filled_image( - BG_COLOR, {TAG_1: COLOR_1}, - width=DEFAULT_WIDTH * 10, - height=DEFAULT_HEIGHT * 10 - ) - fill_rect( - annotated_image, - BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT), - COLOR_1 - ) - annotate_structured_document_using_predicted_images( - structured_document, - [annotated_image] - ) - assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1 - - def test_should_tag_single_token_within_partial_prediction_at_smaller_scale(self): - token_1 = SimpleToken(TOKEN_TEXT_1) - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) - structured_document.set_bounding_box( - structured_document.get_pages()[0], - BoundingBox(0, 0, DEFAULT_WIDTH * 100, DEFAULT_HEIGHT * 100) - ) - structured_document.set_bounding_box( - token_1, - BoundingBox(0, 0, DEFAULT_WIDTH * 10, DEFAULT_HEIGHT * 10) - ) - annotated_image = filled_image( - BG_COLOR, {TAG_1: COLOR_1}, - width=DEFAULT_WIDTH * 10, - height=DEFAULT_HEIGHT * 10 - ) - fill_rect( - annotated_image, - BoundingBox(0, 0, DEFAULT_WIDTH, DEFAULT_HEIGHT), - COLOR_1 - ) - annotate_structured_document_using_predicted_images( - structured_document, - [annotated_image] - ) - assert structured_document.get_tag(token_1, scope=CV_TAG_SCOPE) == TAG_1 class TestParseArgs(object): - def test_should_raise_error_if_not_enough_arguments_are_passed(self): - with pytest.raises(SystemExit): - parse_args([]) - - def test_should_not_raise_error_with_minimum_args(self): - parse_args(['--lxml-path=test', '--images-path=test', '--output-path=test']) - - # def test_should_raise_error_if_mutliple_source_args_are_specified(self): - # with pytest.raises(SystemExit): - # parse_args([ - # '--lxml-path=test', '--svg-path=test', '--images-path=test', '--output-path=test' - # ]) + def test_should_raise_error_if_not_enough_arguments_are_passed(self): + with pytest.raises(SystemExit): + parse_args([]) + + def test_should_not_raise_error_with_minimum_args(self): + parse_args(['--lxml-path=test', '--images-path=test', '--output-path=test']) + + # def test_should_raise_error_if_mutliple_source_args_are_specified(self): + # with pytest.raises(SystemExit): + # parse_args([ + # '--lxml-path=test', '--svg-path=test', '--images-path=test', '--output-path=test' + # ]) diff --git a/sciencebeam_gym/inference_model/extract_from_annotated_document.py b/sciencebeam_gym/inference_model/extract_from_annotated_document.py index aa06c4c..eb12356 100644 --- a/sciencebeam_gym/inference_model/extract_from_annotated_document.py +++ b/sciencebeam_gym/inference_model/extract_from_annotated_document.py @@ -1,95 +1,102 @@ import logging from sciencebeam_gym.structured_document import ( - B_TAG_PREFIX + B_TAG_PREFIX ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class ExtractedItem(object): - def __init__(self, tag, text, tokens=None, tag_prefix=None, sub_items=None): - self.tag = tag - self.tag_prefix = tag_prefix - self.text = text - self.tokens = tokens or [] - self.sub_items = sub_items or [] - - def extend(self, other_item): - return ExtractedItem( - self.tag, - self.text + '\n' + other_item.text, - tokens=self.tokens + other_item.tokens, - tag_prefix=self.tag_prefix, - sub_items=self.sub_items + other_item.sub_items - ) + def __init__(self, tag, text, tokens=None, tag_prefix=None, sub_items=None): + self.tag = tag + self.tag_prefix = tag_prefix + self.text = text + self.tokens = tokens or [] + self.sub_items = sub_items or [] + + def extend(self, other_item): + return ExtractedItem( + self.tag, + self.text + '\n' + other_item.text, + tokens=self.tokens + other_item.tokens, + tag_prefix=self.tag_prefix, + sub_items=self.sub_items + other_item.sub_items + ) + def get_lines(structured_document): - for page in structured_document.get_pages(): - for line in structured_document.get_lines_of_page(page): - yield line + for page in structured_document.get_pages(): + for line in structured_document.get_lines_of_page(page): + yield line + def extract_from_annotated_tokens(structured_document, tokens, tag_scope=None, level=None): - previous_tokens = [] - previous_tag = None - previous_tag_prefix = None - for token in tokens: - tag_prefix, tag = structured_document.get_tag_prefix_and_value( - token, scope=tag_scope, level=level - ) - if not previous_tokens: - previous_tokens = [token] - previous_tag = tag - previous_tag_prefix = tag_prefix - elif tag == previous_tag and tag_prefix != B_TAG_PREFIX: - previous_tokens.append(token) - else: - yield ExtractedItem( - previous_tag, - ' '.join(structured_document.get_text(t) for t in previous_tokens), - tokens=previous_tokens, - tag_prefix=previous_tag_prefix - ) - previous_tokens = [token] - previous_tag = tag - previous_tag_prefix = tag_prefix - if previous_tokens: - yield ExtractedItem( - previous_tag, - ' '.join(structured_document.get_text(t) for t in previous_tokens), - tokens=previous_tokens, - tag_prefix=previous_tag_prefix - ) + previous_tokens = [] + previous_tag = None + previous_tag_prefix = None + for token in tokens: + tag_prefix, tag = structured_document.get_tag_prefix_and_value( + token, scope=tag_scope, level=level + ) + if not previous_tokens: + previous_tokens = [token] + previous_tag = tag + previous_tag_prefix = tag_prefix + elif tag == previous_tag and tag_prefix != B_TAG_PREFIX: + previous_tokens.append(token) + else: + yield ExtractedItem( + previous_tag, + ' '.join(structured_document.get_text(t) for t in previous_tokens), + tokens=previous_tokens, + tag_prefix=previous_tag_prefix + ) + previous_tokens = [token] + previous_tag = tag + previous_tag_prefix = tag_prefix + if previous_tokens: + yield ExtractedItem( + previous_tag, + ' '.join(structured_document.get_text(t) for t in previous_tokens), + tokens=previous_tokens, + tag_prefix=previous_tag_prefix + ) + def with_sub_items(structured_document, extracted_item, tag_scope=None): - return ExtractedItem( - extracted_item.tag, - extracted_item.text, - tokens=extracted_item.tokens, - tag_prefix=extracted_item.tag_prefix, - sub_items=list(extract_from_annotated_tokens( - structured_document, extracted_item.tokens, - tag_scope=tag_scope, level=2 - )) - ) + return ExtractedItem( + extracted_item.tag, + extracted_item.text, + tokens=extracted_item.tokens, + tag_prefix=extracted_item.tag_prefix, + sub_items=list(extract_from_annotated_tokens( + structured_document, extracted_item.tokens, + tag_scope=tag_scope, level=2 + )) + ) + def extract_from_annotated_lines(structured_document, lines, tag_scope=None): - previous_item = None - for line in lines: - tokens = structured_document.get_tokens_of_line(line) - for item in extract_from_annotated_tokens(structured_document, tokens, tag_scope=tag_scope): - if previous_item is not None: - if previous_item.tag == item.tag and item.tag_prefix != B_TAG_PREFIX: - previous_item = previous_item.extend(item) - else: - yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope) - previous_item = item - else: - previous_item = item - if previous_item is not None: - yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope) + previous_item = None + for line in lines: + tokens = structured_document.get_tokens_of_line(line) + for item in extract_from_annotated_tokens(structured_document, tokens, tag_scope=tag_scope): + if previous_item is not None: + if previous_item.tag == item.tag and item.tag_prefix != B_TAG_PREFIX: + previous_item = previous_item.extend(item) + else: + yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope) + previous_item = item + else: + previous_item = item + if previous_item is not None: + yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope) + def extract_from_annotated_document(structured_document, tag_scope=None): - return extract_from_annotated_lines( - structured_document, get_lines(structured_document), tag_scope=tag_scope - ) + return extract_from_annotated_lines( + structured_document, get_lines(structured_document), tag_scope=tag_scope + ) diff --git a/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py b/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py index 847f0e8..9c467c2 100644 --- a/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py +++ b/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py @@ -1,15 +1,15 @@ import logging from sciencebeam_gym.structured_document import ( - SimpleToken, - SimpleLine, - SimpleStructuredDocument, - B_TAG_PREFIX, - I_TAG_PREFIX + SimpleToken, + SimpleLine, + SimpleStructuredDocument, + B_TAG_PREFIX, + I_TAG_PREFIX ) from sciencebeam_gym.inference_model.extract_from_annotated_document import ( - extract_from_annotated_document + extract_from_annotated_document ) VALUE_1 = 'value1' @@ -24,233 +24,242 @@ TAG_3 = 'tag3' TAG_SCOPE_1 = 'tag_scope1' + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def with_tag(x, tag): - if isinstance(x, SimpleToken): - x.set_tag(tag) - elif isinstance(x, list): - return [with_tag(y, tag) for y in x] - elif isinstance(x, SimpleLine): - return SimpleLine(with_tag(x.tokens, tag)) - return x + if isinstance(x, SimpleToken): + x.set_tag(tag) + elif isinstance(x, list): + return [with_tag(y, tag) for y in x] + elif isinstance(x, SimpleLine): + return SimpleLine(with_tag(x.tokens, tag)) + return x + def to_token(token): - return SimpleToken(token) if isinstance(token, str) else token + return SimpleToken(token) if isinstance(token, str) else token + def to_tokens(tokens): - if isinstance(tokens, str): - tokens = tokens.split(' ') - return [to_token(t) for t in tokens] + if isinstance(tokens, str): + tokens = tokens.split(' ') + return [to_token(t) for t in tokens] + def to_line(tokens): - return SimpleLine(to_tokens(tokens)) + return SimpleLine(to_tokens(tokens)) + def annotated_tokens(tokens, tag): - return with_tag(to_tokens(tokens), tag) + return with_tag(to_tokens(tokens), tag) + def annotated_line(tokens, tag): - return with_tag(to_line(tokens), tag) + return with_tag(to_line(tokens), tag) + def _token_with_sub_tag(text, tag=None, tag_prefix=None, sub_tag=None, sub_tag_prefix=None): - token = SimpleToken(text, tag=tag, tag_prefix=tag_prefix) - if sub_tag: - token.set_tag(sub_tag, prefix=sub_tag_prefix, level=2) - return token + token = SimpleToken(text, tag=tag, tag_prefix=tag_prefix) + if sub_tag: + token.set_tag(sub_tag, prefix=sub_tag_prefix, level=2) + return token + class TestExtractFromAnnotatedDocument(object): - def test_should_not_fail_on_empty_document(self): - structured_document = SimpleStructuredDocument() - extract_from_annotated_document(structured_document) + def test_should_not_fail_on_empty_document(self): + structured_document = SimpleStructuredDocument() + extract_from_annotated_document(structured_document) - def test_should_extract_single_annotated_line(self): - lines = [annotated_line(TEXT_1, TAG_1)] - structured_document = SimpleStructuredDocument(lines=lines) - result = [ - (x.tag, x.text) - for x in - extract_from_annotated_document(structured_document) - ] - assert result == [(TAG_1, TEXT_1)] + def test_should_extract_single_annotated_line(self): + lines = [annotated_line(TEXT_1, TAG_1)] + structured_document = SimpleStructuredDocument(lines=lines) + result = [ + (x.tag, x.text) + for x in + extract_from_annotated_document(structured_document) + ] + assert result == [(TAG_1, TEXT_1)] - def test_should_extract_from_different_tag_scope(self): - lines = [SimpleLine([SimpleToken(TEXT_1, tag=TAG_1, tag_scope=TAG_SCOPE_1)])] - structured_document = SimpleStructuredDocument(lines=lines) - result = [ - (x.tag, x.text) - for x in - extract_from_annotated_document(structured_document, tag_scope=TAG_SCOPE_1) - ] - assert result == [(TAG_1, TEXT_1)] + def test_should_extract_from_different_tag_scope(self): + lines = [SimpleLine([SimpleToken(TEXT_1, tag=TAG_1, tag_scope=TAG_SCOPE_1)])] + structured_document = SimpleStructuredDocument(lines=lines) + result = [ + (x.tag, x.text) + for x in + extract_from_annotated_document(structured_document, tag_scope=TAG_SCOPE_1) + ] + assert result == [(TAG_1, TEXT_1)] - def test_should_extract_multiple_annotations_on_single_line(self): - lines = [to_line( - annotated_tokens(TEXT_1, TAG_1) + - to_tokens(TEXT_2) + - annotated_tokens(TEXT_3, TAG_3) - )] - structured_document = SimpleStructuredDocument(lines=lines) - result = [ - (x.tag, x.text) - for x in - extract_from_annotated_document(structured_document) - ] - assert result == [ - (TAG_1, TEXT_1), - (None, TEXT_2), - (TAG_3, TEXT_3) - ] + def test_should_extract_multiple_annotations_on_single_line(self): + lines = [to_line( + annotated_tokens(TEXT_1, TAG_1) + + to_tokens(TEXT_2) + + annotated_tokens(TEXT_3, TAG_3) + )] + structured_document = SimpleStructuredDocument(lines=lines) + result = [ + (x.tag, x.text) + for x in + extract_from_annotated_document(structured_document) + ] + assert result == [ + (TAG_1, TEXT_1), + (None, TEXT_2), + (TAG_3, TEXT_3) + ] - def test_should_combine_multiple_lines(self): - lines = [ - annotated_line(TEXT_1, TAG_1), - annotated_line(TEXT_2, TAG_1) - ] - structured_document = SimpleStructuredDocument(lines=lines) - result = [ - (x.tag, x.text) - for x in - extract_from_annotated_document(structured_document) - ] - get_logger().debug('result: %s', result) - assert result == [(TAG_1, '\n'.join([TEXT_1, TEXT_2]))] + def test_should_combine_multiple_lines(self): + lines = [ + annotated_line(TEXT_1, TAG_1), + annotated_line(TEXT_2, TAG_1) + ] + structured_document = SimpleStructuredDocument(lines=lines) + result = [ + (x.tag, x.text) + for x in + extract_from_annotated_document(structured_document) + ] + get_logger().debug('result: %s', result) + assert result == [(TAG_1, '\n'.join([TEXT_1, TEXT_2]))] - def test_should_combine_multiple_lines_separated_by_other_tag(self): - lines = [ - annotated_line(TEXT_1, TAG_1), - annotated_line(TEXT_2, TAG_2), - annotated_line(TEXT_3, TAG_2), - annotated_line(TEXT_1, TAG_1), - annotated_line(TEXT_2, TAG_2), - annotated_line(TEXT_3, TAG_2) - ] - structured_document = SimpleStructuredDocument(lines=lines) - result = [ - (x.tag, x.text) - for x in - extract_from_annotated_document(structured_document) - ] - get_logger().debug('result: %s', result) - assert result == [ - (TAG_1, TEXT_1), - (TAG_2, '\n'.join([TEXT_2, TEXT_3])), - (TAG_1, TEXT_1), - (TAG_2, '\n'.join([TEXT_2, TEXT_3])) - ] + def test_should_combine_multiple_lines_separated_by_other_tag(self): + lines = [ + annotated_line(TEXT_1, TAG_1), + annotated_line(TEXT_2, TAG_2), + annotated_line(TEXT_3, TAG_2), + annotated_line(TEXT_1, TAG_1), + annotated_line(TEXT_2, TAG_2), + annotated_line(TEXT_3, TAG_2) + ] + structured_document = SimpleStructuredDocument(lines=lines) + result = [ + (x.tag, x.text) + for x in + extract_from_annotated_document(structured_document) + ] + get_logger().debug('result: %s', result) + assert result == [ + (TAG_1, TEXT_1), + (TAG_2, '\n'.join([TEXT_2, TEXT_3])), + (TAG_1, TEXT_1), + (TAG_2, '\n'.join([TEXT_2, TEXT_3])) + ] - def test_should_separate_items_based_on_tag_prefix(self): - tokens = [ - SimpleToken(VALUE_1, tag=TAG_1, tag_prefix=B_TAG_PREFIX), - SimpleToken(VALUE_2, tag=TAG_1, tag_prefix=I_TAG_PREFIX), - SimpleToken(VALUE_3, tag=TAG_1, tag_prefix=I_TAG_PREFIX), - SimpleToken(VALUE_1, tag=TAG_1, tag_prefix=B_TAG_PREFIX), - SimpleToken(VALUE_2, tag=TAG_1, tag_prefix=I_TAG_PREFIX), - SimpleToken(VALUE_3, tag=TAG_1, tag_prefix=I_TAG_PREFIX) - ] - structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)]) - result = [ - (x.tag, x.text) - for x in - extract_from_annotated_document(structured_document) - ] - get_logger().debug('result: %s', result) - assert result == [ - (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3])), - (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3])) - ] + def test_should_separate_items_based_on_tag_prefix(self): + tokens = [ + SimpleToken(VALUE_1, tag=TAG_1, tag_prefix=B_TAG_PREFIX), + SimpleToken(VALUE_2, tag=TAG_1, tag_prefix=I_TAG_PREFIX), + SimpleToken(VALUE_3, tag=TAG_1, tag_prefix=I_TAG_PREFIX), + SimpleToken(VALUE_1, tag=TAG_1, tag_prefix=B_TAG_PREFIX), + SimpleToken(VALUE_2, tag=TAG_1, tag_prefix=I_TAG_PREFIX), + SimpleToken(VALUE_3, tag=TAG_1, tag_prefix=I_TAG_PREFIX) + ] + structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)]) + result = [ + (x.tag, x.text) + for x in + extract_from_annotated_document(structured_document) + ] + get_logger().debug('result: %s', result) + assert result == [ + (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3])), + (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3])) + ] - def test_should_extract_sub_tags_from_single_item(self): - tokens = [ - _token_with_sub_tag( - VALUE_1, - tag=TAG_1, tag_prefix=B_TAG_PREFIX, - sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX - ), - _token_with_sub_tag( - VALUE_2, - tag=TAG_1, tag_prefix=I_TAG_PREFIX, - sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX - ), - _token_with_sub_tag( - VALUE_3, - tag=TAG_1, tag_prefix=I_TAG_PREFIX, - sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX - ) - ] - structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)]) - extracted_items = list(extract_from_annotated_document(structured_document)) - result = [ - ( - x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items] - ) for x in extracted_items - ] - get_logger().debug('result: %s', result) - assert result == [ - ( - TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]), - [ - (TAG_2, ' '.join([VALUE_1, VALUE_2])), - (TAG_3, VALUE_3) + def test_should_extract_sub_tags_from_single_item(self): + tokens = [ + _token_with_sub_tag( + VALUE_1, + tag=TAG_1, tag_prefix=B_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_2, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_3, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX + ) + ] + structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)]) + extracted_items = list(extract_from_annotated_document(structured_document)) + result = [ + ( + x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items] + ) for x in extracted_items + ] + get_logger().debug('result: %s', result) + assert result == [ + ( + TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]), + [ + (TAG_2, ' '.join([VALUE_1, VALUE_2])), + (TAG_3, VALUE_3) + ] + ) ] - ) - ] - def test_should_extract_sub_tags_from_multiple_items(self): - tokens = [ - _token_with_sub_tag( - VALUE_1, - tag=TAG_1, tag_prefix=B_TAG_PREFIX, - sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX - ), - _token_with_sub_tag( - VALUE_2, - tag=TAG_1, tag_prefix=I_TAG_PREFIX, - sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX - ), - _token_with_sub_tag( - VALUE_3, - tag=TAG_1, tag_prefix=I_TAG_PREFIX, - sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX - ), + def test_should_extract_sub_tags_from_multiple_items(self): + tokens = [ + _token_with_sub_tag( + VALUE_1, + tag=TAG_1, tag_prefix=B_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_2, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_3, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX + ), - _token_with_sub_tag( - VALUE_1, - tag=TAG_1, tag_prefix=B_TAG_PREFIX, - sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX - ), - _token_with_sub_tag( - VALUE_2, - tag=TAG_1, tag_prefix=I_TAG_PREFIX, - sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX - ), - _token_with_sub_tag( - VALUE_3, - tag=TAG_1, tag_prefix=I_TAG_PREFIX, - sub_tag=TAG_3, sub_tag_prefix=I_TAG_PREFIX - ) - ] - structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)]) - extracted_items = list(extract_from_annotated_document(structured_document)) - result = [ - ( - x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items] - ) for x in extracted_items - ] - get_logger().debug('result: %s', result) - assert result == [ - ( - TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]), - [ - (TAG_2, ' '.join([VALUE_1, VALUE_2])), - (TAG_3, VALUE_3) + _token_with_sub_tag( + VALUE_1, + tag=TAG_1, tag_prefix=B_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_2, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_3, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_3, sub_tag_prefix=I_TAG_PREFIX + ) + ] + structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)]) + extracted_items = list(extract_from_annotated_document(structured_document)) + result = [ + ( + x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items] + ) for x in extracted_items ] - ), - ( - TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]), - [ - (TAG_2, VALUE_1), - (TAG_3, ' '.join([VALUE_2, VALUE_3])) + get_logger().debug('result: %s', result) + assert result == [ + ( + TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]), + [ + (TAG_2, ' '.join([VALUE_1, VALUE_2])), + (TAG_3, VALUE_3) + ] + ), + ( + TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]), + [ + (TAG_2, VALUE_1), + (TAG_3, ' '.join([VALUE_2, VALUE_3])) + ] + ) ] - ) - ] diff --git a/sciencebeam_gym/inference_model/extract_to_xml.py b/sciencebeam_gym/inference_model/extract_to_xml.py index fba57f7..927ee26 100644 --- a/sciencebeam_gym/inference_model/extract_to_xml.py +++ b/sciencebeam_gym/inference_model/extract_to_xml.py @@ -5,202 +5,220 @@ from lxml import etree from lxml.builder import E from sciencebeam_utils.beam_utils.io import ( - save_file_content + save_file_content ) from sciencebeam_gym.structured_document.structured_document_loader import ( - load_structured_document + load_structured_document ) from sciencebeam_gym.inference_model.extract_from_annotated_document import ( - extract_from_annotated_document + extract_from_annotated_document ) + class Tags(object): - TITLE = 'manuscript_title' - ABSTRACT = 'abstract' - AUTHOR = 'author' - AUTHOR_AFF = 'author_aff' + TITLE = 'manuscript_title' + ABSTRACT = 'abstract' + AUTHOR = 'author' + AUTHOR_AFF = 'author_aff' + class XmlPaths(object): - TITLE = 'front/article-meta/title-group/article-title' - ABSTRACT = 'front/article-meta/abstract' - AUTHOR = 'front/article-meta/contrib-group/contrib' - AUTHOR_AFF = 'front/article-meta/contrib-group/aff' + TITLE = 'front/article-meta/title-group/article-title' + ABSTRACT = 'front/article-meta/abstract' + AUTHOR = 'front/article-meta/contrib-group/contrib' + AUTHOR_AFF = 'front/article-meta/contrib-group/aff' + class SubTags(object): - AUTHOR_SURNAME = 'surname' - AUTHOR_GIVEN_NAMES = 'givennames' + AUTHOR_SURNAME = 'surname' + AUTHOR_GIVEN_NAMES = 'givennames' + class SubXmlPaths(object): - AUTHOR_SURNAME = 'name/surname' - AUTHOR_GIVEN_NAMES = 'name/given-names' + AUTHOR_SURNAME = 'name/surname' + AUTHOR_GIVEN_NAMES = 'name/given-names' + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def rsplit_xml_path(path): - i = path.rfind('/') - if i >= 0: - return path[0:i], path[i + 1:] - else: - return None, path + i = path.rfind('/') + if i >= 0: + return path[0:i], path[i + 1:] + else: + return None, path + def create_node_recursive(xml_root, path, exists_ok=False): - node = xml_root.find(path) - if node is not None: - if not exists_ok: - raise RuntimeError('xml node already exists: %s' % path) + node = xml_root.find(path) + if node is not None: + if not exists_ok: + raise RuntimeError('xml node already exists: %s' % path) + return node + parent, base = rsplit_xml_path(path) + if parent: + parent_node = create_node_recursive(xml_root, parent, exists_ok=True) + else: + parent_node = xml_root + node = etree.Element(base) + parent_node.append(node) return node - parent, base = rsplit_xml_path(path) - if parent: - parent_node = create_node_recursive(xml_root, parent, exists_ok=True) - else: - parent_node = xml_root - node = etree.Element(base) - parent_node.append(node) - return node + def create_and_append_xml_node(xml_root, path): - parent, base = rsplit_xml_path(path) - parent_node = ( - create_node_recursive(xml_root, parent, exists_ok=True) - if parent - else xml_root - ) - node = etree.Element(base) - parent_node.append(node) - return node + parent, base = rsplit_xml_path(path) + parent_node = ( + create_node_recursive(xml_root, parent, exists_ok=True) + if parent + else xml_root + ) + node = etree.Element(base) + parent_node.append(node) + return node + def create_xml_text(xml_root, path, text): - node = create_and_append_xml_node(xml_root, path) - node.text = text - return node + node = create_and_append_xml_node(xml_root, path) + node.text = text + return node + AUTHOR_JUNK_CHARS = ',+*0123456789' + def _clean_author_name(s): - i = len(s) - while ( - i > 0 and - ( - s[i - 1] in AUTHOR_JUNK_CHARS or - # only remove dot after special characters - (s[i - 1] == '.' and i >= 2 and s[i - 2] in AUTHOR_JUNK_CHARS) - ) - ): - i -= 1 - return s[:i] + i = len(s) + while ( + i > 0 and + ( + s[i - 1] in AUTHOR_JUNK_CHARS or + # only remove dot after special characters + (s[i - 1] == '.' and i >= 2 and s[i - 2] in AUTHOR_JUNK_CHARS) + ) + ): + i -= 1 + return s[:i] + class XmlMapping(object): - def __init__( - self, xml_path, single_node=False, sub_mapping=None, attrib=None, - clean_fn=None): + def __init__( + self, xml_path, single_node=False, sub_mapping=None, attrib=None, + clean_fn=None): + + self.xml_path = xml_path + self.single_node = single_node + self.sub_mapping = sub_mapping + self.attrib = attrib + self.clean_fn = clean_fn - self.xml_path = xml_path - self.single_node = single_node - self.sub_mapping = sub_mapping - self.attrib = attrib - self.clean_fn = clean_fn def _extract_items(parent_node, extracted_items, xml_mapping): - previous_tag = None - for extracted_item in extracted_items: - tag = extracted_item.tag - if tag: - mapping_entry = xml_mapping.get(tag) - if not mapping_entry: - get_logger().warning('tag not configured: %s', tag) - continue - extracted_text = extracted_item.text - if extracted_text and mapping_entry.clean_fn: - extracted_text = mapping_entry.clean_fn(extracted_text) - path = mapping_entry.xml_path - if mapping_entry.single_node: - node = create_node_recursive(parent_node, path, exists_ok=True) - if node.text is None: - node.text = extracted_text - elif previous_tag == tag: - node.text += '\n' + extracted_text - else: - get_logger().debug('ignoring tag %s, after tag %s', tag, previous_tag) - else: - node = create_and_append_xml_node(parent_node, path) - if mapping_entry.attrib: - for k, v in mapping_entry.attrib.items(): - node.attrib[k] = v - if extracted_item.sub_items and mapping_entry.sub_mapping: - _extract_items(node, extracted_item.sub_items, mapping_entry.sub_mapping) - else: - node.text = extracted_text - previous_tag = tag + previous_tag = None + for extracted_item in extracted_items: + tag = extracted_item.tag + if tag: + mapping_entry = xml_mapping.get(tag) + if not mapping_entry: + get_logger().warning('tag not configured: %s', tag) + continue + extracted_text = extracted_item.text + if extracted_text and mapping_entry.clean_fn: + extracted_text = mapping_entry.clean_fn(extracted_text) + path = mapping_entry.xml_path + if mapping_entry.single_node: + node = create_node_recursive(parent_node, path, exists_ok=True) + if node.text is None: + node.text = extracted_text + elif previous_tag == tag: + node.text += '\n' + extracted_text + else: + get_logger().debug('ignoring tag %s, after tag %s', tag, previous_tag) + else: + node = create_and_append_xml_node(parent_node, path) + if mapping_entry.attrib: + for k, v in mapping_entry.attrib.items(): + node.attrib[k] = v + if extracted_item.sub_items and mapping_entry.sub_mapping: + _extract_items(node, extracted_item.sub_items, mapping_entry.sub_mapping) + else: + node.text = extracted_text + previous_tag = tag + def extracted_items_to_xml(extracted_items): - xml_mapping = { - Tags.TITLE: XmlMapping(XmlPaths.TITLE, single_node=True), - Tags.ABSTRACT: XmlMapping(XmlPaths.ABSTRACT, single_node=True), - Tags.AUTHOR: XmlMapping(XmlPaths.AUTHOR, sub_mapping={ - SubTags.AUTHOR_GIVEN_NAMES: XmlMapping( - SubXmlPaths.AUTHOR_GIVEN_NAMES, clean_fn=_clean_author_name - ), - SubTags.AUTHOR_SURNAME: XmlMapping( - SubXmlPaths.AUTHOR_SURNAME, clean_fn=_clean_author_name - ) - }, attrib={ - 'contrib-type': 'author' - }, clean_fn=_clean_author_name), - Tags.AUTHOR_AFF: XmlMapping(XmlPaths.AUTHOR_AFF) - } - xml_root = E.article() - _extract_items(xml_root, extracted_items, xml_mapping) - return xml_root + xml_mapping = { + Tags.TITLE: XmlMapping(XmlPaths.TITLE, single_node=True), + Tags.ABSTRACT: XmlMapping(XmlPaths.ABSTRACT, single_node=True), + Tags.AUTHOR: XmlMapping(XmlPaths.AUTHOR, sub_mapping={ + SubTags.AUTHOR_GIVEN_NAMES: XmlMapping( + SubXmlPaths.AUTHOR_GIVEN_NAMES, clean_fn=_clean_author_name + ), + SubTags.AUTHOR_SURNAME: XmlMapping( + SubXmlPaths.AUTHOR_SURNAME, clean_fn=_clean_author_name + ) + }, attrib={ + 'contrib-type': 'author' + }, clean_fn=_clean_author_name), + Tags.AUTHOR_AFF: XmlMapping(XmlPaths.AUTHOR_AFF) + } + xml_root = E.article() + _extract_items(xml_root, extracted_items, xml_mapping) + return xml_root + def extract_structured_document_to_xml(structured_document, tag_scope=None): - return extracted_items_to_xml( - extract_from_annotated_document(structured_document, tag_scope=tag_scope) - ) + return extracted_items_to_xml( + extract_from_annotated_document(structured_document, tag_scope=tag_scope) + ) + def parse_args(argv=None): - parser = argparse.ArgumentParser('Extract JATSy XML from annotated LXML') - parser.add_argument( - '--lxml-path', type=str, required=True, - help='path to lxml or svg pages document' - ) + parser = argparse.ArgumentParser('Extract JATSy XML from annotated LXML') + parser.add_argument( + '--lxml-path', type=str, required=True, + help='path to lxml or svg pages document' + ) - parser.add_argument( - '--tag-scope', type=str, required=False, - help='tag scope to extract based on' - ) + parser.add_argument( + '--tag-scope', type=str, required=False, + help='tag scope to extract based on' + ) - parser.add_argument( - '--output-path', type=str, required=True, - help='output path to annotated document' - ) + parser.add_argument( + '--output-path', type=str, required=True, + help='output path to annotated document' + ) - parser.add_argument( - '--debug', action='store_true', default=False, - help='enable debug output' - ) + parser.add_argument( + '--debug', action='store_true', default=False, + help='enable debug output' + ) + + return parser.parse_args(argv) - return parser.parse_args(argv) def main(argv=None): - args = parse_args(argv) + args = parse_args(argv) + + if args.debug: + logging.getLogger().setLevel('DEBUG') - if args.debug: - logging.getLogger().setLevel('DEBUG') + structured_document = load_structured_document(args.lxml_path) - structured_document = load_structured_document(args.lxml_path) + xml_root = extract_structured_document_to_xml( + structured_document, + tag_scope=args.tag_scope + ) - xml_root = extract_structured_document_to_xml( - structured_document, - tag_scope=args.tag_scope - ) + get_logger().info('writing result to: %s', args.output_path) + save_file_content(args.output_path, etree.tostring(xml_root, pretty_print=True)) - get_logger().info('writing result to: %s', args.output_path) - save_file_content(args.output_path, etree.tostring(xml_root, pretty_print=True)) if __name__ == '__main__': - logging.basicConfig(level='INFO') + logging.basicConfig(level='INFO') - main() + main() diff --git a/sciencebeam_gym/inference_model/extract_to_xml_test.py b/sciencebeam_gym/inference_model/extract_to_xml_test.py index dad436a..abf1329 100644 --- a/sciencebeam_gym/inference_model/extract_to_xml_test.py +++ b/sciencebeam_gym/inference_model/extract_to_xml_test.py @@ -6,166 +6,169 @@ from lxml import etree from lxml.builder import E from sciencebeam_utils.utils.xml import ( - get_text_content, - get_text_content_list + get_text_content, + get_text_content_list ) from sciencebeam_gym.inference_model.extract_from_annotated_document import ( - ExtractedItem + ExtractedItem ) from sciencebeam_gym.inference_model.extract_to_xml import ( - extracted_items_to_xml, - Tags, - XmlPaths, - SubTags, - SubXmlPaths, - main + extracted_items_to_xml, + Tags, + XmlPaths, + SubTags, + SubXmlPaths, + main ) TEXT_1 = 'some text here' TEXT_2 = 'more text to come' TEXT_3 = 'does not stop here' + def _create_author_extracted_items(given_names, surname): - return [ - ExtractedItem(Tags.AUTHOR, ' '.join([given_names, surname]), sub_items=[ - ExtractedItem(SubTags.AUTHOR_GIVEN_NAMES, given_names), - ExtractedItem(SubTags.AUTHOR_SURNAME, surname) - ]) - ] + return [ + ExtractedItem(Tags.AUTHOR, ' '.join([given_names, surname]), sub_items=[ + ExtractedItem(SubTags.AUTHOR_GIVEN_NAMES, given_names), + ExtractedItem(SubTags.AUTHOR_SURNAME, surname) + ]) + ] + class TestExtractedItemsToXml(object): - def test_should_return_empty_xml_for_no_empty_list_of_extracted_items(self): - xml_root = extracted_items_to_xml([]) - assert xml_root is not None - - def test_should_populate_title(self): - xml_root = extracted_items_to_xml([ - ExtractedItem(Tags.TITLE, TEXT_1) - ]) - assert xml_root is not None - assert get_text_content(xml_root.find(XmlPaths.TITLE)) == TEXT_1 - - def test_should_append_to_abstract(self): - xml_root = extracted_items_to_xml([ - ExtractedItem(Tags.ABSTRACT, TEXT_1), - ExtractedItem(Tags.ABSTRACT, TEXT_2) - ]) - assert xml_root is not None - assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1, TEXT_2]) - - def test_should_not_append_to_abstract_after_untagged_content(self): - xml_root = extracted_items_to_xml([ - ExtractedItem(Tags.ABSTRACT, TEXT_1), - ExtractedItem(None, TEXT_2), - ExtractedItem(Tags.ABSTRACT, TEXT_3) - ]) - assert xml_root is not None - assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1, TEXT_3]) - - def test_should_not_append_to_abstract_after_another_tag_occured(self): - xml_root = extracted_items_to_xml([ - ExtractedItem(Tags.ABSTRACT, TEXT_1), - ExtractedItem(Tags.AUTHOR, TEXT_2), - ExtractedItem(Tags.ABSTRACT, TEXT_3) - ]) - assert xml_root is not None - assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1]) - - def test_should_create_separate_author_node(self): - xml_root = extracted_items_to_xml([ - ExtractedItem(Tags.AUTHOR, TEXT_1), - ExtractedItem(Tags.AUTHOR, TEXT_2) - ]) - assert xml_root is not None - assert get_text_content_list(xml_root.findall(XmlPaths.AUTHOR)) == [TEXT_1, TEXT_2] - - def test_should_extract_author_surname_and_given_names_from_single_author(self): - xml_root = extracted_items_to_xml([ - ExtractedItem(Tags.AUTHOR, ' '.join([TEXT_1, TEXT_2]), sub_items=[ - ExtractedItem(SubTags.AUTHOR_GIVEN_NAMES, TEXT_1), - ExtractedItem(SubTags.AUTHOR_SURNAME, TEXT_2) - ]) - ]) - assert xml_root is not None - author = xml_root.find(XmlPaths.AUTHOR) - assert author is not None - assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == TEXT_1 - assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == TEXT_2 - - def test_should_remove_special_characters_and_numbers_from_author(self): - special_num_chars = ',+*0123456789' - xml_root = extracted_items_to_xml(_create_author_extracted_items( - TEXT_1 + special_num_chars, TEXT_2 + special_num_chars - )) - assert xml_root is not None - author = xml_root.find(XmlPaths.AUTHOR) - assert author is not None - assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == TEXT_1 - assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == TEXT_2 - - def test_should_not_remove_dot_after_initials_from_author(self): - xml_root = extracted_items_to_xml(_create_author_extracted_items( - 'Mr T.', 'E.' - )) - assert xml_root is not None - author = xml_root.find(XmlPaths.AUTHOR) - assert author is not None - assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T.' - assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'E.' - - def test_should_not_remove_dot_after_suffix_from_author(self): - xml_root = extracted_items_to_xml(_create_author_extracted_items( - 'Mr T.', 'Jr.' - )) - assert xml_root is not None - author = xml_root.find(XmlPaths.AUTHOR) - assert author is not None - assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T.' - assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'Jr.' - - def test_should_remove_dot_after_other_special_characters(self): - xml_root = extracted_items_to_xml(_create_author_extracted_items( - 'Mr T*.', 'E*.' - )) - assert xml_root is not None - author = xml_root.find(XmlPaths.AUTHOR) - assert author is not None - assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T' - assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'E' - - def test_should_add_contrib_type_author_attribute(self): - xml_root = extracted_items_to_xml(_create_author_extracted_items(TEXT_1, TEXT_2)) - assert xml_root is not None - author = xml_root.find(XmlPaths.AUTHOR) - assert author is not None - assert author.tag == 'contrib' - assert author.attrib.get('contrib-type') == 'author' + def test_should_return_empty_xml_for_no_empty_list_of_extracted_items(self): + xml_root = extracted_items_to_xml([]) + assert xml_root is not None + + def test_should_populate_title(self): + xml_root = extracted_items_to_xml([ + ExtractedItem(Tags.TITLE, TEXT_1) + ]) + assert xml_root is not None + assert get_text_content(xml_root.find(XmlPaths.TITLE)) == TEXT_1 + + def test_should_append_to_abstract(self): + xml_root = extracted_items_to_xml([ + ExtractedItem(Tags.ABSTRACT, TEXT_1), + ExtractedItem(Tags.ABSTRACT, TEXT_2) + ]) + assert xml_root is not None + assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1, TEXT_2]) + + def test_should_not_append_to_abstract_after_untagged_content(self): + xml_root = extracted_items_to_xml([ + ExtractedItem(Tags.ABSTRACT, TEXT_1), + ExtractedItem(None, TEXT_2), + ExtractedItem(Tags.ABSTRACT, TEXT_3) + ]) + assert xml_root is not None + assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1, TEXT_3]) + + def test_should_not_append_to_abstract_after_another_tag_occured(self): + xml_root = extracted_items_to_xml([ + ExtractedItem(Tags.ABSTRACT, TEXT_1), + ExtractedItem(Tags.AUTHOR, TEXT_2), + ExtractedItem(Tags.ABSTRACT, TEXT_3) + ]) + assert xml_root is not None + assert get_text_content(xml_root.find(XmlPaths.ABSTRACT)) == '\n'.join([TEXT_1]) + + def test_should_create_separate_author_node(self): + xml_root = extracted_items_to_xml([ + ExtractedItem(Tags.AUTHOR, TEXT_1), + ExtractedItem(Tags.AUTHOR, TEXT_2) + ]) + assert xml_root is not None + assert get_text_content_list(xml_root.findall(XmlPaths.AUTHOR)) == [TEXT_1, TEXT_2] + + def test_should_extract_author_surname_and_given_names_from_single_author(self): + xml_root = extracted_items_to_xml([ + ExtractedItem(Tags.AUTHOR, ' '.join([TEXT_1, TEXT_2]), sub_items=[ + ExtractedItem(SubTags.AUTHOR_GIVEN_NAMES, TEXT_1), + ExtractedItem(SubTags.AUTHOR_SURNAME, TEXT_2) + ]) + ]) + assert xml_root is not None + author = xml_root.find(XmlPaths.AUTHOR) + assert author is not None + assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == TEXT_1 + assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == TEXT_2 + + def test_should_remove_special_characters_and_numbers_from_author(self): + special_num_chars = ',+*0123456789' + xml_root = extracted_items_to_xml(_create_author_extracted_items( + TEXT_1 + special_num_chars, TEXT_2 + special_num_chars + )) + assert xml_root is not None + author = xml_root.find(XmlPaths.AUTHOR) + assert author is not None + assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == TEXT_1 + assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == TEXT_2 + + def test_should_not_remove_dot_after_initials_from_author(self): + xml_root = extracted_items_to_xml(_create_author_extracted_items( + 'Mr T.', 'E.' + )) + assert xml_root is not None + author = xml_root.find(XmlPaths.AUTHOR) + assert author is not None + assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T.' + assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'E.' + + def test_should_not_remove_dot_after_suffix_from_author(self): + xml_root = extracted_items_to_xml(_create_author_extracted_items( + 'Mr T.', 'Jr.' + )) + assert xml_root is not None + author = xml_root.find(XmlPaths.AUTHOR) + assert author is not None + assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T.' + assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'Jr.' + + def test_should_remove_dot_after_other_special_characters(self): + xml_root = extracted_items_to_xml(_create_author_extracted_items( + 'Mr T*.', 'E*.' + )) + assert xml_root is not None + author = xml_root.find(XmlPaths.AUTHOR) + assert author is not None + assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == 'Mr T' + assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == 'E' + + def test_should_add_contrib_type_author_attribute(self): + xml_root = extracted_items_to_xml(_create_author_extracted_items(TEXT_1, TEXT_2)) + assert xml_root is not None + author = xml_root.find(XmlPaths.AUTHOR) + assert author is not None + assert author.tag == 'contrib' + assert author.attrib.get('contrib-type') == 'author' + class TestMain(object): - def test_should_extract_from_simple_annotated_document(self): - with TemporaryDirectory() as path: - lxml_root = E.DOCUMENT( - E.PAGE( - E.TEXT( - E.TOKEN( - TEXT_1, - { - 'tag': Tags.TITLE - } + def test_should_extract_from_simple_annotated_document(self): + with TemporaryDirectory() as path: + lxml_root = E.DOCUMENT( + E.PAGE( + E.TEXT( + E.TOKEN( + TEXT_1, + { + 'tag': Tags.TITLE + } + ) + ) + ) ) - ) - ) - ) - lxml_path = os.path.join(path, 'test.lxml') - with open(lxml_path, 'w') as f: - f.write(etree.tostring(lxml_root)) + lxml_path = os.path.join(path, 'test.lxml') + with open(lxml_path, 'w') as f: + f.write(etree.tostring(lxml_root)) - output_path = os.path.join(path, 'test.xml') + output_path = os.path.join(path, 'test.xml') - main(['--lxml-path=%s' % lxml_path, '--output-path=%s' % output_path]) + main(['--lxml-path=%s' % lxml_path, '--output-path=%s' % output_path]) - xml_root = etree.parse(output_path) - assert get_text_content(xml_root.find(XmlPaths.TITLE)) == TEXT_1 + xml_root = etree.parse(output_path) + assert get_text_content(xml_root.find(XmlPaths.TITLE)) == TEXT_1 diff --git a/sciencebeam_gym/model_utils/channels.py b/sciencebeam_gym/model_utils/channels.py index b342dce..4328e3e 100644 --- a/sciencebeam_gym/model_utils/channels.py +++ b/sciencebeam_gym/model_utils/channels.py @@ -1,34 +1,41 @@ import tensorflow as tf from sciencebeam_gym.utils.tf import ( - variable_scoped + variable_scoped ) + def color_equals_mask(image, color): - return tf.reduce_all( - tf.equal(image, color), - axis=-1, - name='is_color' - ) + return tf.reduce_all( + tf.equal(image, color), + axis=-1, + name='is_color' + ) + def color_equals_mask_as_float(image, color): - return tf.cast(color_equals_mask(image, color), tf.float32) + return tf.cast(color_equals_mask(image, color), tf.float32) + def calculate_color_masks(image, colors, use_unknown_class=False): - color_masks = [ - variable_scoped('channel_%d' % i, lambda: color_equals_mask_as_float(image, color)) - for i, color in enumerate(colors) - ] - if use_unknown_class: - with tf.variable_scope("unknown_class"): - shape = tf.shape(color_masks[0]) - ones = tf.fill(shape, 1.0, name='ones') - zeros = tf.fill(shape, 0.0, name='zeros') - color_masks.append( - tf.where( - tf.add_n(color_masks) < 0.5, - ones, - zeros + color_masks = [ + variable_scoped( + 'channel_%d' % i, + lambda color_param: color_equals_mask_as_float(image, color_param), + color_param=color ) - ) - return color_masks + for i, color in enumerate(colors) + ] + if use_unknown_class: + with tf.variable_scope("unknown_class"): + shape = tf.shape(color_masks[0]) + ones = tf.fill(shape, 1.0, name='ones') + zeros = tf.fill(shape, 0.0, name='zeros') + color_masks.append( + tf.where( + tf.add_n(color_masks) < 0.5, + ones, + zeros + ) + ) + return color_masks diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py index cd76934..de47dad 100644 --- a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py +++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py @@ -4,138 +4,146 @@ import pickle from itertools import repeat from sciencebeam_gym.utils.tf import ( - FileIO + FileIO ) from sciencebeam_gym.models.text.feature_extractor import ( - structured_document_to_token_props, - token_props_list_to_features, - merge_with_cv_structured_document, - NONE_TAG + structured_document_to_token_props, + token_props_list_to_features, + merge_with_cv_structured_document, + NONE_TAG ) from sciencebeam_gym.structured_document.structured_document_loader import ( - load_lxml_structured_document + load_lxml_structured_document ) from sciencebeam_gym.structured_document.structured_document_saver import ( - save_structured_document + save_structured_document ) CRF_TAG_SCOPE = 'crf' + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def _iter_tokens(structured_document): - for page in structured_document.get_pages(): - for line in structured_document.get_lines_of_page(page): - for token in structured_document.get_tokens_of_line(line): - yield token + for page in structured_document.get_pages(): + for line in structured_document.get_lines_of_page(page): + for token in structured_document.get_tokens_of_line(line): + yield token + def annotate_structured_document_using_predictions( - structured_document, predictions, token_props_list=None, - tag_scope=CRF_TAG_SCOPE): - """ - Annotates the structured document using the predicted tags. - - Args: - structured_document: the document that will be tagged - predictions: list of predicted tags - token_props_list: optional, used to verify that the correct token is being tagged - tag_scope: tag scope to use when setting predicted tags - """ - - if token_props_list is None: - token_props_list = repeat(None) - for token, prediction, token_props in zip( - _iter_tokens(structured_document), - predictions, token_props_list + structured_document, predictions, token_props_list=None, + tag_scope=CRF_TAG_SCOPE): + """ + Annotates the structured document using the predicted tags. + + Args: + structured_document: the document that will be tagged + predictions: list of predicted tags + token_props_list: optional, used to verify that the correct token is being tagged + tag_scope: tag scope to use when setting predicted tags + """ + + if token_props_list is None: + token_props_list = repeat(None) + for token, prediction, token_props in zip( + _iter_tokens(structured_document), + predictions, token_props_list ): - if token_props: - assert structured_document.get_text(token) == token_props['text'] + if token_props: + assert structured_document.get_text(token) == token_props['text'] + + if prediction and prediction != NONE_TAG: + structured_document.set_tag(token, prediction, scope=tag_scope) - if prediction and prediction != NONE_TAG: - structured_document.set_tag(token, prediction, scope=tag_scope) def predict_and_annotate_structured_document(structured_document, model, tag_scope=CRF_TAG_SCOPE): - token_props = list(structured_document_to_token_props(structured_document)) - x = token_props_list_to_features(token_props) - y_pred = model.predict([x])[0] - annotate_structured_document_using_predictions( - structured_document, y_pred, token_props, tag_scope=tag_scope - ) - return structured_document + token_props = list(structured_document_to_token_props(structured_document)) + x = token_props_list_to_features(token_props) + y_pred = model.predict([x])[0] + annotate_structured_document_using_predictions( + structured_document, y_pred, token_props, tag_scope=tag_scope + ) + return structured_document + def parse_args(argv=None): - parser = argparse.ArgumentParser('Annotated LXML using CRF model') - source = parser.add_argument_group('source') - source.add_argument( - '--lxml-path', type=str, required=False, - help='path to lxml document' - ) - - cv_source = parser.add_argument_group('CV source') - cv_source.add_argument( - '--cv-lxml-path', type=str, required=False, - help='path to lxml document with cv predicted tags' - ) - - parser.add_argument( - '--crf-model', type=str, required=True, - help='path to saved crf model' - ) - - parser.add_argument( - '--output-path', type=str, required=True, - help='output path to annotated document' - ) - - parser.add_argument( - '--tag-scope', type=str, required=False, - default=CRF_TAG_SCOPE, - help='target tag scope for the predicted tags' - ) - - parser.add_argument( - '--debug', action='store_true', default=False, - help='enable debug output' - ) - - return parser.parse_args(argv) + parser = argparse.ArgumentParser('Annotated LXML using CRF model') + source = parser.add_argument_group('source') + source.add_argument( + '--lxml-path', type=str, required=False, + help='path to lxml document' + ) + + cv_source = parser.add_argument_group('CV source') + cv_source.add_argument( + '--cv-lxml-path', type=str, required=False, + help='path to lxml document with cv predicted tags' + ) + + parser.add_argument( + '--crf-model', type=str, required=True, + help='path to saved crf model' + ) + + parser.add_argument( + '--output-path', type=str, required=True, + help='output path to annotated document' + ) + + parser.add_argument( + '--tag-scope', type=str, required=False, + default=CRF_TAG_SCOPE, + help='target tag scope for the predicted tags' + ) + + parser.add_argument( + '--debug', action='store_true', default=False, + help='enable debug output' + ) + + return parser.parse_args(argv) + def load_crf_model(path): - with FileIO(path, 'rb') as crf_model_f: - return pickle.load(crf_model_f) + with FileIO(path, 'rb') as crf_model_f: + return pickle.load(crf_model_f) + def main(argv=None): - args = parse_args(argv) + args = parse_args(argv) - if args.debug: - logging.getLogger().setLevel('DEBUG') + if args.debug: + logging.getLogger().setLevel('DEBUG') - structured_document = load_lxml_structured_document(args.lxml_path) + structured_document = load_lxml_structured_document(args.lxml_path) - if args.cv_lxml_path: - cv_structured_document = load_lxml_structured_document(args.cv_lxml_path) - structured_document = merge_with_cv_structured_document( - structured_document, - cv_structured_document - ) + if args.cv_lxml_path: + cv_structured_document = load_lxml_structured_document(args.cv_lxml_path) + structured_document = merge_with_cv_structured_document( + structured_document, + cv_structured_document + ) - model = load_crf_model(args.crf_model) + model = load_crf_model(args.crf_model) + + predict_and_annotate_structured_document( + structured_document, + model, + tag_scope=args.tag_scope + ) - predict_and_annotate_structured_document( - structured_document, - model, - tag_scope=args.tag_scope - ) + get_logger().info('writing result to: %s', args.output_path) + save_structured_document(args.output_path, structured_document) - get_logger().info('writing result to: %s', args.output_path) - save_structured_document(args.output_path, structured_document) if __name__ == '__main__': - logging.basicConfig(level='INFO') + logging.basicConfig(level='INFO') - main() + main() diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py index 7b576e8..7ba0029 100644 --- a/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py +++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py @@ -3,26 +3,26 @@ from mock import MagicMock import pytest from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimplePage, - SimpleLine, - SimpleToken + SimpleStructuredDocument, + SimplePage, + SimpleLine, + SimpleToken ) from sciencebeam_gym.models.text.feature_extractor import ( - structured_document_to_token_props, - token_props_list_to_features, - NONE_TAG + structured_document_to_token_props, + token_props_list_to_features, + NONE_TAG ) from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.models.text.crf.annotate_using_predictions import ( - annotate_structured_document_using_predictions, - predict_and_annotate_structured_document, - CRF_TAG_SCOPE + annotate_structured_document_using_predictions, + predict_and_annotate_structured_document, + CRF_TAG_SCOPE ) TAG_1 = 'tag1' @@ -32,75 +32,77 @@ TOKEN_TEXT_2 = 'token 2' BOUNDING_BOX = BoundingBox(0, 0, 10, 10) + class TestAnnotateStructuredDocumentUsingPredictions(object): - def test_should_not_fail_with_empty_document(self): - structured_document = SimpleStructuredDocument() - annotate_structured_document_using_predictions( - structured_document, - [] - ) - - def test_should_tag_single_token_using_prediction(self): - token_1 = SimpleToken(TOKEN_TEXT_1) - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) - annotate_structured_document_using_predictions( - structured_document, - [TAG_1] - ) - assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1 - - def test_should_not_tag_using_none_tag(self): - token_1 = SimpleToken(TOKEN_TEXT_1) - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) - annotate_structured_document_using_predictions( - structured_document, - [NONE_TAG] - ) - assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) is None - - def test_should_tag_single_token_using_prediction_and_check_token_props(self): - token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) - structured_document = SimpleStructuredDocument(SimplePage( - lines=[SimpleLine([token_1])], - bounding_box=BOUNDING_BOX - )) - token_props_list = structured_document_to_token_props(structured_document) - annotate_structured_document_using_predictions( - structured_document, - [TAG_1], - token_props_list - ) - assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1 - - def test_should_raise_error_if_token_props_do_not_match(self): - token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) - structured_document = SimpleStructuredDocument(SimplePage( - lines=[SimpleLine([token_1])], - bounding_box=BOUNDING_BOX - )) - token_props_list = list(structured_document_to_token_props(structured_document)) - token_props_list[0]['text'] = TOKEN_TEXT_2 - with pytest.raises(AssertionError): - annotate_structured_document_using_predictions( - structured_document, - [TAG_1], - token_props_list - ) + def test_should_not_fail_with_empty_document(self): + structured_document = SimpleStructuredDocument() + annotate_structured_document_using_predictions( + structured_document, + [] + ) + + def test_should_tag_single_token_using_prediction(self): + token_1 = SimpleToken(TOKEN_TEXT_1) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) + annotate_structured_document_using_predictions( + structured_document, + [TAG_1] + ) + assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1 + + def test_should_not_tag_using_none_tag(self): + token_1 = SimpleToken(TOKEN_TEXT_1) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) + annotate_structured_document_using_predictions( + structured_document, + [NONE_TAG] + ) + assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) is None + + def test_should_tag_single_token_using_prediction_and_check_token_props(self): + token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) + structured_document = SimpleStructuredDocument(SimplePage( + lines=[SimpleLine([token_1])], + bounding_box=BOUNDING_BOX + )) + token_props_list = structured_document_to_token_props(structured_document) + annotate_structured_document_using_predictions( + structured_document, + [TAG_1], + token_props_list + ) + assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1 + + def test_should_raise_error_if_token_props_do_not_match(self): + token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) + structured_document = SimpleStructuredDocument(SimplePage( + lines=[SimpleLine([token_1])], + bounding_box=BOUNDING_BOX + )) + token_props_list = list(structured_document_to_token_props(structured_document)) + token_props_list[0]['text'] = TOKEN_TEXT_2 + with pytest.raises(AssertionError): + annotate_structured_document_using_predictions( + structured_document, + [TAG_1], + token_props_list + ) + class TestPredictAndAnnotateStructuredDocument(object): - def test_should_predict_and_annotate_single_token(self): - token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) - structured_document = SimpleStructuredDocument(SimplePage( - lines=[SimpleLine([token_1])], - bounding_box=BOUNDING_BOX - )) - model = MagicMock() - model.predict.return_value = [[TAG_1]] - token_props = list(structured_document_to_token_props(structured_document)) - X = [token_props_list_to_features(token_props)] - predict_and_annotate_structured_document( - structured_document, - model - ) - assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1 - model.predict.assert_called_with(X) + def test_should_predict_and_annotate_single_token(self): + token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) + structured_document = SimpleStructuredDocument(SimplePage( + lines=[SimpleLine([token_1])], + bounding_box=BOUNDING_BOX + )) + model = MagicMock() + model.predict.return_value = [[TAG_1]] + token_props = list(structured_document_to_token_props(structured_document)) + X = [token_props_list_to_features(token_props)] + predict_and_annotate_structured_document( + structured_document, + model + ) + assert structured_document.get_tag(token_1, scope=CRF_TAG_SCOPE) == TAG_1 + model.predict.assert_called_with(X) diff --git a/sciencebeam_gym/models/text/crf/crfsuite_model.py b/sciencebeam_gym/models/text/crf/crfsuite_model.py index a831eff..6650359 100644 --- a/sciencebeam_gym/models/text/crf/crfsuite_model.py +++ b/sciencebeam_gym/models/text/crf/crfsuite_model.py @@ -2,19 +2,22 @@ import logging from sklearn_crfsuite import CRF + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + DEFAULT_PARAMS = dict( - algorithm='lbfgs', - c1=0.1, - c2=0.1, - max_iterations=100, - all_possible_transitions=True + algorithm='lbfgs', + c1=0.1, + c2=0.1, + max_iterations=100, + all_possible_transitions=True ) + class CrfSuiteModel(CRF): - def __init__(self, **kwargs): - d = dict(DEFAULT_PARAMS) - d.update(kwargs) - super(CrfSuiteModel, self).__init__(**d) + def __init__(self, **kwargs): + d = dict(DEFAULT_PARAMS) + d.update(kwargs) + super(CrfSuiteModel, self).__init__(**d) diff --git a/sciencebeam_gym/models/text/crf/crfsuite_model_test.py b/sciencebeam_gym/models/text/crf/crfsuite_model_test.py index 54a6da8..5c9c1c4 100644 --- a/sciencebeam_gym/models/text/crf/crfsuite_model_test.py +++ b/sciencebeam_gym/models/text/crf/crfsuite_model_test.py @@ -3,24 +3,24 @@ import pickle from contextlib import contextmanager from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimplePage, - SimpleLine, - SimpleToken + SimpleStructuredDocument, + SimplePage, + SimpleLine, + SimpleToken ) from sciencebeam_gym.models.text.feature_extractor import ( - structured_document_to_token_props, - token_props_list_to_features, - token_props_list_to_labels + structured_document_to_token_props, + token_props_list_to_features, + token_props_list_to_labels ) from sciencebeam_gym.models.text.crf.crfsuite_model import ( - CrfSuiteModel + CrfSuiteModel ) PAGE_BOUNDING_BOX = BoundingBox(0, 0, 100, 200) @@ -34,89 +34,93 @@ TAG_1 = 'tag1' TAG_2 = 'tag2' TAG_3 = 'tag3' + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + @contextmanager def create_crf_suite_model(): - model = CrfSuiteModel() - yield model + model = CrfSuiteModel() + yield model + class TestCrfSuiteModel(object): - def test_should_learn_simple_sequence(self): - structured_document = SimpleStructuredDocument( - SimplePage([ - SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1), - SimpleToken(TEXT_2, tag=TAG_2), - SimpleToken(TEXT_3, tag=TAG_3) - ]) - ], bounding_box=PAGE_BOUNDING_BOX) - ) - token_props_list = list(structured_document_to_token_props(structured_document)) - get_logger().debug('token_props_list:\n%s', token_props_list) - X = [token_props_list_to_features(token_props_list)] - y = [token_props_list_to_labels(token_props_list)] - get_logger().debug('X:\n%s', X) - get_logger().debug('y:\n%s', y) - with create_crf_suite_model() as model: - model.fit(X, y) - y_predicted = model.predict(X) - assert y_predicted == y - - def test_should_learn_similar_sequence(self): - structured_document_train = SimpleStructuredDocument( - SimplePage([ - SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1), - SimpleToken(TEXT_1, tag=TAG_1), - SimpleToken(TEXT_2, tag=TAG_2), - SimpleToken(TEXT_3, tag=TAG_3) - ]) - ], bounding_box=PAGE_BOUNDING_BOX) - ) - structured_document_test = SimpleStructuredDocument( - SimplePage([ - SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1), - SimpleToken(TEXT_2, tag=TAG_2), - SimpleToken(TEXT_3, tag=TAG_3) - ]) - ], bounding_box=PAGE_BOUNDING_BOX) - ) - token_props_list_train = list(structured_document_to_token_props(structured_document_train)) - X_train = [token_props_list_to_features(token_props_list_train)] - y_train = [token_props_list_to_labels(token_props_list_train)] - - token_props_list_test = list(structured_document_to_token_props(structured_document_test)) - X_test = [token_props_list_to_features(token_props_list_test)] - y_test = [token_props_list_to_labels(token_props_list_test)] - - with create_crf_suite_model() as model: - model.fit(X_train, y_train) - y_predicted = model.predict(X_test) - assert y_predicted == y_test - - def test_should_pickle_and_unpickle_model(self): - structured_document = SimpleStructuredDocument( - SimplePage([ - SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1), - SimpleToken(TEXT_2, tag=TAG_2), - SimpleToken(TEXT_3, tag=TAG_3) - ]) - ], bounding_box=PAGE_BOUNDING_BOX) - ) - token_props_list = list(structured_document_to_token_props(structured_document)) - X = [token_props_list_to_features(token_props_list)] - y = [token_props_list_to_labels(token_props_list)] - with create_crf_suite_model() as model: - model.fit(X, y) - serialized_model = pickle.dumps(model) - - model = pickle.loads(serialized_model) - y_predicted = model.predict(X) - assert y_predicted == y + def test_should_learn_simple_sequence(self): + structured_document = SimpleStructuredDocument( + SimplePage([ + SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1), + SimpleToken(TEXT_2, tag=TAG_2), + SimpleToken(TEXT_3, tag=TAG_3) + ]) + ], bounding_box=PAGE_BOUNDING_BOX) + ) + token_props_list = list(structured_document_to_token_props(structured_document)) + get_logger().debug('token_props_list:\n%s', token_props_list) + X = [token_props_list_to_features(token_props_list)] + y = [token_props_list_to_labels(token_props_list)] + get_logger().debug('X:\n%s', X) + get_logger().debug('y:\n%s', y) + with create_crf_suite_model() as model: + model.fit(X, y) + y_predicted = model.predict(X) + assert y_predicted == y + + def test_should_learn_similar_sequence(self): + structured_document_train = SimpleStructuredDocument( + SimplePage([ + SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1), + SimpleToken(TEXT_1, tag=TAG_1), + SimpleToken(TEXT_2, tag=TAG_2), + SimpleToken(TEXT_3, tag=TAG_3) + ]) + ], bounding_box=PAGE_BOUNDING_BOX) + ) + structured_document_test = SimpleStructuredDocument( + SimplePage([ + SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1), + SimpleToken(TEXT_2, tag=TAG_2), + SimpleToken(TEXT_3, tag=TAG_3) + ]) + ], bounding_box=PAGE_BOUNDING_BOX) + ) + token_props_list_train = list(structured_document_to_token_props(structured_document_train)) + X_train = [token_props_list_to_features(token_props_list_train)] + y_train = [token_props_list_to_labels(token_props_list_train)] + + token_props_list_test = list(structured_document_to_token_props(structured_document_test)) + X_test = [token_props_list_to_features(token_props_list_test)] + y_test = [token_props_list_to_labels(token_props_list_test)] + + with create_crf_suite_model() as model: + model.fit(X_train, y_train) + y_predicted = model.predict(X_test) + assert y_predicted == y_test + + def test_should_pickle_and_unpickle_model(self): + structured_document = SimpleStructuredDocument( + SimplePage([ + SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1), + SimpleToken(TEXT_2, tag=TAG_2), + SimpleToken(TEXT_3, tag=TAG_3) + ]) + ], bounding_box=PAGE_BOUNDING_BOX) + ) + token_props_list = list(structured_document_to_token_props(structured_document)) + X = [token_props_list_to_features(token_props_list)] + y = [token_props_list_to_labels(token_props_list)] + with create_crf_suite_model() as model: + model.fit(X, y) + serialized_model = pickle.dumps(model) + + model = pickle.loads(serialized_model) + y_predicted = model.predict(X) + assert y_predicted == y diff --git a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py index 50d7c80..30d6303 100644 --- a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py +++ b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py @@ -9,215 +9,227 @@ from six import raise_from from tqdm import tqdm from sciencebeam_utils.beam_utils.io import ( - save_file_content + save_file_content ) from sciencebeam_utils.utils.stopwatch import ( - StopWatchRecorder + StopWatchRecorder ) from sciencebeam_utils.utils.file_list import ( - load_file_list + load_file_list ) from sciencebeam_gym.structured_document.structured_document_loader import ( - load_structured_document + load_structured_document ) from sciencebeam_gym.preprocess.preprocessing_utils import ( - parse_page_range + parse_page_range ) from sciencebeam_gym.models.text.feature_extractor import ( - structured_document_to_token_props, - token_props_list_to_features, - token_props_list_to_labels, - merge_with_cv_structured_document, - CV_TAG_SCOPE + structured_document_to_token_props, + token_props_list_to_features, + token_props_list_to_labels, + merge_with_cv_structured_document, + CV_TAG_SCOPE ) from sciencebeam_gym.models.text.crf.crfsuite_model import ( - CrfSuiteModel + CrfSuiteModel ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def parse_args(argv=None): - parser = argparse.ArgumentParser('Trains the CRF Suite model') - source = parser.add_argument_group('source') - source.add_argument( - '--source-file-list', type=str, required=True, - help='path to source file list (tsv/csv/lst)' - ) - source.add_argument( - '--source-file-column', type=str, required=False, - default='url', - help='csv/tsv column (ignored for plain file list)' - ) - - cv_source = parser.add_argument_group('CV source') - cv_source.add_argument( - '--cv-source-file-list', type=str, required=False, - help='path to cv source file list (tsv/csv/lst)' - ' (must be in line with main source file list)' - ) - source.add_argument( - '--cv-source-file-column', type=str, required=False, - default='url', - help='csv/tsv column (ignored for plain file list)' - ) - source.add_argument( - '--cv-source-tag-scope', type=str, required=False, - default=CV_TAG_SCOPE, - help='source tag scope to get the cv tag from' - ) - - parser.add_argument( - '--limit', type=int, required=False, - help='limit the files to process' - ) - parser.add_argument( - '--pages', type=parse_page_range, default=None, - help='only processes the selected pages' - ) - - output = parser.add_argument_group('output') - output.add_argument( - '--output-path', type=str, required=True, - help='output path to model' - ) - - parser.add_argument( - '--debug', action='store_true', default=False, - help='enable debug output' - ) - - return parser.parse_args(argv) + parser = argparse.ArgumentParser('Trains the CRF Suite model') + source = parser.add_argument_group('source') + source.add_argument( + '--source-file-list', type=str, required=True, + help='path to source file list (tsv/csv/lst)' + ) + source.add_argument( + '--source-file-column', type=str, required=False, + default='url', + help='csv/tsv column (ignored for plain file list)' + ) + + cv_source = parser.add_argument_group('CV source') + cv_source.add_argument( + '--cv-source-file-list', type=str, required=False, + help='path to cv source file list (tsv/csv/lst)' + ' (must be in line with main source file list)' + ) + source.add_argument( + '--cv-source-file-column', type=str, required=False, + default='url', + help='csv/tsv column (ignored for plain file list)' + ) + source.add_argument( + '--cv-source-tag-scope', type=str, required=False, + default=CV_TAG_SCOPE, + help='source tag scope to get the cv tag from' + ) + + parser.add_argument( + '--limit', type=int, required=False, + help='limit the files to process' + ) + parser.add_argument( + '--pages', type=parse_page_range, default=None, + help='only processes the selected pages' + ) + + output = parser.add_argument_group('output') + output.add_argument( + '--output-path', type=str, required=True, + help='output path to model' + ) + + parser.add_argument( + '--debug', action='store_true', default=False, + help='enable debug output' + ) + + return parser.parse_args(argv) + def load_and_convert_to_token_props(filename, cv_filename, cv_source_tag_scope, page_range=None): - try: - structured_document = load_structured_document(filename, page_range=page_range) - if cv_filename: - cv_structured_document = load_structured_document(cv_filename, page_range=page_range) - structured_document = merge_with_cv_structured_document( - structured_document, - cv_structured_document, - cv_source_tag_scope=cv_source_tag_scope - ) - return list(structured_document_to_token_props( - structured_document - )) - except StandardError as e: - raise_from(RuntimeError('failed to process %s (due to %s: %s)' % (filename, type(e), e)), e) + try: + structured_document = load_structured_document(filename, page_range=page_range) + if cv_filename: + cv_structured_document = load_structured_document(cv_filename, page_range=page_range) + structured_document = merge_with_cv_structured_document( + structured_document, + cv_structured_document, + cv_source_tag_scope=cv_source_tag_scope + ) + return list(structured_document_to_token_props( + structured_document + )) + except StandardError as e: + raise_from(RuntimeError('failed to process %s (due to %s: %s)' % (filename, type(e), e)), e) + def serialize_model(model): - return pickle.dumps(model) + return pickle.dumps(model) + def submit_all(executor, fn, iterable): - return {executor.submit(fn, x) for x in iterable} + return {executor.submit(fn, x) for x in iterable} + def load_token_props_list_by_document( - file_list, cv_file_list, cv_source_tag_scope, page_range=None, progress=True): - - if not cv_file_list: - cv_file_list = [None] * len(file_list) - - token_props_list_by_document = [] - total = len(file_list) - error_count = 0 - with tqdm(total=total, leave=False, desc='loading files', disable=not progress) as pbar: - with ThreadPoolExecutor(max_workers=50) as executor: - process_fn = lambda (filename, cv_filename): ( - load_and_convert_to_token_props( - filename, cv_filename, cv_source_tag_scope=cv_source_tag_scope, - page_range=page_range + file_list, cv_file_list, cv_source_tag_scope, page_range=None, progress=True): + + if not cv_file_list: + cv_file_list = [None] * len(file_list) + + token_props_list_by_document = [] + total = len(file_list) + error_count = 0 + with tqdm(total=total, leave=False, desc='loading files', disable=not progress) as pbar: + with ThreadPoolExecutor(max_workers=50) as executor: + def process_fn((filename, cv_filename)): + return ( + load_and_convert_to_token_props( + filename, cv_filename, cv_source_tag_scope=cv_source_tag_scope, + page_range=page_range + ) + ) + futures = submit_all(executor, process_fn, zip(file_list, cv_file_list)) + for future in concurrent.futures.as_completed(futures): + try: + token_props_list_by_document.append(future.result()) + except StandardError as e: + get_logger().warning(str(e), exc_info=e) + error_count += 1 + pbar.update(1) + if error_count: + get_logger().info( + 'loading error count: %d (loaded: %d)', error_count, len(token_props_list_by_document) ) - ) - futures = submit_all(executor, process_fn, zip(file_list, cv_file_list)) - for future in concurrent.futures.as_completed(futures): - try: - token_props_list_by_document.append(future.result()) - except StandardError as e: - get_logger().warning(str(e), exc_info=e) - error_count += 1 - pbar.update(1) - if error_count: - get_logger().info( - 'loading error count: %d (loaded: %d)', error_count, len(token_props_list_by_document) - ) - return token_props_list_by_document + return token_props_list_by_document + def train_model( - file_list, cv_file_list, cv_source_tag_scope, page_range=None, progress=True): + file_list, cv_file_list, cv_source_tag_scope, page_range=None, progress=True): + + stop_watch_recorder = StopWatchRecorder() + model = CrfSuiteModel() - stop_watch_recorder = StopWatchRecorder() - model = CrfSuiteModel() + stop_watch_recorder.start('loading files') + token_props_list_by_document = load_token_props_list_by_document( + file_list, cv_file_list, cv_source_tag_scope=cv_source_tag_scope, + page_range=page_range, progress=progress + ) - stop_watch_recorder.start('loading files') - token_props_list_by_document = load_token_props_list_by_document( - file_list, cv_file_list, cv_source_tag_scope=cv_source_tag_scope, - page_range=page_range, progress=progress - ) + assert token_props_list_by_document - assert token_props_list_by_document + stop_watch_recorder.start('converting to features') + X = [token_props_list_to_features(x) for x in token_props_list_by_document] + y = [token_props_list_to_labels(x) for x in token_props_list_by_document] - stop_watch_recorder.start('converting to features') - X = [token_props_list_to_features(x) for x in token_props_list_by_document] - y = [token_props_list_to_labels(x) for x in token_props_list_by_document] + get_logger().info('training model (with %d documents)', len(X)) + stop_watch_recorder.start('train') + model.fit(X, y) - get_logger().info('training model (with %d documents)', len(X)) - stop_watch_recorder.start('train') - model.fit(X, y) + stop_watch_recorder.start('serialize') + serialized_model = serialize_model(model) - stop_watch_recorder.start('serialize') - serialized_model = serialize_model(model) + stop_watch_recorder.stop() + get_logger().info('timings: %s', stop_watch_recorder) - stop_watch_recorder.stop() - get_logger().info('timings: %s', stop_watch_recorder) + return serialized_model - return serialized_model def save_model(output_filename, model_bytes): - get_logger().info('saving model to %s', output_filename) - save_file_content(output_filename, model_bytes) + get_logger().info('saving model to %s', output_filename) + save_file_content(output_filename, model_bytes) + def run(opt): - file_list = load_file_list( - opt.source_file_list, - opt.source_file_column, - limit=opt.limit - ) - if opt.cv_source_file_list: - cv_file_list = load_file_list( - opt.cv_source_file_list, - opt.cv_source_file_column, - limit=opt.limit + file_list = load_file_list( + opt.source_file_list, + opt.source_file_column, + limit=opt.limit + ) + if opt.cv_source_file_list: + cv_file_list = load_file_list( + opt.cv_source_file_list, + opt.cv_source_file_column, + limit=opt.limit + ) + else: + cv_file_list = None + get_logger().info( + 'training using %d files (limit %d), page range: %s', + len(file_list), opt.limit, opt.pages ) - else: - cv_file_list = None - get_logger().info( - 'training using %d files (limit %d), page range: %s', - len(file_list), opt.limit, opt.pages - ) - save_model( - opt.output_path, - train_model( - file_list, cv_file_list, cv_source_tag_scope=opt.cv_source_tag_scope, - page_range=opt.pages + save_model( + opt.output_path, + train_model( + file_list, cv_file_list, cv_source_tag_scope=opt.cv_source_tag_scope, + page_range=opt.pages + ) ) - ) + def main(argv=None): - args = parse_args(argv) + args = parse_args(argv) + + if args.debug: + logging.getLogger().setLevel('DEBUG') - if args.debug: - logging.getLogger().setLevel('DEBUG') + run(args) - run(args) if __name__ == '__main__': - logging.basicConfig(level='INFO') - logging.getLogger('oauth2client').setLevel('WARN') + logging.basicConfig(level='INFO') + logging.getLogger('oauth2client').setLevel('WARN') - main() + main() diff --git a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py index 24cd23e..a571262 100644 --- a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py +++ b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py @@ -3,27 +3,27 @@ from mock import patch, Mock, ANY import pytest from sciencebeam_utils.utils.collection import ( - to_namedtuple + to_namedtuple ) from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimpleLine, - SimpleToken + SimpleStructuredDocument, + SimpleLine, + SimpleToken ) from sciencebeam_gym.models.text.feature_extractor import ( - CV_TAG_SCOPE + CV_TAG_SCOPE ) import sciencebeam_gym.models.text.crf.crfsuite_training_pipeline as crfsuite_training_pipeline from sciencebeam_gym.models.text.crf.crfsuite_training_pipeline import ( - load_and_convert_to_token_props, - load_token_props_list_by_document, - train_model, - save_model, - run, - main + load_and_convert_to_token_props, + load_token_props_list_by_document, + train_model, + save_model, + run, + main ) SOURCE_FILE_LIST_PATH = '.temp/source-file-list.lst' @@ -43,252 +43,320 @@ TAG_1 = 'tag1' TAG_2 = 'tag2' DEFAULT_ARGS = dict( - source_file_column='url', - cv_source_file_list=None, - cv_source_file_column='url', - cv_source_tag_scope=CV_TAG_SCOPE + source_file_column='url', + cv_source_file_list=None, + cv_source_file_column='url', + cv_source_tag_scope=CV_TAG_SCOPE ) + +@pytest.fixture(name='load_structured_document_mock') +def _load_structured_document_mock(): + with patch.object(crfsuite_training_pipeline, 'load_structured_document') as _mock: + yield _mock + + +@pytest.fixture(name='structured_document_to_token_props_mock') +def _structured_document_to_token_props_mock(): + with patch.object(crfsuite_training_pipeline, 'structured_document_to_token_props') as _mock: + yield _mock + + +@pytest.fixture(name='token_props_list_to_features_mock') +def _token_props_list_to_features_mock(): + with patch.object(crfsuite_training_pipeline, 'token_props_list_to_features') as _mock: + yield _mock + + +@pytest.fixture(name='load_token_props_list_by_document_mock') +def _load_token_props_list_by_document_mock(): + with patch.object(crfsuite_training_pipeline, 'load_token_props_list_by_document') as _mock: + yield _mock + + +@pytest.fixture(name='load_and_convert_to_token_props_mock') +def _load_and_convert_to_token_props_mock(): + with patch.object(crfsuite_training_pipeline, 'load_and_convert_to_token_props') as _mock: + yield _mock + + +@pytest.fixture(name='CrfSuiteModel_mock') +def _CrfSuiteModel_mock(): + with patch.object(crfsuite_training_pipeline, 'CrfSuiteModel') as _mock: + yield _mock + + +@pytest.fixture(name='pickle_mock') +def _pickle_mock(): + with patch.object(crfsuite_training_pipeline, 'pickle') as _mock: + yield _mock + + +@pytest.fixture(name='save_file_content_mock') +def _save_file_content_mock(): + with patch.object(crfsuite_training_pipeline, 'save_file_content') as _mock: + yield _mock + + +@pytest.fixture(name='load_file_list_mock') +def _load_file_list_mock(): + with patch.object(crfsuite_training_pipeline, 'load_file_list') as _mock: + yield _mock + + +@pytest.fixture(name='train_model_mock') +def _train_model_mock(): + with patch.object(crfsuite_training_pipeline, 'train_model') as _mock: + yield _mock + + +@pytest.fixture(name='save_model_mock') +def _save_model_mock(): + with patch.object(crfsuite_training_pipeline, 'save_model') as _mock: + yield _mock + + class TestLoadAndConvertToTokenProps(object): - def test_should_load_and_convert_document(self): - m = crfsuite_training_pipeline - with patch.object(m, 'load_structured_document') as load_structured_document_mock: - with patch.object(m, 'structured_document_to_token_props') as \ - structured_document_to_token_props_mock: + def test_should_load_and_convert_document( + self, + load_structured_document_mock, + structured_document_to_token_props_mock): load_and_convert_to_token_props( - FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE + FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE ) load_structured_document_mock.assert_called_with( - FILE_1, page_range=PAGE_RANGE + FILE_1, page_range=PAGE_RANGE ) structured_document_to_token_props_mock.assert_called_with( - load_structured_document_mock.return_value + load_structured_document_mock.return_value ) - def test_should_load_and_convert_document_with_cv(self): - m = crfsuite_training_pipeline - with patch.object(m, 'load_structured_document') as load_structured_document_mock: - with patch.object(m, 'structured_document_to_token_props') as \ - structured_document_to_token_props_mock: + def test_should_load_and_convert_document_with_cv( + self, + load_structured_document_mock, + structured_document_to_token_props_mock): load_and_convert_to_token_props( - FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE + FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE ) load_structured_document_mock.assert_any_call( - FILE_1, page_range=PAGE_RANGE + FILE_1, page_range=PAGE_RANGE ) structured_document_to_token_props_mock.assert_called_with( - load_structured_document_mock.return_value + load_structured_document_mock.return_value ) - def test_should_merge_doc_and_scope_cv_tag(self): - structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1) - ])]) - cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_2, tag_scope=CV_TAG_SCOPE) - ])]) - m = crfsuite_training_pipeline - with patch.object(m, 'load_structured_document') as load_structured_document_mock: - with patch.object(m, 'structured_document_to_token_props') as \ - structured_document_to_token_props_mock: - + def test_should_merge_doc_and_scope_cv_tag( + self, + load_structured_document_mock, + structured_document_to_token_props_mock): + + structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1) + ])]) + cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_2, tag_scope=CV_TAG_SCOPE) + ])]) load_structured_document_mock.side_effect = [ - structured_document, - cv_structured_document + structured_document, + cv_structured_document ] load_and_convert_to_token_props( - FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE + FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE ) load_structured_document_mock.assert_any_call( - FILE_1, page_range=PAGE_RANGE + FILE_1, page_range=PAGE_RANGE ) structured_document_arg = structured_document_to_token_props_mock.call_args[0][0] assert [ - structured_document_arg.get_tag_by_scope(t) - for t in structured_document_arg.iter_all_tokens() + structured_document_arg.get_tag_by_scope(t) + for t in structured_document_arg.iter_all_tokens() ] == [{None: TAG_1, CV_TAG_SCOPE: TAG_2}] + class TestLoadTokenPropsListByDocument(object): - def test_should_load_single_file_without_cv(self): - m = crfsuite_training_pipeline - with patch.object(m, 'load_and_convert_to_token_props') as \ - load_and_convert_to_token_props_mock: - - result = load_token_props_list_by_document( - [FILE_1], None, cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE, progress=False - ) - load_and_convert_to_token_props_mock.assert_called_with( - FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE, page_range=PAGE_RANGE - ) - assert result == [load_and_convert_to_token_props_mock.return_value] - - def test_should_load_single_file_with_cv(self): - m = crfsuite_training_pipeline - with patch.object(m, 'load_and_convert_to_token_props') as \ - load_and_convert_to_token_props_mock: - - result = load_token_props_list_by_document( - [FILE_1], [FILE_2], cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE, progress=False - ) - load_and_convert_to_token_props_mock.assert_called_with( - FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE - ) - assert result == [load_and_convert_to_token_props_mock.return_value] - - def test_should_load_multiple_files(self): - m = crfsuite_training_pipeline - with patch.object(m, 'load_and_convert_to_token_props') as \ - load_and_convert_to_token_props_mock: - - return_values = [Mock(), Mock()] - load_and_convert_to_token_props_mock.side_effect = return_values - result = load_token_props_list_by_document( - [FILE_1, FILE_2], None, cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE, progress=False - ) - load_and_convert_to_token_props_mock.assert_any_call( - FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE - ) - load_and_convert_to_token_props_mock.assert_any_call( - FILE_2, None, cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE - ) - assert set(result) == set(return_values) - - def test_should_skip_files_with_errors(self): - m = crfsuite_training_pipeline - with patch.object(m, 'load_and_convert_to_token_props') as \ - load_and_convert_to_token_props_mock: - - valid_response = Mock() - load_and_convert_to_token_props_mock.side_effect = [ - RuntimeError('oh dear'), valid_response - ] - result = load_token_props_list_by_document( - [FILE_1, FILE_2], None, cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE, progress=False - ) - assert result == [valid_response] + def test_should_load_single_file_without_cv( + self, + load_and_convert_to_token_props_mock): + + result = load_token_props_list_by_document( + [FILE_1], None, cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE, progress=False + ) + load_and_convert_to_token_props_mock.assert_called_with( + FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE, page_range=PAGE_RANGE + ) + assert result == [load_and_convert_to_token_props_mock.return_value] + def test_should_load_single_file_with_cv( + self, + load_and_convert_to_token_props_mock): + + result = load_token_props_list_by_document( + [FILE_1], [FILE_2], cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE, progress=False + ) + load_and_convert_to_token_props_mock.assert_called_with( + FILE_1, FILE_2, cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE + ) + assert result == [load_and_convert_to_token_props_mock.return_value] + + def test_should_load_multiple_files( + self, + load_and_convert_to_token_props_mock): + + return_values = [Mock(), Mock()] + load_and_convert_to_token_props_mock.side_effect = return_values + result = load_token_props_list_by_document( + [FILE_1, FILE_2], None, cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE, progress=False + ) + load_and_convert_to_token_props_mock.assert_any_call( + FILE_1, None, cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE + ) + load_and_convert_to_token_props_mock.assert_any_call( + FILE_2, None, cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE + ) + assert set(result) == set(return_values) + + def test_should_skip_files_with_errors( + self, + load_and_convert_to_token_props_mock): + + valid_response = Mock() + load_and_convert_to_token_props_mock.side_effect = [ + RuntimeError('oh dear'), valid_response + ] + result = load_token_props_list_by_document( + [FILE_1, FILE_2], None, cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE, progress=False + ) + assert result == [valid_response] + + +@pytest.mark.usefixtures( + 'token_props_list_to_features_mock', + 'CrfSuiteModel_mock', + 'pickle_mock' +) class TestTrainModel(object): - def test_should_train_on_single_file(self): - m = crfsuite_training_pipeline - with patch.object(m, 'load_token_props_list_by_document') as \ - load_token_props_list_by_document_mock: - with patch.object(m, 'CrfSuiteModel') as CrfSuiteModel_mock: - with patch.object(m, 'pickle') as pickle: - with patch.object(m, 'token_props_list_to_features') as _: + def test_should_train_on_single_file( + self, + load_token_props_list_by_document_mock, + CrfSuiteModel_mock, + pickle_mock): + + train_model( + [FILE_1], [FILE_2], + cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE, progress=False + ) + load_token_props_list_by_document_mock.assert_called_with( + [FILE_1], [FILE_2], cv_source_tag_scope=CV_TAG_SCOPE, + page_range=PAGE_RANGE, progress=False + ) + model = CrfSuiteModel_mock.return_value + model.fit.assert_called_with(ANY, ANY) + pickle_mock.dumps.assert_called_with(model) + + def test_should_raise_error_if_no_documents_have_been_loaded( + self, + load_token_props_list_by_document_mock): + + with pytest.raises(AssertionError): + load_token_props_list_by_document_mock.return_value = [] train_model( - [FILE_1], [FILE_2], - cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE, progress=False - ) - load_token_props_list_by_document_mock.assert_called_with( - [FILE_1], [FILE_2], cv_source_tag_scope=CV_TAG_SCOPE, - page_range=PAGE_RANGE, progress=False - ) - model = CrfSuiteModel_mock.return_value - model.fit.assert_called_with(ANY, ANY) - pickle.dumps.assert_called_with(model) - - def test_should_raise_error_if_no_documents_have_been_loaded(self): - m = crfsuite_training_pipeline - with patch.object(m, 'load_token_props_list_by_document') as \ - load_token_props_list_by_document_mock: - with patch.object(m, 'CrfSuiteModel'): - with patch.object(m, 'pickle'): - with patch.object(m, 'token_props_list_to_features') as _: - with pytest.raises(AssertionError): - load_token_props_list_by_document_mock.return_value = [] - train_model( [FILE_1], [FILE_2], cv_source_tag_scope=CV_TAG_SCOPE, page_range=PAGE_RANGE - ) + ) + class TestSaveModel(object): - def test_should_call_save_content(self): - m = crfsuite_training_pipeline - with patch.object(m, 'save_file_content') as save_file_content: - save_model(FILE_1, MODEL_DATA) - save_file_content.assert_called_with(FILE_1, MODEL_DATA) + def test_should_call_save_content( + self, + save_file_content_mock): + save_model(FILE_1, MODEL_DATA) + save_file_content_mock.assert_called_with(FILE_1, MODEL_DATA) + class TestRun(object): - def test_should_train_on_single_file(self): - m = crfsuite_training_pipeline - opt = to_namedtuple( - DEFAULT_ARGS, - source_file_list=SOURCE_FILE_LIST_PATH, - output_path=FILE_1, - limit=2, - pages=PAGE_RANGE - ) - with patch.object(m, 'load_file_list') as load_file_list: - with patch.object(m, 'train_model') as train_model_mock: - with patch.object(m, 'save_model') as save_model_mock: - run(opt) - load_file_list.assert_called_with( + def test_should_train_on_single_file( + self, + load_file_list_mock, + train_model_mock, + save_model_mock): + + opt = to_namedtuple( + DEFAULT_ARGS, + source_file_list=SOURCE_FILE_LIST_PATH, + output_path=FILE_1, + limit=2, + pages=PAGE_RANGE + ) + run(opt) + load_file_list_mock.assert_called_with( opt.source_file_list, opt.source_file_column, limit=opt.limit - ) - train_model_mock.assert_called_with( - load_file_list.return_value, + ) + train_model_mock.assert_called_with( + load_file_list_mock.return_value, None, cv_source_tag_scope=opt.cv_source_tag_scope, page_range=PAGE_RANGE - ) - save_model_mock.assert_called_with( + ) + save_model_mock.assert_called_with( opt.output_path, train_model_mock.return_value - ) - - def test_should_train_on_single_file_with_cv_output(self): - m = crfsuite_training_pipeline - opt = to_namedtuple( - DEFAULT_ARGS, - source_file_list=SOURCE_FILE_LIST_PATH, - cv_source_file_list=CV_SOURCE_FILE_LIST_PATH, - output_path=FILE_1, - limit=2, - pages=PAGE_RANGE - ) - file_list = [FILE_1, FILE_2] - cv_file_list = ['cv.' + FILE_1, 'cv.' + FILE_2] - with patch.object(m, 'load_file_list') as load_file_list: - with patch.object(m, 'train_model') as train_model_mock: - with patch.object(m, 'save_model') as save_model_mock: - load_file_list.side_effect = [file_list, cv_file_list] - run(opt) - load_file_list.assert_any_call( + ) + + def test_should_train_on_single_file_with_cv_output( + self, + load_file_list_mock, + train_model_mock, + save_model_mock): + + opt = to_namedtuple( + DEFAULT_ARGS, + source_file_list=SOURCE_FILE_LIST_PATH, + cv_source_file_list=CV_SOURCE_FILE_LIST_PATH, + output_path=FILE_1, + limit=2, + pages=PAGE_RANGE + ) + file_list = [FILE_1, FILE_2] + cv_file_list = ['cv.' + FILE_1, 'cv.' + FILE_2] + load_file_list_mock.side_effect = [file_list, cv_file_list] + run(opt) + load_file_list_mock.assert_any_call( opt.source_file_list, opt.source_file_column, limit=opt.limit - ) - load_file_list.assert_any_call( + ) + load_file_list_mock.assert_any_call( opt.cv_source_file_list, opt.cv_source_file_column, limit=opt.limit - ) - train_model_mock.assert_called_with( + ) + train_model_mock.assert_called_with( file_list, cv_file_list, cv_source_tag_scope=opt.cv_source_tag_scope, page_range=PAGE_RANGE - ) - save_model_mock.assert_called_with( + ) + save_model_mock.assert_called_with( opt.output_path, train_model_mock.return_value - ) + ) + class TestMain(object): - def test_should(self): - argv = ['--source-file-list', SOURCE_FILE_LIST_PATH] - with patch.object(crfsuite_training_pipeline, 'parse_args') as parse_args_mock: - with patch.object(crfsuite_training_pipeline, 'run') as run_mock: - main(argv) - parse_args_mock.assert_called_with(argv) - run_mock.assert_called_with(parse_args_mock.return_value) + def test_should(self): + argv = ['--source-file-list', SOURCE_FILE_LIST_PATH] + with patch.object(crfsuite_training_pipeline, 'parse_args') as parse_args_mock: + with patch.object(crfsuite_training_pipeline, 'run') as run_mock: + main(argv) + parse_args_mock.assert_called_with(argv) + run_mock.assert_called_with(parse_args_mock.return_value) diff --git a/sciencebeam_gym/models/text/feature_extractor.py b/sciencebeam_gym/models/text/feature_extractor.py index b8e11ff..b0fce1c 100644 --- a/sciencebeam_gym/models/text/feature_extractor.py +++ b/sciencebeam_gym/models/text/feature_extractor.py @@ -1,119 +1,125 @@ from functools import partial from sciencebeam_gym.inference_model.annotate_using_predictions import ( - CV_TAG_SCOPE + CV_TAG_SCOPE ) from sciencebeam_gym.structured_document import ( - merge_token_tag + merge_token_tag ) NONE_TAG = 'O' + def structured_document_to_token_props(structured_document): - pages = list(structured_document.get_pages()) - for page_index, page in enumerate(pages): - page_bounding_box = structured_document.get_bounding_box(page) - assert page_bounding_box is not None - page_width = page_bounding_box.width - page_height = page_bounding_box.height - page_info = { - 'index': page_index, - 'count': len(pages), - 'width': page_width, - 'height': page_height - } - page_rx = 1.0 / page_width - page_ry = 1.0 / page_height - lines = list(structured_document.get_lines_of_page(page)) - for line_index, line in enumerate(lines): - line_tokens = list(structured_document.get_tokens_of_line(line)) - line_info = { - 'index': line_index, - 'count': len(lines) - } - for line_token_index, token in enumerate(line_tokens): - line_token_info = { - 'index': line_token_index, - 'count': len(line_tokens) - } - bounding_box = structured_document.get_bounding_box(token) - rel_bounding_box = ( - bounding_box.scale_by(page_rx, page_ry) - if bounding_box else None - ) - yield { - 'text': structured_document.get_text(token), - 'tag': structured_document.get_tag(token), - 'scoped_tags': { - k: v - for k, v in structured_document.get_tag_by_scope(token).items() - if k - }, - 'bounding_box': bounding_box, - 'rel_bounding_box': rel_bounding_box, - 'line_token': line_token_info, - 'page': page_info, - 'line': line_info + pages = list(structured_document.get_pages()) + for page_index, page in enumerate(pages): + page_bounding_box = structured_document.get_bounding_box(page) + assert page_bounding_box is not None + page_width = page_bounding_box.width + page_height = page_bounding_box.height + page_info = { + 'index': page_index, + 'count': len(pages), + 'width': page_width, + 'height': page_height } + page_rx = 1.0 / page_width + page_ry = 1.0 / page_height + lines = list(structured_document.get_lines_of_page(page)) + for line_index, line in enumerate(lines): + line_tokens = list(structured_document.get_tokens_of_line(line)) + line_info = { + 'index': line_index, + 'count': len(lines) + } + for line_token_index, token in enumerate(line_tokens): + line_token_info = { + 'index': line_token_index, + 'count': len(line_tokens) + } + bounding_box = structured_document.get_bounding_box(token) + rel_bounding_box = ( + bounding_box.scale_by(page_rx, page_ry) + if bounding_box else None + ) + yield { + 'text': structured_document.get_text(token), + 'tag': structured_document.get_tag(token), + 'scoped_tags': { + k: v + for k, v in structured_document.get_tag_by_scope(token).items() + if k + }, + 'bounding_box': bounding_box, + 'rel_bounding_box': rel_bounding_box, + 'line_token': line_token_info, + 'page': page_info, + 'line': line_info + } + def token_props_features(token_props, prefix=''): - word = token_props.get('text') or '' - word_lower = word.lower() - d = { - prefix + 'word.lower': word_lower, - prefix + 'word[:1]': word_lower[:1], - prefix + 'word[-3:]': word_lower[-3:], - prefix + 'word[-2:]': word_lower[-2:], - prefix + 'word[:1].isupper': word[:1].istitle(), - prefix + 'word.isupper': word.isupper(), - prefix + 'word.isdigit': word.isdigit() - } - for scope, tag in token_props.get('scoped_tags', {}).items(): - d[prefix + scope + '.tag'] = tag - return d + word = token_props.get('text') or '' + word_lower = word.lower() + d = { + prefix + 'word.lower': word_lower, + prefix + 'word[:1]': word_lower[:1], + prefix + 'word[-3:]': word_lower[-3:], + prefix + 'word[-2:]': word_lower[-2:], + prefix + 'word[:1].isupper': word[:1].istitle(), + prefix + 'word.isupper': word.isupper(), + prefix + 'word.isdigit': word.isdigit() + } + for scope, tag in token_props.get('scoped_tags', {}).items(): + d[prefix + scope + '.tag'] = tag + return d + def token_props_to_features(token_props_list, i): - features = token_props_features(token_props_list[i]) - if i > 0: - pass - for rel_token_index in [-2, -1, 1, 2]: - abs_token_index = i + rel_token_index - if abs_token_index < 0: - features['BOD[%d]' % rel_token_index] = True - elif abs_token_index >= len(token_props_list): - features['EOD[%d]' % rel_token_index] = True - else: - features.update(token_props_features( - token_props_list[abs_token_index], - str(rel_token_index) + ':' - )) - return features + features = token_props_features(token_props_list[i]) + if i > 0: + pass + for rel_token_index in [-2, -1, 1, 2]: + abs_token_index = i + rel_token_index + if abs_token_index < 0: + features['BOD[%d]' % rel_token_index] = True + elif abs_token_index >= len(token_props_list): + features['EOD[%d]' % rel_token_index] = True + else: + features.update(token_props_features( + token_props_list[abs_token_index], + str(rel_token_index) + ':' + )) + return features def token_props_list_to_features(token_props_list): - token_props_list = list(token_props_list) - return [token_props_to_features(token_props_list, i) for i in range(len(token_props_list))] + token_props_list = list(token_props_list) + return [token_props_to_features(token_props_list, i) for i in range(len(token_props_list))] + def remove_labels_from_token_props_list(token_props_list): - return [ - {k: v for k, v in token_props.items() if k != 'tag'} - for token_props in token_props_list - ] + return [ + {k: v for k, v in token_props.items() if k != 'tag'} + for token_props in token_props_list + ] + def token_props_list_to_labels(token_props_list): - return [token_props.get('tag') or NONE_TAG for token_props in token_props_list] + return [token_props.get('tag') or NONE_TAG for token_props in token_props_list] + def merge_with_cv_structured_document( - structured_document, cv_structured_document, - cv_source_tag_scope=CV_TAG_SCOPE): - - structured_document.merge_with( - cv_structured_document, - partial( - merge_token_tag, - source_scope=cv_source_tag_scope, - target_scope=CV_TAG_SCOPE + structured_document, cv_structured_document, + cv_source_tag_scope=CV_TAG_SCOPE): + + structured_document.merge_with( + cv_structured_document, + partial( + merge_token_tag, + source_scope=cv_source_tag_scope, + target_scope=CV_TAG_SCOPE + ) ) - ) - return structured_document + return structured_document diff --git a/sciencebeam_gym/models/text/feature_extractor_test.py b/sciencebeam_gym/models/text/feature_extractor_test.py index acaa016..4ee97b1 100644 --- a/sciencebeam_gym/models/text/feature_extractor_test.py +++ b/sciencebeam_gym/models/text/feature_extractor_test.py @@ -1,22 +1,22 @@ from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimplePage, - SimpleLine, - SimpleToken + SimpleStructuredDocument, + SimplePage, + SimpleLine, + SimpleToken ) from sciencebeam_gym.models.text.feature_extractor import ( - structured_document_to_token_props, - token_props_list_to_features, - token_props_list_to_labels, - remove_labels_from_token_props_list, - merge_with_cv_structured_document, - NONE_TAG, - CV_TAG_SCOPE + structured_document_to_token_props, + token_props_list_to_features, + token_props_list_to_labels, + remove_labels_from_token_props_list, + merge_with_cv_structured_document, + NONE_TAG, + CV_TAG_SCOPE ) PAGE_BOUNDING_BOX = BoundingBox(0, 0, 100, 200) @@ -32,287 +32,294 @@ TAG_3 = 'tag3' SCOPE_1 = 'scope1' + class TestStructuredDocumentToTokenProps(object): - def test_should_return_empty_token_list_if_document_has_no_pages(self): - structured_document = SimpleStructuredDocument([]) - assert list(structured_document_to_token_props( - structured_document - )) == [] - - def test_should_return_empty_token_list_if_document_has_no_lines(self): - structured_document = SimpleStructuredDocument( - SimplePage([], bounding_box=PAGE_BOUNDING_BOX) - ) - assert list(structured_document_to_token_props( - structured_document - )) == [] - - def test_should_return_single_token_text(self): - structured_document = SimpleStructuredDocument( - SimplePage([SimpleLine([ - SimpleToken(TEXT_1) - ])], bounding_box=PAGE_BOUNDING_BOX) - ) - result = list(structured_document_to_token_props( - structured_document - )) - assert [t.get('text') for t in result] == [TEXT_1] - - def test_should_return_multiple_token_texts(self): - structured_document = SimpleStructuredDocument( - SimplePage([SimpleLine([ - SimpleToken(TEXT_1), - SimpleToken(TEXT_2), - SimpleToken(TEXT_3) - ])], bounding_box=PAGE_BOUNDING_BOX) - ) - result = list(structured_document_to_token_props( - structured_document - )) - assert [t.get('text') for t in result] == [TEXT_1, TEXT_2, TEXT_3] - - def test_should_return_tag(self): - structured_document = SimpleStructuredDocument([ - SimplePage([SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1), - SimpleToken(TEXT_2) - ])], bounding_box=PAGE_BOUNDING_BOX), - SimplePage([SimpleLine([ - SimpleToken(TEXT_3, tag=TAG_3) - ])], bounding_box=PAGE_BOUNDING_BOX) - ]) - result = list(structured_document_to_token_props( - structured_document - )) - assert [t.get('tag') for t in result] == [TAG_1, None, TAG_3] - - def test_should_return_scoped_tags(self): - structured_document = SimpleStructuredDocument([ - SimplePage([SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1), - SimpleToken(TEXT_2) - ])], bounding_box=PAGE_BOUNDING_BOX), - SimplePage([SimpleLine([ - SimpleToken(TEXT_3, tag=TAG_3, tag_scope=SCOPE_1) - ])], bounding_box=PAGE_BOUNDING_BOX) - ]) - result = list(structured_document_to_token_props( - structured_document - )) - assert [t.get('scoped_tags') for t in result] == [{}, {}, {SCOPE_1: TAG_3}] - - def test_should_return_bounding_box(self): - structured_document = SimpleStructuredDocument([ - SimplePage([SimpleLine([ - SimpleToken(TEXT_1, bounding_box=TOKEN_BOUNDING_BOX) - ])], bounding_box=PAGE_BOUNDING_BOX) - ]) - result = list(structured_document_to_token_props( - structured_document - )) - assert [t.get('bounding_box') for t in result] == [TOKEN_BOUNDING_BOX] - - def test_should_return_rel_bounding_box(self): - structured_document = SimpleStructuredDocument([ - SimplePage([SimpleLine([ - SimpleToken(TEXT_1, bounding_box=TOKEN_BOUNDING_BOX) - ])], bounding_box=PAGE_BOUNDING_BOX) - ]) - result = list(structured_document_to_token_props( - structured_document - )) - assert [t.get('rel_bounding_box') for t in result] == [ - TOKEN_BOUNDING_BOX.scale_by( - 1.0 / PAGE_BOUNDING_BOX.width, - 1.0 / PAGE_BOUNDING_BOX.height - ) - ] - - def test_should_return_page_index_and_page_count(self): - structured_document = SimpleStructuredDocument([ - SimplePage([SimpleLine([ - SimpleToken(TEXT_1), - SimpleToken(TEXT_2) - ])], bounding_box=PAGE_BOUNDING_BOX), - SimplePage([SimpleLine([ - SimpleToken(TEXT_3) - ])], bounding_box=PAGE_BOUNDING_BOX) - ]) - result = list(structured_document_to_token_props( - structured_document - )) - pages = [t.get('page') for t in result] - assert [p.get('index') for p in pages] == [0, 0, 1] - assert [p.get('count') for p in pages] == [2, 2, 2] - - def test_should_return_page_width_and_height(self): - structured_document = SimpleStructuredDocument([ - SimplePage([SimpleLine([ - SimpleToken(TEXT_1) - ])], bounding_box=PAGE_BOUNDING_BOX) - ]) - result = list(structured_document_to_token_props( - structured_document - )) - pages = [t.get('page') for t in result] - assert [p.get('width') for p in pages] == [PAGE_BOUNDING_BOX.width] - assert [p.get('height') for p in pages] == [PAGE_BOUNDING_BOX.height] - - def test_should_return_line_index_and_page_count(self): - structured_document = SimpleStructuredDocument([ - SimplePage([SimpleLine([ - SimpleToken(TEXT_1) - ]), SimpleLine([ - SimpleToken(TEXT_2) - ])], bounding_box=PAGE_BOUNDING_BOX), - SimplePage([SimpleLine([ - SimpleToken(TEXT_3) - ])], bounding_box=PAGE_BOUNDING_BOX) - ]) - result = list(structured_document_to_token_props( - structured_document - )) - lines = [t.get('line') for t in result] - assert [l.get('index') for l in lines] == [0, 1, 0] - assert [l.get('count') for l in lines] == [2, 2, 1] - - def test_should_return_line_token_index_and_page_count(self): - structured_document = SimpleStructuredDocument([ - SimplePage([SimpleLine([ - SimpleToken(TEXT_1), - SimpleToken(TEXT_2) - ])], bounding_box=PAGE_BOUNDING_BOX), - SimplePage([SimpleLine([ - SimpleToken(TEXT_3) - ])], bounding_box=PAGE_BOUNDING_BOX) - ]) - result = list(structured_document_to_token_props( - structured_document - )) - line_tokens = [t.get('line_token') for t in result] - assert [t.get('index') for t in line_tokens] == [0, 1, 0] - assert [t.get('count') for t in line_tokens] == [2, 2, 1] + def test_should_return_empty_token_list_if_document_has_no_pages(self): + structured_document = SimpleStructuredDocument([]) + assert list(structured_document_to_token_props( + structured_document + )) == [] + + def test_should_return_empty_token_list_if_document_has_no_lines(self): + structured_document = SimpleStructuredDocument( + SimplePage([], bounding_box=PAGE_BOUNDING_BOX) + ) + assert list(structured_document_to_token_props( + structured_document + )) == [] + + def test_should_return_single_token_text(self): + structured_document = SimpleStructuredDocument( + SimplePage([SimpleLine([ + SimpleToken(TEXT_1) + ])], bounding_box=PAGE_BOUNDING_BOX) + ) + result = list(structured_document_to_token_props( + structured_document + )) + assert [t.get('text') for t in result] == [TEXT_1] + + def test_should_return_multiple_token_texts(self): + structured_document = SimpleStructuredDocument( + SimplePage([SimpleLine([ + SimpleToken(TEXT_1), + SimpleToken(TEXT_2), + SimpleToken(TEXT_3) + ])], bounding_box=PAGE_BOUNDING_BOX) + ) + result = list(structured_document_to_token_props( + structured_document + )) + assert [t.get('text') for t in result] == [TEXT_1, TEXT_2, TEXT_3] + + def test_should_return_tag(self): + structured_document = SimpleStructuredDocument([ + SimplePage([SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1), + SimpleToken(TEXT_2) + ])], bounding_box=PAGE_BOUNDING_BOX), + SimplePage([SimpleLine([ + SimpleToken(TEXT_3, tag=TAG_3) + ])], bounding_box=PAGE_BOUNDING_BOX) + ]) + result = list(structured_document_to_token_props( + structured_document + )) + assert [t.get('tag') for t in result] == [TAG_1, None, TAG_3] + + def test_should_return_scoped_tags(self): + structured_document = SimpleStructuredDocument([ + SimplePage([SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1), + SimpleToken(TEXT_2) + ])], bounding_box=PAGE_BOUNDING_BOX), + SimplePage([SimpleLine([ + SimpleToken(TEXT_3, tag=TAG_3, tag_scope=SCOPE_1) + ])], bounding_box=PAGE_BOUNDING_BOX) + ]) + result = list(structured_document_to_token_props( + structured_document + )) + assert [t.get('scoped_tags') for t in result] == [{}, {}, {SCOPE_1: TAG_3}] + + def test_should_return_bounding_box(self): + structured_document = SimpleStructuredDocument([ + SimplePage([SimpleLine([ + SimpleToken(TEXT_1, bounding_box=TOKEN_BOUNDING_BOX) + ])], bounding_box=PAGE_BOUNDING_BOX) + ]) + result = list(structured_document_to_token_props( + structured_document + )) + assert [t.get('bounding_box') for t in result] == [TOKEN_BOUNDING_BOX] + + def test_should_return_rel_bounding_box(self): + structured_document = SimpleStructuredDocument([ + SimplePage([SimpleLine([ + SimpleToken(TEXT_1, bounding_box=TOKEN_BOUNDING_BOX) + ])], bounding_box=PAGE_BOUNDING_BOX) + ]) + result = list(structured_document_to_token_props( + structured_document + )) + assert [t.get('rel_bounding_box') for t in result] == [ + TOKEN_BOUNDING_BOX.scale_by( + 1.0 / PAGE_BOUNDING_BOX.width, + 1.0 / PAGE_BOUNDING_BOX.height + ) + ] + + def test_should_return_page_index_and_page_count(self): + structured_document = SimpleStructuredDocument([ + SimplePage([SimpleLine([ + SimpleToken(TEXT_1), + SimpleToken(TEXT_2) + ])], bounding_box=PAGE_BOUNDING_BOX), + SimplePage([SimpleLine([ + SimpleToken(TEXT_3) + ])], bounding_box=PAGE_BOUNDING_BOX) + ]) + result = list(structured_document_to_token_props( + structured_document + )) + pages = [t.get('page') for t in result] + assert [p.get('index') for p in pages] == [0, 0, 1] + assert [p.get('count') for p in pages] == [2, 2, 2] + + def test_should_return_page_width_and_height(self): + structured_document = SimpleStructuredDocument([ + SimplePage([SimpleLine([ + SimpleToken(TEXT_1) + ])], bounding_box=PAGE_BOUNDING_BOX) + ]) + result = list(structured_document_to_token_props( + structured_document + )) + pages = [t.get('page') for t in result] + assert [p.get('width') for p in pages] == [PAGE_BOUNDING_BOX.width] + assert [p.get('height') for p in pages] == [PAGE_BOUNDING_BOX.height] + + def test_should_return_line_index_and_page_count(self): + structured_document = SimpleStructuredDocument([ + SimplePage([SimpleLine([ + SimpleToken(TEXT_1) + ]), SimpleLine([ + SimpleToken(TEXT_2) + ])], bounding_box=PAGE_BOUNDING_BOX), + SimplePage([SimpleLine([ + SimpleToken(TEXT_3) + ])], bounding_box=PAGE_BOUNDING_BOX) + ]) + result = list(structured_document_to_token_props( + structured_document + )) + lines = [t.get('line') for t in result] + assert [l.get('index') for l in lines] == [0, 1, 0] + assert [l.get('count') for l in lines] == [2, 2, 1] + + def test_should_return_line_token_index_and_page_count(self): + structured_document = SimpleStructuredDocument([ + SimplePage([SimpleLine([ + SimpleToken(TEXT_1), + SimpleToken(TEXT_2) + ])], bounding_box=PAGE_BOUNDING_BOX), + SimplePage([SimpleLine([ + SimpleToken(TEXT_3) + ])], bounding_box=PAGE_BOUNDING_BOX) + ]) + result = list(structured_document_to_token_props( + structured_document + )) + line_tokens = [t.get('line_token') for t in result] + assert [t.get('index') for t in line_tokens] == [0, 1, 0] + assert [t.get('count') for t in line_tokens] == [2, 2, 1] + def create_token_props(text, **kwargs): - d = { - 'text': text - } - d.update(kwargs) - return d + d = { + 'text': text + } + d.update(kwargs) + return d + class TestTokenPropsListToFeatures(object): - def test_should_extract_various_word_features(self): - result = token_props_list_to_features([ - create_token_props('TestMe') - ]) - assert [x.get('word.lower') for x in result] == ['testme'] - assert [x.get('word[:1]') for x in result] == ['t'] - assert [x.get('word[-3:]') for x in result] == ['tme'] - assert [x.get('word[-2:]') for x in result] == ['me'] - assert [x.get('word[:1].isupper') for x in result] == [True] - assert [x.get('word.isupper') for x in result] == [False] - assert [x.get('word.isdigit') for x in result] == [False] - - def test_should_extract_scoped_tags(self): - token_props = create_token_props(TEXT_1) - token_props['scoped_tags'] = { - SCOPE_1: TAG_1 - } - result = token_props_list_to_features([token_props]) - assert [x.get('%s.tag' % SCOPE_1) for x in result] == [TAG_1] - - def test_should_add_previous_and_next_token_word_features(self): - result = token_props_list_to_features([ - create_token_props(TEXT_1), - create_token_props(TEXT_2), - create_token_props(TEXT_3) - ]) - assert [x.get('word.lower') for x in result] == [ - TEXT_1.lower(), TEXT_2.lower(), TEXT_3.lower() - ] - assert [x.get('-2:word.lower') for x in result] == [ - None, None, TEXT_1.lower() - ] - assert [x.get('-1:word.lower') for x in result] == [ - None, TEXT_1.lower(), TEXT_2.lower() - ] - assert [x.get('1:word.lower') for x in result] == [ - TEXT_2.lower(), TEXT_3.lower(), None - ] - assert [x.get('2:word.lower') for x in result] == [ - TEXT_3.lower(), None, None - ] - assert [x.get('BOD[-2]') for x in result] == [ - True, True, None - ] - assert [x.get('BOD[-1]') for x in result] == [ - True, None, None - ] - assert [x.get('EOD[1]') for x in result] == [ - None, None, True - ] - assert [x.get('EOD[2]') for x in result] == [ - None, True, True - ] - - def test_should_not_include_tag(self): - result = token_props_list_to_features([ - create_token_props(TEXT_1, tag=TAG_1) - ]) - assert [x.get('tag') for x in result] == [None] + def test_should_extract_various_word_features(self): + result = token_props_list_to_features([ + create_token_props('TestMe') + ]) + assert [x.get('word.lower') for x in result] == ['testme'] + assert [x.get('word[:1]') for x in result] == ['t'] + assert [x.get('word[-3:]') for x in result] == ['tme'] + assert [x.get('word[-2:]') for x in result] == ['me'] + assert [x.get('word[:1].isupper') for x in result] == [True] + assert [x.get('word.isupper') for x in result] == [False] + assert [x.get('word.isdigit') for x in result] == [False] + + def test_should_extract_scoped_tags(self): + token_props = create_token_props(TEXT_1) + token_props['scoped_tags'] = { + SCOPE_1: TAG_1 + } + result = token_props_list_to_features([token_props]) + assert [x.get('%s.tag' % SCOPE_1) for x in result] == [TAG_1] + + def test_should_add_previous_and_next_token_word_features(self): + result = token_props_list_to_features([ + create_token_props(TEXT_1), + create_token_props(TEXT_2), + create_token_props(TEXT_3) + ]) + assert [x.get('word.lower') for x in result] == [ + TEXT_1.lower(), TEXT_2.lower(), TEXT_3.lower() + ] + assert [x.get('-2:word.lower') for x in result] == [ + None, None, TEXT_1.lower() + ] + assert [x.get('-1:word.lower') for x in result] == [ + None, TEXT_1.lower(), TEXT_2.lower() + ] + assert [x.get('1:word.lower') for x in result] == [ + TEXT_2.lower(), TEXT_3.lower(), None + ] + assert [x.get('2:word.lower') for x in result] == [ + TEXT_3.lower(), None, None + ] + assert [x.get('BOD[-2]') for x in result] == [ + True, True, None + ] + assert [x.get('BOD[-1]') for x in result] == [ + True, None, None + ] + assert [x.get('EOD[1]') for x in result] == [ + None, None, True + ] + assert [x.get('EOD[2]') for x in result] == [ + None, True, True + ] + + def test_should_not_include_tag(self): + result = token_props_list_to_features([ + create_token_props(TEXT_1, tag=TAG_1) + ]) + assert [x.get('tag') for x in result] == [None] + class TestTokenPropsListToLabels(object): - def test_should_extract_tag(self): - assert token_props_list_to_labels([ - create_token_props(TEXT_1, tag=TAG_1), - create_token_props(TEXT_2, tag=TAG_2) - ]) == [TAG_1, TAG_2] - - def test_should_replace_none_tag(self): - assert token_props_list_to_labels([ - create_token_props(TEXT_1, tag=TAG_1), - create_token_props(TEXT_2, tag=None) - ]) == [TAG_1, NONE_TAG] + def test_should_extract_tag(self): + assert token_props_list_to_labels([ + create_token_props(TEXT_1, tag=TAG_1), + create_token_props(TEXT_2, tag=TAG_2) + ]) == [TAG_1, TAG_2] + + def test_should_replace_none_tag(self): + assert token_props_list_to_labels([ + create_token_props(TEXT_1, tag=TAG_1), + create_token_props(TEXT_2, tag=None) + ]) == [TAG_1, NONE_TAG] + class TestRemoveLabelsFromTokenPropsList(object): - def test_should_remove_tag(self): - token_props_list = [ - create_token_props(TEXT_1, tag=TAG_1), - create_token_props(TEXT_2, tag=TAG_2) - ] - updated_token_props_list = remove_labels_from_token_props_list(token_props_list) - assert [x.get('tag') for x in token_props_list] == [TAG_1, TAG_2] - assert [x.get('tag') for x in updated_token_props_list] == [None, None] - assert [x.get('text') for x in updated_token_props_list] == [TEXT_1, TEXT_2] + def test_should_remove_tag(self): + token_props_list = [ + create_token_props(TEXT_1, tag=TAG_1), + create_token_props(TEXT_2, tag=TAG_2) + ] + updated_token_props_list = remove_labels_from_token_props_list(token_props_list) + assert [x.get('tag') for x in token_props_list] == [TAG_1, TAG_2] + assert [x.get('tag') for x in updated_token_props_list] == [None, None] + assert [x.get('text') for x in updated_token_props_list] == [TEXT_1, TEXT_2] + def get_all_token_tags(structured_document, scope=None): - return [x.get_tag(scope=scope) for x in structured_document.iter_all_tokens()] + return [x.get_tag(scope=scope) for x in structured_document.iter_all_tokens()] + class TestMergeWithCvStructuredDocument(object): - def test_should_merge_from_default_tag_scope(self): - structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag_scope=None, tag=TAG_1) - ])]) - cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag_scope=None, tag=TAG_2) - ])]) - structured_document = merge_with_cv_structured_document( - structured_document, cv_structured_document, - cv_source_tag_scope=None - ) - assert get_all_token_tags(structured_document) == [TAG_1] - assert get_all_token_tags(structured_document, scope=CV_TAG_SCOPE) == [TAG_2] - - def test_should_merge_from_cv_tag_scope(self): - structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag_scope=None, tag=TAG_1) - ])]) - cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag_scope=CV_TAG_SCOPE, tag=TAG_2) - ])]) - structured_document = merge_with_cv_structured_document( - structured_document, cv_structured_document, - cv_source_tag_scope=CV_TAG_SCOPE - ) - assert get_all_token_tags(structured_document) == [TAG_1] - assert get_all_token_tags(structured_document, scope=CV_TAG_SCOPE) == [TAG_2] + def test_should_merge_from_default_tag_scope(self): + structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag_scope=None, tag=TAG_1) + ])]) + cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag_scope=None, tag=TAG_2) + ])]) + structured_document = merge_with_cv_structured_document( + structured_document, cv_structured_document, + cv_source_tag_scope=None + ) + assert get_all_token_tags(structured_document) == [TAG_1] + assert get_all_token_tags(structured_document, scope=CV_TAG_SCOPE) == [TAG_2] + + def test_should_merge_from_cv_tag_scope(self): + structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag_scope=None, tag=TAG_1) + ])]) + cv_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag_scope=CV_TAG_SCOPE, tag=TAG_2) + ])]) + structured_document = merge_with_cv_structured_document( + structured_document, cv_structured_document, + cv_source_tag_scope=CV_TAG_SCOPE + ) + assert get_all_token_tags(structured_document) == [TAG_1] + assert get_all_token_tags(structured_document, scope=CV_TAG_SCOPE) == [TAG_2] diff --git a/sciencebeam_gym/pdf/__init__.py b/sciencebeam_gym/pdf/__init__.py index e6c3295..236de0d 100644 --- a/sciencebeam_gym/pdf/__init__.py +++ b/sciencebeam_gym/pdf/__init__.py @@ -1,2 +1,2 @@ -from sciencebeam_gym.pdf.pdf_to_lxml_wrapper import PdfToLxmlWrapper -from sciencebeam_gym.pdf.pdf_to_png import PdfToPng +from sciencebeam_gym.pdf.pdf_to_lxml_wrapper import PdfToLxmlWrapper # flake8: noqa +from sciencebeam_gym.pdf.pdf_to_png import PdfToPng # flake8: noqa diff --git a/sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py b/sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py index 0c6fadb..ceae411 100644 --- a/sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py +++ b/sciencebeam_gym/pdf/pdf_to_lxml_wrapper.py @@ -10,137 +10,144 @@ from tempfile import NamedTemporaryFile from sciencebeam_utils.utils.io import makedirs from sciencebeam_utils.utils.zip import extract_all_with_executable_permission + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def iter_read_lines(reader): - while True: - line = reader.readline() - if not line: - break - yield line + while True: + line = reader.readline() + if not line: + break + yield line + def stream_lines_to_logger(lines, logger, prefix=''): - for line in lines: - line = line.strip() - if line: - logger.info('%s%s', prefix, line) + for line in lines: + line = line.strip() + if line: + logger.info('%s%s', prefix, line) + def download_if_not_exist(url, target_file): - if not os.path.isfile(target_file): - get_logger().info('downloading %s to %s', url, target_file) + if not os.path.isfile(target_file): + get_logger().info('downloading %s to %s', url, target_file) - makedirs(os.path.dirname(target_file), exists_ok=True) + makedirs(os.path.dirname(target_file), exists_ok=True) - temp_filename = target_file + '.part' - if os.path.isfile(temp_filename): - os.remove(temp_filename) - URLopener().retrieve(url, temp_filename) - os.rename(temp_filename, target_file) - return target_file + temp_filename = target_file + '.part' + if os.path.isfile(temp_filename): + os.remove(temp_filename) + URLopener().retrieve(url, temp_filename) + os.rename(temp_filename, target_file) + return target_file -def unzip_if_not_exist(zip_filename, target_directory, ignore_subdirectory=None): - if ignore_subdirectory is None: - ignore_subdirectory = os.path.basename(zip_filename) - if not os.path.isdir(target_directory): - get_logger().info('unzipping %s to %s', zip_filename, target_directory) - temp_target_directory = target_directory + '.part' - if os.path.isdir(temp_target_directory): - rmtree(temp_target_directory) - - with ZipFile(zip_filename, 'r') as zf: - extract_all_with_executable_permission(zf, temp_target_directory) - # ignore first level in the directory structure, if applicable - sub_dir = ( - os.path.join(temp_target_directory, ignore_subdirectory) - if ignore_subdirectory - else target_directory - ) - if os.path.isdir(sub_dir): - os.rename(sub_dir, target_directory) - rmtree(temp_target_directory) - else: - os.rename(temp_target_directory, target_directory) - return target_directory -class PdfToLxmlWrapper(object): - def __init__(self): - temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.temp')) - self.target_directory = os.path.join(temp_dir, 'pdf2xml') - self.zip_filename = os.path.join(temp_dir, 'pdf2xml.zip') - self.zip_url = ( - 'https://storage.googleapis.com/elife-ml/artefacts/pdf2xml-linux-64.zip' - ) +def unzip_if_not_exist(zip_filename, target_directory, ignore_subdirectory=None): + if ignore_subdirectory is None: + ignore_subdirectory = os.path.basename(zip_filename) + if not os.path.isdir(target_directory): + get_logger().info('unzipping %s to %s', zip_filename, target_directory) + temp_target_directory = target_directory + '.part' + if os.path.isdir(temp_target_directory): + rmtree(temp_target_directory) + + with ZipFile(zip_filename, 'r') as zf: + extract_all_with_executable_permission(zf, temp_target_directory) + # ignore first level in the directory structure, if applicable + sub_dir = ( + os.path.join(temp_target_directory, ignore_subdirectory) + if ignore_subdirectory + else target_directory + ) + if os.path.isdir(sub_dir): + os.rename(sub_dir, target_directory) + rmtree(temp_target_directory) + else: + os.rename(temp_target_directory, target_directory) + return target_directory - def download_pdf2xml_zip_if_not_exist(self): - download_if_not_exist( - self.zip_url, - self.zip_filename - ) - def unzip_pdf2xml_zip_if_target_directory_does_not_exist(self): - if not os.path.isdir(self.target_directory): - self.download_pdf2xml_zip_if_not_exist() - unzip_if_not_exist(self.zip_filename, self.target_directory) - - def get_pdf2xml_executable_path(self): - self.unzip_pdf2xml_zip_if_target_directory_does_not_exist() - # use pdftoxml_server as it already handles timeouts - return os.path.join( - self.target_directory, - 'lin-64/pdftoxml' - ) +class PdfToLxmlWrapper(object): + def __init__(self): + temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.temp')) + self.target_directory = os.path.join(temp_dir, 'pdf2xml') + self.zip_filename = os.path.join(temp_dir, 'pdf2xml.zip') + self.zip_url = ( + 'https://storage.googleapis.com/elife-ml/artefacts/pdf2xml-linux-64.zip' + ) + + def download_pdf2xml_zip_if_not_exist(self): + download_if_not_exist( + self.zip_url, + self.zip_filename + ) + + def unzip_pdf2xml_zip_if_target_directory_does_not_exist(self): + if not os.path.isdir(self.target_directory): + self.download_pdf2xml_zip_if_not_exist() + unzip_if_not_exist(self.zip_filename, self.target_directory) + + def get_pdf2xml_executable_path(self): + self.unzip_pdf2xml_zip_if_target_directory_does_not_exist() + # use pdftoxml_server as it already handles timeouts + return os.path.join( + self.target_directory, + 'lin-64/pdftoxml' + ) + + def process_input(self, source_data, args): + with NamedTemporaryFile() as f: + f.write(source_data) + f.flush() + os.fsync(f) + return self.process_file(f.name, args) + + def process_file(self, source_filename, args): + pdf2xml = self.get_pdf2xml_executable_path() + get_logger().info('processing %s using %s', source_filename, pdf2xml) + cmd = [pdf2xml] + args + [source_filename, '-'] + p = subprocess.Popen( + ['timeout', '20s'] + cmd, + stdout=PIPE, + stderr=PIPE, + stdin=None + ) + out, err = p.communicate() + return_code = p.returncode + if return_code != 0: + get_logger().warning( + 'process failed with %d, stderr=%s, stdout=%.200s...', + return_code, err, out + ) + raise RuntimeError('process failed with %d, stderr: %s' % (return_code, err)) + if len(out) == 0: + get_logger().warning( + 'process returned empty response (code %d), stderr=%s, stdout=%.500s...', + return_code, err, out + ) + raise RuntimeError( + 'process returned empty response (code %d), stderr: %s' % (return_code, err) + ) + get_logger().info( + 'received response for %s (%s bytes)', + source_filename, format(len(out), ',') + ) + return out - def process_input(self, source_data, args): - with NamedTemporaryFile() as f: - f.write(source_data) - f.flush() - os.fsync(f) - return self.process_file(f.name, args) - - def process_file(self, source_filename, args): - pdf2xml = self.get_pdf2xml_executable_path() - get_logger().info('processing %s using %s', source_filename, pdf2xml) - cmd = [pdf2xml] + args + [source_filename, '-'] - p = subprocess.Popen( - ['timeout', '20s'] + cmd, - stdout=PIPE, - stderr=PIPE, - stdin=None - ) - out, err = p.communicate() - return_code = p.returncode - if return_code != 0: - get_logger().warning( - 'process failed with %d, stderr=%s, stdout=%.200s...', - return_code, err, out - ) - raise RuntimeError('process failed with %d, stderr: %s' % (return_code, err)) - if len(out) == 0: - get_logger().warning( - 'process returned empty response (code %d), stderr=%s, stdout=%.500s...', - return_code, err, out - ) - raise RuntimeError( - 'process returned empty response (code %d), stderr: %s' % (return_code, err) - ) - get_logger().info( - 'received response for %s (%s bytes)', - source_filename, format(len(out), ',') - ) - return out if __name__ == '__main__': - logging.basicConfig(level='INFO') - - sample_pdf_url = 'https://rawgit.com/elifesciences/XML-mapping/master/elife-00666.pdf' - sample_pdf_filename = '.temp/elife-00666.pdf' - download_if_not_exist(sample_pdf_url, sample_pdf_filename) - with open(sample_pdf_filename, 'rb') as sample_f: - sample_pdf_contents = sample_f.read() - get_logger().info('pdf size: %s bytes', format(len(sample_pdf_contents), ',')) - process_out = PdfToLxmlWrapper().process_input( - sample_pdf_contents, - '-blocks -noImageInline -noImage -fullFontName'.split() - ) - get_logger().info('out: %.1000s...', process_out) + logging.basicConfig(level='INFO') + + sample_pdf_url = 'https://rawgit.com/elifesciences/XML-mapping/master/elife-00666.pdf' + sample_pdf_filename = '.temp/elife-00666.pdf' + download_if_not_exist(sample_pdf_url, sample_pdf_filename) + with open(sample_pdf_filename, 'rb') as sample_f: + sample_pdf_contents = sample_f.read() + get_logger().info('pdf size: %s bytes', format(len(sample_pdf_contents), ',')) + process_out = PdfToLxmlWrapper().process_input( + sample_pdf_contents, + '-blocks -noImageInline -noImage -fullFontName'.split() + ) + get_logger().info('out: %.1000s...', process_out) diff --git a/sciencebeam_gym/pdf/pdf_to_png.py b/sciencebeam_gym/pdf/pdf_to_png.py index 0bc816a..1aec457 100644 --- a/sciencebeam_gym/pdf/pdf_to_png.py +++ b/sciencebeam_gym/pdf/pdf_to_png.py @@ -4,59 +4,65 @@ from subprocess import Popen, PIPE from backports.tempfile import TemporaryDirectory + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class PdfToPng(object): - def __init__(self, dpi=None, image_size=None, page_range=None): - self.dpi = dpi - self.image_size = image_size - self.page_range = page_range - - def iter_pdf_bytes_to_png_fp(self, pdf_bytes): - cmd = ['pdftoppm', '-png'] - if self.page_range: - cmd += ['-f', str(self.page_range[0]), '-l', str(self.page_range[1])] - if self.image_size: - cmd += ['-scale-to-x', str(self.image_size[0]), '-scale-to-y', str(self.image_size[1])] - elif self.dpi: - cmd += ['-r', str(self.dpi)] - cmd += ['-'] - with TemporaryDirectory() as path: - cmd += [os.path.join(path, 'page')] - - p = Popen(cmd, stdout=PIPE, stdin=PIPE, stderr=PIPE) - try: - p.stdin.write(pdf_bytes) - except IOError: - # we'll check the returncode - pass - - out, err = p.communicate() - if p.returncode != 0: - get_logger().debug( - 'process failed with return code %d: cmd=%s, out=%s, err=%s', - p.returncode, cmd, out, err - ) - raise IOError( - 'process failed with return code %d, cmd=%s, err=%s' % - (p.returncode, cmd, err) - ) - - for filename in sorted(os.listdir(path)): - with open(os.path.join(path, filename), 'rb') as f: - yield f + def __init__(self, dpi=None, image_size=None, page_range=None): + self.dpi = dpi + self.image_size = image_size + self.page_range = page_range + + def iter_pdf_bytes_to_png_fp(self, pdf_bytes): + cmd = ['pdftoppm', '-png'] + if self.page_range: + cmd += ['-f', str(self.page_range[0]), '-l', str(self.page_range[1])] + if self.image_size: + cmd += ['-scale-to-x', str(self.image_size[0]), '-scale-to-y', str(self.image_size[1])] + elif self.dpi: + cmd += ['-r', str(self.dpi)] + cmd += ['-'] + with TemporaryDirectory() as path: + cmd += [os.path.join(path, 'page')] + + p = Popen(cmd, stdout=PIPE, stdin=PIPE, stderr=PIPE) + try: + p.stdin.write(pdf_bytes) + except IOError: + # we'll check the returncode + pass + + out, err = p.communicate() + if p.returncode != 0: + get_logger().debug( + 'process failed with return code %d: cmd=%s, out=%s, err=%s', + p.returncode, cmd, out, err + ) + raise IOError( + 'process failed with return code %d, cmd=%s, err=%s' % + (p.returncode, cmd, err) + ) + + for filename in sorted(os.listdir(path)): + with open(os.path.join(path, filename), 'rb') as f: + yield f + if __name__ == '__main__': - from sciencebeam_gym.pdf.pdf_to_lxml_wrapper import download_if_not_exist - - logging.basicConfig(level='INFO') - - sample_pdf_url = 'https://rawgit.com/elifesciences/XML-mapping/master/elife-00666.pdf' - sample_pdf_filename = '.temp/elife-00666.pdf' - download_if_not_exist(sample_pdf_url, sample_pdf_filename) - with open(sample_pdf_filename, 'rb') as sample_f: - sample_pdf_contents = sample_f.read() - get_logger().info('pdf size: %s bytes', format(len(sample_pdf_contents), ',')) - png_bytes = [f.read() for f in PdfToPng(dpi=30).iter_pdf_bytes_to_png_fp(sample_pdf_contents)] - get_logger().info('read: total %d (%d files)', sum(len(x) for x in png_bytes), len(png_bytes)) + from sciencebeam_gym.pdf.pdf_to_lxml_wrapper import download_if_not_exist + + logging.basicConfig(level='INFO') + + sample_pdf_url = 'https://rawgit.com/elifesciences/XML-mapping/master/elife-00666.pdf' + sample_pdf_filename = '.temp/elife-00666.pdf' + download_if_not_exist(sample_pdf_url, sample_pdf_filename) + with open(sample_pdf_filename, 'rb') as sample_f: + sample_pdf_contents = sample_f.read() + get_logger().info('pdf size: %s bytes', format(len(sample_pdf_contents), ',')) + png_bytes = [ + _f.read() + for _f in PdfToPng(dpi=30).iter_pdf_bytes_to_png_fp(sample_pdf_contents) + ] + get_logger().info('read: total %d (%d files)', sum(len(x) for x in png_bytes), len(png_bytes)) diff --git a/sciencebeam_gym/pdf/pdf_to_png_test.py b/sciencebeam_gym/pdf/pdf_to_png_test.py index 4a19024..a84ad07 100644 --- a/sciencebeam_gym/pdf/pdf_to_png_test.py +++ b/sciencebeam_gym/pdf/pdf_to_png_test.py @@ -1,10 +1,9 @@ -import logging from subprocess import PIPE from contextlib import contextmanager from mock import patch from sciencebeam_gym.pdf.pdf_to_png import ( - PdfToPng + PdfToPng ) import sciencebeam_gym.pdf.pdf_to_png as pdf_to_png @@ -17,52 +16,56 @@ ARGS_PREFIX = ['pdftoppm', '-png'] ARGS_SUFFIX = ['-', TEMP_DIR + '/page'] DEFAULT_KWARGS = dict(stdout=PIPE, stdin=PIPE, stderr=PIPE) + @contextmanager def patch_popen(): - with patch.object(pdf_to_png, 'Popen') as mock: - p = mock.return_value - p.communicate.return_value = (None, None) - p.returncode = 0 - yield mock + with patch.object(pdf_to_png, 'Popen') as mock: + p = mock.return_value + p.communicate.return_value = (None, None) + p.returncode = 0 + yield mock + @contextmanager def mock_temp_dir(): - with patch.object(pdf_to_png, 'TemporaryDirectory') as mock: - mock.return_value.__enter__.return_value = TEMP_DIR - with patch('os.listdir') as listdir: - listdir.return_value = [] - yield mock + with patch.object(pdf_to_png, 'TemporaryDirectory') as mock: + mock.return_value.__enter__.return_value = TEMP_DIR + with patch('os.listdir') as listdir: + listdir.return_value = [] + yield mock + class TestPdfToPng(object): - def test_should_pass_default_args_to_Popen(self): - with patch_popen() as mock: - with mock_temp_dir(): - list(PdfToPng().iter_pdf_bytes_to_png_fp(PDF_CONTENT_1)) - assert mock.called - mock.assert_called_with( - ARGS_PREFIX + ARGS_SUFFIX, **DEFAULT_KWARGS - ) + def test_should_pass_default_args_to_Popen(self): + with patch_popen() as mock: + with mock_temp_dir(): + list(PdfToPng().iter_pdf_bytes_to_png_fp(PDF_CONTENT_1)) + assert mock.called + mock.assert_called_with( + ARGS_PREFIX + ARGS_SUFFIX, **DEFAULT_KWARGS + ) - def test_should_add_page_range_to_args(self): - with patch_popen() as mock: - with mock_temp_dir(): - list(PdfToPng(page_range=(1, 3)).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1)) - mock.assert_called_with( - ARGS_PREFIX + ['-f', '1', '-l', '3'] + ARGS_SUFFIX, **DEFAULT_KWARGS - ) + def test_should_add_page_range_to_args(self): + with patch_popen() as mock: + with mock_temp_dir(): + list(PdfToPng(page_range=(1, 3)).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1)) + mock.assert_called_with( + ARGS_PREFIX + ['-f', '1', '-l', '3'] + ARGS_SUFFIX, **DEFAULT_KWARGS + ) - def test_should_add_image_size_to_args(self): - with patch_popen() as mock: - with mock_temp_dir(): - list(PdfToPng(image_size=(100, 200)).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1)) - mock.assert_called_with( - ARGS_PREFIX + ['-scale-to-x', '100', '-scale-to-y', '200'] + ARGS_SUFFIX, **DEFAULT_KWARGS - ) + def test_should_add_image_size_to_args(self): + with patch_popen() as mock: + with mock_temp_dir(): + list(PdfToPng(image_size=(100, 200)).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1)) + mock.assert_called_with( + ARGS_PREFIX + ['-scale-to-x', '100', '-scale-to-y', '200'] + ARGS_SUFFIX, + **DEFAULT_KWARGS + ) - def test_should_add_dpi_to_args(self): - with patch_popen() as mock: - with mock_temp_dir(): - list(PdfToPng(dpi=200).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1)) - mock.assert_called_with( - ARGS_PREFIX + ['-r', '200'] + ARGS_SUFFIX, **DEFAULT_KWARGS - ) + def test_should_add_dpi_to_args(self): + with patch_popen() as mock: + with mock_temp_dir(): + list(PdfToPng(dpi=200).iter_pdf_bytes_to_png_fp(PDF_CONTENT_1)) + mock.assert_called_with( + ARGS_PREFIX + ['-r', '200'] + ARGS_SUFFIX, **DEFAULT_KWARGS + ) diff --git a/sciencebeam_gym/preprocess/annotation/annotation_evaluation.py b/sciencebeam_gym/preprocess/annotation/annotation_evaluation.py index 5ba16ce..909a2dd 100644 --- a/sciencebeam_gym/preprocess/annotation/annotation_evaluation.py +++ b/sciencebeam_gym/preprocess/annotation/annotation_evaluation.py @@ -5,54 +5,59 @@ from collections import Counter from six import iteritems from sciencebeam_utils.utils.collection import ( - flatten + flatten ) + class EvaluationFields(object): - DOCUMENT = 'document' - PAGE = 'page' - TAG = 'tag' - COUNT = 'count' + DOCUMENT = 'document' + PAGE = 'page' + TAG = 'tag' + COUNT = 'count' + DEFAULT_EVALUATION_COLUMNS = [ - EvaluationFields.DOCUMENT, - EvaluationFields.PAGE, - EvaluationFields.TAG, - EvaluationFields.COUNT + EvaluationFields.DOCUMENT, + EvaluationFields.PAGE, + EvaluationFields.TAG, + EvaluationFields.COUNT ] + def evaluate_document_page(structured_document, page): - tag_counter = Counter() - for line in structured_document.get_lines_of_page(page): - tag_counter.update( - structured_document.get_tag_value(token) - for token in structured_document.get_tokens_of_line(line) - ) - num_tokens = sum(tag_counter.values()) - return { - 'count': dict(tag_counter), - 'percentage': { - k: c / num_tokens - for k, c in iteritems(tag_counter) + tag_counter = Counter() + for line in structured_document.get_lines_of_page(page): + tag_counter.update( + structured_document.get_tag_value(token) + for token in structured_document.get_tokens_of_line(line) + ) + num_tokens = sum(tag_counter.values()) + return { + 'count': dict(tag_counter), + 'percentage': { + k: c / num_tokens + for k, c in iteritems(tag_counter) + } } - } + def evaluate_document_by_page(structured_document): - return [ - evaluate_document_page(structured_document, page) - for page in structured_document.get_pages() - ] + return [ + evaluate_document_page(structured_document, page) + for page in structured_document.get_pages() + ] + def to_csv_dict_rows(evaluation_result, document=None): - return flatten( - [ - { - EvaluationFields.DOCUMENT: document, - EvaluationFields.PAGE: 1 + page_index, - EvaluationFields.TAG: tag, - EvaluationFields.COUNT: count - } - for tag, count in iteritems(page_evaluation['count']) - ] - for page_index, page_evaluation in enumerate(evaluation_result) - ) + return flatten( + [ + { + EvaluationFields.DOCUMENT: document, + EvaluationFields.PAGE: 1 + page_index, + EvaluationFields.TAG: tag, + EvaluationFields.COUNT: count + } + for tag, count in iteritems(page_evaluation['count']) + ] + for page_index, page_evaluation in enumerate(evaluation_result) + ) diff --git a/sciencebeam_gym/preprocess/annotation/annotation_evaluation_test.py b/sciencebeam_gym/preprocess/annotation/annotation_evaluation_test.py index 37121b0..c51c4f6 100644 --- a/sciencebeam_gym/preprocess/annotation/annotation_evaluation_test.py +++ b/sciencebeam_gym/preprocess/annotation/annotation_evaluation_test.py @@ -1,56 +1,57 @@ from __future__ import division from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimpleLine, - SimpleToken, - B_TAG_PREFIX, - I_TAG_PREFIX + SimpleStructuredDocument, + SimpleLine, + SimpleToken, + B_TAG_PREFIX, + I_TAG_PREFIX ) from sciencebeam_gym.preprocess.annotation.annotation_evaluation import ( - evaluate_document_by_page + evaluate_document_by_page ) TAG1 = 'tag1' + class TestEvaluateDocumentByPage(object): - def test_should_return_ratio_and_count_of_tagged_tokens(self): - tagged_tokens = [ - SimpleToken('this'), - SimpleToken('is'), - SimpleToken('tagged') - ] - not_tagged_tokens = [ - SimpleToken('this'), - SimpleToken('isn\'t') - ] - doc = SimpleStructuredDocument(lines=[SimpleLine( - tagged_tokens + not_tagged_tokens - )]) - for token in tagged_tokens: - doc.set_tag(token, TAG1) - num_total = len(tagged_tokens) + len(not_tagged_tokens) - results = evaluate_document_by_page(doc) - assert results == [{ - 'count': { - TAG1: len(tagged_tokens), - None: len(not_tagged_tokens) - }, - 'percentage': { - TAG1: len(tagged_tokens) / num_total, - None: len(not_tagged_tokens) / num_total - } - }] + def test_should_return_ratio_and_count_of_tagged_tokens(self): + tagged_tokens = [ + SimpleToken('this'), + SimpleToken('is'), + SimpleToken('tagged') + ] + not_tagged_tokens = [ + SimpleToken('this'), + SimpleToken('isn\'t') + ] + doc = SimpleStructuredDocument(lines=[SimpleLine( + tagged_tokens + not_tagged_tokens + )]) + for token in tagged_tokens: + doc.set_tag(token, TAG1) + num_total = len(tagged_tokens) + len(not_tagged_tokens) + results = evaluate_document_by_page(doc) + assert results == [{ + 'count': { + TAG1: len(tagged_tokens), + None: len(not_tagged_tokens) + }, + 'percentage': { + TAG1: len(tagged_tokens) / num_total, + None: len(not_tagged_tokens) / num_total + } + }] - def test_should_strip_prefix(self): - tagged_tokens = [ - SimpleToken('this', tag=TAG1, tag_prefix=B_TAG_PREFIX), - SimpleToken('is', tag=TAG1, tag_prefix=I_TAG_PREFIX), - SimpleToken('tagged', tag=TAG1, tag_prefix=I_TAG_PREFIX) - ] - doc = SimpleStructuredDocument(lines=[SimpleLine( - tagged_tokens - )]) - results = evaluate_document_by_page(doc) - assert set(results[0]['count'].keys()) == {TAG1} + def test_should_strip_prefix(self): + tagged_tokens = [ + SimpleToken('this', tag=TAG1, tag_prefix=B_TAG_PREFIX), + SimpleToken('is', tag=TAG1, tag_prefix=I_TAG_PREFIX), + SimpleToken('tagged', tag=TAG1, tag_prefix=I_TAG_PREFIX) + ] + doc = SimpleStructuredDocument(lines=[SimpleLine( + tagged_tokens + )]) + results = evaluate_document_by_page(doc) + assert set(results[0]['count'].keys()) == {TAG1} diff --git a/sciencebeam_gym/preprocess/annotation/annotator.py b/sciencebeam_gym/preprocess/annotation/annotator.py index 35ab00d..d6a526e 100644 --- a/sciencebeam_gym/preprocess/annotation/annotator.py +++ b/sciencebeam_gym/preprocess/annotation/annotator.py @@ -1,36 +1,40 @@ from abc import ABCMeta, abstractmethod from sciencebeam_gym.preprocess.annotation.find_line_number import ( - find_line_number_tokens + find_line_number_tokens ) + class AbstractAnnotator(object): - __metaclass__ = ABCMeta + __metaclass__ = ABCMeta + + @abstractmethod + def annotate(self, structured_document): + pass - @abstractmethod - def annotate(self, structured_document): - pass class LineAnnotator(AbstractAnnotator): - def __init__(self, tag='line_no'): - self.tag = tag + def __init__(self, tag='line_no'): + self.tag = tag + + def annotate(self, structured_document): + for t in find_line_number_tokens(structured_document): + structured_document.set_tag(t, self.tag) + return structured_document - def annotate(self, structured_document): - for t in find_line_number_tokens(structured_document): - structured_document.set_tag(t, self.tag) - return structured_document DEFAULT_ANNOTATORS = [ - LineAnnotator() + LineAnnotator() ] + class Annotator(object): - def __init__(self, annotators=None): - if annotators is None: - annotators = DEFAULT_ANNOTATORS - self.annotators = annotators - - def annotate(self, structured_document): - for annotator in self.annotators: - structured_document = annotator.annotate(structured_document) - return structured_document + def __init__(self, annotators=None): + if annotators is None: + annotators = DEFAULT_ANNOTATORS + self.annotators = annotators + + def annotate(self, structured_document): + for annotator in self.annotators: + structured_document = annotator.annotate(structured_document) + return structured_document diff --git a/sciencebeam_gym/preprocess/annotation/annotator_test.py b/sciencebeam_gym/preprocess/annotation/annotator_test.py index 93a4c7d..53494ff 100644 --- a/sciencebeam_gym/preprocess/annotation/annotator_test.py +++ b/sciencebeam_gym/preprocess/annotation/annotator_test.py @@ -1,32 +1,33 @@ from sciencebeam_gym.preprocess.annotation.annotator import ( - LineAnnotator + LineAnnotator ) from sciencebeam_gym.structured_document import ( - SimpleLine, - SimpleToken, - SimpleStructuredDocument + SimpleLine, + SimpleToken, + SimpleStructuredDocument ) line_annotator = LineAnnotator() + class TestLineAnnotator(object): - def test_x(self): - line_number_tokens = [ - SimpleToken(str(line_no), dict(x='1', y=str(line_no * 20))) - for line_no in range(1, 5) - ] - doc = SimpleStructuredDocument(lines=[ - SimpleLine([ - line_number_token, - SimpleToken('other text', dict( - x=str(float(line_number_token.get_x()) + 50), - y=line_number_token.get_y() - )) - ]) - for line_number_token in line_number_tokens - ]) - line_annotator.annotate( - doc - ) - assert [t.get_tag() for t in line_number_tokens] == ['line_no'] * len(line_number_tokens) + def test_x(self): + line_number_tokens = [ + SimpleToken(str(line_no), dict(x='1', y=str(line_no * 20))) + for line_no in range(1, 5) + ] + doc = SimpleStructuredDocument(lines=[ + SimpleLine([ + line_number_token, + SimpleToken('other text', dict( + x=str(float(line_number_token.get_x()) + 50), + y=line_number_token.get_y() + )) + ]) + for line_number_token in line_number_tokens + ]) + line_annotator.annotate( + doc + ) + assert [t.get_tag() for t in line_number_tokens] == ['line_no'] * len(line_number_tokens) diff --git a/sciencebeam_gym/preprocess/annotation/find_line_number.py b/sciencebeam_gym/preprocess/annotation/find_line_number.py index 6730efa..918f7b8 100644 --- a/sciencebeam_gym/preprocess/annotation/find_line_number.py +++ b/sciencebeam_gym/preprocess/annotation/find_line_number.py @@ -10,52 +10,55 @@ DEFAULT_X_THRESHOLD = 10 # (low ratio indicates numbers may be figures or table values rather than line numbers) DEFAULT_TOKEN_RATIO_THRESHOLD = 0.7 + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def _find_line_number_token_candidates(structured_document, page): - for line in structured_document.get_lines_of_page(page): - text_tokens = sorted( - structured_document.get_tokens_of_line(line), - key=lambda t: structured_document.get_x(t) - ) - if text_tokens: - token = text_tokens[0] - token_text = structured_document.get_text(token) - if token_text and token_text.isdigit(): - yield token + for line in structured_document.get_lines_of_page(page): + text_tokens = sorted( + structured_document.get_tokens_of_line(line), + key=lambda t: structured_document.get_x(t) + ) + if text_tokens: + token = text_tokens[0] + token_text = structured_document.get_text(token) + if token_text and token_text.isdigit(): + yield token + def find_line_number_tokens( - structured_document, - x_threshold=DEFAULT_X_THRESHOLD, - token_ratio_threshold=DEFAULT_TOKEN_RATIO_THRESHOLD): - - for page in structured_document.get_pages(): - line_number_candidates = list(_find_line_number_token_candidates( - structured_document, - page - )) - # we need more than two lines - if len(line_number_candidates) > 2: - c = Counter(( - round(float(structured_document.get_x(t))) - for t in line_number_candidates - )) - get_logger().debug('counter: %s', c) - most_common_x, most_common_count = c.most_common(1)[0] - get_logger().debug('most_common: x: %s (count: %s)', most_common_x, most_common_count) - tokens_within_range = [ - token - for token in line_number_candidates - if abs(float(token.attrib['x']) - most_common_x) < x_threshold - ] - token_within_range_ratio = len(tokens_within_range) / len(line_number_candidates) - if token_within_range_ratio < token_ratio_threshold: - get_logger().debug( - 'token within range ratio not meeting threshold: %f < %f', - token_within_range_ratio, - token_ratio_threshold - ) - else: - for token in tokens_within_range: - yield token + structured_document, + x_threshold=DEFAULT_X_THRESHOLD, + token_ratio_threshold=DEFAULT_TOKEN_RATIO_THRESHOLD): + + for page in structured_document.get_pages(): + line_number_candidates = list(_find_line_number_token_candidates( + structured_document, + page + )) + # we need more than two lines + if len(line_number_candidates) > 2: + c = Counter(( + round(float(structured_document.get_x(t))) + for t in line_number_candidates + )) + get_logger().debug('counter: %s', c) + most_common_x, most_common_count = c.most_common(1)[0] + get_logger().debug('most_common: x: %s (count: %s)', most_common_x, most_common_count) + tokens_within_range = [ + token + for token in line_number_candidates + if abs(float(token.attrib['x']) - most_common_x) < x_threshold + ] + token_within_range_ratio = len(tokens_within_range) / len(line_number_candidates) + if token_within_range_ratio < token_ratio_threshold: + get_logger().debug( + 'token within range ratio not meeting threshold: %f < %f', + token_within_range_ratio, + token_ratio_threshold + ) + else: + for token in tokens_within_range: + yield token diff --git a/sciencebeam_gym/preprocess/annotation/find_line_numbers_test.py b/sciencebeam_gym/preprocess/annotation/find_line_numbers_test.py index 367ed0b..5719466 100644 --- a/sciencebeam_gym/preprocess/annotation/find_line_numbers_test.py +++ b/sciencebeam_gym/preprocess/annotation/find_line_numbers_test.py @@ -1,107 +1,108 @@ from sciencebeam_utils.utils.collection import ( - flatten + flatten ) from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimpleLine, - SimpleToken + SimpleStructuredDocument, + SimpleLine, + SimpleToken ) from sciencebeam_gym.preprocess.annotation.find_line_number import ( - find_line_number_tokens + find_line_number_tokens ) + class TestFindLxmlLineNumberTokens(object): - def test_should_return_empty_list_for_empty_page(self): - doc = SimpleStructuredDocument(lines=[]) - line_number_tokens = list(find_line_number_tokens(doc)) - assert len(line_number_tokens) == 0 + def test_should_return_empty_list_for_empty_page(self): + doc = SimpleStructuredDocument(lines=[]) + line_number_tokens = list(find_line_number_tokens(doc)) + assert len(line_number_tokens) == 0 - def test_should_return_line_number_tokens_appearing_first_in_line(self): - line_number_tokens = [ - SimpleToken(str(line_no), dict( - x=str(10), - y=str(line_no * 20)) - ) - for line_no in range(1, 5) - ] - doc = SimpleStructuredDocument(lines=[ - SimpleLine([ - line_number_token, - SimpleToken('other text', dict( - x=str(50), - y=line_number_token.get_y() - )) - ]) - for line_number_token in line_number_tokens - ]) - expected_line_number_tokens = line_number_tokens - actual_line_number_tokens = list(find_line_number_tokens(doc)) - assert actual_line_number_tokens == expected_line_number_tokens + def test_should_return_line_number_tokens_appearing_first_in_line(self): + line_number_tokens = [ + SimpleToken(str(line_no), dict( + x=str(10), + y=str(line_no * 20)) + ) + for line_no in range(1, 5) + ] + doc = SimpleStructuredDocument(lines=[ + SimpleLine([ + line_number_token, + SimpleToken('other text', dict( + x=str(50), + y=line_number_token.get_y() + )) + ]) + for line_number_token in line_number_tokens + ]) + expected_line_number_tokens = line_number_tokens + actual_line_number_tokens = list(find_line_number_tokens(doc)) + assert actual_line_number_tokens == expected_line_number_tokens - def test_should_not_return_line_number_tokens_if_not_line(self): - line_number_tokens = [ - SimpleToken(str(line_no), dict( - x=str(30), - y=str(line_no * 20)) - ) - for line_no in range(1, 5) - ] - doc = SimpleStructuredDocument(lines=[ - SimpleLine([ - line_number_token, - SimpleToken('other text', dict( - x=str(20), - y=line_number_token.get_y() - )) - ]) - for line_number_token in line_number_tokens - ]) - expected_line_number_tokens = [] - actual_line_number_tokens = list(find_line_number_tokens(doc)) - assert actual_line_number_tokens == expected_line_number_tokens + def test_should_not_return_line_number_tokens_if_not_line(self): + line_number_tokens = [ + SimpleToken(str(line_no), dict( + x=str(30), + y=str(line_no * 20)) + ) + for line_no in range(1, 5) + ] + doc = SimpleStructuredDocument(lines=[ + SimpleLine([ + line_number_token, + SimpleToken('other text', dict( + x=str(20), + y=line_number_token.get_y() + )) + ]) + for line_number_token in line_number_tokens + ]) + expected_line_number_tokens = [] + actual_line_number_tokens = list(find_line_number_tokens(doc)) + assert actual_line_number_tokens == expected_line_number_tokens - def test_should_not_return_line_number_tokens_at_unusual_position(self): - usual_line_number_x = 1 - line_number_tokens = [ - SimpleToken(str(line_no), dict( - x=str(usual_line_number_x if line_no != 2 else usual_line_number_x + 30), - y=str(line_no * 20)) - ) - for line_no in range(1, 5) - ] - doc = SimpleStructuredDocument(lines=[ - SimpleLine([ - line_number_token, - SimpleToken('other text', dict( - x=str(50), - y=line_number_token.get_y() - )) - ]) - for line_number_token in line_number_tokens - ]) - expected_line_number_tokens = [ - t for t in line_number_tokens - if int(t.get_x()) == usual_line_number_x - ] - actual_line_number_tokens = list(find_line_number_tokens(doc)) - assert actual_line_number_tokens == expected_line_number_tokens + def test_should_not_return_line_number_tokens_at_unusual_position(self): + usual_line_number_x = 1 + line_number_tokens = [ + SimpleToken(str(line_no), dict( + x=str(usual_line_number_x if line_no != 2 else usual_line_number_x + 30), + y=str(line_no * 20)) + ) + for line_no in range(1, 5) + ] + doc = SimpleStructuredDocument(lines=[ + SimpleLine([ + line_number_token, + SimpleToken('other text', dict( + x=str(50), + y=line_number_token.get_y() + )) + ]) + for line_number_token in line_number_tokens + ]) + expected_line_number_tokens = [ + t for t in line_number_tokens + if int(t.get_x()) == usual_line_number_x + ] + actual_line_number_tokens = list(find_line_number_tokens(doc)) + assert actual_line_number_tokens == expected_line_number_tokens - def test_should_not_return_line_number_tokens_at_unusual_position2(self): - number_tokens = flatten([ - [ - SimpleToken(str(line_no), dict( - x=str(x * 50), - y=str(line_no * 20)) - ) - for line_no in range(1, 5) - ] - for x in range(1, 3) - ]) - doc = SimpleStructuredDocument(lines=[ - SimpleLine([number_token]) - for number_token in number_tokens - ]) - actual_line_number_tokens = list(find_line_number_tokens(doc)) - assert actual_line_number_tokens == [] + def test_should_not_return_line_number_tokens_at_unusual_position2(self): + number_tokens = flatten([ + [ + SimpleToken(str(line_no), dict( + x=str(x * 50), + y=str(line_no * 20)) + ) + for line_no in range(1, 5) + ] + for x in range(1, 3) + ]) + doc = SimpleStructuredDocument(lines=[ + SimpleLine([number_token]) + for number_token in number_tokens + ]) + actual_line_number_tokens = list(find_line_number_tokens(doc)) + assert actual_line_number_tokens == [] diff --git a/sciencebeam_gym/preprocess/annotation/fuzzy_match.py b/sciencebeam_gym/preprocess/annotation/fuzzy_match.py index aa80315..65360a0 100644 --- a/sciencebeam_gym/preprocess/annotation/fuzzy_match.py +++ b/sciencebeam_gym/preprocess/annotation/fuzzy_match.py @@ -3,273 +3,287 @@ from __future__ import division import logging from sciencebeam_utils.utils.string import ( - LazyStr + LazyStr ) from sciencebeam_alignment.align import ( - LocalSequenceMatcher, - SimpleScoring + LocalSequenceMatcher, + SimpleScoring ) from sciencebeam_alignment.word_sequence_matcher import ( - WordSequenceMatcher + WordSequenceMatcher ) DEFAULT_SCORING = SimpleScoring( - match_score=2, - mismatch_score=-1, - gap_score=-2 + match_score=2, + mismatch_score=-1, + gap_score=-2 ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def len_index_range(index_range): - return index_range[1] - index_range[0] + return index_range[1] - index_range[0] + # Treat space or comma after a dot, or a dot after a letter as junk -DEFAULT_ISJUNK = lambda s, i: ( - (i > 0 and s[i - 1] == '.' and (s[i] == ' ' or s[i] == ',')) or - (i > 0 and s[i - 1].isalpha() and s[i] == '.') or - (i > 0 and s[i - 1] == s[i]) or - s[i] == '*' -) +def DEFAULT_ISJUNK(s, i): + return ( + (i > 0 and s[i - 1] == '.' and (s[i] == ' ' or s[i] == ',')) or + (i > 0 and s[i - 1].isalpha() and s[i] == '.') or + (i > 0 and s[i - 1] == s[i]) or + s[i] == '*' + ) -DOT_IS_JUNK = lambda s, i: s[i] == '.' -def remove_junk(s, isjunk=None): - if isjunk is None: - isjunk = DEFAULT_ISJUNK - result = None - start = 0 - for i in range(len(s)): - if isjunk(s, i): - if result is None: - result = [] - if i > start: - result.append(s[start:i]) - start = i + 1 - if result is None: - return s - if len(s) > start: - result.append(s[start:]) - return ''.join(result) +def DOT_IS_JUNK(s, i): + return s[i] == '.' -def invert_index_ranges(range_list, start, end): - i = start - for r_start, r_end in range_list: - if i >= end: - return - if i < r_start: - yield i, min(end, r_start) - i = r_end - if i < end: - yield i, end -class FuzzyMatchResult(object): - def __init__(self, a, b, matching_blocks, isjunk=None): - self.a = a - self.b = b - self.matching_blocks = matching_blocks - self.non_empty_matching_blocks = [x for x in self.matching_blocks if x[-1]] - self._match_count = None - self._a_index_range = None - self._b_index_range = None - self.isjunk = isjunk or DEFAULT_ISJUNK - - def has_match(self): - return len(self.non_empty_matching_blocks) > 0 - - def match_count(self): - if self._match_count is None: - self._match_count = sum(triple[-1] for triple in self.matching_blocks) - return self._match_count - - def ratio_to(self, size): - if not size: - return 0.0 - return self.match_count() / size - - def ratio(self): - a_match_len = len_index_range(self.a_index_range()) - b_match_len = len_index_range(self.b_index_range()) - max_len = max(a_match_len, b_match_len) - if max_len == a_match_len: - junk_match_count = self.a_non_matching_junk_count(self.a_index_range()) - else: - junk_match_count = self.b_non_matching_junk_count(self.b_index_range()) - max_len_excl_junk = max_len - junk_match_count - result = self.ratio_to(max_len_excl_junk) - if result > 1.0: - get_logger().debug( - 'ratio: ratio greater than 1.0, a_match_len=%d, b_match_len=%d, max_len=%d,' - ' junk_match_count=%d, max_len_excl_junk=%d, result=%f', - a_match_len, b_match_len, max_len, junk_match_count, max_len_excl_junk, result - ) - return result - - def count_junk_between(self, s, index_range): - if not self.isjunk: - return 0 - return sum(self.isjunk(s, i) for i in range(index_range[0], index_range[1])) - - def count_non_matching_junk(self, s, s_matching_blocks, index_range=None): - if not self.isjunk: - return 0 - if index_range is None: - index_range = (0, len(s)) - return sum( - self.count_junk_between(s, block_index_range) - for block_index_range in invert_index_ranges( - s_matching_blocks, index_range[0], index_range[1] - ) - ) +def remove_junk(s, isjunk=None): + if isjunk is None: + isjunk = DEFAULT_ISJUNK + result = None + start = 0 + for i in range(len(s)): + if isjunk(s, i): + if result is None: + result = [] + if i > start: + result.append(s[start:i]) + start = i + 1 + if result is None: + return s + if len(s) > start: + result.append(s[start:]) + return ''.join(result) - def a_junk_match_count(self): - return self.count_junk_between(self.a, self.a_index_range()) - - def a_junk_count(self): - return self.count_junk_between(self.a, (0, len(self.a))) - - def a_non_matching_junk_count(self, index_range=None): - return self.count_non_matching_junk(self.a, self.a_matching_blocks(), index_range) - - def b_junk_match_count(self): - return self.count_junk_between(self.b, self.b_index_range()) - - def b_junk_count(self): - return self.count_junk_between(self.b, (0, len(self.b))) - - def b_non_matching_junk_count(self, index_range=None): - return self.count_non_matching_junk(self.b, self.b_matching_blocks(), index_range) - - def a_ratio(self): - return self.ratio_to(len(self.a) - self.a_non_matching_junk_count()) - - def b_ratio(self): - return self.ratio_to(len(self.b) - self.b_non_matching_junk_count()) - - def b_gap_ratio(self): - """ - Calculates the ratio of matches vs the length of b, - but also adds any gaps / mismatches within a. - """ - a_index_range = self.a_index_range() - a_match_len = len_index_range(a_index_range) - match_count = self.match_count() - a_junk_match_count = self.a_non_matching_junk_count(a_index_range) - b_junk_count = self.b_non_matching_junk_count() - a_gaps = a_match_len - match_count - return self.ratio_to(len(self.b) + a_gaps - a_junk_match_count - b_junk_count) - - def a_matching_blocks(self): - return ((a, a + size) for a, _, size in self.non_empty_matching_blocks) - - def b_matching_blocks(self): - return ((b, b + size) for _, b, size in self.non_empty_matching_blocks) - - def a_start_index(self): - return self.non_empty_matching_blocks[0][0] if self.has_match() else None - - def a_end_index(self): - if not self.has_match(): - return None - ai, _, size = self.non_empty_matching_blocks[-1] - return ai + size - - def a_index_range(self): - if not self.non_empty_matching_blocks: - return (0, 0) - if not self._a_index_range: - self._a_index_range = (self.a_start_index(), self.a_end_index()) - return self._a_index_range - - def b_start_index(self): - return self.non_empty_matching_blocks[0][1] if self.has_match() else None - - def b_end_index(self): - if not self.has_match(): - return None - _, bi, size = self.non_empty_matching_blocks[-1] - return bi + size - - def b_index_range(self): - if not self.non_empty_matching_blocks: - return (0, 0) - if not self._b_index_range: - self._b_index_range = (self.b_start_index(), self.b_end_index()) - return self._b_index_range - - def a_split_at(self, index, a_pre_split=None, a_post_split=None): - if a_pre_split is None: - a_pre_split = self.a[:index] - if a_post_split is None: - a_post_split = self.a[index:] - if not self.non_empty_matching_blocks or self.a_end_index() <= index: - return ( - FuzzyMatchResult(a_pre_split, self.b, self.non_empty_matching_blocks), - FuzzyMatchResult(a_post_split, self.b, []) - ) - return ( - FuzzyMatchResult(a_pre_split, self.b, [ - (ai, bi, min(size, index - ai)) - for ai, bi, size in self.non_empty_matching_blocks - if ai < index - ]), - FuzzyMatchResult(a_post_split, self.b, [ - (max(0, ai - index), bi, size if ai >= index else size + ai - index) - for ai, bi, size in self.non_empty_matching_blocks - if ai + size > index - ]) - ) - def b_split_at(self, index, b_pre_split=None, b_post_split=None): - if b_pre_split is None: - b_pre_split = self.b[:index] - if b_post_split is None: - b_post_split = self.b[index:] - if not self.non_empty_matching_blocks or self.b_end_index() <= index: - return ( - FuzzyMatchResult(self.a, b_pre_split, self.non_empty_matching_blocks), - FuzzyMatchResult(self.a, b_post_split, []) - ) - result = ( - FuzzyMatchResult(self.a, b_pre_split, [ - (ai, bi, min(size, index - bi)) - for ai, bi, size in self.non_empty_matching_blocks - if bi < index - ]), - FuzzyMatchResult(self.a, b_post_split, [ - (ai, max(0, bi - index), size if bi >= index else size + bi - index) - for ai, bi, size in self.non_empty_matching_blocks - if bi + size > index - ]) - ) - return result - - def detailed_str(self): - return 'matching_blocks=[%s]' % ( - ', '.join([ - '(a[%d:+%d] = b[%d:+%d] = "%s")' % (ai, size, bi, size, self.a[ai:ai + size]) - for ai, bi, size in self.non_empty_matching_blocks - ]) - ) +def invert_index_ranges(range_list, start, end): + i = start + for r_start, r_end in range_list: + if i >= end: + return + if i < r_start: + yield i, min(end, r_start) + i = r_end + if i < end: + yield i, end - def detailed(self): - return LazyStr(self.detailed_str) - def __repr__(self): - return ( - 'FuzzyMatchResult(matching_blocks={}, match_count={}, ratio={},' - ' a_ratio={}, b_gap_ratio={})'.format( - self.matching_blocks, self.match_count(), self.ratio(), self.a_ratio(), self.b_gap_ratio() - ) - ) +class FuzzyMatchResult(object): + def __init__(self, a, b, matching_blocks, isjunk=None): + self.a = a + self.b = b + self.matching_blocks = matching_blocks + self.non_empty_matching_blocks = [x for x in self.matching_blocks if x[-1]] + self._match_count = None + self._a_index_range = None + self._b_index_range = None + self.isjunk = isjunk or DEFAULT_ISJUNK + + def has_match(self): + return len(self.non_empty_matching_blocks) > 0 + + def match_count(self): + if self._match_count is None: + self._match_count = sum(triple[-1] for triple in self.matching_blocks) + return self._match_count + + def ratio_to(self, size): + if not size: + return 0.0 + return self.match_count() / size + + def ratio(self): + a_match_len = len_index_range(self.a_index_range()) + b_match_len = len_index_range(self.b_index_range()) + max_len = max(a_match_len, b_match_len) + if max_len == a_match_len: + junk_match_count = self.a_non_matching_junk_count(self.a_index_range()) + else: + junk_match_count = self.b_non_matching_junk_count(self.b_index_range()) + max_len_excl_junk = max_len - junk_match_count + result = self.ratio_to(max_len_excl_junk) + if result > 1.0: + get_logger().debug( + 'ratio: ratio greater than 1.0, a_match_len=%d, b_match_len=%d, max_len=%d,' + ' junk_match_count=%d, max_len_excl_junk=%d, result=%f', + a_match_len, b_match_len, max_len, junk_match_count, max_len_excl_junk, result + ) + return result + + def count_junk_between(self, s, index_range): + if not self.isjunk: + return 0 + return sum(self.isjunk(s, i) for i in range(index_range[0], index_range[1])) + + def count_non_matching_junk(self, s, s_matching_blocks, index_range=None): + if not self.isjunk: + return 0 + if index_range is None: + index_range = (0, len(s)) + return sum( + self.count_junk_between(s, block_index_range) + for block_index_range in invert_index_ranges( + s_matching_blocks, index_range[0], index_range[1] + ) + ) + + def a_junk_match_count(self): + return self.count_junk_between(self.a, self.a_index_range()) + + def a_junk_count(self): + return self.count_junk_between(self.a, (0, len(self.a))) + + def a_non_matching_junk_count(self, index_range=None): + return self.count_non_matching_junk(self.a, self.a_matching_blocks(), index_range) + + def b_junk_match_count(self): + return self.count_junk_between(self.b, self.b_index_range()) + + def b_junk_count(self): + return self.count_junk_between(self.b, (0, len(self.b))) + + def b_non_matching_junk_count(self, index_range=None): + return self.count_non_matching_junk(self.b, self.b_matching_blocks(), index_range) + + def a_ratio(self): + return self.ratio_to(len(self.a) - self.a_non_matching_junk_count()) + + def b_ratio(self): + return self.ratio_to(len(self.b) - self.b_non_matching_junk_count()) + + def b_gap_ratio(self): + """ + Calculates the ratio of matches vs the length of b, + but also adds any gaps / mismatches within a. + """ + a_index_range = self.a_index_range() + a_match_len = len_index_range(a_index_range) + match_count = self.match_count() + a_junk_match_count = self.a_non_matching_junk_count(a_index_range) + b_junk_count = self.b_non_matching_junk_count() + a_gaps = a_match_len - match_count + return self.ratio_to(len(self.b) + a_gaps - a_junk_match_count - b_junk_count) + + def a_matching_blocks(self): + return ((a, a + size) for a, _, size in self.non_empty_matching_blocks) + + def b_matching_blocks(self): + return ((b, b + size) for _, b, size in self.non_empty_matching_blocks) + + def a_start_index(self): + return self.non_empty_matching_blocks[0][0] if self.has_match() else None + + def a_end_index(self): + if not self.has_match(): + return None + ai, _, size = self.non_empty_matching_blocks[-1] + return ai + size + + def a_index_range(self): + if not self.non_empty_matching_blocks: + return (0, 0) + if not self._a_index_range: + self._a_index_range = (self.a_start_index(), self.a_end_index()) + return self._a_index_range + + def b_start_index(self): + return self.non_empty_matching_blocks[0][1] if self.has_match() else None + + def b_end_index(self): + if not self.has_match(): + return None + _, bi, size = self.non_empty_matching_blocks[-1] + return bi + size + + def b_index_range(self): + if not self.non_empty_matching_blocks: + return (0, 0) + if not self._b_index_range: + self._b_index_range = (self.b_start_index(), self.b_end_index()) + return self._b_index_range + + def a_split_at(self, index, a_pre_split=None, a_post_split=None): + if a_pre_split is None: + a_pre_split = self.a[:index] + if a_post_split is None: + a_post_split = self.a[index:] + if not self.non_empty_matching_blocks or self.a_end_index() <= index: + return ( + FuzzyMatchResult(a_pre_split, self.b, self.non_empty_matching_blocks), + FuzzyMatchResult(a_post_split, self.b, []) + ) + return ( + FuzzyMatchResult(a_pre_split, self.b, [ + (ai, bi, min(size, index - ai)) + for ai, bi, size in self.non_empty_matching_blocks + if ai < index + ]), + FuzzyMatchResult(a_post_split, self.b, [ + (max(0, ai - index), bi, size if ai >= index else size + ai - index) + for ai, bi, size in self.non_empty_matching_blocks + if ai + size > index + ]) + ) + + def b_split_at(self, index, b_pre_split=None, b_post_split=None): + if b_pre_split is None: + b_pre_split = self.b[:index] + if b_post_split is None: + b_post_split = self.b[index:] + if not self.non_empty_matching_blocks or self.b_end_index() <= index: + return ( + FuzzyMatchResult(self.a, b_pre_split, self.non_empty_matching_blocks), + FuzzyMatchResult(self.a, b_post_split, []) + ) + result = ( + FuzzyMatchResult(self.a, b_pre_split, [ + (ai, bi, min(size, index - bi)) + for ai, bi, size in self.non_empty_matching_blocks + if bi < index + ]), + FuzzyMatchResult(self.a, b_post_split, [ + (ai, max(0, bi - index), size if bi >= index else size + bi - index) + for ai, bi, size in self.non_empty_matching_blocks + if bi + size > index + ]) + ) + return result + + def detailed_str(self): + return 'matching_blocks=[%s]' % ( + ', '.join([ + '(a[%d:+%d] = b[%d:+%d] = "%s")' % (ai, size, bi, size, self.a[ai:ai + size]) + for ai, bi, size in self.non_empty_matching_blocks + ]) + ) + + def detailed(self): + return LazyStr(self.detailed_str) + + def __repr__(self): + return ( + 'FuzzyMatchResult(matching_blocks={}, match_count={}, ratio={},' + ' a_ratio={}, b_gap_ratio={})'.format( + self.matching_blocks, + self.match_count(), + self.ratio(), + self.a_ratio(), + self.b_gap_ratio() + ) + ) + def fuzzy_match(a, b, exact_word_match_threshold=5): - if min(len(a), len(b)) < exact_word_match_threshold: - sm = WordSequenceMatcher(None, a, b) - else: - sm = LocalSequenceMatcher(a=a, b=b, scoring=DEFAULT_SCORING) - matching_blocks = sm.get_matching_blocks() - return FuzzyMatchResult(a, b, matching_blocks) + if min(len(a), len(b)) < exact_word_match_threshold: + sm = WordSequenceMatcher(None, a, b) + else: + sm = LocalSequenceMatcher(a=a, b=b, scoring=DEFAULT_SCORING) + matching_blocks = sm.get_matching_blocks() + return FuzzyMatchResult(a, b, matching_blocks) diff --git a/sciencebeam_gym/preprocess/annotation/fuzzy_match_test.py b/sciencebeam_gym/preprocess/annotation/fuzzy_match_test.py index 5b52a30..9b4fe9f 100644 --- a/sciencebeam_gym/preprocess/annotation/fuzzy_match_test.py +++ b/sciencebeam_gym/preprocess/annotation/fuzzy_match_test.py @@ -3,186 +3,191 @@ from __future__ import division import logging from sciencebeam_gym.preprocess.annotation.fuzzy_match import ( - remove_junk, - invert_index_ranges, - FuzzyMatchResult, - fuzzy_match, - DOT_IS_JUNK + remove_junk, + invert_index_ranges, + FuzzyMatchResult, + fuzzy_match, + DOT_IS_JUNK ) + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + class TestRemoveJunk(object): - def test_should_keep_str_without_junk(self): - assert remove_junk('abc', DOT_IS_JUNK) == 'abc' + def test_should_keep_str_without_junk(self): + assert remove_junk('abc', DOT_IS_JUNK) == 'abc' + + def test_should_remove_dots_after_capitals(self): + assert remove_junk('P.O. Box', DOT_IS_JUNK) == 'PO Box' - def test_should_remove_dots_after_capitals(self): - assert remove_junk('P.O. Box', DOT_IS_JUNK) == 'PO Box' + def test_should_remove_asterisk_after_capitals(self): + assert remove_junk('Mr Beam*') == 'Mr Beam' - def test_should_remove_asterisk_after_capitals(self): - assert remove_junk('Mr Beam*') == 'Mr Beam' + def test_should_remove_repeating_characters(self): + assert remove_junk('Mr Beeeam') == 'Mr Beam' - def test_should_remove_repeating_characters(self): - assert remove_junk('Mr Beeeam') == 'Mr Beam' class TestInvertIndexRanges(object): - def test_should_return_empty_for_empty_range(self): - assert list(invert_index_ranges([], 0, 0)) == list([]) + def test_should_return_empty_for_empty_range(self): + assert list(invert_index_ranges([], 0, 0)) == list([]) - def test_should_return_whole_range_for_empty_range_list(self): - assert list(invert_index_ranges([], 0, 10)) == list([(0, 10)]) + def test_should_return_whole_range_for_empty_range_list(self): + assert list(invert_index_ranges([], 0, 10)) == list([(0, 10)]) - def test_should_exclude_range_in_the_beginning(self): - assert list(invert_index_ranges([(0, 3)], 0, 10)) == list([(3, 10)]) + def test_should_exclude_range_in_the_beginning(self): + assert list(invert_index_ranges([(0, 3)], 0, 10)) == list([(3, 10)]) - def test_should_exclude_range_in_the_beginning_beyond_start(self): - assert list(invert_index_ranges([(0, 13)], 10, 20)) == list([(13, 20)]) + def test_should_exclude_range_in_the_beginning_beyond_start(self): + assert list(invert_index_ranges([(0, 13)], 10, 20)) == list([(13, 20)]) - def test_should_exclude_range_in_the_middle(self): - assert list(invert_index_ranges([(4, 7)], 0, 10)) == list([(0, 4), (7, 10)]) + def test_should_exclude_range_in_the_middle(self): + assert list(invert_index_ranges([(4, 7)], 0, 10)) == list([(0, 4), (7, 10)]) - def test_should_exclude_range_at_the_end(self): - assert list(invert_index_ranges([(7, 10)], 0, 10)) == list([(0, 7)]) + def test_should_exclude_range_at_the_end(self): + assert list(invert_index_ranges([(7, 10)], 0, 10)) == list([(0, 7)]) + + def test_should_exclude_range_at_the_end_beyond_end(self): + assert list(invert_index_ranges([(7, 100)], 0, 10)) == list([(0, 7)]) - def test_should_exclude_range_at_the_end_beyond_end(self): - assert list(invert_index_ranges([(7, 100)], 0, 10)) == list([(0, 7)]) class TestFuzzyMatch(object): - def test_match_count_should_be_the_same_independent_of_order(self): - s1 = 'this is a some sequence' - choice = 'this is another sequence' - fm_1 = fuzzy_match(s1, choice) - fm_2 = fuzzy_match(choice, s1) - assert fm_1.match_count() == fm_2.match_count() + def test_match_count_should_be_the_same_independent_of_order(self): + s1 = 'this is a some sequence' + choice = 'this is another sequence' + fm_1 = fuzzy_match(s1, choice) + fm_2 = fuzzy_match(choice, s1) + assert fm_1.match_count() == fm_2.match_count() + class TestFuzzyMatchResult(object): - def test_exact_match(self): - fm = FuzzyMatchResult('abc', 'abc', [(0, 0, 3)]) - assert fm.has_match() - assert fm.match_count() == 3 - assert fm.ratio() == 1.0 - assert fm.a_ratio() == 1.0 - assert fm.b_ratio() == 1.0 - assert fm.b_gap_ratio() == 1.0 - assert fm.a_index_range() == (0, 3) - assert fm.b_index_range() == (0, 3) - - def test_no_match(self): - fm = FuzzyMatchResult('abc', 'xyz', []) - assert not fm.has_match() - assert fm.match_count() == 0 - - def test_partial_match(self): - fm = FuzzyMatchResult('abx', 'aby', [(0, 0, 2)]) - assert fm.has_match() - assert fm.match_count() == 2 - assert fm.ratio() == 1.0 - assert fm.a_ratio() == 2 / 3 - assert fm.b_ratio() == 2 / 3 - assert fm.b_gap_ratio() == 2 / 3 - assert fm.a_index_range() == (0, 2) - assert fm.b_index_range() == (0, 2) - - def test_partial_match_ignore_junk_at_the_end_of_a(self): - fm = FuzzyMatchResult('ab.', 'ab', [(0, 0, 2)], isjunk=lambda s, i: s[i] == '.') - assert fm.has_match() - assert fm.match_count() == 2 - assert fm.ratio() == 1.0 - assert fm.a_ratio() == 1.0 - assert fm.b_ratio() == 1.0 - assert fm.b_gap_ratio() == 1.0 - assert fm.a_index_range() == (0, 2) - assert fm.b_index_range() == (0, 2) - - def test_partial_match_ignore_junk_at_the_end_of_b(self): - fm = FuzzyMatchResult('ab', 'ab.', [(0, 0, 2)], isjunk=lambda s, i: s[i] == '.') - assert fm.has_match() - assert fm.match_count() == 2 - assert fm.ratio() == 1.0 - assert fm.a_ratio() == 1.0 - assert fm.b_ratio() == 1.0 - assert fm.b_gap_ratio() == 1.0 - assert fm.a_index_range() == (0, 2) - assert fm.b_index_range() == (0, 2) - - def test_partial_match_ignore_junk_in_the_middle_of_a(self): - fm = FuzzyMatchResult('a.b', 'ab', [(0, 0, 1), (2, 1, 1)], isjunk=lambda s, i: s[i] == '.') - assert fm.has_match() - assert fm.match_count() == 2 - assert fm.ratio() == 1.0 - assert fm.a_ratio() == 1.0 - assert fm.b_ratio() == 1.0 - assert fm.b_gap_ratio() == 1.0 - assert fm.a_index_range() == (0, 3) - assert fm.b_index_range() == (0, 2) - - def test_partial_match_ignore_junk_in_the_middle_of_b(self): - fm = FuzzyMatchResult('ab', 'a.b', [(0, 0, 1), (1, 2, 1)], isjunk=lambda s, i: s[i] == '.') - assert fm.has_match() - assert fm.match_count() == 2 - assert fm.ratio() == 1.0 - assert fm.a_ratio() == 1.0 - assert fm.b_ratio() == 1.0 - assert fm.b_gap_ratio() == 1.0 - assert fm.a_index_range() == (0, 2) - assert fm.b_index_range() == (0, 3) - - def test_should_not_double_count_matching_junk(self): - fm = FuzzyMatchResult('a.b', 'a.b', [(0, 0, 3)], isjunk=lambda s, i: s[i] == '.') - assert fm.has_match() - assert fm.match_count() == 3 - assert fm.ratio() == 1.0 - assert fm.a_ratio() == 1.0 - assert fm.b_ratio() == 1.0 - assert fm.b_gap_ratio() == 1.0 - assert fm.a_index_range() == (0, 3) - assert fm.b_index_range() == (0, 3) - - def test_a_split_no_match(self): - fm = FuzzyMatchResult('abc', 'xyz', []) - fm_1, fm_2 = fm.a_split_at(2) - - assert not fm_1.has_match() - assert fm_1.a == 'ab' - assert fm_1.b == 'xyz' - - assert not fm_2.has_match() - assert fm_2.a == 'c' - assert fm_2.b == 'xyz' - - def test_b_split_no_match(self): - fm = FuzzyMatchResult('abc', 'xyz', []) - fm_1, fm_2 = fm.b_split_at(2) - - assert not fm_1.has_match() - assert fm_1.a == 'abc' - assert fm_1.b == 'xy' - - assert not fm_2.has_match() - assert fm_2.a == 'abc' - assert fm_2.b == 'z' - - def test_a_split_exact_match(self): - fm = FuzzyMatchResult('abc', 'abc', [(0, 0, 3)]) - fm_1, fm_2 = fm.a_split_at(2) - - assert fm_1.a == 'ab' - assert fm_1.b == 'abc' - assert fm_1.has_match() - assert fm_1.ratio() == 1.0 - assert fm_1.a_ratio() == 1.0 - assert fm_1.b_ratio() == 2 / 3 - assert fm_1.b_gap_ratio() == 2 / 3 - assert fm_1.a_index_range() == (0, 2) - assert fm_1.b_index_range() == (0, 2) - - assert fm_2.a == 'c' - assert fm_2.b == 'abc' - assert fm_2.has_match() - assert fm_2.ratio() == 1.0 - assert fm_2.a_ratio() == 1.0 - assert fm_2.b_ratio() == 1 / 3 - assert fm_2.b_gap_ratio() == 1 / 3 - assert fm_2.a_index_range() == (0, 1) - assert fm_2.b_index_range() == (0, 1) + def test_exact_match(self): + fm = FuzzyMatchResult('abc', 'abc', [(0, 0, 3)]) + assert fm.has_match() + assert fm.match_count() == 3 + assert fm.ratio() == 1.0 + assert fm.a_ratio() == 1.0 + assert fm.b_ratio() == 1.0 + assert fm.b_gap_ratio() == 1.0 + assert fm.a_index_range() == (0, 3) + assert fm.b_index_range() == (0, 3) + + def test_no_match(self): + fm = FuzzyMatchResult('abc', 'xyz', []) + assert not fm.has_match() + assert fm.match_count() == 0 + + def test_partial_match(self): + fm = FuzzyMatchResult('abx', 'aby', [(0, 0, 2)]) + assert fm.has_match() + assert fm.match_count() == 2 + assert fm.ratio() == 1.0 + assert fm.a_ratio() == 2 / 3 + assert fm.b_ratio() == 2 / 3 + assert fm.b_gap_ratio() == 2 / 3 + assert fm.a_index_range() == (0, 2) + assert fm.b_index_range() == (0, 2) + + def test_partial_match_ignore_junk_at_the_end_of_a(self): + fm = FuzzyMatchResult('ab.', 'ab', [(0, 0, 2)], isjunk=lambda s, i: s[i] == '.') + assert fm.has_match() + assert fm.match_count() == 2 + assert fm.ratio() == 1.0 + assert fm.a_ratio() == 1.0 + assert fm.b_ratio() == 1.0 + assert fm.b_gap_ratio() == 1.0 + assert fm.a_index_range() == (0, 2) + assert fm.b_index_range() == (0, 2) + + def test_partial_match_ignore_junk_at_the_end_of_b(self): + fm = FuzzyMatchResult('ab', 'ab.', [(0, 0, 2)], isjunk=lambda s, i: s[i] == '.') + assert fm.has_match() + assert fm.match_count() == 2 + assert fm.ratio() == 1.0 + assert fm.a_ratio() == 1.0 + assert fm.b_ratio() == 1.0 + assert fm.b_gap_ratio() == 1.0 + assert fm.a_index_range() == (0, 2) + assert fm.b_index_range() == (0, 2) + + def test_partial_match_ignore_junk_in_the_middle_of_a(self): + fm = FuzzyMatchResult('a.b', 'ab', [(0, 0, 1), (2, 1, 1)], isjunk=lambda s, i: s[i] == '.') + assert fm.has_match() + assert fm.match_count() == 2 + assert fm.ratio() == 1.0 + assert fm.a_ratio() == 1.0 + assert fm.b_ratio() == 1.0 + assert fm.b_gap_ratio() == 1.0 + assert fm.a_index_range() == (0, 3) + assert fm.b_index_range() == (0, 2) + + def test_partial_match_ignore_junk_in_the_middle_of_b(self): + fm = FuzzyMatchResult('ab', 'a.b', [(0, 0, 1), (1, 2, 1)], isjunk=lambda s, i: s[i] == '.') + assert fm.has_match() + assert fm.match_count() == 2 + assert fm.ratio() == 1.0 + assert fm.a_ratio() == 1.0 + assert fm.b_ratio() == 1.0 + assert fm.b_gap_ratio() == 1.0 + assert fm.a_index_range() == (0, 2) + assert fm.b_index_range() == (0, 3) + + def test_should_not_double_count_matching_junk(self): + fm = FuzzyMatchResult('a.b', 'a.b', [(0, 0, 3)], isjunk=lambda s, i: s[i] == '.') + assert fm.has_match() + assert fm.match_count() == 3 + assert fm.ratio() == 1.0 + assert fm.a_ratio() == 1.0 + assert fm.b_ratio() == 1.0 + assert fm.b_gap_ratio() == 1.0 + assert fm.a_index_range() == (0, 3) + assert fm.b_index_range() == (0, 3) + + def test_a_split_no_match(self): + fm = FuzzyMatchResult('abc', 'xyz', []) + fm_1, fm_2 = fm.a_split_at(2) + + assert not fm_1.has_match() + assert fm_1.a == 'ab' + assert fm_1.b == 'xyz' + + assert not fm_2.has_match() + assert fm_2.a == 'c' + assert fm_2.b == 'xyz' + + def test_b_split_no_match(self): + fm = FuzzyMatchResult('abc', 'xyz', []) + fm_1, fm_2 = fm.b_split_at(2) + + assert not fm_1.has_match() + assert fm_1.a == 'abc' + assert fm_1.b == 'xy' + + assert not fm_2.has_match() + assert fm_2.a == 'abc' + assert fm_2.b == 'z' + + def test_a_split_exact_match(self): + fm = FuzzyMatchResult('abc', 'abc', [(0, 0, 3)]) + fm_1, fm_2 = fm.a_split_at(2) + + assert fm_1.a == 'ab' + assert fm_1.b == 'abc' + assert fm_1.has_match() + assert fm_1.ratio() == 1.0 + assert fm_1.a_ratio() == 1.0 + assert fm_1.b_ratio() == 2 / 3 + assert fm_1.b_gap_ratio() == 2 / 3 + assert fm_1.a_index_range() == (0, 2) + assert fm_1.b_index_range() == (0, 2) + + assert fm_2.a == 'c' + assert fm_2.b == 'abc' + assert fm_2.has_match() + assert fm_2.ratio() == 1.0 + assert fm_2.a_ratio() == 1.0 + assert fm_2.b_ratio() == 1 / 3 + assert fm_2.b_gap_ratio() == 1 / 3 + assert fm_2.a_index_range() == (0, 1) + assert fm_2.b_index_range() == (0, 1) diff --git a/sciencebeam_gym/preprocess/annotation/matching_annotator.py b/sciencebeam_gym/preprocess/annotation/matching_annotator.py index 564d45c..7b5b59b 100644 --- a/sciencebeam_gym/preprocess/annotation/matching_annotator.py +++ b/sciencebeam_gym/preprocess/annotation/matching_annotator.py @@ -8,35 +8,35 @@ from itertools import tee, islice from six.moves import zip_longest from sciencebeam_utils.utils.compat import ( - python_2_unicode_compatible + python_2_unicode_compatible ) from sciencebeam_utils.utils.csv import ( - csv_delimiter_by_filename, - write_csv_row + csv_delimiter_by_filename, + write_csv_row ) from sciencebeam_utils.utils.string import ( - LazyStr + LazyStr ) from sciencebeam_utils.utils.collection import ( - iter_flatten, - extract_from_dict + iter_flatten, + extract_from_dict ) from sciencebeam_gym.structured_document import ( - B_TAG_PREFIX, - I_TAG_PREFIX + B_TAG_PREFIX, + I_TAG_PREFIX ) from sciencebeam_gym.preprocess.annotation.fuzzy_match import ( - remove_junk, - fuzzy_match + remove_junk, + fuzzy_match ) from sciencebeam_gym.preprocess.annotation.annotator import ( - AbstractAnnotator + AbstractAnnotator ) THIN_SPACE = u'\u2009' @@ -46,623 +46,653 @@ EM_DASH = u'\u2014' DEFAULT_SCORE_THRESHOLD = 0.9 DEFAULT_MAX_MATCH_GAP = 5 + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def normalise_str(s): - return s.lower().replace(EM_DASH, u'-').replace(EN_DASH, u'-').replace(THIN_SPACE, ' ') + return s.lower().replace(EM_DASH, u'-').replace(EN_DASH, u'-').replace(THIN_SPACE, ' ') + def normalise_str_or_list(x): - if isinstance(x, list): - return [normalise_str(s) for s in x] - else: - return normalise_str(x) + if isinstance(x, list): + return [normalise_str(s) for s in x] + else: + return normalise_str(x) + def normalise_and_remove_junk_str(s): - return remove_junk(normalise_str(s)) + return remove_junk(normalise_str(s)) + def normalise_and_remove_junk_str_or_list(x): - if isinstance(x, list): - return [normalise_and_remove_junk_str(s) for s in x] - else: - return normalise_and_remove_junk_str(x) + if isinstance(x, list): + return [normalise_and_remove_junk_str(s) for s in x] + else: + return normalise_and_remove_junk_str(x) + class SequenceWrapper(object): - def __init__(self, structured_document, tokens, str_filter_f=None): - self.structured_document = structured_document - self.str_filter_f = str_filter_f - self.tokens = tokens - self.token_str_list = [structured_document.get_text(t) or '' for t in tokens] - if str_filter_f: - self.token_str_list = [str_filter_f(s) for s in self.token_str_list] - self.tokens_as_str = ' '.join(self.token_str_list) - - def tokens_between(self, index_range): - start, end = index_range - i = 0 - for token, token_str in zip(self.tokens, self.token_str_list): - if i >= end: - break - token_end = i + len(token_str) - if token_end > start: - yield token - i = token_end + 1 - - def sub_sequence_for_tokens(self, tokens): - return SequenceWrapper(self.structured_document, tokens, str_filter_f=self.str_filter_f) - - def untagged_sub_sequences(self): - token_tags = [self.structured_document.get_tag(t) for t in self.tokens] - tagged_count = len([t for t in token_tags if t]) - if tagged_count == 0: - yield self - elif tagged_count == len(self.tokens): - pass - else: - untagged_tokens = [] - for token, tag in zip(self.tokens, token_tags): - if not tag: - untagged_tokens.append(token) - elif untagged_tokens: - yield self.sub_sequence_for_tokens(untagged_tokens) - untagged_tokens = [] - if untagged_tokens: - yield self.sub_sequence_for_tokens(untagged_tokens) - - def __str__(self): - return self.tokens_as_str - - def __repr__(self): - return '{}({})'.format('SequenceWrapper', self.tokens_as_str) + def __init__(self, structured_document, tokens, str_filter_f=None): + self.structured_document = structured_document + self.str_filter_f = str_filter_f + self.tokens = tokens + self.token_str_list = [structured_document.get_text(t) or '' for t in tokens] + if str_filter_f: + self.token_str_list = [str_filter_f(s) for s in self.token_str_list] + self.tokens_as_str = ' '.join(self.token_str_list) + + def tokens_between(self, index_range): + start, end = index_range + i = 0 + for token, token_str in zip(self.tokens, self.token_str_list): + if i >= end: + break + token_end = i + len(token_str) + if token_end > start: + yield token + i = token_end + 1 + + def sub_sequence_for_tokens(self, tokens): + return SequenceWrapper(self.structured_document, tokens, str_filter_f=self.str_filter_f) + + def untagged_sub_sequences(self): + token_tags = [self.structured_document.get_tag(t) for t in self.tokens] + tagged_count = len([t for t in token_tags if t]) + if tagged_count == 0: + yield self + elif tagged_count == len(self.tokens): + pass + else: + untagged_tokens = [] + for token, tag in zip(self.tokens, token_tags): + if not tag: + untagged_tokens.append(token) + elif untagged_tokens: + yield self.sub_sequence_for_tokens(untagged_tokens) + untagged_tokens = [] + if untagged_tokens: + yield self.sub_sequence_for_tokens(untagged_tokens) + + def __str__(self): + return self.tokens_as_str + + def __repr__(self): + return '{}({})'.format('SequenceWrapper', self.tokens_as_str) + class SequenceWrapperWithPosition(SequenceWrapper): - def __init__(self, *args, **kwargs): - position, kwargs = extract_from_dict(kwargs, 'position') - super(SequenceWrapperWithPosition, self).__init__(*args, **kwargs) - self.position = position - - def sub_sequence_for_tokens(self, tokens): - return SequenceWrapperWithPosition( - self.structured_document, tokens, - str_filter_f=self.str_filter_f, - position=self.position - ) + def __init__(self, *args, **kwargs): + position, kwargs = extract_from_dict(kwargs, 'position') + super(SequenceWrapperWithPosition, self).__init__(*args, **kwargs) + self.position = position + + def sub_sequence_for_tokens(self, tokens): + return SequenceWrapperWithPosition( + self.structured_document, tokens, + str_filter_f=self.str_filter_f, + position=self.position + ) + + def __repr__(self): + return '{}({}, {})'.format('SequenceWrapperWithPosition', self.tokens_as_str, self.position) - def __repr__(self): - return '{}({}, {})'.format('SequenceWrapperWithPosition', self.tokens_as_str, self.position) @python_2_unicode_compatible class SequenceMatch(object): - def __init__(self, seq1, seq2, index1_range, index2_range): - self.seq1 = seq1 - self.seq2 = seq2 - self.index1_range = index1_range - self.index2_range = index2_range - - def __repr__(self): - return u"SequenceMatch('{}'[{}:{}], '{}'[{}:{}])".format( - self.seq1, - self.index1_range[0], - self.index1_range[1], - self.seq2, - self.index2_range[0], - self.index2_range[1] - ) + def __init__(self, seq1, seq2, index1_range, index2_range): + self.seq1 = seq1 + self.seq2 = seq2 + self.index1_range = index1_range + self.index2_range = index2_range + + def __repr__(self): + return u"SequenceMatch('{}'[{}:{}], '{}'[{}:{}])".format( + self.seq1, + self.index1_range[0], + self.index1_range[1], + self.seq2, + self.index2_range[0], + self.index2_range[1] + ) + @python_2_unicode_compatible class PositionedSequenceSet(object): - def __init__(self): - self.data = set() + def __init__(self): + self.data = set() + + def add(self, sequence): + self.data.add(sequence.position) - def add(self, sequence): - self.data.add(sequence.position) + def is_close_to_any(self, sequence, max_gap): + if not max_gap or not self.data: + return True + position = sequence.position + max_distance = max_gap + 1 + for other_position in self.data: + if abs(position - other_position) <= max_distance: + return True + return False - def is_close_to_any(self, sequence, max_gap): - if not max_gap or not self.data: - return True - position = sequence.position - max_distance = max_gap + 1 - for other_position in self.data: - if abs(position - other_position) <= max_distance: - return True - return False + def __str__(self): + return str(self.data) - def __str__(self): - return str(self.data) def offset_range_by(index_range, offset): - if not offset: - return index_range - return (offset + index_range[0], offset + index_range[1]) + if not offset: + return index_range + return (offset + index_range[0], offset + index_range[1]) + def skip_whitespaces(s, start): - while start < len(s) and s[start].isspace(): - start += 1 - return start + while start < len(s) and s[start].isspace(): + start += 1 + return start + def get_fuzzy_match_filter( - b_score_threshold, min_match_count, total_match_threshold, - ratio_min_match_count, ratio_threshold): - def check(fm, fm_next=None, previous_match=False): - if ( - fm.match_count() >= ratio_min_match_count and - fm.ratio() >= ratio_threshold): - return True - return ( - fm.b_gap_ratio() >= b_score_threshold and - ( - previous_match or - ( - fm.match_count() >= min_match_count and - (fm_next is None or fm_next.ratio() >= ratio_threshold) - ) or - fm.a_ratio() >= total_match_threshold - ) - ) - return check + b_score_threshold, min_match_count, total_match_threshold, + ratio_min_match_count, ratio_threshold): + def check(fm, fm_next=None, previous_match=False): + if ( + fm.match_count() >= ratio_min_match_count and + fm.ratio() >= ratio_threshold): + return True + return ( + fm.b_gap_ratio() >= b_score_threshold and + ( + previous_match or + ( + fm.match_count() >= min_match_count and + (fm_next is None or fm_next.ratio() >= ratio_threshold) + ) or + fm.a_ratio() >= total_match_threshold + ) + ) + return check + DEFAULT_SEQ_FUZZY_MATCH_FILTER = get_fuzzy_match_filter( - DEFAULT_SCORE_THRESHOLD, - 5, - 0.9, - 50, - 0.9 + DEFAULT_SCORE_THRESHOLD, + 5, + 0.9, + 50, + 0.9 ) DEFAULT_CHOICE_FUZZY_MATCH_FILTER = get_fuzzy_match_filter( - DEFAULT_SCORE_THRESHOLD, - 1, - 0.9, - 100, - 0.9 + DEFAULT_SCORE_THRESHOLD, + 1, + 0.9, + 100, + 0.9 ) + class MatchDebugFields(object): - ID = 'id' - TAG = 'tag' - MATCH_MULTIPLE = 'match_multiple' - TAG_VALUE_PRE = 'tag_value_pre' - TAG_VALUE_CURRENT = 'tag_value_current' - START_INDEX = 'start_index' - NEXT_START_INDEX = 'next_start_index' - REACHED_END = 'reached_end' - CHOICE_COMBINED = 'choice_combined' - CHOICE_CURRENT = 'choice_current' - CHOICE_NEXT = 'choice_next' - ACCEPTED = 'accepted' - TAG_TO_CHOICE_MATCH = 'tag_to_choice_match' - SUB_ANNOTATION = 'sub_annotation' - FM_COMBINED = 'fm_combined' - FM_COMBINED_DETAILED = 'fm_combined_detailed' - FM_CURRENT = 'fm_current' - FM_CURRENT_DETAILED = 'fm_current_detailed' - FM_NEXT = 'fm_next' - FM_NEXT_DETAILED = 'fm_next_detailed' + ID = 'id' + TAG = 'tag' + MATCH_MULTIPLE = 'match_multiple' + TAG_VALUE_PRE = 'tag_value_pre' + TAG_VALUE_CURRENT = 'tag_value_current' + START_INDEX = 'start_index' + NEXT_START_INDEX = 'next_start_index' + REACHED_END = 'reached_end' + CHOICE_COMBINED = 'choice_combined' + CHOICE_CURRENT = 'choice_current' + CHOICE_NEXT = 'choice_next' + ACCEPTED = 'accepted' + TAG_TO_CHOICE_MATCH = 'tag_to_choice_match' + SUB_ANNOTATION = 'sub_annotation' + FM_COMBINED = 'fm_combined' + FM_COMBINED_DETAILED = 'fm_combined_detailed' + FM_CURRENT = 'fm_current' + FM_CURRENT_DETAILED = 'fm_current_detailed' + FM_NEXT = 'fm_next' + FM_NEXT_DETAILED = 'fm_next_detailed' + DEFAULT_MATCH_DEBUG_COLUMNS = [ - MatchDebugFields.ID, - MatchDebugFields.TAG, - MatchDebugFields.MATCH_MULTIPLE, - MatchDebugFields.TAG_VALUE_PRE, - MatchDebugFields.TAG_VALUE_CURRENT, - MatchDebugFields.START_INDEX, - MatchDebugFields.NEXT_START_INDEX, - MatchDebugFields.REACHED_END, - MatchDebugFields.CHOICE_COMBINED, - MatchDebugFields.CHOICE_CURRENT, - MatchDebugFields.CHOICE_NEXT, - MatchDebugFields.ACCEPTED, - MatchDebugFields.TAG_TO_CHOICE_MATCH, - MatchDebugFields.SUB_ANNOTATION, - MatchDebugFields.FM_COMBINED, - MatchDebugFields.FM_COMBINED_DETAILED, - MatchDebugFields.FM_CURRENT, - MatchDebugFields.FM_CURRENT_DETAILED, - MatchDebugFields.FM_NEXT, - MatchDebugFields.FM_NEXT_DETAILED + MatchDebugFields.ID, + MatchDebugFields.TAG, + MatchDebugFields.MATCH_MULTIPLE, + MatchDebugFields.TAG_VALUE_PRE, + MatchDebugFields.TAG_VALUE_CURRENT, + MatchDebugFields.START_INDEX, + MatchDebugFields.NEXT_START_INDEX, + MatchDebugFields.REACHED_END, + MatchDebugFields.CHOICE_COMBINED, + MatchDebugFields.CHOICE_CURRENT, + MatchDebugFields.CHOICE_NEXT, + MatchDebugFields.ACCEPTED, + MatchDebugFields.TAG_TO_CHOICE_MATCH, + MatchDebugFields.SUB_ANNOTATION, + MatchDebugFields.FM_COMBINED, + MatchDebugFields.FM_COMBINED_DETAILED, + MatchDebugFields.FM_CURRENT, + MatchDebugFields.FM_CURRENT_DETAILED, + MatchDebugFields.FM_NEXT, + MatchDebugFields.FM_NEXT_DETAILED ] + class TargetAnnotationMatchFinder(object): - def __init__( - self, - target_annotation, - sequence, - choices, - seq_match_filter=DEFAULT_SEQ_FUZZY_MATCH_FILTER, - choice_match_filter=DEFAULT_CHOICE_FUZZY_MATCH_FILTER, - max_gap=DEFAULT_MAX_MATCH_GAP, - matched_choices=None, - match_detail_reporter=None, - is_sub_match=False - ): - if matched_choices is None: - matched_choices = PositionedSequenceSet() - self.target_annotation = target_annotation - self.sequence = sequence - self.choices = choices - self.seq_match_filter = seq_match_filter - self.choice_match_filter = choice_match_filter - self.max_gap = max_gap - self.matched_choices = matched_choices - self.match_detail_reporter = match_detail_reporter - self.is_sub_match = is_sub_match - self.current_choices, self.next_choices = tee(choices, 2) - self.next_choices = islice(self.next_choices, 1, None) - - def find_next_best_matches(self): - if isinstance(self.sequence, list): - all_matches = [] - get_logger().debug('found sequence list: %s', self.sequence) - # Use tee as choices may be an iterable instead of a list - for s, sub_current_choices, sub_next_choices in zip( - self.sequence, - tee(self.current_choices, len(self.sequence)), - tee(self.next_choices, len(self.sequence)) - ): - all_matches.extend(self._do_find_next_best_matches( - s, - sub_current_choices, - sub_next_choices - )) - get_logger().debug( - 'all_matches (bonding=%s): %s', self.target_annotation.bonding, all_matches - ) - if not self.target_annotation.bonding or len(all_matches) > 1 or self.target_annotation.name == 'author': - for m in all_matches: - yield m - else: - matches = self._do_find_next_best_matches( - self.sequence, self.current_choices, self.next_choices - ) - for m in matches: - yield m - - def _do_find_next_best_matches(self, sequence, current_choices, next_choices): - target_annotation = self.target_annotation - seq_match_filter = self.seq_match_filter - choice_match_filter = self.choice_match_filter - max_gap = self.max_gap - matched_choices = self.matched_choices - - start_index = 0 - s1 = text(sequence) - too_distant_choices = [] - - is_last_match = False - previous_match = False - - for choice, next_choice in zip_longest(current_choices, next_choices): - if not matched_choices.is_close_to_any(choice, max_gap=max_gap): - too_distant_choices.append(choice) - continue - current_choice_str = text(choice) - if not current_choice_str: - return - if next_choice: - next_choice_str = text(next_choice) - choice_str = current_choice_str + ' ' + next_choice_str - else: - choice_str = current_choice_str - next_choice_str = None - current_start_index = start_index - get_logger().debug( - 'processing choice: tag=%s, s1[:%d]=%s, s1[%d:]=%s, current=%s, next=%s (%s), combined=%s', - target_annotation.name, - start_index, s1[:start_index], - start_index, s1[start_index:], - current_choice_str, - next_choice_str, type(next_choice_str), choice_str - ) - fm_combined, fm, fm_next = None, None, None - reached_end = None - tag_to_choice_match = self.is_sub_match or (len(s1) - start_index < len(current_choice_str)) - if not tag_to_choice_match: - fm_combined = fuzzy_match(s1, choice_str) - fm, fm_next = fm_combined.b_split_at(len(current_choice_str)) - get_logger().debug( - 'regular match: s1=%s, choice=%s, fm=%s (combined: %s)', - s1, choice, fm, fm_combined - ) - get_logger().debug('detailed match: %s', fm_combined.detailed()) - accept_match = fm.has_match() and ( - seq_match_filter(fm, fm_next, previous_match=previous_match) or - (seq_match_filter(fm_combined) and fm.b_start_index() < len(current_choice_str)) - ) - if accept_match: - previous_match = True - sm = SequenceMatch( - sequence, - choice, - fm.a_index_range(), - fm.b_index_range() - ) - matched_choices.add(choice) - get_logger().debug('found match: %s', sm) - yield sm - if fm_next.has_match(): - sm = SequenceMatch( - sequence, - next_choice, - fm_next.a_index_range(), - fm_next.b_index_range() + def __init__( + self, + target_annotation, + sequence, + choices, + seq_match_filter=DEFAULT_SEQ_FUZZY_MATCH_FILTER, + choice_match_filter=DEFAULT_CHOICE_FUZZY_MATCH_FILTER, + max_gap=DEFAULT_MAX_MATCH_GAP, + matched_choices=None, + match_detail_reporter=None, + is_sub_match=False + ): + if matched_choices is None: + matched_choices = PositionedSequenceSet() + self.target_annotation = target_annotation + self.sequence = sequence + self.choices = choices + self.seq_match_filter = seq_match_filter + self.choice_match_filter = choice_match_filter + self.max_gap = max_gap + self.matched_choices = matched_choices + self.match_detail_reporter = match_detail_reporter + self.is_sub_match = is_sub_match + self.current_choices, self.next_choices = tee(choices, 2) + self.next_choices = islice(self.next_choices, 1, None) + + def find_next_best_matches(self): + if isinstance(self.sequence, list): + all_matches = [] + get_logger().debug('found sequence list: %s', self.sequence) + # Use tee as choices may be an iterable instead of a list + for s, sub_current_choices, sub_next_choices in zip( + self.sequence, + tee(self.current_choices, len(self.sequence)), + tee(self.next_choices, len(self.sequence)) + ): + all_matches.extend(self._do_find_next_best_matches( + s, + sub_current_choices, + sub_next_choices + )) + get_logger().debug( + 'all_matches (bonding=%s): %s', self.target_annotation.bonding, all_matches ) - matched_choices.add(choice) - get_logger().debug('found next match: %s', sm) - yield sm - index1_end = skip_whitespaces(s1, fm_next.a_end_index()) - else: - index1_end = skip_whitespaces(s1, fm.a_end_index()) - reached_end = index1_end >= len(s1) - if reached_end: - get_logger().debug('end reached: %d >= %d', index1_end, len(s1)) - is_last_match = True - else: - start_index = index1_end - get_logger().debug('setting start index to: %d', start_index) - else: - s1_sub = s1[start_index:] - fm_combined = fuzzy_match(choice_str, s1_sub) - fm, fm_next = fm_combined.a_split_at(len(current_choice_str)) - get_logger().debug( - 'short match: s1_sub=%s, choice=%s, fm=%s (combined: %s)', - s1_sub, choice, fm, fm_combined - ) - get_logger().debug('detailed match: %s', fm_combined.detailed()) - accept_match = fm.has_match() and ( - choice_match_filter(fm, previous_match=previous_match) or - ( - choice_match_filter(fm_combined) and - fm_combined.a_start_index() < len(current_choice_str) - ) - ) - if accept_match: - sm = SequenceMatch( - sequence, - choice, - offset_range_by(fm.b_index_range(), start_index), - fm.a_index_range() - ) - matched_choices.add(choice) - get_logger().debug('found match: %s', sm) - yield sm - if fm_next.has_match(): - sm = SequenceMatch( - sequence, - next_choice, - offset_range_by(fm_next.b_index_range(), start_index), - fm_next.a_index_range() + if ( + not self.target_annotation.bonding or + len(all_matches) > 1 or + self.target_annotation.name == 'author' + ): + for m in all_matches: + yield m + else: + matches = self._do_find_next_best_matches( + self.sequence, self.current_choices, self.next_choices + ) + for m in matches: + yield m + + def _do_find_next_best_matches(self, sequence, current_choices, next_choices): + target_annotation = self.target_annotation + seq_match_filter = self.seq_match_filter + choice_match_filter = self.choice_match_filter + max_gap = self.max_gap + matched_choices = self.matched_choices + + start_index = 0 + s1 = text(sequence) + too_distant_choices = [] + + is_last_match = False + previous_match = False + + for choice, next_choice in zip_longest(current_choices, next_choices): + if not matched_choices.is_close_to_any(choice, max_gap=max_gap): + too_distant_choices.append(choice) + continue + current_choice_str = text(choice) + if not current_choice_str: + return + if next_choice: + next_choice_str = text(next_choice) + choice_str = current_choice_str + ' ' + next_choice_str + else: + choice_str = current_choice_str + next_choice_str = None + current_start_index = start_index + get_logger().debug( + 'processing choice: tag=%s, s1[:%d]=%s, s1[%d:]=%s' + ', current=%s, next=%s (%s), combined=%s', + target_annotation.name, + start_index, s1[:start_index], + start_index, s1[start_index:], + current_choice_str, + next_choice_str, type(next_choice_str), choice_str ) - get_logger().debug('found next match: %s', sm) - matched_choices.add(next_choice) - yield sm - is_last_match = True - if self.match_detail_reporter: - self.match_detail_reporter({ - MatchDebugFields.TAG: target_annotation.name, - MatchDebugFields.MATCH_MULTIPLE: target_annotation.match_multiple, - MatchDebugFields.TAG_VALUE_PRE: s1[:current_start_index], - MatchDebugFields.TAG_VALUE_CURRENT: s1[current_start_index:], - MatchDebugFields.START_INDEX: current_start_index, - MatchDebugFields.NEXT_START_INDEX: start_index, - MatchDebugFields.REACHED_END: reached_end, - MatchDebugFields.CHOICE_COMBINED: choice_str, - MatchDebugFields.CHOICE_CURRENT: current_choice_str, - MatchDebugFields.CHOICE_NEXT: next_choice_str, - MatchDebugFields.ACCEPTED: accept_match, - MatchDebugFields.TAG_TO_CHOICE_MATCH: tag_to_choice_match, - MatchDebugFields.SUB_ANNOTATION: self.is_sub_match, - MatchDebugFields.FM_COMBINED: fm_combined, - MatchDebugFields.FM_COMBINED_DETAILED: fm_combined and fm_combined.detailed_str(), - MatchDebugFields.FM_CURRENT: fm, - MatchDebugFields.FM_CURRENT_DETAILED: fm and fm.detailed_str(), - MatchDebugFields.FM_NEXT: fm_next, - MatchDebugFields.FM_NEXT_DETAILED: fm_next.detailed_str() - }) - if is_last_match: - break - if too_distant_choices: - get_logger().debug( - 'ignored too distant choices: matched=%s (ignored=%s)', - matched_choices, - LazyStr(lambda: ' '.join(str(choice.position) for choice in too_distant_choices)) - ) + fm_combined, fm, fm_next = None, None, None + reached_end = None + tag_to_choice_match = self.is_sub_match or ( + len(s1) - start_index < len(current_choice_str)) + if not tag_to_choice_match: + fm_combined = fuzzy_match(s1, choice_str) + fm, fm_next = fm_combined.b_split_at(len(current_choice_str)) + get_logger().debug( + 'regular match: s1=%s, choice=%s, fm=%s (combined: %s)', + s1, choice, fm, fm_combined + ) + get_logger().debug('detailed match: %s', fm_combined.detailed()) + accept_match = fm.has_match() and ( + seq_match_filter(fm, fm_next, previous_match=previous_match) or + (seq_match_filter(fm_combined) and fm.b_start_index() < len(current_choice_str)) + ) + if accept_match: + previous_match = True + sm = SequenceMatch( + sequence, + choice, + fm.a_index_range(), + fm.b_index_range() + ) + matched_choices.add(choice) + get_logger().debug('found match: %s', sm) + yield sm + if fm_next.has_match(): + sm = SequenceMatch( + sequence, + next_choice, + fm_next.a_index_range(), + fm_next.b_index_range() + ) + matched_choices.add(choice) + get_logger().debug('found next match: %s', sm) + yield sm + index1_end = skip_whitespaces(s1, fm_next.a_end_index()) + else: + index1_end = skip_whitespaces(s1, fm.a_end_index()) + reached_end = index1_end >= len(s1) + if reached_end: + get_logger().debug('end reached: %d >= %d', index1_end, len(s1)) + is_last_match = True + else: + start_index = index1_end + get_logger().debug('setting start index to: %d', start_index) + else: + s1_sub = s1[start_index:] + fm_combined = fuzzy_match(choice_str, s1_sub) + fm, fm_next = fm_combined.a_split_at(len(current_choice_str)) + get_logger().debug( + 'short match: s1_sub=%s, choice=%s, fm=%s (combined: %s)', + s1_sub, choice, fm, fm_combined + ) + get_logger().debug('detailed match: %s', fm_combined.detailed()) + accept_match = fm.has_match() and ( + choice_match_filter(fm, previous_match=previous_match) or + ( + choice_match_filter(fm_combined) and + fm_combined.a_start_index() < len(current_choice_str) + ) + ) + if accept_match: + sm = SequenceMatch( + sequence, + choice, + offset_range_by(fm.b_index_range(), start_index), + fm.a_index_range() + ) + matched_choices.add(choice) + get_logger().debug('found match: %s', sm) + yield sm + if fm_next.has_match(): + sm = SequenceMatch( + sequence, + next_choice, + offset_range_by(fm_next.b_index_range(), start_index), + fm_next.a_index_range() + ) + get_logger().debug('found next match: %s', sm) + matched_choices.add(next_choice) + yield sm + is_last_match = True + if self.match_detail_reporter: + self.match_detail_reporter({ + MatchDebugFields.TAG: target_annotation.name, + MatchDebugFields.MATCH_MULTIPLE: target_annotation.match_multiple, + MatchDebugFields.TAG_VALUE_PRE: s1[:current_start_index], + MatchDebugFields.TAG_VALUE_CURRENT: s1[current_start_index:], + MatchDebugFields.START_INDEX: current_start_index, + MatchDebugFields.NEXT_START_INDEX: start_index, + MatchDebugFields.REACHED_END: reached_end, + MatchDebugFields.CHOICE_COMBINED: choice_str, + MatchDebugFields.CHOICE_CURRENT: current_choice_str, + MatchDebugFields.CHOICE_NEXT: next_choice_str, + MatchDebugFields.ACCEPTED: accept_match, + MatchDebugFields.TAG_TO_CHOICE_MATCH: tag_to_choice_match, + MatchDebugFields.SUB_ANNOTATION: self.is_sub_match, + MatchDebugFields.FM_COMBINED: fm_combined, + MatchDebugFields.FM_COMBINED_DETAILED: ( + fm_combined and fm_combined.detailed_str() + ), + MatchDebugFields.FM_CURRENT: fm, + MatchDebugFields.FM_CURRENT_DETAILED: fm and fm.detailed_str(), + MatchDebugFields.FM_NEXT: fm_next, + MatchDebugFields.FM_NEXT_DETAILED: fm_next.detailed_str() + }) + if is_last_match: + break + if too_distant_choices: + get_logger().debug( + 'ignored too distant choices: matched=%s (ignored=%s)', + matched_choices, + LazyStr(lambda: ' '.join(str(choice.position) for choice in too_distant_choices)) + ) + class CsvMatchDetailReporter(object): - def __init__(self, fp, filename=None, fields=None): - self.fp = fp - self.fields = fields or DEFAULT_MATCH_DEBUG_COLUMNS - self.writer = csv.writer( - fp, - delimiter=csv_delimiter_by_filename(filename) - ) - self.writer.writerow(self.fields) - self.id = 1 + def __init__(self, fp, filename=None, fields=None): + self.fp = fp + self.fields = fields or DEFAULT_MATCH_DEBUG_COLUMNS + self.writer = csv.writer( + fp, + delimiter=csv_delimiter_by_filename(filename) + ) + self.writer.writerow(self.fields) + self.id = 1 - def __call__(self, row): - get_logger().debug('logging debug match id %d', self.id) - write_csv_row(self.writer, [ - self.id if k == MatchDebugFields.ID else row.get(k) - for k in self.fields - ]) - self.id += 1 + def __call__(self, row): + get_logger().debug('logging debug match id %d', self.id) + write_csv_row(self.writer, [ + self.id if k == MatchDebugFields.ID else row.get(k) + for k in self.fields + ]) + self.id += 1 + + def close(self): + self.fp.close() - def close(self): - self.fp.close() def sorted_matches_by_position(matches): - return sorted( - matches, - key=lambda m: (m.seq2.position, m.index2_range) - ) + return sorted( + matches, + key=lambda m: (m.seq2.position, m.index2_range) + ) + def matches_position_range(matches): - positions = [m.seq2.position for m in matches] - return min(positions), max(positions) + positions = [m.seq2.position for m in matches] + return min(positions), max(positions) + def distance_between_matches(matches1, matches2): - matches1_start, matches1_end = matches_position_range(matches1) - matches2_start, matches2_end = matches_position_range(matches2) - return min( - abs(matches2_start - matches1_end), - abs(matches1_start - matches2_end) - ) + matches1_start, matches1_end = matches_position_range(matches1) + matches2_start, matches2_end = matches_position_range(matches2) + return min( + abs(matches2_start - matches1_end), + abs(matches1_start - matches2_end) + ) + def _apply_sub_annotations( - target_annotation, structured_document, matching_tokens, - match_detail_reporter, use_tag_begin_prefix): - - seq = SequenceWrapperWithPosition( - structured_document, matching_tokens, normalise_str, - position=0 - ) - matched_choices = PositionedSequenceSet() - for sub_annotation in target_annotation.sub_annotations: - sub_tag = sub_annotation.name - match_finder = TargetAnnotationMatchFinder( - sub_annotation, - normalise_str_or_list(sub_annotation.value), - [seq], - matched_choices=matched_choices, - match_detail_reporter=match_detail_reporter, - is_sub_match=True - ) - matches = match_finder.find_next_best_matches() - matches = list(matches) - get_logger().info('sub annotation matches: %s', matches) - first_token = True - for m in matches: - choice = m.seq2 - matching_sub_tokens = list(choice.tokens_between(m.index2_range)) - get_logger().debug( - 'matching_sub_tokens: %s %s', - [structured_document.get_text(token) for token in matching_tokens], - m.index2_range - ) - for token in matching_sub_tokens: - tag_prefix = None - if use_tag_begin_prefix: - tag_prefix = B_TAG_PREFIX if first_token else I_TAG_PREFIX - structured_document.set_sub_tag_with_prefix(token, sub_tag, prefix=tag_prefix) - first_token = False + target_annotation, structured_document, matching_tokens, + match_detail_reporter, use_tag_begin_prefix): -def _apply_annotations_to_matches( - target_annotation, - structured_document, - matches, - match_detail_reporter, - use_tag_begin_prefix): - - first_token = True - all_matching_tokens = [] - for m in matches: - choice = m.seq2 - matching_tokens = list(choice.tokens_between(m.index2_range)) - get_logger().debug( - 'matching_tokens: %s %s', - [structured_document.get_text(token) for token in matching_tokens], - m.index2_range + seq = SequenceWrapperWithPosition( + structured_document, matching_tokens, normalise_str, + position=0 ) - for token in matching_tokens: - if not structured_document.get_tag(token): - tag_prefix = None - if use_tag_begin_prefix: - tag_prefix = B_TAG_PREFIX if first_token else I_TAG_PREFIX - structured_document.set_tag_with_prefix( - token, - target_annotation.name, - prefix=tag_prefix + matched_choices = PositionedSequenceSet() + for sub_annotation in target_annotation.sub_annotations: + sub_tag = sub_annotation.name + match_finder = TargetAnnotationMatchFinder( + sub_annotation, + normalise_str_or_list(sub_annotation.value), + [seq], + matched_choices=matched_choices, + match_detail_reporter=match_detail_reporter, + is_sub_match=True ) - first_token = False - all_matching_tokens.append(token) - if target_annotation.sub_annotations: - _apply_sub_annotations( - target_annotation, structured_document, all_matching_tokens, - match_detail_reporter, use_tag_begin_prefix - ) + matches = match_finder.find_next_best_matches() + matches = list(matches) + get_logger().info('sub annotation matches: %s', matches) + first_token = True + for m in matches: + choice = m.seq2 + matching_sub_tokens = list(choice.tokens_between(m.index2_range)) + get_logger().debug( + 'matching_sub_tokens: %s %s', + [structured_document.get_text(token) for token in matching_tokens], + m.index2_range + ) + for token in matching_sub_tokens: + tag_prefix = None + if use_tag_begin_prefix: + tag_prefix = B_TAG_PREFIX if first_token else I_TAG_PREFIX + structured_document.set_sub_tag_with_prefix(token, sub_tag, prefix=tag_prefix) + first_token = False -class MatchingAnnotator(AbstractAnnotator): - def __init__( - self, target_annotations, match_detail_reporter=None, - use_tag_begin_prefix=False): - - self.target_annotations = target_annotations - self.match_detail_reporter = match_detail_reporter - self.use_tag_begin_prefix = use_tag_begin_prefix - - def annotate(self, structured_document): - pending_sequences = [] - for page in structured_document.get_pages(): - for line in structured_document.get_lines_of_page(page): - tokens = [ - token - for token in structured_document.get_tokens_of_line(line) - if not structured_document.get_tag(token) - ] - if tokens: - get_logger().debug( - 'tokens without tag: %s', - [structured_document.get_text(token) for token in tokens] - ) - pending_sequences.append(SequenceWrapperWithPosition( - structured_document, - tokens, - normalise_and_remove_junk_str, - position=len(pending_sequences) - )) - - conditional_match = None - - matched_choices_map = dict() - for target_annotation in self.target_annotations: - get_logger().debug('target annotation: %s', target_annotation) - target_value = normalise_and_remove_junk_str_or_list(target_annotation.value) - untagged_pending_sequences = iter_flatten( - seq.untagged_sub_sequences() for seq in pending_sequences - ) - if target_annotation.bonding: - matched_choices = matched_choices_map.setdefault( - target_annotation.name, - PositionedSequenceSet() - ) - else: - matched_choices = PositionedSequenceSet() - match_finder = TargetAnnotationMatchFinder( +def _apply_annotations_to_matches( target_annotation, - target_value, - untagged_pending_sequences, - matched_choices=matched_choices, - match_detail_reporter=self.match_detail_reporter - ) - item_index = 0 - while item_index == 0 or target_annotation.match_multiple: - get_logger().info('calling find_next_best_matches') - matches = sorted_matches_by_position( - match_finder.find_next_best_matches() + structured_document, + matches, + match_detail_reporter, + use_tag_begin_prefix): + + first_token = True + all_matching_tokens = [] + for m in matches: + choice = m.seq2 + matching_tokens = list(choice.tokens_between(m.index2_range)) + get_logger().debug( + 'matching_tokens: %s %s', + [structured_document.get_text(token) for token in matching_tokens], + m.index2_range ) - if not matches: - conditional_match = None - break - get_logger().info('matches: %s', matches) - if ( - conditional_match and - distance_between_matches(matches, conditional_match['matches']) <= 1 - ): - _apply_annotations_to_matches( - conditional_match['target_annotation'], - structured_document, - conditional_match['matches'], - self.match_detail_reporter, self.use_tag_begin_prefix - ) - if target_annotation.require_next: - conditional_match = dict( - target_annotation=target_annotation, - matches=matches - ) - else: - _apply_annotations_to_matches( - target_annotation, structured_document, matches, - self.match_detail_reporter, self.use_tag_begin_prefix - ) - item_index += 1 - return structured_document + for token in matching_tokens: + if not structured_document.get_tag(token): + tag_prefix = None + if use_tag_begin_prefix: + tag_prefix = B_TAG_PREFIX if first_token else I_TAG_PREFIX + structured_document.set_tag_with_prefix( + token, + target_annotation.name, + prefix=tag_prefix + ) + first_token = False + all_matching_tokens.append(token) + if target_annotation.sub_annotations: + _apply_sub_annotations( + target_annotation, structured_document, all_matching_tokens, + match_detail_reporter, use_tag_begin_prefix + ) + + +class MatchingAnnotator(AbstractAnnotator): + def __init__( + self, target_annotations, match_detail_reporter=None, + use_tag_begin_prefix=False): + + self.target_annotations = target_annotations + self.match_detail_reporter = match_detail_reporter + self.use_tag_begin_prefix = use_tag_begin_prefix + + def annotate(self, structured_document): + pending_sequences = [] + for page in structured_document.get_pages(): + for line in structured_document.get_lines_of_page(page): + tokens = [ + token + for token in structured_document.get_tokens_of_line(line) + if not structured_document.get_tag(token) + ] + if tokens: + get_logger().debug( + 'tokens without tag: %s', + [structured_document.get_text(token) for token in tokens] + ) + pending_sequences.append(SequenceWrapperWithPosition( + structured_document, + tokens, + normalise_and_remove_junk_str, + position=len(pending_sequences) + )) + + conditional_match = None + + matched_choices_map = dict() + for target_annotation in self.target_annotations: + get_logger().debug('target annotation: %s', target_annotation) + target_value = normalise_and_remove_junk_str_or_list(target_annotation.value) + untagged_pending_sequences = iter_flatten( + seq.untagged_sub_sequences() for seq in pending_sequences + ) + if target_annotation.bonding: + matched_choices = matched_choices_map.setdefault( + target_annotation.name, + PositionedSequenceSet() + ) + else: + matched_choices = PositionedSequenceSet() + match_finder = TargetAnnotationMatchFinder( + target_annotation, + target_value, + untagged_pending_sequences, + matched_choices=matched_choices, + match_detail_reporter=self.match_detail_reporter + ) + item_index = 0 + while item_index == 0 or target_annotation.match_multiple: + get_logger().info('calling find_next_best_matches') + matches = sorted_matches_by_position( + match_finder.find_next_best_matches() + ) + if not matches: + conditional_match = None + break + get_logger().info('matches: %s', matches) + if ( + conditional_match and + distance_between_matches(matches, conditional_match['matches']) <= 1 + ): + _apply_annotations_to_matches( + conditional_match['target_annotation'], + structured_document, + conditional_match['matches'], + self.match_detail_reporter, self.use_tag_begin_prefix + ) + if target_annotation.require_next: + conditional_match = dict( + target_annotation=target_annotation, + matches=matches + ) + else: + _apply_annotations_to_matches( + target_annotation, structured_document, matches, + self.match_detail_reporter, self.use_tag_begin_prefix + ) + item_index += 1 + return structured_document diff --git a/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py b/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py index 76fb765..fafac01 100644 --- a/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py +++ b/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py @@ -3,26 +3,26 @@ from __future__ import division import logging from sciencebeam_utils.utils.collection import ( - flatten + flatten ) from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimpleLine, - SimpleToken + SimpleStructuredDocument, + SimpleLine, + SimpleToken ) from sciencebeam_gym.preprocess.annotation.target_annotation import ( - TargetAnnotation + TargetAnnotation ) from sciencebeam_gym.preprocess.annotation.matching_annotator import ( - normalise_str, - MatchingAnnotator, - SequenceWrapper, - THIN_SPACE, - EN_DASH, - EM_DASH + normalise_str, + MatchingAnnotator, + SequenceWrapper, + THIN_SPACE, + EN_DASH, + EM_DASH ) TAG1 = 'tag1' @@ -39,877 +39,892 @@ SOME_VALUE_2 = 'some value2' SOME_LONGER_VALUE = 'some longer value1' SOME_SHORTER_VALUE = 'value1' + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + def _get_tags_of_tokens(tokens): - return [t.get_tag() for t in tokens] + return [t.get_tag() for t in tokens] + def _copy_tokens(tokens): - return [SimpleToken(t.text) for t in tokens] + return [SimpleToken(t.text) for t in tokens] + def _tokens_for_text(text): - return [SimpleToken(s) for s in text.split(' ')] + return [SimpleToken(s) for s in text.split(' ')] + def _tokens_to_text(tokens): - return ' '.join([t.text for t in tokens]) + return ' '.join([t.text for t in tokens]) + def _tokens_for_text_lines(text_lines): - return [_tokens_for_text(line) for line in text_lines] + return [_tokens_for_text(line) for line in text_lines] + def _lines_for_tokens(tokens_by_line): - return [SimpleLine(tokens) for tokens in tokens_by_line] + return [SimpleLine(tokens) for tokens in tokens_by_line] + def _document_for_tokens(tokens_by_line): - return SimpleStructuredDocument(lines=_lines_for_tokens(tokens_by_line)) + return SimpleStructuredDocument(lines=_lines_for_tokens(tokens_by_line)) + class TestNormaliseStr(object): - def test_should_replace_thin_space_with_regular_space(self): - assert normalise_str(THIN_SPACE) == ' ' + def test_should_replace_thin_space_with_regular_space(self): + assert normalise_str(THIN_SPACE) == ' ' - def test_should_replace_em_dash_with_hyphen(self): - assert normalise_str(EM_DASH) == '-' + def test_should_replace_em_dash_with_hyphen(self): + assert normalise_str(EM_DASH) == '-' + + def test_should_replace_en_dash_with_hyphen(self): + assert normalise_str(EN_DASH) == '-' - def test_should_replace_en_dash_with_hyphen(self): - assert normalise_str(EN_DASH) == '-' class TestSequenceWrapper(object): - def test_should_find_all_tokens_without_str_filter(self): - text = 'this is matching' - tokens = _tokens_for_text(text) - doc = _document_for_tokens([tokens]) - seq = SequenceWrapper(doc, tokens) - assert str(seq) == text - assert list(seq.tokens_between((0, len(text)))) == tokens - - def test_should_find_tokens_between_without_str_filter(self): - text = 'this is matching' - tokens = _tokens_for_text(text) - doc = _document_for_tokens([tokens]) - seq = SequenceWrapper(doc, tokens) - assert str(seq) == text - assert list(seq.tokens_between((6, 7))) == [tokens[1]] - - def test_should_find_tokens_between_adjusted_indices_due_to_str_filter(self): - text = 'this is matching' - tokens = _tokens_for_text(text) - doc = _document_for_tokens([tokens]) - seq = SequenceWrapper(doc, tokens, str_filter_f=lambda s: s.replace('th', '')) - assert str(seq) == 'is is matching' - assert list(seq.tokens_between((4, 5))) == [tokens[1]] + def test_should_find_all_tokens_without_str_filter(self): + text = 'this is matching' + tokens = _tokens_for_text(text) + doc = _document_for_tokens([tokens]) + seq = SequenceWrapper(doc, tokens) + assert str(seq) == text + assert list(seq.tokens_between((0, len(text)))) == tokens + + def test_should_find_tokens_between_without_str_filter(self): + text = 'this is matching' + tokens = _tokens_for_text(text) + doc = _document_for_tokens([tokens]) + seq = SequenceWrapper(doc, tokens) + assert str(seq) == text + assert list(seq.tokens_between((6, 7))) == [tokens[1]] + + def test_should_find_tokens_between_adjusted_indices_due_to_str_filter(self): + text = 'this is matching' + tokens = _tokens_for_text(text) + doc = _document_for_tokens([tokens]) + seq = SequenceWrapper(doc, tokens, str_filter_f=lambda s: s.replace('th', '')) + assert str(seq) == 'is is matching' + assert list(seq.tokens_between((4, 5))) == [tokens[1]] + class TestMatchingAnnotator(object): - def test_should_not_fail_on_empty_document(self): - doc = SimpleStructuredDocument(lines=[]) - MatchingAnnotator([]).annotate(doc) - - def test_should_not_fail_on_empty_line_with_blank_token(self): - target_annotations = [ - TargetAnnotation('this is. matching', TAG1) - ] - doc = _document_for_tokens([[SimpleToken('')]]) - MatchingAnnotator(target_annotations).annotate(doc) - - def test_should_annotate_exactly_matching(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('this is matching', TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_use_begin_prefix_if_enabled(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('this is matching', TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [B_TAG_1, I_TAG_1, I_TAG_1] - - def test_should_match_normalised_characters(self): - matching_tokens = [ - SimpleToken('this'), - SimpleToken('is' + THIN_SPACE + EN_DASH + EM_DASH), - SimpleToken('matching') - ] - target_annotations = [ - TargetAnnotation('this is -- matching', TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_match_case_insensitive(self): - matching_tokens = _tokens_for_text('This Is Matching') - target_annotations = [ - TargetAnnotation('tHIS iS mATCHING', TAG1) - ] - doc = SimpleStructuredDocument(lines=[SimpleLine(matching_tokens)]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_prefer_word_boundaries(self): - pre_tokens = _tokens_for_text('this') - matching_tokens = _tokens_for_text('is') - post_tokens = _tokens_for_text('miss') - target_annotations = [ - TargetAnnotation('is', TAG1) - ] - doc = _document_for_tokens([ - pre_tokens + matching_tokens + post_tokens - ]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) - assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) - - def test_should_annotate_multiple_value_target_annotation(self): - matching_tokens = _tokens_for_text('this may match') - target_annotations = [ - TargetAnnotation([ - 'this', 'may', 'match' - ], TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_multiple_value_target_annotation_with_begin_prefix(self): - matching_tokens = _tokens_for_text('this may match') - target_annotations = [ - TargetAnnotation([ - 'this', 'may', 'match' - ], TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == ( - [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1) - ) - - def test_should_annotate_multiple_value_target_annotation_rev_order_with_begin_prefix(self): - matching_tokens = _tokens_for_text('this may match') - target_annotations = [ - TargetAnnotation(list(reversed([ - 'this', 'may', 'match' - ])), TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == ( - [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1) - ) - - def test_should_annotate_multiple_value_target_annotation_over_multiple_lines(self): - tokens_by_line = [ - _tokens_for_text('this may'), - _tokens_for_text('match') - ] - matching_tokens = flatten(tokens_by_line) - target_annotations = [ - TargetAnnotation([ - 'this', 'may', 'match' - ], TAG1) - ] - doc = _document_for_tokens(tokens_by_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_mult_value_target_annot_rev_order_over_mult_lines_with_b_prefix(self): - tokens_by_line = [ - _tokens_for_text('this may'), - _tokens_for_text('match') - ] - matching_tokens = flatten(tokens_by_line) - target_annotations = [ - TargetAnnotation(list(reversed([ - 'this', 'may', 'match' - ])), TAG1) - ] - doc = _document_for_tokens(tokens_by_line) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == ( - [B_TAG_1] + - [I_TAG_1] * (len(matching_tokens) - 1) - ) - - def test_should_annotate_not_match_distant_value_of_multiple_value_target_annotation(self): - matching_tokens = _tokens_for_text('this may match') - distant_matching_tokens = _tokens_for_text('not') - distance_in_lines = 10 - tokens_by_line = [matching_tokens] + [ - _tokens_for_text('other') for _ in range(distance_in_lines) - ] + [distant_matching_tokens] - target_annotations = [ - TargetAnnotation([ - 'this', 'may', 'match', 'not' - ], TAG1) - ] - doc = _document_for_tokens(tokens_by_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(distant_matching_tokens) == [None] * len(distant_matching_tokens) - - def test_should_annotate_not_match_distant_value_of_target_annotation_with_bonding(self): - matching_tokens = _tokens_for_text('this may match') - distant_matching_tokens = _tokens_for_text('not') - distance_in_lines = 10 - tokens_by_line = [matching_tokens] + [ - _tokens_for_text('other') for _ in range(distance_in_lines) - ] + [distant_matching_tokens] - target_annotations = [ - TargetAnnotation('this may match', TAG1, bonding=True), - TargetAnnotation('not', TAG1, bonding=True) - ] - doc = _document_for_tokens(tokens_by_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(distant_matching_tokens) == [None] * len(distant_matching_tokens) - - def test_should_annotate_fuzzily_matching(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('this is. matching', TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_ignoring_space_after_dot_short_sequence(self): - matching_tokens = [ - SimpleToken('A.B.,') - ] - target_annotations = [ - TargetAnnotation('A. B.', TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_ignoring_comma_after_short_sequence(self): - matching_tokens = [ - SimpleToken('Name,'), - ] - target_annotations = [ - TargetAnnotation('Name', TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_ignoring_dots_after_capitals_in_target_annotation(self): - matching_tokens = _tokens_for_text('PO Box 12345') - target_annotations = [ - TargetAnnotation('P.O. Box 12345', TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_ignoring_dots_after_capitals_in_document(self): - matching_tokens = _tokens_for_text('P.O. Box 12345') - target_annotations = [ - TargetAnnotation('PO Box 12345', TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_with_local_matching_smaller_gaps(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('this is. matching indeed matching', TAG1) - ] - # this should align with 'this is_ matching' with one gap' - # instead of globally 'this is_ ________ ______ matching' - # (which would result in a worse b_gap_ratio) - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_not_annotate_fuzzily_matching_with_many_differences(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('txhxixsx ixsx mxaxtxcxhxixnxgx', TAG1) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [None] * len(matching_tokens) - - def test_should_annotate_fuzzily_matching_longer_matches_based_on_ratio(self): - long_matching_text = 'this is matching and is really really long match that we can trust' - matching_tokens = _tokens_for_text(long_matching_text) - no_matching_tokens = _tokens_for_text('what comes next is different') - target_annotations = [ - TargetAnnotation(long_matching_text + ' but this is not and is another matter', TAG1) - ] - doc = _document_for_tokens([ - matching_tokens + no_matching_tokens - ]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(no_matching_tokens) == [None] * len(no_matching_tokens) - - def test_should_not_annotate_not_matching(self): - not_matching_tokens = _tokens_for_text('something completely different') - target_annotations = [ - TargetAnnotation('this is matching', TAG1) - ] - doc = _document_for_tokens([not_matching_tokens]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(not_matching_tokens) == [None] * len(not_matching_tokens) - - def test_should_annotate_exactly_matching_across_multiple_lines(self): - matching_tokens_per_line = [ - _tokens_for_text('this is matching'), - _tokens_for_text('and continues here') - ] - matching_tokens = flatten(matching_tokens_per_line) - target_annotations = [ - TargetAnnotation('this is matching and continues here', TAG1) - ] - doc = _document_for_tokens(matching_tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_exactly_matching_across_multiple_lines_with_begin_prefix(self): - matching_tokens_per_line = [ - _tokens_for_text('this is matching'), - _tokens_for_text('and continues here') - ] - matching_tokens = flatten(matching_tokens_per_line) - target_annotations = [ - TargetAnnotation('this is matching and continues here', TAG1) - ] - doc = _document_for_tokens(matching_tokens_per_line) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == ( - [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1) - ) - - def test_should_not_annotate_shorter_sequence_if_next_line_does_not_match(self): - tokens_per_line = [ - _tokens_for_text('this is'), - _tokens_for_text('something completely different') - ] - tokens = flatten(tokens_per_line) - target_annotations = [ - TargetAnnotation('this is not matching', TAG1) - ] - doc = _document_for_tokens(tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(tokens) == [None] * len(tokens) - - def test_should_annotate_over_multiple_lines_with_tag_transition(self): - tag1_tokens_by_line = [ - _tokens_for_text('this may'), - _tokens_for_text('match') - ] - tag1_tokens = flatten(tag1_tokens_by_line) - tag2_tokens_by_line = [ - _tokens_for_text('another'), - _tokens_for_text('tag here') - ] - tag2_tokens = flatten(tag2_tokens_by_line) - tokens_by_line = [ - tag1_tokens_by_line[0], - tag1_tokens_by_line[1] + tag2_tokens_by_line[0], - tag2_tokens_by_line[1] - ] - target_annotations = [ - TargetAnnotation('this may match', TAG1), - TargetAnnotation('another tag here', TAG2) - ] - doc = _document_for_tokens(tokens_by_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(tag1_tokens) == [TAG1] * len(tag1_tokens) - assert _get_tags_of_tokens(tag2_tokens) == [TAG2] * len(tag2_tokens) - - def test_should_annotate_over_multiple_lines_with_tag_transition_with_begin_prefix(self): - tag1_tokens_by_line = [ - _tokens_for_text('this may'), - _tokens_for_text('match') - ] - tag1_tokens = flatten(tag1_tokens_by_line) - tag2_tokens_by_line = [ - _tokens_for_text('another'), - _tokens_for_text('tag here') - ] - tag2_tokens = flatten(tag2_tokens_by_line) - tokens_by_line = [ - tag1_tokens_by_line[0], - tag1_tokens_by_line[1] + tag2_tokens_by_line[0], - tag2_tokens_by_line[1] - ] - target_annotations = [ - TargetAnnotation('this may match', TAG1), - TargetAnnotation('another tag here', TAG2) - ] - doc = _document_for_tokens(tokens_by_line) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) - assert _get_tags_of_tokens(tag1_tokens) == ( - [B_TAG_1] + [I_TAG_1] * (len(tag1_tokens) - 1) - ) - assert _get_tags_of_tokens(tag2_tokens) == ( - [B_TAG_2] + [I_TAG_2] * (len(tag2_tokens) - 1) - ) - - def test_should_annotate_short_section_title_followed_by_paragraph(self): - section_title_text = 'section title' - section_paragraph_text = 'paragraph text to come here.' - section_title_tokens = _tokens_for_text(section_title_text + '.') - section_paragraph_tokens = _tokens_for_text(section_paragraph_text) - tokens_per_line = [ - section_title_tokens + section_paragraph_tokens - ] - target_annotations = [ - TargetAnnotation(section_title_text, 'section_title', require_next=True), - TargetAnnotation(section_paragraph_text, 'section_paragraph') - ] - doc = _document_for_tokens(tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert ( - _get_tags_of_tokens(section_title_tokens) == - ['section_title'] * len(section_title_tokens) - ) - assert ( - _get_tags_of_tokens(section_paragraph_tokens) == - ['section_paragraph'] * len(section_paragraph_tokens) - ) - - def test_should_not_annotate_short_section_title_not_followed_by_paragraph(self): - section_title_text = 'section title' - section_title_tokens = _tokens_for_text(section_title_text + '.') - section_paragraph_text = 'paragraph text to come here.' - section_paragraph_tokens = _tokens_for_text(section_paragraph_text) - tokens_per_line = [ - section_title_tokens + _tokens_for_text('other text to come here.'), - _tokens_for_text('more unrelated text.'), - _tokens_for_text('even more.'), - section_paragraph_tokens - ] - target_annotations = [ - TargetAnnotation(section_title_text, 'section_title', require_next=True), - TargetAnnotation(section_paragraph_text, 'section_paragraph') - ] - doc = _document_for_tokens(tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert ( - _get_tags_of_tokens(section_title_tokens) == - [None] * len(section_title_tokens) - ) - - def test_should_not_annotate_short_section_title_if_paragraph_follows_later(self): - section_title_text = 'section title' - section_title_tokens = _tokens_for_text(section_title_text + '.') - other_tokens = _tokens_for_text('other text to come here.') - tokens_per_line = [ - section_title_tokens + other_tokens - ] - target_annotations = [ - TargetAnnotation(section_title_text, 'section_title', require_next=True) - ] - doc = _document_for_tokens(tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert ( - _get_tags_of_tokens(section_title_tokens) == - [None] * len(section_title_tokens) - ) - - def test_should_annotate_short_reference_item_followed_by_other_reference_items(self): - reference_item_texts = ['ref_id', 'ref_title'] - reference_item_tokens = _tokens_for_text(' '.join(reference_item_texts)) - tokens_per_line = [ - reference_item_tokens - ] - target_annotations = [ - TargetAnnotation(reference_item_texts, 'reference', bonding=True) - ] - doc = _document_for_tokens(tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert ( - _get_tags_of_tokens(reference_item_tokens) == - ['reference'] * len(reference_item_tokens) - ) - - def test_should_not_annotate_short_reference_item_not_followed_by_other_reference_items(self): - matching_reference_item_text = 'ref_id' - reference_item_texts = [matching_reference_item_text] + ['ref_title'] - matching_reference_item_tokens = _tokens_for_text(matching_reference_item_text) - other_tokens = _tokens_for_text('other') - tokens_per_line = [ - matching_reference_item_tokens + other_tokens - ] - target_annotations = [ - TargetAnnotation(reference_item_texts, 'reference', bonding=True) - ] - doc = _document_for_tokens(tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert ( - _get_tags_of_tokens(matching_reference_item_tokens) == - [None] * len(matching_reference_item_tokens) - ) - - def test_should_annotate_last_line_of_block_followed_by_other_text(self): - block_text_lines = [ - 'this is the first row', - 'second row follows', - 'here we are on the third', - 'last line of block' - ] - block_tokens_per_line = _tokens_for_text_lines(block_text_lines) - block_tokens = flatten(block_tokens_per_line) - tokens_per_line = block_tokens_per_line + [ - _tokens_for_text('other text') - ] - target_annotations = [ - TargetAnnotation('\n'.join(block_text_lines), TAG1) - ] - doc = _document_for_tokens(tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert ( - _get_tags_of_tokens(block_tokens) == - [TAG1] * len(block_tokens) - ) - - def test_should_annotate_longer_sequence_over_multiple_lines_considering_next_line(self): - # we need a long enough sequence to fall into the first branch - # and match the partial match threshold - exact_matching_text_lines = ('this may', 'indeed match very well without the slightest doubt') - # add a short prefix that doesn't affect the score much - # but would be skipped if we only matched the second line - matching_text_lines = (exact_matching_text_lines[0], 'x ' + exact_matching_text_lines[1]) - matching_tokens_by_line = _tokens_for_text_lines(matching_text_lines) - matching_tokens = flatten(matching_tokens_by_line) - pre_tokens = _tokens_for_text(matching_text_lines[0] + ' this may not') - post_tokens = _tokens_for_text('or not') - tokens_by_line = [ - pre_tokens + matching_tokens_by_line[0], - matching_tokens_by_line[1] + post_tokens - ] - target_annotations = [ - TargetAnnotation(' '.join(exact_matching_text_lines), TAG1) - ] - doc = _document_for_tokens(tokens_by_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) - assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) - - def test_should_annotate_shorter_sequence_over_multiple_lines_considering_next_line(self): - # use a short sequence that wouldn't get matched on it's own - matching_text_lines = ('this may', 'match') - matching_tokens_by_line = _tokens_for_text_lines(matching_text_lines) - matching_tokens = flatten(matching_tokens_by_line) - # repeat the same text on the two lines, only by combining the lines would it be clear - # which tokens to match - pre_tokens = _tokens_for_text(matching_text_lines[0] + ' be some other longer preceeding text') - post_tokens = _tokens_for_text('this is some text after but no ' + matching_text_lines[1]) - tokens_by_line = [ - pre_tokens + matching_tokens_by_line[0], - matching_tokens_by_line[1] + post_tokens - ] - target_annotations = [ - TargetAnnotation('this may match', TAG1) - ] - doc = _document_for_tokens(tokens_by_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) - assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) - - def test_should_not_annotate_too_short_match_of_longer_sequence(self): - matching_tokens = _tokens_for_text('this is matching') - too_short_tokens = _tokens_for_text('1') - tokens_per_line = [ - too_short_tokens, - matching_tokens - ] - target_annotations = [ - TargetAnnotation('this is matching 1', TAG1) - ] - doc = _document_for_tokens(tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(too_short_tokens) == [None] * len(too_short_tokens) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_not_annotate_similar_sequence_multiple_times(self): - matching_tokens_per_line = [ - _tokens_for_text('this is matching'), - _tokens_for_text('and continues here') - ] - not_matching_tokens = _tokens_for_text('this is matching') - - matching_tokens = flatten(matching_tokens_per_line) - target_annotations = [ - TargetAnnotation('this is matching and continues here', TAG1) - ] - doc = _document_for_tokens( - matching_tokens_per_line + [not_matching_tokens] - ) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(not_matching_tokens) == [None] * len(not_matching_tokens) - - def test_should_annotate_same_sequence_multiple_times_if_enabled(self): - matching_tokens_per_line = [ - _tokens_for_text('this is matching'), - _tokens_for_text('this is matching') - ] - - matching_tokens = flatten(matching_tokens_per_line) - target_annotations = [ - TargetAnnotation('this is matching', TAG1, match_multiple=True) - ] - doc = _document_for_tokens(matching_tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_same_sequence_multiple_times_with_begin_prefix(self): - matching_tokens_per_line = [ - _tokens_for_text('this is matching'), - _tokens_for_text('this is matching') - ] - - matching_tokens = flatten(matching_tokens_per_line) - target_annotations = [ - TargetAnnotation('this is matching', TAG1, match_multiple=True) - ] - doc = _document_for_tokens(matching_tokens_per_line) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) - # the begin tag should appear at the beginning of each match - assert ( - _get_tags_of_tokens(matching_tokens) == - [B_TAG_1, I_TAG_1, I_TAG_1, B_TAG_1, I_TAG_1, I_TAG_1] - ) - - def test_should_not_override_annotation(self): - matching_tokens_per_line = [ - _tokens_for_text('this is matching') - ] - - matching_tokens = flatten(matching_tokens_per_line) - target_annotations = [ - TargetAnnotation('this is matching', TAG1), - TargetAnnotation('matching', TAG2) - ] - doc = _document_for_tokens(matching_tokens_per_line) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_not_annotate_pre_annotated_tokens_on_separate_lines(self): - line_no_tokens = _tokens_for_text('1') - line_no_tokens[0].set_tag('line_no') - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('1', TAG2), - TargetAnnotation('this is matching', TAG1) - ] - doc = _document_for_tokens([ - line_no_tokens + matching_tokens - ]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(line_no_tokens) == ['line_no'] * len(line_no_tokens) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - - def test_should_annotate_shorter_target_annotation_in_longer_line(self): - pre_tokens = _tokens_for_text('pre') - matching_tokens = _tokens_for_text('this is matching') - post_tokens = _tokens_for_text('post') - target_annotations = [ - TargetAnnotation('this is matching', TAG1) - ] - doc = _document_for_tokens([ - pre_tokens + matching_tokens + post_tokens - ]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) - - def test_should_annotate_shorter_target_annotation_fuzzily(self): - pre_tokens = _tokens_for_text('pre') - matching_tokens = _tokens_for_text('this is matching') - post_tokens = _tokens_for_text('post') - target_annotations = [ - TargetAnnotation('this is. matching', TAG1) - ] - doc = _document_for_tokens([ - pre_tokens + matching_tokens + post_tokens - ]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) - - def test_should_annotate_multiple_shorter_target_annotation_in_longer_line(self): - pre_tokens = _tokens_for_text('pre') - matching_tokens_tag_1 = _tokens_for_text('this is matching') - mid_tokens = _tokens_for_text('mid') - matching_tokens_tag_2 = _tokens_for_text('also good') - post_tokens = _tokens_for_text('post') - target_annotations = [ - TargetAnnotation('this is matching', TAG1), - TargetAnnotation('also good', TAG2) - ] - doc = _document_for_tokens([ - pre_tokens + matching_tokens_tag_1 + mid_tokens + matching_tokens_tag_2 + post_tokens - ]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) - assert _get_tags_of_tokens(matching_tokens_tag_1) == [TAG1] * len(matching_tokens_tag_1) - assert _get_tags_of_tokens(mid_tokens) == [None] * len(mid_tokens) - assert _get_tags_of_tokens(matching_tokens_tag_2) == [TAG2] * len(matching_tokens_tag_2) - assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) - - def test_should_not_annotate_shorter_target_annotation_in_longer_line_multiple_times(self): - pre_tokens = _tokens_for_text('pre') - matching_tokens = _tokens_for_text('this is matching') - post_tokens = _tokens_for_text('post') - first_line_tokens = pre_tokens + matching_tokens + post_tokens - similar_line_tokens = _copy_tokens(first_line_tokens) - target_annotations = [ - TargetAnnotation('this is matching', TAG1) - ] - doc = _document_for_tokens([ - first_line_tokens, - similar_line_tokens - ]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(similar_line_tokens) == [None] * len(similar_line_tokens) - - def test_should_annotate_shorter_target_annotation_in_longer_line_multiple_times_if_enabled(self): - pre_tokens = _tokens_for_text('pre') - matching_tokens = _tokens_for_text('this is matching') - post_tokens = _tokens_for_text('post') - same_matching_tokens = _copy_tokens(matching_tokens) - target_annotations = [ - TargetAnnotation('this is matching', TAG1, match_multiple=True) - ] - doc = _document_for_tokens([ - pre_tokens + matching_tokens + post_tokens, - _copy_tokens(pre_tokens) + same_matching_tokens + _copy_tokens(post_tokens) - ]) - MatchingAnnotator(target_annotations).annotate(doc) - assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) - assert _get_tags_of_tokens(same_matching_tokens) == [TAG1] * len(same_matching_tokens) + def test_should_not_fail_on_empty_document(self): + doc = SimpleStructuredDocument(lines=[]) + MatchingAnnotator([]).annotate(doc) + + def test_should_not_fail_on_empty_line_with_blank_token(self): + target_annotations = [ + TargetAnnotation('this is. matching', TAG1) + ] + doc = _document_for_tokens([[SimpleToken('')]]) + MatchingAnnotator(target_annotations).annotate(doc) + + def test_should_annotate_exactly_matching(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('this is matching', TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_use_begin_prefix_if_enabled(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('this is matching', TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [B_TAG_1, I_TAG_1, I_TAG_1] + + def test_should_match_normalised_characters(self): + matching_tokens = [ + SimpleToken('this'), + SimpleToken('is' + THIN_SPACE + EN_DASH + EM_DASH), + SimpleToken('matching') + ] + target_annotations = [ + TargetAnnotation('this is -- matching', TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_match_case_insensitive(self): + matching_tokens = _tokens_for_text('This Is Matching') + target_annotations = [ + TargetAnnotation('tHIS iS mATCHING', TAG1) + ] + doc = SimpleStructuredDocument(lines=[SimpleLine(matching_tokens)]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_prefer_word_boundaries(self): + pre_tokens = _tokens_for_text('this') + matching_tokens = _tokens_for_text('is') + post_tokens = _tokens_for_text('miss') + target_annotations = [ + TargetAnnotation('is', TAG1) + ] + doc = _document_for_tokens([ + pre_tokens + matching_tokens + post_tokens + ]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) + assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) + + def test_should_annotate_multiple_value_target_annotation(self): + matching_tokens = _tokens_for_text('this may match') + target_annotations = [ + TargetAnnotation([ + 'this', 'may', 'match' + ], TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_multiple_value_target_annotation_with_begin_prefix(self): + matching_tokens = _tokens_for_text('this may match') + target_annotations = [ + TargetAnnotation([ + 'this', 'may', 'match' + ], TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == ( + [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1) + ) + + def test_should_annotate_multiple_value_target_annotation_rev_order_with_begin_prefix(self): + matching_tokens = _tokens_for_text('this may match') + target_annotations = [ + TargetAnnotation(list(reversed([ + 'this', 'may', 'match' + ])), TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == ( + [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1) + ) + + def test_should_annotate_multiple_value_target_annotation_over_multiple_lines(self): + tokens_by_line = [ + _tokens_for_text('this may'), + _tokens_for_text('match') + ] + matching_tokens = flatten(tokens_by_line) + target_annotations = [ + TargetAnnotation([ + 'this', 'may', 'match' + ], TAG1) + ] + doc = _document_for_tokens(tokens_by_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_mult_value_target_annot_rev_order_over_mult_lines_with_b_prefix(self): + tokens_by_line = [ + _tokens_for_text('this may'), + _tokens_for_text('match') + ] + matching_tokens = flatten(tokens_by_line) + target_annotations = [ + TargetAnnotation(list(reversed([ + 'this', 'may', 'match' + ])), TAG1) + ] + doc = _document_for_tokens(tokens_by_line) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == ( + [B_TAG_1] + + [I_TAG_1] * (len(matching_tokens) - 1) + ) + + def test_should_annotate_not_match_distant_value_of_multiple_value_target_annotation(self): + matching_tokens = _tokens_for_text('this may match') + distant_matching_tokens = _tokens_for_text('not') + distance_in_lines = 10 + tokens_by_line = [matching_tokens] + [ + _tokens_for_text('other') for _ in range(distance_in_lines) + ] + [distant_matching_tokens] + target_annotations = [ + TargetAnnotation([ + 'this', 'may', 'match', 'not' + ], TAG1) + ] + doc = _document_for_tokens(tokens_by_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(distant_matching_tokens) == [None] * len(distant_matching_tokens) + + def test_should_annotate_not_match_distant_value_of_target_annotation_with_bonding(self): + matching_tokens = _tokens_for_text('this may match') + distant_matching_tokens = _tokens_for_text('not') + distance_in_lines = 10 + tokens_by_line = [matching_tokens] + [ + _tokens_for_text('other') for _ in range(distance_in_lines) + ] + [distant_matching_tokens] + target_annotations = [ + TargetAnnotation('this may match', TAG1, bonding=True), + TargetAnnotation('not', TAG1, bonding=True) + ] + doc = _document_for_tokens(tokens_by_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(distant_matching_tokens) == [None] * len(distant_matching_tokens) + + def test_should_annotate_fuzzily_matching(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('this is. matching', TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_ignoring_space_after_dot_short_sequence(self): + matching_tokens = [ + SimpleToken('A.B.,') + ] + target_annotations = [ + TargetAnnotation('A. B.', TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_ignoring_comma_after_short_sequence(self): + matching_tokens = [ + SimpleToken('Name,'), + ] + target_annotations = [ + TargetAnnotation('Name', TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_ignoring_dots_after_capitals_in_target_annotation(self): + matching_tokens = _tokens_for_text('PO Box 12345') + target_annotations = [ + TargetAnnotation('P.O. Box 12345', TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_ignoring_dots_after_capitals_in_document(self): + matching_tokens = _tokens_for_text('P.O. Box 12345') + target_annotations = [ + TargetAnnotation('PO Box 12345', TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_with_local_matching_smaller_gaps(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('this is. matching indeed matching', TAG1) + ] + # this should align with 'this is_ matching' with one gap' + # instead of globally 'this is_ ________ ______ matching' + # (which would result in a worse b_gap_ratio) + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_not_annotate_fuzzily_matching_with_many_differences(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('txhxixsx ixsx mxaxtxcxhxixnxgx', TAG1) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [None] * len(matching_tokens) + + def test_should_annotate_fuzzily_matching_longer_matches_based_on_ratio(self): + long_matching_text = 'this is matching and is really really long match that we can trust' + matching_tokens = _tokens_for_text(long_matching_text) + no_matching_tokens = _tokens_for_text('what comes next is different') + target_annotations = [ + TargetAnnotation(long_matching_text + ' but this is not and is another matter', TAG1) + ] + doc = _document_for_tokens([ + matching_tokens + no_matching_tokens + ]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(no_matching_tokens) == [None] * len(no_matching_tokens) + + def test_should_not_annotate_not_matching(self): + not_matching_tokens = _tokens_for_text('something completely different') + target_annotations = [ + TargetAnnotation('this is matching', TAG1) + ] + doc = _document_for_tokens([not_matching_tokens]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(not_matching_tokens) == [None] * len(not_matching_tokens) + + def test_should_annotate_exactly_matching_across_multiple_lines(self): + matching_tokens_per_line = [ + _tokens_for_text('this is matching'), + _tokens_for_text('and continues here') + ] + matching_tokens = flatten(matching_tokens_per_line) + target_annotations = [ + TargetAnnotation('this is matching and continues here', TAG1) + ] + doc = _document_for_tokens(matching_tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_exactly_matching_across_multiple_lines_with_begin_prefix(self): + matching_tokens_per_line = [ + _tokens_for_text('this is matching'), + _tokens_for_text('and continues here') + ] + matching_tokens = flatten(matching_tokens_per_line) + target_annotations = [ + TargetAnnotation('this is matching and continues here', TAG1) + ] + doc = _document_for_tokens(matching_tokens_per_line) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == ( + [B_TAG_1] + [I_TAG_1] * (len(matching_tokens) - 1) + ) + + def test_should_not_annotate_shorter_sequence_if_next_line_does_not_match(self): + tokens_per_line = [ + _tokens_for_text('this is'), + _tokens_for_text('something completely different') + ] + tokens = flatten(tokens_per_line) + target_annotations = [ + TargetAnnotation('this is not matching', TAG1) + ] + doc = _document_for_tokens(tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(tokens) == [None] * len(tokens) + + def test_should_annotate_over_multiple_lines_with_tag_transition(self): + tag1_tokens_by_line = [ + _tokens_for_text('this may'), + _tokens_for_text('match') + ] + tag1_tokens = flatten(tag1_tokens_by_line) + tag2_tokens_by_line = [ + _tokens_for_text('another'), + _tokens_for_text('tag here') + ] + tag2_tokens = flatten(tag2_tokens_by_line) + tokens_by_line = [ + tag1_tokens_by_line[0], + tag1_tokens_by_line[1] + tag2_tokens_by_line[0], + tag2_tokens_by_line[1] + ] + target_annotations = [ + TargetAnnotation('this may match', TAG1), + TargetAnnotation('another tag here', TAG2) + ] + doc = _document_for_tokens(tokens_by_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(tag1_tokens) == [TAG1] * len(tag1_tokens) + assert _get_tags_of_tokens(tag2_tokens) == [TAG2] * len(tag2_tokens) + + def test_should_annotate_over_multiple_lines_with_tag_transition_with_begin_prefix(self): + tag1_tokens_by_line = [ + _tokens_for_text('this may'), + _tokens_for_text('match') + ] + tag1_tokens = flatten(tag1_tokens_by_line) + tag2_tokens_by_line = [ + _tokens_for_text('another'), + _tokens_for_text('tag here') + ] + tag2_tokens = flatten(tag2_tokens_by_line) + tokens_by_line = [ + tag1_tokens_by_line[0], + tag1_tokens_by_line[1] + tag2_tokens_by_line[0], + tag2_tokens_by_line[1] + ] + target_annotations = [ + TargetAnnotation('this may match', TAG1), + TargetAnnotation('another tag here', TAG2) + ] + doc = _document_for_tokens(tokens_by_line) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) + assert _get_tags_of_tokens(tag1_tokens) == ( + [B_TAG_1] + [I_TAG_1] * (len(tag1_tokens) - 1) + ) + assert _get_tags_of_tokens(tag2_tokens) == ( + [B_TAG_2] + [I_TAG_2] * (len(tag2_tokens) - 1) + ) + + def test_should_annotate_short_section_title_followed_by_paragraph(self): + section_title_text = 'section title' + section_paragraph_text = 'paragraph text to come here.' + section_title_tokens = _tokens_for_text(section_title_text + '.') + section_paragraph_tokens = _tokens_for_text(section_paragraph_text) + tokens_per_line = [ + section_title_tokens + section_paragraph_tokens + ] + target_annotations = [ + TargetAnnotation(section_title_text, 'section_title', require_next=True), + TargetAnnotation(section_paragraph_text, 'section_paragraph') + ] + doc = _document_for_tokens(tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert ( + _get_tags_of_tokens(section_title_tokens) == + ['section_title'] * len(section_title_tokens) + ) + assert ( + _get_tags_of_tokens(section_paragraph_tokens) == + ['section_paragraph'] * len(section_paragraph_tokens) + ) + + def test_should_not_annotate_short_section_title_not_followed_by_paragraph(self): + section_title_text = 'section title' + section_title_tokens = _tokens_for_text(section_title_text + '.') + section_paragraph_text = 'paragraph text to come here.' + section_paragraph_tokens = _tokens_for_text(section_paragraph_text) + tokens_per_line = [ + section_title_tokens + _tokens_for_text('other text to come here.'), + _tokens_for_text('more unrelated text.'), + _tokens_for_text('even more.'), + section_paragraph_tokens + ] + target_annotations = [ + TargetAnnotation(section_title_text, 'section_title', require_next=True), + TargetAnnotation(section_paragraph_text, 'section_paragraph') + ] + doc = _document_for_tokens(tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert ( + _get_tags_of_tokens(section_title_tokens) == + [None] * len(section_title_tokens) + ) + + def test_should_not_annotate_short_section_title_if_paragraph_follows_later(self): + section_title_text = 'section title' + section_title_tokens = _tokens_for_text(section_title_text + '.') + other_tokens = _tokens_for_text('other text to come here.') + tokens_per_line = [ + section_title_tokens + other_tokens + ] + target_annotations = [ + TargetAnnotation(section_title_text, 'section_title', require_next=True) + ] + doc = _document_for_tokens(tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert ( + _get_tags_of_tokens(section_title_tokens) == + [None] * len(section_title_tokens) + ) + + def test_should_annotate_short_reference_item_followed_by_other_reference_items(self): + reference_item_texts = ['ref_id', 'ref_title'] + reference_item_tokens = _tokens_for_text(' '.join(reference_item_texts)) + tokens_per_line = [ + reference_item_tokens + ] + target_annotations = [ + TargetAnnotation(reference_item_texts, 'reference', bonding=True) + ] + doc = _document_for_tokens(tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert ( + _get_tags_of_tokens(reference_item_tokens) == + ['reference'] * len(reference_item_tokens) + ) + + def test_should_not_annotate_short_reference_item_not_followed_by_other_reference_items(self): + matching_reference_item_text = 'ref_id' + reference_item_texts = [matching_reference_item_text] + ['ref_title'] + matching_reference_item_tokens = _tokens_for_text(matching_reference_item_text) + other_tokens = _tokens_for_text('other') + tokens_per_line = [ + matching_reference_item_tokens + other_tokens + ] + target_annotations = [ + TargetAnnotation(reference_item_texts, 'reference', bonding=True) + ] + doc = _document_for_tokens(tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert ( + _get_tags_of_tokens(matching_reference_item_tokens) == + [None] * len(matching_reference_item_tokens) + ) + + def test_should_annotate_last_line_of_block_followed_by_other_text(self): + block_text_lines = [ + 'this is the first row', + 'second row follows', + 'here we are on the third', + 'last line of block' + ] + block_tokens_per_line = _tokens_for_text_lines(block_text_lines) + block_tokens = flatten(block_tokens_per_line) + tokens_per_line = block_tokens_per_line + [ + _tokens_for_text('other text') + ] + target_annotations = [ + TargetAnnotation('\n'.join(block_text_lines), TAG1) + ] + doc = _document_for_tokens(tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert ( + _get_tags_of_tokens(block_tokens) == + [TAG1] * len(block_tokens) + ) + + def test_should_annotate_longer_sequence_over_multiple_lines_considering_next_line(self): + # we need a long enough sequence to fall into the first branch + # and match the partial match threshold + exact_matching_text_lines = ( + 'this may', 'indeed match very well without the slightest doubt') + # add a short prefix that doesn't affect the score much + # but would be skipped if we only matched the second line + matching_text_lines = (exact_matching_text_lines[0], 'x ' + exact_matching_text_lines[1]) + matching_tokens_by_line = _tokens_for_text_lines(matching_text_lines) + matching_tokens = flatten(matching_tokens_by_line) + pre_tokens = _tokens_for_text(matching_text_lines[0] + ' this may not') + post_tokens = _tokens_for_text('or not') + tokens_by_line = [ + pre_tokens + matching_tokens_by_line[0], + matching_tokens_by_line[1] + post_tokens + ] + target_annotations = [ + TargetAnnotation(' '.join(exact_matching_text_lines), TAG1) + ] + doc = _document_for_tokens(tokens_by_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) + assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) + + def test_should_annotate_shorter_sequence_over_multiple_lines_considering_next_line(self): + # use a short sequence that wouldn't get matched on it's own + matching_text_lines = ('this may', 'match') + matching_tokens_by_line = _tokens_for_text_lines(matching_text_lines) + matching_tokens = flatten(matching_tokens_by_line) + # repeat the same text on the two lines, only by combining the lines would it be clear + # which tokens to match + pre_tokens = _tokens_for_text( + matching_text_lines[0] + ' be some other longer preceeding text') + post_tokens = _tokens_for_text('this is some text after but no ' + matching_text_lines[1]) + tokens_by_line = [ + pre_tokens + matching_tokens_by_line[0], + matching_tokens_by_line[1] + post_tokens + ] + target_annotations = [ + TargetAnnotation('this may match', TAG1) + ] + doc = _document_for_tokens(tokens_by_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) + assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) + + def test_should_not_annotate_too_short_match_of_longer_sequence(self): + matching_tokens = _tokens_for_text('this is matching') + too_short_tokens = _tokens_for_text('1') + tokens_per_line = [ + too_short_tokens, + matching_tokens + ] + target_annotations = [ + TargetAnnotation('this is matching 1', TAG1) + ] + doc = _document_for_tokens(tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(too_short_tokens) == [None] * len(too_short_tokens) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_not_annotate_similar_sequence_multiple_times(self): + matching_tokens_per_line = [ + _tokens_for_text('this is matching'), + _tokens_for_text('and continues here') + ] + not_matching_tokens = _tokens_for_text('this is matching') + + matching_tokens = flatten(matching_tokens_per_line) + target_annotations = [ + TargetAnnotation('this is matching and continues here', TAG1) + ] + doc = _document_for_tokens( + matching_tokens_per_line + [not_matching_tokens] + ) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(not_matching_tokens) == [None] * len(not_matching_tokens) + + def test_should_annotate_same_sequence_multiple_times_if_enabled(self): + matching_tokens_per_line = [ + _tokens_for_text('this is matching'), + _tokens_for_text('this is matching') + ] + + matching_tokens = flatten(matching_tokens_per_line) + target_annotations = [ + TargetAnnotation('this is matching', TAG1, match_multiple=True) + ] + doc = _document_for_tokens(matching_tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_same_sequence_multiple_times_with_begin_prefix(self): + matching_tokens_per_line = [ + _tokens_for_text('this is matching'), + _tokens_for_text('this is matching') + ] + + matching_tokens = flatten(matching_tokens_per_line) + target_annotations = [ + TargetAnnotation('this is matching', TAG1, match_multiple=True) + ] + doc = _document_for_tokens(matching_tokens_per_line) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) + # the begin tag should appear at the beginning of each match + assert ( + _get_tags_of_tokens(matching_tokens) == + [B_TAG_1, I_TAG_1, I_TAG_1, B_TAG_1, I_TAG_1, I_TAG_1] + ) + + def test_should_not_override_annotation(self): + matching_tokens_per_line = [ + _tokens_for_text('this is matching') + ] + + matching_tokens = flatten(matching_tokens_per_line) + target_annotations = [ + TargetAnnotation('this is matching', TAG1), + TargetAnnotation('matching', TAG2) + ] + doc = _document_for_tokens(matching_tokens_per_line) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_not_annotate_pre_annotated_tokens_on_separate_lines(self): + line_no_tokens = _tokens_for_text('1') + line_no_tokens[0].set_tag('line_no') + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('1', TAG2), + TargetAnnotation('this is matching', TAG1) + ] + doc = _document_for_tokens([ + line_no_tokens + matching_tokens + ]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(line_no_tokens) == ['line_no'] * len(line_no_tokens) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + + def test_should_annotate_shorter_target_annotation_in_longer_line(self): + pre_tokens = _tokens_for_text('pre') + matching_tokens = _tokens_for_text('this is matching') + post_tokens = _tokens_for_text('post') + target_annotations = [ + TargetAnnotation('this is matching', TAG1) + ] + doc = _document_for_tokens([ + pre_tokens + matching_tokens + post_tokens + ]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) + + def test_should_annotate_shorter_target_annotation_fuzzily(self): + pre_tokens = _tokens_for_text('pre') + matching_tokens = _tokens_for_text('this is matching') + post_tokens = _tokens_for_text('post') + target_annotations = [ + TargetAnnotation('this is. matching', TAG1) + ] + doc = _document_for_tokens([ + pre_tokens + matching_tokens + post_tokens + ]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) + + def test_should_annotate_multiple_shorter_target_annotation_in_longer_line(self): + pre_tokens = _tokens_for_text('pre') + matching_tokens_tag_1 = _tokens_for_text('this is matching') + mid_tokens = _tokens_for_text('mid') + matching_tokens_tag_2 = _tokens_for_text('also good') + post_tokens = _tokens_for_text('post') + target_annotations = [ + TargetAnnotation('this is matching', TAG1), + TargetAnnotation('also good', TAG2) + ] + doc = _document_for_tokens([ + pre_tokens + matching_tokens_tag_1 + mid_tokens + matching_tokens_tag_2 + post_tokens + ]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(pre_tokens) == [None] * len(pre_tokens) + assert _get_tags_of_tokens(matching_tokens_tag_1) == [TAG1] * len(matching_tokens_tag_1) + assert _get_tags_of_tokens(mid_tokens) == [None] * len(mid_tokens) + assert _get_tags_of_tokens(matching_tokens_tag_2) == [TAG2] * len(matching_tokens_tag_2) + assert _get_tags_of_tokens(post_tokens) == [None] * len(post_tokens) + + def test_should_not_annotate_shorter_target_annotation_in_longer_line_multiple_times(self): + pre_tokens = _tokens_for_text('pre') + matching_tokens = _tokens_for_text('this is matching') + post_tokens = _tokens_for_text('post') + first_line_tokens = pre_tokens + matching_tokens + post_tokens + similar_line_tokens = _copy_tokens(first_line_tokens) + target_annotations = [ + TargetAnnotation('this is matching', TAG1) + ] + doc = _document_for_tokens([ + first_line_tokens, + similar_line_tokens + ]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(similar_line_tokens) == [None] * len(similar_line_tokens) + + def test_should_annotate_shorter_target_annotation_in_longer_line_multiple_times_if_enabled( + self): + pre_tokens = _tokens_for_text('pre') + matching_tokens = _tokens_for_text('this is matching') + post_tokens = _tokens_for_text('post') + same_matching_tokens = _copy_tokens(matching_tokens) + target_annotations = [ + TargetAnnotation('this is matching', TAG1, match_multiple=True) + ] + doc = _document_for_tokens([ + pre_tokens + matching_tokens + post_tokens, + _copy_tokens(pre_tokens) + same_matching_tokens + _copy_tokens(post_tokens) + ]) + MatchingAnnotator(target_annotations).annotate(doc) + assert _get_tags_of_tokens(matching_tokens) == [TAG1] * len(matching_tokens) + assert _get_tags_of_tokens(same_matching_tokens) == [TAG1] * len(same_matching_tokens) + class TestMatchingAnnotatorSubAnnotations(object): - def test_should_annotate_sub_tag_exactly_matching_without_begin_prefix(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('this is matching', TAG2, sub_annotations=[ - TargetAnnotation('this', TAG1) - ]) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) - assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, None, None] - - def test_should_annotate_sub_tag_exactly_matching_with_begin_prefix(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('this is matching', TAG2, sub_annotations=[ - TargetAnnotation('this', TAG1) - ]) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) - assert [doc.get_sub_tag(x) for x in matching_tokens] == [B_TAG_1, None, None] - - def test_should_annotate_sub_tag_case_insensitive(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('this is matching', TAG2, sub_annotations=[ - TargetAnnotation('This', TAG1) - ]) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) - assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, None, None] - - def test_should_annotate_multiple_sub_tag_exactly_matching(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('this is matching', TAG2, sub_annotations=[ - TargetAnnotation('this', TAG1), - TargetAnnotation('is', TAG2) - ]) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) - assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, TAG2, None] - - def test_should_annotate_multiple_sub_annotations_with_same_sub_tag(self): - matching_tokens = _tokens_for_text('this is matching') - target_annotations = [ - TargetAnnotation('this is matching', TAG2, sub_annotations=[ - TargetAnnotation('this', TAG1), - TargetAnnotation('is', TAG1) - ]) - ] - doc = _document_for_tokens([matching_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) - assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, TAG1, None] - - def test_should_annotate_same_sub_annotations_multiple_times_with_begin_prefic(self): - matching_tokens_by_line = [ - _tokens_for_text('this is matching'), - _tokens_for_text('this is matching') - ] - matching_tokens = flatten(matching_tokens_by_line) - target_annotations = [ - TargetAnnotation('this is matching', TAG2, match_multiple=True, sub_annotations=[ - TargetAnnotation('this is', TAG1) - ]) - ] - doc = _document_for_tokens(matching_tokens_by_line) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) - assert ( - [doc.get_sub_tag(x) for x in matching_tokens] == - [B_TAG_1, I_TAG_1, None, B_TAG_1, I_TAG_1, None] - ) - - def test_should_annotate_sub_tag_across_multiple_tokens(self): - sub_matching_tokens = _tokens_for_text('this is matching') - tag_matching_tokens = ( - _tokens_for_text('something before') + - sub_matching_tokens + - _tokens_for_text('more text to come') - ) - all_tokens = ( - _tokens_for_text('not matching') + tag_matching_tokens + _tokens_for_text('and there') - ) - target_annotations = [ - TargetAnnotation(_tokens_to_text(tag_matching_tokens), TAG2, sub_annotations=[ - TargetAnnotation(_tokens_to_text(sub_matching_tokens), TAG1) - ]) - ] - doc = _document_for_tokens([all_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) - assert [doc.get_sub_tag(x) for x in sub_matching_tokens] == [TAG1, TAG1, TAG1] - - def test_should_annotate_sub_tag_across_multiple_tokens_with_junk_characters(self): - junk_char = ',' - sub_matching_tokens = _tokens_for_text('this is matching' + junk_char) - tag_matching_tokens = ( - _tokens_for_text('something before') + - sub_matching_tokens + - _tokens_for_text('more text to come') - ) - all_tokens = ( - _tokens_for_text('not matching') + tag_matching_tokens + _tokens_for_text('and there') - ) - tag_matching_text = _tokens_to_text(tag_matching_tokens).replace(junk_char, '') - sub_matching_text = _tokens_to_text(sub_matching_tokens).replace(junk_char, '') - target_annotations = [ - TargetAnnotation(tag_matching_text, TAG2, sub_annotations=[ - TargetAnnotation(sub_matching_text, TAG1) - ]) - ] - doc = _document_for_tokens([all_tokens]) - MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) - assert [doc.get_sub_tag(x) for x in sub_matching_tokens] == [TAG1, TAG1, TAG1] + def test_should_annotate_sub_tag_exactly_matching_without_begin_prefix(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('this is matching', TAG2, sub_annotations=[ + TargetAnnotation('this', TAG1) + ]) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) + assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, None, None] + + def test_should_annotate_sub_tag_exactly_matching_with_begin_prefix(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('this is matching', TAG2, sub_annotations=[ + TargetAnnotation('this', TAG1) + ]) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) + assert [doc.get_sub_tag(x) for x in matching_tokens] == [B_TAG_1, None, None] + + def test_should_annotate_sub_tag_case_insensitive(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('this is matching', TAG2, sub_annotations=[ + TargetAnnotation('This', TAG1) + ]) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) + assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, None, None] + + def test_should_annotate_multiple_sub_tag_exactly_matching(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('this is matching', TAG2, sub_annotations=[ + TargetAnnotation('this', TAG1), + TargetAnnotation('is', TAG2) + ]) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) + assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, TAG2, None] + + def test_should_annotate_multiple_sub_annotations_with_same_sub_tag(self): + matching_tokens = _tokens_for_text('this is matching') + target_annotations = [ + TargetAnnotation('this is matching', TAG2, sub_annotations=[ + TargetAnnotation('this', TAG1), + TargetAnnotation('is', TAG1) + ]) + ] + doc = _document_for_tokens([matching_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) + assert [doc.get_sub_tag(x) for x in matching_tokens] == [TAG1, TAG1, None] + + def test_should_annotate_same_sub_annotations_multiple_times_with_begin_prefic(self): + matching_tokens_by_line = [ + _tokens_for_text('this is matching'), + _tokens_for_text('this is matching') + ] + matching_tokens = flatten(matching_tokens_by_line) + target_annotations = [ + TargetAnnotation('this is matching', TAG2, match_multiple=True, sub_annotations=[ + TargetAnnotation('this is', TAG1) + ]) + ] + doc = _document_for_tokens(matching_tokens_by_line) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=True).annotate(doc) + assert ( + [doc.get_sub_tag(x) for x in matching_tokens] == + [B_TAG_1, I_TAG_1, None, B_TAG_1, I_TAG_1, None] + ) + + def test_should_annotate_sub_tag_across_multiple_tokens(self): + sub_matching_tokens = _tokens_for_text('this is matching') + tag_matching_tokens = ( + _tokens_for_text('something before') + + sub_matching_tokens + + _tokens_for_text('more text to come') + ) + all_tokens = ( + _tokens_for_text('not matching') + tag_matching_tokens + _tokens_for_text('and there') + ) + target_annotations = [ + TargetAnnotation(_tokens_to_text(tag_matching_tokens), TAG2, sub_annotations=[ + TargetAnnotation(_tokens_to_text(sub_matching_tokens), TAG1) + ]) + ] + doc = _document_for_tokens([all_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) + assert [doc.get_sub_tag(x) for x in sub_matching_tokens] == [TAG1, TAG1, TAG1] + + def test_should_annotate_sub_tag_across_multiple_tokens_with_junk_characters(self): + junk_char = ',' + sub_matching_tokens = _tokens_for_text('this is matching' + junk_char) + tag_matching_tokens = ( + _tokens_for_text('something before') + + sub_matching_tokens + + _tokens_for_text('more text to come') + ) + all_tokens = ( + _tokens_for_text('not matching') + tag_matching_tokens + _tokens_for_text('and there') + ) + tag_matching_text = _tokens_to_text(tag_matching_tokens).replace(junk_char, '') + sub_matching_text = _tokens_to_text(sub_matching_tokens).replace(junk_char, '') + target_annotations = [ + TargetAnnotation(tag_matching_text, TAG2, sub_annotations=[ + TargetAnnotation(sub_matching_text, TAG1) + ]) + ] + doc = _document_for_tokens([all_tokens]) + MatchingAnnotator(target_annotations, use_tag_begin_prefix=False).annotate(doc) + assert [doc.get_sub_tag(x) for x in sub_matching_tokens] == [TAG1, TAG1, TAG1] diff --git a/sciencebeam_gym/preprocess/annotation/target_annotation.py b/sciencebeam_gym/preprocess/annotation/target_annotation.py index b1ff43d..82f2b2c 100644 --- a/sciencebeam_gym/preprocess/annotation/target_annotation.py +++ b/sciencebeam_gym/preprocess/annotation/target_annotation.py @@ -1,411 +1,448 @@ +from __future__ import absolute_import + import logging import json import re from itertools import chain +from six.moves.configparser import ConfigParser import six -from six.moves.configparser import ConfigParser # pylint: disable=E0401 from lxml import etree from sciencebeam_utils.utils.compat import ( - python_2_unicode_compatible + python_2_unicode_compatible ) from sciencebeam_utils.utils.string import ( - LazyStr + LazyStr ) from sciencebeam_utils.utils.xml import ( - get_text_content, - get_immediate_text + get_text_content, + get_immediate_text ) from sciencebeam_utils.utils.collection import ( - filter_truthy, - strip_all + filter_truthy, + strip_all ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class XmlMappingSuffix(object): - REGEX = '.regex' - EXTRACT_REGEX = '.extract-regex' - MATCH_MULTIPLE = '.match-multiple' - BONDING = '.bonding' - REQUIRE_NEXT = '.require-next' - CHILDREN = '.children' - CHILDREN_CONCAT = '.children.concat' - CHILDREN_RANGE = '.children.range' - UNMATCHED_PARENT_TEXT = '.unmatched-parent-text' - PRIORITY = '.priority' - SUB = '.sub' + REGEX = '.regex' + EXTRACT_REGEX = '.extract-regex' + MATCH_MULTIPLE = '.match-multiple' + BONDING = '.bonding' + REQUIRE_NEXT = '.require-next' + CHILDREN = '.children' + CHILDREN_CONCAT = '.children.concat' + CHILDREN_RANGE = '.children.range' + UNMATCHED_PARENT_TEXT = '.unmatched-parent-text' + PRIORITY = '.priority' + SUB = '.sub' + @python_2_unicode_compatible class TargetAnnotation(object): - def __init__( - self, value, name, - match_multiple=False, bonding=False, require_next=False, - sub_annotations=None): - - self.value = value - self.name = name - self.match_multiple = match_multiple - self.bonding = bonding - self.require_next = require_next - self.sub_annotations = sub_annotations - - def __repr__(self): - return u'TA {} (match_multiple={}, bonding={}, require_next={}): {}'.format( - self.name, self.match_multiple, self.bonding, self.require_next, - self.value - ) + def __init__( + self, value, name, + match_multiple=False, bonding=False, require_next=False, + sub_annotations=None): + + self.value = value + self.name = name + self.match_multiple = match_multiple + self.bonding = bonding + self.require_next = require_next + self.sub_annotations = sub_annotations + + def __repr__(self): + return u'TA {} (match_multiple={}, bonding={}, require_next={}): {}'.format( + self.name, self.match_multiple, self.bonding, self.require_next, + self.value + ) + def parse_xml_mapping(xml_mapping_filename): - with open(xml_mapping_filename, 'r') as f: - config = ConfigParser() - if six.PY3: - config.read_file(f) - else: - config.readfp(f) - return { - k: dict(config.items(k)) - for k in config.sections() - } + with open(xml_mapping_filename, 'r') as f: + config = ConfigParser() + if six.PY3: + config.read_file(f) # pylint: disable=no-member + else: + config.readfp(f) + return { + k: dict(config.items(k)) + for k in config.sections() + } + def relace_recursive(s, old, new): - previous = None - while s != previous: - previous = s - s = s.replace(old, new) - return s + previous = None + while s != previous: + previous = s + s = s.replace(old, new) + return s + def strip_whitespace(s): - replacements = [ - ('\t', ' '), - (' ', ' '), - ('\r', '\n'), - (' \n', '\n'), - ('\n ', '\n'), - ('\n\n', '\n') - ] - for old, new in replacements: - s = relace_recursive(s, old, new) - return s + replacements = [ + ('\t', ' '), + (' ', ' '), + ('\r', '\n'), + (' \n', '\n'), + ('\n ', '\n'), + ('\n\n', '\n') + ] + for old, new in replacements: + s = relace_recursive(s, old, new) + return s + def get_stripped_text_content(node, **kwargs): - return strip_whitespace(get_text_content(node, **kwargs).strip()) + return strip_whitespace(get_text_content(node, **kwargs).strip()) + def get_stripped_text_content_list(nodes, **kwargs): - return [get_stripped_text_content(node, **kwargs) for node in nodes] + return [get_stripped_text_content(node, **kwargs) for node in nodes] + def iter_flatten_if_nested(a): - for x in a: - if isinstance(x, list): - for y in iter_flatten_if_nested(x): - yield y - else: - yield x + for x in a: + if isinstance(x, list): + for y in iter_flatten_if_nested(x): + yield y + else: + yield x + def flatten_if_nested(a): - if not a: - return a - return list(iter_flatten_if_nested(a)) + if not a: + return a + return list(iter_flatten_if_nested(a)) + def apply_pattern(s, compiled_pattern): - m = compiled_pattern.match(s) - if m: - get_logger().debug('regex match: %s -> %s', compiled_pattern, m.groups()) - return m.group(1) - return s + m = compiled_pattern.match(s) + if m: + get_logger().debug('regex match: %s -> %s', compiled_pattern, m.groups()) + return m.group(1) + return s + def iter_parents(children): - for child in children: - p = child.getparent() - if p is not None: - yield p + for child in children: + p = child.getparent() + if p is not None: + yield p + def exclude_parents(children): - if not isinstance(children, list): - children = list(children) - all_parents = set(iter_parents(children)) - return [child for child in children if not child in all_parents] + if not isinstance(children, list): + children = list(children) + all_parents = set(iter_parents(children)) + return [child for child in children if child not in all_parents] + def extract_children_source_list(parent, children_source_list): - used_nodes = set() - values = [] - for children_source in children_source_list: - xpath = children_source.get('xpath') - if xpath: - matching_nodes = exclude_parents(parent.xpath(xpath)) - if not matching_nodes: - get_logger().debug( - 'child xpath does not match any item, skipping: xpath=%s (xml=%s)', - xpath, - LazyStr(lambda: str(etree.tostring(parent))) - ) - used_nodes = set() - values = [] - break - used_nodes |= set(matching_nodes) - value = ' '.join(get_stripped_text_content_list(matching_nodes)) - else: - value = children_source.get('value') - values.append(value or '') - return values, used_nodes + used_nodes = set() + values = [] + for children_source in children_source_list: + xpath = children_source.get('xpath') + if xpath: + matching_nodes = exclude_parents(parent.xpath(xpath)) + if not matching_nodes: + get_logger().debug( + 'child xpath does not match any item, skipping: xpath=%s (xml=%s)', + xpath, + LazyStr(lambda: str(etree.tostring(parent))) + ) + used_nodes = set() + values = [] + break + used_nodes |= set(matching_nodes) + value = ' '.join(get_stripped_text_content_list(matching_nodes)) + else: + value = children_source.get('value') + values.append(value or '') + return values, used_nodes + def extract_children_concat(parent, children_concat): - used_nodes = set() - values = [] - get_logger().debug('children_concat: %s', children_concat) - for children_concat_item in children_concat: - temp_values, temp_used_nodes = extract_children_source_list( - parent, children_concat_item - ) - used_nodes |= temp_used_nodes - if temp_values: - values.append(''.join(temp_values)) - return values, used_nodes + used_nodes = set() + values = [] + get_logger().debug('children_concat: %s', children_concat) + for children_concat_item in children_concat: + temp_values, temp_used_nodes = extract_children_source_list( + parent, children_concat_item + ) + used_nodes |= temp_used_nodes + if temp_values: + values.append(''.join(temp_values)) + return values, used_nodes + def extract_children_range(parent, children_range): - used_nodes = set() - values = [] - standalone_values = [] - get_logger().debug('children_range: %s', children_range) - for range_item in children_range: - temp_values, temp_used_nodes = extract_children_source_list( - parent, [range_item.get('min'), range_item.get('max')] - ) - if len(temp_values) == 2: - temp_values = strip_all(temp_values) - if all(s.isdigit() for s in temp_values): - num_values = [int(s) for s in temp_values] - range_values = [str(x) for x in range(num_values[0], num_values[1] + 1)] - if range_item.get('standalone'): - standalone_values.extend(range_values) - else: - values.extend(range_values) - used_nodes |= temp_used_nodes - else: - get_logger().info('values not integers: %s', temp_values) - return values, standalone_values, used_nodes + used_nodes = set() + values = [] + standalone_values = [] + get_logger().debug('children_range: %s', children_range) + for range_item in children_range: + temp_values, temp_used_nodes = extract_children_source_list( + parent, [range_item.get('min'), range_item.get('max')] + ) + if len(temp_values) == 2: + temp_values = strip_all(temp_values) + if all(s.isdigit() for s in temp_values): + num_values = [int(s) for s in temp_values] + range_values = [str(x) for x in range(num_values[0], num_values[1] + 1)] + if range_item.get('standalone'): + standalone_values.extend(range_values) + else: + values.extend(range_values) + used_nodes |= temp_used_nodes + else: + get_logger().info('values not integers: %s', temp_values) + return values, standalone_values, used_nodes + def parse_xpaths(s): - return strip_all(s.strip().split('\n')) if s else None + return strip_all(s.strip().split('\n')) if s else None + def match_xpaths(parent, xpaths): - return chain(*[parent.xpath(s) for s in xpaths]) + return chain(*[parent.xpath(s) for s in xpaths]) + def extract_children( - parent, children_xpaths, children_concat, children_range, unmatched_parent_text): - - concat_values_list, concat_used_nodes = extract_children_concat(parent, children_concat) - range_values_list, standalone_values, range_used_nodes = ( - extract_children_range(parent, children_range) - ) - used_nodes = concat_used_nodes | range_used_nodes - - other_child_nodes = [ - node for node in match_xpaths(parent, children_xpaths) - if not node in used_nodes - ] - other_child_nodes_excl_parents = exclude_parents(other_child_nodes) - text_content_list = filter_truthy(strip_all( - get_stripped_text_content_list(other_child_nodes_excl_parents) + - concat_values_list + range_values_list - )) - if len(other_child_nodes_excl_parents) != len(other_child_nodes): - other_child_nodes_excl_parents_set = set(other_child_nodes_excl_parents) - for child in other_child_nodes: - if child not in other_child_nodes_excl_parents_set: - text_values = filter_truthy(strip_all(get_immediate_text(child))) - text_content_list.extend(text_values) - if unmatched_parent_text: - value = get_stripped_text_content( - parent, - exclude=set(other_child_nodes) | used_nodes - ).strip() - if value and not value in text_content_list: - text_content_list.append(value) - return text_content_list, standalone_values + parent, children_xpaths, children_concat, children_range, unmatched_parent_text): + + concat_values_list, concat_used_nodes = extract_children_concat(parent, children_concat) + range_values_list, standalone_values, range_used_nodes = ( + extract_children_range(parent, children_range) + ) + used_nodes = concat_used_nodes | range_used_nodes + + other_child_nodes = [ + node for node in match_xpaths(parent, children_xpaths) + if node not in used_nodes + ] + other_child_nodes_excl_parents = exclude_parents(other_child_nodes) + text_content_list = filter_truthy(strip_all( + get_stripped_text_content_list(other_child_nodes_excl_parents) + + concat_values_list + range_values_list + )) + if len(other_child_nodes_excl_parents) != len(other_child_nodes): + other_child_nodes_excl_parents_set = set(other_child_nodes_excl_parents) + for child in other_child_nodes: + if child not in other_child_nodes_excl_parents_set: + text_values = filter_truthy(strip_all(get_immediate_text(child))) + text_content_list.extend(text_values) + if unmatched_parent_text: + value = get_stripped_text_content( + parent, + exclude=set(other_child_nodes) | used_nodes + ).strip() + if value and value not in text_content_list: + text_content_list.append(value) + return text_content_list, standalone_values + def parse_json_with_default(s, default_value): - return json.loads(s) if s else default_value + return json.loads(s) if s else default_value + def get_prefixed_dict_values(d, key_prefix): - return { - k[len(key_prefix):]: v - for k, v in d.items() - if k.startswith(key_prefix) - } + return { + k[len(key_prefix):]: v + for k, v in d.items() + if k.startswith(key_prefix) + } + def get_sub_mapping(mapping, tag): - return { - k: v - for k, v in get_prefixed_dict_values(mapping, tag + XmlMappingSuffix.SUB + '.').items() - if '.' not in k - } + return { + k: v + for k, v in get_prefixed_dict_values(mapping, tag + XmlMappingSuffix.SUB + '.').items() + if '.' not in k + } + def re_compile_or_none(pattern): - return re.compile(pattern) if pattern else None + return re.compile(pattern) if pattern else None + def extract_using_regex(s, compiled_pattern): - result = None - start = 0 - for m in compiled_pattern.finditer(s): - m_start = m.start(1) - m_end = m.end(1) - m_text = m.group(1) - get_logger().debug('extract match: %d:%d, %s', m_start, m_end, m_text) + result = None + start = 0 + for m in compiled_pattern.finditer(s): + m_start = m.start(1) + m_end = m.end(1) + m_text = m.group(1) + get_logger().debug('extract match: %d:%d, %s', m_start, m_end, m_text) + if result is None: + result = [] + if start < m_start: + result.append(s[start:m_start].strip()) + result.append(m_text) + start = m_end + 1 if result is None: - result = [] - if start < m_start: - result.append(s[start:m_start].strip()) - result.append(m_text) - start = m_end + 1 - if result is None: - return s - if start < len(s): - result.append(s[start:].strip()) - if len(result) == 1: - return result[0] - # also include the full string - result.append(s) - return result + return s + if start < len(s): + result.append(s[start:].strip()) + if len(result) == 1: + return result[0] + # also include the full string + result.append(s) + return result + def extract_sub_annotations(parent_node, sub_xpaths, mapping, parent_key): - if not sub_xpaths: - return - sub_annotations = [] - for sub_tag, sub_xpath in sub_xpaths.items(): - sub_key_prefix = parent_key + XmlMappingSuffix.SUB + '.' + sub_tag - extract_re_compiled_pattern = re_compile_or_none( - mapping.get(sub_key_prefix + XmlMappingSuffix.EXTRACT_REGEX) - ) - get_logger().debug('sub_key_prefix: %s', sub_key_prefix) - get_logger().debug( - 'extract_re_compiled_pattern (%s, %s): %s', - parent_key, sub_tag, extract_re_compiled_pattern - ) + if not sub_xpaths: + return None + sub_annotations = [] + for sub_tag, sub_xpath in sub_xpaths.items(): + sub_key_prefix = parent_key + XmlMappingSuffix.SUB + '.' + sub_tag + extract_re_compiled_pattern = re_compile_or_none( + mapping.get(sub_key_prefix + XmlMappingSuffix.EXTRACT_REGEX) + ) + get_logger().debug('sub_key_prefix: %s', sub_key_prefix) + get_logger().debug( + 'extract_re_compiled_pattern (%s, %s): %s', + parent_key, sub_tag, extract_re_compiled_pattern + ) + + for e in match_xpaths(parent_node, [sub_xpath]): + value = get_stripped_text_content(e) + if value: + value = strip_whitespace(value).strip() + if extract_re_compiled_pattern is not None and value: + value = extract_using_regex(value, extract_re_compiled_pattern) + if value: + sub_annotations.append(TargetAnnotation(value, sub_tag)) + return sub_annotations - for e in match_xpaths(parent_node, [sub_xpath]): - value = get_stripped_text_content(e) - if value: - value = strip_whitespace(value).strip() - if extract_re_compiled_pattern is not None and value: - value = extract_using_regex(value, extract_re_compiled_pattern) - if value: - sub_annotations.append(TargetAnnotation(value, sub_tag)) - return sub_annotations def xml_root_to_target_annotations(xml_root, xml_mapping): - if not xml_root.tag in xml_mapping: - raise Exception("unrecognised tag: {} (available: {})".format( - xml_root.tag, xml_mapping.sections()) - ) + if xml_root.tag not in xml_mapping: + raise Exception("unrecognised tag: {} (available: {})".format( + xml_root.tag, xml_mapping.sections()) + ) - mapping = xml_mapping[xml_root.tag] - - field_names = [k for k in mapping.keys() if '.' not in k] - get_mapping_flag = lambda k, suffix: mapping.get(k + suffix) == 'true' - get_match_multiple = lambda k: get_mapping_flag(k, XmlMappingSuffix.MATCH_MULTIPLE) - get_bonding_flag = lambda k: get_mapping_flag(k, XmlMappingSuffix.BONDING) - get_require_next_flag = lambda k: get_mapping_flag(k, XmlMappingSuffix.REQUIRE_NEXT) - get_unmatched_parent_text_flag = ( - lambda k: get_mapping_flag(k, XmlMappingSuffix.UNMATCHED_PARENT_TEXT) - ) - - get_logger().debug('fields: %s', field_names) - - target_annotations_with_pos = [] - xml_pos_by_node = {node: i for i, node in enumerate(xml_root.iter())} - for k in field_names: - match_multiple = get_match_multiple(k) - bonding = get_bonding_flag(k) - require_next = get_require_next_flag(k) - unmatched_parent_text = get_unmatched_parent_text_flag(k) - children_xpaths = parse_xpaths(mapping.get(k + XmlMappingSuffix.CHILDREN)) - children_concat = parse_json_with_default( - mapping.get(k + XmlMappingSuffix.CHILDREN_CONCAT), [] - ) - children_range = parse_json_with_default( - mapping.get(k + XmlMappingSuffix.CHILDREN_RANGE), [] - ) - re_compiled_pattern = re_compile_or_none( - mapping.get(k + XmlMappingSuffix.REGEX) - ) - extract_re_compiled_pattern = re_compile_or_none( - mapping.get(k + XmlMappingSuffix.EXTRACT_REGEX) - ) - get_logger().debug('extract_re_compiled_pattern (%s): %s', k, extract_re_compiled_pattern) + mapping = xml_mapping[xml_root.tag] + + field_names = [k for k in mapping.keys() if '.' not in k] + + def get_mapping_flag(k, suffix): + return mapping.get(k + suffix) == 'true' + + def get_match_multiple(k): + return get_mapping_flag(k, XmlMappingSuffix.MATCH_MULTIPLE) - priority = int(mapping.get(k + XmlMappingSuffix.PRIORITY, '0')) - sub_xpaths = get_sub_mapping(mapping, k) - get_logger().debug('sub_xpaths (%s): %s', k, sub_xpaths) + def get_bonding_flag(k): + return get_mapping_flag(k, XmlMappingSuffix.BONDING) - xpaths = parse_xpaths(mapping[k]) - get_logger().debug('xpaths(%s): %s', k, xpaths) - for e in match_xpaths(xml_root, xpaths): - e_pos = xml_pos_by_node.get(e) + def get_require_next_flag(k): + return get_mapping_flag(k, XmlMappingSuffix.REQUIRE_NEXT) - sub_annotations = extract_sub_annotations(e, sub_xpaths, mapping, k) - get_logger().debug('sub_annotations (%s): %s', k, sub_annotations) + get_unmatched_parent_text_flag = ( + lambda k: get_mapping_flag(k, XmlMappingSuffix.UNMATCHED_PARENT_TEXT) + ) - if children_xpaths: - text_content_list, standalone_values = extract_children( - e, children_xpaths, children_concat, children_range, unmatched_parent_text + get_logger().debug('fields: %s', field_names) + + target_annotations_with_pos = [] + xml_pos_by_node = {node: i for i, node in enumerate(xml_root.iter())} + for k in field_names: + match_multiple = get_match_multiple(k) + bonding = get_bonding_flag(k) + require_next = get_require_next_flag(k) + unmatched_parent_text = get_unmatched_parent_text_flag(k) + children_xpaths = parse_xpaths(mapping.get(k + XmlMappingSuffix.CHILDREN)) + children_concat = parse_json_with_default( + mapping.get(k + XmlMappingSuffix.CHILDREN_CONCAT), [] + ) + children_range = parse_json_with_default( + mapping.get(k + XmlMappingSuffix.CHILDREN_RANGE), [] ) - else: - text_content_list = filter_truthy(strip_all([get_stripped_text_content(e)])) - standalone_values = [] - if re_compiled_pattern: - text_content_list = filter_truthy([ - apply_pattern(s, re_compiled_pattern) for s in text_content_list - ]) - if extract_re_compiled_pattern: - text_content_list = filter_truthy([ - extract_using_regex(s, extract_re_compiled_pattern) for s in text_content_list - ]) - text_content_list = flatten_if_nested(text_content_list) - if text_content_list: - value = ( - text_content_list[0] - if len(text_content_list) == 1 - else sorted(text_content_list, key=lambda s: -len(s)) + re_compiled_pattern = re_compile_or_none( + mapping.get(k + XmlMappingSuffix.REGEX) ) - target_annotations_with_pos.append(( - (-priority, e_pos), - TargetAnnotation( - value, - k, - match_multiple=match_multiple, - bonding=bonding, - require_next=require_next, - sub_annotations=sub_annotations - ) - )) - if standalone_values: - for i, standalone_value in enumerate(standalone_values): - target_annotations_with_pos.append(( - (-priority, e_pos, i), - TargetAnnotation( - standalone_value, - k, - match_multiple=match_multiple, - bonding=bonding, - sub_annotations=sub_annotations - ) - )) - target_annotations_with_pos = sorted( - target_annotations_with_pos, - key=lambda x: x[0] - ) - get_logger().debug('target_annotations_with_pos:\n%s', target_annotations_with_pos) - target_annotations = [ - x[1] for x in target_annotations_with_pos - ] - get_logger().debug('target_annotations:\n%s', '\n'.join([ - ' ' + str(a) for a in target_annotations - ])) - return target_annotations + extract_re_compiled_pattern = re_compile_or_none( + mapping.get(k + XmlMappingSuffix.EXTRACT_REGEX) + ) + get_logger().debug('extract_re_compiled_pattern (%s): %s', k, extract_re_compiled_pattern) + + priority = int(mapping.get(k + XmlMappingSuffix.PRIORITY, '0')) + sub_xpaths = get_sub_mapping(mapping, k) + get_logger().debug('sub_xpaths (%s): %s', k, sub_xpaths) + + xpaths = parse_xpaths(mapping[k]) + get_logger().debug('xpaths(%s): %s', k, xpaths) + for e in match_xpaths(xml_root, xpaths): + e_pos = xml_pos_by_node.get(e) + + sub_annotations = extract_sub_annotations(e, sub_xpaths, mapping, k) + get_logger().debug('sub_annotations (%s): %s', k, sub_annotations) + + if children_xpaths: + text_content_list, standalone_values = extract_children( + e, children_xpaths, children_concat, children_range, unmatched_parent_text + ) + else: + text_content_list = filter_truthy(strip_all([get_stripped_text_content(e)])) + standalone_values = [] + if re_compiled_pattern: + text_content_list = filter_truthy([ + apply_pattern(s, re_compiled_pattern) for s in text_content_list + ]) + if extract_re_compiled_pattern: + text_content_list = filter_truthy([ + extract_using_regex(s, extract_re_compiled_pattern) for s in text_content_list + ]) + text_content_list = flatten_if_nested(text_content_list) + if text_content_list: + value = ( + text_content_list[0] + if len(text_content_list) == 1 + else sorted(text_content_list, key=lambda s: -len(s)) + ) + target_annotations_with_pos.append(( + (-priority, e_pos), + TargetAnnotation( + value, + k, + match_multiple=match_multiple, + bonding=bonding, + require_next=require_next, + sub_annotations=sub_annotations + ) + )) + if standalone_values: + for i, standalone_value in enumerate(standalone_values): + target_annotations_with_pos.append(( + (-priority, e_pos, i), + TargetAnnotation( + standalone_value, + k, + match_multiple=match_multiple, + bonding=bonding, + sub_annotations=sub_annotations + ) + )) + target_annotations_with_pos = sorted( + target_annotations_with_pos, + key=lambda x: x[0] + ) + get_logger().debug('target_annotations_with_pos:\n%s', target_annotations_with_pos) + target_annotations = [ + x[1] for x in target_annotations_with_pos + ] + get_logger().debug('target_annotations:\n%s', '\n'.join([ + ' ' + str(a) for a in target_annotations + ])) + return target_annotations diff --git a/sciencebeam_gym/preprocess/annotation/target_annotation_test.py b/sciencebeam_gym/preprocess/annotation/target_annotation_test.py index a3fdadf..9d80098 100644 --- a/sciencebeam_gym/preprocess/annotation/target_annotation_test.py +++ b/sciencebeam_gym/preprocess/annotation/target_annotation_test.py @@ -5,9 +5,9 @@ import json from lxml.builder import E from sciencebeam_gym.preprocess.annotation.target_annotation import ( - strip_whitespace, - xml_root_to_target_annotations, - XmlMappingSuffix + strip_whitespace, + xml_root_to_target_annotations, + XmlMappingSuffix ) TAG1 = 'tag1' @@ -18,603 +18,608 @@ SOME_VALUE_2 = 'some value2' SOME_LONGER_VALUE = 'some longer value1' SOME_SHORTER_VALUE = 'value1' + class TestStripWhitespace(object): - def test_should_replace_tab_with_space(self): - assert strip_whitespace(SOME_VALUE + '\t' + SOME_VALUE_2) == SOME_VALUE + ' ' + SOME_VALUE_2 + def test_should_replace_tab_with_space(self): + assert strip_whitespace(SOME_VALUE + '\t' + SOME_VALUE_2) == SOME_VALUE + ' ' + SOME_VALUE_2 + + def test_should_strip_double_space(self): + assert strip_whitespace(SOME_VALUE + ' ' + SOME_VALUE_2) == SOME_VALUE + ' ' + SOME_VALUE_2 - def test_should_strip_double_space(self): - assert strip_whitespace(SOME_VALUE + ' ' + SOME_VALUE_2) == SOME_VALUE + ' ' + SOME_VALUE_2 + def test_should_strip_double_line_feed(self): + assert strip_whitespace(SOME_VALUE + '\n\n' + + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2 - def test_should_strip_double_line_feed(self): - assert strip_whitespace(SOME_VALUE + '\n\n' + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2 + def test_should_replace_cr_with_line_feed(self): + assert strip_whitespace( + SOME_VALUE + '\r' + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2 - def test_should_replace_cr_with_line_feed(self): - assert strip_whitespace(SOME_VALUE + '\r' + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2 + def test_should_strip_spaces_around_line_feed(self): + assert strip_whitespace(SOME_VALUE + ' \n ' + + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2 - def test_should_strip_spaces_around_line_feed(self): - assert strip_whitespace(SOME_VALUE + ' \n ' + SOME_VALUE_2) == SOME_VALUE + '\n' + SOME_VALUE_2 + def test_should_strip_multiple_lines_with_blanks(self): + assert ( + strip_whitespace(SOME_VALUE + ' \n \n \n ' + SOME_VALUE_2) == + SOME_VALUE + '\n' + SOME_VALUE_2 + ) - def test_should_strip_multiple_lines_with_blanks(self): - assert ( - strip_whitespace(SOME_VALUE + ' \n \n \n ' + SOME_VALUE_2) == - SOME_VALUE + '\n' + SOME_VALUE_2 - ) class TestXmlRootToTargetAnnotations(object): - def test_should_return_empty_target_annotations_for_empty_xml(self): - xml_root = E.article( - ) - xml_mapping = { - 'article': { - 'title': 'title' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert target_annotations == [] - - def test_should_return_empty_target_annotations_for_no_matching_annotations(self): - xml_root = E.article( - E.other(SOME_VALUE) - ) - xml_mapping = { - 'article': { - TAG1: 'title' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert target_annotations == [] - - def test_should_return_matching_target_annotations(self): - xml_root = E.article( - E.title(SOME_VALUE) - ) - xml_mapping = { - 'article': { - TAG1: 'title' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert len(target_annotations) == 1 - assert target_annotations[0].name == TAG1 - assert target_annotations[0].value == SOME_VALUE - - def test_should_strip_extra_space(self): - xml_root = E.article( - E.abstract(SOME_VALUE + ' ' + SOME_VALUE_2) - ) - xml_mapping = { - 'article': { - TAG1: 'abstract' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert len(target_annotations) == 1 - assert target_annotations[0].name == TAG1 - assert target_annotations[0].value == SOME_VALUE + ' ' + SOME_VALUE_2 - - def test_should_apply_regex_to_result(self): - xml_root = E.article( - E.title('1.1. ' + SOME_VALUE) - ) - xml_mapping = { - 'article': { - TAG1: 'title', - TAG1 + XmlMappingSuffix.REGEX: r'(?:\d+\.?)* ?(.*)' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert len(target_annotations) == 1 - assert target_annotations[0].name == TAG1 - assert target_annotations[0].value == SOME_VALUE - - def test_should_apply_match_multiple_flag(self): - xml_root = E.article( - E.title(SOME_VALUE) - ) - xml_mapping = { - 'article': { - TAG1: 'title', - TAG1 + XmlMappingSuffix.MATCH_MULTIPLE: 'true' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [t.match_multiple for t in target_annotations] == [True] - - def test_should_not_apply_match_multiple_flag_if_not_set(self): - xml_root = E.article( - E.title(SOME_VALUE) - ) - xml_mapping = { - 'article': { - TAG1: 'title' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [t.match_multiple for t in target_annotations] == [False] - - def test_should_apply_match_bonding_flag(self): - xml_root = E.article( - E.title(SOME_VALUE) - ) - xml_mapping = { - 'article': { - TAG1: 'title', - TAG1 + XmlMappingSuffix.BONDING: 'true' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [t.bonding for t in target_annotations] == [True] - - def test_should_not_apply_match_bonding_flag_if_not_set(self): - xml_root = E.article( - E.title(SOME_VALUE) - ) - xml_mapping = { - 'article': { - TAG1: 'title' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [t.bonding for t in target_annotations] == [False] - - def test_should_apply_match_require_next_flag(self): - xml_root = E.article( - E.title(SOME_VALUE) - ) - xml_mapping = { - 'article': { - TAG1: 'title', - TAG1 + XmlMappingSuffix.REQUIRE_NEXT: 'true' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [t.require_next for t in target_annotations] == [True] - - def test_should_not_apply_match_require_next_flag_if_not_set(self): - xml_root = E.article( - E.title(SOME_VALUE) - ) - xml_mapping = { - 'article': { - TAG1: 'title' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [t.require_next for t in target_annotations] == [False] - - def test_should_use_multiple_xpaths(self): - xml_root = E.article( - E.entry( - E.child1(SOME_VALUE), - E.child2(SOME_VALUE_2) - ) - ) - xml_mapping = { - 'article': { - TAG1: '\n{}\n{}\n'.format( - 'entry/child1', - 'entry/child2' + def test_should_return_empty_target_annotations_for_empty_xml(self): + xml_root = E.article( + ) + xml_mapping = { + 'article': { + 'title': 'title' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert target_annotations == [] + + def test_should_return_empty_target_annotations_for_no_matching_annotations(self): + xml_root = E.article( + E.other(SOME_VALUE) + ) + xml_mapping = { + 'article': { + TAG1: 'title' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert target_annotations == [] + + def test_should_return_matching_target_annotations(self): + xml_root = E.article( + E.title(SOME_VALUE) + ) + xml_mapping = { + 'article': { + TAG1: 'title' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert len(target_annotations) == 1 + assert target_annotations[0].name == TAG1 + assert target_annotations[0].value == SOME_VALUE + + def test_should_strip_extra_space(self): + xml_root = E.article( + E.abstract(SOME_VALUE + ' ' + SOME_VALUE_2) + ) + xml_mapping = { + 'article': { + TAG1: 'abstract' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert len(target_annotations) == 1 + assert target_annotations[0].name == TAG1 + assert target_annotations[0].value == SOME_VALUE + ' ' + SOME_VALUE_2 + + def test_should_apply_regex_to_result(self): + xml_root = E.article( + E.title('1.1. ' + SOME_VALUE) + ) + xml_mapping = { + 'article': { + TAG1: 'title', + TAG1 + XmlMappingSuffix.REGEX: r'(?:\d+\.?)* ?(.*)' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert len(target_annotations) == 1 + assert target_annotations[0].name == TAG1 + assert target_annotations[0].value == SOME_VALUE + + def test_should_apply_match_multiple_flag(self): + xml_root = E.article( + E.title(SOME_VALUE) + ) + xml_mapping = { + 'article': { + TAG1: 'title', + TAG1 + XmlMappingSuffix.MATCH_MULTIPLE: 'true' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [t.match_multiple for t in target_annotations] == [True] + + def test_should_not_apply_match_multiple_flag_if_not_set(self): + xml_root = E.article( + E.title(SOME_VALUE) + ) + xml_mapping = { + 'article': { + TAG1: 'title' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [t.match_multiple for t in target_annotations] == [False] + + def test_should_apply_match_bonding_flag(self): + xml_root = E.article( + E.title(SOME_VALUE) + ) + xml_mapping = { + 'article': { + TAG1: 'title', + TAG1 + XmlMappingSuffix.BONDING: 'true' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [t.bonding for t in target_annotations] == [True] + + def test_should_not_apply_match_bonding_flag_if_not_set(self): + xml_root = E.article( + E.title(SOME_VALUE) + ) + xml_mapping = { + 'article': { + TAG1: 'title' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [t.bonding for t in target_annotations] == [False] + + def test_should_apply_match_require_next_flag(self): + xml_root = E.article( + E.title(SOME_VALUE) + ) + xml_mapping = { + 'article': { + TAG1: 'title', + TAG1 + XmlMappingSuffix.REQUIRE_NEXT: 'true' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [t.require_next for t in target_annotations] == [True] + + def test_should_not_apply_match_require_next_flag_if_not_set(self): + xml_root = E.article( + E.title(SOME_VALUE) + ) + xml_mapping = { + 'article': { + TAG1: 'title' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [t.require_next for t in target_annotations] == [False] + + def test_should_use_multiple_xpaths(self): + xml_root = E.article( + E.entry( + E.child1(SOME_VALUE), + E.child2(SOME_VALUE_2) + ) + ) + xml_mapping = { + 'article': { + TAG1: '\n{}\n{}\n'.format( + 'entry/child1', + 'entry/child2' + ) + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, SOME_VALUE), + (TAG1, SOME_VALUE_2) + ] + + def test_should_apply_children_xpaths_and_sort_by_value_descending(self): + xml_root = E.article( + E.entry( + E.child1(SOME_SHORTER_VALUE), + E.child2(SOME_LONGER_VALUE) + ), + E.entry( + E.child1(SOME_LONGER_VALUE) + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: './/*' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE]), + (TAG1, SOME_LONGER_VALUE) + ] + + def test_should_apply_children_xpaths_and_exclude_parents(self): + xml_root = E.article( + E.entry( + E.parent( + E.child2(SOME_LONGER_VALUE), + E.child1(SOME_SHORTER_VALUE) + ) + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: './/*' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE]) + ] + + def test_should_apply_children_xpaths_and_include_parent_text_between_matched_children(self): + xml_root = E.article( + E.entry( + E.parent( + E.child2(SOME_LONGER_VALUE), + SOME_VALUE, + E.child1(SOME_SHORTER_VALUE) + ) + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: './/*' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, [SOME_LONGER_VALUE, SOME_VALUE, SOME_SHORTER_VALUE]) + ] + + def test_should_apply_multiple_children_xpaths_and_include_parent_text_if_enabled(self): + xml_root = E.article( + E.entry( + E.child1(SOME_SHORTER_VALUE), + SOME_LONGER_VALUE + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: '\n{}\n{}\n'.format('.//*', '.'), + TAG1 + XmlMappingSuffix.UNMATCHED_PARENT_TEXT: 'true' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE]) + ] + + def test_should_apply_concat_children(self): + num_values = ['101', '202'] + xml_root = E.article( + E.entry( + E.parent( + E.child1(SOME_VALUE), + E.fpage(num_values[0]), + E.lpage(num_values[1]) + ) + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: './/*', + TAG1 + XmlMappingSuffix.CHILDREN_CONCAT: json.dumps([[{ + 'xpath': './/fpage' + }, { + 'value': '-' + }, { + 'xpath': './/lpage' + }]]) + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, [SOME_VALUE, '-'.join(num_values)]) + ] + + def test_should_not_apply_concat_children_if_one_node_was_not_found(self): + num_values = ['101', '202'] + xml_root = E.article( + E.entry( + E.parent( + E.child1(SOME_VALUE), + E.fpage(num_values[0]), + E.lpage(num_values[1]) + ) + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: './/*', + TAG1 + XmlMappingSuffix.CHILDREN_CONCAT: json.dumps([[{ + 'xpath': './/fpage' + }, { + 'value': '-' + }, { + 'xpath': './/unknown' + }]]) + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, [SOME_VALUE, num_values[0], num_values[1]]) + ] + + def test_should_apply_range_children(self): + num_values = [101, 102, 103, 104, 105, 106, 107] + xml_root = E.article( + E.entry( + E.child1(SOME_VALUE), + E.fpage(str(min(num_values))), + E.lpage(str(max(num_values))) + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage', + TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{ + 'min': { + 'xpath': 'fpage' + }, + 'max': { + 'xpath': 'lpage' + } + }]) + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, [str(x) for x in num_values]) + ] + + def test_should_apply_range_children_as_separate_target_annotations(self): + num_values = [101, 102, 103, 104, 105, 106, 107] + xml_root = E.article( + E.entry( + E.child1(SOME_VALUE), + E.fpage(str(min(num_values))), + E.lpage(str(max(num_values))) + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage', + TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{ + 'min': { + 'xpath': 'fpage' + }, + 'max': { + 'xpath': 'lpage' + }, + 'standalone': True + }]) + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, str(x)) + for x in num_values + ] + + def test_should_not_apply_range_children_if_xpath_not_matching(self): + num_values = [101, 102, 103, 104, 105, 106, 107] + fpage = str(min(num_values)) + lpage = str(max(num_values)) + xml_root = E.article( + E.entry( + E.child1(SOME_VALUE), + E.fpage(fpage), + E.lpage(lpage) + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|unknown', + TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{ + 'min': { + 'xpath': 'fpage' + }, + 'max': { + 'xpath': 'unknown' + } + }]) + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, fpage) + ] + + def test_should_not_apply_range_children_if_value_is_not_integer(self): + fpage = 'abc' + lpage = 'xyz' + xml_root = E.article( + E.entry( + E.child1(SOME_VALUE), + E.fpage(fpage), + E.lpage(lpage) + ) + ) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage', + TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{ + 'min': { + 'xpath': 'fpage' + }, + 'max': { + 'xpath': 'lpage' + } + }]) + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, [fpage, lpage]) + ] + + def test_should_add_sub_annotations(self): + xml_root = E.article( + E.entry( + E.firstname(SOME_VALUE), + E.givennames(SOME_VALUE_2) + ) ) - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, SOME_VALUE), - (TAG1, SOME_VALUE_2) - ] - - def test_should_apply_children_xpaths_and_sort_by_value_descending(self): - xml_root = E.article( - E.entry( - E.child1(SOME_SHORTER_VALUE), - E.child2(SOME_LONGER_VALUE) - ), - E.entry( - E.child1(SOME_LONGER_VALUE) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: './/*' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE]), - (TAG1, SOME_LONGER_VALUE) - ] - - def test_should_apply_children_xpaths_and_exclude_parents(self): - xml_root = E.article( - E.entry( - E.parent( - E.child2(SOME_LONGER_VALUE), - E.child1(SOME_SHORTER_VALUE) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.SUB + '.firstname': './firstname', + TAG1 + XmlMappingSuffix.SUB + '.givennames': './givennames', + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations[0].sub_annotations] == [ + ('firstname', SOME_VALUE), + ('givennames', SOME_VALUE_2) + ] + + def test_should_add_sub_annotations_with_multiple_values(self): + xml_root = E.article( + E.entry( + E.value(SOME_VALUE), + E.value(SOME_VALUE_2) + ) ) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: './/*' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE]) - ] - - def test_should_apply_children_xpaths_and_include_parent_text_between_matched_children(self): - xml_root = E.article( - E.entry( - E.parent( - E.child2(SOME_LONGER_VALUE), - SOME_VALUE, - E.child1(SOME_SHORTER_VALUE) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.SUB + '.value': './value' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(t.name, t.value) for t in target_annotations[0].sub_annotations] == [ + ('value', SOME_VALUE), + ('value', SOME_VALUE_2) + ] + + def test_should_extract_numbers_from_value_after_text(self): + xml_root = E.article(E.entry( + E.value(SOME_VALUE + ' 12345') + )) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert len(target_annotations) == 1 + assert [(t.name, set(t.value)) for t in target_annotations] == [ + (TAG1, {SOME_VALUE + ' 12345', SOME_VALUE, '12345'}) + ] + + def test_should_extract_single_value_if_its_the_only_value(self): + xml_root = E.article(E.entry( + E.value('12345') + )) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert len(target_annotations) == 1 + assert [(t.name, t.value) for t in target_annotations] == [ + (TAG1, '12345') + ] + + def test_should_unnest_extract_value_from_children(self): + xml_root = E.article(E.entry( + E.value(SOME_VALUE + ' 12345'), + E.value(SOME_VALUE_2 + ' 54321') + )) + xml_mapping = { + 'article': { + TAG1: 'entry', + TAG1 + XmlMappingSuffix.CHILDREN: r'.//*', + TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert len(target_annotations) == 1 + assert [(t.name, set(t.value)) for t in target_annotations] == [ + (TAG1, { + SOME_VALUE + ' 12345', SOME_VALUE, '12345', + SOME_VALUE_2 + ' 54321', SOME_VALUE_2, '54321' + }) + ] + + def test_should_extract_numbers_from_sub_value_after_text(self): + xml_root = E.article(E.entry( + E.value(SOME_VALUE + ' 12345') + )) + sub_key = TAG1 + XmlMappingSuffix.SUB + '.value' + xml_mapping = { + 'article': { + TAG1: 'entry', + sub_key: './value', + sub_key + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert len(target_annotations) == 1 + assert [(t.name, set(t.value)) for t in target_annotations[0].sub_annotations] == [ + ('value', {SOME_VALUE + ' 12345', SOME_VALUE, '12345'}) + ] + + def test_should_return_full_text(self): + xml_root = E.article( + E.title( + 'some ', + E.other('embedded'), + ' text' + ) ) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: './/*' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, [SOME_LONGER_VALUE, SOME_VALUE, SOME_SHORTER_VALUE]) - ] - - def test_should_apply_multiple_children_xpaths_and_include_parent_text_if_enabled(self): - xml_root = E.article( - E.entry( - E.child1(SOME_SHORTER_VALUE), - SOME_LONGER_VALUE - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: '\n{}\n{}\n'.format('.//*', '.'), - TAG1 + XmlMappingSuffix.UNMATCHED_PARENT_TEXT: 'true' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, [SOME_LONGER_VALUE, SOME_SHORTER_VALUE]) - ] - - def test_should_apply_concat_children(self): - num_values = ['101', '202'] - xml_root = E.article( - E.entry( - E.parent( - E.child1(SOME_VALUE), - E.fpage(num_values[0]), - E.lpage(num_values[1]) + xml_mapping = { + 'article': { + TAG1: 'title' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert len(target_annotations) == 1 + assert target_annotations[0].name == TAG1 + assert target_annotations[0].value == 'some embedded text' + + def test_should_return_target_annotations_in_order_of_xml(self): + xml_root = E.article( + E.tag1('tag1.1'), E.tag2('tag2.1'), E.tag1('tag1.2'), E.tag2('tag2.2'), ) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: './/*', - TAG1 + XmlMappingSuffix.CHILDREN_CONCAT: json.dumps([[{ - 'xpath': './/fpage' - }, { - 'value': '-' - }, { - 'xpath': './/lpage' - }]]) - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, [SOME_VALUE, '-'.join(num_values)]) - ] - - def test_should_not_apply_concat_children_if_one_node_was_not_found(self): - num_values = ['101', '202'] - xml_root = E.article( - E.entry( - E.parent( - E.child1(SOME_VALUE), - E.fpage(num_values[0]), - E.lpage(num_values[1]) + xml_mapping = { + 'article': { + TAG1: 'tag1', + TAG2: 'tag2' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(ta.name, ta.value) for ta in target_annotations] == [ + (TAG1, 'tag1.1'), (TAG2, 'tag2.1'), (TAG1, 'tag1.2'), (TAG2, 'tag2.2') + ] + + def test_should_return_target_annotations_in_order_of_priority_first(self): + xml_root = E.article( + E.tag1('tag1.1'), E.tag2('tag2.1'), E.tag1('tag1.2'), E.tag2('tag2.2'), ) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: './/*', - TAG1 + XmlMappingSuffix.CHILDREN_CONCAT: json.dumps([[{ - 'xpath': './/fpage' - }, { - 'value': '-' - }, { - 'xpath': './/unknown' - }]]) - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, [SOME_VALUE, num_values[0], num_values[1]]) - ] - - def test_should_apply_range_children(self): - num_values = [101, 102, 103, 104, 105, 106, 107] - xml_root = E.article( - E.entry( - E.child1(SOME_VALUE), - E.fpage(str(min(num_values))), - E.lpage(str(max(num_values))) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage', - TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{ - 'min': { - 'xpath': 'fpage' - }, - 'max': { - 'xpath': 'lpage' - } - }]) - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, [str(x) for x in num_values]) - ] - - def test_should_apply_range_children_as_separate_target_annotations(self): - num_values = [101, 102, 103, 104, 105, 106, 107] - xml_root = E.article( - E.entry( - E.child1(SOME_VALUE), - E.fpage(str(min(num_values))), - E.lpage(str(max(num_values))) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage', - TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{ - 'min': { - 'xpath': 'fpage' - }, - 'max': { - 'xpath': 'lpage' - }, - 'standalone': True - }]) - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, str(x)) - for x in num_values - ] - - def test_should_not_apply_range_children_if_xpath_not_matching(self): - num_values = [101, 102, 103, 104, 105, 106, 107] - fpage = str(min(num_values)) - lpage = str(max(num_values)) - xml_root = E.article( - E.entry( - E.child1(SOME_VALUE), - E.fpage(fpage), - E.lpage(lpage) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|unknown', - TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{ - 'min': { - 'xpath': 'fpage' - }, - 'max': { - 'xpath': 'unknown' - } - }]) - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, fpage) - ] - - def test_should_not_apply_range_children_if_value_is_not_integer(self): - fpage = 'abc' - lpage = 'xyz' - xml_root = E.article( - E.entry( - E.child1(SOME_VALUE), - E.fpage(fpage), - E.lpage(lpage) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: 'fpage|lpage', - TAG1 + XmlMappingSuffix.CHILDREN_RANGE: json.dumps([{ - 'min': { - 'xpath': 'fpage' - }, - 'max': { - 'xpath': 'lpage' - } - }]) - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, [fpage, lpage]) - ] - - def test_should_add_sub_annotations(self): - xml_root = E.article( - E.entry( - E.firstname(SOME_VALUE), - E.givennames(SOME_VALUE_2) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.SUB + '.firstname': './firstname', - TAG1 + XmlMappingSuffix.SUB + '.givennames': './givennames', - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations[0].sub_annotations] == [ - ('firstname', SOME_VALUE), - ('givennames', SOME_VALUE_2) - ] - - def test_should_add_sub_annotations_with_multiple_values(self): - xml_root = E.article( - E.entry( - E.value(SOME_VALUE), - E.value(SOME_VALUE_2) - ) - ) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.SUB + '.value': './value' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(t.name, t.value) for t in target_annotations[0].sub_annotations] == [ - ('value', SOME_VALUE), - ('value', SOME_VALUE_2) - ] - - def test_should_extract_numbers_from_value_after_text(self): - xml_root = E.article(E.entry( - E.value(SOME_VALUE + ' 12345') - )) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert len(target_annotations) == 1 - assert [(t.name, set(t.value)) for t in target_annotations] == [ - (TAG1, {SOME_VALUE + ' 12345', SOME_VALUE, '12345'}) - ] - - def test_should_extract_single_value_if_its_the_only_value(self): - xml_root = E.article(E.entry( - E.value('12345') - )) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert len(target_annotations) == 1 - assert [(t.name, t.value) for t in target_annotations] == [ - (TAG1, '12345') - ] - - def test_should_unnest_extract_value_from_children(self): - xml_root = E.article(E.entry( - E.value(SOME_VALUE + ' 12345'), - E.value(SOME_VALUE_2 + ' 54321') - )) - xml_mapping = { - 'article': { - TAG1: 'entry', - TAG1 + XmlMappingSuffix.CHILDREN: r'.//*', - TAG1 + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert len(target_annotations) == 1 - assert [(t.name, set(t.value)) for t in target_annotations] == [ - (TAG1, { - SOME_VALUE + ' 12345', SOME_VALUE, '12345', - SOME_VALUE_2 + ' 54321', SOME_VALUE_2, '54321' - }) - ] - - def test_should_extract_numbers_from_sub_value_after_text(self): - xml_root = E.article(E.entry( - E.value(SOME_VALUE + ' 12345') - )) - sub_key = TAG1 + XmlMappingSuffix.SUB + '.value' - xml_mapping = { - 'article': { - TAG1: 'entry', - sub_key: './value', - sub_key + XmlMappingSuffix.EXTRACT_REGEX: r'.*\b(\d+)\b.*' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert len(target_annotations) == 1 - assert [(t.name, set(t.value)) for t in target_annotations[0].sub_annotations] == [ - ('value', {SOME_VALUE + ' 12345', SOME_VALUE, '12345'}) - ] - - def test_should_return_full_text(self): - xml_root = E.article( - E.title( - 'some ', - E.other('embedded'), - ' text' - ) - ) - xml_mapping = { - 'article': { - TAG1: 'title' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert len(target_annotations) == 1 - assert target_annotations[0].name == TAG1 - assert target_annotations[0].value == 'some embedded text' - - def test_should_return_target_annotations_in_order_of_xml(self): - xml_root = E.article( - E.tag1('tag1.1'), E.tag2('tag2.1'), E.tag1('tag1.2'), E.tag2('tag2.2'), - ) - xml_mapping = { - 'article': { - TAG1: 'tag1', - TAG2: 'tag2' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(ta.name, ta.value) for ta in target_annotations] == [ - (TAG1, 'tag1.1'), (TAG2, 'tag2.1'), (TAG1, 'tag1.2'), (TAG2, 'tag2.2') - ] - - def test_should_return_target_annotations_in_order_of_priority_first(self): - xml_root = E.article( - E.tag1('tag1.1'), E.tag2('tag2.1'), E.tag1('tag1.2'), E.tag2('tag2.2'), - ) - xml_mapping = { - 'article': { - TAG1: 'tag1', - TAG2: 'tag2', - TAG2 + XmlMappingSuffix.PRIORITY: '1' - } - } - target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) - assert [(ta.name, ta.value) for ta in target_annotations] == [ - (TAG2, 'tag2.1'), (TAG2, 'tag2.2'), (TAG1, 'tag1.1'), (TAG1, 'tag1.2') - ] + xml_mapping = { + 'article': { + TAG1: 'tag1', + TAG2: 'tag2', + TAG2 + XmlMappingSuffix.PRIORITY: '1' + } + } + target_annotations = xml_root_to_target_annotations(xml_root, xml_mapping) + assert [(ta.name, ta.value) for ta in target_annotations] == [ + (TAG2, 'tag2.1'), (TAG2, 'tag2.2'), (TAG1, 'tag1.1'), (TAG1, 'tag1.2') + ] diff --git a/sciencebeam_gym/preprocess/blockify_annotations.py b/sciencebeam_gym/preprocess/blockify_annotations.py index 355604d..7f30755 100644 --- a/sciencebeam_gym/preprocess/blockify_annotations.py +++ b/sciencebeam_gym/preprocess/blockify_annotations.py @@ -10,267 +10,289 @@ from pyqtree import Index as PqtreeIndex from PIL import Image, ImageDraw, ImageColor from sciencebeam_gym.structured_document.svg import ( - SVG_NSMAP, - SVG_DOC, - SVG_RECT, + SVG_NSMAP, + SVG_DOC, + SVG_RECT, ) DEFAULT_NEARBY_TOLERANCE = 5 + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class AnnotationBlock(object): - def __init__(self, tag, bounding_box): - self.tag = tag - self.bounding_box = bounding_box - - def merge_with(self, other): - return AnnotationBlock( - self.tag, - self.bounding_box.include(other.bounding_box) - ) + def __init__(self, tag, bounding_box): + self.tag = tag + self.bounding_box = bounding_box + + def merge_with(self, other): + return AnnotationBlock( + self.tag, + self.bounding_box.include(other.bounding_box) + ) + + def __str__(self): + return 'AnnotationBlock({}, {})'.format(self.tag, self.bounding_box) - def __str__(self): - return 'AnnotationBlock({}, {})'.format(self.tag, self.bounding_box) + def __repr__(self): + return str(self) - def __repr__(self): - return str(self) class BlockPoint(object): - def __init__(self, block, x, y): - self.block = block - self.point = (x, y) + def __init__(self, block, x, y): + self.block = block + self.point = (x, y) - def __str__(self): - return 'BlockPoint({}, {})'.format(self.block, self.point) + def __str__(self): + return 'BlockPoint({}, {})'.format(self.block, self.point) - def __repr__(self): - return str(self) + def __repr__(self): + return str(self) - def __len__(self): - return len(self.point) + def __len__(self): + return len(self.point) + + def __getitem__(self, index): + return self.point[index] - def __getitem__(self, index): - return self.point[index] def _to_bbox(bb): - return (bb.x, bb.y, bb.x + bb.width - 1, bb.y + bb.height - 1) + return (bb.x, bb.y, bb.x + bb.width - 1, bb.y + bb.height - 1) + ProcessedWrapper = namedtuple('ProcessedWrapper', field_names=['data', 'deleted']) -class DeletableWrapper(object): - def __init__(self, data): - self.data = data - self.deleted = False - def __hash__(self): - return hash(self.data) +class DeletableWrapper(object): + def __init__(self, data): + self.data = data + self.deleted = False - def __eq__(self, other): - return self.data == other.data + def __hash__(self): + return hash(self.data) -class BlockSearch(object): - def __init__(self, blocks): - bboxs = [block.bounding_box for block in blocks] - xmax = max([bb.x + bb.width for bb in bboxs]) - ymax = max([bb.y + bb.height for bb in bboxs]) - self.spindex = PqtreeIndex(bbox=(0, 0, xmax, ymax)) - self.wrapper_map = {} - for block in blocks: - wrapper = DeletableWrapper(block) - self.wrapper_map[block] = wrapper - self.spindex.insert(wrapper, _to_bbox(block.bounding_box)) + def __eq__(self, other): + return self.data == other.data - def find_intersection_with(self, search_bounding_box): - return [ - wrapper.data - for wrapper in self.spindex.intersect(_to_bbox(search_bounding_box)) - if not wrapper.deleted - ] - def remove(self, block): - wrapper = self.wrapper_map.get(block) - if wrapper is not None: - wrapper.deleted = True +class BlockSearch(object): + def __init__(self, blocks): + bboxs = [block.bounding_box for block in blocks] + xmax = max([bb.x + bb.width for bb in bboxs]) + ymax = max([bb.y + bb.height for bb in bboxs]) + self.spindex = PqtreeIndex(bbox=(0, 0, xmax, ymax)) + self.wrapper_map = {} + for block in blocks: + wrapper = DeletableWrapper(block) + self.wrapper_map[block] = wrapper + self.spindex.insert(wrapper, _to_bbox(block.bounding_box)) + + def find_intersection_with(self, search_bounding_box): + return [ + wrapper.data + for wrapper in self.spindex.intersect(_to_bbox(search_bounding_box)) + if not wrapper.deleted + ] + + def remove(self, block): + wrapper = self.wrapper_map.get(block) + if wrapper is not None: + wrapper.deleted = True def merge_blocks(blocks, nearby_tolerance=0): - if len(blocks) <= 1: - return blocks - merged_blocks = deque() - logger = get_logger() - logger.debug('nearby_tolerance: %s', nearby_tolerance) - logger.debug('blocks: %s', blocks) - logger.debug('bboxs: %s', [_to_bbox(block.bounding_box) for block in blocks]) - tags = sorted(set([b.tag for b in blocks])) - logger.debug('tags: %s', tags) - remaining_blocks = deque(blocks) - search_by_tag = { - tag: BlockSearch([b for b in remaining_blocks if b.tag == tag]) - for tag in tags - } - while len(remaining_blocks) >= 2: - merged_block = remaining_blocks.popleft() - search = search_by_tag[merged_block.tag] - search.remove(merged_block) - search_bounding_box = merged_block.bounding_box.with_margin(1 + nearby_tolerance, 0) - logger.debug('search_bounding_box: %s (%s)', search_bounding_box, _to_bbox(search_bounding_box)) - neighbours = search.find_intersection_with(search_bounding_box) - logger.debug('neighbours: %s', neighbours) - neighbours_blocks_count = 0 - for neighbour in neighbours: - if neighbour.tag == merged_block.tag: - merged_block = merged_block.merge_with(neighbour) - search.remove(neighbour) - remaining_blocks.remove(neighbour) - neighbours_blocks_count += 1 - if neighbours_blocks_count == 0 or len(remaining_blocks) == 0: - logger.debug( - 'no or all remaining blocks merged, mark block as merged: %d', - neighbours_blocks_count - ) - merged_blocks.append(merged_block) - else: - logger.debug( - 'some but not all remaining blocks merged, continue search: %d', - neighbours_blocks_count - ) - remaining_blocks.appendleft(merged_block) - result = list(merged_blocks) + list(remaining_blocks) - return result + if len(blocks) <= 1: + return blocks + merged_blocks = deque() + logger = get_logger() + logger.debug('nearby_tolerance: %s', nearby_tolerance) + logger.debug('blocks: %s', blocks) + logger.debug('bboxs: %s', [_to_bbox(block.bounding_box) for block in blocks]) + tags = sorted(set([b.tag for b in blocks])) + logger.debug('tags: %s', tags) + remaining_blocks = deque(blocks) + search_by_tag = { + tag: BlockSearch([b for b in remaining_blocks if b.tag == tag]) + for tag in tags + } + while len(remaining_blocks) >= 2: + merged_block = remaining_blocks.popleft() + search = search_by_tag[merged_block.tag] + search.remove(merged_block) + search_bounding_box = merged_block.bounding_box.with_margin(1 + nearby_tolerance, 0) + logger.debug('search_bounding_box: %s (%s)', + search_bounding_box, _to_bbox(search_bounding_box)) + neighbours = search.find_intersection_with(search_bounding_box) + logger.debug('neighbours: %s', neighbours) + neighbours_blocks_count = 0 + for neighbour in neighbours: + if neighbour.tag == merged_block.tag: + merged_block = merged_block.merge_with(neighbour) + search.remove(neighbour) + remaining_blocks.remove(neighbour) + neighbours_blocks_count += 1 + if neighbours_blocks_count == 0 or len(remaining_blocks) == 0: + logger.debug( + 'no or all remaining blocks merged, mark block as merged: %d', + neighbours_blocks_count + ) + merged_blocks.append(merged_block) + else: + logger.debug( + 'some but not all remaining blocks merged, continue search: %d', + neighbours_blocks_count + ) + remaining_blocks.appendleft(merged_block) + result = list(merged_blocks) + list(remaining_blocks) + return result + def expand_bounding_box(bb): - return bb.with_margin(4, 2) + return bb.with_margin(4, 2) + def expand_block(block): - return AnnotationBlock(block.tag, expand_bounding_box(block.bounding_box)) + return AnnotationBlock(block.tag, expand_bounding_box(block.bounding_box)) + def expand_blocks(blocks): - return [expand_block(block) for block in blocks] + return [expand_block(block) for block in blocks] + def annotation_document_page_to_annotation_blocks(structured_document, page): - tags_and_tokens = ( - (structured_document.get_tag_value(token), token) - for line in structured_document.get_lines_of_page(page) - for token in structured_document.get_tokens_of_line(line) - ) - tags_and_bounding_boxes = ( - (tag, structured_document.get_bounding_box(token)) - for tag, token in tags_and_tokens - if tag - ) - return [ - AnnotationBlock(tag, bounding_box) - for tag, bounding_box in tags_and_bounding_boxes - if bounding_box - ] + tags_and_tokens = ( + (structured_document.get_tag_value(token), token) + for line in structured_document.get_lines_of_page(page) + for token in structured_document.get_tokens_of_line(line) + ) + tags_and_bounding_boxes = ( + (tag, structured_document.get_bounding_box(token)) + for tag, token in tags_and_tokens + if tag + ) + return [ + AnnotationBlock(tag, bounding_box) + for tag, bounding_box in tags_and_bounding_boxes + if bounding_box + ] + def annotation_document_page_to_merged_blocks(structured_document, page, **kwargs): - return merge_blocks( - annotation_document_page_to_annotation_blocks(structured_document, page), - **kwargs - ) + return merge_blocks( + annotation_document_page_to_annotation_blocks(structured_document, page), + **kwargs + ) + def extend_color_map_for_tags(color_map, tags): - updated_color_map = dict(color_map) - for tag in tags: - if tag not in updated_color_map: - updated_color_map[tag] = ( - max(updated_color_map.values()) + 1 if len(updated_color_map) > 0 else 1 - ) - return updated_color_map + updated_color_map = dict(color_map) + for tag in tags: + if tag not in updated_color_map: + updated_color_map[tag] = ( + max(updated_color_map.values()) + 1 if len(updated_color_map) > 0 else 1 + ) + return updated_color_map + def extend_color_map_for_blocks(color_map, blocks): - return extend_color_map_for_tags( - color_map, - sorted(set([b.tag for b in blocks])) - ) + return extend_color_map_for_tags( + color_map, + sorted(set([b.tag for b in blocks])) + ) + class AbstractSurface(object, with_metaclass(ABCMeta)): - @abstractmethod - def rect(self, bounding_box, color, tag=None): - pass - -class SvgSurface(object): - def __init__(self, width, height, background): - if not (width and height): - raise AttributeError('width and height are required') - - self.svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP, attrib={ - 'width': str(width), - 'height': str(height) - }) - - if background: - self.svg_root.append(etree.Element(SVG_RECT, attrib={ - 'width': '100%', - 'height': '100%', - 'fill': background, - 'class': 'background' - })) - - def rect(self, bounding_box, color, tag=None): - attrib = { - 'class': str(tag), - 'shape-rendering': 'crispEdges', - 'x': str(bounding_box.x), - 'y': str(bounding_box.y), - 'width': str(bounding_box.width), - 'height': str(bounding_box.height) - } - if color: - attrib['fill'] = str(color) - rect = etree.Element(SVG_RECT, attrib=attrib) - self.svg_root.append(rect) - return rect + @abstractmethod + def rect(self, bounding_box, color, tag=None): + pass + + +class SvgSurface(AbstractSurface): + def __init__(self, width, height, background): + if not (width and height): + raise AttributeError('width and height are required') + + self.svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP, attrib={ + 'width': str(width), + 'height': str(height) + }) + + if background: + self.svg_root.append(etree.Element(SVG_RECT, attrib={ + 'width': '100%', + 'height': '100%', + 'fill': background, + 'class': 'background' + })) + + def rect(self, bounding_box, color, tag=None): + attrib = { + 'class': str(tag), + 'shape-rendering': 'crispEdges', + 'x': str(bounding_box.x), + 'y': str(bounding_box.y), + 'width': str(bounding_box.width), + 'height': str(bounding_box.height) + } + if color: + attrib['fill'] = str(color) + rect = etree.Element(SVG_RECT, attrib=attrib) + self.svg_root.append(rect) + return rect + def color_to_tuple(color): - if isinstance(color, tuple): - return color - return ImageColor.getrgb(color) - -class ImageSurface(object): - def __init__(self, width, height, background): - if not (width and height): - raise AttributeError('width and height are required') - - width = int(math.ceil(width)) - height = int(math.ceil(height)) - if background: - self.image = Image.new('RGB', (width, height), color_to_tuple(background)) - else: - self.image = Image.new('RGBA', (width, height), (255, 255, 255, 0)) - self._draw = ImageDraw.Draw(self.image) - - def rect(self, bounding_box, color, tag=None): - if color is None: - return - self._draw.rectangle( - ( - (bounding_box.x, bounding_box.y), - (bounding_box.x + bounding_box.width, bounding_box.y + bounding_box.height) - ), - fill=color_to_tuple(color) - ) + if isinstance(color, tuple): + return color + return ImageColor.getrgb(color) + + +class ImageSurface(AbstractSurface): + def __init__(self, width, height, background): + if not (width and height): + raise AttributeError('width and height are required') + + width = int(math.ceil(width)) + height = int(math.ceil(height)) + if background: + self.image = Image.new('RGB', (width, height), color_to_tuple(background)) + else: + self.image = Image.new('RGBA', (width, height), (255, 255, 255, 0)) + self._draw = ImageDraw.Draw(self.image) + + def rect(self, bounding_box, color, tag=None): + if color is None: + return + self._draw.rectangle( + ( + (bounding_box.x, bounding_box.y), + (bounding_box.x + bounding_box.width, bounding_box.y + bounding_box.height) + ), + fill=color_to_tuple(color) + ) + def annotated_blocks_to_surface(blocks, surface, color_map): - for block in blocks: - color = color_map.get(block.tag) - surface.rect(block.bounding_box, color, block.tag) + for block in blocks: + color = color_map.get(block.tag) + surface.rect(block.bounding_box, color, block.tag) + def annotated_blocks_to_svg(blocks, color_map, width=None, height=None, background=None): - surface = SvgSurface(width, height, background) - annotated_blocks_to_surface(blocks, surface, color_map) - return surface.svg_root + surface = SvgSurface(width, height, background) + annotated_blocks_to_surface(blocks, surface, color_map) + return surface.svg_root + def annotated_blocks_to_image( - blocks, color_map, width=None, height=None, background=None, - scale_to_size=None): - - surface = ImageSurface(width, height, background) - annotated_blocks_to_surface(blocks, surface, color_map) - image = surface.image - if scale_to_size: - image = image.resize(scale_to_size, Image.NEAREST) - return image + blocks, color_map, width=None, height=None, background=None, + scale_to_size=None): + + surface = ImageSurface(width, height, background) + annotated_blocks_to_surface(blocks, surface, color_map) + image = surface.image + if scale_to_size: + image = image.resize(scale_to_size, Image.NEAREST) + return image diff --git a/sciencebeam_gym/preprocess/blockify_annotations_test.py b/sciencebeam_gym/preprocess/blockify_annotations_test.py index c00e875..6d23c09 100644 --- a/sciencebeam_gym/preprocess/blockify_annotations_test.py +++ b/sciencebeam_gym/preprocess/blockify_annotations_test.py @@ -1,27 +1,27 @@ import logging from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimpleLine, - SimpleToken, - B_TAG_PREFIX + SimpleStructuredDocument, + SimpleLine, + SimpleToken, + B_TAG_PREFIX ) from sciencebeam_gym.structured_document.svg import ( - SVG_NS + SVG_NS ) from sciencebeam_gym.preprocess.blockify_annotations import ( - annotation_document_page_to_annotation_blocks, - annotation_document_page_to_merged_blocks, - annotated_blocks_to_svg, - annotated_blocks_to_image, - merge_blocks, - AnnotationBlock + annotation_document_page_to_annotation_blocks, + annotation_document_page_to_merged_blocks, + annotated_blocks_to_svg, + annotated_blocks_to_image, + merge_blocks, + AnnotationBlock ) TAG1 = 'tag1' @@ -35,245 +35,251 @@ DEFAULT_BOUNDING_BOX = BoundingBox(0, 0, 16, 10) DEFAULT_NEARBY_TOLERANCE = 10 + def setup_module(): - logging.basicConfig(level=logging.DEBUG) + logging.basicConfig(level=logging.DEBUG) + class TestAnnotatedBlocksToSvg(object): - def test_should_create_rect_for_single_annotated_block(self): - blocks = [ - AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) - ] - - result_svg = annotated_blocks_to_svg(blocks, color_map={ - TAG1: DEFAULT_COLOR - }, width=100, height=100) - result_rect_elements = result_svg.xpath('svg:rect', namespaces={'svg': SVG_NS}) - assert len(result_rect_elements) == 1 - assert result_rect_elements[0].attrib['class'] == TAG1 - assert result_rect_elements[0].attrib['fill'] == DEFAULT_COLOR - - def test_should_add_background(self): - blocks = [ - AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) - ] - - result_svg = annotated_blocks_to_svg(blocks, color_map={ - TAG1: '#123' - }, width=100, height=100, background=DEFAULT_COLOR) - background_elements = result_svg.xpath( - 'svg:rect[@class="background"]', namespaces={'svg': SVG_NS} - ) - assert len(background_elements) == 1 - assert background_elements[0].attrib['fill'] == DEFAULT_COLOR + def test_should_create_rect_for_single_annotated_block(self): + blocks = [ + AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) + ] + + result_svg = annotated_blocks_to_svg(blocks, color_map={ + TAG1: DEFAULT_COLOR + }, width=100, height=100) + result_rect_elements = result_svg.xpath('svg:rect', namespaces={'svg': SVG_NS}) + assert len(result_rect_elements) == 1 + assert result_rect_elements[0].attrib['class'] == TAG1 + assert result_rect_elements[0].attrib['fill'] == DEFAULT_COLOR + + def test_should_add_background(self): + blocks = [ + AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) + ] + + result_svg = annotated_blocks_to_svg(blocks, color_map={ + TAG1: '#123' + }, width=100, height=100, background=DEFAULT_COLOR) + background_elements = result_svg.xpath( + 'svg:rect[@class="background"]', namespaces={'svg': SVG_NS} + ) + assert len(background_elements) == 1 + assert background_elements[0].attrib['fill'] == DEFAULT_COLOR + class TestAnnotatedBlocksToImage(object): - def test_should_create_rect_for_single_annotated_block(self): - blocks = [ - AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1)) - ] - - image = annotated_blocks_to_image(blocks, color_map={ - TAG1: (0, 255, 0) - }, width=3, height=3) - assert image.getpixel((0, 0)) == (0, 255, 0, 255) - - def test_should_accept_float_image_size(self): - blocks = [ - AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1)) - ] - - image = annotated_blocks_to_image(blocks, color_map={ - TAG1: (0, 255, 0) - }, width=3.1, height=3.9) - assert image.size == (4, 4) - - def test_should_convert_rect_color_name(self): - blocks = [ - AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1)) - ] - - image = annotated_blocks_to_image(blocks, color_map={ - TAG1: 'green' - }, width=3, height=3) - assert image.getpixel((0, 0)) == (0, 128, 0, 255) - - def test_should_ignore_unmapped_tag(self): - blocks = [ - AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1)) - ] - - image = annotated_blocks_to_image(blocks, color_map={ - }, width=3, height=3) - assert image.getpixel((0, 0)) == (255, 255, 255, 0) - - def test_should_add_background(self): - width = 3 - height = 2 - image = annotated_blocks_to_image([], color_map={ - TAG1: '#123' - }, width=width, height=height, background=(255, 0, 0)) - data = list(image.getdata()) - assert data == [(255, 0, 0)] * (width * height) - - def test_should_convert_background_color_name(self): - width = 3 - height = 2 - image = annotated_blocks_to_image([], color_map={ - TAG1: '#123' - }, width=width, height=height, background='red') - data = list(image.getdata()) - assert data == [(255, 0, 0)] * (width * height) + def test_should_create_rect_for_single_annotated_block(self): + blocks = [ + AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1)) + ] + + image = annotated_blocks_to_image(blocks, color_map={ + TAG1: (0, 255, 0) + }, width=3, height=3) + assert image.getpixel((0, 0)) == (0, 255, 0, 255) + + def test_should_accept_float_image_size(self): + blocks = [ + AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1)) + ] + + image = annotated_blocks_to_image(blocks, color_map={ + TAG1: (0, 255, 0) + }, width=3.1, height=3.9) + assert image.size == (4, 4) + + def test_should_convert_rect_color_name(self): + blocks = [ + AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1)) + ] + + image = annotated_blocks_to_image(blocks, color_map={ + TAG1: 'green' + }, width=3, height=3) + assert image.getpixel((0, 0)) == (0, 128, 0, 255) + + def test_should_ignore_unmapped_tag(self): + blocks = [ + AnnotationBlock(TAG1, BoundingBox(0, 0, 1, 1)) + ] + + image = annotated_blocks_to_image(blocks, color_map={ + }, width=3, height=3) + assert image.getpixel((0, 0)) == (255, 255, 255, 0) + + def test_should_add_background(self): + width = 3 + height = 2 + image = annotated_blocks_to_image([], color_map={ + TAG1: '#123' + }, width=width, height=height, background=(255, 0, 0)) + data = list(image.getdata()) + assert data == [(255, 0, 0)] * (width * height) + + def test_should_convert_background_color_name(self): + width = 3 + height = 2 + image = annotated_blocks_to_image([], color_map={ + TAG1: '#123' + }, width=width, height=height, background='red') + data = list(image.getdata()) + assert data == [(255, 0, 0)] * (width * height) + class TestAnnotationDocumentPageToAnnotationBlocks(object): - def test_should_convert_single_token_to_block_with_same_bounding_box(self): - token = SimpleToken('test', tag=TAG1, bounding_box=DEFAULT_BOUNDING_BOX) - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])]) - - blocks = annotation_document_page_to_annotation_blocks( - structured_document, - structured_document.get_pages()[0] - ) - assert len(blocks) == 1 - - block = blocks[0] - assert block.tag == TAG1 - assert block.bounding_box == DEFAULT_BOUNDING_BOX - - def test_should_strip_tag_prefix(self): - token = SimpleToken( - 'test', tag=TAG1, tag_prefix=B_TAG_PREFIX, - bounding_box=DEFAULT_BOUNDING_BOX - ) - assert token.get_tag() == B_TAG_PREFIX + TAG1 - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])]) - - blocks = annotation_document_page_to_annotation_blocks( - structured_document, - structured_document.get_pages()[0] - ) - assert [b.tag for b in blocks] == [TAG1] - - def test_should_ignore_block_without_bounding_box(self): - token = SimpleToken('test') - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])]) - structured_document.set_tag(token, TAG1) - - blocks = annotation_document_page_to_annotation_blocks( - structured_document, - structured_document.get_pages()[0] - ) - assert len(blocks) == 0 + def test_should_convert_single_token_to_block_with_same_bounding_box(self): + token = SimpleToken('test', tag=TAG1, bounding_box=DEFAULT_BOUNDING_BOX) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])]) + + blocks = annotation_document_page_to_annotation_blocks( + structured_document, + structured_document.get_pages()[0] + ) + assert len(blocks) == 1 + + block = blocks[0] + assert block.tag == TAG1 + assert block.bounding_box == DEFAULT_BOUNDING_BOX + + def test_should_strip_tag_prefix(self): + token = SimpleToken( + 'test', tag=TAG1, tag_prefix=B_TAG_PREFIX, + bounding_box=DEFAULT_BOUNDING_BOX + ) + assert token.get_tag() == B_TAG_PREFIX + TAG1 + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])]) + + blocks = annotation_document_page_to_annotation_blocks( + structured_document, + structured_document.get_pages()[0] + ) + assert [b.tag for b in blocks] == [TAG1] + + def test_should_ignore_block_without_bounding_box(self): + token = SimpleToken('test') + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])]) + structured_document.set_tag(token, TAG1) + + blocks = annotation_document_page_to_annotation_blocks( + structured_document, + structured_document.get_pages()[0] + ) + assert len(blocks) == 0 + class TestAnnotationDocumentPageToMergedBlocks(object): - def test_should_convert_single_token_to_block_with_same_bounding_box(self): - token = SimpleToken('test', tag=TAG1, bounding_box=DEFAULT_BOUNDING_BOX) - structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])]) + def test_should_convert_single_token_to_block_with_same_bounding_box(self): + token = SimpleToken('test', tag=TAG1, bounding_box=DEFAULT_BOUNDING_BOX) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token])]) + + blocks = annotation_document_page_to_merged_blocks( + structured_document, + structured_document.get_pages()[0] + ) + assert len(blocks) == 1 - blocks = annotation_document_page_to_merged_blocks( - structured_document, - structured_document.get_pages()[0] - ) - assert len(blocks) == 1 + block = blocks[0] + assert block.tag == TAG1 + assert block.bounding_box == DEFAULT_BOUNDING_BOX - block = blocks[0] - assert block.tag == TAG1 - assert block.bounding_box == DEFAULT_BOUNDING_BOX class TestMergeBlocks(object): - def test_should_return_same_single_blocks(self): - block = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) - - merged_blocks = merge_blocks([block]) - assert merged_blocks == [block] - - def test_should_merge_right_adjacent_block_with_same_tag(self): - block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) - block2 = AnnotationBlock( - TAG1, - block1.bounding_box.move_by(block1.bounding_box.width, 0) - ) - - merged_blocks = merge_blocks([block1, block2]) - assert [b.tag for b in merged_blocks] == [TAG1] - assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box) - - def test_should_not_merge_right_adjacent_block_with_same_different_tag(self): - block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) - block2 = AnnotationBlock( - TAG2, - block1.bounding_box.move_by(block1.bounding_box.width, 0) - ) - - merged_blocks = merge_blocks([block1, block2]) - assert [b.tag for b in merged_blocks] == [TAG1, TAG2] - - def test_should_merge_multiple_separate_right_adjacent_blocks(self): - block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) - block2 = AnnotationBlock( - TAG1, - block1.bounding_box.move_by(block1.bounding_box.width, 0) - ) - - block3 = AnnotationBlock( - TAG2, - block1.bounding_box.move_by(block1.bounding_box.width * 2, 0) - ) - block4 = AnnotationBlock( - TAG2, - block3.bounding_box.move_by(block3.bounding_box.width, 0) - ) - - merged_blocks = merge_blocks([block1, block2, block3, block4]) - assert [b.tag for b in merged_blocks] == [TAG1, TAG2] - assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box) - assert merged_blocks[1].bounding_box == block3.bounding_box.include(block4.bounding_box) - - def test_should_merge_multiple_sequential_right_adjacent_blocks(self): - block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) - block2 = AnnotationBlock( - TAG1, - block1.bounding_box.move_by(block1.bounding_box.width, 0) - ) - block3 = AnnotationBlock( - TAG1, - block2.bounding_box.move_by(block2.bounding_box.width, 0) - ) - - merged_blocks = merge_blocks([block1, block2, block3]) - assert [b.tag for b in merged_blocks] == [TAG1] - - assert merged_blocks[0].bounding_box == ( - block1.bounding_box.include(block2.bounding_box).include(block3.bounding_box) - ) - - def test_should_merge_right_nearby_block_with_same_tag(self): - block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) - block2 = AnnotationBlock( - TAG1, - block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE, 0) - ) - - merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE) - assert [b.tag for b in merged_blocks] == [TAG1] - assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box) - - def test_should_not_merge_too_far_away_block_with_same_tag(self): - block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) - block2 = AnnotationBlock( - TAG1, - block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE + 1, 0) - ) - - merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE) - assert [b.tag for b in merged_blocks] == [TAG1, TAG1] - - def test_should_merge_right_nearby_block_with_same_tag_using_fractions(self): - block1 = AnnotationBlock(TAG1, BoundingBox(10.5, 10.5, 10.9, 26.3)) - block2 = AnnotationBlock( - TAG1, - block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE, 0) - ) - - merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE) - assert [b.tag for b in merged_blocks] == [TAG1] - assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box) + def test_should_return_same_single_blocks(self): + block = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) + + merged_blocks = merge_blocks([block]) + assert merged_blocks == [block] + + def test_should_merge_right_adjacent_block_with_same_tag(self): + block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) + block2 = AnnotationBlock( + TAG1, + block1.bounding_box.move_by(block1.bounding_box.width, 0) + ) + + merged_blocks = merge_blocks([block1, block2]) + assert [b.tag for b in merged_blocks] == [TAG1] + assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box) + + def test_should_not_merge_right_adjacent_block_with_same_different_tag(self): + block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) + block2 = AnnotationBlock( + TAG2, + block1.bounding_box.move_by(block1.bounding_box.width, 0) + ) + + merged_blocks = merge_blocks([block1, block2]) + assert [b.tag for b in merged_blocks] == [TAG1, TAG2] + + def test_should_merge_multiple_separate_right_adjacent_blocks(self): + block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) + block2 = AnnotationBlock( + TAG1, + block1.bounding_box.move_by(block1.bounding_box.width, 0) + ) + + block3 = AnnotationBlock( + TAG2, + block1.bounding_box.move_by(block1.bounding_box.width * 2, 0) + ) + block4 = AnnotationBlock( + TAG2, + block3.bounding_box.move_by(block3.bounding_box.width, 0) + ) + + merged_blocks = merge_blocks([block1, block2, block3, block4]) + assert [b.tag for b in merged_blocks] == [TAG1, TAG2] + assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box) + assert merged_blocks[1].bounding_box == block3.bounding_box.include(block4.bounding_box) + + def test_should_merge_multiple_sequential_right_adjacent_blocks(self): + block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) + block2 = AnnotationBlock( + TAG1, + block1.bounding_box.move_by(block1.bounding_box.width, 0) + ) + block3 = AnnotationBlock( + TAG1, + block2.bounding_box.move_by(block2.bounding_box.width, 0) + ) + + merged_blocks = merge_blocks([block1, block2, block3]) + assert [b.tag for b in merged_blocks] == [TAG1] + + assert merged_blocks[0].bounding_box == ( + block1.bounding_box.include(block2.bounding_box).include(block3.bounding_box) + ) + + def test_should_merge_right_nearby_block_with_same_tag(self): + block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) + block2 = AnnotationBlock( + TAG1, + block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE, 0) + ) + + merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE) + assert [b.tag for b in merged_blocks] == [TAG1] + assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box) + + def test_should_not_merge_too_far_away_block_with_same_tag(self): + block1 = AnnotationBlock(TAG1, DEFAULT_BOUNDING_BOX) + block2 = AnnotationBlock( + TAG1, + block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE + 1, 0) + ) + + merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE) + assert [b.tag for b in merged_blocks] == [TAG1, TAG1] + + def test_should_merge_right_nearby_block_with_same_tag_using_fractions(self): + block1 = AnnotationBlock(TAG1, BoundingBox(10.5, 10.5, 10.9, 26.3)) + block2 = AnnotationBlock( + TAG1, + block1.bounding_box.move_by(block1.bounding_box.width + DEFAULT_NEARBY_TOLERANCE, 0) + ) + + merged_blocks = merge_blocks([block1, block2], nearby_tolerance=DEFAULT_NEARBY_TOLERANCE) + assert [b.tag for b in merged_blocks] == [TAG1] + assert merged_blocks[0].bounding_box == block1.bounding_box.include(block2.bounding_box) diff --git a/sciencebeam_gym/preprocess/color_map.py b/sciencebeam_gym/preprocess/color_map.py index ce8ba2e..ecd7271 100644 --- a/sciencebeam_gym/preprocess/color_map.py +++ b/sciencebeam_gym/preprocess/color_map.py @@ -1,37 +1,40 @@ from __future__ import absolute_import import re -from six import string_types from six.moves.configparser import ConfigParser +from six import string_types + def parse_color_map_from_configparser(color_map_config): - num_pattern = re.compile(r'(\d+)') - rgb_pattern = re.compile(r'\((\d+),\s*(\d+),\s*(\d+)\)') - - def parse_color(s): - m = num_pattern.match(s) - if m: - x = int(m.group(1)) - return (x, x, x) - else: - m = rgb_pattern.match(s) - if m: - return (int(m.group(1)), int(m.group(2)), int(m.group(3))) - raise Exception('invalid color value: {}'.format(s)) + num_pattern = re.compile(r'(\d+)') + rgb_pattern = re.compile(r'\((\d+),\s*(\d+),\s*(\d+)\)') + + def parse_color(s): + m = num_pattern.match(s) + if m: + x = int(m.group(1)) + return (x, x, x) + else: + m = rgb_pattern.match(s) + if m: + return (int(m.group(1)), int(m.group(2)), int(m.group(3))) + raise Exception('invalid color value: {}'.format(s)) + + color_map = dict() + for k, v in color_map_config.items('color_map'): + color_map[k] = parse_color(v) + return color_map - color_map = dict() - for k, v in color_map_config.items('color_map'): - color_map[k] = parse_color(v) - return color_map def parse_color_map_from_file(f): - color_map_config = ConfigParser() - if isinstance(f, string_types): - with open(f, 'r') as fp: - color_map_config.readfp(fp) - else: - color_map_config.readfp(f) - return parse_color_map_from_configparser(color_map_config) + color_map_config = ConfigParser() + if isinstance(f, string_types): + with open(f, 'r') as fp: + color_map_config.readfp(fp) + else: + color_map_config.readfp(f) + return parse_color_map_from_configparser(color_map_config) + def parse_color_map(f): - return parse_color_map_from_file(f) + return parse_color_map_from_file(f) diff --git a/sciencebeam_gym/preprocess/color_map_test.py b/sciencebeam_gym/preprocess/color_map_test.py index 5cc3207..1bad14a 100644 --- a/sciencebeam_gym/preprocess/color_map_test.py +++ b/sciencebeam_gym/preprocess/color_map_test.py @@ -1,16 +1,17 @@ from six import BytesIO from sciencebeam_gym.preprocess.color_map import ( - parse_color_map_from_file + parse_color_map_from_file ) + class TestParseColorMapFromFile(object): - def test_should_parse_rgb_color_values(self): - data=b'\n'.join([ - b'[color_map]', - b'tag1 = (255, 0, 0)' - ]) - color_map = parse_color_map_from_file(BytesIO(data)) - assert color_map == { - 'tag1': (255, 0, 0) - } + def test_should_parse_rgb_color_values(self): + data = b'\n'.join([ + b'[color_map]', + b'tag1 = (255, 0, 0)' + ]) + color_map = parse_color_map_from_file(BytesIO(data)) + assert color_map == { + 'tag1': (255, 0, 0) + } diff --git a/sciencebeam_gym/preprocess/lxml_to_svg.py b/sciencebeam_gym/preprocess/lxml_to_svg.py index 8186739..887695b 100644 --- a/sciencebeam_gym/preprocess/lxml_to_svg.py +++ b/sciencebeam_gym/preprocess/lxml_to_svg.py @@ -5,232 +5,240 @@ import os from lxml import etree from sciencebeam_utils.utils.csv import ( - open_csv_output, - write_dict_csv + open_csv_output, + write_dict_csv ) from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.preprocess.annotation.annotator import ( - Annotator, - DEFAULT_ANNOTATORS + Annotator, + DEFAULT_ANNOTATORS ) from sciencebeam_gym.preprocess.annotation.matching_annotator import ( - MatchingAnnotator, - CsvMatchDetailReporter + MatchingAnnotator, + CsvMatchDetailReporter ) from sciencebeam_gym.preprocess.annotation.target_annotation import ( - parse_xml_mapping, - xml_root_to_target_annotations + parse_xml_mapping, + xml_root_to_target_annotations ) from sciencebeam_gym.preprocess.annotation.annotation_evaluation import ( - evaluate_document_by_page, - DEFAULT_EVALUATION_COLUMNS, - to_csv_dict_rows as to_annotation_evaluation_csv_dict_rows + evaluate_document_by_page, + DEFAULT_EVALUATION_COLUMNS, + to_csv_dict_rows as to_annotation_evaluation_csv_dict_rows ) from sciencebeam_gym.structured_document.svg import ( - SVG_TEXT, - SVG_G, - SVG_RECT, - SVG_DOC, - SVG_NSMAP, - SVGE_BOUNDING_BOX, - SvgStructuredDocument, - SvgStyleClasses, - format_bounding_box as svg_format_bounding_box + SVG_TEXT, + SVG_G, + SVG_RECT, + SVG_DOC, + SVG_NSMAP, + SVGE_BOUNDING_BOX, + SvgStructuredDocument, + SvgStyleClasses, + format_bounding_box as svg_format_bounding_box ) from sciencebeam_gym.preprocess.visualize_svg_annotation import ( - visualize_svg_annotations + visualize_svg_annotations ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def ElementWithText(tag, text, **kwargs): - node = etree.Element(tag, **kwargs) - node.text = text - return node + node = etree.Element(tag, **kwargs) + node.text = text + return node + def svg_pattern_for_lxml_path(lxml_path): - name, _ = os.path.splitext(lxml_path) - return name + '-page{}.svg' + name, _ = os.path.splitext(lxml_path) + return name + '-page{}.svg' + def parse_args(argv=None): - parser = argparse.ArgumentParser( - description='Convert to LXML (pdftoxml) to SVG' - ) - parser.add_argument( - '--lxml-path', type=str, required=True, - help='path to lxml file' - ) - parser.add_argument( - '--svg-path', type=str, required=False, - help='path to svg file' - ) - parser.add_argument( - '--xml-path', type=str, required=False, - help='path to xml file' - ) - parser.add_argument( - '--xml-mapping-path', type=str, default='annot-xml-front.conf', - help='path to xml mapping file' - ) - parser.add_argument( - '--annotate', action='store_true', required=False, - help='enable annotation' - ) - parser.add_argument( - '--debug', action='store_true', required=False, - help='enable debug logging' - ) - parser.add_argument( - '--debug-match', type=str, required=False, - help='debug matches and save as csv' - ) - parser.add_argument( - '--annotation-evaluation-csv', type=str, required=False, - help='Annotation evaluation CSV output file' - ) - args = parser.parse_args(argv) - return args + parser = argparse.ArgumentParser( + description='Convert to LXML (pdftoxml) to SVG' + ) + parser.add_argument( + '--lxml-path', type=str, required=True, + help='path to lxml file' + ) + parser.add_argument( + '--svg-path', type=str, required=False, + help='path to svg file' + ) + parser.add_argument( + '--xml-path', type=str, required=False, + help='path to xml file' + ) + parser.add_argument( + '--xml-mapping-path', type=str, default='annot-xml-front.conf', + help='path to xml mapping file' + ) + parser.add_argument( + '--annotate', action='store_true', required=False, + help='enable annotation' + ) + parser.add_argument( + '--debug', action='store_true', required=False, + help='enable debug logging' + ) + parser.add_argument( + '--debug-match', type=str, required=False, + help='debug matches and save as csv' + ) + parser.add_argument( + '--annotation-evaluation-csv', type=str, required=False, + help='Annotation evaluation CSV output file' + ) + args = parser.parse_args(argv) + return args + def iter_svg_pages_for_lxml(lxml_root, add_background=True): - previous_block = None - previous_svg_block = None - for page in lxml_root.xpath('//DOCUMENT/PAGE'): - svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) - page_width = page.attrib.get('width') - page_height = page.attrib.get('height') - if page_width and page_height: - svg_root.attrib['viewBox'] = '0 0 %s %s' % (page_width, page_height) - if add_background: - svg_root.append(etree.Element(SVG_RECT, attrib={ - 'width': '100%', - 'height': '100%', - 'fill': 'white', - 'class': 'background' - })) - for text in page.xpath('.//TEXT'): - svg_g = etree.Element(SVG_G, nsmap=SVG_NSMAP, attrib={ - 'class': SvgStyleClasses.LINE - }) - for token in text.xpath('./TOKEN'): - x = float(token.attrib.get('x')) - y = float(token.attrib.get('y')) - height = float(token.attrib.get('height')) - width = float(token.attrib.get('width')) - base = float(token.attrib.get('base', y)) - y_center = y + height / 2.0 - attrib = { - 'x': str(x), - 'y': str(base), - 'font-size': token.attrib.get('font-size'), - 'font-family': token.attrib.get('font-name'), - 'fill': token.attrib.get('font-color'), - SVGE_BOUNDING_BOX: svg_format_bounding_box(BoundingBox( - x, y, width, height - )) - } - angle = float(token.attrib.get('angle', '0')) - if token.attrib.get('rotation') == '1' and angle == 90.0: - attrib['x'] = '0' - attrib['y'] = '0' - attrib['transform'] = 'translate({x} {y}) rotate({angle})'.format( - x=str(x), - y=str(y_center), - angle=str(-angle) - ) - svg_g.append( - ElementWithText(SVG_TEXT, token.text, attrib=attrib) - ) - text_parent = text.getparent() - if text_parent.tag == 'BLOCK': - if text_parent != previous_block: - previous_svg_block = etree.Element(SVG_G, nsmap=SVG_NSMAP, attrib={ - 'class': SvgStyleClasses.BLOCK - }) - svg_root.append(previous_svg_block) - previous_block = text_parent - previous_svg_block.append(svg_g) - else: - previous_block = None - previous_svg_block = None - svg_root.append(svg_g) - yield svg_root + previous_block = None + previous_svg_block = None + for page in lxml_root.xpath('//DOCUMENT/PAGE'): + svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) + page_width = page.attrib.get('width') + page_height = page.attrib.get('height') + if page_width and page_height: + svg_root.attrib['viewBox'] = '0 0 %s %s' % (page_width, page_height) + if add_background: + svg_root.append(etree.Element(SVG_RECT, attrib={ + 'width': '100%', + 'height': '100%', + 'fill': 'white', + 'class': 'background' + })) + for text in page.xpath('.//TEXT'): + svg_g = etree.Element(SVG_G, nsmap=SVG_NSMAP, attrib={ + 'class': SvgStyleClasses.LINE + }) + for token in text.xpath('./TOKEN'): + x = float(token.attrib.get('x')) + y = float(token.attrib.get('y')) + height = float(token.attrib.get('height')) + width = float(token.attrib.get('width')) + base = float(token.attrib.get('base', y)) + y_center = y + height / 2.0 + attrib = { + 'x': str(x), + 'y': str(base), + 'font-size': token.attrib.get('font-size'), + 'font-family': token.attrib.get('font-name'), + 'fill': token.attrib.get('font-color'), + SVGE_BOUNDING_BOX: svg_format_bounding_box(BoundingBox( + x, y, width, height + )) + } + angle = float(token.attrib.get('angle', '0')) + if token.attrib.get('rotation') == '1' and angle == 90.0: + attrib['x'] = '0' + attrib['y'] = '0' + attrib['transform'] = 'translate({x} {y}) rotate({angle})'.format( + x=str(x), + y=str(y_center), + angle=str(-angle) + ) + svg_g.append( + ElementWithText(SVG_TEXT, token.text, attrib=attrib) + ) + text_parent = text.getparent() + if text_parent.tag == 'BLOCK': + if text_parent != previous_block: + previous_svg_block = etree.Element(SVG_G, nsmap=SVG_NSMAP, attrib={ + 'class': SvgStyleClasses.BLOCK + }) + svg_root.append(previous_svg_block) + previous_block = text_parent + previous_svg_block.append(svg_g) + else: + previous_block = None + previous_svg_block = None + svg_root.append(svg_g) + yield svg_root + def convert(args): - logger = get_logger() - svg_filename_pattern = args.svg_path - if not svg_filename_pattern: - svg_filename_pattern = svg_pattern_for_lxml_path(args.lxml_path) - logger.debug('svg_filename_pattern: %s', svg_filename_pattern) - lxml_root = etree.parse(args.lxml_path).getroot() - - match_detail_reporter = None - if args.annotate: - annotators = DEFAULT_ANNOTATORS - if args.debug_match: - match_detail_reporter = CsvMatchDetailReporter( - open_csv_output(args.debug_match), - args.debug_match - ) - if args.xml_path: - xml_mapping = parse_xml_mapping(args.xml_mapping_path) - target_annotations = xml_root_to_target_annotations( - etree.parse(args.xml_path).getroot(), - xml_mapping - ) - annotators = annotators + [MatchingAnnotator( - target_annotations, match_detail_reporter=match_detail_reporter, - use_tag_begin_prefix=True - )] - annotator = Annotator(annotators) - else: - annotator = None - - if annotator: - svg_roots = list(iter_svg_pages_for_lxml(lxml_root)) - annotator.annotate(SvgStructuredDocument(svg_roots)) - else: - svg_roots = iter_svg_pages_for_lxml(lxml_root) - for page_index, svg_root in enumerate(svg_roots): + logger = get_logger() + svg_filename_pattern = args.svg_path + if not svg_filename_pattern: + svg_filename_pattern = svg_pattern_for_lxml_path(args.lxml_path) + logger.debug('svg_filename_pattern: %s', svg_filename_pattern) + lxml_root = etree.parse(args.lxml_path).getroot() + + match_detail_reporter = None + if args.annotate: + annotators = DEFAULT_ANNOTATORS + if args.debug_match: + match_detail_reporter = CsvMatchDetailReporter( + open_csv_output(args.debug_match), + args.debug_match + ) + if args.xml_path: + xml_mapping = parse_xml_mapping(args.xml_mapping_path) + target_annotations = xml_root_to_target_annotations( + etree.parse(args.xml_path).getroot(), + xml_mapping + ) + annotators = annotators + [MatchingAnnotator( + target_annotations, match_detail_reporter=match_detail_reporter, + use_tag_begin_prefix=True + )] + annotator = Annotator(annotators) + else: + annotator = None + + if annotator: + svg_roots = list(iter_svg_pages_for_lxml(lxml_root)) + annotator.annotate(SvgStructuredDocument(svg_roots)) + else: + svg_roots = iter_svg_pages_for_lxml(lxml_root) + for page_index, svg_root in enumerate(svg_roots): + if annotator: + svg_root = visualize_svg_annotations(svg_root) + svg_filename = svg_filename_pattern.format(1 + page_index) + logger.info('writing to: %s', svg_filename) + with open(svg_filename, 'wb') as f: + etree.ElementTree(svg_root).write(f, pretty_print=True) if annotator: - svg_root = visualize_svg_annotations(svg_root) - svg_filename = svg_filename_pattern.format(1 + page_index) - logger.info('writing to: %s', svg_filename) - with open(svg_filename, 'wb') as f: - etree.ElementTree(svg_root).write(f, pretty_print=True) - if annotator: - tagging_evaluation_results = evaluate_document_by_page(SvgStructuredDocument(svg_roots)) - logger.info('tagging evaluation:\n%s', '\n'.join([ - 'page{}: {}'.format(1 + i, r) for i, r in enumerate(tagging_evaluation_results) - ])) - if args.annotation_evaluation_csv: - write_dict_csv( - args.annotation_evaluation_csv, - DEFAULT_EVALUATION_COLUMNS, - to_annotation_evaluation_csv_dict_rows( - tagging_evaluation_results, - document=os.path.basename(args.lxml_path) - ) - ) - if match_detail_reporter: - match_detail_reporter.close() + tagging_evaluation_results = evaluate_document_by_page(SvgStructuredDocument(svg_roots)) + logger.info('tagging evaluation:\n%s', '\n'.join([ + 'page{}: {}'.format(1 + i, r) for i, r in enumerate(tagging_evaluation_results) + ])) + if args.annotation_evaluation_csv: + write_dict_csv( + args.annotation_evaluation_csv, + DEFAULT_EVALUATION_COLUMNS, + to_annotation_evaluation_csv_dict_rows( + tagging_evaluation_results, + document=os.path.basename(args.lxml_path) + ) + ) + if match_detail_reporter: + match_detail_reporter.close() + def main(): - args = parse_args() - if args.debug: - logging.basicConfig(level=logging.DEBUG) - else: - logging.basicConfig(level=logging.INFO) - convert(args) + args = parse_args() + if args.debug: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + convert(args) + if __name__ == "__main__": - main() + main() diff --git a/sciencebeam_gym/preprocess/lxml_to_svg_test.py b/sciencebeam_gym/preprocess/lxml_to_svg_test.py index c754cfa..98a6398 100644 --- a/sciencebeam_gym/preprocess/lxml_to_svg_test.py +++ b/sciencebeam_gym/preprocess/lxml_to_svg_test.py @@ -1,20 +1,20 @@ from lxml.builder import E from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.structured_document.svg import ( - SVG_TEXT, - SVG_G, - SVG_DOC, - SVG_NS, - SVGE_BOUNDING_BOX, - parse_bounding_box + SVG_TEXT, + SVG_G, + SVG_DOC, + SVG_NS, + SVGE_BOUNDING_BOX, + parse_bounding_box ) from sciencebeam_gym.preprocess.lxml_to_svg import ( - iter_svg_pages_for_lxml + iter_svg_pages_for_lxml ) SOME_TEXT = "some text" @@ -27,177 +27,182 @@ SOME_FONT_SIZE = "40" SOME_FONT_FAMILY = "Fontastic" SOME_FONT_COLOR = '#123' + class LXML(object): - X = 'x' - Y = 'y' - BASE = 'base' - WIDTH = 'width' - HEIGHT = 'height' - FONT_SIZE = 'font-size' - FONT_NAME = 'font-name' - FONT_COLOR = 'font-color' + X = 'x' + Y = 'y' + BASE = 'base' + WIDTH = 'width' + HEIGHT = 'height' + FONT_SIZE = 'font-size' + FONT_NAME = 'font-name' + FONT_COLOR = 'font-color' + class SVG(object): - X = 'x' - Y = 'y' - HEIGHT = 'height' - FONT_SIZE = 'font-size' - FONT_FAMILY = 'font-family' - FILL = 'fill' - BOUNDING_BOX = SVGE_BOUNDING_BOX + X = 'x' + Y = 'y' + HEIGHT = 'height' + FONT_SIZE = 'font-size' + FONT_FAMILY = 'font-family' + FILL = 'fill' + BOUNDING_BOX = SVGE_BOUNDING_BOX + COMMON_LXML_TOKEN_ATTRIBS = { - LXML.X: SOME_X, - LXML.Y: SOME_Y, - LXML.WIDTH: SOME_WIDTH, - LXML.HEIGHT: SOME_HEIGHT, - LXML.FONT_SIZE: SOME_FONT_SIZE, - LXML.FONT_NAME: SOME_FONT_FAMILY, - LXML.FONT_COLOR: SOME_FONT_COLOR + LXML.X: SOME_X, + LXML.Y: SOME_Y, + LXML.WIDTH: SOME_WIDTH, + LXML.HEIGHT: SOME_HEIGHT, + LXML.FONT_SIZE: SOME_FONT_SIZE, + LXML.FONT_NAME: SOME_FONT_FAMILY, + LXML.FONT_COLOR: SOME_FONT_COLOR } + def dict_extend(*dicts): - d = dict() - for x in dicts: - d.update(x) - return d + d = dict() + for x in dicts: + d.update(x) + return d + class TestIterSvgPagesForLxml(object): - def test_should_return_one_page(self): - lxml_root = E.DOCUMENT( - E.PAGE( - ) - ) - svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) - assert len(svg_pages) == 1 - - def test_should_return_multiple_pages(self): - lxml_root = E.DOCUMENT( - E.PAGE( - ), - E.PAGE( - ), - E.PAGE( - ) - ) - svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) - assert len(svg_pages) == 3 - - def test_should_set_svg_dimensions(self): - lxml_root = E.DOCUMENT( - E.PAGE( - width='600', - height='800' - ) - ) - svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) - assert len(svg_pages) == 1 - assert svg_pages[0].attrib.get('viewBox') == '0 0 600 800' - - def test_should_add_background_rect(self): - lxml_root = E.DOCUMENT( - E.PAGE( - width='600', - height='800' - ) - ) - svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) - assert len(svg_pages) == 1 - background_rect = svg_pages[0].xpath( - 'svg:rect[@class="background"]', - namespaces={'svg': SVG_NS} - ) - assert len(background_rect) == 1 - - def test_should_create_text_node_with_common_attributes(self): - lxml_root = E.DOCUMENT( - E.PAGE( - E.TEXT( - E.TOKEN( - SOME_TEXT, - COMMON_LXML_TOKEN_ATTRIBS - ) + def test_should_return_one_page(self): + lxml_root = E.DOCUMENT( + E.PAGE( + ) ) - ) - ) - svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) - assert len(svg_pages) == 1 - first_page = svg_pages[0] - svg_text = first_page.find('.//' + SVG_TEXT) - assert svg_text is not None - assert svg_text.text == SOME_TEXT - assert float(svg_text.attrib[SVG.X]) == float(SOME_X) - assert float(svg_text.attrib[SVG.Y]) == float(SOME_Y) - assert float(svg_text.attrib[SVG.FONT_SIZE]) == float(SOME_FONT_SIZE) - assert svg_text.attrib[SVG.FONT_FAMILY] == SOME_FONT_FAMILY - assert svg_text.attrib[SVG.FILL] == SOME_FONT_COLOR - assert parse_bounding_box(svg_text.attrib.get(SVG.BOUNDING_BOX)) == BoundingBox( - float(COMMON_LXML_TOKEN_ATTRIBS[LXML.X]), - float(COMMON_LXML_TOKEN_ATTRIBS[LXML.Y]), - float(COMMON_LXML_TOKEN_ATTRIBS[LXML.WIDTH]), - float(COMMON_LXML_TOKEN_ATTRIBS[LXML.HEIGHT]) - ) - - def test_should_use_base_as_y_in_svg_if_available(self): - lxml_root = E.DOCUMENT( - E.PAGE( - E.TEXT( - E.TOKEN( - SOME_TEXT, - dict_extend(COMMON_LXML_TOKEN_ATTRIBS, { - LXML.BASE: SOME_BASE - }) - ) + svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) + assert len(svg_pages) == 1 + + def test_should_return_multiple_pages(self): + lxml_root = E.DOCUMENT( + E.PAGE( + ), + E.PAGE( + ), + E.PAGE( + ) ) - ) - ) - svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) - assert len(svg_pages) == 1 - first_page = svg_pages[0] - svg_text = first_page.find('.//' + SVG_TEXT) - assert float(svg_text.attrib[SVG.Y]) == float(SOME_BASE) - - def test_should_keep_text_block_structure_without_block(self): - lxml_root = E.DOCUMENT( - E.PAGE( - E.TEXT( - E.TOKEN( - SOME_TEXT, - dict_extend(COMMON_LXML_TOKEN_ATTRIBS, { - LXML.BASE: SOME_BASE - }) - ) + svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) + assert len(svg_pages) == 3 + + def test_should_set_svg_dimensions(self): + lxml_root = E.DOCUMENT( + E.PAGE( + width='600', + height='800' + ) + ) + svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) + assert len(svg_pages) == 1 + assert svg_pages[0].attrib.get('viewBox') == '0 0 600 800' + + def test_should_add_background_rect(self): + lxml_root = E.DOCUMENT( + E.PAGE( + width='600', + height='800' + ) + ) + svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) + assert len(svg_pages) == 1 + background_rect = svg_pages[0].xpath( + 'svg:rect[@class="background"]', + namespaces={'svg': SVG_NS} + ) + assert len(background_rect) == 1 + + def test_should_create_text_node_with_common_attributes(self): + lxml_root = E.DOCUMENT( + E.PAGE( + E.TEXT( + E.TOKEN( + SOME_TEXT, + COMMON_LXML_TOKEN_ATTRIBS + ) + ) + ) + ) + svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) + assert len(svg_pages) == 1 + first_page = svg_pages[0] + svg_text = first_page.find('.//' + SVG_TEXT) + assert svg_text is not None + assert svg_text.text == SOME_TEXT + assert float(svg_text.attrib[SVG.X]) == float(SOME_X) + assert float(svg_text.attrib[SVG.Y]) == float(SOME_Y) + assert float(svg_text.attrib[SVG.FONT_SIZE]) == float(SOME_FONT_SIZE) + assert svg_text.attrib[SVG.FONT_FAMILY] == SOME_FONT_FAMILY + assert svg_text.attrib[SVG.FILL] == SOME_FONT_COLOR + assert parse_bounding_box(svg_text.attrib.get(SVG.BOUNDING_BOX)) == BoundingBox( + float(COMMON_LXML_TOKEN_ATTRIBS[LXML.X]), + float(COMMON_LXML_TOKEN_ATTRIBS[LXML.Y]), + float(COMMON_LXML_TOKEN_ATTRIBS[LXML.WIDTH]), + float(COMMON_LXML_TOKEN_ATTRIBS[LXML.HEIGHT]) + ) + + def test_should_use_base_as_y_in_svg_if_available(self): + lxml_root = E.DOCUMENT( + E.PAGE( + E.TEXT( + E.TOKEN( + SOME_TEXT, + dict_extend(COMMON_LXML_TOKEN_ATTRIBS, { + LXML.BASE: SOME_BASE + }) + ) + ) + ) + ) + svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) + assert len(svg_pages) == 1 + first_page = svg_pages[0] + svg_text = first_page.find('.//' + SVG_TEXT) + assert float(svg_text.attrib[SVG.Y]) == float(SOME_BASE) + + def test_should_keep_text_block_structure_without_block(self): + lxml_root = E.DOCUMENT( + E.PAGE( + E.TEXT( + E.TOKEN( + SOME_TEXT, + dict_extend(COMMON_LXML_TOKEN_ATTRIBS, { + LXML.BASE: SOME_BASE + }) + ) + ) + ) ) - ) - ) - svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) - assert len(svg_pages) == 1 - first_page = svg_pages[0] - svg_text = first_page.find('.//' + SVG_TEXT) - assert svg_text is not None - assert svg_text.getparent().tag == SVG_G - assert svg_text.getparent().getparent().tag == SVG_DOC - - def test_should_keep_text_block_structure_with_block(self): - lxml_root = E.DOCUMENT( - E.PAGE( - E.BLOCK( - E.TEXT( - E.TOKEN( - SOME_TEXT, - dict_extend(COMMON_LXML_TOKEN_ATTRIBS, { - LXML.BASE: SOME_BASE - }) + svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) + assert len(svg_pages) == 1 + first_page = svg_pages[0] + svg_text = first_page.find('.//' + SVG_TEXT) + assert svg_text is not None + assert svg_text.getparent().tag == SVG_G + assert svg_text.getparent().getparent().tag == SVG_DOC + + def test_should_keep_text_block_structure_with_block(self): + lxml_root = E.DOCUMENT( + E.PAGE( + E.BLOCK( + E.TEXT( + E.TOKEN( + SOME_TEXT, + dict_extend(COMMON_LXML_TOKEN_ATTRIBS, { + LXML.BASE: SOME_BASE + }) + ) + ) + ) ) - ) ) - ) - ) - svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) - assert len(svg_pages) == 1 - first_page = svg_pages[0] - svg_text = first_page.find('.//' + SVG_TEXT) - assert svg_text is not None - assert svg_text.getparent().tag == SVG_G - assert svg_text.getparent().getparent().tag == SVG_G - assert svg_text.getparent().getparent().getparent().tag == SVG_DOC + svg_pages = list(iter_svg_pages_for_lxml(lxml_root)) + assert len(svg_pages) == 1 + first_page = svg_pages[0] + svg_text = first_page.find('.//' + SVG_TEXT) + assert svg_text is not None + assert svg_text.getparent().tag == SVG_G + assert svg_text.getparent().getparent().tag == SVG_G + assert svg_text.getparent().getparent().getparent().tag == SVG_DOC diff --git a/sciencebeam_gym/preprocess/preprocessing_pipeline.py b/sciencebeam_gym/preprocess/preprocessing_pipeline.py index 9f36244..9f3749e 100644 --- a/sciencebeam_gym/preprocess/preprocessing_pipeline.py +++ b/sciencebeam_gym/preprocess/preprocessing_pipeline.py @@ -10,550 +10,563 @@ from apache_beam.io.filesystems import FileSystems from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions from sciencebeam_utils.beam_utils.utils import ( - TransformAndCount, - TransformAndLog, - MapOrLog, - PreventFusion + TransformAndCount, + TransformAndLog, + MapOrLog, + PreventFusion ) from sciencebeam_utils.beam_utils.csv import ( - WriteDictCsv, - ReadDictCsv + WriteDictCsv, + ReadDictCsv ) from sciencebeam_utils.beam_utils.io import ( - read_all_from_path, - basename, - save_file_content + read_all_from_path, + basename, + save_file_content ) from sciencebeam_utils.beam_utils.main import ( - add_cloud_args, - process_cloud_args + add_cloud_args, + process_cloud_args ) from sciencebeam_utils.utils.collection import ( - extend_dict, - remove_keys_from_dict + extend_dict, + remove_keys_from_dict ) from sciencebeam_utils.utils.file_path import ( - change_ext, - relative_path, - join_if_relative_path + change_ext, + relative_path, + join_if_relative_path ) from sciencebeam_utils.utils.file_pairs import ( - find_file_pairs_grouped_by_parent_directory_or_name, + find_file_pairs_grouped_by_parent_directory_or_name, ) from sciencebeam_gym.structured_document.svg import ( - SvgStructuredDocument + SvgStructuredDocument ) from sciencebeam_gym.preprocess.annotation.target_annotation import ( - parse_xml_mapping + parse_xml_mapping ) from sciencebeam_gym.preprocess.color_map import ( - parse_color_map_from_file + parse_color_map_from_file ) from sciencebeam_gym.preprocess.annotation.annotation_evaluation import ( - evaluate_document_by_page, - DEFAULT_EVALUATION_COLUMNS, - to_csv_dict_rows as to_annotation_evaluation_csv_dict_rows + evaluate_document_by_page, + DEFAULT_EVALUATION_COLUMNS, + to_csv_dict_rows as to_annotation_evaluation_csv_dict_rows ) from sciencebeam_gym.preprocess.preprocessing_utils import ( - convert_pdf_bytes_to_lxml, - convert_and_annotate_lxml_content, - pdf_bytes_to_png_pages, - svg_page_to_blockified_png_bytes, - save_pages, - save_svg_roots, - filter_list_props_by_indices, - get_page_indices_with_min_annotation_percentage, - parse_page_range + convert_pdf_bytes_to_lxml, + convert_and_annotate_lxml_content, + pdf_bytes_to_png_pages, + svg_page_to_blockified_png_bytes, + save_pages, + save_svg_roots, + filter_list_props_by_indices, + get_page_indices_with_min_annotation_percentage, + parse_page_range ) from sciencebeam_gym.preprocess.preprocessing_transforms import ( - WritePropsToTFRecord + WritePropsToTFRecord ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class MetricCounters(object): - FILE_PAIR = 'file_pair_count' - PAGE = 'page_count' - FILTERED_PAGE = 'filtered_page_count' - CONVERT_PDF_TO_LXML_ERROR = 'ConvertPdfToLxml_error_count' - CONVERT_PDF_TO_PNG_ERROR = 'ConvertPdfToPng_error_count' - CONVERT_LXML_TO_SVG_ANNOT_ERROR = 'ConvertPdfToSvgAnnot_error_count' + FILE_PAIR = 'file_pair_count' + PAGE = 'page_count' + FILTERED_PAGE = 'filtered_page_count' + CONVERT_PDF_TO_LXML_ERROR = 'ConvertPdfToLxml_error_count' + CONVERT_PDF_TO_PNG_ERROR = 'ConvertPdfToPng_error_count' + CONVERT_LXML_TO_SVG_ANNOT_ERROR = 'ConvertPdfToSvgAnnot_error_count' + def configure_pipeline(p, opt): - image_size = ( - (opt.image_width, opt.image_height) - if opt.image_width and opt.image_height - else None - ) - page_range = opt.pages - first_page = page_range[0] if page_range else 1 - xml_mapping = parse_xml_mapping(opt.xml_mapping_path) - if opt.lxml_path: - lxml_xml_file_pairs = ( - p | - beam.Create([[ - join_if_relative_path(opt.base_data_path, s) - for s in [opt.lxml_path, opt.xml_path] - ]]) | - "FindFilePairs" >> TransformAndLog( - beam.FlatMap( - lambda patterns: islice( - find_file_pairs_grouped_by_parent_directory_or_name(patterns), - opt.limit - ) - ), - log_prefix='file pairs: ', - log_level='debug' - ) | - PreventFusion() | - "ReadFileContent" >> beam.Map(lambda filenames: { - 'source_filename': filenames[0], - 'xml_filename': filenames[1], - 'lxml_content': read_all_from_path(filenames[0]), - 'xml_content': read_all_from_path(filenames[1]) - }) + image_size = ( + (opt.image_width, opt.image_height) + if opt.image_width and opt.image_height + else None ) - elif opt.pdf_path or opt.pdf_xml_file_list: - if opt.pdf_xml_file_list: - pdf_xml_url_pairs = ( - p | - "ReadFilePairUrls" >> ReadDictCsv(opt.pdf_xml_file_list, limit=opt.limit) | - "TranslateFilePairUrls" >> beam.Map(lambda row: (row['source_url'], row['xml_url'])) - ) - else: - pdf_xml_url_pairs = ( - p | - beam.Create([[ - join_if_relative_path(opt.base_data_path, s) - for s in [opt.pdf_path, opt.xml_path] - ]]) | - "FindFilePairs" >> TransformAndLog( - beam.FlatMap( - lambda patterns: islice( - find_file_pairs_grouped_by_parent_directory_or_name(patterns), - opt.limit + page_range = opt.pages + first_page = page_range[0] if page_range else 1 + xml_mapping = parse_xml_mapping(opt.xml_mapping_path) + if opt.lxml_path: + lxml_xml_file_pairs = ( + p | + beam.Create([[ + join_if_relative_path(opt.base_data_path, s) + for s in [opt.lxml_path, opt.xml_path] + ]]) | + "FindFilePairs" >> TransformAndLog( + beam.FlatMap( + lambda patterns: islice( + find_file_pairs_grouped_by_parent_directory_or_name(patterns), + opt.limit + ) + ), + log_prefix='file pairs: ', + log_level='debug' + ) | + PreventFusion() | + "ReadFileContent" >> beam.Map(lambda filenames: { + 'source_filename': filenames[0], + 'xml_filename': filenames[1], + 'lxml_content': read_all_from_path(filenames[0]), + 'xml_content': read_all_from_path(filenames[1]) + }) + ) + elif opt.pdf_path or opt.pdf_xml_file_list: + if opt.pdf_xml_file_list: + pdf_xml_url_pairs = ( + p | + "ReadFilePairUrls" >> ReadDictCsv(opt.pdf_xml_file_list, limit=opt.limit) | + "TranslateFilePairUrls" >> beam.Map(lambda row: (row['source_url'], row['xml_url'])) + ) + else: + pdf_xml_url_pairs = ( + p | + beam.Create([[ + join_if_relative_path(opt.base_data_path, s) + for s in [opt.pdf_path, opt.xml_path] + ]]) | + "FindFilePairs" >> TransformAndLog( + beam.FlatMap( + lambda patterns: islice( + find_file_pairs_grouped_by_parent_directory_or_name(patterns), + opt.limit + ) + ), + log_prefix='file pairs: ', + log_level='debug' + ) + ) + pdf_xml_file_pairs = ( + pdf_xml_url_pairs | + PreventFusion() | + "ReadFileContent" >> TransformAndCount( + beam.Map(lambda filenames: { + 'source_filename': filenames[0], + 'xml_filename': filenames[1], + 'pdf_content': read_all_from_path(filenames[0]), + 'xml_content': read_all_from_path(filenames[1]) + }), + MetricCounters.FILE_PAIR ) - ), - log_prefix='file pairs: ', - log_level='debug' ) - ) - pdf_xml_file_pairs = ( - pdf_xml_url_pairs | - PreventFusion() | - "ReadFileContent" >> TransformAndCount( - beam.Map(lambda filenames: { - 'source_filename': filenames[0], - 'xml_filename': filenames[1], - 'pdf_content': read_all_from_path(filenames[0]), - 'xml_content': read_all_from_path(filenames[1]) - }), - MetricCounters.FILE_PAIR - ) - ) - lxml_xml_file_pairs = ( - pdf_xml_file_pairs | - "ConvertPdfToLxml" >> MapOrLog(lambda v: remove_keys_from_dict( - extend_dict(v, { - 'lxml_content': convert_pdf_bytes_to_lxml( - v['pdf_content'], path=v['source_filename'], - page_range=page_range - ) - }), - # we don't need the pdf_content unless we are writing tf_records - None if opt.save_tfrecords else {'pdf_content'} - ), log_fn=lambda e, v: ( - get_logger().warning( - 'caught exception (ignoring item): %s, pdf: %s, xml: %s', - e, v['source_filename'], v['xml_filename'], exc_info=e + lxml_xml_file_pairs = ( + pdf_xml_file_pairs | + "ConvertPdfToLxml" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + 'lxml_content': convert_pdf_bytes_to_lxml( + v['pdf_content'], path=v['source_filename'], + page_range=page_range + ) + }), + # we don't need the pdf_content unless we are writing tf_records + None if opt.save_tfrecords else {'pdf_content'} + ), log_fn=lambda e, v: ( + get_logger().warning( + 'caught exception (ignoring item): %s, pdf: %s, xml: %s', + e, v['source_filename'], v['xml_filename'], exc_info=e + ) + ), error_count=MetricCounters.CONVERT_PDF_TO_LXML_ERROR) ) - ), error_count=MetricCounters.CONVERT_PDF_TO_LXML_ERROR) - ) - else: - raise RuntimeError('either lxml-path or pdf-path required') - - if opt.save_png or opt.save_tfrecords: - with_pdf_png_pages = ( - (lxml_xml_file_pairs if opt.save_tfrecords else pdf_xml_file_pairs) | - "ConvertPdfToPng" >> MapOrLog(lambda v: remove_keys_from_dict( - extend_dict(v, { - 'pdf_png_pages': list(pdf_bytes_to_png_pages( - v['pdf_content'], - dpi=opt.png_dpi, - image_size=image_size, - page_range=page_range - )) - }), - {'pdf_content'} # we no longer need the pdf_content - ), error_count=MetricCounters.CONVERT_PDF_TO_PNG_ERROR) - ) - - if opt.save_png: - _ = ( - with_pdf_png_pages | - "SavePdfToPng" >> TransformAndLog( - beam.Map(lambda v: save_pages( - FileSystems.join( - opt.output_path, - change_ext( - relative_path(opt.base_data_path, v['source_filename']), - None, '.png.zip' - ) - ), - '.png', - v['pdf_png_pages'] - )), - log_fn=lambda x: get_logger().info('saved result: %s', x) + else: + raise RuntimeError('either lxml-path or pdf-path required') + + if opt.save_png or opt.save_tfrecords: + with_pdf_png_pages = ( + (lxml_xml_file_pairs if opt.save_tfrecords else pdf_xml_file_pairs) | + "ConvertPdfToPng" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + 'pdf_png_pages': list(pdf_bytes_to_png_pages( + v['pdf_content'], + dpi=opt.png_dpi, + image_size=image_size, + page_range=page_range + )) + }), + {'pdf_content'} # we no longer need the pdf_content + ), error_count=MetricCounters.CONVERT_PDF_TO_PNG_ERROR) ) - ) - - if opt.save_lxml: - _ = ( - lxml_xml_file_pairs | - "SaveLxml" >> TransformAndLog( - beam.Map(lambda v: save_file_content( - FileSystems.join( - opt.output_path, - change_ext( - relative_path(opt.base_data_path, v['source_filename']), - None, '.lxml.gz' + + if opt.save_png: + _ = ( + with_pdf_png_pages | + "SavePdfToPng" >> TransformAndLog( + beam.Map(lambda v: save_pages( + FileSystems.join( + opt.output_path, + change_ext( + relative_path(opt.base_data_path, v['source_filename']), + None, '.png.zip' + ) + ), + '.png', + v['pdf_png_pages'] + )), + log_fn=lambda x: get_logger().info('saved result: %s', x) + ) ) - ), - v['lxml_content'] - )), - log_fn=lambda x: get_logger().info('saved lxml: %s', x) - ) - ) - annotation_results = ( - (with_pdf_png_pages if opt.save_tfrecords else lxml_xml_file_pairs) | - "ConvertLxmlToSvgAndAnnotate" >> TransformAndCount( - MapOrLog(lambda v: remove_keys_from_dict( - extend_dict(v, { - 'svg_pages': list(convert_and_annotate_lxml_content( - v['lxml_content'], v['xml_content'], xml_mapping, - name=v['source_filename'] - )) - }), - # Won't need the XML anymore - {'lxml_content', 'xml_content'} - ), log_fn=lambda e, v: ( - get_logger().warning( - 'caught exception (ignoring item): %s, source: %s, xml: %s', - e, v['source_filename'], v['xml_filename'], exc_info=e - ) - ), error_count=MetricCounters.CONVERT_LXML_TO_SVG_ANNOT_ERROR), - MetricCounters.PAGE, - lambda v: len(v['svg_pages']) - ) - ) - - if opt.save_svg: - _ = ( - annotation_results | - "SaveSvgPages" >> TransformAndLog( - beam.Map(lambda v: save_svg_roots( - FileSystems.join( - opt.output_path, - change_ext( - relative_path(opt.base_data_path, v['source_filename']), - None, '.svg.zip' + if opt.save_lxml: + _ = ( + lxml_xml_file_pairs | + "SaveLxml" >> TransformAndLog( + beam.Map(lambda v: save_file_content( + FileSystems.join( + opt.output_path, + change_ext( + relative_path(opt.base_data_path, v['source_filename']), + None, '.lxml.gz' + ) + ), + v['lxml_content'] + )), + log_fn=lambda x: get_logger().info('saved lxml: %s', x) ) - ), - v['svg_pages'] - )), - log_fn=lambda x: get_logger().info('saved result: %s', x) - ) + ) + + annotation_results = ( + (with_pdf_png_pages if opt.save_tfrecords else lxml_xml_file_pairs) | + "ConvertLxmlToSvgAndAnnotate" >> TransformAndCount( + MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + 'svg_pages': list(convert_and_annotate_lxml_content( + v['lxml_content'], v['xml_content'], xml_mapping, + name=v['source_filename'] + )) + }), + # Won't need the XML anymore + {'lxml_content', 'xml_content'} + ), log_fn=lambda e, v: ( + get_logger().warning( + 'caught exception (ignoring item): %s, source: %s, xml: %s', + e, v['source_filename'], v['xml_filename'], exc_info=e + ) + ), error_count=MetricCounters.CONVERT_LXML_TO_SVG_ANNOT_ERROR), + MetricCounters.PAGE, + lambda v: len(v['svg_pages']) + ) ) - if opt.annotation_evaluation_csv or opt.min_annotation_percentage: - annotation_evaluation_results = ( - annotation_results | - "EvaluateAnnotations" >> TransformAndLog( - beam.Map(lambda v: remove_keys_from_dict( - extend_dict(v, { - 'annotation_evaluation': evaluate_document_by_page( - SvgStructuredDocument(v['svg_pages']) + if opt.save_svg: + _ = ( + annotation_results | + "SaveSvgPages" >> TransformAndLog( + beam.Map(lambda v: save_svg_roots( + FileSystems.join( + opt.output_path, + change_ext( + relative_path(opt.base_data_path, v['source_filename']), + None, '.svg.zip' + ) + ), + v['svg_pages'] + )), + log_fn=lambda x: get_logger().info('saved result: %s', x) ) - }), - None if opt.min_annotation_percentage else {'svg_pages'} - )), - log_fn=lambda x: get_logger().info( - 'annotation evaluation result: %s: %s', - x['source_filename'], x['annotation_evaluation'] ) - ) - ) - if opt.save_block_png or opt.save_tfrecords: - color_map = parse_color_map_from_file(opt.color_map) - with_block_png_pages = ( - (annotation_evaluation_results if opt.min_annotation_percentage else annotation_results) | - "GenerateBlockPng" >> beam.Map(lambda v: remove_keys_from_dict( - extend_dict(v, { - 'block_png_pages': [ - svg_page_to_blockified_png_bytes(svg_page, color_map, image_size=image_size) - for svg_page in v['svg_pages'] - ] - }), - {'svg_pages'} - )) - ) + if opt.annotation_evaluation_csv or opt.min_annotation_percentage: + annotation_evaluation_results = ( + annotation_results | + "EvaluateAnnotations" >> TransformAndLog( + beam.Map(lambda v: remove_keys_from_dict( + extend_dict(v, { + 'annotation_evaluation': evaluate_document_by_page( + SvgStructuredDocument(v['svg_pages']) + ) + }), + None if opt.min_annotation_percentage else {'svg_pages'} + )), + log_fn=lambda x: get_logger().info( + 'annotation evaluation result: %s: %s', + x['source_filename'], x['annotation_evaluation'] + ) + ) + ) - if opt.save_block_png: - _ = ( - with_block_png_pages | - "SaveBlockPng" >> TransformAndLog( - beam.Map(lambda v: save_pages( - FileSystems.join( - opt.output_path, - change_ext( - relative_path(opt.base_data_path, v['source_filename']), - None, '.block-png.zip' - ) - ), - '.png', - v['block_png_pages'] - )), - log_fn=lambda x: get_logger().info('saved result: %s', x) + if opt.save_block_png or opt.save_tfrecords: + color_map = parse_color_map_from_file(opt.color_map) + with_block_png_pages = ( + ( + annotation_evaluation_results + if opt.min_annotation_percentage + else annotation_results + ) | + "GenerateBlockPng" >> beam.Map(lambda v: remove_keys_from_dict( + extend_dict(v, { + 'block_png_pages': [ + svg_page_to_blockified_png_bytes(svg_page, color_map, image_size=image_size) + for svg_page in v['svg_pages'] + ] + }), + {'svg_pages'} + )) ) - ) - - if opt.save_tfrecords: - if opt.min_annotation_percentage: - filtered_pages = ( - with_block_png_pages | - "FilterPages" >> TransformAndCount( - beam.Map( - lambda v: filter_list_props_by_indices( - v, - get_page_indices_with_min_annotation_percentage( - v['annotation_evaluation'], - opt.min_annotation_percentage - ), - {'pdf_png_pages', 'block_png_pages'} - ) - ), - MetricCounters.FILTERED_PAGE, - lambda v: len(v['block_png_pages']) - ) + + if opt.save_block_png: + _ = ( + with_block_png_pages | + "SaveBlockPng" >> TransformAndLog( + beam.Map(lambda v: save_pages( + FileSystems.join( + opt.output_path, + change_ext( + relative_path(opt.base_data_path, v['source_filename']), + None, '.block-png.zip' + ) + ), + '.png', + v['block_png_pages'] + )), + log_fn=lambda x: get_logger().info('saved result: %s', x) + ) + ) + + if opt.save_tfrecords: + if opt.min_annotation_percentage: + filtered_pages = ( + with_block_png_pages | + "FilterPages" >> TransformAndCount( + beam.Map( + lambda v: filter_list_props_by_indices( + v, + get_page_indices_with_min_annotation_percentage( + v['annotation_evaluation'], + opt.min_annotation_percentage + ), + {'pdf_png_pages', 'block_png_pages'} + ) + ), + MetricCounters.FILTERED_PAGE, + lambda v: len(v['block_png_pages']) + ) + ) + else: + filtered_pages = with_block_png_pages + _ = ( + filtered_pages | + "WriteTFRecords" >> WritePropsToTFRecord( + FileSystems.join(opt.output_path, 'data'), + lambda v: ( + { + 'input_uri': v['source_filename'] + '#page%d' % (first_page + i), + 'input_image': pdf_png_page, + 'annotation_uri': ( + v['source_filename'] + '.annot' + '#page%d' % (first_page + i) + ), + 'annotation_image': block_png_page, + 'page_no': first_page + i + } + for i, pdf_png_page, block_png_page in zip( + range(len(v['pdf_png_pages'])), v['pdf_png_pages'], v['block_png_pages'] + ) + ) + ) + ) + + if opt.annotation_evaluation_csv: + annotation_evaluation_csv_name, annotation_evaluation_ext = ( + os.path.splitext(opt.annotation_evaluation_csv) ) - else: - filtered_pages = with_block_png_pages - _ = ( - filtered_pages | - "WriteTFRecords" >> WritePropsToTFRecord( - FileSystems.join(opt.output_path, 'data'), - lambda v: ( - { - 'input_uri': v['source_filename'] + '#page%d' % (first_page + i), - 'input_image': pdf_png_page, - 'annotation_uri': v['source_filename'] + '.annot' + '#page%d' % (first_page + i), - 'annotation_image': block_png_page, - 'page_no': first_page + i - } - for i, pdf_png_page, block_png_page in zip( - range(len(v['pdf_png_pages'])), v['pdf_png_pages'], v['block_png_pages'] + _ = ( # flake8: noqa + annotation_evaluation_results | + "FlattenAnotationEvaluationResults" >> beam.FlatMap( + lambda v: to_annotation_evaluation_csv_dict_rows( + v['annotation_evaluation'], + document=basename(v['source_filename']) + ) + ) | + "WriteAnnotationEvaluationToCsv" >> WriteDictCsv( + join_if_relative_path(opt.output_path, annotation_evaluation_csv_name), + file_name_suffix=annotation_evaluation_ext, + columns=DEFAULT_EVALUATION_COLUMNS ) - ) ) - ) - if opt.annotation_evaluation_csv: - annotation_evaluation_csv_name, annotation_evaluation_ext = ( - os.path.splitext(opt.annotation_evaluation_csv) + +def add_main_args(parser): + parser.add_argument( + '--data-path', type=str, required=True, + help='base data path' ) - _ = ( - annotation_evaluation_results | - "FlattenAnotationEvaluationResults" >> beam.FlatMap( - lambda v: to_annotation_evaluation_csv_dict_rows( - v['annotation_evaluation'], - document=basename(v['source_filename']) - ) - ) | - "WriteAnnotationEvaluationToCsv" >> WriteDictCsv( - join_if_relative_path(opt.output_path, annotation_evaluation_csv_name), - file_name_suffix=annotation_evaluation_ext, - columns=DEFAULT_EVALUATION_COLUMNS - ) + + source_group = parser.add_mutually_exclusive_group(required=True) + source_group.add_argument( + '--lxml-path', type=str, required=False, + help='path to lxml file(s)' + ) + source_group.add_argument( + '--pdf-path', type=str, required=False, + help='path to pdf file(s) (alternative to lxml)' + ) + source_group.add_argument( + '--pdf-xml-file-list', type=str, required=False, + help='path to pdf-xml csv/tsv file list' + ) + parser.add_argument( + '--limit', type=int, required=False, + help='limit the number of file pairs to process' ) -def add_main_args(parser): - parser.add_argument( - '--data-path', type=str, required=True, - help='base data path' - ) - - source_group = parser.add_mutually_exclusive_group(required=True) - source_group.add_argument( - '--lxml-path', type=str, required=False, - help='path to lxml file(s)' - ) - source_group.add_argument( - '--pdf-path', type=str, required=False, - help='path to pdf file(s) (alternative to lxml)' - ) - source_group.add_argument( - '--pdf-xml-file-list', type=str, required=False, - help='path to pdf-xml csv/tsv file list' - ) - parser.add_argument( - '--limit', type=int, required=False, - help='limit the number of file pairs to process' - ) - - parser.add_argument( - '--save-lxml', default=False, action='store_true', - help='save generated lxml (if using pdf as an input)' - ) - - parser.add_argument( - '--save-svg', default=False, action='store_true', - help='save svg pages with annotation tags' - ) - - parser.add_argument( - '--save-png', default=False, action='store_true', - help='save png pages of the original pdf' - ) - parser.add_argument( - '--png-dpi', type=int, default=90, - help='dpi of rendered pdf pages' - ) - - parser.add_argument( - '--image-width', type=int, required=False, - help='image width of resulting PNGs' - ) - parser.add_argument( - '--image-height', type=int, required=False, - help='image height of resulting PNGs' - ) - - parser.add_argument( - '--save-block-png', default=False, action='store_true', - help='save blockified version of the svg as a png' - ) - parser.add_argument( - '--color-map', default='color_map.conf', - help='color map to use (see save-block-png)' - ) - - parser.add_argument( - '--xml-path', type=str, required=False, - help='path to xml file(s)' - ) - parser.add_argument( - '--xml-mapping-path', type=str, default='annot-xml-front.conf', - help='path to xml mapping file' - ) - - parser.add_argument( - '--pages', type=parse_page_range, default=None, - help='only processes the selected pages' - ) - - parser.add_argument( - '--save-tfrecords', default=False, action='store_true', - help='Save TFRecords with PDF PNG and Annotation PNG' - ' (--image-width and --image-height recommended)' - ) - - parser.add_argument( - '--min-annotation-percentage', type=float, required=False, - help='Minimum percentage of annotations per page' - ' (pages below that threshold will get dropped)' - ) - - parser.add_argument( - '--annotation-evaluation-csv', type=str, required=False, - help='Annotation evaluation CSV output file' - ) - parser.add_argument( - '--output-path', required=False, - help='Output directory to write results to.' - ) + parser.add_argument( + '--save-lxml', default=False, action='store_true', + help='save generated lxml (if using pdf as an input)' + ) -def process_main_args(parser, args): - args.base_data_path = args.data_path.replace('/*/', '/') + parser.add_argument( + '--save-svg', default=False, action='store_true', + help='save svg pages with annotation tags' + ) - if not args.output_path: - args.output_path = os.path.join( - os.path.dirname(args.base_data_path), - os.path.basename(args.base_data_path + '-results') + parser.add_argument( + '--save-png', default=False, action='store_true', + help='save png pages of the original pdf' + ) + parser.add_argument( + '--png-dpi', type=int, default=90, + help='dpi of rendered pdf pages' ) - if not args.xml_path and not args.pdf_xml_file_list: - parser.error('--xml-path required unless --pdf-xml-file-list is specified') + parser.add_argument( + '--image-width', type=int, required=False, + help='image width of resulting PNGs' + ) + parser.add_argument( + '--image-height', type=int, required=False, + help='image height of resulting PNGs' + ) - pdf_path_or_pdf_xml_file_list = args.pdf_path or args.pdf_xml_file_list + parser.add_argument( + '--save-block-png', default=False, action='store_true', + help='save blockified version of the svg as a png' + ) + parser.add_argument( + '--color-map', default='color_map.conf', + help='color map to use (see save-block-png)' + ) - if args.save_lxml and not pdf_path_or_pdf_xml_file_list: - parser.error('--save-lxml only valid with --pdf-path or --pdf-xml-file-list') + parser.add_argument( + '--xml-path', type=str, required=False, + help='path to xml file(s)' + ) + parser.add_argument( + '--xml-mapping-path', type=str, default='annot-xml-front.conf', + help='path to xml mapping file' + ) - if args.save_png and not pdf_path_or_pdf_xml_file_list: - parser.error('--save-png only valid with --pdf-path or --pdf-xml-file-list') + parser.add_argument( + '--pages', type=parse_page_range, default=None, + help='only processes the selected pages' + ) - if args.save_tfrecords and not pdf_path_or_pdf_xml_file_list: - parser.error('--save-tfrecords only valid with --pdf-path or --pdf-xml-file-list') + parser.add_argument( + '--save-tfrecords', default=False, action='store_true', + help='Save TFRecords with PDF PNG and Annotation PNG' + ' (--image-width and --image-height recommended)' + ) - if sum(1 if x else 0 for x in (args.image_width, args.image_height)) == 1: - parser.error('--image-width and --image-height need to be specified together') + parser.add_argument( + '--min-annotation-percentage', type=float, required=False, + help='Minimum percentage of annotations per page' + ' (pages below that threshold will get dropped)' + ) - if not (args.save_lxml or args.save_svg or args.save_png or args.save_tfrecords): - parser.error( - 'at least one of the output options required:' - ' --save-lxml --save-svg --save-png or --save-tfrecords' + parser.add_argument( + '--annotation-evaluation-csv', type=str, required=False, + help='Annotation evaluation CSV output file' ) + parser.add_argument( + '--output-path', required=False, + help='Output directory to write results to.' + ) + + +def process_main_args(parser, args): + args.base_data_path = args.data_path.replace('/*/', '/') + + if not args.output_path: + args.output_path = os.path.join( + os.path.dirname(args.base_data_path), + os.path.basename(args.base_data_path + '-results') + ) + + if not args.xml_path and not args.pdf_xml_file_list: + parser.error('--xml-path required unless --pdf-xml-file-list is specified') + + pdf_path_or_pdf_xml_file_list = args.pdf_path or args.pdf_xml_file_list + + if args.save_lxml and not pdf_path_or_pdf_xml_file_list: + parser.error('--save-lxml only valid with --pdf-path or --pdf-xml-file-list') + + if args.save_png and not pdf_path_or_pdf_xml_file_list: + parser.error('--save-png only valid with --pdf-path or --pdf-xml-file-list') + + if args.save_tfrecords and not pdf_path_or_pdf_xml_file_list: + parser.error('--save-tfrecords only valid with --pdf-path or --pdf-xml-file-list') + + if sum(1 if x else 0 for x in (args.image_width, args.image_height)) == 1: + parser.error('--image-width and --image-height need to be specified together') + + if not (args.save_lxml or args.save_svg or args.save_png or args.save_tfrecords): + parser.error( + 'at least one of the output options required:' + ' --save-lxml --save-svg --save-png or --save-tfrecords' + ) + def parse_args(argv=None): - parser = argparse.ArgumentParser() - add_main_args(parser) - add_cloud_args(parser) + parser = argparse.ArgumentParser() + add_main_args(parser) + add_cloud_args(parser) - # parsed_args, other_args = parser.parse_known_args(argv) - parsed_args = parser.parse_args(argv) + # parsed_args, other_args = parser.parse_known_args(argv) + parsed_args = parser.parse_args(argv) + + process_main_args(parser, parsed_args) + process_cloud_args( + parsed_args, parsed_args.output_path, + name='sciencbeam-gym-preprocessing' + ) - process_main_args(parser, parsed_args) - process_cloud_args( - parsed_args, parsed_args.output_path, - name='sciencbeam-gym-preprocessing' - ) + get_logger().info('parsed_args: %s', parsed_args) - get_logger().info('parsed_args: %s', parsed_args) + return parsed_args - return parsed_args def run(argv=None): - """Main entry point; defines and runs the tfidf pipeline.""" - known_args = parse_args(argv) + """Main entry point; defines and runs the tfidf pipeline.""" + known_args = parse_args(argv) - # We use the save_main_session option because one or more DoFn's in this - # workflow rely on global context (e.g., a module imported at module level). - pipeline_options = PipelineOptions.from_dictionary(vars(known_args)) - pipeline_options.view_as(SetupOptions).save_main_session = True + # We use the save_main_session option because one or more DoFn's in this + # workflow rely on global context (e.g., a module imported at module level). + pipeline_options = PipelineOptions.from_dictionary(vars(known_args)) + pipeline_options.view_as(SetupOptions).save_main_session = True - with beam.Pipeline(known_args.runner, options=pipeline_options) as p: - configure_pipeline(p, known_args) + with beam.Pipeline(known_args.runner, options=pipeline_options) as p: + configure_pipeline(p, known_args) - # Execute the pipeline and wait until it is completed. + # Execute the pipeline and wait until it is completed. if __name__ == '__main__': - logging.basicConfig(level='INFO') + logging.basicConfig(level='INFO') - run() + run() diff --git a/sciencebeam_gym/preprocess/preprocessing_pipeline_test.py b/sciencebeam_gym/preprocess/preprocessing_pipeline_test.py index f444ce7..382c30a 100644 --- a/sciencebeam_gym/preprocess/preprocessing_pipeline_test.py +++ b/sciencebeam_gym/preprocess/preprocessing_pipeline_test.py @@ -7,24 +7,24 @@ import pytest import apache_beam as beam from sciencebeam_utils.beam_utils.utils import ( - TransformAndLog + TransformAndLog ) from sciencebeam_utils.beam_utils.testing import ( - BeamTest, - TestPipeline, - get_current_test_context, - get_counter_value + BeamTest, + TestPipeline, + get_current_test_context, + get_counter_value ) from sciencebeam_utils.utils.collection import ( - extend_dict + extend_dict ) from sciencebeam_gym.preprocess.preprocessing_pipeline import ( - parse_args, - configure_pipeline, - MetricCounters + parse_args, + configure_pipeline, + MetricCounters ) PREPROCESSING_PIPELINE = 'sciencebeam_gym.preprocess.preprocessing_pipeline' @@ -39,386 +39,410 @@ PDF_FILE_2 = '2/file.pdf' XML_FILE_2 = '2/file.xml' PDF_XML_FILE_LIST_FILE_1 = 'pdf-xml-files.tsv' + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def fake_content(path): - return 'fake content: %s' % path + return 'fake content: %s' % path + def fake_lxml_for_pdf(pdf, path, page_range=None): - return 'fake lxml for pdf: %s (%s) [%s]' % (pdf, path, page_range) + return 'fake lxml for pdf: %s (%s) [%s]' % (pdf, path, page_range) + + +def fake_svg_page(i=0): + return 'fake svg page: %d' % i + + +def fake_pdf_png_page(i=0): + return 'fake pdf png page: %d' % i + + +def fake_block_png_page(i=0): + return 'fake block png page: %d' % i -fake_svg_page = lambda i=0: 'fake svg page: %d' % i -fake_pdf_png_page = lambda i=0: 'fake pdf png page: %d' % i -fake_block_png_page = lambda i=0: 'fake block png page: %d' % i def get_global_tfrecords_mock(): - # workaround for mock that would get serialized/deserialized before being invoked - return get_current_test_context().tfrecords_mock + # workaround for mock that would get serialized/deserialized before being invoked + return get_current_test_context().tfrecords_mock + @contextmanager def patch_preprocessing_pipeline(**kwargs): - always_mock = { - 'find_file_pairs_grouped_by_parent_directory_or_name', - 'read_all_from_path', - 'pdf_bytes_to_png_pages', - 'convert_pdf_bytes_to_lxml', - 'convert_and_annotate_lxml_content', - 'svg_page_to_blockified_png_bytes', - 'save_svg_roots', - 'save_pages', - 'evaluate_document_by_page', - 'ReadDictCsv' - } - tfrecords_mock = Mock(name='tfrecords_mock') - - def DummyWritePropsToTFRecord(file_path, extract_props): - return TransformAndLog(beam.Map( - lambda v: tfrecords_mock(file_path, list(extract_props(v))) - ), log_fn=lambda x: get_logger().info('tfrecords: %s', x)) - - with patch.multiple( - PREPROCESSING_PIPELINE, - WritePropsToTFRecord=DummyWritePropsToTFRecord, - **{ - k: kwargs.get(k, DEFAULT) - for k in always_mock + always_mock = { + 'find_file_pairs_grouped_by_parent_directory_or_name', + 'read_all_from_path', + 'pdf_bytes_to_png_pages', + 'convert_pdf_bytes_to_lxml', + 'convert_and_annotate_lxml_content', + 'svg_page_to_blockified_png_bytes', + 'save_svg_roots', + 'save_pages', + 'evaluate_document_by_page', + 'ReadDictCsv' } - ) as mocks: - get_current_test_context().mocks = mocks - mocks['read_all_from_path'].side_effect = fake_content - mocks['convert_pdf_bytes_to_lxml'].side_effect = fake_lxml_for_pdf - yield extend_dict( - mocks, - {'tfrecords': tfrecords_mock} - ) + tfrecords_mock = Mock(name='tfrecords_mock') + + def DummyWritePropsToTFRecord(file_path, extract_props): + return TransformAndLog(beam.Map( + lambda v: tfrecords_mock(file_path, list(extract_props(v))) + ), log_fn=lambda x: get_logger().info('tfrecords: %s', x)) + + with patch.multiple( + PREPROCESSING_PIPELINE, + WritePropsToTFRecord=DummyWritePropsToTFRecord, + **{ + k: kwargs.get(k, DEFAULT) + for k in always_mock + } + ) as mocks: + get_current_test_context().mocks = mocks + mocks['read_all_from_path'].side_effect = fake_content + mocks['convert_pdf_bytes_to_lxml'].side_effect = fake_lxml_for_pdf + yield extend_dict( + mocks, + {'tfrecords': tfrecords_mock} + ) -MIN_ARGV = [ - '--data-path=' + BASE_DATA_PATH, - '--pdf-path=' + PDF_PATH, - '--xml-path=' + XML_PATH, - '--save-svg' -] -def get_default_args(): - return parse_args([ +MIN_ARGV = [ '--data-path=' + BASE_DATA_PATH, '--pdf-path=' + PDF_PATH, '--xml-path=' + XML_PATH, '--save-svg' - ]) +] + + +def get_default_args(): + return parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-path=' + PDF_PATH, + '--xml-path=' + XML_PATH, + '--save-svg' + ]) + def page_uri_suffix(page_no): - return '#page%d' % page_no + return '#page%d' % page_no + def _expected_tfrecord_props(pdf_file, page_no=1): - return { - 'input_uri': pdf_file + page_uri_suffix(page_no), - 'annotation_uri': pdf_file + '.annot' + page_uri_suffix(page_no), - 'input_image': fake_pdf_png_page(page_no), - 'annotation_image': fake_block_png_page(page_no), - 'page_no': page_no - } + return { + 'input_uri': pdf_file + page_uri_suffix(page_no), + 'annotation_uri': pdf_file + '.annot' + page_uri_suffix(page_no), + 'input_image': fake_pdf_png_page(page_no), + 'annotation_image': fake_block_png_page(page_no), + 'page_no': page_no + } + def _setup_mocks_for_pages(mocks, page_no_list, file_count=1): - mocks['convert_and_annotate_lxml_content'].return_value = [ - fake_svg_page(i) for i in page_no_list - ] - mocks['pdf_bytes_to_png_pages'].return_value = [ - fake_pdf_png_page(i) for i in page_no_list - ] - mocks['svg_page_to_blockified_png_bytes'].side_effect = [ - fake_block_png_page(i) - for _ in range(file_count) - for i in page_no_list - ] + mocks['convert_and_annotate_lxml_content'].return_value = [ + fake_svg_page(i) for i in page_no_list + ] + mocks['pdf_bytes_to_png_pages'].return_value = [ + fake_pdf_png_page(i) for i in page_no_list + ] + mocks['svg_page_to_blockified_png_bytes'].side_effect = [ + fake_block_png_page(i) + for _ in range(file_count) + for i in page_no_list + ] + @pytest.mark.slow class TestConfigurePipeline(BeamTest): - def test_should_pass_pdf_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self): - with patch_preprocessing_pipeline() as mocks: - opt = get_default_args() - opt.base_data_path = 'base' - opt.pdf_path = 'pdf' - opt.xml_path = 'xml' - with TestPipeline() as p: - mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [] - configure_pipeline(p, opt) - - mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with( - ['base/pdf', 'base/xml'] - ) - - def test_should_pass_lxml_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self): - with patch_preprocessing_pipeline() as mocks: - opt = get_default_args() - opt.base_data_path = 'base' - opt.pdf_path = '' - opt.lxml_path = 'lxml' - opt.xml_path = 'xml' - with TestPipeline() as p: - mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [] - configure_pipeline(p, opt) - - mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with( - ['base/lxml', 'base/xml'] - ) - - def test_should_write_tfrecords_from_pdf_xml_file_list(self): - with patch_preprocessing_pipeline() as mocks: - opt = get_default_args() - opt.pdf_path = None - opt.xml_path = None - opt.pdf_xml_file_list = '.temp/file-list.tsv' - opt.save_tfrecords = True - with TestPipeline() as p: - mocks['ReadDictCsv'].return_value = beam.Create([{ - 'source_url': PDF_FILE_1, - 'xml_url': XML_FILE_1 - }]) - _setup_mocks_for_pages(mocks, [1]) - configure_pipeline(p, opt) - - mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None) - mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [ - _expected_tfrecord_props(PDF_FILE_1) - ]) - - def test_should_write_multiple_tfrecords_from_pdf_xml_file_list(self): - with patch_preprocessing_pipeline() as mocks: - opt = get_default_args() - opt.pdf_path = None - opt.xml_path = None - opt.pdf_xml_file_list = '.temp/file-list.tsv' - opt.save_tfrecords = True - with TestPipeline() as p: - mocks['ReadDictCsv'].return_value = beam.Create([{ - 'source_url': PDF_FILE_1, - 'xml_url': XML_FILE_1 - }, { - 'source_url': PDF_FILE_2, - 'xml_url': XML_FILE_2 - }]) - _setup_mocks_for_pages(mocks, [1], file_count=2) - configure_pipeline(p, opt) - - mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None) - for pdf_file in [PDF_FILE_1, PDF_FILE_2]: - mocks['tfrecords'].assert_any_call(opt.output_path + '/data', [ - _expected_tfrecord_props(pdf_file) - ]) - assert mocks['tfrecords'].call_count == 2 - - def test_should_pass_limit_to_read_dict_csv(self): - with patch_preprocessing_pipeline() as mocks: - opt = get_default_args() - opt.pdf_path = None - opt.xml_path = None - opt.pdf_xml_file_list = '.temp/file-list.tsv' - opt.limit = 1 - opt.save_tfrecords = True - with TestPipeline() as p: - mocks['ReadDictCsv'].return_value = beam.Create([{ - 'source_url': PDF_FILE_1, - 'xml_url': XML_FILE_1 - }]) - _setup_mocks_for_pages(mocks, [1]) - configure_pipeline(p, opt) - - mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=opt.limit) - assert mocks['tfrecords'].call_count == 1 - - def test_should_pass_limit_to_find_file_pairs_grouped_by_parent_directory_or_name(self): - with patch_preprocessing_pipeline() as mocks: - opt = get_default_args() - opt.base_data_path = 'base' - opt.pdf_path = '' - opt.lxml_path = 'lxml' - opt.xml_path = 'xml' - opt.save_tfrecords = True - opt.limit = 1 - with TestPipeline() as p: - mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ - (PDF_FILE_1, XML_FILE_1), - (PDF_FILE_2, XML_FILE_2) - ] - configure_pipeline(p, opt) - - mocks['tfrecords'].call_count == 1 - - def test_should_write_tfrecords_from_pdf_xml_path(self): - with patch_preprocessing_pipeline() as mocks: - opt = get_default_args() - opt.save_tfrecords = True - with TestPipeline() as p: - mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ - (PDF_FILE_1, XML_FILE_1) - ] - _setup_mocks_for_pages(mocks, [1]) - configure_pipeline(p, opt) - - mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [ - _expected_tfrecord_props(PDF_FILE_1) - ]) - - def test_should_write_multiple_tfrecords_and_count_pages(self): - with patch_preprocessing_pipeline() as mocks: - opt = get_default_args() - opt.save_tfrecords = True - with TestPipeline() as p: - mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ - (PDF_FILE_1, XML_FILE_1) - ] - _setup_mocks_for_pages(mocks, [1, 2]) - configure_pipeline(p, opt) - - p_result = p.run() - assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1 - assert get_counter_value(p_result, MetricCounters.PAGE) == 2 - assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) is None - - mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [ - _expected_tfrecord_props(PDF_FILE_1, page_no=i) - for i in [1, 2] - ]) - - def test_should_not_write_tfrecord_below_annotation_threshold_and_count_pages(self): - custom_mocks = dict( - evaluate_document_by_page=lambda _: [{ - 'percentage': { - # low percentage of None (no annotation, include) - None: 0.1 - } - }, { - 'percentage': { - # low percentage of None (no annotation, exclude) - None: 0.9 - } - }] - ) - with patch_preprocessing_pipeline(**custom_mocks) as mocks: - opt = get_default_args() - opt.save_tfrecords = True - opt.min_annotation_percentage = 0.5 - with TestPipeline() as p: - mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ - (PDF_FILE_1, XML_FILE_1) - ] - _setup_mocks_for_pages(mocks, [1, 2]) - configure_pipeline(p, opt) - - p_result = p.run() - assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1 - assert get_counter_value(p_result, MetricCounters.PAGE) == 2 - assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) == 1 - - mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [ - _expected_tfrecord_props(PDF_FILE_1, page_no=i) - for i in [1] - ]) - - def test_should_only_process_selected_pages(self): - with patch_preprocessing_pipeline() as mocks: - opt = get_default_args() - opt.save_tfrecords = True - opt.save_png = True - opt.pages = (1, 3) - with TestPipeline() as p: - mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ - (PDF_FILE_1, XML_FILE_1) - ] - _setup_mocks_for_pages(mocks, [1, 2]) - configure_pipeline(p, opt) - - assert mocks['convert_pdf_bytes_to_lxml'].called - assert mocks['convert_pdf_bytes_to_lxml'].call_args[1].get('page_range') == opt.pages - - assert mocks['pdf_bytes_to_png_pages'].called - assert mocks['pdf_bytes_to_png_pages'].call_args[1].get('page_range') == opt.pages + def test_should_pass_pdf_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self): + with patch_preprocessing_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = 'base' + opt.pdf_path = 'pdf' + opt.xml_path = 'xml' + with TestPipeline() as p: + mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [] + configure_pipeline(p, opt) + + mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with( + ['base/pdf', 'base/xml'] + ) + + def test_should_pass_lxml_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self): + with patch_preprocessing_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = 'base' + opt.pdf_path = '' + opt.lxml_path = 'lxml' + opt.xml_path = 'xml' + with TestPipeline() as p: + mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [] + configure_pipeline(p, opt) + + mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with( + ['base/lxml', 'base/xml'] + ) + + def test_should_write_tfrecords_from_pdf_xml_file_list(self): + with patch_preprocessing_pipeline() as mocks: + opt = get_default_args() + opt.pdf_path = None + opt.xml_path = None + opt.pdf_xml_file_list = '.temp/file-list.tsv' + opt.save_tfrecords = True + with TestPipeline() as p: + mocks['ReadDictCsv'].return_value = beam.Create([{ + 'source_url': PDF_FILE_1, + 'xml_url': XML_FILE_1 + }]) + _setup_mocks_for_pages(mocks, [1]) + configure_pipeline(p, opt) + + mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None) + mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [ + _expected_tfrecord_props(PDF_FILE_1) + ]) + + def test_should_write_multiple_tfrecords_from_pdf_xml_file_list(self): + with patch_preprocessing_pipeline() as mocks: + opt = get_default_args() + opt.pdf_path = None + opt.xml_path = None + opt.pdf_xml_file_list = '.temp/file-list.tsv' + opt.save_tfrecords = True + with TestPipeline() as p: + mocks['ReadDictCsv'].return_value = beam.Create([{ + 'source_url': PDF_FILE_1, + 'xml_url': XML_FILE_1 + }, { + 'source_url': PDF_FILE_2, + 'xml_url': XML_FILE_2 + }]) + _setup_mocks_for_pages(mocks, [1], file_count=2) + configure_pipeline(p, opt) + + mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None) + for pdf_file in [PDF_FILE_1, PDF_FILE_2]: + mocks['tfrecords'].assert_any_call(opt.output_path + '/data', [ + _expected_tfrecord_props(pdf_file) + ]) + assert mocks['tfrecords'].call_count == 2 + + def test_should_pass_limit_to_read_dict_csv(self): + with patch_preprocessing_pipeline() as mocks: + opt = get_default_args() + opt.pdf_path = None + opt.xml_path = None + opt.pdf_xml_file_list = '.temp/file-list.tsv' + opt.limit = 1 + opt.save_tfrecords = True + with TestPipeline() as p: + mocks['ReadDictCsv'].return_value = beam.Create([{ + 'source_url': PDF_FILE_1, + 'xml_url': XML_FILE_1 + }]) + _setup_mocks_for_pages(mocks, [1]) + configure_pipeline(p, opt) + + mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=opt.limit) + assert mocks['tfrecords'].call_count == 1 + + def test_should_pass_limit_to_find_file_pairs_grouped_by_parent_directory_or_name(self): + with patch_preprocessing_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = 'base' + opt.pdf_path = 'pdf' + opt.lxml_path = '' + opt.xml_path = 'xml' + opt.save_tfrecords = True + opt.limit = 1 + with TestPipeline() as p: + mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ + (PDF_FILE_1, XML_FILE_1), + (PDF_FILE_2, XML_FILE_2) + ] + configure_pipeline(p, opt) + + assert mocks['tfrecords'].call_count == 1 + + def test_should_write_tfrecords_from_pdf_xml_path(self): + with patch_preprocessing_pipeline() as mocks: + opt = get_default_args() + opt.save_tfrecords = True + with TestPipeline() as p: + mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ + (PDF_FILE_1, XML_FILE_1) + ] + _setup_mocks_for_pages(mocks, [1]) + configure_pipeline(p, opt) + + mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [ + _expected_tfrecord_props(PDF_FILE_1) + ]) + + def test_should_write_multiple_tfrecords_and_count_pages(self): + with patch_preprocessing_pipeline() as mocks: + opt = get_default_args() + opt.save_tfrecords = True + with TestPipeline() as p: + mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ + (PDF_FILE_1, XML_FILE_1) + ] + _setup_mocks_for_pages(mocks, [1, 2]) + configure_pipeline(p, opt) + + p_result = p.run() + assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1 + assert get_counter_value(p_result, MetricCounters.PAGE) == 2 + assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) is None + + mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [ + _expected_tfrecord_props(PDF_FILE_1, page_no=i) + for i in [1, 2] + ]) + + def test_should_not_write_tfrecord_below_annotation_threshold_and_count_pages(self): + custom_mocks = dict( + evaluate_document_by_page=lambda _: [{ + 'percentage': { + # low percentage of None (no annotation, include) + None: 0.1 + } + }, { + 'percentage': { + # low percentage of None (no annotation, exclude) + None: 0.9 + } + }] + ) + with patch_preprocessing_pipeline(**custom_mocks) as mocks: + opt = get_default_args() + opt.save_tfrecords = True + opt.min_annotation_percentage = 0.5 + with TestPipeline() as p: + mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ + (PDF_FILE_1, XML_FILE_1) + ] + _setup_mocks_for_pages(mocks, [1, 2]) + configure_pipeline(p, opt) + + p_result = p.run() + assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1 + assert get_counter_value(p_result, MetricCounters.PAGE) == 2 + assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) == 1 + + mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [ + _expected_tfrecord_props(PDF_FILE_1, page_no=i) + for i in [1] + ]) + + def test_should_only_process_selected_pages(self): + with patch_preprocessing_pipeline() as mocks: + opt = get_default_args() + opt.save_tfrecords = True + opt.save_png = True + opt.pages = (1, 3) + with TestPipeline() as p: + mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [ + (PDF_FILE_1, XML_FILE_1) + ] + _setup_mocks_for_pages(mocks, [1, 2]) + configure_pipeline(p, opt) + + assert mocks['convert_pdf_bytes_to_lxml'].called + assert mocks['convert_pdf_bytes_to_lxml'].call_args[1].get('page_range') == opt.pages + + assert mocks['pdf_bytes_to_png_pages'].called + assert mocks['pdf_bytes_to_png_pages'].call_args[1].get('page_range') == opt.pages -class TestParseArgs(object): - def test_should_raise_error_without_arguments(self): - with pytest.raises(SystemExit): - parse_args([]) - def test_should_not_raise_error_with_minimum_arguments(self): - parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-svg']) +class TestParseArgs(object): + def test_should_raise_error_without_arguments(self): + with pytest.raises(SystemExit): + parse_args([]) - def test_should_not_raise_error_with_lxml_path_instead_of_pdf_path(self): - parse_args(['--data-path=test', '--lxml-path=test', '--xml-path=test', '--save-svg']) + def test_should_not_raise_error_with_minimum_arguments(self): + parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-svg']) - def test_should_raise_error_if_no_output_option_specified(self): - with pytest.raises(SystemExit): - parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test']) + def test_should_not_raise_error_with_lxml_path_instead_of_pdf_path(self): + parse_args(['--data-path=test', '--lxml-path=test', '--xml-path=test', '--save-svg']) - def test_should_raise_error_if_pdf_and_lxml_path_are_specified(self): - with pytest.raises(SystemExit): - parse_args([ - '--data-path=test', '--pdf-path=test', '--lxml-path=test', '--xml-path=test', - '--save-svg' - ]) + def test_should_raise_error_if_no_output_option_specified(self): + with pytest.raises(SystemExit): + parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test']) - def test_should_raise_error_if_pdf_path_specified_without_xml_path(self): - with pytest.raises(SystemExit): - parse_args(['--data-path=test', '--pdf-path=test', '--save-svg']) + def test_should_raise_error_if_pdf_and_lxml_path_are_specified(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=test', '--pdf-path=test', '--lxml-path=test', '--xml-path=test', + '--save-svg' + ]) - def test_should_not_raise_error_if_pdf_xml_file_list_specified_without_xml_path(self): - parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-svg']) + def test_should_raise_error_if_pdf_path_specified_without_xml_path(self): + with pytest.raises(SystemExit): + parse_args(['--data-path=test', '--pdf-path=test', '--save-svg']) - def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_path(self): - parse_args(['--data-path=test', '--pdf-path=test', '--save-lxml', '--xml-path=test']) + def test_should_not_raise_error_if_pdf_xml_file_list_specified_without_xml_path(self): + parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-svg']) - def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_xml_file_list(self): - parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-lxml', '--xml-path=test']) + def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_path(self): + parse_args(['--data-path=test', '--pdf-path=test', '--save-lxml', '--xml-path=test']) - def test_should_raise_error_if_save_lxml_specified_without_pdf_path(self): - with pytest.raises(SystemExit): - parse_args(['--data-path=test', '--lxml-path=test', '--save-lxml', '--xml-path=test']) + def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_xml_file_list(self): + parse_args(['--data-path=test', '--pdf-xml-file-list=test', + '--save-lxml', '--xml-path=test']) - def test_should_raise_error_if_save_png_is_specified_without_pdf_path(self): - with pytest.raises(SystemExit): - parse_args(['--data-path=test', '--lxml-path=test', '--save-png', '--xml-path=test']) + def test_should_raise_error_if_save_lxml_specified_without_pdf_path(self): + with pytest.raises(SystemExit): + parse_args(['--data-path=test', '--lxml-path=test', '--save-lxml', '--xml-path=test']) - def test_should_not_raise_error_with_save_png_path_together_with_pdf_path(self): - parse_args(['--data-path=test', '--pdf-path=test', '--save-png', '--xml-path=test']) + def test_should_raise_error_if_save_png_is_specified_without_pdf_path(self): + with pytest.raises(SystemExit): + parse_args(['--data-path=test', '--lxml-path=test', '--save-png', '--xml-path=test']) - def test_should_not_raise_error_with_save_png_path_together_with_pdf_xml_file_list(self): - parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-png', '--xml-path=test']) + def test_should_not_raise_error_with_save_png_path_together_with_pdf_path(self): + parse_args(['--data-path=test', '--pdf-path=test', '--save-png', '--xml-path=test']) - def test_should_raise_error_if_image_width_was_specified_without_image_height(self): - with pytest.raises(SystemExit): - parse_args([ - '--data-path=test', '--pdf-path=test', '--xml-path=test', - '--save-png', '--image-width=100' - ]) - - def test_should_raise_error_if_image_height_was_specified_without_image_width(self): - with pytest.raises(SystemExit): - parse_args([ - '--data-path=test', '--pdf-path=test', '--xml-path=test', - '--save-png', '--image-height=100' - ]) + def test_should_not_raise_error_with_save_png_path_together_with_pdf_xml_file_list(self): + parse_args([ + '--data-path=test', '--pdf-xml-file-list=test', '--save-png', '--xml-path=test' + ]) - def test_should_not_raise_error_if_both_image_width_and_height_are_specified(self): - parse_args([ - '--data-path=test', '--pdf-path=test', '--xml-path=test', - '--save-png', '--image-width=100', '--image-height=100' - ]) + def test_should_raise_error_if_image_width_was_specified_without_image_height(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=test', '--pdf-path=test', '--xml-path=test', + '--save-png', '--image-width=100' + ]) + + def test_should_raise_error_if_image_height_was_specified_without_image_width(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=test', '--pdf-path=test', '--xml-path=test', + '--save-png', '--image-height=100' + ]) + + def test_should_not_raise_error_if_both_image_width_and_height_are_specified(self): + parse_args([ + '--data-path=test', '--pdf-path=test', '--xml-path=test', + '--save-png', '--image-width=100', '--image-height=100' + ]) - def test_should_raise_error_if_save_tfrecords_specified_without_pdf_path(self): - with pytest.raises(SystemExit): - parse_args(['--data-path=test', '--lxml-path=test', '--xml-path=test', '--save-tfrecords']) + def test_should_raise_error_if_save_tfrecords_specified_without_pdf_path(self): + with pytest.raises(SystemExit): + parse_args(['--data-path=test', '--lxml-path=test', + '--xml-path=test', '--save-tfrecords']) - def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_path(self): - parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-tfrecords']) + def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_path(self): + parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-tfrecords']) - def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_xml_file_list(self): - parse_args([ - '--data-path=test', '--pdf-xml-file-list=test', '--xml-path=test', '--save-tfrecords' - ]) + def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_xml_file_list(self): + parse_args([ + '--data-path=test', '--pdf-xml-file-list=test', '--xml-path=test', '--save-tfrecords' + ]) - def test_should_have_none_page_range_by_default(self): - assert parse_args(MIN_ARGV).pages is None + def test_should_have_none_page_range_by_default(self): + assert parse_args(MIN_ARGV).pages is None - def test_should_parse_pages_as_list(self): - assert parse_args(MIN_ARGV + ['--pages=1-3']).pages == (1, 3) + def test_should_parse_pages_as_list(self): + assert parse_args(MIN_ARGV + ['--pages=1-3']).pages == (1, 3) diff --git a/sciencebeam_gym/preprocess/preprocessing_transforms.py b/sciencebeam_gym/preprocess/preprocessing_transforms.py index 20bfe32..5cbf7bf 100644 --- a/sciencebeam_gym/preprocess/preprocessing_transforms.py +++ b/sciencebeam_gym/preprocess/preprocessing_transforms.py @@ -1,34 +1,34 @@ import apache_beam as beam try: - import tensorflow as tf + import tensorflow as tf - from sciencebeam_gym.utils.tfrecord import ( - dict_to_example - ) + from sciencebeam_gym.utils.tfrecord import ( + dict_to_example + ) except ImportError: - # Make tensorflow optional - tf = None + # Make tensorflow optional + tf = None class WritePropsToTFRecord(beam.PTransform): - def __init__(self, file_path, extract_props): - super(WritePropsToTFRecord, self).__init__() - self.file_path = file_path - self.extract_props = extract_props - if tf is None: - raise RuntimeError('TensorFlow required for this transform') + def __init__(self, file_path, extract_props): + super(WritePropsToTFRecord, self).__init__() + self.file_path = file_path + self.extract_props = extract_props + if tf is None: + raise RuntimeError('TensorFlow required for this transform') - def expand(self, pcoll): # pylint: disable=W0221 - return ( - pcoll | - 'ConvertToTfExamples' >> beam.FlatMap(lambda v: ( - dict_to_example(props) - for props in self.extract_props(v) - )) | - 'SerializeToString' >> beam.Map(lambda x: x.SerializeToString()) | - 'SaveToTfRecords' >> beam.io.WriteToTFRecord( - self.file_path, - file_name_suffix='.tfrecord.gz' - ) - ) + def expand(self, pcoll): # pylint: disable=W0221 + return ( + pcoll | + 'ConvertToTfExamples' >> beam.FlatMap(lambda v: ( + dict_to_example(props) + for props in self.extract_props(v) + )) | + 'SerializeToString' >> beam.Map(lambda x: x.SerializeToString()) | + 'SaveToTfRecords' >> beam.io.WriteToTFRecord( + self.file_path, + file_name_suffix='.tfrecord.gz' + ) + ) diff --git a/sciencebeam_gym/preprocess/preprocessing_transforms_test.py b/sciencebeam_gym/preprocess/preprocessing_transforms_test.py index 732700c..9b1a170 100644 --- a/sciencebeam_gym/preprocess/preprocessing_transforms_test.py +++ b/sciencebeam_gym/preprocess/preprocessing_transforms_test.py @@ -4,20 +4,20 @@ import apache_beam as beam from apache_beam.io.filesystems import FileSystems from sciencebeam_utils.beam_utils.io import ( - find_matching_filenames + find_matching_filenames ) from sciencebeam_utils.beam_utils.testing import ( - BeamTest, - TestPipeline + BeamTest, + TestPipeline ) from sciencebeam_gym.utils.tfrecord import ( - iter_read_tfrecord_file_as_dict_list + iter_read_tfrecord_file_as_dict_list ) from sciencebeam_gym.preprocess.preprocessing_transforms import ( - WritePropsToTFRecord + WritePropsToTFRecord ) TFRECORDS_PATH = '.temp/test-data' @@ -28,35 +28,36 @@ KEY_2 = b'key2' VALUE_1 = b'value 1' VALUE_2 = b'value 2' + @pytest.mark.slow class TestWritePropsToTFRecord(BeamTest): - def test_should_write_single_entry(self): - dict_list = [{KEY_1: VALUE_1}] - with TestPipeline() as p: - _ = ( - p | - beam.Create(dict_list) | - WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x]) - ) - filenames = list(find_matching_filenames(TFRECORDS_PATH + '*')) - assert len(filenames) == 1 - records = list(iter_read_tfrecord_file_as_dict_list(filenames[0])) - assert records == dict_list - FileSystems.delete(filenames) - - def test_should_write_multiple_entries(self): - dict_list = [ - {KEY_1: VALUE_1}, - {KEY_2: VALUE_2} - ] - with TestPipeline() as p: - _ = ( - p | - beam.Create(dict_list) | - WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x]) - ) - filenames = list(find_matching_filenames(TFRECORDS_PATH + '*')) - assert len(filenames) == 1 - records = list(iter_read_tfrecord_file_as_dict_list(filenames[0])) - assert records == dict_list - FileSystems.delete(filenames) + def test_should_write_single_entry(self): + dict_list = [{KEY_1: VALUE_1}] + with TestPipeline() as p: + _ = ( # flake8: noqa + p | + beam.Create(dict_list) | + WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x]) + ) + filenames = list(find_matching_filenames(TFRECORDS_PATH + '*')) + assert len(filenames) == 1 + records = list(iter_read_tfrecord_file_as_dict_list(filenames[0])) + assert records == dict_list + FileSystems.delete(filenames) + + def test_should_write_multiple_entries(self): + dict_list = [ + {KEY_1: VALUE_1}, + {KEY_2: VALUE_2} + ] + with TestPipeline() as p: + _ = ( # flake8: noqa + p | + beam.Create(dict_list) | + WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x]) + ) + filenames = list(find_matching_filenames(TFRECORDS_PATH + '*')) + assert len(filenames) == 1 + records = list(iter_read_tfrecord_file_as_dict_list(filenames[0])) + assert records == dict_list + FileSystems.delete(filenames) diff --git a/sciencebeam_gym/preprocess/preprocessing_utils.py b/sciencebeam_gym/preprocess/preprocessing_utils.py index 62d7be8..0800000 100644 --- a/sciencebeam_gym/preprocess/preprocessing_utils.py +++ b/sciencebeam_gym/preprocess/preprocessing_utils.py @@ -1,217 +1,208 @@ from __future__ import absolute_import -import os import logging from io import BytesIO -from functools import reduce # pylint: disable=W0622 from six import iteritems from lxml import etree -from apache_beam.io.filesystems import FileSystems - from sciencebeam_alignment.align import ( - native_enabled as align_native_enabled -) - -from sciencebeam_utils.beam_utils.io import ( - find_matching_filenames + native_enabled as align_native_enabled ) from sciencebeam_utils.utils.xml import ( - xml_from_string_with_recover + xml_from_string_with_recover ) from sciencebeam_utils.utils.stopwatch import ( - StopWatchRecorder -) - -from sciencebeam_utils.utils.collection import ( - groupby_to_dict, - sort_and_groupby_to_dict -) - -from sciencebeam_utils.utils.file_path import ( - relative_path + StopWatchRecorder ) from sciencebeam_gym.utils.pages_zip import ( - save_pages + save_pages ) from sciencebeam_gym.preprocess.lxml_to_svg import ( - iter_svg_pages_for_lxml + iter_svg_pages_for_lxml ) from sciencebeam_gym.structured_document.svg import ( - SvgStructuredDocument + SvgStructuredDocument ) from sciencebeam_gym.preprocess.annotation.annotator import ( - Annotator, - DEFAULT_ANNOTATORS + Annotator, + DEFAULT_ANNOTATORS ) from sciencebeam_gym.preprocess.annotation.matching_annotator import ( - MatchingAnnotator + MatchingAnnotator ) from sciencebeam_gym.preprocess.annotation.target_annotation import ( - xml_root_to_target_annotations + xml_root_to_target_annotations ) from sciencebeam_gym.preprocess.visualize_svg_annotation import ( - visualize_svg_annotations + visualize_svg_annotations ) from sciencebeam_gym.preprocess.blockify_annotations import ( - annotation_document_page_to_annotation_blocks, - merge_blocks, - expand_blocks, - annotated_blocks_to_image + annotation_document_page_to_annotation_blocks, + merge_blocks, + expand_blocks, + annotated_blocks_to_image ) from sciencebeam_gym.pdf import ( - PdfToLxmlWrapper, - PdfToPng + PdfToLxmlWrapper, + PdfToPng ) def get_logger(): - return logging.getLogger(__name__) - -def convert_pdf_bytes_to_lxml(pdf_content, path=None, page_range=None): - stop_watch_recorder = StopWatchRecorder() + return logging.getLogger(__name__) - args = '-blocks -noImageInline -noImage -fullFontName'.split() - if page_range: - args += ['-f', str(page_range[0]), '-l', str(page_range[1])] - stop_watch_recorder.start('convert to lxml') - lxml_content = PdfToLxmlWrapper().process_input( - pdf_content, - args - ) - stop_watch_recorder.stop() - - get_logger().info( - 'converted to lxml: path=%s, pdf size=%s, lxml size=%s, timings=[%s]', - path, format(len(pdf_content), ','), format(len(lxml_content), ','), - stop_watch_recorder - ) +def convert_pdf_bytes_to_lxml(pdf_content, path=None, page_range=None): + stop_watch_recorder = StopWatchRecorder() - return lxml_content + args = '-blocks -noImageInline -noImage -fullFontName'.split() + if page_range: + args += ['-f', str(page_range[0]), '-l', str(page_range[1])] -def convert_and_annotate_lxml_content(lxml_content, xml_content, xml_mapping, name=None): - stop_watch_recorder = StopWatchRecorder() + stop_watch_recorder.start('convert to lxml') + lxml_content = PdfToLxmlWrapper().process_input( + pdf_content, + args + ) + stop_watch_recorder.stop() - stop_watch_recorder.start('parse lxml') - lxml_root = etree.fromstring(lxml_content) + get_logger().info( + 'converted to lxml: path=%s, pdf size=%s, lxml size=%s, timings=[%s]', + path, format(len(pdf_content), ','), format(len(lxml_content), ','), + stop_watch_recorder + ) - # use a more lenient way to parse xml as xml errors are not uncomment - stop_watch_recorder.start('parse xml') - xml_root = xml_from_string_with_recover(xml_content) + return lxml_content - stop_watch_recorder.start('extract target annotations') - target_annotations = xml_root_to_target_annotations( - xml_root, - xml_mapping - ) - stop_watch_recorder.stop() - annotators = DEFAULT_ANNOTATORS + [MatchingAnnotator( - target_annotations, - use_tag_begin_prefix=True - )] - annotator = Annotator(annotators) +def convert_and_annotate_lxml_content(lxml_content, xml_content, xml_mapping, name=None): + stop_watch_recorder = StopWatchRecorder() - stop_watch_recorder.start('convert to svg') - svg_roots = list(iter_svg_pages_for_lxml(lxml_root)) + stop_watch_recorder.start('parse lxml') + lxml_root = etree.fromstring(lxml_content) - stop_watch_recorder.start('annotate svg') - annotator.annotate(SvgStructuredDocument(svg_roots)) + # use a more lenient way to parse xml as xml errors are not uncomment + stop_watch_recorder.start('parse xml') + xml_root = xml_from_string_with_recover(xml_content) - stop_watch_recorder.start('add visualisation') - svg_roots = [ - visualize_svg_annotations(svg_root) - for svg_root in svg_roots - ] - stop_watch_recorder.stop() + stop_watch_recorder.start('extract target annotations') + target_annotations = xml_root_to_target_annotations( + xml_root, + xml_mapping + ) + stop_watch_recorder.stop() + + annotators = DEFAULT_ANNOTATORS + [MatchingAnnotator( + target_annotations, + use_tag_begin_prefix=True + )] + annotator = Annotator(annotators) + + stop_watch_recorder.start('convert to svg') + svg_roots = list(iter_svg_pages_for_lxml(lxml_root)) + + stop_watch_recorder.start('annotate svg') + annotator.annotate(SvgStructuredDocument(svg_roots)) + + stop_watch_recorder.start('add visualisation') + svg_roots = [ + visualize_svg_annotations(svg_root) + for svg_root in svg_roots + ] + stop_watch_recorder.stop() + + get_logger().info( + 'processed: name=%s, lxml size=%s, xml size=%s, timings=[%s] (native align impl=%s)', + name, format(len(lxml_content), ','), format(len(xml_content), ','), + stop_watch_recorder, align_native_enabled + ) - get_logger().info( - 'processed: name=%s, lxml size=%s, xml size=%s, timings=[%s] (native align impl=%s)', - name, format(len(lxml_content), ','), format(len(xml_content), ','), - stop_watch_recorder, align_native_enabled - ) + return svg_roots - return svg_roots def save_svg_roots(output_filename, svg_pages): - return save_pages(output_filename, '.svg', ( - etree.tostring(svg_page) - for svg_page in svg_pages - )) + return save_pages(output_filename, '.svg', ( + etree.tostring(svg_page) + for svg_page in svg_pages + )) + def pdf_bytes_to_png_pages(pdf_bytes, dpi, image_size, page_range=None): - pdf_to_png = PdfToPng(dpi=dpi, image_size=image_size, page_range=page_range) - return ( - fp.read() - for fp in pdf_to_png.iter_pdf_bytes_to_png_fp(pdf_bytes) - ) + pdf_to_png = PdfToPng(dpi=dpi, image_size=image_size, page_range=page_range) + return ( + fp.read() + for fp in pdf_to_png.iter_pdf_bytes_to_png_fp(pdf_bytes) + ) + def svg_page_to_blockified_png_bytes(svg_page, color_map, image_size=None): - structured_document = SvgStructuredDocument(svg_page) - blocks = expand_blocks( - merge_blocks( - annotation_document_page_to_annotation_blocks( - structured_document, - structured_document.get_pages()[0] - ) + structured_document = SvgStructuredDocument(svg_page) + blocks = expand_blocks( + merge_blocks( + annotation_document_page_to_annotation_blocks( + structured_document, + structured_document.get_pages()[0] + ) + ) ) - ) - viewbox = svg_page.attrib.get('viewBox') - if not viewbox: - raise RuntimeError( - 'viewbox missing on svg, available attributes: %s' % svg_page.attrib.keys() + viewbox = svg_page.attrib.get('viewBox') + if not viewbox: + raise RuntimeError( + 'viewbox missing on svg, available attributes: %s' % svg_page.attrib.keys() + ) + _, _, width, height = viewbox.split() + image = annotated_blocks_to_image( + blocks, color_map, + width=float(width), height=float(height), background='white', + scale_to_size=image_size ) - _, _, width, height = viewbox.split() - image = annotated_blocks_to_image( - blocks, color_map, - width=float(width), height=float(height), background='white', - scale_to_size=image_size - ) - out = BytesIO() - image.save(out, 'png') - return out.getvalue() + out = BytesIO() + image.save(out, 'png') + return out.getvalue() + def filter_list_props_by_indices(d, indices, list_props): - return { - k: ( - [x for i, x in enumerate(v) if i in indices] - if k in list_props - else v - ) - for k, v in iteritems(d) - } + return { + k: ( + [x for i, x in enumerate(v) if i in indices] + if k in list_props + else v + ) + for k, v in iteritems(d) + } + def get_page_indices_with_min_annotation_percentage( - annotation_evaluation, min_annotation_percentage): + annotation_evaluation, min_annotation_percentage): + + return [ + i + for i, page_evaluation in enumerate(annotation_evaluation) + if page_evaluation['percentage'].get(None) <= (1 - min_annotation_percentage) + ] - return [ - i - for i, page_evaluation in enumerate(annotation_evaluation) - if page_evaluation['percentage'].get(None) <= (1 - min_annotation_percentage) - ] def parse_page_range(s): - s = s.strip() - if not s: - return None - a = tuple([int(x) for x in s.split('-')]) - if len(a) == 1: - return (a[0], a[0]) - elif len(a) == 2: - return a - else: - raise TypeError('invalid page range: %s' % s) + s = s.strip() + if not s: + return None + a = tuple([int(x) for x in s.split('-')]) + if len(a) == 1: + return (a[0], a[0]) + elif len(a) == 2: + return a + else: + raise TypeError('invalid page range: %s' % s) diff --git a/sciencebeam_gym/preprocess/preprocessing_utils_test.py b/sciencebeam_gym/preprocess/preprocessing_utils_test.py index a287dd3..c6ac391 100644 --- a/sciencebeam_gym/preprocess/preprocessing_utils_test.py +++ b/sciencebeam_gym/preprocess/preprocessing_utils_test.py @@ -3,68 +3,72 @@ from mock import patch, MagicMock, DEFAULT from lxml import etree from sciencebeam_gym.structured_document.svg import ( - SVG_DOC + SVG_DOC ) from sciencebeam_gym.preprocess.preprocessing_utils import ( - svg_page_to_blockified_png_bytes, - convert_pdf_bytes_to_lxml, - parse_page_range, + svg_page_to_blockified_png_bytes, + convert_pdf_bytes_to_lxml, + parse_page_range, ) PROCESSING_UTILS = 'sciencebeam_gym.preprocess.preprocessing_utils' PDF_CONTENT_1 = b'pdf content 1' + class TestSvgPageToBlockifiedPngBytes(object): - def test_should_parse_viewbox_and_pass_width_and_height_to_annotated_blocks_to_image(self): - with patch.multiple(PROCESSING_UTILS, annotated_blocks_to_image=DEFAULT) as mocks: - svg_page = etree.Element(SVG_DOC, attrib={ - 'viewBox': '0 0 100.1 200.9' - }) - color_map = {} - image_size = (100, 200) - svg_page_to_blockified_png_bytes(svg_page, color_map, image_size) - call_args = mocks['annotated_blocks_to_image'].call_args - kwargs = call_args[1] - assert (kwargs.get('width'), kwargs.get('height')) == (100.1, 200.9) + def test_should_parse_viewbox_and_pass_width_and_height_to_annotated_blocks_to_image(self): + with patch.multiple(PROCESSING_UTILS, annotated_blocks_to_image=DEFAULT) as mocks: + svg_page = etree.Element(SVG_DOC, attrib={ + 'viewBox': '0 0 100.1 200.9' + }) + color_map = {} + image_size = (100, 200) + svg_page_to_blockified_png_bytes(svg_page, color_map, image_size) + call_args = mocks['annotated_blocks_to_image'].call_args + kwargs = call_args[1] + assert (kwargs.get('width'), kwargs.get('height')) == (100.1, 200.9) + DEFAULT_PDF_TO_LXML_ARGS = ['-blocks', '-noImageInline', '-noImage', '-fullFontName'] LXML_CONTENT_1 = b'lxml content 1' + class TestConvertPdfBytesToLxml(object): - def test_should_pass_pdf_content_and_default_args_to_process_input(self): - mock = MagicMock() - with patch.multiple(PROCESSING_UTILS, PdfToLxmlWrapper=mock): - mock.return_value.process_input.return_value = LXML_CONTENT_1 - lxml_content = convert_pdf_bytes_to_lxml(PDF_CONTENT_1) - mock.return_value.process_input.assert_called_with( - PDF_CONTENT_1, - DEFAULT_PDF_TO_LXML_ARGS - ) - assert lxml_content == LXML_CONTENT_1 - - def test_should_pass_include_page_range_in_args(self): - mock = MagicMock() - with patch.multiple(PROCESSING_UTILS, PdfToLxmlWrapper=mock): - mock.return_value.process_input.return_value = LXML_CONTENT_1 - lxml_content = convert_pdf_bytes_to_lxml(PDF_CONTENT_1, page_range=(1, 3)) - mock.return_value.process_input.assert_called_with( - PDF_CONTENT_1, - DEFAULT_PDF_TO_LXML_ARGS + ['-f', '1', '-l', '3'] - ) - assert lxml_content == LXML_CONTENT_1 + def test_should_pass_pdf_content_and_default_args_to_process_input(self): + mock = MagicMock() + with patch.multiple(PROCESSING_UTILS, PdfToLxmlWrapper=mock): + mock.return_value.process_input.return_value = LXML_CONTENT_1 + lxml_content = convert_pdf_bytes_to_lxml(PDF_CONTENT_1) + mock.return_value.process_input.assert_called_with( + PDF_CONTENT_1, + DEFAULT_PDF_TO_LXML_ARGS + ) + assert lxml_content == LXML_CONTENT_1 + + def test_should_pass_include_page_range_in_args(self): + mock = MagicMock() + with patch.multiple(PROCESSING_UTILS, PdfToLxmlWrapper=mock): + mock.return_value.process_input.return_value = LXML_CONTENT_1 + lxml_content = convert_pdf_bytes_to_lxml(PDF_CONTENT_1, page_range=(1, 3)) + mock.return_value.process_input.assert_called_with( + PDF_CONTENT_1, + DEFAULT_PDF_TO_LXML_ARGS + ['-f', '1', '-l', '3'] + ) + assert lxml_content == LXML_CONTENT_1 + class TestPageRange(object): - def test_should_parse_single_page_number_as_range(self): - assert parse_page_range('1') == (1, 1) + def test_should_parse_single_page_number_as_range(self): + assert parse_page_range('1') == (1, 1) - def test_should_parse_range_with_hyphen(self): - assert parse_page_range('1-3') == (1, 3) + def test_should_parse_range_with_hyphen(self): + assert parse_page_range('1-3') == (1, 3) - def test_should_parse_range_with_spaces(self): - assert parse_page_range(' 1 - 3 ') == (1, 3) + def test_should_parse_range_with_spaces(self): + assert parse_page_range(' 1 - 3 ') == (1, 3) - def test_should_return_none_for_empty_range(self): - assert parse_page_range('') is None + def test_should_return_none_for_empty_range(self): + assert parse_page_range('') is None diff --git a/sciencebeam_gym/preprocess/visualize_svg_annotation.py b/sciencebeam_gym/preprocess/visualize_svg_annotation.py index fcd18ec..90439ee 100644 --- a/sciencebeam_gym/preprocess/visualize_svg_annotation.py +++ b/sciencebeam_gym/preprocess/visualize_svg_annotation.py @@ -2,87 +2,97 @@ import hashlib from lxml import etree +from sciencebeam_utils.utils.collection import flatten + from sciencebeam_gym.structured_document.svg import ( - SVG_TAG_ATTRIB + SVG_TAG_ATTRIB ) DEFAULT_COLORS = [ - 'maroon', 'red', 'purple', '#c0c', 'green', '#0c0', - 'olive', '#cc0', 'navy', 'blue', 'teal', '#0cc' + 'maroon', 'red', 'purple', '#c0c', 'green', '#0c0', + 'olive', '#cc0', 'navy', 'blue', 'teal', '#0cc' ] # colors replaced due readability issues: # fuchsia, lime, yellow, aqua -flatten = lambda l: [item for sublist in l for item in sublist] - def color_for_tag(tag, colors=None): - if colors is None: - colors = DEFAULT_COLORS - h = int(hashlib.md5(tag.encode('utf8')).hexdigest(), 16) - return colors[h % len(colors)] + if colors is None: + colors = DEFAULT_COLORS + h = int(hashlib.md5(tag.encode('utf8')).hexdigest(), 16) + return colors[h % len(colors)] + def style_props_for_tag(tag): - return { - 'fill': color_for_tag(tag) - } + return { + 'fill': color_for_tag(tag) + } + def style_props_for_tags(tags): - return { - tag: style_props_for_tag(tag) - for tag in tags - } + return { + tag: style_props_for_tag(tag) + for tag in tags + } + def render_style_props(style_props): - return '\n'.join([ - ' {}: {}'.format(k, v) - for k, v in style_props.items() - ]) + return '\n'.join([ + ' {}: {}'.format(k, v) + for k, v in style_props.items() + ]) + def style_block_for_tag(tag, style_props): - return 'text[{tag_attrib}~="{tag}"] {{\n{content}\n}}'.format( - tag_attrib=SVG_TAG_ATTRIB, - tag=tag, - content=render_style_props(style_props) - ) + return 'text[{tag_attrib}~="{tag}"] {{\n{content}\n}}'.format( + tag_attrib=SVG_TAG_ATTRIB, + tag=tag, + content=render_style_props(style_props) + ) + def style_block_for_tags(tags): - style_props_map = style_props_for_tags(tags) - return '\n\n'.join([ - style_block_for_tag(tag, style_props_map[tag]) - for tag in tags - ]) + style_props_map = style_props_for_tags(tags) + return '\n\n'.join([ + style_block_for_tag(tag, style_props_map[tag]) + for tag in tags + ]) + def tags_for_node(node): - svga_tags = node.attrib.get(SVG_TAG_ATTRIB, '').strip() - if len(svga_tags) == 0: - return [] - return svga_tags.split(' ') + svga_tags = node.attrib.get(SVG_TAG_ATTRIB, '').strip() + if len(svga_tags) == 0: + return [] + return svga_tags.split(' ') + def tags_for_nodes(nodes): - return sorted(set(flatten([ - tags_for_node(node) - for node in nodes - ]))) + return sorted(set(flatten([ + tags_for_node(node) + for node in nodes + ]))) + def nodes_with_tags(svg_root): - for node in svg_root.findall('*[@{}]'.format(SVG_TAG_ATTRIB)): - yield node - for nested_node in nodes_with_tags(node): - yield nested_node + for node in svg_root.findall('*[@{}]'.format(SVG_TAG_ATTRIB)): + yield node + for nested_node in nodes_with_tags(node): + yield nested_node + def add_title_to_nodes(nodes): - for node in nodes: - tags = tags_for_node(node) - if len(tags) > 0: - title = etree.Element('title') - title.text = ' '.join(tags) - node.append(title) + for node in nodes: + tags = tags_for_node(node) + if len(tags) > 0: + title = etree.Element('title') + title.text = ' '.join(tags) + node.append(title) + def visualize_svg_annotations(svg_root): - nodes = list(nodes_with_tags(svg_root)) - style_block = etree.Element('style') - style_block.text = style_block_for_tags(tags_for_nodes(nodes)) - svg_root.insert(0, style_block) - add_title_to_nodes(nodes) - return svg_root + nodes = list(nodes_with_tags(svg_root)) + style_block = etree.Element('style') + style_block.text = style_block_for_tags(tags_for_nodes(nodes)) + svg_root.insert(0, style_block) + add_title_to_nodes(nodes) + return svg_root diff --git a/sciencebeam_gym/preprocess/visualize_svg_annotation_test.py b/sciencebeam_gym/preprocess/visualize_svg_annotation_test.py index 67b573b..3dfbf1f 100644 --- a/sciencebeam_gym/preprocess/visualize_svg_annotation_test.py +++ b/sciencebeam_gym/preprocess/visualize_svg_annotation_test.py @@ -3,19 +3,19 @@ import logging from lxml import etree from sciencebeam_gym.structured_document.svg import ( - SVG_NSMAP, - SVG_DOC, - SVG_TEXT, - SVG_TAG_ATTRIB + SVG_NSMAP, + SVG_DOC, + SVG_TEXT, + SVG_TAG_ATTRIB ) from sciencebeam_gym.preprocess.visualize_svg_annotation import ( - visualize_svg_annotations, - style_block_for_tags, - style_block_for_tag, - style_props_for_tags, - render_style_props, - color_for_tag + visualize_svg_annotations, + style_block_for_tags, + style_block_for_tag, + style_props_for_tags, + render_style_props, + color_for_tag ) TAG1 = 'tag1' @@ -28,92 +28,100 @@ logger = logging.getLogger('test') def _create_xml_node(tag, text=None, attrib=None): - node = etree.Element(tag) - if text is not None: - node.text = text - if attrib is not None: - for k, v in attrib.items(): - node.attrib[k] = str(v) - return node + node = etree.Element(tag) + if text is not None: + node.text = text + if attrib is not None: + for k, v in attrib.items(): + node.attrib[k] = str(v) + return node + def _create_tagged_text(svga_tags): - return _create_xml_node(SVG_TEXT, attrib={ - SVG_TAG_ATTRIB: svga_tags - }) + return _create_xml_node(SVG_TEXT, attrib={ + SVG_TAG_ATTRIB: svga_tags + }) def test_add_style_block_for_single_tag_on_multiple_nodes(): - svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) + svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) + + svg_root.append(_create_tagged_text(TAG1)) + svg_root.append(_create_tagged_text(TAG1)) - svg_root.append(_create_tagged_text(TAG1)) - svg_root.append(_create_tagged_text(TAG1)) + result_svg = visualize_svg_annotations(svg_root) + style_block = result_svg.find('style') - result_svg = visualize_svg_annotations(svg_root) - style_block = result_svg.find('style') + assert style_block is not None + assert style_block.text == style_block_for_tags([TAG1]) - assert style_block is not None - assert style_block.text == style_block_for_tags([TAG1]) def test_add_style_block_for_multiple_tags_on_separate_nodes(): - svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) + svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) - svg_root.append(_create_tagged_text(TAG1)) - svg_root.append(_create_tagged_text(TAG2)) + svg_root.append(_create_tagged_text(TAG1)) + svg_root.append(_create_tagged_text(TAG2)) - result_svg = visualize_svg_annotations(svg_root) - style_block = result_svg.find('style') + result_svg = visualize_svg_annotations(svg_root) + style_block = result_svg.find('style') + + assert style_block is not None + assert style_block.text == style_block_for_tags([TAG1, TAG2]) - assert style_block is not None - assert style_block.text == style_block_for_tags([TAG1, TAG2]) def test_add_style_block_for_multiple_tags_on_same_node(): - svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) + svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) + + svg_root.append(_create_tagged_text(' '.join([TAG1, TAG2]))) - svg_root.append(_create_tagged_text(' '.join([TAG1, TAG2]))) + result_svg = visualize_svg_annotations(svg_root) + style_block = result_svg.find('style') - result_svg = visualize_svg_annotations(svg_root) - style_block = result_svg.find('style') + assert style_block is not None + assert style_block.text == style_block_for_tags([TAG1, TAG2]) - assert style_block is not None - assert style_block.text == style_block_for_tags([TAG1, TAG2]) def test_add_title_with_tags(): - svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) + svg_root = etree.Element(SVG_DOC, nsmap=SVG_NSMAP) - svg_root.append(_create_tagged_text(TAG1)) + svg_root.append(_create_tagged_text(TAG1)) - result_svg = visualize_svg_annotations(svg_root) - text_node = result_svg.find(SVG_TEXT + '/title') + result_svg = visualize_svg_annotations(svg_root) + text_node = result_svg.find(SVG_TEXT + '/title') + + assert text_node is not None + assert text_node.text == TAG1 - assert text_node is not None - assert text_node.text == TAG1 def test_style_block_for_single_tag(): - style_block_text = style_block_for_tags([TAG1]) + style_block_text = style_block_for_tags([TAG1]) + + assert style_block_text == ( + style_block_for_tag(TAG1, style_props_for_tags([TAG1])[TAG1]) + ) - assert style_block_text == ( - style_block_for_tag(TAG1, style_props_for_tags([TAG1])[TAG1]) - ) def test_style_block_for_multiple_tags(): - style_block_text = style_block_for_tags([TAG1, TAG2]) - style_props_map = style_props_for_tags([TAG1, TAG2]) + style_block_text = style_block_for_tags([TAG1, TAG2]) + style_props_map = style_props_for_tags([TAG1, TAG2]) + + assert style_block_text == ( + '\n\n'.join([ + style_block_for_tag(TAG1, style_props_map[TAG1]), + style_block_for_tag(TAG2, style_props_map[TAG2]) + ]) + ) - assert style_block_text == ( - '\n\n'.join([ - style_block_for_tag(TAG1, style_props_map[TAG1]), - style_block_for_tag(TAG2, style_props_map[TAG2]) - ]) - ) def test_style_block_for_tag(): - style_props = style_props_for_tags([TAG1])[TAG1] - style_block_text = style_block_for_tag(TAG1, style_props) + style_props = style_props_for_tags([TAG1])[TAG1] + style_block_text = style_block_for_tag(TAG1, style_props) + + assert ( + style_block_text == + 'text[class~="' + TAG1 + '"] {\n' + render_style_props(style_props) + '\n}' + ) - assert ( - style_block_text == - 'text[class~="' + TAG1 + '"] {\n' + render_style_props(style_props) + '\n}' - ) def test_color_for_tag_should_be_different_for_different_tags(): - assert color_for_tag(TAG1) != color_for_tag(TAG2) + assert color_for_tag(TAG1) != color_for_tag(TAG2) diff --git a/sciencebeam_gym/structured_document/__init__.py b/sciencebeam_gym/structured_document/__init__.py index 6f8ccf1..f7aba87 100644 --- a/sciencebeam_gym/structured_document/__init__.py +++ b/sciencebeam_gym/structured_document/__init__.py @@ -11,230 +11,244 @@ LEVEL_ATTRIB_SEP = '_' SIMPLE_TAG_ATTRIB_NAME = 'tag' + def merge_token_tag( - merged_structured_document, merged_token, - other_structured_document, other_token, - source_scope=None, target_scope=None): - - tag = other_structured_document.get_tag(other_token, scope=source_scope) - if tag: - merged_structured_document.set_tag( - merged_token, - tag, - scope=target_scope - ) + merged_structured_document, merged_token, + other_structured_document, other_token, + source_scope=None, target_scope=None): + + tag = other_structured_document.get_tag(other_token, scope=source_scope) + if tag: + merged_structured_document.set_tag( + merged_token, + tag, + scope=target_scope + ) + def get_scoped_attrib_name(name, scope=None, level=None): - if level: - name = 'level%s%s%s' % (level, LEVEL_ATTRIB_SEP, name) - return '%s%s%s' % (scope, SCOPE_ATTRIB_SEP, name) if scope else name + if level: + name = 'level%s%s%s' % (level, LEVEL_ATTRIB_SEP, name) + return '%s%s%s' % (scope, SCOPE_ATTRIB_SEP, name) if scope else name + def get_attrib_by_scope(attrib, name): - suffix = '%s%s' % (SCOPE_ATTRIB_SEP, name) - return { - (None if k == name else k[:-len(suffix)]): v - for k, v in attrib.items() - if k.endswith(suffix) or k == name - } + suffix = '%s%s' % (SCOPE_ATTRIB_SEP, name) + return { + (None if k == name else k[:-len(suffix)]): v + for k, v in attrib.items() + if k.endswith(suffix) or k == name + } + def get_simple_tag_attrib_name(scope, level=None): - return get_scoped_attrib_name(SIMPLE_TAG_ATTRIB_NAME, scope, level) + return get_scoped_attrib_name(SIMPLE_TAG_ATTRIB_NAME, scope, level) + def split_tag_prefix(tag): - if tag: - if tag.startswith(B_TAG_PREFIX): - return B_TAG_PREFIX, tag[len(B_TAG_PREFIX):] - if tag.startswith(I_TAG_PREFIX): - return I_TAG_PREFIX, tag[len(I_TAG_PREFIX):] - return None, tag + if tag: + if tag.startswith(B_TAG_PREFIX): + return B_TAG_PREFIX, tag[len(B_TAG_PREFIX):] + if tag.startswith(I_TAG_PREFIX): + return I_TAG_PREFIX, tag[len(I_TAG_PREFIX):] + return None, tag + def strip_tag_prefix(tag): - return split_tag_prefix(tag)[1] + return split_tag_prefix(tag)[1] + def add_tag_prefix(tag, prefix): - return prefix + tag if prefix and tag else tag + return prefix + tag if prefix and tag else tag + class AbstractStructuredDocument(object, with_metaclass(ABCMeta)): - def clone(self): - return deepcopy(self) - - def iter_all_tokens(self): - for page in self.get_pages(): - for line in self.get_lines_of_page(page): - for token in self.get_tokens_of_line(line): - yield token - - def merge_with( - self, - other_structured_document, - merge_fn): - """ - Merges this structured document with another structured document using the merge fn. - - Note: this document will be changed (operate on a clone if that is undesired) - """ - - for merged_token, other_token in zip( - self.iter_all_tokens(), - other_structured_document.iter_all_tokens() - ): - assert ( - self.get_text(merged_token) == - other_structured_document.get_text(other_token) - ) - merge_fn( - self, merged_token, - other_structured_document, other_token - ) - - def get_tag_prefix_and_value(self, parent, scope=None, level=None): - return split_tag_prefix(self.get_tag(parent, scope=scope, level=level)) - - def get_tag_value(self, parent, scope=None): - return self.get_tag_prefix_and_value(parent, scope=scope)[1] - - def set_tag_with_prefix(self, parent, tag, scope=None, prefix=None): - self.set_tag(parent, add_tag_prefix(tag, prefix), scope=scope) - - def get_sub_tag(self, parent, scope=None): - return self.get_tag(parent, scope=scope, level=2) - - def set_sub_tag(self, parent, tag, scope=None): - self.set_tag(parent, tag, scope=scope, level=2) - - def set_sub_tag_with_prefix(self, parent, tag, scope=None, prefix=None): - self.set_sub_tag(parent, add_tag_prefix(tag, prefix), scope=scope) - - @abstractmethod - def get_pages(self): - pass - - @abstractmethod - def get_lines_of_page(self, page): - pass - - @abstractmethod - def get_tokens_of_line(self, line): - pass - - @abstractmethod - def get_x(self, parent): - pass - - @abstractmethod - def get_text(self, parent): - pass - - @abstractmethod - def get_tag(self, parent, scope=None, level=None): - pass - - @abstractmethod - def set_tag(self, parent, tag, scope=None, level=None): - pass - - @abstractmethod - def get_tag_by_scope(self, parent): - pass - - @abstractmethod - def get_bounding_box(self, parent): - pass - - @abstractmethod - def set_bounding_box(self, parent, bounding_box): - pass + def clone(self): + return deepcopy(self) + + def iter_all_tokens(self): + for page in self.get_pages(): + for line in self.get_lines_of_page(page): + for token in self.get_tokens_of_line(line): + yield token + + def merge_with( + self, + other_structured_document, + merge_fn): + """ + Merges this structured document with another structured document using the merge fn. + + Note: this document will be changed (operate on a clone if that is undesired) + """ + + for merged_token, other_token in zip( + self.iter_all_tokens(), + other_structured_document.iter_all_tokens() + ): + assert ( + self.get_text(merged_token) == + other_structured_document.get_text(other_token) + ) + merge_fn( + self, merged_token, + other_structured_document, other_token + ) + + def get_tag_prefix_and_value(self, parent, scope=None, level=None): + return split_tag_prefix(self.get_tag(parent, scope=scope, level=level)) + + def get_tag_value(self, parent, scope=None): + return self.get_tag_prefix_and_value(parent, scope=scope)[1] + + def set_tag_with_prefix(self, parent, tag, scope=None, prefix=None): + self.set_tag(parent, add_tag_prefix(tag, prefix), scope=scope) + + def get_sub_tag(self, parent, scope=None): + return self.get_tag(parent, scope=scope, level=2) + + def set_sub_tag(self, parent, tag, scope=None): + self.set_tag(parent, tag, scope=scope, level=2) + + def set_sub_tag_with_prefix(self, parent, tag, scope=None, prefix=None): + self.set_sub_tag(parent, add_tag_prefix(tag, prefix), scope=scope) + + @abstractmethod + def get_pages(self): + pass + + @abstractmethod + def get_lines_of_page(self, page): + pass + + @abstractmethod + def get_tokens_of_line(self, line): + pass + + @abstractmethod + def get_x(self, parent): + pass + + @abstractmethod + def get_text(self, parent): + pass + + @abstractmethod + def get_tag(self, parent, scope=None, level=None): + pass + + @abstractmethod + def set_tag(self, parent, tag, scope=None, level=None): + pass + + @abstractmethod + def get_tag_by_scope(self, parent): + pass + + @abstractmethod + def get_bounding_box(self, parent): + pass + + @abstractmethod + def set_bounding_box(self, parent, bounding_box): + pass + class SimpleElement(object): - def __init__(self, bounding_box=None): - self._bounding_box = bounding_box + def __init__(self, bounding_box=None): + self._bounding_box = bounding_box - def get_bounding_box(self): - return self._bounding_box + def get_bounding_box(self): + return self._bounding_box + + def set_bounding_box(self, bounding_box): + self._bounding_box = bounding_box - def set_bounding_box(self, bounding_box): - self._bounding_box = bounding_box class SimpleToken(SimpleElement): - def __init__( - self, text, attrib=None, tag=None, tag_scope=None, tag_prefix=None, **kwargs): - super(SimpleToken, self).__init__(**kwargs) - self.text = text - if attrib is None: - attrib = {} - self.attrib = attrib - if tag is not None: - self.set_tag(tag, scope=tag_scope, prefix=tag_prefix) + def __init__( + self, text, attrib=None, tag=None, tag_scope=None, tag_prefix=None, **kwargs): + super(SimpleToken, self).__init__(**kwargs) + self.text = text + if attrib is None: + attrib = {} + self.attrib = attrib + if tag is not None: + self.set_tag(tag, scope=tag_scope, prefix=tag_prefix) + + def get_x(self): + return self.attrib.get('x') - def get_x(self): - return self.attrib.get('x') + def get_y(self): + return self.attrib.get('y') - def get_y(self): - return self.attrib.get('y') + def get_tag(self, scope=None, level=None): + return self.attrib.get(get_simple_tag_attrib_name(scope=scope, level=level)) - def get_tag(self, scope=None, level=None): - return self.attrib.get(get_simple_tag_attrib_name(scope=scope, level=level)) + def set_tag(self, tag, scope=None, level=None, prefix=None): + self.attrib[get_simple_tag_attrib_name( + scope=scope, level=level)] = add_tag_prefix(tag, prefix) - def set_tag(self, tag, scope=None, level=None, prefix=None): - self.attrib[get_simple_tag_attrib_name(scope=scope, level=level)] = add_tag_prefix(tag, prefix) + def get_tag_by_scope(self): + return get_attrib_by_scope(self.attrib, SIMPLE_TAG_ATTRIB_NAME) - def get_tag_by_scope(self): - return get_attrib_by_scope(self.attrib, SIMPLE_TAG_ATTRIB_NAME) + def get_text(self): + return self.text - def get_text(self): - return self.text + def __repr__(self): + return '%s(%s)' % (type(self).__name__, self.text) - def __repr__(self): - return '%s(%s)' % (type(self).__name__, self.text) class SimpleLine(SimpleElement): - def __init__(self, tokens): - super(SimpleLine, self).__init__() - self.tokens = tokens + def __init__(self, tokens): + super(SimpleLine, self).__init__() + self.tokens = tokens + class SimplePage(SimpleElement): - def __init__(self, lines, **kwargs): - super(SimplePage, self).__init__(**kwargs) - self.lines = lines + def __init__(self, lines, **kwargs): + super(SimplePage, self).__init__(**kwargs) + self.lines = lines + class SimpleStructuredDocument(AbstractStructuredDocument): - def __init__(self, page_or_pages=None, lines=None): - if lines is not None: - pages = [SimplePage(lines)] - elif page_or_pages is None: - pages = [] - elif isinstance(page_or_pages, list): - pages = page_or_pages - else: - pages = [page_or_pages] - self._pages = pages + def __init__(self, page_or_pages=None, lines=None): + if lines is not None: + pages = [SimplePage(lines)] + elif page_or_pages is None: + pages = [] + elif isinstance(page_or_pages, list): + pages = page_or_pages + else: + pages = [page_or_pages] + self._pages = pages - def get_pages(self): - return self._pages + def get_pages(self): + return self._pages - def get_lines_of_page(self, page): - return page.lines + def get_lines_of_page(self, page): + return page.lines - def get_tokens_of_line(self, line): - return line.tokens + def get_tokens_of_line(self, line): + return line.tokens - def get_x(self, parent): - return parent.get_x() + def get_x(self, parent): + return parent.get_x() - def get_text(self, parent): - return parent.get_text() + def get_text(self, parent): + return parent.get_text() - def get_tag(self, parent, scope=None, level=None): - return parent.get_tag(scope=scope, level=level) + def get_tag(self, parent, scope=None, level=None): + return parent.get_tag(scope=scope, level=level) - def set_tag(self, parent, tag, scope=None, level=None): - return parent.set_tag(tag, scope=scope, level=level) + def set_tag(self, parent, tag, scope=None, level=None): + return parent.set_tag(tag, scope=scope, level=level) - def get_tag_by_scope(self, parent): - return parent.get_tag_by_scope() + def get_tag_by_scope(self, parent): + return parent.get_tag_by_scope() - def get_bounding_box(self, parent): - return parent.get_bounding_box() + def get_bounding_box(self, parent): + return parent.get_bounding_box() - def set_bounding_box(self, parent, bounding_box): - parent.set_bounding_box(bounding_box) + def set_bounding_box(self, parent, bounding_box): + parent.set_bounding_box(bounding_box) diff --git a/sciencebeam_gym/structured_document/__init___test.py b/sciencebeam_gym/structured_document/__init___test.py index 4cb537c..3ecb681 100644 --- a/sciencebeam_gym/structured_document/__init___test.py +++ b/sciencebeam_gym/structured_document/__init___test.py @@ -3,10 +3,10 @@ from functools import partial import pytest from sciencebeam_gym.structured_document import ( - SimpleStructuredDocument, - SimpleLine, - SimpleToken, - merge_token_tag + SimpleStructuredDocument, + SimpleLine, + SimpleToken, + merge_token_tag ) TEXT_1 = 'text 1' @@ -17,109 +17,111 @@ TAG_2 = 'tag2' SCOPE_1 = 'scope1' -class TestAbstractStructuredDocumentMergeWith(object): - def test_should_merge_single_token_and_add_prefix(self): - merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1) - ])]) - other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_2) - ])]) - merged_structured_document.merge_with( - other_structured_document, - partial( - merge_token_tag, - target_scope=SCOPE_1 - ) - ) - merged_tokens = list(merged_structured_document.iter_all_tokens()) - assert ( - [merged_structured_document.get_text(t) for t in merged_tokens] == - [TEXT_1] - ) - assert ( - [merged_structured_document.get_tag(t) for t in merged_tokens] == - [TAG_1] - ) - assert ( - [merged_structured_document.get_tag(t, scope=SCOPE_1) for t in merged_tokens] == - [TAG_2] - ) - def test_should_not_fail_with_absent_tags(self): - merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1) - ])]) - other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1) - ])]) - merged_structured_document.merge_with( - other_structured_document, - partial( - merge_token_tag, - target_scope=SCOPE_1 - ) - ) - merged_tokens = list(merged_structured_document.iter_all_tokens()) - assert ( - [merged_structured_document.get_tag(t, scope=SCOPE_1) for t in merged_tokens] == - [None] - ) +class TestAbstractStructuredDocumentMergeWith(object): + def test_should_merge_single_token_and_add_prefix(self): + merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1) + ])]) + other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_2) + ])]) + merged_structured_document.merge_with( + other_structured_document, + partial( + merge_token_tag, + target_scope=SCOPE_1 + ) + ) + merged_tokens = list(merged_structured_document.iter_all_tokens()) + assert ( + [merged_structured_document.get_text(t) for t in merged_tokens] == + [TEXT_1] + ) + assert ( + [merged_structured_document.get_tag(t) for t in merged_tokens] == + [TAG_1] + ) + assert ( + [merged_structured_document.get_tag(t, scope=SCOPE_1) for t in merged_tokens] == + [TAG_2] + ) - def test_should_not_override_with_empty_tags(self): - merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1) - ])]) - other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1) - ])]) - merged_structured_document.merge_with( - other_structured_document, - partial( - merge_token_tag - ) - ) - merged_tokens = list(merged_structured_document.iter_all_tokens()) - assert ( - [merged_structured_document.get_tag(t) for t in merged_tokens] == - [TAG_1] - ) + def test_should_not_fail_with_absent_tags(self): + merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1) + ])]) + other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1) + ])]) + merged_structured_document.merge_with( + other_structured_document, + partial( + merge_token_tag, + target_scope=SCOPE_1 + ) + ) + merged_tokens = list(merged_structured_document.iter_all_tokens()) + assert ( + [merged_structured_document.get_tag(t, scope=SCOPE_1) for t in merged_tokens] == + [None] + ) - def test_should_raise_assertion_error_if_tokens_mismatch(self): - merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_1, tag=TAG_1) - ])]) - other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ - SimpleToken(TEXT_2, tag=TAG_2) - ])]) - with pytest.raises(AssertionError): - merged_structured_document.merge_with( - other_structured_document, - partial( - merge_token_tag, - target_scope=SCOPE_1 + def test_should_not_override_with_empty_tags(self): + merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1) + ])]) + other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1) + ])]) + merged_structured_document.merge_with( + other_structured_document, + partial( + merge_token_tag + ) + ) + merged_tokens = list(merged_structured_document.iter_all_tokens()) + assert ( + [merged_structured_document.get_tag(t) for t in merged_tokens] == + [TAG_1] ) - ) + + def test_should_raise_assertion_error_if_tokens_mismatch(self): + merged_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_1, tag=TAG_1) + ])]) + other_structured_document = SimpleStructuredDocument(lines=[SimpleLine([ + SimpleToken(TEXT_2, tag=TAG_2) + ])]) + with pytest.raises(AssertionError): + merged_structured_document.merge_with( + other_structured_document, + partial( + merge_token_tag, + target_scope=SCOPE_1 + ) + ) + class TestSimpleStructuredDocument(object): - def test_should_set_tag_without_scope(self): - token = SimpleToken(TEXT_1) - doc = SimpleStructuredDocument(lines=[SimpleLine([token])]) - doc.set_tag(token, TAG_1) - assert doc.get_tag(token) == TAG_1 + def test_should_set_tag_without_scope(self): + token = SimpleToken(TEXT_1) + doc = SimpleStructuredDocument(lines=[SimpleLine([token])]) + doc.set_tag(token, TAG_1) + assert doc.get_tag(token) == TAG_1 - def test_should_set_tag_with_scope(self): - token = SimpleToken(TEXT_1) - doc = SimpleStructuredDocument(lines=[SimpleLine([token])]) - doc.set_tag(token, TAG_1, scope=SCOPE_1) - assert doc.get_tag(token, scope=SCOPE_1) == TAG_1 - assert doc.get_tag(token) is None + def test_should_set_tag_with_scope(self): + token = SimpleToken(TEXT_1) + doc = SimpleStructuredDocument(lines=[SimpleLine([token])]) + doc.set_tag(token, TAG_1, scope=SCOPE_1) + assert doc.get_tag(token, scope=SCOPE_1) == TAG_1 + assert doc.get_tag(token) is None - def test_should_return_all_tag_by_scope(self): - token = SimpleToken(TEXT_1) - doc = SimpleStructuredDocument(lines=[SimpleLine([token])]) - doc.set_tag(token, TAG_1) - doc.set_tag(token, TAG_2, scope=SCOPE_1) - assert doc.get_tag(token) == TAG_1 - assert doc.get_tag(token, scope=SCOPE_1) == TAG_2 - assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2} + def test_should_return_all_tag_by_scope(self): + token = SimpleToken(TEXT_1) + doc = SimpleStructuredDocument(lines=[SimpleLine([token])]) + doc.set_tag(token, TAG_1) + doc.set_tag(token, TAG_2, scope=SCOPE_1) + assert doc.get_tag(token) == TAG_1 + assert doc.get_tag(token, scope=SCOPE_1) == TAG_2 + assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2} diff --git a/sciencebeam_gym/structured_document/lxml.py b/sciencebeam_gym/structured_document/lxml.py index 5219529..1b04f65 100644 --- a/sciencebeam_gym/structured_document/lxml.py +++ b/sciencebeam_gym/structured_document/lxml.py @@ -1,63 +1,66 @@ from sciencebeam_utils.utils.xml import ( - set_or_remove_attrib + set_or_remove_attrib ) from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.structured_document import ( - AbstractStructuredDocument, - get_scoped_attrib_name, - get_attrib_by_scope + AbstractStructuredDocument, + get_scoped_attrib_name, + get_attrib_by_scope ) TAG_ATTRIB_NAME = 'tag' + def get_node_bounding_box(t): - return BoundingBox( - float(t.attrib.get('x', 0)), - float(t.attrib.get('y', 0)), - float(t.attrib['width']), - float(t.attrib['height']) - ) + return BoundingBox( + float(t.attrib.get('x', 0)), + float(t.attrib.get('y', 0)), + float(t.attrib['width']), + float(t.attrib['height']) + ) + def _get_tag_attrib_name(scope, level): - return get_scoped_attrib_name(TAG_ATTRIB_NAME, scope=scope, level=level) + return get_scoped_attrib_name(TAG_ATTRIB_NAME, scope=scope, level=level) + class LxmlStructuredDocument(AbstractStructuredDocument): - def __init__(self, root): - self.root = root + def __init__(self, root): + self.root = root - def get_pages(self): - return self.root.findall('.//PAGE') + def get_pages(self): + return self.root.findall('.//PAGE') - def get_lines_of_page(self, page): - return page.findall('.//TEXT') + def get_lines_of_page(self, page): + return page.findall('.//TEXT') - def get_tokens_of_line(self, line): - return line.findall('./TOKEN') + def get_tokens_of_line(self, line): + return line.findall('./TOKEN') - def get_x(self, parent): - return parent.attrib.get('x') + def get_x(self, parent): + return parent.attrib.get('x') - def get_text(self, parent): - return parent.text + def get_text(self, parent): + return parent.text - def get_tag(self, parent, scope=None, level=None): - return parent.attrib.get(_get_tag_attrib_name(scope, level)) + def get_tag(self, parent, scope=None, level=None): + return parent.attrib.get(_get_tag_attrib_name(scope, level)) - def set_tag(self, parent, tag, scope=None, level=None): - set_or_remove_attrib(parent.attrib, _get_tag_attrib_name(scope, level), tag) + def set_tag(self, parent, tag, scope=None, level=None): + set_or_remove_attrib(parent.attrib, _get_tag_attrib_name(scope, level), tag) - def get_tag_by_scope(self, parent): - return get_attrib_by_scope(parent.attrib, TAG_ATTRIB_NAME) + def get_tag_by_scope(self, parent): + return get_attrib_by_scope(parent.attrib, TAG_ATTRIB_NAME) - def get_bounding_box(self, parent): - return get_node_bounding_box(parent) + def get_bounding_box(self, parent): + return get_node_bounding_box(parent) - def set_bounding_box(self, parent, bounding_box): - parent.attrib['x'] = str(bounding_box.x) - parent.attrib['y'] = str(bounding_box.y) - parent.attrib['width'] = str(bounding_box.width) - parent.attrib['height'] = str(bounding_box.height) + def set_bounding_box(self, parent, bounding_box): + parent.attrib['x'] = str(bounding_box.x) + parent.attrib['y'] = str(bounding_box.y) + parent.attrib['width'] = str(bounding_box.width) + parent.attrib['height'] = str(bounding_box.height) diff --git a/sciencebeam_gym/structured_document/lxml_test.py b/sciencebeam_gym/structured_document/lxml_test.py index 774b991..ef361fc 100644 --- a/sciencebeam_gym/structured_document/lxml_test.py +++ b/sciencebeam_gym/structured_document/lxml_test.py @@ -3,11 +3,11 @@ from __future__ import absolute_import from lxml.builder import E from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.structured_document.lxml import ( - LxmlStructuredDocument + LxmlStructuredDocument ) TAG_1 = 'tag1' @@ -15,142 +15,143 @@ TAG_2 = 'tag2' SCOPE_1 = 'scope1' + class TestLxmlStructuredDocument(object): - def test_should_find_pages(self): - pages = [ - E.PAGE(), - E.PAGE() - ] - doc = LxmlStructuredDocument( - E.DOCUMENT( - *pages - ) - ) - assert list(doc.get_pages()) == pages - - def test_should_find_lines_of_page_without_blocks(self): - lines = [ - E.TEXT(), - E.TEXT() - ] - page = E.PAGE(*lines) - doc = LxmlStructuredDocument( - E.DOCUMENT( - page, - # add another page just for effect - E.PAGE( - E.TEXT() + def test_should_find_pages(self): + pages = [ + E.PAGE(), + E.PAGE() + ] + doc = LxmlStructuredDocument( + E.DOCUMENT( + *pages + ) + ) + assert list(doc.get_pages()) == pages + + def test_should_find_lines_of_page_without_blocks(self): + lines = [ + E.TEXT(), + E.TEXT() + ] + page = E.PAGE(*lines) + doc = LxmlStructuredDocument( + E.DOCUMENT( + page, + # add another page just for effect + E.PAGE( + E.TEXT() + ) + ) ) - ) - ) - assert list(doc.get_lines_of_page(page)) == lines - - def test_should_find_lines_of_page_with_blocks(self): - lines = [ - E.TEXT(), - E.TEXT() - ] - page = E.PAGE(E.BLOCK(*lines)) - doc = LxmlStructuredDocument( - E.DOCUMENT( - page, - # add another page just for effect - E.PAGE( - E.BLOCK(E.TEXT()) + assert list(doc.get_lines_of_page(page)) == lines + + def test_should_find_lines_of_page_with_blocks(self): + lines = [ + E.TEXT(), + E.TEXT() + ] + page = E.PAGE(E.BLOCK(*lines)) + doc = LxmlStructuredDocument( + E.DOCUMENT( + page, + # add another page just for effect + E.PAGE( + E.BLOCK(E.TEXT()) + ) + ) ) - ) - ) - assert list(doc.get_lines_of_page(page)) == lines - - def test_should_find_tokens_of_line(self): - tokens = [ - E.TOKEN(), - E.TOKEN() - ] - line = E.TEXT(*tokens) - doc = LxmlStructuredDocument( - E.DOCUMENT( - E.PAGE( - line, - E.TEXT(E.TOKEN) + assert list(doc.get_lines_of_page(page)) == lines + + def test_should_find_tokens_of_line(self): + tokens = [ + E.TOKEN(), + E.TOKEN() + ] + line = E.TEXT(*tokens) + doc = LxmlStructuredDocument( + E.DOCUMENT( + E.PAGE( + line, + E.TEXT(E.TOKEN) + ) + ) ) - ) - ) - assert list(doc.get_tokens_of_line(line)) == tokens - - def test_should_calculate_default_bounding_box(self): - token = E.TOKEN({ - 'x': '10', - 'y': '11', - 'width': '100', - 'height': '101' - }) - doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.TEXT(token)))) - assert doc.get_bounding_box(token) == BoundingBox(10, 11, 100, 101) - - def test_should_be_able_to_set_bounding_box(self): - bounding_box = BoundingBox(10, 11, 100, 101) - token = E.TOKEN({ - 'x': '20', - 'y': '21', - 'width': '200', - 'height': '201' - }) - doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.TEXT(token)))) - doc.set_bounding_box(token, bounding_box) - assert doc.get_bounding_box(token) == bounding_box - - def test_should_calculate_bounding_box_of_page_without_xy(self): - page = E.PAGE({ - 'width': '100', - 'height': '101' - }) - doc = LxmlStructuredDocument(E.DOCUMENT(page)) - assert doc.get_bounding_box(page) == BoundingBox(0, 0, 100, 101) - - def test_should_set_tag_without_scope(self): - token = E.TEXT() - doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) - doc.set_tag(token, TAG_1) - assert doc.get_tag(token) == TAG_1 - - def test_should_set_tag_with_scope(self): - token = E.TEXT() - doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) - doc.set_tag(token, TAG_1, scope=SCOPE_1) - assert doc.get_tag(token, scope=SCOPE_1) == TAG_1 - assert doc.get_tag(token) is None - - def test_should_set_tag_with_level(self): - token = E.TEXT() - doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) - doc.set_tag(token, TAG_1, level=2) - assert doc.get_tag(token, level=2) == TAG_1 - assert doc.get_tag(token) is None - - def test_should_clear_tag_when_setting_tag_to_none(self): - token = E.TEXT() - doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) - doc.set_tag(token, TAG_1) - doc.set_tag(token, TAG_1, scope=SCOPE_1) - doc.set_tag(token, None) - doc.set_tag(token, None, scope=SCOPE_1) - assert doc.get_tag(token) is None - assert doc.get_tag(token, scope=SCOPE_1) is None - - def test_should_not_fail_setting_empty_tag_to_none(self): - token = E.TEXT() - doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) - doc.set_tag(token, None) - doc.set_tag(token, None, scope=SCOPE_1) - assert doc.get_tag(token) is None - assert doc.get_tag(token, scope=SCOPE_1) is None - - def test_should_return_all_tag_by_scope(self): - token = E.TEXT() - doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) - doc.set_tag(token, TAG_1) - doc.set_tag(token, TAG_2, scope=SCOPE_1) - assert doc.get_tag(token) == TAG_1 - assert doc.get_tag(token, scope=SCOPE_1) == TAG_2 - assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2} + assert list(doc.get_tokens_of_line(line)) == tokens + + def test_should_calculate_default_bounding_box(self): + token = E.TOKEN({ + 'x': '10', + 'y': '11', + 'width': '100', + 'height': '101' + }) + doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.TEXT(token)))) + assert doc.get_bounding_box(token) == BoundingBox(10, 11, 100, 101) + + def test_should_be_able_to_set_bounding_box(self): + bounding_box = BoundingBox(10, 11, 100, 101) + token = E.TOKEN({ + 'x': '20', + 'y': '21', + 'width': '200', + 'height': '201' + }) + doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.TEXT(token)))) + doc.set_bounding_box(token, bounding_box) + assert doc.get_bounding_box(token) == bounding_box + + def test_should_calculate_bounding_box_of_page_without_xy(self): + page = E.PAGE({ + 'width': '100', + 'height': '101' + }) + doc = LxmlStructuredDocument(E.DOCUMENT(page)) + assert doc.get_bounding_box(page) == BoundingBox(0, 0, 100, 101) + + def test_should_set_tag_without_scope(self): + token = E.TEXT() + doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) + doc.set_tag(token, TAG_1) + assert doc.get_tag(token) == TAG_1 + + def test_should_set_tag_with_scope(self): + token = E.TEXT() + doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) + doc.set_tag(token, TAG_1, scope=SCOPE_1) + assert doc.get_tag(token, scope=SCOPE_1) == TAG_1 + assert doc.get_tag(token) is None + + def test_should_set_tag_with_level(self): + token = E.TEXT() + doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) + doc.set_tag(token, TAG_1, level=2) + assert doc.get_tag(token, level=2) == TAG_1 + assert doc.get_tag(token) is None + + def test_should_clear_tag_when_setting_tag_to_none(self): + token = E.TEXT() + doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) + doc.set_tag(token, TAG_1) + doc.set_tag(token, TAG_1, scope=SCOPE_1) + doc.set_tag(token, None) + doc.set_tag(token, None, scope=SCOPE_1) + assert doc.get_tag(token) is None + assert doc.get_tag(token, scope=SCOPE_1) is None + + def test_should_not_fail_setting_empty_tag_to_none(self): + token = E.TEXT() + doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) + doc.set_tag(token, None) + doc.set_tag(token, None, scope=SCOPE_1) + assert doc.get_tag(token) is None + assert doc.get_tag(token, scope=SCOPE_1) is None + + def test_should_return_all_tag_by_scope(self): + token = E.TEXT() + doc = LxmlStructuredDocument(E.DOCUMENT(E.PAGE(E.BLOCK(token)))) + doc.set_tag(token, TAG_1) + doc.set_tag(token, TAG_2, scope=SCOPE_1) + assert doc.get_tag(token) == TAG_1 + assert doc.get_tag(token, scope=SCOPE_1) == TAG_2 + assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2} diff --git a/sciencebeam_gym/structured_document/structured_document_loader.py b/sciencebeam_gym/structured_document/structured_document_loader.py index 89fd82f..1798dd0 100644 --- a/sciencebeam_gym/structured_document/structured_document_loader.py +++ b/sciencebeam_gym/structured_document/structured_document_loader.py @@ -1,59 +1,66 @@ from __future__ import absolute_import -from zipfile import ZipFile - from lxml import etree from lxml.builder import E from apache_beam.io.filesystems import FileSystems from sciencebeam_gym.utils.pages_zip import ( - load_pages + load_pages ) from sciencebeam_gym.structured_document.lxml import ( - LxmlStructuredDocument + LxmlStructuredDocument ) from sciencebeam_gym.structured_document.svg import ( - SvgStructuredDocument + SvgStructuredDocument ) + class StructuredDocumentType(object): - LXML = 'lxml' - SVG_PAGES = 'svg-pages' + LXML = 'lxml' + SVG_PAGES = 'svg-pages' + def get_structuctured_document_type(filename): - if filename.endswith('.zip'): - return StructuredDocumentType.SVG_PAGES - return StructuredDocumentType.LXML + if filename.endswith('.zip'): + return StructuredDocumentType.SVG_PAGES + return StructuredDocumentType.LXML + def load_lxml_structured_document(filename, page_range=None): - with FileSystems.open(filename) as f: - structured_document = LxmlStructuredDocument(etree.parse(f).getroot()) - if page_range: - structured_document = LxmlStructuredDocument( - E.DOCUMENT( - *structured_document.get_pages()[ - max(0, page_range[0] - 1): - page_range[1] - ] - ) - ) - return structured_document + with FileSystems.open(filename) as f: + structured_document = LxmlStructuredDocument(etree.parse(f).getroot()) + if page_range: + structured_document = LxmlStructuredDocument( + E.DOCUMENT( + *structured_document.get_pages()[ + max(0, page_range[0] - 1): + page_range[1] + ] + ) + ) + return structured_document + def load_svg_pages_structured_document(filename, page_range=None): - return SvgStructuredDocument([ - etree.parse(svg_f).getroot() - for svg_f in load_pages(filename, page_range=page_range) - ]) + return SvgStructuredDocument([ + etree.parse(svg_f).getroot() + for svg_f in load_pages(filename, page_range=page_range) + ]) + def load_structured_document(filename, page_range=None): - structured_document_type = get_structuctured_document_type(filename) - if structured_document_type == StructuredDocumentType.LXML: - return load_lxml_structured_document(filename, page_range=page_range) - if structured_document_type == StructuredDocumentType.SVG_PAGES: - return load_svg_pages_structured_document(filename, page_range=page_range) + structured_document_type = get_structuctured_document_type(filename) + if structured_document_type == StructuredDocumentType.LXML: + return load_lxml_structured_document(filename, page_range=page_range) + if structured_document_type == StructuredDocumentType.SVG_PAGES: + return load_svg_pages_structured_document(filename, page_range=page_range) + raise RuntimeError('unsupported structured_document_type: %s (%s)' % ( + structured_document_type, filename + )) + def load_structured_documents_from_file_list(file_list, page_range=None): - return (load_structured_document(s, page_range=page_range) for s in file_list) + return (load_structured_document(s, page_range=page_range) for s in file_list) diff --git a/sciencebeam_gym/structured_document/structured_document_loader_test.py b/sciencebeam_gym/structured_document/structured_document_loader_test.py index 651802a..5842ccb 100644 --- a/sciencebeam_gym/structured_document/structured_document_loader_test.py +++ b/sciencebeam_gym/structured_document/structured_document_loader_test.py @@ -10,12 +10,12 @@ from lxml.builder import E import sciencebeam_gym.structured_document.structured_document_loader as structured_document_loader from sciencebeam_gym.structured_document.structured_document_loader import ( - StructuredDocumentType, - get_structuctured_document_type, - load_structured_documents_from_file_list, - load_lxml_structured_document, - load_svg_pages_structured_document, - load_structured_document + StructuredDocumentType, + get_structuctured_document_type, + load_structured_documents_from_file_list, + load_lxml_structured_document, + load_svg_pages_structured_document, + load_structured_document ) FILE_1 = 'file1.pdf' @@ -23,93 +23,98 @@ FILE_2 = 'file2.pdf' PAGE_RANGE = (2, 3) + class TestLoadLxmlStructuredDocument(object): - def test_should_load_file(self): - lxml_content = etree.tostring(E.test('test')) - with NamedTemporaryFile() as f: - f.write(lxml_content) - f.flush() - structured_document = load_lxml_structured_document(f.name) - assert etree.tostring(structured_document.root) == lxml_content - assert hasattr(structured_document.root, 'attrib') - - def test_should_limit_page_range(self): - lxml_content = etree.tostring(E.DOCUMENT( - E.PAGE('page 1'), - E.PAGE('page 2'), - E.PAGE('page 3'), - E.PAGE('page 4') - )) - with NamedTemporaryFile() as f: - f.write(lxml_content) - f.flush() - structured_document = load_lxml_structured_document(f.name, page_range=(2, 3)) - assert [x.text for x in structured_document.get_pages()] == ['page 2', 'page 3'] + def test_should_load_file(self): + lxml_content = etree.tostring(E.test('test')) + with NamedTemporaryFile() as f: + f.write(lxml_content) + f.flush() + structured_document = load_lxml_structured_document(f.name) + assert etree.tostring(structured_document.root) == lxml_content + assert hasattr(structured_document.root, 'attrib') + + def test_should_limit_page_range(self): + lxml_content = etree.tostring(E.DOCUMENT( + E.PAGE('page 1'), + E.PAGE('page 2'), + E.PAGE('page 3'), + E.PAGE('page 4') + )) + with NamedTemporaryFile() as f: + f.write(lxml_content) + f.flush() + structured_document = load_lxml_structured_document(f.name, page_range=(2, 3)) + assert [x.text for x in structured_document.get_pages()] == ['page 2', 'page 3'] + class TestLoadSvgPagesStructuredDocument(object): - def test_should_load_file_with_multiple_pages(self): - svg_pages_content = [ - etree.tostring(E.svg('page 1')), - etree.tostring(E.svg('page 2')) - ] - with NamedTemporaryFile() as f: - with ZipFile(f, 'w') as zf: - for i, svg_page_content in enumerate(svg_pages_content): - zf.writestr('page-%d.svg' % (1 + i), svg_page_content) - f.flush() - structured_document = load_svg_pages_structured_document(f.name) - assert ( - [etree.tostring(x) for x in structured_document.page_roots] == - svg_pages_content - ) - assert hasattr(structured_document.page_roots[0], 'attrib') - - def test_should_limit_page_range(self): - svg_pages_content = [ - etree.tostring(E.svg('page 1')), - etree.tostring(E.svg('page 2')), - etree.tostring(E.svg('page 3')), - etree.tostring(E.svg('page 4')) - ] - with NamedTemporaryFile() as f: - with ZipFile(f, 'w') as zf: - for i, svg_page_content in enumerate(svg_pages_content): - zf.writestr('page-%d.svg' % (1 + i), svg_page_content) - f.flush() - structured_document = load_svg_pages_structured_document(f.name, page_range=(2, 3)) - assert ( - [x.text for x in structured_document.page_roots] == - ['page 2', 'page 3'] - ) + def test_should_load_file_with_multiple_pages(self): + svg_pages_content = [ + etree.tostring(E.svg('page 1')), + etree.tostring(E.svg('page 2')) + ] + with NamedTemporaryFile() as f: + with ZipFile(f, 'w') as zf: + for i, svg_page_content in enumerate(svg_pages_content): + zf.writestr('page-%d.svg' % (1 + i), svg_page_content) + f.flush() + structured_document = load_svg_pages_structured_document(f.name) + assert ( + [etree.tostring(x) for x in structured_document.page_roots] == + svg_pages_content + ) + assert hasattr(structured_document.page_roots[0], 'attrib') + + def test_should_limit_page_range(self): + svg_pages_content = [ + etree.tostring(E.svg('page 1')), + etree.tostring(E.svg('page 2')), + etree.tostring(E.svg('page 3')), + etree.tostring(E.svg('page 4')) + ] + with NamedTemporaryFile() as f: + with ZipFile(f, 'w') as zf: + for i, svg_page_content in enumerate(svg_pages_content): + zf.writestr('page-%d.svg' % (1 + i), svg_page_content) + f.flush() + structured_document = load_svg_pages_structured_document(f.name, page_range=(2, 3)) + assert ( + [x.text for x in structured_document.page_roots] == + ['page 2', 'page 3'] + ) + class TestGetStructuredDocumentType(object): - def test_should_return_lxml_for_lxml_file(self): - assert get_structuctured_document_type('file.lxml') == StructuredDocumentType.LXML + def test_should_return_lxml_for_lxml_file(self): + assert get_structuctured_document_type('file.lxml') == StructuredDocumentType.LXML - def test_should_return_lxml_for_lxml_gz_file(self): - assert get_structuctured_document_type('file.lxml.gz') == StructuredDocumentType.LXML + def test_should_return_lxml_for_lxml_gz_file(self): + assert get_structuctured_document_type('file.lxml.gz') == StructuredDocumentType.LXML + + def test_should_return_lxml_for_svg_zip_file(self): + assert get_structuctured_document_type('file.svg.zip') == StructuredDocumentType.SVG_PAGES - def test_should_return_lxml_for_svg_zip_file(self): - assert get_structuctured_document_type('file.svg.zip') == StructuredDocumentType.SVG_PAGES class TestLoadStructuredDocument(object): - def test_should_call_load_plain_file_list_if_file(self): - with patch.object(structured_document_loader, 'load_lxml_structured_document') as mock: - result = load_structured_document('file.lxml', page_range=PAGE_RANGE) - mock.assert_called_with('file.lxml', page_range=PAGE_RANGE) - assert result == mock.return_value - - def test_should_call_load_csv_or_tsv_file_list_if_file(self): - with patch.object(structured_document_loader, 'load_svg_pages_structured_document') as mock: - result = load_structured_document('file.svg.zip', page_range=PAGE_RANGE) - mock.assert_called_with('file.svg.zip', page_range=PAGE_RANGE) - assert result == mock.return_value + def test_should_call_load_plain_file_list_if_file(self): + with patch.object(structured_document_loader, 'load_lxml_structured_document') as mock: + result = load_structured_document('file.lxml', page_range=PAGE_RANGE) + mock.assert_called_with('file.lxml', page_range=PAGE_RANGE) + assert result == mock.return_value + + def test_should_call_load_csv_or_tsv_file_list_if_file(self): + with patch.object(structured_document_loader, 'load_svg_pages_structured_document') as mock: + result = load_structured_document('file.svg.zip', page_range=PAGE_RANGE) + mock.assert_called_with('file.svg.zip', page_range=PAGE_RANGE) + assert result == mock.return_value + class TestLoadStructuredDocumentsFromFileList(object): - def test_should_single_file(self): - with patch.object(structured_document_loader, 'load_structured_document') as mock: - assert ( - list(load_structured_documents_from_file_list([FILE_1], page_range=PAGE_RANGE)) == - [mock.return_value] - ) - mock.assert_called_with(FILE_1, page_range=PAGE_RANGE) + def test_should_single_file(self): + with patch.object(structured_document_loader, 'load_structured_document') as mock: + assert ( + list(load_structured_documents_from_file_list([FILE_1], page_range=PAGE_RANGE)) == + [mock.return_value] + ) + mock.assert_called_with(FILE_1, page_range=PAGE_RANGE) diff --git a/sciencebeam_gym/structured_document/structured_document_saver.py b/sciencebeam_gym/structured_document/structured_document_saver.py index 5ca17bd..7b541df 100644 --- a/sciencebeam_gym/structured_document/structured_document_saver.py +++ b/sciencebeam_gym/structured_document/structured_document_saver.py @@ -3,33 +3,36 @@ from __future__ import absolute_import from lxml import etree from sciencebeam_utils.beam_utils.io import ( - save_file_content + save_file_content ) from sciencebeam_gym.utils.pages_zip import ( - save_pages + save_pages ) from sciencebeam_gym.structured_document.lxml import ( - LxmlStructuredDocument + LxmlStructuredDocument ) from sciencebeam_gym.structured_document.svg import ( - SvgStructuredDocument + SvgStructuredDocument ) + def save_lxml_structured_document(filename, lxml_structured_document): - save_file_content(filename, etree.tostring(lxml_structured_document.root)) + save_file_content(filename, etree.tostring(lxml_structured_document.root)) + def save_svg_structured_document(filename, svg_structured_document): - return save_pages(filename, '.svg', ( - etree.tostring(svg_page) - for svg_page in svg_structured_document.get_pages() - )) + return save_pages(filename, '.svg', ( + etree.tostring(svg_page) + for svg_page in svg_structured_document.get_pages() + )) + def save_structured_document(filename, structured_document): - if isinstance(structured_document, LxmlStructuredDocument): - return save_lxml_structured_document(filename, structured_document) - if isinstance(structured_document, SvgStructuredDocument): - return save_svg_structured_document(filename, structured_document) - raise RuntimeError('unsupported type: %s' % type(structured_document)) + if isinstance(structured_document, LxmlStructuredDocument): + return save_lxml_structured_document(filename, structured_document) + if isinstance(structured_document, SvgStructuredDocument): + return save_svg_structured_document(filename, structured_document) + raise RuntimeError('unsupported type: %s' % type(structured_document)) diff --git a/sciencebeam_gym/structured_document/structured_document_saver_test.py b/sciencebeam_gym/structured_document/structured_document_saver_test.py index fdbece8..1cbc455 100644 --- a/sciencebeam_gym/structured_document/structured_document_saver_test.py +++ b/sciencebeam_gym/structured_document/structured_document_saver_test.py @@ -5,53 +5,56 @@ from mock import patch, ANY from lxml.builder import E from sciencebeam_gym.structured_document.lxml import ( - LxmlStructuredDocument + LxmlStructuredDocument ) from sciencebeam_gym.structured_document.svg import ( - SvgStructuredDocument + SvgStructuredDocument ) import sciencebeam_gym.structured_document.structured_document_saver as structured_document_saver from sciencebeam_gym.structured_document.structured_document_saver import ( - save_lxml_structured_document, - save_svg_structured_document, - save_structured_document + save_lxml_structured_document, + save_svg_structured_document, + save_structured_document ) FILE_1 = 'file1' + class TestSaveLxmlStructuredDocument(object): - def test_should_call_save_file_content(self): - m = structured_document_saver - root = E.DOCUMENT() - with patch.object(m, 'save_file_content') as save_file_content: - with patch.object(m, 'etree') as etree: - save_lxml_structured_document(FILE_1, LxmlStructuredDocument(root)) - save_file_content.assert_called_with(FILE_1, etree.tostring(root)) + def test_should_call_save_file_content(self): + m = structured_document_saver + root = E.DOCUMENT() + with patch.object(m, 'save_file_content') as save_file_content: + with patch.object(m, 'etree') as etree: + save_lxml_structured_document(FILE_1, LxmlStructuredDocument(root)) + save_file_content.assert_called_with(FILE_1, etree.tostring(root)) + class TestSaveSvgStructuredDocument(object): - def test_should_call_save_pages(self): - m = structured_document_saver - root = E.svg() - with patch.object(m, 'save_pages') as save_pages: - with patch.object(m, 'etree') as etree: - save_svg_structured_document(FILE_1, SvgStructuredDocument(root)) - save_pages.assert_called_with(FILE_1, '.svg', ANY) - args, _ = save_pages.call_args - assert list(args[2]) == [etree.tostring(root)] + def test_should_call_save_pages(self): + m = structured_document_saver + root = E.svg() + with patch.object(m, 'save_pages') as save_pages: + with patch.object(m, 'etree') as etree: + save_svg_structured_document(FILE_1, SvgStructuredDocument(root)) + save_pages.assert_called_with(FILE_1, '.svg', ANY) + args, _ = save_pages.call_args + assert list(args[2]) == [etree.tostring(root)] + class TestSaveStructuredDocument(object): - def test_should_call_save_lxml_structured_document(self): - structured_document = LxmlStructuredDocument(E.DOCUMENT) - m = structured_document_saver - with patch.object(m, 'save_lxml_structured_document') as save_lxml_structured_document_mock: - save_structured_document(FILE_1, structured_document) - save_lxml_structured_document_mock.assert_called_with(FILE_1, structured_document) - - def test_should_call_save_svg_structured_document(self): - structured_document = SvgStructuredDocument(E.svg) - m = structured_document_saver - with patch.object(m, 'save_svg_structured_document') as save_svg_structured_document_mock: - save_structured_document(FILE_1, structured_document) - save_svg_structured_document_mock.assert_called_with(FILE_1, structured_document) + def test_should_call_save_lxml_structured_document(self): + structured_document = LxmlStructuredDocument(E.DOCUMENT) + m = structured_document_saver + with patch.object(m, 'save_lxml_structured_document') as save_lxml_structured_document_mock: + save_structured_document(FILE_1, structured_document) + save_lxml_structured_document_mock.assert_called_with(FILE_1, structured_document) + + def test_should_call_save_svg_structured_document(self): + structured_document = SvgStructuredDocument(E.svg) + m = structured_document_saver + with patch.object(m, 'save_svg_structured_document') as save_svg_structured_document_mock: + save_structured_document(FILE_1, structured_document) + save_svg_structured_document_mock.assert_called_with(FILE_1, structured_document) diff --git a/sciencebeam_gym/structured_document/svg.py b/sciencebeam_gym/structured_document/svg.py index 194ccf9..c9ecd43 100644 --- a/sciencebeam_gym/structured_document/svg.py +++ b/sciencebeam_gym/structured_document/svg.py @@ -1,15 +1,15 @@ from sciencebeam_utils.utils.xml import ( - set_or_remove_attrib + set_or_remove_attrib ) from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.structured_document import ( - AbstractStructuredDocument, - get_scoped_attrib_name, - get_attrib_by_scope + AbstractStructuredDocument, + get_scoped_attrib_name, + get_attrib_by_scope ) SVG_NS = 'http://www.w3.org/2000/svg' @@ -30,89 +30,95 @@ SVGE_BOUNDING_BOX = SVGE_NS_PREFIX + 'bounding-box' SCOPED_TAG_ATTRIB_SUFFIX = 'tag' SVG_NSMAP = { - None : SVG_NS, - 'svge': SVGE_NS + None: SVG_NS, + 'svge': SVGE_NS } + class SvgStyleClasses(object): - LINE = 'line' - BLOCK = 'block' - LINE_NO = 'line_no' + LINE = 'line' + BLOCK = 'block' + LINE_NO = 'line_no' + def format_bounding_box(bounding_box): - return '%s %s %s %s' % (bounding_box.x, bounding_box.y, bounding_box.width, bounding_box.height) + return '%s %s %s %s' % (bounding_box.x, bounding_box.y, bounding_box.width, bounding_box.height) + def parse_bounding_box(bounding_box_str): - if not bounding_box_str: - return None - x, y, width, height = bounding_box_str.split() - return BoundingBox(float(x), float(y), float(width), float(height)) + if not bounding_box_str: + return None + x, y, width, height = bounding_box_str.split() + return BoundingBox(float(x), float(y), float(width), float(height)) + def get_node_bounding_box(t): - attrib = t.attrib - if SVGE_BOUNDING_BOX in attrib: - return parse_bounding_box(attrib[SVGE_BOUNDING_BOX]) - if SVG_VIEWBOX_ATTRIB in attrib: - return parse_bounding_box(attrib[SVG_VIEWBOX_ATTRIB]) - if not ('font-size' in attrib and 'x' in attrib and 'y' in attrib): - return None - font_size = float(attrib['font-size']) - width = font_size * 0.8 * max(1, len(t.text)) - return BoundingBox( - float(attrib['x']), - float(attrib['y']), - width, - font_size - ) + attrib = t.attrib + if SVGE_BOUNDING_BOX in attrib: + return parse_bounding_box(attrib[SVGE_BOUNDING_BOX]) + if SVG_VIEWBOX_ATTRIB in attrib: + return parse_bounding_box(attrib[SVG_VIEWBOX_ATTRIB]) + if not ('font-size' in attrib and 'x' in attrib and 'y' in attrib): + return None + font_size = float(attrib['font-size']) + width = font_size * 0.8 * max(1, len(t.text)) + return BoundingBox( + float(attrib['x']), + float(attrib['y']), + width, + font_size + ) + def _get_tag_attrib_name(scope, level): - return ( - SVGE_NS_PREFIX + get_scoped_attrib_name(SCOPED_TAG_ATTRIB_SUFFIX, scope=scope, level=level) - if scope or level - else SVG_TAG_ATTRIB - ) + return ( + SVGE_NS_PREFIX + get_scoped_attrib_name(SCOPED_TAG_ATTRIB_SUFFIX, scope=scope, level=level) + if scope or level + else SVG_TAG_ATTRIB + ) + class SvgStructuredDocument(AbstractStructuredDocument): - def __init__(self, root_or_roots): - if isinstance(root_or_roots, list): - self.page_roots = root_or_roots - else: - self.page_roots = [root_or_roots] + def __init__(self, root_or_roots): + if isinstance(root_or_roots, list): + self.page_roots = root_or_roots + else: + self.page_roots = [root_or_roots] - def get_pages(self): - return self.page_roots + def get_pages(self): + return self.page_roots - def get_lines_of_page(self, page): - return page.findall('.//{}[@class="{}"]'.format(SVG_G, SvgStyleClasses.LINE)) + def get_lines_of_page(self, page): + return page.findall('.//{}[@class="{}"]'.format(SVG_G, SvgStyleClasses.LINE)) - def get_tokens_of_line(self, line): - return line.findall('./{}'.format(SVG_TEXT)) + def get_tokens_of_line(self, line): + return line.findall('./{}'.format(SVG_TEXT)) - def get_x(self, parent): - return parent.attrib.get('x') + def get_x(self, parent): + return parent.attrib.get('x') - def get_text(self, parent): - return parent.text + def get_text(self, parent): + return parent.text - def get_tag(self, parent, scope=None, level=None): - return parent.attrib.get(_get_tag_attrib_name(scope, level)) + def get_tag(self, parent, scope=None, level=None): + return parent.attrib.get(_get_tag_attrib_name(scope, level)) - def set_tag(self, parent, tag, scope=None, level=None): - set_or_remove_attrib(parent.attrib, _get_tag_attrib_name(scope, level), tag) + def set_tag(self, parent, tag, scope=None, level=None): + set_or_remove_attrib(parent.attrib, _get_tag_attrib_name(scope, level), tag) - def get_tag_by_scope(self, parent): - d = { - k[len(SVGE_NS_PREFIX):]: v - for k, v in get_attrib_by_scope(parent.attrib, SCOPED_TAG_ATTRIB_SUFFIX).items() - if k.startswith(SVGE_NS_PREFIX) - } - tag = self.get_tag(parent) - if tag: - d[None] = tag - return d + def get_tag_by_scope(self, parent): + d = { + k[len(SVGE_NS_PREFIX):]: v + for k, v in get_attrib_by_scope(parent.attrib, SCOPED_TAG_ATTRIB_SUFFIX).items() + if k.startswith(SVGE_NS_PREFIX) + } + tag = self.get_tag(parent) + if tag: + d[None] = tag + return d - def get_bounding_box(self, parent): - return get_node_bounding_box(parent) + def get_bounding_box(self, parent): + return get_node_bounding_box(parent) - def set_bounding_box(self, parent, bounding_box): - parent.attrib[SVGE_BOUNDING_BOX] = format_bounding_box(bounding_box) + def set_bounding_box(self, parent, bounding_box): + parent.attrib[SVGE_BOUNDING_BOX] = format_bounding_box(bounding_box) diff --git a/sciencebeam_gym/structured_document/svg_test.py b/sciencebeam_gym/structured_document/svg_test.py index 879597c..db8ef2c 100644 --- a/sciencebeam_gym/structured_document/svg_test.py +++ b/sciencebeam_gym/structured_document/svg_test.py @@ -3,16 +3,16 @@ from __future__ import absolute_import from lxml.builder import ElementMaker from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) from sciencebeam_gym.structured_document.svg import ( - SvgStructuredDocument, - SvgStyleClasses, - SVG_NS, - SVG_VIEWBOX_ATTRIB, - SVGE_BOUNDING_BOX, - format_bounding_box + SvgStructuredDocument, + SvgStyleClasses, + SVG_NS, + SVG_VIEWBOX_ATTRIB, + SVGE_BOUNDING_BOX, + format_bounding_box ) E = ElementMaker(namespace=SVG_NS) @@ -25,166 +25,167 @@ TAG_2 = 'tag2' SCOPE_1 = 'scope1' + class TestSvgStructuredDocument(object): - def test_should_return_root_as_pages(self): - root = E.svg() - doc = SvgStructuredDocument(root) - assert list(doc.get_pages()) == [root] - - def test_should_find_lines_of_page_without_blocks(self): - lines = [ - SVG_TEXT_LINE(), - SVG_TEXT_LINE() - ] - doc = SvgStructuredDocument( - E.svg( - *lines - ) - ) - page = doc.get_pages()[0] - assert list(doc.get_lines_of_page(page)) == lines - - def test_should_find_lines_of_page_with_blocks(self): - lines = [ - SVG_TEXT_LINE(), - SVG_TEXT_LINE() - ] - doc = SvgStructuredDocument( - E.svg( - SVG_TEXT_BLOCK( - *lines + def test_should_return_root_as_pages(self): + root = E.svg() + doc = SvgStructuredDocument(root) + assert list(doc.get_pages()) == [root] + + def test_should_find_lines_of_page_without_blocks(self): + lines = [ + SVG_TEXT_LINE(), + SVG_TEXT_LINE() + ] + doc = SvgStructuredDocument( + E.svg( + *lines + ) + ) + page = doc.get_pages()[0] + assert list(doc.get_lines_of_page(page)) == lines + + def test_should_find_lines_of_page_with_blocks(self): + lines = [ + SVG_TEXT_LINE(), + SVG_TEXT_LINE() + ] + doc = SvgStructuredDocument( + E.svg( + SVG_TEXT_BLOCK( + *lines + ) + ) + ) + page = doc.get_pages()[0] + assert list(doc.get_lines_of_page(page)) == lines + + def test_should_find_tokens_of_line(self): + tokens = [ + SVG_TEXT(), + SVG_TEXT() + ] + line = SVG_TEXT_LINE(*tokens) + doc = SvgStructuredDocument( + E.svg( + line, + SVG_TEXT_LINE(SVG_TEXT()) + ) ) - ) - ) - page = doc.get_pages()[0] - assert list(doc.get_lines_of_page(page)) == lines - - def test_should_find_tokens_of_line(self): - tokens = [ - SVG_TEXT(), - SVG_TEXT() - ] - line = SVG_TEXT_LINE(*tokens) - doc = SvgStructuredDocument( - E.svg( - line, - SVG_TEXT_LINE(SVG_TEXT()) - ) - ) - assert list(doc.get_tokens_of_line(line)) == tokens - - def test_should_tag_text_as_line_no(self): - text = SVG_TEXT() - doc = SvgStructuredDocument( - E.svg( - SVG_TEXT_LINE(text) - ) - ) - doc.set_tag(text, SvgStyleClasses.LINE_NO) - assert text.attrib['class'] == SvgStyleClasses.LINE_NO - - def test_should_calculate_default_bounding_box(self): - text = SVG_TEXT('a', { - 'x': '10', - 'y': '11', - 'font-size': '100' - }) - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) - assert doc.get_bounding_box(text) == BoundingBox(10, 11, 100 * 0.8, 100) - - def test_should_estimate_width_based_on_number_of_characters(self): - s = 'abc' - text = SVG_TEXT(s, { - 'x': '10', - 'y': '11', - 'font-size': '100' - }) - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) - assert doc.get_bounding_box(text) == BoundingBox( - 10, 11, 100 * 0.8 * len(s), 100 - ) - - def test_should_not_return_bounding_box_if_font_size_is_missing(self): - text = SVG_TEXT({ - 'x': '10', - 'y': '11' - }) - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) - assert doc.get_bounding_box(text) is None - - def test_should_use_bounding_box_if_available(self): - bounding_box = BoundingBox(11, 12, 101, 102) - text = SVG_TEXT('a', { - 'x': '10', - 'y': '11', - 'font-size': '100', - SVGE_BOUNDING_BOX: format_bounding_box(bounding_box) - }) - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) - assert doc.get_bounding_box(text) == bounding_box - - def test_should_be_able_to_set_bounding_box(self): - bounding_box = BoundingBox(11, 12, 101, 102) - text = SVG_TEXT('a', { - 'x': '10', - 'y': '11', - 'font-size': '100' - }) - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) - doc.set_bounding_box(text, bounding_box) - assert text.attrib[SVGE_BOUNDING_BOX] == format_bounding_box(bounding_box) - - def test_should_use_viewbox_if_available(self): - bounding_box = BoundingBox(11, 12, 101, 102) - page = E.svg({ - SVG_VIEWBOX_ATTRIB: format_bounding_box(bounding_box) - }) - doc = SvgStructuredDocument(page) - assert doc.get_bounding_box(page) == bounding_box - - def test_should_set_tag_without_scope(self): - token = SVG_TEXT('test') - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) - doc.set_tag(token, TAG_1) - assert doc.get_tag(token) == TAG_1 - - def test_should_set_tag_with_scope(self): - token = SVG_TEXT('test') - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) - doc.set_tag(token, TAG_1, scope=SCOPE_1) - assert doc.get_tag(token, scope=SCOPE_1) == TAG_1 - assert doc.get_tag(token) is None - - def test_should_set_tag_with_level(self): - token = SVG_TEXT('test') - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) - doc.set_tag(token, TAG_1, level=2) - assert doc.get_tag(token, level=2) == TAG_1 - assert doc.get_tag(token) is None - - def test_should_return_all_tag_by_scope(self): - token = SVG_TEXT('test') - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) - doc.set_tag(token, TAG_1) - doc.set_tag(token, TAG_2, scope=SCOPE_1) - assert doc.get_tag(token) == TAG_1 - assert doc.get_tag(token, scope=SCOPE_1) == TAG_2 - assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2} - - def test_should_clear_tag_when_setting_tag_to_none(self): - token = SVG_TEXT('test') - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) - doc.set_tag(token, TAG_1) - doc.set_tag(token, TAG_1, scope=SCOPE_1) - doc.set_tag(token, None) - doc.set_tag(token, None, scope=SCOPE_1) - assert doc.get_tag(token) is None - assert doc.get_tag(token, scope=SCOPE_1) is None - - def test_should_not_fail_setting_empty_tag_to_none(self): - token = SVG_TEXT('test') - doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) - doc.set_tag(token, None) - doc.set_tag(token, None, scope=SCOPE_1) - assert doc.get_tag(token) is None - assert doc.get_tag(token, scope=SCOPE_1) is None + assert list(doc.get_tokens_of_line(line)) == tokens + + def test_should_tag_text_as_line_no(self): + text = SVG_TEXT() + doc = SvgStructuredDocument( + E.svg( + SVG_TEXT_LINE(text) + ) + ) + doc.set_tag(text, SvgStyleClasses.LINE_NO) + assert text.attrib['class'] == SvgStyleClasses.LINE_NO + + def test_should_calculate_default_bounding_box(self): + text = SVG_TEXT('a', { + 'x': '10', + 'y': '11', + 'font-size': '100' + }) + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) + assert doc.get_bounding_box(text) == BoundingBox(10, 11, 100 * 0.8, 100) + + def test_should_estimate_width_based_on_number_of_characters(self): + s = 'abc' + text = SVG_TEXT(s, { + 'x': '10', + 'y': '11', + 'font-size': '100' + }) + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) + assert doc.get_bounding_box(text) == BoundingBox( + 10, 11, 100 * 0.8 * len(s), 100 + ) + + def test_should_not_return_bounding_box_if_font_size_is_missing(self): + text = SVG_TEXT({ + 'x': '10', + 'y': '11' + }) + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) + assert doc.get_bounding_box(text) is None + + def test_should_use_bounding_box_if_available(self): + bounding_box = BoundingBox(11, 12, 101, 102) + text = SVG_TEXT('a', { + 'x': '10', + 'y': '11', + 'font-size': '100', + SVGE_BOUNDING_BOX: format_bounding_box(bounding_box) + }) + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) + assert doc.get_bounding_box(text) == bounding_box + + def test_should_be_able_to_set_bounding_box(self): + bounding_box = BoundingBox(11, 12, 101, 102) + text = SVG_TEXT('a', { + 'x': '10', + 'y': '11', + 'font-size': '100' + }) + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(text))) + doc.set_bounding_box(text, bounding_box) + assert text.attrib[SVGE_BOUNDING_BOX] == format_bounding_box(bounding_box) + + def test_should_use_viewbox_if_available(self): + bounding_box = BoundingBox(11, 12, 101, 102) + page = E.svg({ + SVG_VIEWBOX_ATTRIB: format_bounding_box(bounding_box) + }) + doc = SvgStructuredDocument(page) + assert doc.get_bounding_box(page) == bounding_box + + def test_should_set_tag_without_scope(self): + token = SVG_TEXT('test') + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) + doc.set_tag(token, TAG_1) + assert doc.get_tag(token) == TAG_1 + + def test_should_set_tag_with_scope(self): + token = SVG_TEXT('test') + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) + doc.set_tag(token, TAG_1, scope=SCOPE_1) + assert doc.get_tag(token, scope=SCOPE_1) == TAG_1 + assert doc.get_tag(token) is None + + def test_should_set_tag_with_level(self): + token = SVG_TEXT('test') + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) + doc.set_tag(token, TAG_1, level=2) + assert doc.get_tag(token, level=2) == TAG_1 + assert doc.get_tag(token) is None + + def test_should_return_all_tag_by_scope(self): + token = SVG_TEXT('test') + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) + doc.set_tag(token, TAG_1) + doc.set_tag(token, TAG_2, scope=SCOPE_1) + assert doc.get_tag(token) == TAG_1 + assert doc.get_tag(token, scope=SCOPE_1) == TAG_2 + assert doc.get_tag_by_scope(token) == {None: TAG_1, SCOPE_1: TAG_2} + + def test_should_clear_tag_when_setting_tag_to_none(self): + token = SVG_TEXT('test') + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) + doc.set_tag(token, TAG_1) + doc.set_tag(token, TAG_1, scope=SCOPE_1) + doc.set_tag(token, None) + doc.set_tag(token, None, scope=SCOPE_1) + assert doc.get_tag(token) is None + assert doc.get_tag(token, scope=SCOPE_1) is None + + def test_should_not_fail_setting_empty_tag_to_none(self): + token = SVG_TEXT('test') + doc = SvgStructuredDocument(E.svg(SVG_TEXT_LINE(token))) + doc.set_tag(token, None) + doc.set_tag(token, None, scope=SCOPE_1) + assert doc.get_tag(token) is None + assert doc.get_tag(token, scope=SCOPE_1) is None diff --git a/sciencebeam_gym/tools/calculate_class_weights.py b/sciencebeam_gym/tools/calculate_class_weights.py index b8446ca..fed0464 100644 --- a/sciencebeam_gym/tools/calculate_class_weights.py +++ b/sciencebeam_gym/tools/calculate_class_weights.py @@ -7,254 +7,276 @@ import json import numpy as np import tensorflow as tf -from tensorflow.python.lib.io import file_io # pylint: disable=E0611 +from tensorflow.python.lib.io import file_io # pylint: disable=E0611 from tqdm import tqdm from sciencebeam_gym.preprocess.color_map import ( - parse_color_map_from_file + parse_color_map_from_file ) from sciencebeam_gym.utils.tfrecord import ( - iter_read_tfrecord_file_as_dict_list + iter_read_tfrecord_file_as_dict_list ) from sciencebeam_gym.model_utils.channels import ( - color_equals_mask_as_float, - calculate_color_masks + color_equals_mask_as_float, + calculate_color_masks ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def color_frequency(image, color): - return tf.reduce_sum(color_equals_mask_as_float(image, color)) + return tf.reduce_sum(color_equals_mask_as_float(image, color)) + def get_shape(x): - try: - return x.shape - except AttributeError: - return tf.constant(x).shape + try: + return x.shape + except AttributeError: + return tf.constant(x).shape + def calculate_sample_frequencies(image, colors, use_unknown_class=False): - color_masks = calculate_color_masks(image, colors, use_unknown_class) - return [ - tf.reduce_sum(color_mask) - for color_mask in color_masks - ] + color_masks = calculate_color_masks(image, colors, use_unknown_class) + return [ + tf.reduce_sum(color_mask) + for color_mask in color_masks + ] + def iter_calculate_sample_frequencies( - images, colors, image_shape=None, image_format=None, use_unknown_class=False): + images, colors, image_shape=None, image_format=None, use_unknown_class=False): + + with tf.Graph().as_default(): + if image_format == 'png': + image_tensor = tf.placeholder(tf.string, shape=[], name='image') + decoded_image_tensor = tf.image.decode_png(image_tensor, channels=3) + else: + if image_shape is None: + image_shape = (None, None, 3) + image_tensor = tf.placeholder(tf.uint8, shape=image_shape, name='image') + decoded_image_tensor = image_tensor + get_logger().debug('decoded_image_tensor: %s', decoded_image_tensor) + frequency_tensors = calculate_sample_frequencies( + decoded_image_tensor, colors, use_unknown_class=use_unknown_class + ) + with tf.Session() as session: + for image in images: + frequencies = session.run(frequency_tensors, { + image_tensor: image + }) + get_logger().debug('frequencies: %s', frequencies) + yield frequencies - with tf.Graph().as_default(): - if image_format == 'png': - image_tensor = tf.placeholder(tf.string, shape=[], name='image') - decoded_image_tensor = tf.image.decode_png(image_tensor, channels=3) - else: - if image_shape is None: - image_shape = (None, None, 3) - image_tensor = tf.placeholder(tf.uint8, shape=image_shape, name='image') - decoded_image_tensor = image_tensor - get_logger().debug('decoded_image_tensor: %s', decoded_image_tensor) - frequency_tensors = calculate_sample_frequencies( - decoded_image_tensor, colors, use_unknown_class=use_unknown_class - ) - with tf.Session() as session: - for image in images: - frequencies = session.run(frequency_tensors, { - image_tensor: image - }) - get_logger().debug('frequencies: %s', frequencies) - yield frequencies def tf_calculate_efnet_weights_for_frequency_by_label( - frequency_by_label, return_zero_for_zero_frequency=True): + frequency_by_label, return_zero_for_zero_frequency=True): + + total_frequency = tf.reduce_sum(frequency_by_label) + class_weights = 1.0 / tf.log(1.02 + frequency_by_label / total_frequency) + return ( + class_weights * tf.cast(tf.minimum(frequency_by_label, 1), class_weights.dtype) + if return_zero_for_zero_frequency + else class_weights + ) - total_frequency = tf.reduce_sum(frequency_by_label) - class_weights = 1.0 / tf.log(1.02 + frequency_by_label / total_frequency) - return ( - class_weights * tf.cast(tf.minimum(frequency_by_label, 1), class_weights.dtype) - if return_zero_for_zero_frequency - else class_weights - ) def calculate_efnet_weights_for_frequency_by_label(frequency_by_label): - total_frequency = sum(frequency_by_label) - return [ - 1 / np.log(1.02 + (frequency / total_frequency)) - if frequency > 0 else 0.0 - for frequency in frequency_by_label - ] + total_frequency = sum(frequency_by_label) + return [ + 1 / np.log(1.02 + (frequency / total_frequency)) + if frequency > 0 else 0.0 + for frequency in frequency_by_label + ] + def sum_frequencies_by_label(frequencies_by_label): - return [sum(x) for x in frequencies_by_label] + return [sum(x) for x in frequencies_by_label] + def calculate_efnet_weights_for_frequencies_by_label(frequencies_by_label): - return calculate_efnet_weights_for_frequency_by_label( - sum_frequencies_by_label(frequencies_by_label) - ) + return calculate_efnet_weights_for_frequency_by_label( + sum_frequencies_by_label(frequencies_by_label) + ) + def calculate_median_class_weight(class_frequencies): - """ - Perform median frequency balancing on the image files, given by the formula: - f = Median_freq_c / total_freq_c - where median_freq_c is the median frequency of the class - for all pixels of C that appeared in images - and total_freq_c is the total number of pixels of c in the total pixels - of the images where c appeared. - """ - - non_zero_frequencies = [f for f in class_frequencies if f != 0.0] - if not non_zero_frequencies: - return 0.0 - get_logger().debug('non_zero_frequencies: %s', non_zero_frequencies) - total_freq_c = sum(non_zero_frequencies) - get_logger().debug('total_freq_c: %s', total_freq_c) - median_freq_c = np.median(non_zero_frequencies) - get_logger().debug('median_freq_c: %s', median_freq_c) - return median_freq_c / total_freq_c + """ + Perform median frequency balancing on the image files, given by the formula: + f = Median_freq_c / total_freq_c + where median_freq_c is the median frequency of the class + for all pixels of C that appeared in images + and total_freq_c is the total number of pixels of c in the total pixels + of the images where c appeared. + """ + + non_zero_frequencies = [f for f in class_frequencies if f != 0.0] + if not non_zero_frequencies: + return 0.0 + get_logger().debug('non_zero_frequencies: %s', non_zero_frequencies) + total_freq_c = sum(non_zero_frequencies) + get_logger().debug('total_freq_c: %s', total_freq_c) + median_freq_c = np.median(non_zero_frequencies) + get_logger().debug('median_freq_c: %s', median_freq_c) + return median_freq_c / total_freq_c + def calculate_median_weights_for_frequencies(frequencies): - median_frequencies_balanced = [ - calculate_median_class_weight(f) - for f in frequencies - ] - total = sum(median_frequencies_balanced) - return [ - f / total - for f in median_frequencies_balanced - ] + median_frequencies_balanced = [ + calculate_median_class_weight(f) + for f in frequencies + ] + total = sum(median_frequencies_balanced) + return [ + f / total + for f in median_frequencies_balanced + ] + def parse_color_map(color_map_filename): - with file_io.FileIO(color_map_filename, 'r') as config_f: - return parse_color_map_from_file( - config_f - ) + with file_io.FileIO(color_map_filename, 'r') as config_f: + return parse_color_map_from_file( + config_f + ) + def transpose(m): - return zip(*m) + return zip(*m) + def iter_images_for_tfrecord_paths(tfrecord_paths, image_key, progress=False): - for tfrecord_path in tfrecord_paths: - get_logger().info('tfrecord_path: %s', tfrecord_path) - filenames = file_io.get_matching_files(tfrecord_path) - with tqdm(list(filenames), leave=False, disable=not progress) as pbar: - for tfrecord_filename in pbar: - pbar.set_description('%-40s' % tfrecord_filename) - get_logger().debug('tfrecord_filename: %s', tfrecord_filename) - for d in iter_read_tfrecord_file_as_dict_list(tfrecord_filename, keys={image_key}): - yield d[image_key] + for tfrecord_path in tfrecord_paths: + get_logger().info('tfrecord_path: %s', tfrecord_path) + filenames = file_io.get_matching_files(tfrecord_path) + with tqdm(list(filenames), leave=False, disable=not progress) as pbar: + for tfrecord_filename in pbar: + pbar.set_description('%-40s' % tfrecord_filename) + get_logger().debug('tfrecord_filename: %s', tfrecord_filename) + for d in iter_read_tfrecord_file_as_dict_list(tfrecord_filename, keys={image_key}): + yield d[image_key] + def calculate_median_class_weights_for_tfrecord_paths_and_colors( - tfrecord_paths, image_key, colors, use_unknown_class=False, progress=False): - - get_logger().debug('colors: %s', colors) - get_logger().info('loading tfrecords: %s', tfrecord_paths) - images = iter_images_for_tfrecord_paths(tfrecord_paths, image_key, progress=progress) - if progress: - images = list(images) - images = tqdm(images, 'analysing images', leave=False) - frequency_list = list(iter_calculate_sample_frequencies( - images, colors, image_format='png', use_unknown_class=use_unknown_class - )) - get_logger().debug('frequency_list: %s', frequency_list) - frequencies = transpose(frequency_list) - get_logger().debug('frequencies: %s', frequencies) - class_weights = calculate_median_weights_for_frequencies(frequencies) - return class_weights + tfrecord_paths, image_key, colors, use_unknown_class=False, progress=False): + + get_logger().debug('colors: %s', colors) + get_logger().info('loading tfrecords: %s', tfrecord_paths) + images = iter_images_for_tfrecord_paths(tfrecord_paths, image_key, progress=progress) + if progress: + images = list(images) + images = tqdm(images, 'analysing images', leave=False) + frequency_list = list(iter_calculate_sample_frequencies( + images, colors, image_format='png', use_unknown_class=use_unknown_class + )) + get_logger().debug('frequency_list: %s', frequency_list) + frequencies = transpose(frequency_list) + get_logger().debug('frequencies: %s', frequencies) + class_weights = calculate_median_weights_for_frequencies(frequencies) + return class_weights + def calculate_median_class_weights_for_tfrecord_paths_and_color_map( - tfrecord_paths, image_key, color_map, channels=None, - use_unknown_class=False, unknown_class_label='unknown', - progress=False): - if not channels: - channels = sorted(color_map.keys()) - colors = [color_map[k] for k in channels] - class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( - tfrecord_paths, - image_key, - colors, - progress=progress, - use_unknown_class=use_unknown_class - ) - if use_unknown_class: - channels += [unknown_class_label] - return { - k: class_weight for k, class_weight in zip(channels, class_weights) - } + tfrecord_paths, image_key, color_map, channels=None, + use_unknown_class=False, unknown_class_label='unknown', + progress=False): + if not channels: + channels = sorted(color_map.keys()) + colors = [color_map[k] for k in channels] + class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( + tfrecord_paths, + image_key, + colors, + progress=progress, + use_unknown_class=use_unknown_class + ) + if use_unknown_class: + channels += [unknown_class_label] + return { + k: class_weight for k, class_weight in zip(channels, class_weights) + } + def str_to_bool(s): - return s.lower() in ('yes', 'true', '1') + return s.lower() in ('yes', 'true', '1') + def str_to_list(s): - s = s.strip() - if not s: - return [] - return [x.strip() for x in s.split(',')] + s = s.strip() + if not s: + return [] + return [x.strip() for x in s.split(',')] + def get_args_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--tfrecord-paths', - required=True, - type=str, - action='append', - help='The paths to the tf-records files to analyse.' - ) - parser.add_argument( - '--image-key', - required=False, - type=str, - help='The name of the image key to do the class weights on.' - ) - parser.add_argument( - '--color-map', - required=True, - type=str, - help='The color-map filename.' - ) - parser.add_argument( - '--channels', - type=str_to_list, - help='The channels to use (subset of color map), otherwise all of the labels will be used' - ) - parser.add_argument( - '--use-unknown-class', - type=str_to_bool, - default=True, - help='Use unknown class channel' - ) - parser.add_argument( - '--out', - required=False, - type=str, - help='The filename the output file (json), otherwise the output will be written to stdout.' - ) - return parser + parser = argparse.ArgumentParser() + parser.add_argument( + '--tfrecord-paths', + required=True, + type=str, + action='append', + help='The paths to the tf-records files to analyse.' + ) + parser.add_argument( + '--image-key', + required=False, + type=str, + help='The name of the image key to do the class weights on.' + ) + parser.add_argument( + '--color-map', + required=True, + type=str, + help='The color-map filename.' + ) + parser.add_argument( + '--channels', + type=str_to_list, + help='The channels to use (subset of color map), otherwise all of the labels will be used' + ) + parser.add_argument( + '--use-unknown-class', + type=str_to_bool, + default=True, + help='Use unknown class channel' + ) + parser.add_argument( + '--out', + required=False, + type=str, + help='The filename the output file (json), otherwise the output will be written to stdout.' + ) + return parser + def parse_args(argv=None): - parser = get_args_parser() - parsed_args = parser.parse_args(argv) - return parsed_args + parser = get_args_parser() + parsed_args = parser.parse_args(argv) + return parsed_args + def main(argv=None): - args = parse_args(argv) - color_map = parse_color_map(args.color_map) - class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( - args.tfrecord_paths, - args.image_key, - color_map, - channels=args.channels, - use_unknown_class=args.use_unknown_class, - progress=True - ) - get_logger().info('class_weights: %s', class_weights_map) - json_str = json.dumps(class_weights_map, indent=2) - if args.out: - with file_io.FileIO(args.out, 'wb') as out_f: - out_f.write(json_str) - else: - print(json_str) + args = parse_args(argv) + color_map = parse_color_map(args.color_map) + class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( + args.tfrecord_paths, + args.image_key, + color_map, + channels=args.channels, + use_unknown_class=args.use_unknown_class, + progress=True + ) + get_logger().info('class_weights: %s', class_weights_map) + json_str = json.dumps(class_weights_map, indent=2) + if args.out: + with file_io.FileIO(args.out, 'wb') as out_f: + out_f.write(json_str) + else: + print(json_str) + if __name__ == '__main__': - logging.basicConfig(level='INFO') - main() + logging.basicConfig(level='INFO') + main() diff --git a/sciencebeam_gym/tools/calculate_class_weights_test.py b/sciencebeam_gym/tools/calculate_class_weights_test.py index 5dd9935..6fcd14b 100644 --- a/sciencebeam_gym/tools/calculate_class_weights_test.py +++ b/sciencebeam_gym/tools/calculate_class_weights_test.py @@ -6,346 +6,359 @@ from io import BytesIO from backports.tempfile import TemporaryDirectory +import tensorflow as tf +import numpy as np +from PIL import Image + from sciencebeam_utils.utils.num import ( - assert_close, - assert_all_close + assert_close, + assert_all_close ) from sciencebeam_gym.utils.tfrecord import ( - dict_to_example, - write_examples_to_tfrecord + dict_to_example, + write_examples_to_tfrecord ) from sciencebeam_gym.tools.calculate_class_weights import ( - calculate_sample_frequencies, - iter_calculate_sample_frequencies, - calculate_median_class_weight, - calculate_median_weights_for_frequencies, - calculate_median_class_weights_for_tfrecord_paths_and_colors, - calculate_median_class_weights_for_tfrecord_paths_and_color_map, - calculate_efnet_weights_for_frequencies_by_label, - tf_calculate_efnet_weights_for_frequency_by_label + calculate_sample_frequencies, + iter_calculate_sample_frequencies, + calculate_median_class_weight, + calculate_median_weights_for_frequencies, + calculate_median_class_weights_for_tfrecord_paths_and_colors, + calculate_median_class_weights_for_tfrecord_paths_and_color_map, + calculate_efnet_weights_for_frequencies_by_label, + tf_calculate_efnet_weights_for_frequency_by_label ) -import tensorflow as tf -import numpy as np -from PIL import Image def color(i): - return (i, i, i) + return (i, i, i) + COLOR_0 = color(0) COLOR_1 = color(1) COLOR_2 = color(2) COLOR_3 = color(3) + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class TestCalculateSampleFrequencies(object): - def test_should_return_zero_for_single_not_matching_color(self): - with tf.Session() as session: - assert session.run(calculate_sample_frequencies([[ - COLOR_0 - ]], [COLOR_1])) == [0.0] - - def test_should_return_one_for_single_matching_color(self): - with tf.Session() as session: - assert session.run(calculate_sample_frequencies([[ - COLOR_1 - ]], [COLOR_1])) == [1.0] - - def test_should_return_total_count_for_multiple_all_matching_color(self): - with tf.Session() as session: - assert session.run(calculate_sample_frequencies([[ - COLOR_1, COLOR_1, COLOR_1 - ]], [COLOR_1])) == [3.0] - - def test_should_return_total_count_for_multiple_mixed_color(self): - with tf.Session() as session: - assert session.run(calculate_sample_frequencies([[ - COLOR_1, COLOR_1, COLOR_2 - ]], [COLOR_1, COLOR_2])) == [2.0, 1.0] - - def test_should_include_unknown_class_count_if_enabled(self): - with tf.Session() as session: - assert session.run(calculate_sample_frequencies([[ - COLOR_1, COLOR_2, COLOR_3 - ]], [COLOR_1], use_unknown_class=True)) == [1.0, 2.0] + def test_should_return_zero_for_single_not_matching_color(self): + with tf.Session() as session: + assert session.run(calculate_sample_frequencies([[ + COLOR_0 + ]], [COLOR_1])) == [0.0] + + def test_should_return_one_for_single_matching_color(self): + with tf.Session() as session: + assert session.run(calculate_sample_frequencies([[ + COLOR_1 + ]], [COLOR_1])) == [1.0] + + def test_should_return_total_count_for_multiple_all_matching_color(self): + with tf.Session() as session: + assert session.run(calculate_sample_frequencies([[ + COLOR_1, COLOR_1, COLOR_1 + ]], [COLOR_1])) == [3.0] + + def test_should_return_total_count_for_multiple_mixed_color(self): + with tf.Session() as session: + assert session.run(calculate_sample_frequencies([[ + COLOR_1, COLOR_1, COLOR_2 + ]], [COLOR_1, COLOR_2])) == [2.0, 1.0] + + def test_should_include_unknown_class_count_if_enabled(self): + with tf.Session() as session: + assert session.run(calculate_sample_frequencies([[ + COLOR_1, COLOR_2, COLOR_3 + ]], [COLOR_1], use_unknown_class=True)) == [1.0, 2.0] + def encode_png(data): - out = BytesIO() - data = np.array(data, dtype=np.uint8) - image_size = data.shape[:-1] - get_logger().debug('data type: %s', data.dtype) - get_logger().debug('image_size: %s', image_size) - mode = 'RGB' - image = Image.fromarray(data, mode) - image.save(out, 'png') - image_bytes = out.getvalue() - return image_bytes + out = BytesIO() + data = np.array(data, dtype=np.uint8) + image_size = data.shape[:-1] + get_logger().debug('data type: %s', data.dtype) + get_logger().debug('image_size: %s', image_size) + mode = 'RGB' + image = Image.fromarray(data, mode) + image.save(out, 'png') + image_bytes = out.getvalue() + return image_bytes + class TestIterCalculateSampleFrequencies(object): - def test_should_return_zero_for_single_not_matching_color(self): - assert list(iter_calculate_sample_frequencies([ - [[ - COLOR_0 - ]] - ], [COLOR_1], image_shape=(1, 1, 3))) == [[0.0]] - - def test_should_infer_image_shape(self): - assert list(iter_calculate_sample_frequencies([ - [[ - COLOR_0 - ]] - ], [COLOR_1])) == [[0.0]] - - def test_should_include_unknown_class_if_enabled(self): - assert list(iter_calculate_sample_frequencies([ - [[ - COLOR_0 - ]] - ], [COLOR_1], image_shape=(1, 1, 3), use_unknown_class=True)) == [[0.0, 1.0]] - - def test_should_include_unknown_class_if_enabled_and_infer_shape(self): - assert list(iter_calculate_sample_frequencies([ - [[ - COLOR_0 - ]] - ], [COLOR_1], use_unknown_class=True)) == [[0.0, 1.0]] - - def test_should_return_total_count_for_multiple_mixed_color(self): - assert list(iter_calculate_sample_frequencies([ - [[ - COLOR_0, COLOR_0, COLOR_0 - ]], [[ - COLOR_0, COLOR_1, COLOR_2 - ]], [[ - COLOR_1, COLOR_1, COLOR_2 - ]] - ], [COLOR_1, COLOR_2])) == [ - [0.0, 0.0], - [1.0, 1.0], - [2.0, 1.0] - ] - - def test_should_decode_png(self): - assert list(iter_calculate_sample_frequencies([ - encode_png([[ - COLOR_1 - ]]) - ], [COLOR_1], image_shape=(1, 1, 3), image_format='png')) == [[1.0]] - - def test_should_infer_shape_when_decoding_png(self): - assert list(iter_calculate_sample_frequencies([ - encode_png([[ - COLOR_1 - ]]) - ], [COLOR_1], image_format='png')) == [[1.0]] - - def test_should_infer_shape_when_decoding_png_and_include_unknown_class(self): - assert list(iter_calculate_sample_frequencies([ - encode_png([[ - COLOR_1, COLOR_2, COLOR_3 - ]]) - ], [COLOR_1], image_format='png', use_unknown_class=True)) == [[1.0, 2.0]] + def test_should_return_zero_for_single_not_matching_color(self): + assert list(iter_calculate_sample_frequencies([ + [[ + COLOR_0 + ]] + ], [COLOR_1], image_shape=(1, 1, 3))) == [[0.0]] + + def test_should_infer_image_shape(self): + assert list(iter_calculate_sample_frequencies([ + [[ + COLOR_0 + ]] + ], [COLOR_1])) == [[0.0]] + + def test_should_include_unknown_class_if_enabled(self): + assert list(iter_calculate_sample_frequencies([ + [[ + COLOR_0 + ]] + ], [COLOR_1], image_shape=(1, 1, 3), use_unknown_class=True)) == [[0.0, 1.0]] + + def test_should_include_unknown_class_if_enabled_and_infer_shape(self): + assert list(iter_calculate_sample_frequencies([ + [[ + COLOR_0 + ]] + ], [COLOR_1], use_unknown_class=True)) == [[0.0, 1.0]] + + def test_should_return_total_count_for_multiple_mixed_color(self): + assert list(iter_calculate_sample_frequencies([ + [[ + COLOR_0, COLOR_0, COLOR_0 + ]], [[ + COLOR_0, COLOR_1, COLOR_2 + ]], [[ + COLOR_1, COLOR_1, COLOR_2 + ]] + ], [COLOR_1, COLOR_2])) == [ + [0.0, 0.0], + [1.0, 1.0], + [2.0, 1.0] + ] + + def test_should_decode_png(self): + assert list(iter_calculate_sample_frequencies([ + encode_png([[ + COLOR_1 + ]]) + ], [COLOR_1], image_shape=(1, 1, 3), image_format='png')) == [[1.0]] + + def test_should_infer_shape_when_decoding_png(self): + assert list(iter_calculate_sample_frequencies([ + encode_png([[ + COLOR_1 + ]]) + ], [COLOR_1], image_format='png')) == [[1.0]] + + def test_should_infer_shape_when_decoding_png_and_include_unknown_class(self): + assert list(iter_calculate_sample_frequencies([ + encode_png([[ + COLOR_1, COLOR_2, COLOR_3 + ]]) + ], [COLOR_1], image_format='png', use_unknown_class=True)) == [[1.0, 2.0]] + class TestTfCalculateEfnetForFrequencyByLabel(object): - def test_should_return_same_value_for_classes_with_same_frequencies(self): - with tf.Graph().as_default(): - with tf.Session(): - frequencies = [1, 1] - result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval() + def test_should_return_same_value_for_classes_with_same_frequencies(self): + with tf.Graph().as_default(): + with tf.Session(): + frequencies = [1, 1] + result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval() + assert result[0] == result[1] + + def test_should_return_higher_value_for_less_frequent_occuring_class(self): + with tf.Graph().as_default(): + with tf.Session(): + frequencies = [2, 1] + result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval() + assert result[0] < result[1] + + def test_should_return_zero_value_for_not_occuring_class(self): + with tf.Graph().as_default(): + with tf.Session(): + frequencies = [1, 0] + result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval() + assert result[-1] == 0.0 + + +class TestCalculateEfnetForFrequenciesByLabel(object): + def test_should_return_same_value_for_classes_with_same_frequencies(self): + frequencies = [ + [0, 1], + [0, 1] + ] + result = calculate_efnet_weights_for_frequencies_by_label(frequencies) assert result[0] == result[1] - def test_should_return_higher_value_for_less_frequent_occuring_class(self): - with tf.Graph().as_default(): - with tf.Session(): - frequencies = [2, 1] - result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval() + def test_should_return_higher_value_for_less_frequent_occuring_class(self): + frequencies = [ + [1, 1], + [0, 1] + ] + result = calculate_efnet_weights_for_frequencies_by_label(frequencies) assert result[0] < result[1] - def test_should_return_zero_value_for_not_occuring_class(self): - with tf.Graph().as_default(): - with tf.Session(): - frequencies = [1, 0] - result = tf_calculate_efnet_weights_for_frequency_by_label(frequencies).eval() + def test_should_return_zero_value_for_not_occuring_class(self): + frequencies = [ + [1, 1], + [0, 0] + ] + result = calculate_efnet_weights_for_frequencies_by_label(frequencies) assert result[-1] == 0.0 -class TestCalculateEfnetForFrequenciesByLabel(object): - def test_should_return_same_value_for_classes_with_same_frequencies(self): - frequencies = [ - [0, 1], - [0, 1] - ] - result = calculate_efnet_weights_for_frequencies_by_label(frequencies) - assert result[0] == result[1] - - def test_should_return_higher_value_for_less_frequent_occuring_class(self): - frequencies = [ - [1, 1], - [0, 1] - ] - result = calculate_efnet_weights_for_frequencies_by_label(frequencies) - assert result[0] < result[1] - - def test_should_return_zero_value_for_not_occuring_class(self): - frequencies = [ - [1, 1], - [0, 0] - ] - result = calculate_efnet_weights_for_frequencies_by_label(frequencies) - assert result[-1] == 0.0 class TestCalculateMedianClassWeight(object): - def test_should_return_median_frequency_balanced_for_same_frequencies(self): - assert calculate_median_class_weight([3, 3, 3]) == 1 / 3 + def test_should_return_median_frequency_balanced_for_same_frequencies(self): + assert calculate_median_class_weight([3, 3, 3]) == 1 / 3 + + def test_should_return_median_frequence_balanced_for_different_frequencies(self): + assert calculate_median_class_weight([1, 3, 5]) == 1 / 3 - def test_should_return_median_frequence_balanced_for_different_frequencies(self): - assert calculate_median_class_weight([1, 3, 5]) == 1 / 3 + def test_should_return_zero_for_all_zero_frequencies(self): + assert calculate_median_class_weight([0, 0, 0]) == 0.0 - def test_should_return_zero_for_all_zero_frequencies(self): - assert calculate_median_class_weight([0, 0, 0]) == 0.0 class TestCalculateWeightsForFrequencies(object): - def test_should_return_one_for_single_class(self): - assert calculate_median_weights_for_frequencies([ - [3, 3, 3] - ]) == [1.0] - - def test_should_return_50p_for_classes_with_same_frequencies(self): - assert calculate_median_weights_for_frequencies([ - [3, 3, 3], - [3, 3, 3] - ]) == [0.5, 0.5] - - def test_should_return_higher_value_for_less_frequent_occuring_class(self): - frequencies = [ - [1, 1], - [1, 1], - [0, 1] - ] - result = calculate_median_weights_for_frequencies(frequencies) - get_logger().debug('result: %s', result) - assert_close(sum(result), 1.0) - assert_all_close(result, [0.25, 0.25, 0.5], atol=0.001) - - def test_should_return_zero_value_for_not_occuring_class(self): - frequencies = [ - [1, 1], - [1, 1], - [0, 0] - ] - result = calculate_median_weights_for_frequencies(frequencies) - get_logger().debug('result: %s', result) - assert_close(sum(result), 1.0) - assert_all_close(result, [0.5, 0.5, 0.0], atol=0.001) + def test_should_return_one_for_single_class(self): + assert calculate_median_weights_for_frequencies([ + [3, 3, 3] + ]) == [1.0] + + def test_should_return_50p_for_classes_with_same_frequencies(self): + assert calculate_median_weights_for_frequencies([ + [3, 3, 3], + [3, 3, 3] + ]) == [0.5, 0.5] + + def test_should_return_higher_value_for_less_frequent_occuring_class(self): + frequencies = [ + [1, 1], + [1, 1], + [0, 1] + ] + result = calculate_median_weights_for_frequencies(frequencies) + get_logger().debug('result: %s', result) + assert_close(sum(result), 1.0) + assert_all_close(result, [0.25, 0.25, 0.5], atol=0.001) + + def test_should_return_zero_value_for_not_occuring_class(self): + frequencies = [ + [1, 1], + [1, 1], + [0, 0] + ] + result = calculate_median_weights_for_frequencies(frequencies) + get_logger().debug('result: %s', result) + assert_close(sum(result), 1.0) + assert_all_close(result, [0.5, 0.5, 0.0], atol=0.001) + class TestCalculateMedianClassWeightsForFfrecordPathsAndColors(object): - def test_should_calculate_median_class_weights_for_single_image_and_single_color(self): - with TemporaryDirectory() as path: - tfrecord_filename = os.path.join(path, 'data.tfrecord') - get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) - write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ - 'image': encode_png([[ - COLOR_1 - ]]) - })]) - class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( - [tfrecord_filename], 'image', [COLOR_1] - ) - assert class_weights == [1.0] - - def test_should_calculate_median_class_weights_for_multiple_image_and_multiple_images(self): - with TemporaryDirectory() as path: - tfrecord_filename = os.path.join(path, 'data.tfrecord') - get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) - write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ - 'image': encode_png([[ - COLOR_0, COLOR_1, COLOR_2 - ]]) - }), dict_to_example({ - 'image': encode_png([[ - COLOR_1, COLOR_2, COLOR_3 - ]]) - })]) - class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( - [tfrecord_filename], 'image', [COLOR_1, COLOR_2, COLOR_3] - ) - assert class_weights == [0.25, 0.25, 0.5] - - def test_should_return_zero_for_non_occuring_class(self): - with TemporaryDirectory() as path: - tfrecord_filename = os.path.join(path, 'data.tfrecord') - get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) - write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ - 'image': encode_png([[ - COLOR_1 - ]]) - })]) - class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( - [tfrecord_filename], 'image', [COLOR_1, COLOR_2] - ) - assert class_weights == [1.0, 0.0] + def test_should_calculate_median_class_weights_for_single_image_and_single_color(self): + with TemporaryDirectory() as path: + tfrecord_filename = os.path.join(path, 'data.tfrecord') + get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) + write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ + 'image': encode_png([[ + COLOR_1 + ]]) + })]) + class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( + [tfrecord_filename], 'image', [COLOR_1] + ) + assert class_weights == [1.0] + + def test_should_calculate_median_class_weights_for_multiple_image_and_multiple_images(self): + with TemporaryDirectory() as path: + tfrecord_filename = os.path.join(path, 'data.tfrecord') + get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) + write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ + 'image': encode_png([[ + COLOR_0, COLOR_1, COLOR_2 + ]]) + }), dict_to_example({ + 'image': encode_png([[ + COLOR_1, COLOR_2, COLOR_3 + ]]) + })]) + class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( + [tfrecord_filename], 'image', [COLOR_1, COLOR_2, COLOR_3] + ) + assert class_weights == [0.25, 0.25, 0.5] + + def test_should_return_zero_for_non_occuring_class(self): + with TemporaryDirectory() as path: + tfrecord_filename = os.path.join(path, 'data.tfrecord') + get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) + write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ + 'image': encode_png([[ + COLOR_1 + ]]) + })]) + class_weights = calculate_median_class_weights_for_tfrecord_paths_and_colors( + [tfrecord_filename], 'image', [COLOR_1, COLOR_2] + ) + assert class_weights == [1.0, 0.0] + class TestCalculateMedianClassWeightsForFfrecordPathsAndColorMap(object): - def test_should_calculate_median_class_weights_for_single_image_and_single_color(self): - with TemporaryDirectory() as path: - tfrecord_filename = os.path.join(path, 'data.tfrecord') - get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) - write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ - 'image': encode_png([[ - COLOR_1, COLOR_2 - ]]) - })]) - class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( - [tfrecord_filename], 'image', { - 'color1': COLOR_1, - 'color2': COLOR_2, - 'color3': COLOR_3 - }, - channels=['color1', 'color2'] - ) - assert class_weights_map == { - 'color1': 0.5, - 'color2': 0.5 - } - - def test_should_use_color_map_keys_as_channels_by_default(self): - with TemporaryDirectory() as path: - tfrecord_filename = os.path.join(path, 'data.tfrecord') - get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) - write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ - 'image': encode_png([[ - COLOR_1, COLOR_2 - ]]) - })]) - class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( - [tfrecord_filename], 'image', { - 'color1': COLOR_1, - 'color2': COLOR_2 - } - ) - assert set(class_weights_map.keys()) == {'color1', 'color2'} - - def test_should_include_unknown_class_if_enabled(self): - with TemporaryDirectory() as path: - tfrecord_filename = os.path.join(path, 'data.tfrecord') - get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) - write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ - 'image': encode_png([[ - COLOR_0, COLOR_1, COLOR_2, COLOR_3 - ]]) - })]) - class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( - [tfrecord_filename], 'image', { - 'color1': COLOR_1, - 'color2': COLOR_2 - }, - use_unknown_class=True, - unknown_class_label='unknown' - ) - assert set(class_weights_map.keys()) == {'color1', 'color2', 'unknown'} + def test_should_calculate_median_class_weights_for_single_image_and_single_color(self): + with TemporaryDirectory() as path: + tfrecord_filename = os.path.join(path, 'data.tfrecord') + get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) + write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ + 'image': encode_png([[ + COLOR_1, COLOR_2 + ]]) + })]) + class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( + [tfrecord_filename], 'image', { + 'color1': COLOR_1, + 'color2': COLOR_2, + 'color3': COLOR_3 + }, + channels=['color1', 'color2'] + ) + assert class_weights_map == { + 'color1': 0.5, + 'color2': 0.5 + } + + def test_should_use_color_map_keys_as_channels_by_default(self): + with TemporaryDirectory() as path: + tfrecord_filename = os.path.join(path, 'data.tfrecord') + get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) + write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ + 'image': encode_png([[ + COLOR_1, COLOR_2 + ]]) + })]) + class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( + [tfrecord_filename], 'image', { + 'color1': COLOR_1, + 'color2': COLOR_2 + } + ) + assert set(class_weights_map.keys()) == {'color1', 'color2'} + + def test_should_include_unknown_class_if_enabled(self): + with TemporaryDirectory() as path: + tfrecord_filename = os.path.join(path, 'data.tfrecord') + get_logger().debug('writing to test tfrecord_filename: %s', tfrecord_filename) + write_examples_to_tfrecord(tfrecord_filename, [dict_to_example({ + 'image': encode_png([[ + COLOR_0, COLOR_1, COLOR_2, COLOR_3 + ]]) + })]) + class_weights_map = calculate_median_class_weights_for_tfrecord_paths_and_color_map( + [tfrecord_filename], 'image', { + 'color1': COLOR_1, + 'color2': COLOR_2 + }, + use_unknown_class=True, + unknown_class_label='unknown' + ) + assert set(class_weights_map.keys()) == {'color1', 'color2', 'unknown'} diff --git a/sciencebeam_gym/tools/colorize_image.py b/sciencebeam_gym/tools/colorize_image.py index 438d7fa..8306996 100644 --- a/sciencebeam_gym/tools/colorize_image.py +++ b/sciencebeam_gym/tools/colorize_image.py @@ -9,93 +9,101 @@ from six.moves.configparser import ConfigParser from sciencebeam_gym.utils.tf import FileIO + def get_args_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--color_map', - default='color_map.conf', - type=str, - help='The path to the color map configuration.' - ) - parser.add_argument( - '--input_image', - required=True, - type=str, - help='The path to the input image.' - ) - parser.add_argument( - '--output_image', - required=False, - type=str, - help='The path to the output image.' - ) - return parser + parser = argparse.ArgumentParser() + parser.add_argument( + '--color_map', + default='color_map.conf', + type=str, + help='The path to the color map configuration.' + ) + parser.add_argument( + '--input_image', + required=True, + type=str, + help='The path to the input image.' + ) + parser.add_argument( + '--output_image', + required=False, + type=str, + help='The path to the output image.' + ) + return parser + def parse_args(argv=None): - parser = get_args_parser() - parsed_args, _ = parser.parse_known_args(argv) - return parsed_args + parser = get_args_parser() + parsed_args, _ = parser.parse_known_args(argv) + return parsed_args + def parse_color_map_from_configparser(color_map_config): - num_pattern = re.compile(r'(\d+)') - rgb_pattern = re.compile(r'\((\d+),(\d+),(\d+)\)') - - def parse_color(s): - m = num_pattern.match(s) - if m: - x = int(m.group(1)) - return (x, x, x) - else: - m = rgb_pattern.match(s) - if m: - return (int(m.group(1)), int(m.group(2)), int(m.group(3))) - raise Exception('invalid color value: {}'.format(s)) - - color_map = dict() - for k, v in color_map_config.items('color_map'): - color_map[parse_color(k)] = parse_color(v) - return color_map + num_pattern = re.compile(r'(\d+)') + rgb_pattern = re.compile(r'\((\d+),(\d+),(\d+)\)') + + def parse_color(s): + m = num_pattern.match(s) + if m: + x = int(m.group(1)) + return (x, x, x) + else: + m = rgb_pattern.match(s) + if m: + return (int(m.group(1)), int(m.group(2)), int(m.group(3))) + raise Exception('invalid color value: {}'.format(s)) + + color_map = dict() + for k, v in color_map_config.items('color_map'): + color_map[parse_color(k)] = parse_color(v) + return color_map + def parse_color_map_from_file(f): - color_map_config = ConfigParser() - color_map_config.readfp(f) - return parse_color_map_from_configparser(color_map_config) + color_map_config = ConfigParser() + color_map_config.readfp(f) + return parse_color_map_from_configparser(color_map_config) + def parse_color_map(f): - return parse_color_map_from_file(f) + return parse_color_map_from_file(f) + def map_colors(img, color_map): - if color_map is None or len(color_map) == 0: + if color_map is None or len(color_map) == 0: + return img + original_data = img.getdata() + mapped_data = [ + color_map.get(color, color) + for color in original_data + ] + img.putdata(mapped_data) return img - original_data = img.getdata() - mapped_data = [ - color_map.get(color, color) - for color in original_data - ] - img.putdata(mapped_data) - return img + def main(): - from PIL import Image + from PIL import Image + + logger = logging.getLogger(__name__) + args = parse_args() - logger = logging.getLogger(__name__) - args = parse_args() + with FileIO(args.color_map, 'r') as config_f: + color_map = parse_color_map(config_f) - with FileIO(args.color_map, 'r') as config_f: - color_map = parse_color_map(config_f) + logger.info('read %s color mappings', len(color_map)) - logger.info('read {} color mappings'.format(len(color_map))) + with FileIO(args.input_image, 'rb') as input_f: + image_bytes = input_f.read() + img = Image.open(io.BytesIO(image_bytes)).convert('RGB') - with FileIO(args.input_image, 'rb') as input_f: - image_bytes = input_f.read() - img = Image.open(io.BytesIO(image_bytes)).convert('RGB') + img = map_colors(img, color_map) - img = map_colors(img, color_map) + with FileIO(args.output_image, 'wb') as output_f: + img.save(output_f, 'png') - with FileIO(args.output_image, 'wb') as output_f: - img.save(output_f, 'png') if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.INFO) - main() + main() diff --git a/sciencebeam_gym/tools/inspect_tfrecords.py b/sciencebeam_gym/tools/inspect_tfrecords.py index c3373b1..19416e8 100644 --- a/sciencebeam_gym/tools/inspect_tfrecords.py +++ b/sciencebeam_gym/tools/inspect_tfrecords.py @@ -7,94 +7,100 @@ import tensorflow as tf from tensorflow.python.lib.io import file_io from sciencebeam_gym.utils.tf import ( - FileIO + FileIO ) + def get_args_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--records_paths', - required=True, - type=str, - action='append', - help='The paths to the tf-records files to inspect.' - ) - parser.add_argument( - '--inspect_key', - required=False, - type=str, - help='The name of the key to further inspect.' - ) - parser.add_argument( - '--extract_dir', - required=False, - default=".", - type=str, - help='The directory to extract to.' - ) - parser.add_argument( - '--extract_image', - required=False, - action='append', - type=str, - help='The name of the key to extract as an image.' - ) - return parser + parser = argparse.ArgumentParser() + parser.add_argument( + '--records_paths', + required=True, + type=str, + action='append', + help='The paths to the tf-records files to inspect.' + ) + parser.add_argument( + '--inspect_key', + required=False, + type=str, + help='The name of the key to further inspect.' + ) + parser.add_argument( + '--extract_dir', + required=False, + default=".", + type=str, + help='The directory to extract to.' + ) + parser.add_argument( + '--extract_image', + required=False, + action='append', + type=str, + help='The name of the key to extract as an image.' + ) + return parser + def parse_args(argv=None): - parser = get_args_parser() - parsed_args, _ = parser.parse_known_args(argv) - return parsed_args + parser = get_args_parser() + parsed_args, _ = parser.parse_known_args(argv) + return parsed_args + def get_matching_files_for_paths(paths): - files = [] - for path in paths: - files.extend(file_io.get_matching_files(path)) - # logging.info('files: %s (%s)', files, paths) - return files + files = [] + for path in paths: + files.extend(file_io.get_matching_files(path)) + # logging.info('files: %s (%s)', files, paths) + return files + def main(): - args = parse_args() - - files = get_matching_files_for_paths(args.records_paths) - - if args.extract_image: - file_io.recursive_create_dir(args.extract_dir) - - total_count = 0 - for f in files: - options = None - if f.endswith('.gz'): - options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.GZIP) - print("file:", f) - count = 0 - for i, example in enumerate(tf.python_io.tf_record_iterator(f, options=options)): - result = tf.train.Example.FromString(example) - if i == 0: - print(" features:", result.features.feature.keys()) - if args.inspect_key: - print(" first value of {}:\n {}".format( - args.inspect_key, - result.features.feature.get(args.inspect_key).bytes_list.value[0] - )) - if args.extract_image: - for extract_image_key in args.extract_image: - image_bytes = result.features.feature.get(extract_image_key).bytes_list.value[0] - print(" image size %d bytes (%s)" % (len(image_bytes), type(image_bytes))) - image_filename = os.path.join( - args.extract_dir, - '{}-{}-{}.png'.format(os.path.basename(f), extract_image_key, i) - ) - print(" extracting image to {}".format(image_filename)) - with FileIO(image_filename, 'wb') as image_f: - image_f.write(image_bytes) - count += 1 - print(" found {} records".format(count)) - total_count += count - print("found total of {} records in {} files".format(total_count, len(files))) + args = parse_args() + + files = get_matching_files_for_paths(args.records_paths) + + if args.extract_image: + file_io.recursive_create_dir(args.extract_dir) + + total_count = 0 + for f in files: + options = None + if f.endswith('.gz'): + options = tf.python_io.TFRecordOptions( + compression_type=tf.python_io.TFRecordCompressionType.GZIP) + print("file:", f) + count = 0 + for i, example in enumerate(tf.python_io.tf_record_iterator(f, options=options)): + result = tf.train.Example.FromString(example) # pylint: disable=no-member + if i == 0: + print(" features:", result.features.feature.keys()) + if args.inspect_key: + print(" first value of {}:\n {}".format( + args.inspect_key, + result.features.feature.get(args.inspect_key).bytes_list.value[0] + )) + if args.extract_image: + for extract_image_key in args.extract_image: + image_bytes = result.features.feature.get( + extract_image_key).bytes_list.value[0] + print(" image size %d bytes (%s)" % (len(image_bytes), type(image_bytes))) + image_filename = os.path.join( + args.extract_dir, + '{}-{}-{}.png'.format(os.path.basename(f), extract_image_key, i) + ) + print(" extracting image to {}".format(image_filename)) + with FileIO(image_filename, 'wb') as image_f: + image_f.write(image_bytes) + count += 1 + print(" found {} records".format(count)) + total_count += count + print("found total of {} records in {} files".format(total_count, len(files))) + if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.INFO) - main() + main() diff --git a/sciencebeam_gym/tools/resize_image.py b/sciencebeam_gym/tools/resize_image.py index d2f250f..d407a28 100644 --- a/sciencebeam_gym/tools/resize_image.py +++ b/sciencebeam_gym/tools/resize_image.py @@ -3,59 +3,62 @@ from __future__ import absolute_import import argparse import io -from six.moves.configparser import ConfigParser - from PIL import Image from sciencebeam_gym.utils.tf import FileIO + def get_args_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--image_width', - type=int, - required=True, - help='Resize images to the specified width') - parser.add_argument( - '--image_height', - type=int, - required=True, - help='Resize images to the specified height') - parser.add_argument( - '--input_image', - required=True, - type=str, - help='The path to the input image.' - ) - parser.add_argument( - '--output_image', - required=False, - type=str, - help='The path to the output image.' - ) - return parser + parser = argparse.ArgumentParser() + parser.add_argument( + '--image_width', + type=int, + required=True, + help='Resize images to the specified width') + parser.add_argument( + '--image_height', + type=int, + required=True, + help='Resize images to the specified height') + parser.add_argument( + '--input_image', + required=True, + type=str, + help='The path to the input image.' + ) + parser.add_argument( + '--output_image', + required=False, + type=str, + help='The path to the output image.' + ) + return parser + def parse_args(argv=None): - parser = get_args_parser() - parsed_args, _ = parser.parse_known_args(argv) - return parsed_args + parser = get_args_parser() + parsed_args, _ = parser.parse_known_args(argv) + return parsed_args + def image_resize_bicubic(image, size): - return image.resize(size, Image.BICUBIC) + return image.resize(size, Image.BICUBIC) + def main(): - args = parse_args() + args = parse_args() + + image_size = (args.image_width, args.image_height) - image_size = (args.image_width, args.image_height) + with FileIO(args.input_image, 'rb') as input_f: + image_bytes = input_f.read() + image = Image.open(io.BytesIO(image_bytes)).convert('RGB') - with FileIO(args.input_image, 'rb') as input_f: - image_bytes = input_f.read() - image = Image.open(io.BytesIO(image_bytes)).convert('RGB') + image = image_resize_bicubic(image, image_size) - image = image_resize_bicubic(image, image_size) + with FileIO(args.output_image, 'wb') as output_f: + image.save(output_f, 'png') - with FileIO(args.output_image, 'wb') as output_f: - image.save(output_f, 'png') if __name__ == "__main__": - main() + main() diff --git a/sciencebeam_gym/trainer/checkpoint.py b/sciencebeam_gym/trainer/checkpoint.py index 1c4fd80..a849858 100644 --- a/sciencebeam_gym/trainer/checkpoint.py +++ b/sciencebeam_gym/trainer/checkpoint.py @@ -3,38 +3,40 @@ import logging import tensorflow as tf from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import ( - batch_dimensions_to_most_likely_colors_list + batch_dimensions_to_most_likely_colors_list ) from sciencebeam_gym.inference_model import ( - InferenceModel + InferenceModel ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def load_last_checkpoint_as_inference_model(model, checkpoint_path, session=None): - if session is None: - session = tf.get_default_session() - assert session is not None - last_checkpoint = tf.train.latest_checkpoint(checkpoint_path) - get_logger().info( - 'last_checkpoint: %s (%s)', last_checkpoint, checkpoint_path - ) - tensors = model.build_predict_graph() - inputs_tensor = tensors.inputs['image'] - outputs_tensor = tensors.pred - if model.use_separate_channels: - outputs_tensor = batch_dimensions_to_most_likely_colors_list( - outputs_tensor, - model.dimension_colors_with_unknown + if session is None: + session = tf.get_default_session() + assert session is not None + last_checkpoint = tf.train.latest_checkpoint(checkpoint_path) + get_logger().info( + 'last_checkpoint: %s (%s)', last_checkpoint, checkpoint_path + ) + tensors = model.build_predict_graph() + inputs_tensor = tensors.inputs['image'] + outputs_tensor = tensors.pred + if model.use_separate_channels: + outputs_tensor = batch_dimensions_to_most_likely_colors_list( + outputs_tensor, + model.dimension_colors_with_unknown + ) + saver = tf.train.Saver() + saver.restore(session, last_checkpoint) + labels = tf.constant(model.dimension_labels_with_unknown) + colors = tf.constant(model.dimension_colors_with_unknown) + inference_model = InferenceModel( + inputs_tensor, outputs_tensor, + labels, colors ) - saver = tf.train.Saver() - saver.restore(session, last_checkpoint) - labels = tf.constant(model.dimension_labels_with_unknown) - colors = tf.constant(model.dimension_colors_with_unknown) - inference_model = InferenceModel( - inputs_tensor, outputs_tensor, - labels, colors - ) - return inference_model + return inference_model diff --git a/sciencebeam_gym/trainer/data/examples.py b/sciencebeam_gym/trainer/data/examples.py index 0c13b04..d71c650 100644 --- a/sciencebeam_gym/trainer/data/examples.py +++ b/sciencebeam_gym/trainer/data/examples.py @@ -2,125 +2,133 @@ import logging from functools import partial import tensorflow as tf -from tensorflow.python.lib.io import file_io # pylint: disable=E0611 +from tensorflow.python.lib.io import file_io # pylint: disable=E0611 from sciencebeam_utils.utils.collection import ( - extend_dict + extend_dict ) from sciencebeam_gym.model_utils.channels import ( - calculate_color_masks + calculate_color_masks ) try: - # TensorFlow 1.4+ - tf_data = tf.data + # TensorFlow 1.4+ + tf_data = tf.data except AttributeError: - tf_data = tf.contrib.data + tf_data = tf.contrib.data Dataset = tf_data.Dataset TFRecordDataset = tf_data.TFRecordDataset DEFAULT_FEATURE_MAP = { - 'input_uri': tf.FixedLenFeature( - shape=[], dtype=tf.string, default_value=[''] - ), - 'annotation_uri': tf.FixedLenFeature( - shape=[], dtype=tf.string, default_value=[''] - ), - 'input_image': tf.FixedLenFeature( - shape=[], dtype=tf.string - ), - 'annotation_image': tf.FixedLenFeature( - shape=[], dtype=tf.string - ) + 'input_uri': tf.FixedLenFeature( + shape=[], dtype=tf.string, default_value=[''] + ), + 'annotation_uri': tf.FixedLenFeature( + shape=[], dtype=tf.string, default_value=[''] + ), + 'input_image': tf.FixedLenFeature( + shape=[], dtype=tf.string + ), + 'annotation_image': tf.FixedLenFeature( + shape=[], dtype=tf.string + ) } PAGE_NO_FEATURE = { - 'page_no': tf.FixedLenFeature( - shape=[], dtype=tf.int64 - ) + 'page_no': tf.FixedLenFeature( + shape=[], dtype=tf.int64 + ) } + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def get_matching_files(paths): - files = [] - for e in paths: - for path in e.split(','): - files.extend(file_io.get_matching_files(path)) - return files + files = [] + for e in paths: + for path in e.split(','): + files.extend(file_io.get_matching_files(path)) + return files + def parse_example(example, feature_map=None): - if feature_map is None: - feature_map = DEFAULT_FEATURE_MAP - get_logger().info('example: %s', example) - return tf.parse_single_example(example, features=feature_map) + if feature_map is None: + feature_map = DEFAULT_FEATURE_MAP + get_logger().info('example: %s', example) + return tf.parse_single_example(example, features=feature_map) # Workaround for Tensorflow 1.2 not supporting dicts + + class MapKeysTracker(object): - def __init__(self): - self.keys = None + def __init__(self): + self.keys = None - def wrap(self, fn): - def wrapper(x): - x = fn(x) - if self.keys is not None: - get_logger().warn('keys already set: %s', self.keys) - self.keys = sorted(x.keys()) - return [x[k] for k in self.keys] - return wrapper + def wrap(self, fn): + def wrapper(x): + x = fn(x) + if self.keys is not None: + get_logger().warn('keys already set: %s', self.keys) + self.keys = sorted(x.keys()) + return [x[k] for k in self.keys] + return wrapper + + def unwrap(self, result): + return {k: v for k, v in zip(self.keys, result)} - def unwrap(self, result): - return {k: v for k, v in zip(self.keys, result)} def page_no_is_within(page_no, page_range): - get_logger().debug('page_no: %s, page_range: %s', page_no, page_range) - return tf.logical_and(page_no >= page_range[0], page_no <= page_range[1]) + get_logger().debug('page_no: %s, page_range: %s', page_no, page_range) + return tf.logical_and(page_no >= page_range[0], page_no <= page_range[1]) + def image_contains_any_of_the_colors(image, colors): - decoded_image = tf.image.decode_png(image, channels=3) - color_masks = calculate_color_masks(decoded_image, colors) - return tf.reduce_any([ - tf.reduce_any(color_mask >= 0.5) - for color_mask in color_masks - ]) + decoded_image = tf.image.decode_png(image, channels=3) + color_masks = calculate_color_masks(decoded_image, colors) + return tf.reduce_any([ + tf.reduce_any(color_mask >= 0.5) + for color_mask in color_masks + ]) + def read_examples( - filenames, - shuffle, - num_epochs=None, - page_range=None, - channel_colors=None): - - # Convert num_epochs == 0 -> num_epochs is None, if necessary - num_epochs = num_epochs or None - - feature_map = DEFAULT_FEATURE_MAP - if page_range is not None: - feature_map = extend_dict(feature_map, PAGE_NO_FEATURE) - - map_keys_tracker = MapKeysTracker() - - dataset = TFRecordDataset(filenames, compression_type='GZIP') - dataset = dataset.map(map_keys_tracker.wrap( - partial(parse_example, feature_map=feature_map) - )) - if page_range is not None: - dataset = dataset.filter(lambda *x: page_no_is_within( - map_keys_tracker.unwrap(x)['page_no'], - page_range - )) - if channel_colors is not None: - dataset = dataset.filter(lambda *x: image_contains_any_of_the_colors( - map_keys_tracker.unwrap(x)['annotation_image'], - channel_colors - )) - if shuffle: - dataset = dataset.shuffle(buffer_size=10000) - dataset = dataset.repeat(num_epochs) + filenames, + shuffle, + num_epochs=None, + page_range=None, + channel_colors=None): + + # Convert num_epochs == 0 -> num_epochs is None, if necessary + num_epochs = num_epochs or None - return map_keys_tracker.unwrap( - dataset.make_one_shot_iterator().get_next() - ) + feature_map = DEFAULT_FEATURE_MAP + if page_range is not None: + feature_map = extend_dict(feature_map, PAGE_NO_FEATURE) + + map_keys_tracker = MapKeysTracker() + + dataset = TFRecordDataset(filenames, compression_type='GZIP') + dataset = dataset.map(map_keys_tracker.wrap( + partial(parse_example, feature_map=feature_map) + )) + if page_range is not None: + dataset = dataset.filter(lambda *x: page_no_is_within( + map_keys_tracker.unwrap(x)['page_no'], + page_range + )) + if channel_colors is not None: + dataset = dataset.filter(lambda *x: image_contains_any_of_the_colors( + map_keys_tracker.unwrap(x)['annotation_image'], + channel_colors + )) + if shuffle: + dataset = dataset.shuffle(buffer_size=10000) + dataset = dataset.repeat(num_epochs) + + return map_keys_tracker.unwrap( + dataset.make_one_shot_iterator().get_next() + ) diff --git a/sciencebeam_gym/trainer/data/examples_test.py b/sciencebeam_gym/trainer/data/examples_test.py index 7a2886a..cb99ef6 100644 --- a/sciencebeam_gym/trainer/data/examples_test.py +++ b/sciencebeam_gym/trainer/data/examples_test.py @@ -1,26 +1,24 @@ import logging from mock import patch -import pytest - import tensorflow as tf from sciencebeam_utils.utils.collection import ( - extend_dict + extend_dict ) from sciencebeam_gym.utils.tfrecord import ( - dict_to_example + dict_to_example ) from sciencebeam_gym.tools.calculate_class_weights_test import ( - encode_png + encode_png ) import sciencebeam_gym.trainer.data.examples as examples_module from sciencebeam_gym.trainer.data.examples import ( - read_examples, - tf_data + read_examples, + tf_data ) DATA_PATH = '.temp/data/*.tfrecord' @@ -28,78 +26,87 @@ DATA_PATH = '.temp/data/*.tfrecord' IMAGE_SHAPE = (5, 5) EXAMPLE_PROPS_1 = { - 'input_uri': 'input.png', - 'input_image': b'input image', - 'annotation_uri': 'annotation.png', - 'annotation_image': b'annotation image' + 'input_uri': 'input.png', + 'input_image': b'input image', + 'annotation_uri': 'annotation.png', + 'annotation_image': b'annotation image' } RECORD_1 = dict_to_example(EXAMPLE_PROPS_1).SerializeToString() + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + def list_dataset(data, dtype): - data = tf.constant(data, dtype=dtype) - return tf_data.Dataset.from_tensor_slices(data) + data = tf.constant(data, dtype=dtype) + return tf_data.Dataset.from_tensor_slices(data) + def fetch_examples(session, examples_tensor, max_examples=1000): - try: - for _ in range(max_examples): - yield session.run([examples_tensor])[0] - except tf.errors.OutOfRangeError: - get_logger().debug('end of dataset') + try: + for _ in range(max_examples): + yield session.run([examples_tensor])[0] + except tf.errors.OutOfRangeError: + get_logger().debug('end of dataset') + def some_color(i): - return (i, i, i) + return (i, i, i) + def image_with_color(color, shape=IMAGE_SHAPE): - return encode_png([[color] * shape[1]] * shape[0]) + return encode_png([[color] * shape[1]] * shape[0]) # @pytest.mark.slow + + class TestReadExamples(object): - def test_should_read_single_example(self): - with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset: - with tf.Graph().as_default(): - TFRecordDataset.return_value = list_dataset([RECORD_1], tf.string) - examples = read_examples(DATA_PATH, shuffle=False) - TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP') - with tf.Session() as session: - next_example = session.run([examples])[0] - get_logger().info('next_example: %s', next_example) - assert next_example == EXAMPLE_PROPS_1 - - def test_should_filter_by_page_no(self): - with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset: - with tf.Graph().as_default(): - TFRecordDataset.return_value = list_dataset([ - dict_to_example(extend_dict(EXAMPLE_PROPS_1, page_no=page_no)).SerializeToString() - for page_no in [1, 2, 3, 4] - ], tf.string) - examples = read_examples(DATA_PATH, shuffle=False, num_epochs=1, page_range=(2, 3)) - TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP') - with tf.Session() as session: - assert [x['page_no'] for x in fetch_examples(session, examples)] == [2, 3] - - def test_should_filter_by_channel_colors(self): - with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset: - with tf.Graph().as_default(): - TFRecordDataset.return_value = list_dataset([ - dict_to_example(extend_dict( - EXAMPLE_PROPS_1, - page_no=page_no, - annotation_image=image_with_color(some_color(page_no)) - )).SerializeToString() - for page_no in [1, 2, 3, 4] - ], tf.string) - examples = read_examples( - DATA_PATH, shuffle=False, num_epochs=1, - page_range=(0, 100), - channel_colors=[some_color(i) for i in [2, 3]] - ) - TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP') - with tf.Session() as session: - assert [x['page_no'] for x in fetch_examples(session, examples)] == [2, 3] + def test_should_read_single_example(self): + with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset: + with tf.Graph().as_default(): + TFRecordDataset.return_value = list_dataset([RECORD_1], tf.string) + examples = read_examples(DATA_PATH, shuffle=False) + TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP') + with tf.Session() as session: + next_example = session.run([examples])[0] + get_logger().info('next_example: %s', next_example) + assert next_example == EXAMPLE_PROPS_1 + + def test_should_filter_by_page_no(self): + with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset: + with tf.Graph().as_default(): + TFRecordDataset.return_value = list_dataset([ + dict_to_example(extend_dict(EXAMPLE_PROPS_1, page_no=page_no) + ).SerializeToString() + for page_no in [1, 2, 3, 4] + ], tf.string) + examples = read_examples(DATA_PATH, shuffle=False, num_epochs=1, page_range=(2, 3)) + TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP') + with tf.Session() as session: + assert [x['page_no'] for x in fetch_examples(session, examples)] == [2, 3] + + def test_should_filter_by_channel_colors(self): + with patch.object(examples_module, 'TFRecordDataset') as TFRecordDataset: + with tf.Graph().as_default(): + TFRecordDataset.return_value = list_dataset([ + dict_to_example(extend_dict( + EXAMPLE_PROPS_1, + page_no=page_no, + annotation_image=image_with_color(some_color(page_no)) + )).SerializeToString() + for page_no in [1, 2, 3, 4] + ], tf.string) + examples = read_examples( + DATA_PATH, shuffle=False, num_epochs=1, + page_range=(0, 100), + channel_colors=[some_color(i) for i in [2, 3]] + ) + TFRecordDataset.assert_called_with(DATA_PATH, compression_type='GZIP') + with tf.Session() as session: + assert [x['page_no'] for x in fetch_examples(session, examples)] == [2, 3] diff --git a/sciencebeam_gym/trainer/evaluator.py b/sciencebeam_gym/trainer/evaluator.py index 67d62cc..1fa78bf 100644 --- a/sciencebeam_gym/trainer/evaluator.py +++ b/sciencebeam_gym/trainer/evaluator.py @@ -3,478 +3,489 @@ import logging import json from io import BytesIO -import matplotlib as mpl -# this is important to run on the cloud - we won't have python-tk installed -mpl.use("Agg") - -# pylint: disable=C0413 -from matplotlib import pyplot as plt import numpy as np import six import tensorflow as tf -from tensorflow.python.lib.io import file_io # pylint: disable=E0611 +from tensorflow.python.lib.io import file_io from PIL import Image +from sciencebeam_gym.utils.pyplot import pyplot as plt + from sciencebeam_gym.utils.tf import ( - FileIO + FileIO ) from sciencebeam_gym.trainer.util import ( - CustomSupervisor, - get_graph_size + CustomSupervisor, + get_graph_size ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def plot_image(ax, image, label): - if len(image.shape) == 3: - get_logger().info('image shape: %s (%s)', image.shape, image.shape[-1]) - if image.shape[-1] == 1: - ax.imshow(image.squeeze(), aspect='auto', vmin=0, vmax=255, cmap=plt.get_cmap('gray')) + if len(image.shape) == 3: + get_logger().info('image shape: %s (%s)', image.shape, image.shape[-1]) + if image.shape[-1] == 1: + ax.imshow(image.squeeze(), aspect='auto', vmin=0, vmax=255, cmap=plt.get_cmap('gray')) + else: + ax.imshow(image, aspect='auto') else: - ax.imshow(image, aspect='auto') - else: - ax.imshow(np.dstack((image.astype(np.uint8),)*3)*100, aspect='auto') - ax.set_title(label, color=(0.5, 0.5, 0.5), y=0.995) - ax.set_axis_off() - ax.set(xlim=[0, 255], ylim=[255, 0], aspect=1) - ax.axes.get_xaxis().set_visible(False) - ax.axes.get_yaxis().set_visible(False) + ax.imshow(np.dstack((image.astype(np.uint8),) * 3) * 100, aspect='auto') + ax.set_title(label, color=(0.5, 0.5, 0.5), y=0.995) + ax.set_axis_off() + ax.set(xlim=[0, 255], ylim=[255, 0], aspect=1) + ax.axes.get_xaxis().set_visible(False) + ax.axes.get_yaxis().set_visible(False) + def show_result_images3(input_image, annot, output_image): - figsize = plt.figaspect(1.1 / 3.0) - fig, (ax_img, ax_annot, ax_out) = plt.subplots( - 1, - 3, - sharey=True, - figsize=figsize, - frameon=False, - facecolor=None, - dpi=60 - ) - - plot_image(ax_img, input_image, 'input') - plot_image(ax_annot, annot, 'target') - plot_image(ax_out, output_image, 'prediction') - - margin = 0.01 - fig.subplots_adjust( - left=margin, - right=1.0 - margin, - top=0.95 - margin, - bottom=margin, - wspace=0.05 - ) - - return fig + figsize = plt.figaspect(1.1 / 3.0) + fig, (ax_img, ax_annot, ax_out) = plt.subplots( + 1, + 3, + sharey=True, + figsize=figsize, + frameon=False, + facecolor=None, + dpi=60 + ) + + plot_image(ax_img, input_image, 'input') + plot_image(ax_annot, annot, 'target') + plot_image(ax_out, output_image, 'prediction') + + margin = 0.01 + fig.subplots_adjust( + left=margin, + right=1.0 - margin, + top=0.95 - margin, + bottom=margin, + wspace=0.05 + ) + + return fig + def save_image_data(image_filename, image_data): - image = Image.fromarray(np.array(image_data), "RGB") - with FileIO(image_filename, 'wb') as image_f: - image.save(image_f, 'png') + image = Image.fromarray(np.array(image_data), "RGB") + with FileIO(image_filename, 'wb') as image_f: + image.save(image_f, 'png') + def save_file(filename, data): - with FileIO(filename, 'wb') as f: - f.write(data) + with FileIO(filename, 'wb') as f: + f.write(data) + def precision_from_tp_fp(tp, fp): - return tp / (tp + fp) + return tp / (tp + fp) + def recall_from_tp_fn(tp, fn): - return tp / (tp + fn) + return tp / (tp + fn) + def f1_from_precision_recall(precision, recall): - return 2 * precision * recall / (precision + recall) + return 2 * precision * recall / (precision + recall) + def f1_from_tp_fp_fn(tp, fp, fn): - return f1_from_precision_recall( - precision_from_tp_fp(tp, fp), - recall_from_tp_fn(tp, fn) - ) + return f1_from_precision_recall( + precision_from_tp_fp(tp, fp), + recall_from_tp_fn(tp, fn) + ) + def to_list_if_not_none(x): - return x.tolist() if x is not None else x + return x.tolist() if x is not None else x + IMAGE_PREFIX = 'image_' + class Evaluator(object): - """Loads variables from latest checkpoint and performs model evaluation.""" - - def __init__( - self, args, model, - checkpoint_path, - data_paths, - dataset='eval', - eval_batch_size=None, - eval_set_size=None, - qualitative_set_size=None, - run_async=None): - - self.eval_batch_size = eval_batch_size or args.eval_batch_size - self.num_eval_batches = (eval_set_size or args.eval_set_size) // self.eval_batch_size - self.num_detail_eval_batches = ( - min((qualitative_set_size or 10), args.eval_set_size) // self.eval_batch_size - ) - self.checkpoint_path = checkpoint_path - self.output_path = os.path.join(args.output_path, dataset) - self.eval_data_paths = data_paths - self.batch_size = args.batch_size - self.stream = args.streaming_eval - self.model = model - self.results_dir = os.path.join(self.output_path, 'results') - self.graph_size = None - self.run_async = run_async - if not run_async: - self.run_async = lambda f, args: f(*args) - - def init(self): - file_io.recursive_create_dir(self.results_dir) - - def _check_fetches(self, fetches): - for k, v in six.iteritems(fetches): - if v is None: - raise Exception('fetches tensor is None: {}'.format(k)) - - def _get_default_fetches(self, tensors): - return { - 'global_step': tensors.global_step, - 'input_uri': tensors.input_uri, - 'input_image': tensors.image_tensor, - 'annotation_image': tensors.annotation_tensor, - 'output_image': tensors.summaries.get('output_image'), - 'metric_values': tensors.metric_values - } - - def _add_image_fetches(self, fetches, tensors): - for k, v in six.iteritems(tensors.image_tensors): - fetches[IMAGE_PREFIX + k] = v - return fetches - - def _add_evaluation_result_fetches(self, fetches, tensors): - if tensors.evaluation_result: - if tensors.output_layer_labels is not None: - fetches['output_layer_labels'] = tensors.output_layer_labels - fetches['confusion_matrix'] = tensors.evaluation_result.confusion_matrix - fetches['tp'] = tensors.evaluation_result.tp - fetches['fp'] = tensors.evaluation_result.fp - fetches['fn'] = tensors.evaluation_result.fn - fetches['tn'] = tensors.evaluation_result.tn - fetches['accuracy'] = tensors.evaluation_result.accuracy - fetches['micro_f1'] = tensors.evaluation_result.micro_f1 - return fetches - - def _accumulate_evaluation_results(self, results, accumulated_results=None): - if results.get('confusion_matrix') is None: - return accumulated_results - if accumulated_results is None: - accumulated_results = [] - accumulated_results.append({ - 'output_layer_labels': results.get('output_layer_labels'), - 'confusion_matrix': results['confusion_matrix'], - 'tp': results['tp'], - 'fp': results['fp'], - 'fn': results['fn'], - 'tn': results['tn'], - 'accuracy': results['accuracy'], - 'micro_f1': results['micro_f1'], - 'count': self.batch_size, - 'global_step': results['global_step'] - }) - return accumulated_results - - def _save_accumulate_evaluation_results(self, accumulated_results): - if accumulated_results: - first_result = accumulated_results[0] - global_step = first_result['global_step'] - graph_size = get_graph_size() - output_layer_labels = to_list_if_not_none(first_result.get('output_layer_labels')) - scores_file = os.path.join( - self.results_dir, 'result_{}_scores.json'.format( - global_step - ) - ) - tp = np.sum([r['tp'] for r in accumulated_results], axis=0) - fp = np.sum([r['fp'] for r in accumulated_results], axis=0) - fn = np.sum([r['fn'] for r in accumulated_results], axis=0) - tn = np.sum([r['tn'] for r in accumulated_results], axis=0) - f1 = f1_from_tp_fp_fn(tp.astype(float), fp, fn) - meta = { - 'global_step': global_step, - 'batch_size': self.batch_size - } - if self.graph_size: - meta['graph_size'] = self.graph_size - scores = { - 'accuracy': float(np.mean([r['accuracy'] for r in accumulated_results])), - 'output_layer_labels': output_layer_labels, - 'confusion_matrix': sum([r['confusion_matrix'] for r in accumulated_results]).tolist(), - 'tp': to_list_if_not_none(tp), - 'fp': to_list_if_not_none(fp), - 'fn': to_list_if_not_none(fn), - 'tn': to_list_if_not_none(tn), - 'f1': to_list_if_not_none(f1), - 'micro_f1': float(np.mean([r['micro_f1'] for r in accumulated_results])), - 'macro_f1': float(np.mean(f1)), - 'count': sum([r['count'] for r in accumulated_results]), - 'meta': meta - } - scores_str = json.dumps(scores, indent=2) - with FileIO(scores_file, 'w') as f: - f.write(scores_str) - - def _save_prediction_summary_image_for( - self, eval_index, global_step, inputs, targets, outputs, name): - - for batch_index, input_image, target_image, output_image in zip( - range(len(inputs)), inputs, targets, outputs - ): - - fig = show_result_images3( - input_image, - target_image, - output_image - ) - result_file = os.path.join( - self.results_dir, 'result_{}_{}_{}_{}.png'.format( - global_step, eval_index, batch_index, name + """Loads variables from latest checkpoint and performs model evaluation.""" + + def __init__( + self, args, model, + checkpoint_path, + data_paths, + dataset='eval', + eval_batch_size=None, + eval_set_size=None, + qualitative_set_size=None, + run_async=None): + + self.eval_batch_size = eval_batch_size or args.eval_batch_size + self.num_eval_batches = (eval_set_size or args.eval_set_size) // self.eval_batch_size + self.num_detail_eval_batches = ( + min((qualitative_set_size or 10), args.eval_set_size) // self.eval_batch_size ) - ) - logging.info('result_file: %s', result_file) - bio = BytesIO() - plt.savefig(bio, format='png', transparent=False, frameon=True, dpi='figure') - plt.close(fig) - self.run_async(save_file, (result_file, bio.getvalue())) - - def _save_prediction_summary_image(self, eval_index, results): - global_step = results['global_step'] - self._save_prediction_summary_image_for( - eval_index, - global_step, - results['input_image'], - results['annotation_image'], - results['output_image'], - 'summary_output' - ) - - if results.get(IMAGE_PREFIX + 'outputs_most_likely') is not None: - self._save_prediction_summary_image_for( - eval_index, - global_step, - results['input_image'], - results['annotation_image'], - results[IMAGE_PREFIX + 'outputs_most_likely'], - 'summary_most_likely' - ) - outputs_key_needle = '_outputs_' - for k in six.iterkeys(results): - outputs_key_needle_index = k.find(outputs_key_needle) - if k.startswith(IMAGE_PREFIX) and outputs_key_needle_index >= 0: - targets_key = k.replace(outputs_key_needle, '_targets_') - if not targets_key in results: - continue + self.checkpoint_path = checkpoint_path + self.output_path = os.path.join(args.output_path, dataset) + self.eval_data_paths = data_paths + self.batch_size = args.batch_size + self.stream = args.streaming_eval + self.model = model + self.results_dir = os.path.join(self.output_path, 'results') + self.graph_size = None + self.run_async = run_async + if not run_async: + self.run_async = lambda f, args: f(*args) + + def init(self): + file_io.recursive_create_dir(self.results_dir) + + def _check_fetches(self, fetches): + for k, v in six.iteritems(fetches): + if v is None: + raise Exception('fetches tensor is None: {}'.format(k)) + + def _get_default_fetches(self, tensors): + return { + 'global_step': tensors.global_step, + 'input_uri': tensors.input_uri, + 'input_image': tensors.image_tensor, + 'annotation_image': tensors.annotation_tensor, + 'output_image': tensors.summaries.get('output_image'), + 'metric_values': tensors.metric_values + } + + def _add_image_fetches(self, fetches, tensors): + for k, v in six.iteritems(tensors.image_tensors): + fetches[IMAGE_PREFIX + k] = v + return fetches + + def _add_evaluation_result_fetches(self, fetches, tensors): + if tensors.evaluation_result: + if tensors.output_layer_labels is not None: + fetches['output_layer_labels'] = tensors.output_layer_labels + fetches['confusion_matrix'] = tensors.evaluation_result.confusion_matrix + fetches['tp'] = tensors.evaluation_result.tp + fetches['fp'] = tensors.evaluation_result.fp + fetches['fn'] = tensors.evaluation_result.fn + fetches['tn'] = tensors.evaluation_result.tn + fetches['accuracy'] = tensors.evaluation_result.accuracy + fetches['micro_f1'] = tensors.evaluation_result.micro_f1 + return fetches + + def _accumulate_evaluation_results(self, results, accumulated_results=None): + if results.get('confusion_matrix') is None: + return accumulated_results + if accumulated_results is None: + accumulated_results = [] + accumulated_results.append({ + 'output_layer_labels': results.get('output_layer_labels'), + 'confusion_matrix': results['confusion_matrix'], + 'tp': results['tp'], + 'fp': results['fp'], + 'fn': results['fn'], + 'tn': results['tn'], + 'accuracy': results['accuracy'], + 'micro_f1': results['micro_f1'], + 'count': self.batch_size, + 'global_step': results['global_step'] + }) + return accumulated_results + + def _save_accumulate_evaluation_results(self, accumulated_results): + if accumulated_results: + first_result = accumulated_results[0] + global_step = first_result['global_step'] + output_layer_labels = to_list_if_not_none(first_result.get('output_layer_labels')) + scores_file = os.path.join( + self.results_dir, 'result_{}_scores.json'.format( + global_step + ) + ) + tp = np.sum([r['tp'] for r in accumulated_results], axis=0) + fp = np.sum([r['fp'] for r in accumulated_results], axis=0) + fn = np.sum([r['fn'] for r in accumulated_results], axis=0) + tn = np.sum([r['tn'] for r in accumulated_results], axis=0) + f1 = f1_from_tp_fp_fn(tp.astype(float), fp, fn) + meta = { + 'global_step': global_step, + 'batch_size': self.batch_size + } + if self.graph_size: + meta['graph_size'] = self.graph_size + scores = { + 'accuracy': float(np.mean([r['accuracy'] for r in accumulated_results])), + 'output_layer_labels': output_layer_labels, + 'confusion_matrix': sum( + [r['confusion_matrix'] for r in accumulated_results] + ).tolist(), + 'tp': to_list_if_not_none(tp), + 'fp': to_list_if_not_none(fp), + 'fn': to_list_if_not_none(fn), + 'tn': to_list_if_not_none(tn), + 'f1': to_list_if_not_none(f1), + 'micro_f1': float(np.mean([r['micro_f1'] for r in accumulated_results])), + 'macro_f1': float(np.mean(f1)), + 'count': sum([r['count'] for r in accumulated_results]), + 'meta': meta + } + scores_str = json.dumps(scores, indent=2) + with FileIO(scores_file, 'w') as f: + f.write(scores_str) + + def _save_prediction_summary_image_for( + self, eval_index, global_step, inputs, targets, outputs, name): + + for batch_index, input_image, target_image, output_image in zip( + range(len(inputs)), inputs, targets, outputs + ): + + fig = show_result_images3( + input_image, + target_image, + output_image + ) + result_file = os.path.join( + self.results_dir, 'result_{}_{}_{}_{}.png'.format( + global_step, eval_index, batch_index, name + ) + ) + logging.info('result_file: %s', result_file) + bio = BytesIO() + plt.savefig(bio, format='png', transparent=False, frameon=True, dpi='figure') + plt.close(fig) + self.run_async(save_file, (result_file, bio.getvalue())) + + def _save_prediction_summary_image(self, eval_index, results): + global_step = results['global_step'] self._save_prediction_summary_image_for( - eval_index, - global_step, - results['input_image'], - results[targets_key], - results[k], - 'summary_' + k[(outputs_key_needle_index + len(outputs_key_needle)):] + eval_index, + global_step, + results['input_image'], + results['annotation_image'], + results['output_image'], + 'summary_output' ) - def _save_result_images(self, eval_index, results): - global_step = results['global_step'] - for k in six.iterkeys(results): - if k.startswith(IMAGE_PREFIX): - batch_image_data = results[k] - name = k[len(IMAGE_PREFIX):] - for batch_index, image_data in enumerate(batch_image_data): - image_filename = os.path.join( - self.results_dir, 'result_{}_{}_{}_{}.png'.format( - global_step, eval_index, batch_index, name + if results.get(IMAGE_PREFIX + 'outputs_most_likely') is not None: + self._save_prediction_summary_image_for( + eval_index, + global_step, + results['input_image'], + results['annotation_image'], + results[IMAGE_PREFIX + 'outputs_most_likely'], + 'summary_most_likely' ) - ) - self.run_async(save_image_data, (image_filename, image_data)) - - def _save_meta(self, eval_index, results): - global_step = results['global_step'] - metric_values = results['metric_values'] - for batch_index, input_uri in enumerate(results['input_uri']): - meta_file = os.path.join( - self.results_dir, 'result_{}_{}_{}_meta.json'.format( - global_step, eval_index, batch_index - ) - ) - meta_str = json.dumps({ - 'global_step': int(global_step), - 'eval_index': eval_index, - 'batch_index': batch_index, - 'metric_values': [float(x) for x in metric_values], - 'input_uri': input_uri - }, indent=2) - with FileIO(meta_file, 'w') as meta_f: - meta_f.write(meta_str) - - def evaluate_in_session(self, session, tensors, num_eval_batches=None): - summary_writer = tf.summary.FileWriter(self.output_path) - num_eval_batches = num_eval_batches or self.num_eval_batches - num_detailed_eval_batches = min(self.num_detail_eval_batches, num_eval_batches) - if self.stream: - for _ in range(num_eval_batches): - session.run(tensors.metric_updates, feed_dict={ - tensors.is_training: False - }) - else: - get_logger().info('tensors.examples: %s', tensors.examples) - - metric_values = None - accumulated_results = None - - for eval_index in range(num_eval_batches): - detailed_evaluation = eval_index < num_detailed_eval_batches - - fetches = self._get_default_fetches(tensors) - self._add_evaluation_result_fetches(fetches, tensors) - if detailed_evaluation: - self._add_image_fetches(fetches, tensors) - fetches['summary_value'] = tensors.summary - self._check_fetches(fetches) - results = session.run(fetches, feed_dict={ - tensors.is_training: False - }) - - accumulated_results = self._accumulate_evaluation_results(results, accumulated_results) - if detailed_evaluation: - self._save_prediction_summary_image(eval_index, results) - self._save_result_images(eval_index, results) - self._save_meta(eval_index, results) - - global_step = results['global_step'] - summary_value = results['summary_value'] - summary_writer.add_summary(summary_value, global_step) - summary_writer.flush() - + outputs_key_needle = '_outputs_' + for k in six.iterkeys(results): + outputs_key_needle_index = k.find(outputs_key_needle) + if k.startswith(IMAGE_PREFIX) and outputs_key_needle_index >= 0: + targets_key = k.replace(outputs_key_needle, '_targets_') + if targets_key not in results: + continue + self._save_prediction_summary_image_for( + eval_index, + global_step, + results['input_image'], + results[targets_key], + results[k], + 'summary_' + k[(outputs_key_needle_index + len(outputs_key_needle)):] + ) + + def _save_result_images(self, eval_index, results): + global_step = results['global_step'] + for k in six.iterkeys(results): + if k.startswith(IMAGE_PREFIX): + batch_image_data = results[k] + name = k[len(IMAGE_PREFIX):] + for batch_index, image_data in enumerate(batch_image_data): + image_filename = os.path.join( + self.results_dir, 'result_{}_{}_{}_{}.png'.format( + global_step, eval_index, batch_index, name + ) + ) + self.run_async(save_image_data, (image_filename, image_data)) + + def _save_meta(self, eval_index, results): + global_step = results['global_step'] metric_values = results['metric_values'] + for batch_index, input_uri in enumerate(results['input_uri']): + meta_file = os.path.join( + self.results_dir, 'result_{}_{}_{}_meta.json'.format( + global_step, eval_index, batch_index + ) + ) + meta_str = json.dumps({ + 'global_step': int(global_step), + 'eval_index': eval_index, + 'batch_index': batch_index, + 'metric_values': [float(x) for x in metric_values], + 'input_uri': input_uri + }, indent=2) + with FileIO(meta_file, 'w') as meta_f: + meta_f.write(meta_str) + + def evaluate_in_session(self, session, tensors, num_eval_batches=None): + summary_writer = tf.summary.FileWriter(self.output_path) + num_eval_batches = num_eval_batches or self.num_eval_batches + num_detailed_eval_batches = min(self.num_detail_eval_batches, num_eval_batches) + if self.stream: + for _ in range(num_eval_batches): + session.run(tensors.metric_updates, feed_dict={ + tensors.is_training: False + }) + else: + get_logger().info('tensors.examples: %s', tensors.examples) - self._save_accumulate_evaluation_results(accumulated_results) - - logging.info('eval done') - return metric_values + metric_values = None + accumulated_results = None - metric_values = session.run(tensors.metric_values, feed_dict={ - tensors.is_training: False - }) - return metric_values + for eval_index in range(num_eval_batches): + detailed_evaluation = eval_index < num_detailed_eval_batches - def evaluate(self, num_eval_batches=None, session=None): - """Run one round of evaluation, return loss and accuracy.""" + fetches = self._get_default_fetches(tensors) + self._add_evaluation_result_fetches(fetches, tensors) + if detailed_evaluation: + self._add_image_fetches(fetches, tensors) + fetches['summary_value'] = tensors.summary + self._check_fetches(fetches) + results = session.run(fetches, feed_dict={ + tensors.is_training: False + }) - num_eval_batches = num_eval_batches or self.num_eval_batches - with tf.Graph().as_default() as graph: - tensors = self.model.build_eval_graph( - self.eval_data_paths, - self.eval_batch_size - ) - self.graph_size = get_graph_size() + accumulated_results = self._accumulate_evaluation_results( + results, accumulated_results) + if detailed_evaluation: + self._save_prediction_summary_image(eval_index, results) + self._save_result_images(eval_index, results) + self._save_meta(eval_index, results) - saver = tf.train.Saver() + global_step = results['global_step'] + summary_value = results['summary_value'] + summary_writer.add_summary(summary_value, global_step) + summary_writer.flush() - sv = CustomSupervisor( - model=self.model, - graph=graph, - logdir=self.output_path, - summary_op=None, - global_step=None, - saver=saver - ) - try: - last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) - logging.info('last_checkpoint: %s (%s)', last_checkpoint, self.checkpoint_path) + metric_values = results['metric_values'] - file_io.recursive_create_dir(self.results_dir) + self._save_accumulate_evaluation_results(accumulated_results) - with sv.managed_session( - master='', start_standard_services=False) as session: - sv.saver.restore(session, last_checkpoint) + logging.info('eval done') + return metric_values - logging.info('session restored') + metric_values = session.run(tensors.metric_values, feed_dict={ + tensors.is_training: False + }) + return metric_values - if self.stream: - logging.info('start queue runners (stream)') - sv.start_queue_runners(session) - for _ in range(num_eval_batches): - session.run(self.tensors.metric_updates, feed_dict={ - tensors.is_training: False - }) - else: - logging.info('start queue runners (batch)') - sv.start_queue_runners(session) - - logging.info('evaluate_in_session') - return self.evaluate_in_session(session, tensors) - finally: - sv.stop() - - def write_predictions(self): - """Run one round of predictions and write predictions to csv file.""" - num_eval_batches = self.num_eval_batches - num_detailed_eval_batches = self.num_detail_eval_batches - with tf.Graph().as_default() as graph: - tensors = self.model.build_eval_graph( - self.eval_data_paths, - self.batch_size - ) - self.graph_size = get_graph_size() - saver = tf.train.Saver() - - sv = CustomSupervisor( - model=self.model, - graph=graph, - logdir=self.output_path, - summary_op=None, - global_step=None, - saver=saver - ) + def evaluate(self, num_eval_batches=None): + """Run one round of evaluation, return loss and accuracy.""" - file_io.recursive_create_dir(self.results_dir) - - accumulated_results = None - - last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) - with sv.managed_session( - master='', start_standard_services=False) as session: - sv.saver.restore(session, last_checkpoint) - predictions_filename = os.path.join(self.output_path, 'predictions.csv') - with FileIO(predictions_filename, 'w') as csv_f: - sv.start_queue_runners(session) - last_log_progress = 0 - for eval_index in range(num_eval_batches): - progress = eval_index * 100 // num_eval_batches - if progress > last_log_progress: - logging.info('%3d%% predictions processed', progress) - last_log_progress = progress - - detailed_evaluation = eval_index < num_detailed_eval_batches - - fetches = self._get_default_fetches(tensors) - self._add_evaluation_result_fetches(fetches, tensors) - if detailed_evaluation: - self._add_image_fetches(fetches, tensors) - self._check_fetches(fetches) - results = session.run(fetches, feed_dict={ - tensors.is_training: False - }) + num_eval_batches = num_eval_batches or self.num_eval_batches + with tf.Graph().as_default() as graph: + tensors = self.model.build_eval_graph( + self.eval_data_paths, + self.eval_batch_size + ) + self.graph_size = get_graph_size() - accumulated_results = self._accumulate_evaluation_results(results, accumulated_results) - if detailed_evaluation: - self._save_prediction_summary_image(eval_index, results) - self._save_result_images(eval_index, results) - self._save_meta(eval_index, results) + saver = tf.train.Saver() - input_uri = results['input_uri'] - metric_values = results['metric_values'] - csv_f.write('{},{}\n'.format(input_uri, metric_values[0])) + sv = CustomSupervisor( + model=self.model, + graph=graph, + logdir=self.output_path, + summary_op=None, + global_step=None, + saver=saver + ) + try: + last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) + logging.info('last_checkpoint: %s (%s)', last_checkpoint, self.checkpoint_path) + + file_io.recursive_create_dir(self.results_dir) + + with sv.managed_session( + master='', start_standard_services=False) as session: + sv.saver.restore(session, last_checkpoint) + + logging.info('session restored') + + if self.stream: + logging.info('start queue runners (stream)') + sv.start_queue_runners(session) + for _ in range(num_eval_batches): + session.run(tensors.metric_updates, feed_dict={ + tensors.is_training: False + }) + else: + logging.info('start queue runners (batch)') + sv.start_queue_runners(session) + + logging.info('evaluate_in_session') + return self.evaluate_in_session(session, tensors) + finally: + sv.stop() + + def write_predictions(self): + """Run one round of predictions and write predictions to csv file.""" + num_eval_batches = self.num_eval_batches + num_detailed_eval_batches = self.num_detail_eval_batches + with tf.Graph().as_default() as graph: + tensors = self.model.build_eval_graph( + self.eval_data_paths, + self.batch_size + ) + self.graph_size = get_graph_size() + saver = tf.train.Saver() + + sv = CustomSupervisor( + model=self.model, + graph=graph, + logdir=self.output_path, + summary_op=None, + global_step=None, + saver=saver + ) - self._save_accumulate_evaluation_results(accumulated_results) + file_io.recursive_create_dir(self.results_dir) + + accumulated_results = None + + last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) + with sv.managed_session( + master='', start_standard_services=False) as session: + sv.saver.restore(session, last_checkpoint) + predictions_filename = os.path.join(self.output_path, 'predictions.csv') + with FileIO(predictions_filename, 'w') as csv_f: + sv.start_queue_runners(session) + last_log_progress = 0 + for eval_index in range(num_eval_batches): + progress = eval_index * 100 // num_eval_batches + if progress > last_log_progress: + logging.info('%3d%% predictions processed', progress) + last_log_progress = progress + + detailed_evaluation = eval_index < num_detailed_eval_batches + + fetches = self._get_default_fetches(tensors) + self._add_evaluation_result_fetches(fetches, tensors) + if detailed_evaluation: + self._add_image_fetches(fetches, tensors) + self._check_fetches(fetches) + results = session.run(fetches, feed_dict={ + tensors.is_training: False + }) + + accumulated_results = self._accumulate_evaluation_results( + results, accumulated_results) + if detailed_evaluation: + self._save_prediction_summary_image(eval_index, results) + self._save_result_images(eval_index, results) + self._save_meta(eval_index, results) + + input_uri = results['input_uri'] + metric_values = results['metric_values'] + csv_f.write('{},{}\n'.format(input_uri, metric_values[0])) + + self._save_accumulate_evaluation_results(accumulated_results) diff --git a/sciencebeam_gym/trainer/evaluator_test.py b/sciencebeam_gym/trainer/evaluator_test.py index 3f257f0..1a891fe 100644 --- a/sciencebeam_gym/trainer/evaluator_test.py +++ b/sciencebeam_gym/trainer/evaluator_test.py @@ -4,25 +4,25 @@ import pytest import tensorflow as tf from sciencebeam_utils.utils.collection import ( - to_namedtuple + to_namedtuple ) from sciencebeam_gym.utils.tfrecord import ( - dict_to_example + dict_to_example ) from sciencebeam_gym.trainer.data.examples import ( - parse_example, - MapKeysTracker + parse_example, + MapKeysTracker ) from sciencebeam_gym.trainer.data.examples_test import ( - EXAMPLE_PROPS_1, - list_dataset + EXAMPLE_PROPS_1, + list_dataset ) from sciencebeam_gym.trainer.evaluator import ( - Evaluator + Evaluator ) TEST_PATH = '.temp/test/evaluator' @@ -34,102 +34,108 @@ BATCH_SIZE = 10 EVAL_SET_SIZE = 10 DEFAULT_ARGS = dict( - batch_size=BATCH_SIZE, - eval_set_size=EVAL_SET_SIZE, - streaming_eval=False, - output_path=OUTPUT_PATH + batch_size=BATCH_SIZE, + eval_set_size=EVAL_SET_SIZE, + streaming_eval=False, + output_path=OUTPUT_PATH ) DEFAULT_KWARGS = dict( - checkpoint_path=CHECKPOINT_PATH, - data_paths=DATA_PATHS, - eval_batch_size=BATCH_SIZE + checkpoint_path=CHECKPOINT_PATH, + data_paths=DATA_PATHS, + eval_batch_size=BATCH_SIZE ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + class GraphMode(object): - TRAIN = 'train' - EVALUATE = 'eval' + TRAIN = 'train' + EVALUATE = 'eval' + def example_dataset(map_keys_tracker, examples): - dataset = list_dataset([ - dict_to_example(example).SerializeToString() - for example in examples - ], tf.string) - dataset = dataset.map(map_keys_tracker.wrap(parse_example)) - return dataset + dataset = list_dataset([ + dict_to_example(example).SerializeToString() + for example in examples + ], tf.string) + dataset = dataset.map(map_keys_tracker.wrap(parse_example)) + return dataset + class ExampleModel(object): - def __init__(self, examples): - self.examples = examples - - def build_graph(self, data_paths, batch_size, graph_mode): - tensors = dict() - tensors['is_training'] = tf.placeholder(tf.bool) - map_keys_tracker = MapKeysTracker() - dataset = example_dataset(map_keys_tracker, self.examples) - iterator = dataset.make_one_shot_iterator() - parsed = map_keys_tracker.unwrap(iterator.get_next()) - get_logger().debug('parsed: %s', parsed) - tensors['examples'] = parsed - tensors['metric_values'] = [] - tensors['metric_updates'] = [] - tensors['global_step'] = tf.constant(100, tf.int32) - tensors['summaries'] = dict() - tensors['image_tensors'] = dict() - tensors['evaluation_result'] = None - image_shape = (10, 10, 3) - pre_batch_tensors = { - 'input_uri': tf.squeeze(parsed['input_uri']), - 'annotation_uri': tf.squeeze(parsed['annotation_uri']), - 'image_tensor': tf.zeros(image_shape), - 'annotation_tensor': tf.zeros(image_shape), - 'output_image': tf.zeros(image_shape) - } - post_batch_tensors = tf.train.batch(pre_batch_tensors, batch_size=batch_size) - tensors.update(post_batch_tensors) - for name in ['image_tensor', 'annotation_tensor', 'output_image']: - image_tensor = tensors[name] - get_logger().debug('name=%s, image_tensor=%s', name, image_tensor) - tf.summary.image(name, image_tensor) - tensors['image_tensors'][name] = image_tensor - tensors['summaries'][name] = image_tensor - tensors['summary'] = tf.summary.merge_all() - tensors['initializer'] = [tf.global_variables_initializer()] - return to_namedtuple(tensors, name='Tensors') - - def build_train_graph(self, data_paths, batch_size): - return self.build_graph(data_paths, batch_size, GraphMode.TRAIN) - - def build_eval_graph(self, data_paths, batch_size): - return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE) + def __init__(self, examples): + self.examples = examples + + def build_graph(self, data_paths, batch_size, graph_mode): # pylint: disable=unused-argument + tensors = dict() + tensors['is_training'] = tf.placeholder(tf.bool) + map_keys_tracker = MapKeysTracker() + dataset = example_dataset(map_keys_tracker, self.examples) + iterator = dataset.make_one_shot_iterator() + parsed = map_keys_tracker.unwrap(iterator.get_next()) + get_logger().debug('parsed: %s', parsed) + tensors['examples'] = parsed + tensors['metric_values'] = [] + tensors['metric_updates'] = [] + tensors['global_step'] = tf.constant(100, tf.int32) + tensors['summaries'] = dict() + tensors['image_tensors'] = dict() + tensors['evaluation_result'] = None + image_shape = (10, 10, 3) + pre_batch_tensors = { + 'input_uri': tf.squeeze(parsed['input_uri']), + 'annotation_uri': tf.squeeze(parsed['annotation_uri']), + 'image_tensor': tf.zeros(image_shape), + 'annotation_tensor': tf.zeros(image_shape), + 'output_image': tf.zeros(image_shape) + } + post_batch_tensors = tf.train.batch(pre_batch_tensors, batch_size=batch_size) + tensors.update(post_batch_tensors) + for name in ['image_tensor', 'annotation_tensor', 'output_image']: + image_tensor = tensors[name] + get_logger().debug('name=%s, image_tensor=%s', name, image_tensor) + tf.summary.image(name, image_tensor) + tensors['image_tensors'][name] = image_tensor + tensors['summaries'][name] = image_tensor + tensors['summary'] = tf.summary.merge_all() + tensors['initializer'] = [tf.global_variables_initializer()] + return to_namedtuple(tensors, name='Tensors') + + def build_train_graph(self, data_paths, batch_size): + return self.build_graph(data_paths, batch_size, GraphMode.TRAIN) + + def build_eval_graph(self, data_paths, batch_size): + return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE) + @pytest.mark.slow class TestEvaluator(object): - def test_should_not_fail_eval_in_session(self): - with tf.Graph().as_default(): - model = ExampleModel([EXAMPLE_PROPS_1] * BATCH_SIZE) - tensors = model.build_train_graph( - DATA_PATHS, BATCH_SIZE - ) - - evaluator = Evaluator( - args=to_namedtuple(DEFAULT_ARGS, name='args'), - model=model, - **DEFAULT_KWARGS - ) - evaluator.init() - - get_logger().info('starting session') - with tf.Session() as session: - coord = tf.train.Coordinator() - tf.train.start_queue_runners(sess=session, coord=coord) - get_logger().info('evaluating') - session.run(tensors.initializer) - evaluator.evaluate_in_session(session, tensors) - get_logger().info('done') + def test_should_not_fail_eval_in_session(self): + with tf.Graph().as_default(): + model = ExampleModel([EXAMPLE_PROPS_1] * BATCH_SIZE) + tensors = model.build_train_graph( + DATA_PATHS, BATCH_SIZE + ) + + evaluator = Evaluator( + args=to_namedtuple(DEFAULT_ARGS, name='args'), + model=model, + **DEFAULT_KWARGS + ) + evaluator.init() + + get_logger().info('starting session') + with tf.Session() as session: + coord = tf.train.Coordinator() + tf.train.start_queue_runners(sess=session, coord=coord) + get_logger().info('evaluating') + session.run(tensors.initializer) + evaluator.evaluate_in_session(session, tensors) + get_logger().info('done') diff --git a/sciencebeam_gym/trainer/models/pix2pix/evaluate.py b/sciencebeam_gym/trainer/models/pix2pix/evaluate.py index c8a7846..5c79b05 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/evaluate.py +++ b/sciencebeam_gym/trainer/models/pix2pix/evaluate.py @@ -6,114 +6,122 @@ import collections import tensorflow as tf EvaluationTensors = collections.namedtuple( - "EvaluationTensors", [ - "confusion_matrix", - "tp", - "fp", - "fn", - "tn", - "precision", - "recall", - "f1", - "accuracy", - "micro_precision", - "micro_recall", - "micro_f1", - "macro_precision", - "macro_recall", - "macro_f1" - ] + "EvaluationTensors", [ + "confusion_matrix", + "tp", + "fp", + "fn", + "tn", + "precision", + "recall", + "f1", + "accuracy", + "micro_precision", + "micro_recall", + "micro_f1", + "macro_precision", + "macro_recall", + "macro_f1" + ] ) + def output_probabilities_to_class(outputs): - return tf.argmax(outputs, 3) + return tf.argmax(outputs, 3) + def to_1d_vector(tensor): - return tf.reshape(tensor, [-1]) + return tf.reshape(tensor, [-1]) + def precision_from_tp_fp(tp, fp): - return tp / (tp + fp) + return tp / (tp + fp) + def recall_from_tp_fn(tp, fn): - return tp / (tp + fn) + return tp / (tp + fn) + def f1_from_precision_recall(precision, recall): - return 2 * precision * recall / (precision + recall) + return 2 * precision * recall / (precision + recall) + def _evaluate_from_confusion_matrix(confusion, accuracy=None): - total = tf.reduce_sum(confusion) - actual_p = tf.reduce_sum(confusion, axis=0) - pred_p = tf.reduce_sum(confusion, axis=1) - tp = tf.diag_part(confusion) - fp = actual_p - tp - fn = pred_p - tp - tn = total - tp - fp - fn - precision = precision_from_tp_fp(tp, fp) - recall = recall_from_tp_fn(tp, fn) - f1 = f1_from_precision_recall(precision, recall) - total_tp = tf.reduce_sum(tp) - total_fp = tf.reduce_sum(fp) - total_fn = tf.reduce_sum(fn) - # Note: micro averages (with equal weights) will lead to the same precision, recall, f1 - micro_precision = precision_from_tp_fp(total_tp, total_fp) - micro_recall = recall_from_tp_fn(total_tp, total_fn) - micro_f1 = f1_from_precision_recall(micro_precision, micro_recall) - macro_precision = tf.reduce_sum(precision) - macro_recall = tf.reduce_sum(recall) - macro_f1 = tf.reduce_sum(f1) - return EvaluationTensors( - confusion_matrix=confusion, - tp=tp, - fp=fp, - fn=fn, - tn=tn, - precision=precision, - recall=recall, - f1=f1, - accuracy=accuracy, - micro_precision=micro_precision, - micro_recall=micro_recall, - micro_f1=micro_f1, - macro_precision=macro_precision, - macro_recall=macro_recall, - macro_f1=macro_f1 - ) + total = tf.reduce_sum(confusion) + actual_p = tf.reduce_sum(confusion, axis=0) + pred_p = tf.reduce_sum(confusion, axis=1) + tp = tf.diag_part(confusion) + fp = actual_p - tp + fn = pred_p - tp + tn = total - tp - fp - fn + precision = precision_from_tp_fp(tp, fp) + recall = recall_from_tp_fn(tp, fn) + f1 = f1_from_precision_recall(precision, recall) + total_tp = tf.reduce_sum(tp) + total_fp = tf.reduce_sum(fp) + total_fn = tf.reduce_sum(fn) + # Note: micro averages (with equal weights) will lead to the same precision, recall, f1 + micro_precision = precision_from_tp_fp(total_tp, total_fp) + micro_recall = recall_from_tp_fn(total_tp, total_fn) + micro_f1 = f1_from_precision_recall(micro_precision, micro_recall) + macro_precision = tf.reduce_sum(precision) + macro_recall = tf.reduce_sum(recall) + macro_f1 = tf.reduce_sum(f1) + return EvaluationTensors( + confusion_matrix=confusion, + tp=tp, + fp=fp, + fn=fn, + tn=tn, + precision=precision, + recall=recall, + f1=f1, + accuracy=accuracy, + micro_precision=micro_precision, + micro_recall=micro_recall, + micro_f1=micro_f1, + macro_precision=macro_precision, + macro_recall=macro_recall, + macro_f1=macro_f1 + ) + def evaluate_predictions(labels, predictions, n_classes): - labels = to_1d_vector(labels) - predictions = to_1d_vector(predictions) + labels = to_1d_vector(labels) + predictions = to_1d_vector(predictions) + + correct_prediction = tf.equal(labels, predictions) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - correct_prediction = tf.equal(labels, predictions) - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + confusion = tf.contrib.metrics.confusion_matrix( + labels=labels, + predictions=predictions, + num_classes=n_classes + ) - confusion = tf.contrib.metrics.confusion_matrix( - labels=labels, - predictions=predictions, - num_classes=n_classes - ) + return _evaluate_from_confusion_matrix( + confusion=confusion, + accuracy=accuracy + ) - return _evaluate_from_confusion_matrix( - confusion=confusion, - accuracy=accuracy - ) def evaluate_separate_channels(targets, outputs): - n_classes = targets.shape[-1] + n_classes = targets.shape[-1] - labels = output_probabilities_to_class(targets) - predictions = output_probabilities_to_class(outputs) - return evaluate_predictions( - labels=labels, - predictions=predictions, - n_classes=n_classes - ) + labels = output_probabilities_to_class(targets) + predictions = output_probabilities_to_class(outputs) + return evaluate_predictions( + labels=labels, + predictions=predictions, + n_classes=n_classes + ) def evaluation_summary(evaluation_tensors, layer_labels): - tf.summary.scalar("micro_precision", evaluation_tensors.micro_precision) - tf.summary.scalar("micro_recall", evaluation_tensors.micro_recall) - tf.summary.scalar("micro_f1", evaluation_tensors.micro_f1) - tf.summary.scalar("macro_f1", evaluation_tensors.macro_f1) - tf.summary.scalar("accuracy", evaluation_tensors.accuracy) - for i, layer_label in enumerate(layer_labels): - tf.summary.scalar("f1_{}_{}".format(i, layer_label), evaluation_tensors.f1[i]) + tf.summary.scalar("micro_precision", evaluation_tensors.micro_precision) + tf.summary.scalar("micro_recall", evaluation_tensors.micro_recall) + tf.summary.scalar("micro_f1", evaluation_tensors.micro_f1) + tf.summary.scalar("macro_f1", evaluation_tensors.macro_f1) + tf.summary.scalar("accuracy", evaluation_tensors.accuracy) + for i, layer_label in enumerate(layer_labels): + tf.summary.scalar("f1_{}_{}".format(i, layer_label), evaluation_tensors.f1[i]) diff --git a/sciencebeam_gym/trainer/models/pix2pix/evaluate_test.py b/sciencebeam_gym/trainer/models/pix2pix/evaluate_test.py index 99e9c6f..f6a4877 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/evaluate_test.py +++ b/sciencebeam_gym/trainer/models/pix2pix/evaluate_test.py @@ -7,49 +7,50 @@ import tensorflow as tf import numpy as np from sciencebeam_utils.utils.num import ( - assert_close + assert_close ) from sciencebeam_gym.trainer.models.pix2pix.evaluate import ( - evaluate_predictions + evaluate_predictions ) + @pytest.mark.slow def test_evaluate_predictions(): - n_classes = 4 - predictions = tf.constant(np.array([0, 1, 1, 2, 3, 3])) - labels = tf.constant(np.array([0, 1, 2, 3, 3, 3])) + n_classes = 4 + predictions = tf.constant(np.array([0, 1, 1, 2, 3, 3])) + labels = tf.constant(np.array([0, 1, 2, 3, 3, 3])) - evaluation_tensors = evaluate_predictions( - labels=labels, - predictions=predictions, - n_classes=n_classes - ) - with tf.Session() as session: - assert np.array_equal(session.run(evaluation_tensors.confusion_matrix), np.array([ - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 2] - ])) - assert np.array_equal(session.run(evaluation_tensors.tp), np.array([1, 1, 0, 2])) - assert np.array_equal(session.run(evaluation_tensors.fp), np.array([0, 1, 1, 0])) - assert np.array_equal(session.run(evaluation_tensors.fn), np.array([0, 0, 1, 1])) - expected_micro_precision = 4.0 / (4 + 2) - expected_micro_recall = 4.0 / (4 + 2) - expected_micro_f1 = ( - 2 * expected_micro_precision * expected_micro_recall / - (expected_micro_precision + expected_micro_recall) - ) - assert_close( - session.run(evaluation_tensors.micro_precision), - expected_micro_precision - ) - assert_close( - session.run(evaluation_tensors.micro_recall), - expected_micro_recall - ) - assert_close( - session.run(evaluation_tensors.micro_f1), - expected_micro_f1 + evaluation_tensors = evaluate_predictions( + labels=labels, + predictions=predictions, + n_classes=n_classes ) + with tf.Session() as session: + assert np.array_equal(session.run(evaluation_tensors.confusion_matrix), np.array([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 2] + ])) + assert np.array_equal(session.run(evaluation_tensors.tp), np.array([1, 1, 0, 2])) + assert np.array_equal(session.run(evaluation_tensors.fp), np.array([0, 1, 1, 0])) + assert np.array_equal(session.run(evaluation_tensors.fn), np.array([0, 0, 1, 1])) + expected_micro_precision = 4.0 / (4 + 2) + expected_micro_recall = 4.0 / (4 + 2) + expected_micro_f1 = ( + 2 * expected_micro_precision * expected_micro_recall / + (expected_micro_precision + expected_micro_recall) + ) + assert_close( + session.run(evaluation_tensors.micro_precision), + expected_micro_precision + ) + assert_close( + session.run(evaluation_tensors.micro_recall), + expected_micro_recall + ) + assert_close( + session.run(evaluation_tensors.micro_f1), + expected_micro_f1 + ) diff --git a/sciencebeam_gym/trainer/models/pix2pix/loss.py b/sciencebeam_gym/trainer/models/pix2pix/loss.py index 2a60587..a685fdf 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/loss.py +++ b/sciencebeam_gym/trainer/models/pix2pix/loss.py @@ -1,36 +1,39 @@ import tensorflow as tf + def l1_loss(labels, outputs): - with tf.name_scope("l1_loss"): - # abs(labels - outputs) => 0 - return tf.reduce_mean(tf.abs(labels - outputs)) + with tf.name_scope("l1_loss"): + # abs(labels - outputs) => 0 + return tf.reduce_mean(tf.abs(labels - outputs)) + def cross_entropy_loss(labels, logits): - with tf.name_scope("cross_entropy"): - return tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits( - logits=logits, - labels=labels - ) - ) + with tf.name_scope("cross_entropy"): + return tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits( + logits=logits, + labels=labels + ) + ) + def weighted_cross_entropy_loss(labels, logits, pos_weight, scalar=True): - with tf.name_scope("weighted_cross_entropy"): - softmax_loss = tf.nn.softmax_cross_entropy_with_logits( - logits=logits, - labels=labels - ) - # calculate weight per sample using the labels and weight per class - weight_per_sample = tf.reduce_sum( - tf.multiply( - labels, - pos_weight - ), - axis=-1 - ) - # weight each loss per sample - value = tf.multiply( - softmax_loss, - weight_per_sample - ) - return tf.reduce_mean(value) if scalar else value + with tf.name_scope("weighted_cross_entropy"): + softmax_loss = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, + labels=labels + ) + # calculate weight per sample using the labels and weight per class + weight_per_sample = tf.reduce_sum( + tf.multiply( + labels, + pos_weight + ), + axis=-1 + ) + # weight each loss per sample + value = tf.multiply( + softmax_loss, + weight_per_sample + ) + return tf.reduce_mean(value) if scalar else value diff --git a/sciencebeam_gym/trainer/models/pix2pix/loss_test.py b/sciencebeam_gym/trainer/models/pix2pix/loss_test.py index 25916b3..e5cc2cf 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/loss_test.py +++ b/sciencebeam_gym/trainer/models/pix2pix/loss_test.py @@ -1,137 +1,138 @@ import logging -from six import raise_from - import tensorflow as tf -import numpy as np from sciencebeam_utils.utils.num import ( - assert_close + assert_close ) from sciencebeam_gym.trainer.models.pix2pix.loss import ( - l1_loss, - cross_entropy_loss, - weighted_cross_entropy_loss + l1_loss, + cross_entropy_loss, + weighted_cross_entropy_loss ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class TestL1Loss(object): - def test_should_return_abs_diff_for_single_value(self): - with tf.Graph().as_default(): - labels = tf.constant([0.9]) - outputs = tf.constant([0.1]) - loss = l1_loss(labels, outputs) - with tf.Session() as session: - assert_close(session.run([loss])[0], 0.8) + def test_should_return_abs_diff_for_single_value(self): + with tf.Graph().as_default(): + labels = tf.constant([0.9]) + outputs = tf.constant([0.1]) + loss = l1_loss(labels, outputs) + with tf.Session() as session: + assert_close(session.run([loss])[0], 0.8) + class TestCrossEntropyLoss(object): - def test_should_return_zero_if_logits_are_matching_labels_with_neg_pos_value(self): - with tf.Graph().as_default(): - labels = tf.constant([ - [[0.0, 1.0]] - ]) - logits = tf.constant([ - [[-10.0, 10.0]] - ]) - loss = cross_entropy_loss(labels, logits) - with tf.Session() as session: - assert_close(session.run([loss])[0], 0.0) + def test_should_return_zero_if_logits_are_matching_labels_with_neg_pos_value(self): + with tf.Graph().as_default(): + labels = tf.constant([ + [[0.0, 1.0]] + ]) + logits = tf.constant([ + [[-10.0, 10.0]] + ]) + loss = cross_entropy_loss(labels, logits) + with tf.Session() as session: + assert_close(session.run([loss])[0], 0.0) + + def test_should_return_not_zero_if_logits_are_not_matching_labels(self): + with tf.Graph().as_default(): + labels = tf.constant([ + [[0.0, 1.0]] + ]) + logits = tf.constant([ + [[10.0, 10.0]] + ]) + loss = cross_entropy_loss(labels, logits) + with tf.Session() as session: + assert session.run([loss])[0] > 0.5 - def test_should_return_not_zero_if_logits_are_not_matching_labels(self): - with tf.Graph().as_default(): - labels = tf.constant([ - [[0.0, 1.0]] - ]) - logits = tf.constant([ - [[10.0, 10.0]] - ]) - loss = cross_entropy_loss(labels, logits) - with tf.Session() as session: - assert session.run([loss])[0] > 0.5 class TestWeightedCrossEntropyLoss(object): - def test_should_return_zero_if_logits_are_matching_labels_with_neg_pos_value(self): - with tf.Graph().as_default(): - labels = tf.constant([ - [[0.0, 1.0]] - ]) - logits = tf.constant([ - [[-10.0, 10.0]] - ]) - pos_weight = tf.constant([ - 1.0 - ]) - loss = weighted_cross_entropy_loss(labels, logits, pos_weight) - with tf.Session() as session: - assert_close(session.run([loss])[0], 0.0, atol=0.0001) + def test_should_return_zero_if_logits_are_matching_labels_with_neg_pos_value(self): + with tf.Graph().as_default(): + labels = tf.constant([ + [[0.0, 1.0]] + ]) + logits = tf.constant([ + [[-10.0, 10.0]] + ]) + pos_weight = tf.constant([ + 1.0 + ]) + loss = weighted_cross_entropy_loss(labels, logits, pos_weight) + with tf.Session() as session: + assert_close(session.run([loss])[0], 0.0, atol=0.0001) - def test_should_return_not_zero_if_logits_are_not_matching_labels(self): - with tf.Graph().as_default(): - labels = tf.constant([ - [[0.0, 1.0]] - ]) - logits = tf.constant([ - [[10.0, 10.0]] - ]) - pos_weight = tf.constant([ - 1.0 - ]) - loss = weighted_cross_entropy_loss(labels, logits, pos_weight) - with tf.Session() as session: - assert session.run([loss])[0] > 0.5 + def test_should_return_not_zero_if_logits_are_not_matching_labels(self): + with tf.Graph().as_default(): + labels = tf.constant([ + [[0.0, 1.0]] + ]) + logits = tf.constant([ + [[10.0, 10.0]] + ]) + pos_weight = tf.constant([ + 1.0 + ]) + loss = weighted_cross_entropy_loss(labels, logits, pos_weight) + with tf.Session() as session: + assert session.run([loss])[0] > 0.5 - def test_should_return_higher_loss_for_value_with_greater_weight(self): - with tf.Graph().as_default(): - labels = tf.constant([ - [[0.0, 1.0]] - ]) - logits = tf.constant([ - [[10.0, 10.0]] - ]) - pos_weight_1 = tf.constant([ - 0.5 - ]) - pos_weight_2 = tf.constant([ - 1.0 - ]) - loss_1 = weighted_cross_entropy_loss(labels, logits, pos_weight_1) - loss_2 = weighted_cross_entropy_loss(labels, logits, pos_weight_2) - with tf.Session() as session: - loss_1_value, loss_2_value = session.run([loss_1, loss_2]) - assert loss_1_value < loss_2_value + def test_should_return_higher_loss_for_value_with_greater_weight(self): + with tf.Graph().as_default(): + labels = tf.constant([ + [[0.0, 1.0]] + ]) + logits = tf.constant([ + [[10.0, 10.0]] + ]) + pos_weight_1 = tf.constant([ + 0.5 + ]) + pos_weight_2 = tf.constant([ + 1.0 + ]) + loss_1 = weighted_cross_entropy_loss(labels, logits, pos_weight_1) + loss_2 = weighted_cross_entropy_loss(labels, logits, pos_weight_2) + with tf.Session() as session: + loss_1_value, loss_2_value = session.run([loss_1, loss_2]) + assert loss_1_value < loss_2_value - def test_should_support_batch_example_pos_weights(self): - batch_size = 3 - with tf.Graph().as_default(): - labels = tf.constant([[0.0, 1.0]] * batch_size) - logits = tf.constant([[10.0, 10.0]] * batch_size) - pos_weight_1 = tf.constant([ - [0.5, 0.5], - [1.0, 1.0], - [1.0, 1.0] - ]) - pos_weight_2 = tf.constant([ - [1.0, 1.0], - [0.5, 0.5], - [1.0, 1.0] - ]) - loss_1 = weighted_cross_entropy_loss( - labels, logits, pos_weight_1, scalar=False - ) - loss_2 = weighted_cross_entropy_loss( - labels, logits, pos_weight_2, scalar=False - ) - with tf.Session() as session: - get_logger().debug('labels=\n%s', labels.eval()) - get_logger().debug('logits=\n%s', logits.eval()) - loss_1_value, loss_2_value = session.run([loss_1, loss_2]) - get_logger().debug( - '\nloss_1_value=\n%s\nloss_2_value=\n%s', - loss_1_value, loss_2_value - ) - assert loss_1_value[0] < loss_2_value[0] - assert loss_1_value[1] > loss_2_value[1] - assert loss_1_value[2] == loss_2_value[2] + def test_should_support_batch_example_pos_weights(self): + batch_size = 3 + with tf.Graph().as_default(): + labels = tf.constant([[0.0, 1.0]] * batch_size) + logits = tf.constant([[10.0, 10.0]] * batch_size) + pos_weight_1 = tf.constant([ + [0.5, 0.5], + [1.0, 1.0], + [1.0, 1.0] + ]) + pos_weight_2 = tf.constant([ + [1.0, 1.0], + [0.5, 0.5], + [1.0, 1.0] + ]) + loss_1 = weighted_cross_entropy_loss( + labels, logits, pos_weight_1, scalar=False + ) + loss_2 = weighted_cross_entropy_loss( + labels, logits, pos_weight_2, scalar=False + ) + with tf.Session() as session: + get_logger().debug('labels=\n%s', labels.eval()) + get_logger().debug('logits=\n%s', logits.eval()) + loss_1_value, loss_2_value = session.run([loss_1, loss_2]) + get_logger().debug( + '\nloss_1_value=\n%s\nloss_2_value=\n%s', + loss_1_value, loss_2_value + ) + assert loss_1_value[0] < loss_2_value[0] + assert loss_1_value[1] > loss_2_value[1] + assert loss_1_value[2] == loss_2_value[2] diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py index effd89c..912b65f 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core.py @@ -10,560 +10,598 @@ import collections import tensorflow as tf from sciencebeam_gym.trainer.models.pix2pix.tf_utils import ( - blank_other_channels, - get_channel_slice + blank_other_channels, + get_channel_slice ) from sciencebeam_gym.trainer.models.pix2pix.loss import ( - l1_loss, - cross_entropy_loss, - weighted_cross_entropy_loss + l1_loss, + cross_entropy_loss, + weighted_cross_entropy_loss ) EPS = 1e-12 + class BaseLoss(object): - L1 = "L1" - CROSS_ENTROPY = "CE" - # Loss with pre-calculated weights (and cross entropy) - # (weights calculated across entire dataset) - WEIGHTED_CROSS_ENTROPY = "WCE" - # Loss with per sample weights (and cross entropy) - # (weights calculated as part of the graph) - SAMPLE_WEIGHTED_CROSS_ENTROPY = "SWCE" - # Combination of the above; i.e. sample weights multiplied by pre-calculated weights - WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY = "WSWCE" + L1 = "L1" + CROSS_ENTROPY = "CE" + # Loss with pre-calculated weights (and cross entropy) + # (weights calculated across entire dataset) + WEIGHTED_CROSS_ENTROPY = "WCE" + # Loss with per sample weights (and cross entropy) + # (weights calculated as part of the graph) + SAMPLE_WEIGHTED_CROSS_ENTROPY = "SWCE" + # Combination of the above; i.e. sample weights multiplied by pre-calculated weights + WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY = "WSWCE" + ALL_BASE_LOSS = [ - BaseLoss.L1, BaseLoss.CROSS_ENTROPY, BaseLoss.WEIGHTED_CROSS_ENTROPY, - BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY + BaseLoss.L1, BaseLoss.CROSS_ENTROPY, BaseLoss.WEIGHTED_CROSS_ENTROPY, + BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY ] ALL_CE_BASE_LOSS = { - BaseLoss.CROSS_ENTROPY, - BaseLoss.WEIGHTED_CROSS_ENTROPY, - BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, - BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY + BaseLoss.CROSS_ENTROPY, + BaseLoss.WEIGHTED_CROSS_ENTROPY, + BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, + BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY } ALL_WCE_BASE_LOSS = ALL_CE_BASE_LOSS - {BaseLoss.CROSS_ENTROPY} Pix2PixModel = collections.namedtuple( - "Pix2PixModel", [ - "inputs", - "targets", - "outputs", - "predict_real", - "predict_fake", - "discrim_loss", - "discrim_grads_and_vars", - "gen_loss_GAN", - "gen_loss_L1", - "gen_grads_and_vars", - "global_step", - "train" - ] + "Pix2PixModel", [ + "inputs", + "targets", + "outputs", + "predict_real", + "predict_fake", + "discrim_loss", + "discrim_grads_and_vars", + "gen_loss_GAN", + "gen_loss_L1", + "gen_grads_and_vars", + "global_step", + "train" + ] ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def lrelu(x, a): - with tf.name_scope("lrelu"): - # adding these together creates the leak part and linear part - # then cancels them out by subtracting/adding an absolute value term - # leak: a*x/2 - a*abs(x)/2 - # linear: x/2 + abs(x)/2 + with tf.name_scope("lrelu"): + # adding these together creates the leak part and linear part + # then cancels them out by subtracting/adding an absolute value term + # leak: a*x/2 - a*abs(x)/2 + # linear: x/2 + abs(x)/2 + + # this block looks like it has 2 inputs on the graph unless we do this + x = tf.identity(x) + return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) - # this block looks like it has 2 inputs on the graph unless we do this - x = tf.identity(x) - return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) def batchnorm(batch_input): - with tf.variable_scope("batchnorm"): - # this block looks like it has 3 inputs on the graph unless we do this - batch_input = tf.identity(batch_input) - - input_shape = batch_input.get_shape() - get_logger().debug('batchnorm, input_shape: %s', input_shape) - channels = input_shape[-1] - offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer()) - scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02)) - mean, variance = tf.nn.moments(batch_input, axes=[0, 1, 2], keep_dims=False) - variance_epsilon = 1e-5 - normalized = tf.nn.batch_normalization(batch_input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) - return normalized + with tf.variable_scope("batchnorm"): + # this block looks like it has 3 inputs on the graph unless we do this + batch_input = tf.identity(batch_input) + + input_shape = batch_input.get_shape() + get_logger().debug('batchnorm, input_shape: %s', input_shape) + channels = input_shape[-1] + offset = tf.get_variable( + "offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer() + ) + scale = tf.get_variable( + "scale", [channels], + dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02) + ) + mean, variance = tf.nn.moments(batch_input, axes=[0, 1, 2], keep_dims=False) + variance_epsilon = 1e-5 + normalized = tf.nn.batch_normalization( + batch_input, mean, variance, offset, scale, variance_epsilon=variance_epsilon + ) + return normalized + def conv(batch_input, out_channels, stride): - with tf.variable_scope("conv"): - input_shape = batch_input.get_shape() - get_logger().debug('conv, input_shape: %s', input_shape) - in_channels = input_shape[-1] - conv_filter = tf.get_variable( - "filter", - [4, 4, in_channels, out_channels], - dtype=tf.float32, - initializer=tf.random_normal_initializer(0, 0.02) - ) - # [batch, in_height, in_width, in_channels], - # [filter_width, filter_height, in_channels, out_channels] - # => [batch, out_height, out_width, out_channels] - padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") - return tf.nn.conv2d(padded_input, conv_filter, [1, stride, stride, 1], padding="VALID") + with tf.variable_scope("conv"): + input_shape = batch_input.get_shape() + get_logger().debug('conv, input_shape: %s', input_shape) + in_channels = input_shape[-1] + conv_filter = tf.get_variable( + "filter", + [4, 4, in_channels, out_channels], + dtype=tf.float32, + initializer=tf.random_normal_initializer(0, 0.02) + ) + # [batch, in_height, in_width, in_channels], + # [filter_width, filter_height, in_channels, out_channels] + # => [batch, out_height, out_width, out_channels] + padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") + return tf.nn.conv2d(padded_input, conv_filter, [1, stride, stride, 1], padding="VALID") + def deconv(batch_input, out_channels): - with tf.variable_scope("deconv"): - input_shape = batch_input.get_shape() - get_logger().debug('deconv, input_shape: %s', input_shape) - - batch, in_height, in_width, in_channels = input_shape.as_list() - conv_filter = tf.get_variable( - "filter", - [4, 4, out_channels, in_channels], - dtype=tf.float32, - initializer=tf.random_normal_initializer(0, 0.02) - ) - # [batch, in_height, in_width, in_channels], - # [filter_width, filter_height, out_channels, in_channels] - # => [batch, out_height, out_width, out_channels] - # tf.shape creates a tensor to allow calculation of shape at runtime - dynamic_batch_size = tf.shape(batch_input)[0] - output_shape = [batch, in_height * 2, in_width * 2, out_channels] - return tf.reshape( - tf.nn.conv2d_transpose( - batch_input, - conv_filter, - tf.stack([dynamic_batch_size] + output_shape[1:]), - [1, 2, 2, 1], - padding="SAME" - ), - [batch or -1] + output_shape[1:] - ) + with tf.variable_scope("deconv"): + input_shape = batch_input.get_shape() + get_logger().debug('deconv, input_shape: %s', input_shape) + + batch, in_height, in_width, in_channels = input_shape.as_list() + conv_filter = tf.get_variable( + "filter", + [4, 4, out_channels, in_channels], + dtype=tf.float32, + initializer=tf.random_normal_initializer(0, 0.02) + ) + # [batch, in_height, in_width, in_channels], + # [filter_width, filter_height, out_channels, in_channels] + # => [batch, out_height, out_width, out_channels] + # tf.shape creates a tensor to allow calculation of shape at runtime + dynamic_batch_size = tf.shape(batch_input)[0] + output_shape = [batch, in_height * 2, in_width * 2, out_channels] + return tf.reshape( + tf.nn.conv2d_transpose( + batch_input, + conv_filter, + tf.stack([dynamic_batch_size] + output_shape[1:]), + [1, 2, 2, 1], + padding="SAME" + ), + [batch or -1] + output_shape[1:] + ) + def create_encoder_layers( - generator_inputs, - layer_specs): - """ - Creates encoders with every layer halfing width and height, - using an output depth specified by layer_specs. - - Args: - generator_inputs: A `Tensor`. The input to the generator. - layer_specs: A list of number of output channels for each layer. - Returns: - A list of `Tensor`. The output tensor of each layer. - """ - - layers = [] - for out_channels in layer_specs: - with tf.variable_scope("encoder_%d" % (len(layers) + 1)): - if not layers: - # very first layer is connected to inputs and doesn't use batch norm - output = conv(generator_inputs, out_channels, stride=2) - layers.append(output) - else: - rectified = lrelu(layers[-1], 0.2) - # [batch, in_height, in_width, in_channels] => - # [batch, in_height/2, in_width/2, out_channels] - convolved = conv(rectified, out_channels, stride=2) - output = batchnorm(convolved) - layers.append(output) - return layers + generator_inputs, + layer_specs): + """ + Creates encoders with every layer halfing width and height, + using an output depth specified by layer_specs. + + Args: + generator_inputs: A `Tensor`. The input to the generator. + layer_specs: A list of number of output channels for each layer. + Returns: + A list of `Tensor`. The output tensor of each layer. + """ + + layers = [] + for out_channels in layer_specs: + with tf.variable_scope("encoder_%d" % (len(layers) + 1)): + if not layers: + # very first layer is connected to inputs and doesn't use batch norm + output = conv(generator_inputs, out_channels, stride=2) + layers.append(output) + else: + rectified = lrelu(layers[-1], 0.2) + # [batch, in_height, in_width, in_channels] => + # [batch, in_height/2, in_width/2, out_channels] + convolved = conv(rectified, out_channels, stride=2) + output = batchnorm(convolved) + layers.append(output) + return layers + def conditional_dropout(cond, x, **kwargs): - return tf.cond(cond, lambda: tf.nn.dropout(x, **kwargs), lambda: x) + return tf.cond(cond, lambda: tf.nn.dropout(x, **kwargs), lambda: x) + def create_decoder_layers( - decoder_inputs, - layer_specs, - skip_connection_layers, - is_training): - """ - Creates decoders with every layer doubling width and height, - using an output depth and dropout specified by layer_specs. - - Args: - generator_inputs: A `Tensor`. The input to the generator. - layer_specs: A list of tuples for each layer. Each tuple containing - number of output channels and dropout. - Returns: - A list of `Tensor`. The output tensor of each layer. - """ - - layers = [] - num_encoder_layers = len(skip_connection_layers) - for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): - skip_layer = num_encoder_layers - decoder_layer - 1 - is_last_layer = decoder_layer == len(layer_specs) - 1 - with tf.variable_scope("decoder_%d" % (skip_layer + 1)): - if decoder_layer == 0: - # first decoder layer doesn't have skip connections - # since it is directly connected to the skip_layer - layer_input = decoder_inputs - else: - layer_input = tf.concat([layers[-1], skip_connection_layers[skip_layer]], axis=3) - - rectified = tf.nn.relu(layer_input) - # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] - output = deconv(rectified, out_channels) - if not is_last_layer: - # very last layer does not use batch norm - output = batchnorm(output) - - if dropout > 0.0: - # only apply dropout in training mode - output = conditional_dropout( - tf.convert_to_tensor(is_training), - output, - keep_prob=1 - dropout - ) + decoder_inputs, + layer_specs, + skip_connection_layers, + is_training): + """ + Creates decoders with every layer doubling width and height, + using an output depth and dropout specified by layer_specs. + + Args: + generator_inputs: A `Tensor`. The input to the generator. + layer_specs: A list of tuples for each layer. Each tuple containing + number of output channels and dropout. + Returns: + A list of `Tensor`. The output tensor of each layer. + """ + + layers = [] + num_encoder_layers = len(skip_connection_layers) + for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): + skip_layer = num_encoder_layers - decoder_layer - 1 + is_last_layer = decoder_layer == len(layer_specs) - 1 + with tf.variable_scope("decoder_%d" % (skip_layer + 1)): + if decoder_layer == 0: + # first decoder layer doesn't have skip connections + # since it is directly connected to the skip_layer + layer_input = decoder_inputs + else: + layer_input = tf.concat([layers[-1], skip_connection_layers[skip_layer]], axis=3) + + rectified = tf.nn.relu(layer_input) + # [batch, in_height, in_width, in_channels] + # => [batch, in_height*2, in_width*2, out_channels] + output = deconv(rectified, out_channels) + if not is_last_layer: + # very last layer does not use batch norm + output = batchnorm(output) + + if dropout > 0.0: + # only apply dropout in training mode + output = conditional_dropout( + tf.convert_to_tensor(is_training), + output, + keep_prob=1 - dropout + ) + + layers.append(output) + return layers - layers.append(output) - return layers def create_encoder_decoder( - generator_inputs, - encoder_layer_specs, - decoder_layer_specs, - is_training): - - encoder_layers = create_encoder_layers( - generator_inputs, - encoder_layer_specs - ) - decoder_layers = create_decoder_layers( - encoder_layers[-1], - decoder_layer_specs, - encoder_layers, - is_training - ) - return decoder_layers[-1] + generator_inputs, + encoder_layer_specs, + decoder_layer_specs, + is_training): + + encoder_layers = create_encoder_layers( + generator_inputs, + encoder_layer_specs + ) + decoder_layers = create_decoder_layers( + encoder_layers[-1], + decoder_layer_specs, + encoder_layers, + is_training + ) + return decoder_layers[-1] + def create_generator(generator_inputs, generator_outputs_channels, a, is_training): - encoder_layer_specs = [ - a.ngf, # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] - a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] - a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] - a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] - a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] - a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] - a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] - a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] - ] - decoder_layer_specs = [ - (a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] - (a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] - (a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] - (a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] - (a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] - (a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] - (a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] - # decoder_1 [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] - (generator_outputs_channels, 0.0) - ] - return create_encoder_decoder( - generator_inputs, - encoder_layer_specs, - decoder_layer_specs, - is_training - ) + encoder_layer_specs = [ + a.ngf, # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] + a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] + a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] + a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] + a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] + a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] + a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] + a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] + ] + decoder_layer_specs = [ + (a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] + (a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] + (a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] + (a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] + # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] + (a.ngf * 4, 0.0), + # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] + (a.ngf * 2, 0.0), + (a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] + # decoder_1 [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] + (generator_outputs_channels, 0.0) + ] + return create_encoder_decoder( + generator_inputs, + encoder_layer_specs, + decoder_layer_specs, + is_training + ) + def create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1): - n_layers = 3 - layers = [] - - # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2] - layer_input = tf.concat([discrim_inputs, discrim_targets], axis=3) - - # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf] - with tf.variable_scope("layer_1"): - convolved = conv(layer_input, a.ndf, stride=2) - rectified = lrelu(convolved, 0.2) - layers.append(rectified) - - # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2] - # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4] - # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] - for i in range(n_layers): + n_layers = 3 + layers = [] + + # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2] + layer_input = tf.concat([discrim_inputs, discrim_targets], axis=3) + + # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf] + with tf.variable_scope("layer_1"): + convolved = conv(layer_input, a.ndf, stride=2) + rectified = lrelu(convolved, 0.2) + layers.append(rectified) + + # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2] + # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4] + # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8] + for i in range(n_layers): + with tf.variable_scope("layer_%d" % (len(layers) + 1)): + layer_out_channels = a.ndf * min(2 ** (i + 1), 8) + stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 + convolved = conv(layers[-1], layer_out_channels, stride=stride) + normalized = batchnorm(convolved) + rectified = lrelu(normalized, 0.2) + layers.append(rectified) + + # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] with tf.variable_scope("layer_%d" % (len(layers) + 1)): - layer_out_channels = a.ndf * min(2**(i+1), 8) - stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 - convolved = conv(layers[-1], layer_out_channels, stride=stride) - normalized = batchnorm(convolved) - rectified = lrelu(normalized, 0.2) - layers.append(rectified) + convolved = conv(rectified, out_channels=out_channels, stride=1) + output = tf.sigmoid(convolved) + layers.append(output) - # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1] - with tf.variable_scope("layer_%d" % (len(layers) + 1)): - convolved = conv(rectified, out_channels=out_channels, stride=1) - output = tf.sigmoid(convolved) - layers.append(output) + return layers[-1] - return layers[-1] def with_variable_scope(name, fn, reuse=False, **kwargs): - with tf.variable_scope(name, reuse=reuse): - return fn(**kwargs) + with tf.variable_scope(name, reuse=reuse): + return fn(**kwargs) + def create_separate_discriminators(discrim_inputs, discrim_targets, a, out_channels=1): - if out_channels == 1: - return create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1) - n_targets_channels = discrim_targets.shape[-1] - reuse = tf.get_variable_scope().reuse - return tf.concat([ - with_variable_scope( - 'layer_{}'.format(channel_index), - lambda i: create_discriminator( - discrim_inputs, - get_channel_slice(discrim_targets, i), - a, - out_channels=1 - ), - reuse=reuse, - i=channel_index - ) - for channel_index in range(n_targets_channels) - ], axis=0) + if out_channels == 1: + return create_discriminator(discrim_inputs, discrim_targets, a, out_channels=1) + n_targets_channels = discrim_targets.shape[-1] + reuse = tf.get_variable_scope().reuse + return tf.concat([ + with_variable_scope( + 'layer_{}'.format(channel_index), + lambda i: create_discriminator( + discrim_inputs, + get_channel_slice(discrim_targets, i), + a, + out_channels=1 + ), + reuse=reuse, + i=channel_index + ) + for channel_index in range(n_targets_channels) + ], axis=0) + def create_separate_channel_discriminator_by_blanking_out_channels(inputs, targets, a): - # We need to teach the discriminator to detect the real channels, - # by just looking at the real channel. - # For each channel: - # - let the discriminator only see the current channel, blank out all other channels - # - expect output to not be fake for the not blanked out channel - n_targets_channels = int(targets.shape[-1]) - predict_real_channels = [] - predict_real_blanked_list = [] - for i in range(n_targets_channels): - masked_targets = blank_other_channels( - targets, - i + # We need to teach the discriminator to detect the real channels, + # by just looking at the real channel. + # For each channel: + # - let the discriminator only see the current channel, blank out all other channels + # - expect output to not be fake for the not blanked out channel + n_targets_channels = int(targets.shape[-1]) + predict_real_channels = [] + predict_real_blanked_list = [] + for i in range(n_targets_channels): + masked_targets = blank_other_channels( + targets, + i + ) + with tf.variable_scope("discriminator", reuse=(i > 0)): + # 2x [batch, height, width, channels] => [batch, 30, 30, n_targets_channels] + predict_real_i = create_discriminator( + inputs, masked_targets, a, + out_channels=n_targets_channels + ) + predict_real_channels.append(predict_real_i[:, :, :, i]) + for j in range(n_targets_channels): + if j != i: + predict_real_blanked_list.append(predict_real_i[:, :, :, j]) + predict_real = tf.stack( + predict_real_channels, + axis=-1, + name='predict_real' ) - with tf.variable_scope("discriminator", reuse=(i > 0)): - # 2x [batch, height, width, channels] => [batch, 30, 30, n_targets_channels] - predict_real_i = create_discriminator( - inputs, masked_targets, a, - out_channels=n_targets_channels - ) - predict_real_channels.append(predict_real_i[:, :, :, i]) - for j in range(n_targets_channels): - if j != i: - predict_real_blanked_list.append(predict_real_i[:, :, :, j]) - predict_real = tf.stack( - predict_real_channels, - axis=-1, - name='predict_real' - ) - predict_real_blanked = tf.stack( - predict_real_blanked_list, - axis=-1, - name='predict_real_blanked' - ) - return predict_real, predict_real_blanked + predict_real_blanked = tf.stack( + predict_real_blanked_list, + axis=-1, + name='predict_real_blanked' + ) + return predict_real, predict_real_blanked def create_pix2pix_model(inputs, targets, a, is_training, pos_weight=None, n_output_channels=None): - get_logger().info('gan_weight: %s, l1_weight: %s', a.gan_weight, a.l1_weight) - gan_enabled = abs(a.gan_weight) > 0.000001 - - is_predict = targets is None - if n_output_channels is None: - n_output_channels = int(targets.shape[-1]) - with tf.variable_scope("generator"): - outputs = create_generator(inputs, n_output_channels, a, is_training) - if a.base_loss in ALL_CE_BASE_LOSS: - output_logits = outputs - outputs = tf.nn.softmax(output_logits) - else: - outputs = tf.tanh(outputs) - - if is_predict: - return Pix2PixModel( - inputs=inputs, - targets=None, - predict_real=None, - predict_fake=None, - discrim_loss=None, - discrim_grads_and_vars=None, - gen_loss_GAN=None, - gen_loss_L1=None, - gen_grads_and_vars=None, - outputs=outputs, - global_step=None, - train=None, - ) + get_logger().info('gan_weight: %s, l1_weight: %s', a.gan_weight, a.l1_weight) + gan_enabled = abs(a.gan_weight) > 0.000001 + + is_predict = targets is None + if n_output_channels is None: + n_output_channels = int(targets.shape[-1]) + with tf.variable_scope("generator"): + outputs = create_generator(inputs, n_output_channels, a, is_training) + if a.base_loss in ALL_CE_BASE_LOSS: + output_logits = outputs + outputs = tf.nn.softmax(output_logits) + else: + outputs = tf.tanh(outputs) + + if is_predict: + return Pix2PixModel( + inputs=inputs, + targets=None, + predict_real=None, + predict_fake=None, + discrim_loss=None, + discrim_grads_and_vars=None, + gen_loss_GAN=None, + gen_loss_L1=None, + gen_grads_and_vars=None, + outputs=outputs, + global_step=None, + train=None, + ) - n_targets_channels = n_output_channels + n_targets_channels = n_output_channels - if gan_enabled: - discrim_out_channels = ( - n_targets_channels - if a.use_separate_discriminator_channels - else 1 - ) - get_logger().info('discrim_out_channels: %s', discrim_out_channels) - - # create two copies of discriminator, one for real pairs and one for fake pairs - # they share the same underlying variables - with tf.name_scope("real_discriminator"): - if discrim_out_channels > 1: - predict_real, predict_real_blanked = ( - create_separate_channel_discriminator_by_blanking_out_channels( - inputs, targets, a - ) + if gan_enabled: + discrim_out_channels = ( + n_targets_channels + if a.use_separate_discriminator_channels + else 1 ) - else: - with tf.variable_scope("discriminator"): - if a.use_separate_discriminators: - get_logger().info('using separate discriminators: %s', n_targets_channels) - predict_real = create_separate_discriminators( - inputs, targets, a, out_channels=n_targets_channels + get_logger().info('discrim_out_channels: %s', discrim_out_channels) + + # create two copies of discriminator, one for real pairs and one for fake pairs + # they share the same underlying variables + with tf.name_scope("real_discriminator"): + if discrim_out_channels > 1: + predict_real, predict_real_blanked = ( + create_separate_channel_discriminator_by_blanking_out_channels( + inputs, targets, a + ) + ) + else: + with tf.variable_scope("discriminator"): + if a.use_separate_discriminators: + get_logger().info('using separate discriminators: %s', n_targets_channels) + predict_real = create_separate_discriminators( + inputs, targets, a, out_channels=n_targets_channels + ) + else: + # 2x [batch, height, width, channels] => [batch, 30, 30, 1] + predict_real = create_discriminator(inputs, targets, a) + + with tf.name_scope("fake_discriminator"): + with tf.variable_scope("discriminator", reuse=True): + # 2x [batch, height, width, channels] => [batch, 30, 30, discrim_out_channels] + # We don't need to split the channels, + # the discriminator should detect them all as fake + if a.use_separate_discriminators: + predict_fake = create_separate_discriminators( + inputs, outputs, a, out_channels=n_targets_channels + ) + else: + predict_fake = create_discriminator( + inputs, outputs, a, + out_channels=discrim_out_channels + ) + + with tf.name_scope("discriminator_loss"): + # minimizing -tf.log will try to get inputs to 1 + # predict_real => 1 + # predict_fake => 0 + discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + + tf.log(1 - predict_fake + EPS))) + if discrim_out_channels > 1: + discrim_loss += tf.reduce_mean( + -tf.log(1 - tf.reshape(predict_real_blanked, [-1]) + EPS) + ) + + with tf.name_scope("discriminator_train"): + discrim_tvars = [ + var for var in tf.trainable_variables() if var.name.startswith("discriminator") + ] + discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1) + discrim_grads_and_vars = discrim_optim.compute_gradients( + discrim_loss, var_list=discrim_tvars) + discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) + else: + with tf.name_scope("gan_disabled"): + discrim_loss = tf.constant(0.0) + predict_real = None + predict_fake = None + discrim_grads_and_vars = [] + + with tf.name_scope("generator_loss"): + get_logger().info('using loss: %s', a.base_loss) + if a.base_loss == BaseLoss.L1: + gen_base_loss = l1_loss(labels=targets, outputs=outputs) + elif a.base_loss == BaseLoss.CROSS_ENTROPY: + gen_base_loss = cross_entropy_loss( + logits=output_logits, + labels=targets + ) + elif a.base_loss in ALL_WCE_BASE_LOSS: + if pos_weight is None: + raise ValueError('pos_weight missing') + pos_weight = tf.convert_to_tensor(pos_weight) + gen_base_loss = weighted_cross_entropy_loss( + logits=output_logits, + labels=targets, + pos_weight=pos_weight ) - else: - # 2x [batch, height, width, channels] => [batch, 30, 30, 1] - predict_real = create_discriminator(inputs, targets, a) - - with tf.name_scope("fake_discriminator"): - with tf.variable_scope("discriminator", reuse=True): - # 2x [batch, height, width, channels] => [batch, 30, 30, discrim_out_channels] - # We don't need to split the channels, the discriminator should detect them all as fake - if a.use_separate_discriminators: - predict_fake = create_separate_discriminators( - inputs, outputs, a, out_channels=n_targets_channels - ) else: - predict_fake = create_discriminator( - inputs, outputs, a, - out_channels=discrim_out_channels - ) - - with tf.name_scope("discriminator_loss"): - # minimizing -tf.log will try to get inputs to 1 - # predict_real => 1 - # predict_fake => 0 - discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) - if discrim_out_channels > 1: - discrim_loss += tf.reduce_mean(-tf.log(1 - tf.reshape(predict_real_blanked, [-1]) + EPS)) - - with tf.name_scope("discriminator_train"): - discrim_tvars = [ - var for var in tf.trainable_variables() if var.name.startswith("discriminator") - ] - discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1) - discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) - discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) - else: - with tf.name_scope("gan_disabled"): - discrim_loss = tf.constant(0.0) - predict_real = None - predict_fake = None - discrim_grads_and_vars = [] - - with tf.name_scope("generator_loss"): - get_logger().info('using loss: %s', a.base_loss) - if a.base_loss == BaseLoss.L1: - gen_base_loss = l1_loss(labels=targets, outputs=outputs) - elif a.base_loss == BaseLoss.CROSS_ENTROPY: - gen_base_loss = cross_entropy_loss( - logits=output_logits, - labels=targets - ) - elif a.base_loss in ALL_WCE_BASE_LOSS: - if pos_weight is None: - raise ValueError('pos_weight missing') - pos_weight = tf.convert_to_tensor(pos_weight) - gen_base_loss = weighted_cross_entropy_loss( - logits=output_logits, - labels=targets, - pos_weight=pos_weight - ) - else: - raise ValueError('unrecognised base loss: %s' % a.base_loss) - gen_loss = gen_base_loss * a.l1_weight + raise ValueError('unrecognised base loss: %s' % a.base_loss) + gen_loss = gen_base_loss * a.l1_weight - if gan_enabled: - # predict_fake => 1 - gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) - gen_loss += gen_loss_GAN * a.gan_weight - else: - gen_loss_GAN = tf.constant(0.0) + if gan_enabled: + # predict_fake => 1 + gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) + gen_loss += gen_loss_GAN * a.gan_weight + else: + gen_loss_GAN = tf.constant(0.0) + + with tf.name_scope("generator_train"): + generator_train_dependencies = ( + [discrim_train] if gan_enabled + else [] + ) + with tf.control_dependencies(generator_train_dependencies): + gen_tvars = [var for var in tf.trainable_variables( + ) if var.name.startswith("generator")] + gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1) + gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) + gen_train = gen_optim.apply_gradients(gen_grads_and_vars) + + ema = tf.train.ExponentialMovingAverage(decay=0.99) + update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_base_loss]) + + global_step = tf.contrib.framework.get_or_create_global_step() + incr_global_step = tf.assign(global_step, global_step + 1) - with tf.name_scope("generator_train"): - generator_train_dependencies = ( - [discrim_train] if gan_enabled - else [] + # TODO change gen_loss_L1 name + return Pix2PixModel( + inputs=inputs, + targets=targets, + predict_real=predict_real, + predict_fake=predict_fake, + discrim_loss=ema.average(discrim_loss), + discrim_grads_and_vars=discrim_grads_and_vars, + gen_loss_GAN=ema.average(gen_loss_GAN), + gen_loss_L1=ema.average(gen_base_loss), + gen_grads_and_vars=gen_grads_and_vars, + outputs=outputs, + global_step=global_step, + train=tf.group(update_losses, incr_global_step, gen_train), ) - with tf.control_dependencies(generator_train_dependencies): - gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] - gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1) - gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) - gen_train = gen_optim.apply_gradients(gen_grads_and_vars) - - ema = tf.train.ExponentialMovingAverage(decay=0.99) - update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_base_loss]) - - global_step = tf.contrib.framework.get_or_create_global_step() - incr_global_step = tf.assign(global_step, global_step+1) - - # TODO change gen_loss_L1 name - return Pix2PixModel( - inputs=inputs, - targets=targets, - predict_real=predict_real, - predict_fake=predict_fake, - discrim_loss=ema.average(discrim_loss), - discrim_grads_and_vars=discrim_grads_and_vars, - gen_loss_GAN=ema.average(gen_loss_GAN), - gen_loss_L1=ema.average(gen_base_loss), - gen_grads_and_vars=gen_grads_and_vars, - outputs=outputs, - global_step=global_step, - train=tf.group(update_losses, incr_global_step, gen_train), - ) + def create_image_summaries(model): - def convert(image): - return tf.image.convert_image_dtype( - image, - dtype=tf.uint8, - saturate=True - ) - summaries = {} + def convert(image): + return tf.image.convert_image_dtype( + image, + dtype=tf.uint8, + saturate=True + ) + summaries = {} - # reverse any processing on images so they can be written to disk or displayed to user - with tf.name_scope("convert_inputs"): - converted_inputs = convert(model.inputs) + # reverse any processing on images so they can be written to disk or displayed to user + with tf.name_scope("convert_inputs"): + converted_inputs = convert(model.inputs) - with tf.name_scope("convert_targets"): - converted_targets = convert(model.targets) + with tf.name_scope("convert_targets"): + converted_targets = convert(model.targets) - with tf.name_scope("convert_outputs"): - converted_outputs = convert(model.outputs) + with tf.name_scope("convert_outputs"): + converted_outputs = convert(model.outputs) - with tf.name_scope("inputs_summary"): - tf.summary.image("inputs", converted_inputs) + with tf.name_scope("inputs_summary"): + tf.summary.image("inputs", converted_inputs) - with tf.name_scope("targets_summary"): - tf.summary.image("targets", converted_targets) + with tf.name_scope("targets_summary"): + tf.summary.image("targets", converted_targets) - with tf.name_scope("outputs_summary"): - tf.summary.image("outputs", converted_outputs) - summaries['output_image'] = converted_outputs + with tf.name_scope("outputs_summary"): + tf.summary.image("outputs", converted_outputs) + summaries['output_image'] = converted_outputs - with tf.name_scope("predict_real_summary"): - tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) + with tf.name_scope("predict_real_summary"): + tf.summary.image( + "predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8) + ) + + with tf.name_scope("predict_fake_summary"): + tf.summary.image( + "predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8) + ) + return summaries - with tf.name_scope("predict_fake_summary"): - tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) - return summaries def create_other_summaries(model): - tf.summary.scalar("discriminator_loss", model.discrim_loss) - tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) - tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) + tf.summary.scalar("discriminator_loss", model.discrim_loss) + tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) + tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) - for var in tf.trainable_variables(): - tf.summary.histogram(var.op.name + "/values", var) + for var in tf.trainable_variables(): + tf.summary.histogram(var.op.name + "/values", var) - for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars: - tf.summary.histogram(var.op.name + "/gradients", grad) + for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars: + tf.summary.histogram(var.op.name + "/gradients", grad) diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core_test.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core_test.py index b514198..b71197a 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core_test.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_core_test.py @@ -7,32 +7,32 @@ import numpy as np import pytest from sciencebeam_utils.utils.num import ( - assert_all_close, - assert_all_not_close + assert_all_close, + assert_all_not_close ) from sciencebeam_utils.utils.collection import ( - extend_dict + extend_dict ) import sciencebeam_gym.trainer.models.pix2pix.pix2pix_core as pix2pix_core from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import ( - create_encoder_decoder, - create_pix2pix_model, - BaseLoss + create_encoder_decoder, + create_pix2pix_model, + BaseLoss ) DEFAULT_ARGS = dict( - ngf=2, # use small ngf and ndf to keep the graph smaller - ndf=2, - lr=0.0002, - beta1=0.5, - l1_weight=1.0, - gan_weight=0.0, - base_loss=BaseLoss.L1, - use_separate_discriminator_channels=False, - use_separate_discriminators=False + ngf=2, # use small ngf and ndf to keep the graph smaller + ndf=2, + lr=0.0002, + beta1=0.5, + l1_weight=1.0, + gan_weight=0.0, + base_loss=BaseLoss.L1, + use_separate_discriminator_channels=False, + use_separate_discriminators=False ) BATCH_SIZE = 10 @@ -40,164 +40,180 @@ WIDTH = 256 HEIGHT = 256 CHANNELS = 3 + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def setup_module(): - logging.basicConfig(level='DEBUG') + logging.basicConfig(level='DEBUG') + def create_args(*args, **kwargs): - d = extend_dict(*list(args) + [kwargs]) - return namedtuple('args', d.keys())(**d) + d = extend_dict(*list(args) + [kwargs]) + return namedtuple('args', d.keys())(**d) + def patch_spy_object(o, name): - return patch.object(o, name, wraps=getattr(o, name)) + return patch.object(o, name, wraps=getattr(o, name)) + class TestCreateEncoderDecoder(object): - def test_should_add_dropout_in_training_mode_using_constant(self): - with tf.Graph().as_default(): - encoder_inputs = tf.ones((1, 8, 8, 3)) - encoder_layer_specs = [5, 10] - decoder_layer_specs = [(5, 0.5), (3, 0.0)] - outputs = create_encoder_decoder( - encoder_inputs, - encoder_layer_specs, - decoder_layer_specs, - is_training=True - ) - with tf.Session() as session: - session.run(tf.global_variables_initializer()) - # with dropout, the outputs are expected to be different for every run - assert_all_not_close(session.run(outputs), session.run(outputs)) - - def test_should_not_add_dropout_not_in_training_mode_using_constant(self): - with tf.Graph().as_default(): - encoder_inputs = tf.ones((1, 8, 8, 3)) - encoder_layer_specs = [5, 10] - decoder_layer_specs = [(5, 0.5), (3, 0.0)] - outputs = create_encoder_decoder( - encoder_inputs, - encoder_layer_specs, - decoder_layer_specs, - is_training=False - ) - with tf.Session() as session: - session.run(tf.global_variables_initializer()) - # without dropout, the outputs are expected to the same for every run - assert_all_close(session.run(outputs), session.run(outputs)) - - def test_should_add_dropout_in_training_mode_using_placeholder(self): - with tf.Graph().as_default(): - is_training = tf.placeholder(tf.bool) - encoder_inputs = tf.ones((1, 8, 8, 3)) - encoder_layer_specs = [5, 10] - decoder_layer_specs = [(5, 0.5), (3, 0.0)] - outputs = create_encoder_decoder( - encoder_inputs, - encoder_layer_specs, - decoder_layer_specs, - is_training=is_training - ) - with tf.Session() as session: - session.run(tf.global_variables_initializer()) - feed_dict = {is_training: True} - # with dropout, the outputs are expected to be different for every run - assert_all_not_close( - session.run(outputs, feed_dict=feed_dict), - session.run(outputs, feed_dict=feed_dict) - ) - - def test_should_not_add_dropout_not_in_training_mode_using_placeholder(self): - with tf.Graph().as_default(): - is_training = tf.placeholder(tf.bool) - encoder_inputs = tf.ones((1, 8, 8, 3)) - encoder_layer_specs = [5, 10] - decoder_layer_specs = [(5, 0.5), (3, 0.0)] - outputs = create_encoder_decoder( - encoder_inputs, - encoder_layer_specs, - decoder_layer_specs, - is_training=is_training - ) - with tf.Session() as session: - session.run(tf.global_variables_initializer()) - feed_dict = {is_training: False} - # without dropout, the outputs are expected to the same for every run - assert_all_close( - session.run(outputs, feed_dict=feed_dict), - session.run(outputs, feed_dict=feed_dict) - ) - - def test_should_allow_undefined_batch_size(self): - with tf.Graph().as_default(): - input_shape = [None, 8, 8, 3] - encoder_inputs = tf.placeholder(tf.float32, input_shape) - encoder_layer_specs = [5, 10] - decoder_layer_specs = [(5, 0.5), (3, 0.0)] - outputs = create_encoder_decoder( - encoder_inputs, - encoder_layer_specs, - decoder_layer_specs, - is_training=False - ) - assert outputs.get_shape().as_list() == input_shape - with tf.Session() as session: - session.run(tf.global_variables_initializer()) - feed_dict = {encoder_inputs: np.ones([1] + input_shape[1:])} - assert_all_close( - session.run(outputs, feed_dict=feed_dict), - session.run(outputs, feed_dict=feed_dict) - ) + def test_should_add_dropout_in_training_mode_using_constant(self): + with tf.Graph().as_default(): + encoder_inputs = tf.ones((1, 8, 8, 3)) + encoder_layer_specs = [5, 10] + decoder_layer_specs = [(5, 0.5), (3, 0.0)] + outputs = create_encoder_decoder( + encoder_inputs, + encoder_layer_specs, + decoder_layer_specs, + is_training=True + ) + with tf.Session() as session: + session.run(tf.global_variables_initializer()) + # with dropout, the outputs are expected to be different for every run + assert_all_not_close(session.run(outputs), session.run(outputs)) + + def test_should_not_add_dropout_not_in_training_mode_using_constant(self): + with tf.Graph().as_default(): + encoder_inputs = tf.ones((1, 8, 8, 3)) + encoder_layer_specs = [5, 10] + decoder_layer_specs = [(5, 0.5), (3, 0.0)] + outputs = create_encoder_decoder( + encoder_inputs, + encoder_layer_specs, + decoder_layer_specs, + is_training=False + ) + with tf.Session() as session: + session.run(tf.global_variables_initializer()) + # without dropout, the outputs are expected to the same for every run + assert_all_close(session.run(outputs), session.run(outputs)) + + def test_should_add_dropout_in_training_mode_using_placeholder(self): + with tf.Graph().as_default(): + is_training = tf.placeholder(tf.bool) + encoder_inputs = tf.ones((1, 8, 8, 3)) + encoder_layer_specs = [5, 10] + decoder_layer_specs = [(5, 0.5), (3, 0.0)] + outputs = create_encoder_decoder( + encoder_inputs, + encoder_layer_specs, + decoder_layer_specs, + is_training=is_training + ) + with tf.Session() as session: + session.run(tf.global_variables_initializer()) + feed_dict = {is_training: True} + # with dropout, the outputs are expected to be different for every run + assert_all_not_close( + session.run(outputs, feed_dict=feed_dict), + session.run(outputs, feed_dict=feed_dict) + ) + + def test_should_not_add_dropout_not_in_training_mode_using_placeholder(self): + with tf.Graph().as_default(): + is_training = tf.placeholder(tf.bool) + encoder_inputs = tf.ones((1, 8, 8, 3)) + encoder_layer_specs = [5, 10] + decoder_layer_specs = [(5, 0.5), (3, 0.0)] + outputs = create_encoder_decoder( + encoder_inputs, + encoder_layer_specs, + decoder_layer_specs, + is_training=is_training + ) + with tf.Session() as session: + session.run(tf.global_variables_initializer()) + feed_dict = {is_training: False} + # without dropout, the outputs are expected to the same for every run + assert_all_close( + session.run(outputs, feed_dict=feed_dict), + session.run(outputs, feed_dict=feed_dict) + ) + + def test_should_allow_undefined_batch_size(self): + with tf.Graph().as_default(): + input_shape = [None, 8, 8, 3] + encoder_inputs = tf.placeholder(tf.float32, input_shape) + encoder_layer_specs = [5, 10] + decoder_layer_specs = [(5, 0.5), (3, 0.0)] + outputs = create_encoder_decoder( + encoder_inputs, + encoder_layer_specs, + decoder_layer_specs, + is_training=False + ) + assert outputs.get_shape().as_list() == input_shape + with tf.Session() as session: + session.run(tf.global_variables_initializer()) + feed_dict = {encoder_inputs: np.ones([1] + input_shape[1:])} + assert_all_close( + session.run(outputs, feed_dict=feed_dict), + session.run(outputs, feed_dict=feed_dict) + ) + @pytest.mark.slow @pytest.mark.very_slow class TestCreatePix2pixModel(object): - def test_should_be_able_to_construct_graph_with_defaults_without_gan(self): - with tf.Graph().as_default(): - inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - a = create_args(DEFAULT_ARGS, gan_weight=0.0) - create_pix2pix_model(inputs, targets, a, is_training=True) - - def test_should_be_able_to_construct_graph_with_defaults_and_gan(self): - with tf.Graph().as_default(): - inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - a = create_args(DEFAULT_ARGS, gan_weight=1.0) - create_pix2pix_model(inputs, targets, a, is_training=True) - - def test_should_be_able_to_construct_graph_with_gan_and_sep_discrim_channels(self): - with patch_spy_object(pix2pix_core, 'l1_loss') as l1_loss: - with tf.Graph().as_default(): - inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - a = create_args(DEFAULT_ARGS, gan_weight=1.0, use_separate_discriminator_channels=True) - create_pix2pix_model(inputs, targets, a, is_training=True) - assert l1_loss.called - - def test_should_be_able_to_construct_graph_with_sep_discrim_channels_and_cross_entropy_loss(self): - with patch_spy_object(pix2pix_core, 'cross_entropy_loss') as cross_entropy_loss: - with tf.Graph().as_default(): - inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - a = create_args( - DEFAULT_ARGS, - gan_weight=1.0, use_separate_discriminator_channels=True, base_loss=BaseLoss.CROSS_ENTROPY - ) - create_pix2pix_model(inputs, targets, a, is_training=True) - assert cross_entropy_loss.called - - def test_should_be_able_to_construct_graph_with_weighted_cross_entropy_loss(self): - with patch_spy_object(pix2pix_core, 'weighted_cross_entropy_loss') \ - as weighted_cross_entropy_loss: - - with tf.Graph().as_default(): - inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) - a = create_args( - DEFAULT_ARGS, - gan_weight=1.0, use_separate_discriminator_channels=True, - base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY - ) - create_pix2pix_model(inputs, targets, a, is_training=True, pos_weight=[1.0] * CHANNELS) - assert weighted_cross_entropy_loss.called + def test_should_be_able_to_construct_graph_with_defaults_without_gan(self): + with tf.Graph().as_default(): + inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + a = create_args(DEFAULT_ARGS, gan_weight=0.0) + create_pix2pix_model(inputs, targets, a, is_training=True) + + def test_should_be_able_to_construct_graph_with_defaults_and_gan(self): + with tf.Graph().as_default(): + inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + a = create_args(DEFAULT_ARGS, gan_weight=1.0) + create_pix2pix_model(inputs, targets, a, is_training=True) + + def test_should_be_able_to_construct_graph_with_gan_and_sep_discrim_channels(self): + with patch_spy_object(pix2pix_core, 'l1_loss') as l1_loss: + with tf.Graph().as_default(): + inputs = tf.constant( + np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + targets = tf.constant( + np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + a = create_args(DEFAULT_ARGS, gan_weight=1.0, + use_separate_discriminator_channels=True) + create_pix2pix_model(inputs, targets, a, is_training=True) + assert l1_loss.called + + def test_should_be_able_to_construct_graph_with_sep_discrim_channels_and_cross_entropy_loss( + self): + with patch_spy_object(pix2pix_core, 'cross_entropy_loss') as cross_entropy_loss: + with tf.Graph().as_default(): + inputs = tf.constant( + np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + targets = tf.constant( + np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + a = create_args( + DEFAULT_ARGS, + gan_weight=1.0, use_separate_discriminator_channels=True, + base_loss=BaseLoss.CROSS_ENTROPY + ) + create_pix2pix_model(inputs, targets, a, is_training=True) + assert cross_entropy_loss.called + + def test_should_be_able_to_construct_graph_with_weighted_cross_entropy_loss(self): + with patch_spy_object(pix2pix_core, 'weighted_cross_entropy_loss') \ + as weighted_cross_entropy_loss: + + with tf.Graph().as_default(): + inputs = tf.constant( + np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + targets = tf.constant( + np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32)) + a = create_args( + DEFAULT_ARGS, + gan_weight=1.0, use_separate_discriminator_channels=True, + base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY + ) + create_pix2pix_model(inputs, targets, a, is_training=True, + pos_weight=[1.0] * CHANNELS) + assert weighted_cross_entropy_loss.called diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py index e4e2e5f..805be05 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model.py @@ -12,44 +12,44 @@ from six.moves import reduce import tensorflow as tf -from tensorflow.python.lib.io.file_io import FileIO # pylint: disable=E0611 +from tensorflow.python.lib.io.file_io import FileIO # pylint: disable=E0611 from sciencebeam_gym.trainer.data.examples import ( - get_matching_files, - read_examples + get_matching_files, + read_examples ) from sciencebeam_gym.preprocess.color_map import ( - parse_color_map_from_file + parse_color_map_from_file ) from sciencebeam_gym.tools.calculate_class_weights import ( - tf_calculate_efnet_weights_for_frequency_by_label + tf_calculate_efnet_weights_for_frequency_by_label ) from sciencebeam_gym.trainer.models.pix2pix.tf_utils import ( - find_nearest_centroid_indices + find_nearest_centroid_indices ) from sciencebeam_gym.preprocess.preprocessing_utils import ( - parse_page_range + parse_page_range ) from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import ( - BaseLoss, - ALL_BASE_LOSS, - create_pix2pix_model, - create_other_summaries + BaseLoss, + ALL_BASE_LOSS, + create_pix2pix_model, + create_other_summaries ) from sciencebeam_gym.trainer.models.pix2pix.evaluate import ( - evaluate_separate_channels, - evaluate_predictions, - evaluation_summary + evaluate_separate_channels, + evaluate_predictions, + evaluation_summary ) from sciencebeam_gym.model_utils.channels import ( - calculate_color_masks + calculate_color_masks ) @@ -58,628 +58,669 @@ UNKNOWN_LABEL = 'unknown' DEFAULT_UNKNOWN_CLASS_WEIGHT = 0.1 + class GraphMode(object): - TRAIN = 1 - EVALUATE = 2 - PREDICT = 3 + TRAIN = 1 + EVALUATE = 2 + PREDICT = 3 + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) class GraphReferences(object): - """Holder of base tensors used for training model using common task.""" - - def __init__(self): - self.is_training = None - self.inputs = dict() - self.examples = None - self.train = None - self.global_step = None - self.metric_updates = [] - self.metric_values = [] - self.predictions = [] - self.input_jpeg = None - self.input_uri = None - self.image_tensor = None - self.annotation_uri = None - self.annotation_tensor = None - self.separate_channel_annotation_tensor = None - self.class_labels_tensor = None - self.pred = None - self.probabilities = None - self.summary = None - self.summaries = None - self.image_tensors = None - self.targets_class_indices = None - self.outputs_class_indices = None - self.output_layer_labels = None - self.evaluation_result = None - self.pos_weight = None + """Holder of base tensors used for training model using common task.""" + + def __init__(self): + self.is_training = None + self.inputs = dict() + self.examples = None + self.train = None + self.global_step = None + self.metric_updates = [] + self.metric_values = [] + self.predictions = [] + self.input_jpeg = None + self.input_uri = None + self.image_tensor = None + self.annotation_uri = None + self.annotation_tensor = None + self.separate_channel_annotation_tensor = None + self.class_labels_tensor = None + self.pred = None + self.probabilities = None + self.summary = None + self.summaries = None + self.image_tensors = None + self.targets_class_indices = None + self.outputs_class_indices = None + self.output_layer_labels = None + self.evaluation_result = None + self.pos_weight = None + def batch_dimensions_to_colors_list(image_tensor, colors): - batch_images = [] - for i, single_label_color in enumerate(colors): - batch_images.append( - tf.expand_dims( - image_tensor[:, :, :, i], - axis=-1 - ) * ([x / 255.0 for x in single_label_color]) - ) - return batch_images + batch_images = [] + for i, single_label_color in enumerate(colors): + batch_images.append( + tf.expand_dims( + image_tensor[:, :, :, i], + axis=-1 + ) * ([x / 255.0 for x in single_label_color]) + ) + return batch_images + def batch_dimensions_to_most_likely_colors_list(image_tensor, colors): - with tf.variable_scope("batch_dimensions_to_most_likely_colors_list"): - colors_tensor = tf.constant(colors, dtype=tf.uint8, name='colors') - most_likely_class_index = tf.argmax(image_tensor, 3) - return tf.gather(params=colors_tensor, indices=most_likely_class_index) + with tf.variable_scope("batch_dimensions_to_most_likely_colors_list"): + colors_tensor = tf.constant(colors, dtype=tf.uint8, name='colors') + most_likely_class_index = tf.argmax(image_tensor, 3) + return tf.gather(params=colors_tensor, indices=most_likely_class_index) + def add_summary_image(tensors, name, image): - tensors.image_tensors[name] = image - tf.summary.image(name, image) + tensors.image_tensors[name] = image + tf.summary.image(name, image) + def convert_image(image_tensor): - return tf.image.convert_image_dtype( - image_tensor, - dtype=tf.uint8, - saturate=True - ) + return tf.image.convert_image_dtype( + image_tensor, + dtype=tf.uint8, + saturate=True + ) + def add_simple_summary_image(tensors, name, image_tensor): - with tf.name_scope(name): - add_summary_image( - tensors, - name, - convert_image(image_tensor) - ) + with tf.name_scope(name): + add_summary_image( + tensors, + name, + convert_image(image_tensor) + ) + def replace_black_with_white_color(image_tensor): - is_black = tf.reduce_all( - tf.equal(image_tensor, (0, 0, 0)), - axis=-1 - ) - is_black = tf.stack([is_black] * 3, axis=-1) - return tf.where( - is_black, - 255 * tf.ones_like(image_tensor), - image_tensor - ) + is_black = tf.reduce_all( + tf.equal(image_tensor, (0, 0, 0)), + axis=-1 + ) + is_black = tf.stack([is_black] * 3, axis=-1) + return tf.where( + is_black, + 255 * tf.ones_like(image_tensor), + image_tensor + ) + def combine_image(batch_images, replace_black_with_white=False): - clipped_batch_images = [ - tf.clip_by_value(batch_image, 0.0, 1.0) - for batch_image in batch_images - ] - combined_image = convert_image( - reduce( - lambda a, b: a + b, - clipped_batch_images + clipped_batch_images = [ + tf.clip_by_value(batch_image, 0.0, 1.0) + for batch_image in batch_images + ] + combined_image = convert_image( + reduce( + lambda a, b: a + b, + clipped_batch_images + ) ) - ) - if replace_black_with_white: - combined_image = replace_black_with_white_color(combined_image) - return combined_image + if replace_black_with_white: + combined_image = replace_black_with_white_color(combined_image) + return combined_image + def remove_last(a): - return a[:-1] + return a[:-1] + def add_model_summary_images( - tensors, dimension_colors, dimension_labels, - use_separate_channels=False, - has_unknown_class=False): - - tensors.summaries = {} - add_simple_summary_image( - tensors, 'input', tensors.image_tensor - ) - add_simple_summary_image( - tensors, 'target', tensors.annotation_tensor - ) - if (has_unknown_class or not use_separate_channels) and dimension_labels is not None: - dimension_labels_with_unknown = dimension_labels + [UNKNOWN_LABEL] - dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)] - else: - dimension_labels_with_unknown = dimension_labels - dimension_colors_with_unknown = dimension_colors - if use_separate_channels: - for name, outputs in [ - ('targets', tensors.separate_channel_annotation_tensor), - ('outputs', tensors.pred) - ]: - - batch_images = batch_dimensions_to_colors_list( - outputs, - dimension_colors_with_unknown - ) - batch_images_excluding_unknown = ( - remove_last(batch_images) - if has_unknown_class - else batch_images - ) - for i, (batch_image, dimension_label) in enumerate(zip( - batch_images, dimension_labels_with_unknown)): - - suffix = "_{}_{}".format( - i, dimension_label if dimension_label else 'unknown_label' - ) - add_simple_summary_image( - tensors, name + suffix, batch_image - ) - with tf.name_scope(name + "_combined"): - combined_image = combine_image(batch_images_excluding_unknown) - if name == 'outputs': - tensors.summaries['output_image'] = combined_image - add_summary_image( - tensors, - name + "_combined", - combined_image - ) + tensors, dimension_colors, dimension_labels, + use_separate_channels=False, + has_unknown_class=False): - if name == 'outputs': - with tf.name_scope(name + "_most_likely"): - add_summary_image( - tensors, - name + "_most_likely", - batch_dimensions_to_most_likely_colors_list( - outputs, - dimension_colors_with_unknown) - ) - else: + tensors.summaries = {} add_simple_summary_image( - tensors, - "output", - tensors.pred + tensors, 'input', tensors.image_tensor ) - if tensors.outputs_class_indices is not None: - outputs = tensors.pred - with tf.name_scope("outputs_most_likely"): - colors_tensor = tf.constant( - dimension_colors_with_unknown, - dtype=tf.uint8, name='colors' - ) - add_summary_image( - tensors, - "outputs_most_likely", - tf.gather( - params=colors_tensor, - indices=tensors.outputs_class_indices - ) + add_simple_summary_image( + tensors, 'target', tensors.annotation_tensor + ) + if (has_unknown_class or not use_separate_channels) and dimension_labels is not None: + dimension_labels_with_unknown = dimension_labels + [UNKNOWN_LABEL] + dimension_colors_with_unknown = dimension_colors + [(255, 255, 255)] + else: + dimension_labels_with_unknown = dimension_labels + dimension_colors_with_unknown = dimension_colors + if use_separate_channels: + for name, outputs in [ + ('targets', tensors.separate_channel_annotation_tensor), + ('outputs', tensors.pred) + ]: + + batch_images = batch_dimensions_to_colors_list( + outputs, + dimension_colors_with_unknown + ) + batch_images_excluding_unknown = ( + remove_last(batch_images) + if has_unknown_class + else batch_images + ) + for i, (batch_image, dimension_label) in enumerate(zip( + batch_images, dimension_labels_with_unknown)): + + suffix = "_{}_{}".format( + i, dimension_label if dimension_label else 'unknown_label' + ) + add_simple_summary_image( + tensors, name + suffix, batch_image + ) + with tf.name_scope(name + "_combined"): + combined_image = combine_image(batch_images_excluding_unknown) + if name == 'outputs': + tensors.summaries['output_image'] = combined_image + add_summary_image( + tensors, + name + "_combined", + combined_image + ) + + if name == 'outputs': + with tf.name_scope(name + "_most_likely"): + add_summary_image( + tensors, + name + "_most_likely", + batch_dimensions_to_most_likely_colors_list( + outputs, + dimension_colors_with_unknown) + ) + else: + add_simple_summary_image( + tensors, + "output", + tensors.pred ) - tensors.summaries['output_image'] = tensors.image_tensors['output'] + if tensors.outputs_class_indices is not None: + outputs = tensors.pred + with tf.name_scope("outputs_most_likely"): + colors_tensor = tf.constant( + dimension_colors_with_unknown, + dtype=tf.uint8, name='colors' + ) + add_summary_image( + tensors, + "outputs_most_likely", + tf.gather( + params=colors_tensor, + indices=tensors.outputs_class_indices + ) + ) + tensors.summaries['output_image'] = tensors.image_tensors['output'] + def parse_json_file(filename): - with FileIO(filename, 'r') as f: - return json.load(f) + with FileIO(filename, 'r') as f: + return json.load(f) + def class_weights_to_pos_weight( - class_weights, labels, - use_unknown_class, unknown_class_weight=DEFAULT_UNKNOWN_CLASS_WEIGHT): + class_weights, labels, + use_unknown_class, unknown_class_weight=DEFAULT_UNKNOWN_CLASS_WEIGHT): + + pos_weight = [class_weights[k] for k in labels] + return pos_weight + [unknown_class_weight] if use_unknown_class else pos_weight - pos_weight = [class_weights[k] for k in labels] - return pos_weight + [unknown_class_weight] if use_unknown_class else pos_weight def parse_color_map(color_map_filename): - with FileIO(color_map_filename, 'r') as config_f: - return parse_color_map_from_file( - config_f - ) + with FileIO(color_map_filename, 'r') as config_f: + return parse_color_map_from_file( + config_f + ) + def color_map_to_labels(color_map, labels=None): - if labels: - if not all(color_map.has_key(k) for k in labels): - raise ValueError( - 'not all lables found in color map, labels=%s, available keys=%s' % - (labels, color_map.keys()) - ) - return labels - return sorted(color_map.keys()) + if labels: + if not all(k in color_map for k in labels): + raise ValueError( + 'not all lables found in color map, labels=%s, available keys=%s' % + (labels, color_map.keys()) + ) + return labels + return sorted(color_map.keys()) + def color_map_to_colors(color_map, labels): - return [color_map[k] for k in labels] + return [color_map[k] for k in labels] + def colors_and_labels_with_unknown_class(colors, labels, use_unknown_class): - if use_unknown_class or not colors: - return ( - colors + [UNKNOWN_COLOR], - labels + [UNKNOWN_LABEL] - ) - else: - return colors, labels + if use_unknown_class or not colors: + return ( + colors + [UNKNOWN_COLOR], + labels + [UNKNOWN_LABEL] + ) + else: + return colors, labels -def remove_none_from_dict(d): - return {k: v for k, v in iteritems(d) if v is not None} -class Model(object): - def __init__(self, args): - self.args = args - self.image_width = 256 - self.image_height = 256 - self.color_map = None - self.pos_weight = None - self.dimension_colors = None - self.dimension_labels = None - self.use_unknown_class = args.use_unknown_class - self.use_separate_channels = args.use_separate_channels and self.args.color_map is not None - logger = get_logger() - logger.info('use_separate_channels: %s', self.use_separate_channels) - if self.args.color_map: - color_map = parse_color_map(args.color_map) - class_weights = ( - parse_json_file(self.args.class_weights) - if ( - self.args.class_weights and - self.args.base_loss in { - BaseLoss.WEIGHTED_CROSS_ENTROPY, - BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY - } - ) - else None - ) - available_labels = color_map_to_labels(color_map) - if class_weights: - # remove labels with zero class weights - available_labels = [k for k in available_labels if class_weights.get(k, 0.0) != 0.0] - self.dimension_labels = args.channels if args.channels else available_labels - self.dimension_colors = color_map_to_colors(color_map, self.dimension_labels) - self.dimension_colors_with_unknown, self.dimension_labels_with_unknown = ( - colors_and_labels_with_unknown_class( - self.dimension_colors, - self.dimension_labels, - self.use_unknown_class - ) - ) - logger.debug("dimension_colors: %s", self.dimension_colors) - logger.debug("dimension_labels: %s", self.dimension_labels) - if class_weights: - self.pos_weight = class_weights_to_pos_weight( - class_weights, - self.dimension_labels, - self.use_separate_channels, - class_weights.get(UNKNOWN_LABEL, DEFAULT_UNKNOWN_CLASS_WEIGHT) - ) - logger.info("pos_weight: %s", self.pos_weight) +def remove_none_from_dict(d): + return {k: v for k, v in iteritems(d) if v is not None} - def _build_predict_graph(self): - tensors = GraphReferences() - input_image_tensor = tf.placeholder( - tf.uint8, (None, None, None, 3), - name='inputs_image' - ) - tensors.inputs = dict( - image=input_image_tensor - ) - tensors.image_tensor = tf.image.resize_images( - tf.image.convert_image_dtype(input_image_tensor, tf.float32), - (self.image_height, self.image_width), - method=tf.image.ResizeMethod.BILINEAR - ) +def _create_pos_weights_tensor( + base_loss, + separate_channel_annotation_tensor, + pos_weight_values, + input_uri, + debug): - if self.use_separate_channels: - n_output_channels = len(self.dimension_labels_with_unknown) - else: - n_output_channels = 3 - pix2pix_model = create_pix2pix_model( - tensors.image_tensor, - None, - self.args, - is_training=False, - pos_weight=tensors.pos_weight, - n_output_channels=n_output_channels + frequency_by_label = tf.reduce_sum( + separate_channel_annotation_tensor, + axis=[0, 1], + keep_dims=True, + name='frequency_by_channel' ) - tensors.pred = pix2pix_model.outputs - return tensors - - def build_graph(self, data_paths, batch_size, graph_mode): - if graph_mode == GraphMode.PREDICT: - return self._build_predict_graph() - - logger = get_logger() - logger.debug('batch_size: %s', batch_size) - tensors = GraphReferences() - tensors.is_training = tf.constant(graph_mode == GraphMode.TRAIN) - is_training = ( - graph_mode == GraphMode.TRAIN or - graph_mode == GraphMode.EVALUATE + pos_weight_sample = tf_calculate_efnet_weights_for_frequency_by_label( + frequency_by_label ) - - if not data_paths: - raise ValueError('data_paths required') - get_logger().info('reading examples from %s', data_paths) - tensors.examples = read_examples( - get_matching_files(data_paths), - shuffle=(graph_mode == GraphMode.TRAIN), - num_epochs=None if is_training else 2, - page_range=self.args.pages, - channel_colors=( - self.dimension_colors if self.args.filter_annotated - else None - ) + pos_weight = ( + pos_weight_sample * pos_weight_values + if base_loss == BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY + else pos_weight_sample ) - parsed = tensors.examples - - tensors.image_tensors = {} - - tensors.input_uri = tf.squeeze(parsed['input_uri']) - tensors.annotation_uri = tf.squeeze(parsed['annotation_uri']) - raw_input_image = tf.squeeze(parsed['input_image']) - logging.info('raw_input_image: %s', raw_input_image) - raw_annotation_image = tf.squeeze(parsed['annotation_image']) - tensors.image_tensor = tf.image.decode_png(raw_input_image, channels=3) - tensors.annotation_tensor = tf.image.decode_png(raw_annotation_image, channels=3) - - # TODO resize_images and tf.cast did not work on input image - # but did work on annotation image - tensors.image_tensor = tf.image.resize_image_with_crop_or_pad( - tensors.image_tensor, self.image_height, self.image_width + if debug: + pos_weight = tf.Print( + pos_weight, [ + pos_weight, + pos_weight_sample, + frequency_by_label, + input_uri + ], + 'pos weights, sample, frequency, uri: ', + summarize=1000 + ) + get_logger().debug( + 'pos_weight before batch: %s (frequency_by_label: %s)', + pos_weight, frequency_by_label ) + return pos_weight - tensors.image_tensor = tf.image.convert_image_dtype(tensors.image_tensor, tf.float32) - - tensors.annotation_tensor = tf.image.resize_image_with_crop_or_pad( - tensors.annotation_tensor, self.image_height, self.image_width - ) - if self.use_separate_channels: - with tf.variable_scope('channels'): - color_masks = calculate_color_masks( - tensors.annotation_tensor, - self.dimension_colors, - use_unknown_class=self.use_unknown_class - ) - tensors.separate_channel_annotation_tensor = tf.stack(color_masks, axis=-1) - if self.args.base_loss == BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY: - with tf.variable_scope('class_weights'): - frequency_by_label = tf.reduce_sum( - tensors.separate_channel_annotation_tensor, - axis=[0, 1], - keep_dims=True, - name='frequency_by_channel' +class Model(object): + def __init__(self, args): + self.args = args + self.image_width = 256 + self.image_height = 256 + self.color_map = None + self.pos_weight = None + self.dimension_colors = None + self.dimension_labels = None + self.use_unknown_class = args.use_unknown_class + self.use_separate_channels = args.use_separate_channels and self.args.color_map is not None + logger = get_logger() + logger.info('use_separate_channels: %s', self.use_separate_channels) + if self.args.color_map: + color_map = parse_color_map(args.color_map) + class_weights = ( + parse_json_file(self.args.class_weights) + if ( + self.args.class_weights and + self.args.base_loss in { + BaseLoss.WEIGHTED_CROSS_ENTROPY, + BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY + } + ) + else None ) - pos_weight_sample = tf_calculate_efnet_weights_for_frequency_by_label( - frequency_by_label + available_labels = color_map_to_labels(color_map) + if class_weights: + # remove labels with zero class weights + available_labels = [k for k in available_labels if class_weights.get(k, 0.0) != 0.0] + self.dimension_labels = args.channels if args.channels else available_labels + self.dimension_colors = color_map_to_colors(color_map, self.dimension_labels) + self.dimension_colors_with_unknown, self.dimension_labels_with_unknown = ( + colors_and_labels_with_unknown_class( + self.dimension_colors, + self.dimension_labels, + self.use_unknown_class + ) ) - tensors.pos_weight = ( - pos_weight_sample * self.pos_weight - if self.args.base_loss == BaseLoss.WEIGHTED_SAMPLE_WEIGHTED_CROSS_ENTROPY - else pos_weight_sample + logger.debug("dimension_colors: %s", self.dimension_colors) + logger.debug("dimension_labels: %s", self.dimension_labels) + if class_weights: + self.pos_weight = class_weights_to_pos_weight( + class_weights, + self.dimension_labels, + self.use_separate_channels, + class_weights.get(UNKNOWN_LABEL, DEFAULT_UNKNOWN_CLASS_WEIGHT) + ) + logger.info("pos_weight: %s", self.pos_weight) + + def _build_predict_graph(self): + tensors = GraphReferences() + input_image_tensor = tf.placeholder( + tf.uint8, (None, None, None, 3), + name='inputs_image' + ) + tensors.inputs = dict( + image=input_image_tensor + ) + + tensors.image_tensor = tf.image.resize_images( + tf.image.convert_image_dtype(input_image_tensor, tf.float32), + (self.image_height, self.image_width), + method=tf.image.ResizeMethod.BILINEAR + ) + + if self.use_separate_channels: + n_output_channels = len(self.dimension_labels_with_unknown) + else: + n_output_channels = 3 + pix2pix_model = create_pix2pix_model( + tensors.image_tensor, + None, + self.args, + is_training=False, + pos_weight=tensors.pos_weight, + n_output_channels=n_output_channels + ) + tensors.pred = pix2pix_model.outputs + return tensors + + def build_graph(self, data_paths, batch_size, graph_mode): + if graph_mode == GraphMode.PREDICT: + return self._build_predict_graph() + + logger = get_logger() + logger.debug('batch_size: %s', batch_size) + tensors = GraphReferences() + tensors.is_training = tf.constant(graph_mode == GraphMode.TRAIN) + is_training = ( + graph_mode == GraphMode.TRAIN or + graph_mode == GraphMode.EVALUATE + ) + + if not data_paths: + raise ValueError('data_paths required') + get_logger().info('reading examples from %s', data_paths) + tensors.examples = read_examples( + get_matching_files(data_paths), + shuffle=(graph_mode == GraphMode.TRAIN), + num_epochs=None if is_training else 2, + page_range=self.args.pages, + channel_colors=( + self.dimension_colors if self.args.filter_annotated + else None ) - if self.args.debug: - tensors.pos_weight = tf.Print( - tensors.pos_weight, [ - tensors.pos_weight, - pos_weight_sample, - frequency_by_label, - tensors.input_uri - ], - 'pos weights, sample, frequency, uri: ', - summarize=1000 - ) - get_logger().debug( - 'pos_weight before batch: %s (frequency_by_label: %s)', - tensors.pos_weight, frequency_by_label + ) + parsed = tensors.examples + + tensors.image_tensors = {} + + tensors.input_uri = tf.squeeze(parsed['input_uri']) + tensors.annotation_uri = tf.squeeze(parsed['annotation_uri']) + raw_input_image = tf.squeeze(parsed['input_image']) + logging.info('raw_input_image: %s', raw_input_image) + raw_annotation_image = tf.squeeze(parsed['annotation_image']) + tensors.image_tensor = tf.image.decode_png(raw_input_image, channels=3) + tensors.annotation_tensor = tf.image.decode_png(raw_annotation_image, channels=3) + + # TODO resize_images and tf.cast did not work on input image + # but did work on annotation image + tensors.image_tensor = tf.image.resize_image_with_crop_or_pad( + tensors.image_tensor, self.image_height, self.image_width + ) + + tensors.image_tensor = tf.image.convert_image_dtype(tensors.image_tensor, tf.float32) + + tensors.annotation_tensor = tf.image.resize_image_with_crop_or_pad( + tensors.annotation_tensor, self.image_height, self.image_width + ) + + if self.use_separate_channels: + with tf.variable_scope('channels'): + color_masks = calculate_color_masks( + tensors.annotation_tensor, + self.dimension_colors, + use_unknown_class=self.use_unknown_class + ) + tensors.separate_channel_annotation_tensor = tf.stack(color_masks, axis=-1) + if self.args.base_loss == BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY: + with tf.variable_scope('class_weights'): + tensors.pos_weight = _create_pos_weights_tensor( + base_loss=self.args.base_loss, + separate_channel_annotation_tensor=( + tensors.separate_channel_annotation_tensor + ), + pos_weight_values=self.pos_weight, + input_uri=tensors.input_uri, + debug=self.args.debug + ) + else: + tensors.annotation_tensor = tf.image.convert_image_dtype( + tensors.annotation_tensor, tf.float32 ) - else: - tensors.annotation_tensor = tf.image.convert_image_dtype( - tensors.annotation_tensor, tf.float32 - ) - tensors.separate_channel_annotation_tensor = tensors.annotation_tensor - - batched_tensors = tf.train.batch( - remove_none_from_dict({ - k: getattr(tensors, k) - for k in { - 'input_uri', - 'annotation_uri', - 'image_tensor', - 'annotation_tensor', - 'separate_channel_annotation_tensor', - 'pos_weight' - } - }), - batch_size=batch_size - ) - for k, v in iteritems(batched_tensors): - setattr(tensors, k, v) - - if tensors.pos_weight is None: - tensors.pos_weight = self.pos_weight - - pix2pix_model = create_pix2pix_model( - tensors.image_tensor, - tensors.separate_channel_annotation_tensor, - self.args, - is_training=tensors.is_training, - pos_weight=tensors.pos_weight - ) + tensors.separate_channel_annotation_tensor = tensors.annotation_tensor + + batched_tensors = tf.train.batch( + remove_none_from_dict({ + k: getattr(tensors, k) + for k in { + 'input_uri', + 'annotation_uri', + 'image_tensor', + 'annotation_tensor', + 'separate_channel_annotation_tensor', + 'pos_weight' + } + }), + batch_size=batch_size + ) + for k, v in iteritems(batched_tensors): + setattr(tensors, k, v) + + if tensors.pos_weight is None: + tensors.pos_weight = self.pos_weight + + pix2pix_model = create_pix2pix_model( + tensors.image_tensor, + tensors.separate_channel_annotation_tensor, + self.args, + is_training=tensors.is_training, + pos_weight=tensors.pos_weight + ) - if self.use_separate_channels: - with tf.name_scope("evaluation"): - tensors.output_layer_labels = tf.constant(self.dimension_labels_with_unknown) - evaluation_result = evaluate_separate_channels( - targets=pix2pix_model.targets, - outputs=pix2pix_model.outputs + if self.use_separate_channels: + with tf.name_scope("evaluation"): + tensors.output_layer_labels = tf.constant(self.dimension_labels_with_unknown) + evaluation_result = evaluate_separate_channels( + targets=pix2pix_model.targets, + outputs=pix2pix_model.outputs + ) + tensors.evaluation_result = evaluation_result + evaluation_summary(evaluation_result, self.dimension_labels_with_unknown) + else: + with tf.name_scope('evaluation'): + if self.dimension_colors: + tensors.output_layer_labels = tf.constant(self.dimension_labels) + colors_tensor = tf.constant( + self.dimension_colors_with_unknown, + dtype=tf.float32 + ) / 255.0 + tensors.outputs_class_indices = find_nearest_centroid_indices( + predictions=pix2pix_model.outputs, + centroids=colors_tensor + ) + tensors.targets_class_indices = find_nearest_centroid_indices( + predictions=pix2pix_model.targets, + centroids=colors_tensor + ) + evaluation_result = evaluate_predictions( + labels=tensors.targets_class_indices, + predictions=tensors.outputs_class_indices, + n_classes=len(self.dimension_colors_with_unknown) + ) + tensors.evaluation_result = evaluation_result + evaluation_summary(evaluation_result, self.dimension_labels) + + tensors.global_step = pix2pix_model.global_step + tensors.train = pix2pix_model.train + tensors.class_labels_tensor = tensors.annotation_tensor + tensors.pred = pix2pix_model.outputs + tensors.probabilities = pix2pix_model.outputs + tensors.metric_values = [pix2pix_model.discrim_loss] + + add_model_summary_images( + tensors, + self.dimension_colors, + self.dimension_labels, + use_separate_channels=self.use_separate_channels, + has_unknown_class=self.use_unknown_class ) - tensors.evaluation_result = evaluation_result - evaluation_summary(evaluation_result, self.dimension_labels_with_unknown) - else: - with tf.name_scope('evaluation'): - if self.dimension_colors: - tensors.output_layer_labels = tf.constant(self.dimension_labels) - colors_tensor = tf.constant( - self.dimension_colors_with_unknown, - dtype=tf.float32 - ) / 255.0 - tensors.outputs_class_indices = find_nearest_centroid_indices( - predictions=pix2pix_model.outputs, - centroids=colors_tensor - ) - tensors.targets_class_indices = find_nearest_centroid_indices( - predictions=pix2pix_model.targets, - centroids=colors_tensor - ) - evaluation_result = evaluate_predictions( - labels=tensors.targets_class_indices, - predictions=tensors.outputs_class_indices, - n_classes=len(self.dimension_colors_with_unknown) - ) - tensors.evaluation_result = evaluation_result - evaluation_summary(evaluation_result, self.dimension_labels) - - tensors.global_step = pix2pix_model.global_step - tensors.train = pix2pix_model.train - tensors.class_labels_tensor = tensors.annotation_tensor - tensors.pred = pix2pix_model.outputs - tensors.probabilities = pix2pix_model.outputs - tensors.metric_values = [pix2pix_model.discrim_loss] - - add_model_summary_images( - tensors, - self.dimension_colors, - self.dimension_labels, - use_separate_channels=self.use_separate_channels, - has_unknown_class=self.use_unknown_class - ) - # tensors.summaries = create_summaries(pix2pix_model) - create_other_summaries(pix2pix_model) + # tensors.summaries = create_summaries(pix2pix_model) + create_other_summaries(pix2pix_model) + + if ( + self.args.base_loss == BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY and + tensors.pos_weight is not None + ): + with tf.variable_scope('pos_weight_summary'): + tf.summary.text('pos_weight', tf.as_string(tf.reshape( + tensors.pos_weight, [-1, int(tensors.pos_weight.shape[-1])] + ))) - if ( - self.args.base_loss == BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY and - tensors.pos_weight is not None - ): - with tf.variable_scope('pos_weight_summary'): - tf.summary.text('pos_weight', tf.as_string(tf.reshape( - tensors.pos_weight, [-1, int(tensors.pos_weight.shape[-1])] - ))) + tensors.summary = tf.summary.merge_all() + return tensors - tensors.summary = tf.summary.merge_all() - return tensors + def build_train_graph(self, data_paths, batch_size): + return self.build_graph(data_paths, batch_size, GraphMode.TRAIN) - def build_train_graph(self, data_paths, batch_size): - return self.build_graph(data_paths, batch_size, GraphMode.TRAIN) + def build_eval_graph(self, data_paths, batch_size): + return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE) - def build_eval_graph(self, data_paths, batch_size): - return self.build_graph(data_paths, batch_size, GraphMode.EVALUATE) + def build_predict_graph(self): + return self.build_graph(None, None, GraphMode.PREDICT) - def build_predict_graph(self): - return self.build_graph(None, None, GraphMode.PREDICT) + def initialize(self, session): + pass - def initialize(self, session): - pass + def format_metric_values(self, metric_values): + """Formats metric values - used for logging purpose.""" - def format_metric_values(self, metric_values): - """Formats metric values - used for logging purpose.""" + # Early in training, metric_values may actually be None. + loss_str = 'N/A' + accuracy_str = 'N/A' + try: + loss_str = '%.3f' % metric_values[0] + accuracy_str = '%.3f' % metric_values[1] + except (TypeError, IndexError): + pass - # Early in training, metric_values may actually be None. - loss_str = 'N/A' - accuracy_str = 'N/A' - try: - loss_str = '%.3f' % metric_values[0] - accuracy_str = '%.3f' % metric_values[1] - except (TypeError, IndexError): - pass + return '%s, %s' % (loss_str, accuracy_str) - return '%s, %s' % (loss_str, accuracy_str) def str_to_bool(s): - return s.lower() in ('yes', 'true', '1') + return s.lower() in ('yes', 'true', '1') + def str_to_list(s): - s = s.strip() - if not s: - return [] - return [x.strip() for x in s.split(',')] + s = s.strip() + if not s: + return [] + return [x.strip() for x in s.split(',')] + def model_args_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--ngf", type=int, default=64, help="number of generator filters in first conv layer" - ) - parser.add_argument( - "--ndf", type=int, default=64, help="number of discriminator filters in first conv layer" - ) - parser.add_argument( - "--lr", type=float, default=0.0002, help="initial learning rate for adam" - ) - parser.add_argument( - "--beta1", type=float, default=0.5, help="momentum term of adam" - ) - parser.add_argument( - "--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient" - ) - parser.add_argument( - "--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient" - ) - - parser.add_argument( - '--pages', type=parse_page_range, default=None, - help='only processes the selected pages' - ) - parser.add_argument( - '--color_map', - type=str, - help='The path to the color map configuration.' - ) - parser.add_argument( - '--class_weights', - type=str, - help='The path to the class weights configuration.' - ) - parser.add_argument( - '--channels', - type=str_to_list, - help='The channels to use (subset of color map), otherwise all of the labels will be used' - ) - parser.add_argument( - '--filter_annotated', - type=str_to_bool, - default=False, - help='Only include pages that have annotations for the selected channels' - ' (if color map is provided)' - ) - parser.add_argument( - '--use_unknown_class', - type=str_to_bool, - default=True, - help='Use unknown class channel (if color map is provided)' - ) - parser.add_argument( - '--use_separate_channels', - type=str_to_bool, - default=False, - help='The separate output channels per annotation (if color map is provided)' - ) - parser.add_argument( - '--use_separate_discriminator_channels', - type=str_to_bool, - default=False, - help='The separate discriminator channels per annotation (if color map is provided)' - ) - parser.add_argument( - '--use_separate_discriminators', - type=str_to_bool, - default=False, - help='The separate discriminators per annotation (if color map is provided)' - ) - parser.add_argument( - '--base_loss', - type=str, - default=BaseLoss.L1, - choices=ALL_BASE_LOSS, - help='The base loss function to use' - ) - parser.add_argument( - '--debug', - type=str_to_bool, - default=True, - help='Enable debug mode' - ) - return parser + parser = argparse.ArgumentParser() + parser.add_argument( + "--ngf", type=int, default=64, help="number of generator filters in first conv layer" + ) + parser.add_argument( + "--ndf", type=int, default=64, help="number of discriminator filters in first conv layer" + ) + parser.add_argument( + "--lr", type=float, default=0.0002, help="initial learning rate for adam" + ) + parser.add_argument( + "--beta1", type=float, default=0.5, help="momentum term of adam" + ) + parser.add_argument( + "--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient" + ) + parser.add_argument( + "--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient" + ) + + parser.add_argument( + '--pages', type=parse_page_range, default=None, + help='only processes the selected pages' + ) + parser.add_argument( + '--color_map', + type=str, + help='The path to the color map configuration.' + ) + parser.add_argument( + '--class_weights', + type=str, + help='The path to the class weights configuration.' + ) + parser.add_argument( + '--channels', + type=str_to_list, + help='The channels to use (subset of color map), otherwise all of the labels will be used' + ) + parser.add_argument( + '--filter_annotated', + type=str_to_bool, + default=False, + help='Only include pages that have annotations for the selected channels' + ' (if color map is provided)' + ) + parser.add_argument( + '--use_unknown_class', + type=str_to_bool, + default=True, + help='Use unknown class channel (if color map is provided)' + ) + parser.add_argument( + '--use_separate_channels', + type=str_to_bool, + default=False, + help='The separate output channels per annotation (if color map is provided)' + ) + parser.add_argument( + '--use_separate_discriminator_channels', + type=str_to_bool, + default=False, + help='The separate discriminator channels per annotation (if color map is provided)' + ) + parser.add_argument( + '--use_separate_discriminators', + type=str_to_bool, + default=False, + help='The separate discriminators per annotation (if color map is provided)' + ) + parser.add_argument( + '--base_loss', + type=str, + default=BaseLoss.L1, + choices=ALL_BASE_LOSS, + help='The base loss function to use' + ) + parser.add_argument( + '--debug', + type=str_to_bool, + default=True, + help='Enable debug mode' + ) + return parser def create_model(argv=None): - """Factory method that creates model to be used by generic task.py.""" - parser = model_args_parser() - args, task_args = parser.parse_known_args(argv) - return Model(args), task_args + """Factory method that creates model to be used by generic task.py.""" + parser = model_args_parser() + args, task_args = parser.parse_known_args(argv) + return Model(args), task_args diff --git a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py index 7377812..3d59547 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py +++ b/sciencebeam_gym/trainer/models/pix2pix/pix2pix_model_test.py @@ -7,36 +7,36 @@ from pytest import raises import tensorflow as tf from sciencebeam_utils.utils.collection import ( - extend_dict + extend_dict ) from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import ( - BaseLoss, - ALL_BASE_LOSS + BaseLoss, + ALL_BASE_LOSS ) from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core_test import ( - DEFAULT_ARGS as CORE_DEFAULT_ARGS + DEFAULT_ARGS as CORE_DEFAULT_ARGS ) from sciencebeam_gym.trainer.data.examples_test import ( - EXAMPLE_PROPS_1 + EXAMPLE_PROPS_1 ) import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as pix2pix_model from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import ( - parse_color_map, - color_map_to_labels, - color_map_to_colors, - colors_and_labels_with_unknown_class, - UNKNOWN_COLOR, - UNKNOWN_LABEL, - DEFAULT_UNKNOWN_CLASS_WEIGHT, - Model, - str_to_list, - model_args_parser, - class_weights_to_pos_weight + parse_color_map, + color_map_to_labels, + color_map_to_colors, + colors_and_labels_with_unknown_class, + UNKNOWN_COLOR, + UNKNOWN_LABEL, + DEFAULT_UNKNOWN_CLASS_WEIGHT, + Model, + str_to_list, + model_args_parser, + class_weights_to_pos_weight ) COLOR_MAP_FILENAME = 'color_map.conf' @@ -44,384 +44,452 @@ CLASS_WEIGHTS_FILENAME = 'class-weights.json' DATA_PATH = 'some/where/*.tfrecord' BATCH_SIZE = 2 + def some_color(i): - return (i, i, i) + return (i, i, i) + SOME_COLORS = [some_color(1), some_color(2), some_color(3)] SOME_LABELS = ['a', 'b', 'c'] SOME_COLOR_MAP = { - k: v for k, v in zip(SOME_LABELS, SOME_COLORS) + k: v for k, v in zip(SOME_LABELS, SOME_COLORS) } SOME_CLASS_WEIGHTS = { - k: float(1 + i) for i, k in enumerate(SOME_LABELS) + k: float(1 + i) for i, k in enumerate(SOME_LABELS) } -class TestParseColorMap(object): - def test_should_use_fileio_to_load_file_and_pass_to_parser(self): + +@pytest.fixture(name='FileIO_mock') +def _FileIO_mock(): with patch.object(pix2pix_model, 'FileIO') as FileIO: - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: + yield FileIO + + +@pytest.fixture(name='parse_color_map_from_file_mock') +def _parse_color_map_from_file_mock(): + with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: + yield parse_color_map_from_file + + +@pytest.fixture(name='parse_json_file_mock') +def _parse_json_file_mock(): + with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: + yield parse_json_file + + +@pytest.fixture(name='get_matching_files_mock') +def _get_matching_files_mock(): + with patch.object(pix2pix_model, 'get_matching_files') as get_matching_files: + yield get_matching_files + + +@pytest.fixture(name='read_examples_mock') +def _read_examples_mock(): + with patch.object(pix2pix_model, 'read_examples') as read_examples: + yield read_examples + + +@pytest.fixture(name='default_graph') +def _graph(): + with tf.Graph().as_default() as graph: + yield graph + + +class TestParseColorMap(object): + def test_should_use_fileio_to_load_file_and_pass_to_parser( + self, FileIO_mock, parse_color_map_from_file_mock): + parse_color_map(COLOR_MAP_FILENAME) - FileIO.assert_called_with(COLOR_MAP_FILENAME, 'r') - parse_color_map_from_file.assert_called_with(FileIO.return_value.__enter__.return_value) + FileIO_mock.assert_called_with(COLOR_MAP_FILENAME, 'r') + parse_color_map_from_file_mock.assert_called_with( + FileIO_mock.return_value.__enter__.return_value) + class TestColorMapToLabels(object): - def test_should_use_color_maps_keys_by_default(self): - color_map = { - 'a': some_color(1), - 'b': some_color(2), - 'c': some_color(3) - } - assert color_map_to_labels(color_map) == ['a', 'b', 'c'] - - def test_should_return_specified_labels(self): - color_map = { - 'a': some_color(1), - 'b': some_color(2), - 'c': some_color(3) - } - assert color_map_to_labels(color_map, ['b', 'a']) == ['b', 'a'] - - def test_should_raise_error_if_specified_label_not_in_color_map(self): - color_map = { - 'a': some_color(1), - 'c': some_color(3) - } - with raises(ValueError): - color_map_to_labels(color_map, ['a', 'b']) + def test_should_use_color_maps_keys_by_default(self): + color_map = { + 'a': some_color(1), + 'b': some_color(2), + 'c': some_color(3) + } + assert color_map_to_labels(color_map) == ['a', 'b', 'c'] + + def test_should_return_specified_labels(self): + color_map = { + 'a': some_color(1), + 'b': some_color(2), + 'c': some_color(3) + } + assert color_map_to_labels(color_map, ['b', 'a']) == ['b', 'a'] + + def test_should_raise_error_if_specified_label_not_in_color_map(self): + color_map = { + 'a': some_color(1), + 'c': some_color(3) + } + with raises(ValueError): + color_map_to_labels(color_map, ['a', 'b']) + class TestColorMapToColors(object): - def test_should_return_colors_for_labels(self): - color_map = { - 'a': some_color(1), - 'b': some_color(2), - 'c': some_color(3) - } - assert color_map_to_colors(color_map, ['a', 'b']) == [ - some_color(1), - some_color(2) - ] + def test_should_return_colors_for_labels(self): + color_map = { + 'a': some_color(1), + 'b': some_color(2), + 'c': some_color(3) + } + assert color_map_to_colors(color_map, ['a', 'b']) == [ + some_color(1), + some_color(2) + ] + class TestColorsAndLabelsWithUnknownClass(object): - def test_should_not_add_unknown_class_if_not_enabled(self): - colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class( - SOME_COLORS, - SOME_LABELS, - use_unknown_class=False - ) - assert colors_with_unknown == SOME_COLORS - assert labels_with_unknown == SOME_LABELS - - def test_should_add_unknown_class_if_enabled(self): - colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class( - SOME_COLORS, - SOME_LABELS, - use_unknown_class=True - ) - assert colors_with_unknown == SOME_COLORS + [UNKNOWN_COLOR] - assert labels_with_unknown == SOME_LABELS + [UNKNOWN_LABEL] - - def test_should_add_unknown_class_if_colors_are_empty(self): - colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class( - [], - [], - use_unknown_class=False - ) - assert colors_with_unknown == [UNKNOWN_COLOR] - assert labels_with_unknown == [UNKNOWN_LABEL] + def test_should_not_add_unknown_class_if_not_enabled(self): + colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class( + SOME_COLORS, + SOME_LABELS, + use_unknown_class=False + ) + assert colors_with_unknown == SOME_COLORS + assert labels_with_unknown == SOME_LABELS + + def test_should_add_unknown_class_if_enabled(self): + colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class( + SOME_COLORS, + SOME_LABELS, + use_unknown_class=True + ) + assert colors_with_unknown == SOME_COLORS + [UNKNOWN_COLOR] + assert labels_with_unknown == SOME_LABELS + [UNKNOWN_LABEL] + + def test_should_add_unknown_class_if_colors_are_empty(self): + colors_with_unknown, labels_with_unknown = colors_and_labels_with_unknown_class( + [], + [], + use_unknown_class=False + ) + assert colors_with_unknown == [UNKNOWN_COLOR] + assert labels_with_unknown == [UNKNOWN_LABEL] + class TestClassWeightsToPosWeight(object): - def test_should_extract_selected_weights(self): - assert class_weights_to_pos_weight({ - 'a': 0.1, - 'b': 0.2, - 'c': 0.3 - }, ['a', 'b'], False) == [0.1, 0.2] - - def test_should_add_zero_if_unknown_class_is_true(self): - assert class_weights_to_pos_weight({ - 'a': 0.1, - 'b': 0.2, - 'c': 0.3 - }, ['a', 'b'], True, DEFAULT_UNKNOWN_CLASS_WEIGHT) == ( - [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT] - ) + def test_should_extract_selected_weights(self): + assert class_weights_to_pos_weight({ + 'a': 0.1, + 'b': 0.2, + 'c': 0.3 + }, ['a', 'b'], False) == [0.1, 0.2] + + def test_should_add_zero_if_unknown_class_is_true(self): + assert class_weights_to_pos_weight({ + 'a': 0.1, + 'b': 0.2, + 'c': 0.3 + }, ['a', 'b'], True, DEFAULT_UNKNOWN_CLASS_WEIGHT) == ( + [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT] + ) + DEFAULT_ARGS = extend_dict( - CORE_DEFAULT_ARGS, - dict( - pages=None, - color_map=None, - class_weights=None, - channels=None, - filter_annotated=False, - use_separate_channels=False, - use_unknown_class=False, - debug=False - ) + CORE_DEFAULT_ARGS, + dict( + pages=None, + color_map=None, + class_weights=None, + channels=None, + filter_annotated=False, + use_separate_channels=False, + use_unknown_class=False, + debug=False + ) ) + def create_args(*args, **kwargs): - d = extend_dict(*list(args) + [kwargs]) - return namedtuple('args', d.keys())(**d) + d = extend_dict(*list(args) + [kwargs]) + return namedtuple('args', d.keys())(**d) + +@pytest.mark.usefixtures( + 'parse_json_file_mock' +) class TestModel(object): - def test_parse_separate_channels_with_color_map_without_class_weights(self): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - parse_color_map_from_file.return_value = { - 'a': some_color(1), - 'b': some_color(2), - 'c': some_color(3) - } - args = create_args( - DEFAULT_ARGS, - color_map=COLOR_MAP_FILENAME, - class_weights=None, - channels=['a', 'b'], - use_separate_channels=True, - use_unknown_class=True - ) - model = Model(args) - assert model.dimension_colors == [some_color(1), some_color(2)] - assert model.dimension_labels == ['a', 'b'] - assert model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR] - assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL] - assert model.pos_weight is None - - def test_parse_separate_channels_with_color_map_and_class_weights(self): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: - parse_color_map_from_file.return_value = { - 'a': some_color(1), - 'b': some_color(2), - 'c': some_color(3) + def test_parse_separate_channels_with_color_map_without_class_weights( + self, + parse_color_map_from_file_mock): + + parse_color_map_from_file_mock.return_value = { + 'a': some_color(1), + 'b': some_color(2), + 'c': some_color(3) } - parse_json_file.return_value = { - 'a': 0.1, - 'b': 0.2, - 'c': 0.3 + args = create_args( + DEFAULT_ARGS, + color_map=COLOR_MAP_FILENAME, + class_weights=None, + channels=['a', 'b'], + use_separate_channels=True, + use_unknown_class=True + ) + model = Model(args) + assert model.dimension_colors == [some_color(1), some_color(2)] + assert model.dimension_labels == ['a', 'b'] + assert model.dimension_colors_with_unknown == [ + some_color(1), some_color(2), UNKNOWN_COLOR] + assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL] + assert model.pos_weight is None + + def test_parse_separate_channels_with_color_map_and_class_weights( + self, + parse_color_map_from_file_mock, + parse_json_file_mock): + + parse_color_map_from_file_mock.return_value = { + 'a': some_color(1), + 'b': some_color(2), + 'c': some_color(3) + } + parse_json_file_mock.return_value = { + 'a': 0.1, + 'b': 0.2, + 'c': 0.3 } args = create_args( - DEFAULT_ARGS, - base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY, - color_map=COLOR_MAP_FILENAME, - class_weights=CLASS_WEIGHTS_FILENAME, - channels=['a', 'b'], - use_separate_channels=True, - use_unknown_class=True + DEFAULT_ARGS, + base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY, + color_map=COLOR_MAP_FILENAME, + class_weights=CLASS_WEIGHTS_FILENAME, + channels=['a', 'b'], + use_separate_channels=True, + use_unknown_class=True ) model = Model(args) assert model.dimension_colors == [some_color(1), some_color(2)] assert model.dimension_labels == ['a', 'b'] assert ( - model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR] + model.dimension_colors_with_unknown == [ + some_color(1), some_color(2), UNKNOWN_COLOR] ) assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL] assert model.pos_weight == [0.1, 0.2, DEFAULT_UNKNOWN_CLASS_WEIGHT] - def test_should_only_include_labels_with_non_zero_class_labels_by_default(self): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: - parse_color_map_from_file.return_value = { - 'a': some_color(1), - 'b': some_color(2), - 'c': some_color(3) + def test_should_only_include_labels_with_non_zero_class_labels_by_default( + self, + parse_color_map_from_file_mock, + parse_json_file_mock): + + parse_color_map_from_file_mock.return_value = { + 'a': some_color(1), + 'b': some_color(2), + 'c': some_color(3) } - parse_json_file.return_value = { - 'a': 0.1, - 'b': 0.0, - 'c': 0.3 + parse_json_file_mock.return_value = { + 'a': 0.1, + 'b': 0.0, + 'c': 0.3 } args = create_args( - DEFAULT_ARGS, - base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY, - color_map=COLOR_MAP_FILENAME, - class_weights=CLASS_WEIGHTS_FILENAME, - use_separate_channels=True, - use_unknown_class=True + DEFAULT_ARGS, + base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY, + color_map=COLOR_MAP_FILENAME, + class_weights=CLASS_WEIGHTS_FILENAME, + use_separate_channels=True, + use_unknown_class=True ) model = Model(args) assert model.dimension_labels == ['a', 'c'] assert model.dimension_colors == [some_color(1), some_color(3)] assert model.pos_weight == [0.1, 0.3, DEFAULT_UNKNOWN_CLASS_WEIGHT] - def test_should_use_unknown_class_weight_from_configuration(self): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: - parse_color_map_from_file.return_value = SOME_COLOR_MAP - parse_json_file.return_value = extend_dict(SOME_CLASS_WEIGHTS, { - 'unknown': 0.99 + def test_should_use_unknown_class_weight_from_configuration( + self, + parse_color_map_from_file_mock, + parse_json_file_mock): + + parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP + parse_json_file_mock.return_value = extend_dict(SOME_CLASS_WEIGHTS, { + 'unknown': 0.99 }) args = create_args( - DEFAULT_ARGS, - base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY, - color_map=COLOR_MAP_FILENAME, - class_weights=CLASS_WEIGHTS_FILENAME, - use_separate_channels=True, - use_unknown_class=True + DEFAULT_ARGS, + base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY, + color_map=COLOR_MAP_FILENAME, + class_weights=CLASS_WEIGHTS_FILENAME, + use_separate_channels=True, + use_unknown_class=True ) model = Model(args) assert model.pos_weight[-1] == 0.99 - def test_should_not_load_class_weights_for_cross_entropy(self): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'parse_json_file'): - parse_color_map_from_file.return_value = SOME_COLOR_MAP + def test_should_not_load_class_weights_for_cross_entropy( + self, + parse_color_map_from_file_mock): + + parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP args = create_args( - DEFAULT_ARGS, - base_loss=BaseLoss.CROSS_ENTROPY, - color_map=COLOR_MAP_FILENAME, - class_weights=CLASS_WEIGHTS_FILENAME, - use_separate_channels=True, - use_unknown_class=True + DEFAULT_ARGS, + base_loss=BaseLoss.CROSS_ENTROPY, + color_map=COLOR_MAP_FILENAME, + class_weights=CLASS_WEIGHTS_FILENAME, + use_separate_channels=True, + use_unknown_class=True ) model = Model(args) assert model.pos_weight is None - def test_should_not_load_class_weights_for_sample_weighted_cross_entropy(self): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'parse_json_file'): - parse_color_map_from_file.return_value = SOME_COLOR_MAP + def test_should_not_load_class_weights_for_sample_weighted_cross_entropy( + self, + parse_color_map_from_file_mock): + + parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP args = create_args( - DEFAULT_ARGS, - base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, - color_map=COLOR_MAP_FILENAME, - class_weights=CLASS_WEIGHTS_FILENAME, - use_separate_channels=True, - use_unknown_class=True + DEFAULT_ARGS, + base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, + color_map=COLOR_MAP_FILENAME, + class_weights=CLASS_WEIGHTS_FILENAME, + use_separate_channels=True, + use_unknown_class=True ) model = Model(args) assert model.pos_weight is None + @pytest.mark.slow @pytest.mark.very_slow +@pytest.mark.usefixtures( + 'get_matching_files_mock', + 'default_graph' +) class TestModelBuildGraph(object): - def test_should_build_train_graph_with_defaults(self): - with tf.Graph().as_default(): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'get_matching_files'): - with patch.object(pix2pix_model, 'read_examples') as read_examples: - parse_color_map_from_file.return_value = SOME_COLOR_MAP - read_examples.return_value = EXAMPLE_PROPS_1 - args = create_args( - DEFAULT_ARGS - ) - model = Model(args) - model.build_train_graph(DATA_PATH, BATCH_SIZE) - - def test_should_build_train_graph_with_class_weights(self): - with tf.Graph().as_default(): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: - with patch.object(pix2pix_model, 'get_matching_files'): - with patch.object(pix2pix_model, 'read_examples') as read_examples: - parse_color_map_from_file.return_value = SOME_COLOR_MAP - parse_json_file.return_value = SOME_CLASS_WEIGHTS - read_examples.return_value = EXAMPLE_PROPS_1 - args = create_args( - DEFAULT_ARGS, - base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY, - color_map=COLOR_MAP_FILENAME, - class_weights=CLASS_WEIGHTS_FILENAME, - channels=['a', 'b'], - use_separate_channels=True, - use_unknown_class=True - ) - model = Model(args) - tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE) - assert tensors.pos_weight is not None - - def test_should_build_train_graph_with_sample_class_weights(self): - with tf.Graph().as_default(): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: - with patch.object(pix2pix_model, 'get_matching_files'): - with patch.object(pix2pix_model, 'read_examples') as read_examples: - parse_color_map_from_file.return_value = SOME_COLOR_MAP - parse_json_file.return_value = SOME_CLASS_WEIGHTS - read_examples.return_value = EXAMPLE_PROPS_1 - args = create_args( - DEFAULT_ARGS, - base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, - color_map=COLOR_MAP_FILENAME, - channels=SOME_LABELS, - use_separate_channels=True, - use_unknown_class=True - ) - model = Model(args) - tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE) - n_output_channels = len(SOME_LABELS) + 1 - assert ( - tensors.separate_channel_annotation_tensor.shape.as_list() == - [BATCH_SIZE, model.image_height, model.image_width, n_output_channels] - ) - assert tensors.pos_weight.shape.as_list() == [BATCH_SIZE, 1, 1, n_output_channels] - - def test_should_build_predict_graph_with_defaults(self): - with tf.Graph().as_default(): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'get_matching_files'): - with patch.object(pix2pix_model, 'read_examples') as read_examples: - parse_color_map_from_file.return_value = SOME_COLOR_MAP - read_examples.return_value = EXAMPLE_PROPS_1 - args = create_args( - DEFAULT_ARGS - ) - model = Model(args) - tensors = model.build_predict_graph() - n_output_channels = 3 - assert ( - tensors.pred.shape.as_list() == - [None, model.image_height, model.image_width, n_output_channels] - ) - - def test_should_build_predict_graph_with_sample_class_weights(self): - with tf.Graph().as_default(): - with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file: - with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file: - with patch.object(pix2pix_model, 'get_matching_files'): - with patch.object(pix2pix_model, 'read_examples') as read_examples: - parse_color_map_from_file.return_value = SOME_COLOR_MAP - parse_json_file.return_value = SOME_CLASS_WEIGHTS - read_examples.return_value = EXAMPLE_PROPS_1 - args = create_args( - DEFAULT_ARGS, - base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, - color_map=COLOR_MAP_FILENAME, - channels=SOME_LABELS, - use_separate_channels=True, - use_unknown_class=True - ) - model = Model(args) - tensors = model.build_predict_graph() - n_output_channels = len(SOME_LABELS) + 1 - assert ( - tensors.pred.shape.as_list() == - [None, model.image_height, model.image_width, n_output_channels] - ) + def test_should_build_train_graph_with_defaults( + self, + parse_color_map_from_file_mock, + read_examples_mock): + + parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP + read_examples_mock.return_value = EXAMPLE_PROPS_1 + args = create_args( + DEFAULT_ARGS + ) + model = Model(args) + model.build_train_graph(DATA_PATH, BATCH_SIZE) + + def test_should_build_train_graph_with_class_weights( + self, + parse_color_map_from_file_mock, + parse_json_file_mock, + read_examples_mock): + + parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP + parse_json_file_mock.return_value = SOME_CLASS_WEIGHTS + read_examples_mock.return_value = EXAMPLE_PROPS_1 + args = create_args( + DEFAULT_ARGS, + base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY, + color_map=COLOR_MAP_FILENAME, + class_weights=CLASS_WEIGHTS_FILENAME, + channels=['a', 'b'], + use_separate_channels=True, + use_unknown_class=True + ) + model = Model(args) + tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE) + assert tensors.pos_weight is not None + + def test_should_build_train_graph_with_sample_class_weights( + self, + parse_color_map_from_file_mock, + parse_json_file_mock, + read_examples_mock): + + parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP + parse_json_file_mock.return_value = SOME_CLASS_WEIGHTS + read_examples_mock.return_value = EXAMPLE_PROPS_1 + args = create_args( + DEFAULT_ARGS, + base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, + color_map=COLOR_MAP_FILENAME, + channels=SOME_LABELS, + use_separate_channels=True, + use_unknown_class=True + ) + model = Model(args) + tensors = model.build_train_graph(DATA_PATH, BATCH_SIZE) + n_output_channels = len(SOME_LABELS) + 1 + assert ( + tensors.separate_channel_annotation_tensor.shape.as_list() == + [BATCH_SIZE, model.image_height, model.image_width, n_output_channels] + ) + assert tensors.pos_weight.shape.as_list( + ) == [BATCH_SIZE, 1, 1, n_output_channels] + + def test_should_build_predict_graph_with_defaults( + self, + parse_color_map_from_file_mock, + read_examples_mock): + parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP + read_examples_mock.return_value = EXAMPLE_PROPS_1 + args = create_args( + DEFAULT_ARGS + ) + model = Model(args) + tensors = model.build_predict_graph() + n_output_channels = 3 + assert ( + tensors.pred.shape.as_list() == + [None, model.image_height, model.image_width, n_output_channels] + ) + + def test_should_build_predict_graph_with_sample_class_weights( + self, + parse_color_map_from_file_mock, + parse_json_file_mock, + read_examples_mock): + parse_color_map_from_file_mock.return_value = SOME_COLOR_MAP + parse_json_file_mock.return_value = SOME_CLASS_WEIGHTS + read_examples_mock.return_value = EXAMPLE_PROPS_1 + args = create_args( + DEFAULT_ARGS, + base_loss=BaseLoss.SAMPLE_WEIGHTED_CROSS_ENTROPY, + color_map=COLOR_MAP_FILENAME, + channels=SOME_LABELS, + use_separate_channels=True, + use_unknown_class=True + ) + model = Model(args) + tensors = model.build_predict_graph() + n_output_channels = len(SOME_LABELS) + 1 + assert ( + tensors.pred.shape.as_list() == + [None, model.image_height, model.image_width, n_output_channels] + ) + class TestStrToList(object): - def test_should_parse_empty_string_as_empty_list(self): - assert str_to_list('') == [] + def test_should_parse_empty_string_as_empty_list(self): + assert str_to_list('') == [] + + def test_should_parse_blank_string_as_empty_list(self): + assert str_to_list(' ') == [] - def test_should_parse_blank_string_as_empty_list(self): - assert str_to_list(' ') == [] + def test_should_parse_comma_separated_list(self): + assert str_to_list('a,b,c') == ['a', 'b', 'c'] - def test_should_parse_comma_separated_list(self): - assert str_to_list('a,b,c') == ['a', 'b', 'c'] + def test_should_ignore_white_space_around_values(self): + assert str_to_list(' a , b , c ') == ['a', 'b', 'c'] - def test_should_ignore_white_space_around_values(self): - assert str_to_list(' a , b , c ') == ['a', 'b', 'c'] class TestModelArgsParser(object): - def test_should_parse_channels(self): - args = model_args_parser().parse_args(['--channels', 'a,b,c']) - assert args.channels == ['a', 'b', 'c'] - - def test_should_set_channels_to_none_by_default(self): - args = model_args_parser().parse_args([]) - assert args.channels is None - - def test_should_allow_all_base_loss_options(self): - for base_loss in ALL_BASE_LOSS: - args = model_args_parser().parse_args(['--base_loss', base_loss]) - assert args.base_loss == base_loss + def test_should_parse_channels(self): + args = model_args_parser().parse_args(['--channels', 'a,b,c']) + assert args.channels == ['a', 'b', 'c'] + + def test_should_set_channels_to_none_by_default(self): + args = model_args_parser().parse_args([]) + assert args.channels is None + + def test_should_allow_all_base_loss_options(self): + for base_loss in ALL_BASE_LOSS: + args = model_args_parser().parse_args(['--base_loss', base_loss]) + assert args.base_loss == base_loss diff --git a/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py b/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py index 440c100..ba14e6c 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py +++ b/sciencebeam_gym/trainer/models/pix2pix/tf_utils.py @@ -3,49 +3,54 @@ from __future__ import division import tensorflow as tf + def pairwise_squared_distance_with(predictions, centroids): - centroid_count = centroids.shape[0] - return tf.stack([ - tf.reduce_sum( - tf.square(predictions - centroids[i]), - axis=-1 - ) - for i in range(centroid_count) - ], axis=-1) + centroid_count = centroids.shape[0] + return tf.stack([ + tf.reduce_sum( + tf.square(predictions - centroids[i]), + axis=-1 + ) + for i in range(centroid_count) + ], axis=-1) + def find_nearest_centroid_indices(predictions, centroids): - return tf.argmin( - pairwise_squared_distance_with(predictions, centroids), - axis=-1 - ) + return tf.argmin( + pairwise_squared_distance_with(predictions, centroids), + axis=-1 + ) + def find_nearest_centroid(predictions, centroids): - return tf.gather( - params=centroids, - indices=find_nearest_centroid_indices(predictions, centroids) - ) + return tf.gather( + params=centroids, + indices=find_nearest_centroid_indices(predictions, centroids) + ) + def get_channel_slice(tensor, channel_index): - rank = len(tensor.shape) - return tf.slice( - tensor, - begin=[0] * (rank - 1) + [channel_index], - size=[-1] * (rank - 1) + [1] - ) + rank = len(tensor.shape) + return tf.slice( + tensor, + begin=[0] * (rank - 1) + [channel_index], + size=[-1] * (rank - 1) + [1] + ) + def blank_other_channels(tensor, keep_index): - tensor_shape = tensor.shape - n_channels = int(tensor_shape[-1]) - rank = len(tensor_shape) - tensor_slice = get_channel_slice( - tensor, - keep_index - ) - paddings = tf.constant( - [[0, 0]] * (rank - 1) + - [[keep_index, n_channels - keep_index - 1]] - ) - padded = tf.pad( - tensor_slice, paddings, "CONSTANT" - ) - return padded + tensor_shape = tensor.shape + n_channels = int(tensor_shape[-1]) + rank = len(tensor_shape) + tensor_slice = get_channel_slice( + tensor, + keep_index + ) + paddings = tf.constant( + [[0, 0]] * (rank - 1) + + [[keep_index, n_channels - keep_index - 1]] + ) + padded = tf.pad( + tensor_slice, paddings, "CONSTANT" + ) + return padded diff --git a/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py b/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py index 22cfc0d..d5e0887 100644 --- a/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py +++ b/sciencebeam_gym/trainer/models/pix2pix/tf_utils_test.py @@ -2,90 +2,92 @@ from __future__ import absolute_import from __future__ import division import tensorflow as tf -import numpy as np from sciencebeam_utils.utils.num import ( - assert_all_close + assert_all_close ) from sciencebeam_gym.trainer.models.pix2pix.tf_utils import ( - find_nearest_centroid, - blank_other_channels + find_nearest_centroid, + blank_other_channels ) + def test_find_nearest_centroid(): - colors = tf.constant([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]]) - outputs = tf.constant([[[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]]]) - nearest_color = find_nearest_centroid(outputs, colors) + colors = tf.constant([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]]) + outputs = tf.constant([[[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]]]) + nearest_color = find_nearest_centroid(outputs, colors) - with tf.Session() as session: - assert_all_close( - session.run(nearest_color), - [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]]] - ) + with tf.Session() as session: + assert_all_close( + session.run(nearest_color), + [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]]] + ) -def test_find_nearest_centroid1(): - colors = tf.constant([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]]) - outputs = tf.constant([ - [ - [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]], - [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]] - ], - [ - [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]], - [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]] - ] - ]) - nearest_color = find_nearest_centroid(outputs, colors) - with tf.Session() as session: - assert_all_close( - session.run(nearest_color), - [ +def test_find_nearest_centroid1(): + colors = tf.constant([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]]) + outputs = tf.constant([ [ - [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]], - [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]] + [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]], + [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]] ], [ - [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]], - [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]] + [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]], + [[0.1, 0.1, 0.1], [0.9, 0.1, 0.1], [0.9, 0.9, 0.9], [0.1, 0.1, 0.9]] ] - ] - ) + ]) + nearest_color = find_nearest_centroid(outputs, colors) + + with tf.Session() as session: + assert_all_close( + session.run(nearest_color), + [ + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]], + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]] + ], + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]], + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]] + ] + ] + ) + def test_blank_other_channels(): - tensor = tf.constant([ - [ - [5, 5, 5, 5], - [6, 6, 6, 6], - [7, 7, 7, 7], - [8, 8, 8, 8] - ], - [ - [5, 5, 5, 5], - [6, 6, 6, 6], - [7, 7, 7, 7], - [8, 8, 8, 8] - ] - ]) - padded = blank_other_channels( - tensor, 1 - ) - with tf.Session() as session: - assert_all_close( - session.run(padded), - [ + tensor = tf.constant([ [ - [0, 5, 0, 0], - [0, 6, 0, 0], - [0, 7, 0, 0], - [0, 8, 0, 0] + [5, 5, 5, 5], + [6, 6, 6, 6], + [7, 7, 7, 7], + [8, 8, 8, 8] ], [ - [0, 5, 0, 0], - [0, 6, 0, 0], - [0, 7, 0, 0], - [0, 8, 0, 0] + [5, 5, 5, 5], + [6, 6, 6, 6], + [7, 7, 7, 7], + [8, 8, 8, 8] ] - ] + ]) + padded = blank_other_channels( + tensor, 1 ) + with tf.Session() as session: + assert_all_close( + session.run(padded), + [ + [ + [0, 5, 0, 0], + [0, 6, 0, 0], + [0, 7, 0, 0], + [0, 8, 0, 0] + ], + [ + [0, 5, 0, 0], + [0, 6, 0, 0], + [0, 7, 0, 0], + [0, 8, 0, 0] + ] + ] + ) diff --git a/sciencebeam_gym/trainer/predict.py b/sciencebeam_gym/trainer/predict.py index 441a528..a81de41 100644 --- a/sciencebeam_gym/trainer/predict.py +++ b/sciencebeam_gym/trainer/predict.py @@ -8,48 +8,52 @@ import tensorflow as tf from sciencebeam_gym.utils.tf import FileIO from sciencebeam_gym.inference_model import ( - load_inference_model + load_inference_model ) from sciencebeam_gym.trainer.checkpoint import ( - load_last_checkpoint_as_inference_model + load_last_checkpoint_as_inference_model ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def predict_using_inference_model( - inference_model, predict_filename, output_image_filename): + inference_model, predict_filename, output_image_filename): - with FileIO(predict_filename, 'rb') as input_f: - image_bytes = input_f.read() - img = Image.open(BytesIO(image_bytes)).convert('RGB') - img_data = np.asarray(img, dtype=np.uint8) - img_data_batch = np.reshape(img_data, tuple([1] + list(img_data.shape))) + with FileIO(predict_filename, 'rb') as input_f: + image_bytes = input_f.read() + img = Image.open(BytesIO(image_bytes)).convert('RGB') + img_data = np.asarray(img, dtype=np.uint8) + img_data_batch = np.reshape(img_data, tuple([1] + list(img_data.shape))) - output_img_data_batch = inference_model(img_data_batch) - output_img_data = output_img_data_batch[0] - output_image = Image.fromarray(output_img_data, 'RGB') - out = BytesIO() - output_image.save(out, 'png') - output_image_bytes = out.getvalue() + output_img_data_batch = inference_model(img_data_batch) + output_img_data = output_img_data_batch[0] + output_image = Image.fromarray(output_img_data, 'RGB') + out = BytesIO() + output_image.save(out, 'png') + output_image_bytes = out.getvalue() + + get_logger().info('writing to %s', output_image_filename) + with FileIO(output_image_filename, 'wb') as output_f: + output_f.write(output_image_bytes) - get_logger().info('writing to %s', output_image_filename) - with FileIO(output_image_filename, 'wb') as output_f: - output_f.write(output_image_bytes) def load_saved_model_and_predict(export_dir, predict_filename, output_image_filename): - with tf.Session(graph=tf.Graph()): - predict_using_inference_model( - load_inference_model(export_dir), - predict_filename, - output_image_filename - ) + with tf.Session(graph=tf.Graph()): + predict_using_inference_model( + load_inference_model(export_dir), + predict_filename, + output_image_filename + ) + def load_checkpoint_and_predict(model, checkpoint_path, predict_filename, output_image_filename): - with tf.Session(graph=tf.Graph()): - predict_using_inference_model( - load_last_checkpoint_as_inference_model(model, checkpoint_path), - predict_filename, - output_image_filename - ) + with tf.Session(graph=tf.Graph()): + predict_using_inference_model( + load_last_checkpoint_as_inference_model(model, checkpoint_path), + predict_filename, + output_image_filename + ) diff --git a/sciencebeam_gym/trainer/saver.py b/sciencebeam_gym/trainer/saver.py index ef9797e..1ad9b04 100644 --- a/sciencebeam_gym/trainer/saver.py +++ b/sciencebeam_gym/trainer/saver.py @@ -3,17 +3,19 @@ import logging import tensorflow as tf from sciencebeam_gym.inference_model import ( - save_inference_model + save_inference_model ) from sciencebeam_gym.trainer.checkpoint import ( - load_last_checkpoint_as_inference_model + load_last_checkpoint_as_inference_model ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def load_checkpoint_and_save_model(model, checkpoint_path, export_dir): - with tf.Session(graph=tf.Graph()): - inference_model = load_last_checkpoint_as_inference_model(model, checkpoint_path) - save_inference_model(export_dir, inference_model) + with tf.Session(graph=tf.Graph()): + inference_model = load_last_checkpoint_as_inference_model(model, checkpoint_path) + save_inference_model(export_dir, inference_model) diff --git a/sciencebeam_gym/trainer/task.py b/sciencebeam_gym/trainer/task.py index 192f483..3adf822 100644 --- a/sciencebeam_gym/trainer/task.py +++ b/sciencebeam_gym/trainer/task.py @@ -21,748 +21,771 @@ from tensorflow.python.client.device_lib import list_local_devices from sciencebeam_gym.trainer.evaluator import Evaluator from sciencebeam_gym.trainer.util import ( - CustomSupervisor, - SimpleStepScheduler, - get_graph_size + CustomSupervisor, + SimpleStepScheduler, + get_graph_size ) from sciencebeam_gym.trainer.predict import ( - load_checkpoint_and_predict, - load_saved_model_and_predict + load_checkpoint_and_predict, + load_saved_model_and_predict ) from sciencebeam_gym.trainer.saver import ( - load_checkpoint_and_save_model + load_checkpoint_and_save_model ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class TrainingProgressLogger(object): - def __init__(self, start_time, start_step, task): - self.start_time = start_time - self.start_step = start_step - self.last_log_time = start_time - self.last_global_step = start_step - self.last_local_step = 0 - self.task = task - - def get_last_log_time(self): - return self.last_log_time - - def log(self, now, global_step, local_step): - """Logs training progress.""" - logging.info( - 'Train [%s/%d], step %d (%.3f sec) %.1f ' - 'global steps/s, %.1f local steps/s', - self.task.type, - self.task.index, - global_step, - (now - self.start_time), - (global_step - self.last_global_step) / - (now - self.last_log_time), - (local_step - self.last_local_step) / - (now - self.last_log_time) - ) - self.last_log_time = now - self.last_global_step = global_step - self.last_local_step = local_step + def __init__(self, start_time, start_step, task): + self.start_time = start_time + self.start_step = start_step + self.last_log_time = start_time + self.last_global_step = start_step + self.last_local_step = 0 + self.task = task + + def get_last_log_time(self): + return self.last_log_time + + def log(self, now, global_step, local_step): + """Logs training progress.""" + logging.info( + 'Train [%s/%d], step %d (%.3f sec) %.1f ' + 'global steps/s, %.1f local steps/s', + self.task.type, + self.task.index, + global_step, + (now - self.start_time), + (global_step - self.last_global_step) / + (now - self.last_log_time), + (local_step - self.last_local_step) / + (now - self.last_log_time) + ) + self.last_log_time = now + self.last_global_step = global_step + self.last_local_step = local_step + def get_qualitative_evaluator(args, model, run_async): - if args.qualitative_data_paths: - return Evaluator( - args, - model, - train_dir(args.output_path), - args.qualitative_data_paths, - dataset='qualitative_set', - eval_set_size=args.qualitative_set_size or args.eval_set_size, - qualitative_set_size=args.qualitative_set_size, - run_async=run_async - ) - else: - return None + if args.qualitative_data_paths: + return Evaluator( + args, + model, + train_dir(args.output_path), + args.qualitative_data_paths, + dataset='qualitative_set', + eval_set_size=args.qualitative_set_size or args.eval_set_size, + qualitative_set_size=args.qualitative_set_size, + run_async=run_async + ) + else: + return None + def set_random_seed(seed): - get_logger().info('setting random seed to: %s', seed) - random.seed(seed) - tf.set_random_seed(seed) + get_logger().info('setting random seed to: %s', seed) + random.seed(seed) + tf.set_random_seed(seed) -class Trainer(object): - """Performs model training and optionally evaluation.""" - - def __init__(self, args, model, cluster, task): - self.args = args - self.model = model - self.cluster = cluster - self.task = task - self.run_async = None - default_run_async = lambda f, args: f(*args) - run_async = lambda f, args: (self.run_async or default_run_async)(f, args) - self.evaluator = Evaluator( - self.args, self.model, - train_dir(self.args.output_path), - self.args.eval_data_paths, - 'eval_set', - run_async=run_async - ) - self.train_evaluator = Evaluator( - self.args, self.model, - train_dir(self.args.output_path), - self.args.train_data_paths, - 'train_set', - run_async=run_async - ) - self.qualitative_evaluator = get_qualitative_evaluator( - self.args, - self.model, - run_async=run_async - ) - self.min_train_eval_rate = args.min_train_eval_rate - self.global_step = None - self.last_save = 0 - - def run_training(self): - get_logger().info('creating async pool, pool size: %d', self.args.pool_size) - pool = Pool(processes=self.args.pool_size) - self.run_async = lambda f, args: pool.apply_async(f, args) - self._do_run_training() - get_logger().info('Waiting for tasks to complete') - pool.close() - pool.join() - self.run_async = None - def _do_run_training(self): - """Runs a Master.""" - logger = get_logger() +class Trainer(object): + """Performs model training and optionally evaluation.""" + + def __init__(self, args, model, cluster, task): + self.args = args + self.model = model + self.cluster = cluster + self.task = task + self.run_async = None + + def default_run_async(f, args): + return f(*args) + + def run_async(f, args): + return (self.run_async or default_run_async)(f, args) + + self.evaluator = Evaluator( + self.args, self.model, + train_dir(self.args.output_path), + self.args.eval_data_paths, + 'eval_set', + run_async=run_async + ) + self.train_evaluator = Evaluator( + self.args, self.model, + train_dir(self.args.output_path), + self.args.train_data_paths, + 'train_set', + run_async=run_async + ) + self.qualitative_evaluator = get_qualitative_evaluator( + self.args, + self.model, + run_async=run_async + ) + self.min_train_eval_rate = args.min_train_eval_rate + self.global_step = None + self.last_save = 0 + + def run_training(self): + get_logger().info('creating async pool, pool size: %d', self.args.pool_size) + pool = Pool(processes=self.args.pool_size) + self.run_async = lambda f, args: pool.apply_async(f, args) + self._do_run_training() + get_logger().info('Waiting for tasks to complete') + pool.close() + pool.join() + self.run_async = None + + def _do_run_training(self): + """Runs a Master.""" + logger = get_logger() + + logger.info('tensorflow version: %s', tf.__version__) + + if self.args.seed is not None: + set_random_seed(self.args.seed) + self.train_evaluator.init() + self.evaluator.init() + if self.qualitative_evaluator: + self.qualitative_evaluator.init() + ensure_output_path(self.args.output_path) + train_path = train_dir(self.args.output_path) + # model_path = model_dir(self.args.output_path) + is_master = self.task.type != 'worker' + log_interval = self.args.log_interval_secs + save_interval = self.args.save_interval_secs + eval_interval = self.args.eval_interval_secs + summary_interval = log_interval + summary_freq = self.args.log_freq + if is_master and self.task.index > 0: + raise StandardError('Only one replica of master expected') + + if self.cluster: + logging.info('Starting %s/%d', self.task.type, self.task.index) + server = start_server(self.cluster, self.task) + target = server.target + device_fn = tf.train.replica_device_setter( + ps_device='/job:ps', + worker_device='/job:%s/task:%d' % (self.task.type, self.task.index), + cluster=self.cluster + ) + # We use a device_filter to limit the communication between this job + # and the parameter servers, i.e., there is no need to directly + # communicate with the other workers; attempting to do so can result + # in reliability problems. + device_filters = [ + '/job:ps', '/job:%s/task:%d' % (self.task.type, self.task.index) + ] + config = tf.ConfigProto(device_filters=device_filters) + else: + target = '' + device_fn = '' + config = None + + logger.info('batch_size: %s', self.args.batch_size) + + logger.info( + 'available devices: %s', + ', '.join([ + '%s (%s)' % (x.name, x.device_type) + for x in list_local_devices() + ]) + ) - logger.info('tensorflow version: %s', tf.__version__) - - if self.args.seed is not None: - set_random_seed(self.args.seed) - self.train_evaluator.init() - self.evaluator.init() - if self.qualitative_evaluator: - self.qualitative_evaluator.init() - ensure_output_path(self.args.output_path) - train_path = train_dir(self.args.output_path) - # model_path = model_dir(self.args.output_path) - is_master = self.task.type != 'worker' - log_interval = self.args.log_interval_secs - save_interval = self.args.save_interval_secs - eval_interval = self.args.eval_interval_secs - summary_interval = log_interval - summary_freq = self.args.log_freq - if is_master and self.task.index > 0: - raise StandardError('Only one replica of master expected') - - if self.cluster: - logging.info('Starting %s/%d', self.task.type, self.task.index) - server = start_server(self.cluster, self.task) - target = server.target - device_fn = tf.train.replica_device_setter( - ps_device='/job:ps', - worker_device='/job:%s/task:%d' % (self.task.type, self.task.index), - cluster=self.cluster - ) - # We use a device_filter to limit the communication between this job - # and the parameter servers, i.e., there is no need to directly - # communicate with the other workers; attempting to do so can result - # in reliability problems. - device_filters = [ - '/job:ps', '/job:%s/task:%d' % (self.task.type, self.task.index) - ] - config = tf.ConfigProto(device_filters=device_filters) - else: - target = '' - device_fn = '' - config = None - - logger.info('batch_size: %s', self.args.batch_size) - - logger.info( - 'available devices: %s', - ', '.join([ - '%s (%s)' % (x.name, x.device_type) - for x in list_local_devices() - ]) - ) - - with tf.Graph().as_default() as graph: - with tf.device(device_fn): - # Build the training graph. - logger.info('building graph...') - tensors = self.model.build_train_graph( - self.args.train_data_paths, - self.args.batch_size) - - logger.info('done building graph, calculating graph size...') - logger.info('graph_size: %s bytes', '{:,}'.format(get_graph_size())) - - # Create a saver for writing training checkpoints. - saver = tf.train.Saver( - max_to_keep=self.args.save_max_to_keep + with tf.Graph().as_default() as graph: + with tf.device(device_fn): + # Build the training graph. + logger.info('building graph...') + tensors = self.model.build_train_graph( + self.args.train_data_paths, + self.args.batch_size) + + logger.info('done building graph, calculating graph size...') + logger.info('graph_size: %s bytes', '{:,}'.format(get_graph_size())) + + # Create a saver for writing training checkpoints. + saver = tf.train.Saver( + max_to_keep=self.args.save_max_to_keep + ) + + # Create a "supervisor", which oversees the training process. + sv = CustomSupervisor( + model=self.model, + graph=graph, + is_chief=is_master, + logdir=train_path, + saver=saver, + # Write summary_ops by hand. + summary_op=None, + global_step=tensors.global_step ) - # Create a "supervisor", which oversees the training process. - sv = CustomSupervisor( - model=self.model, - graph=graph, - is_chief=is_master, - logdir=train_path, - saver=saver, - # Write summary_ops by hand. - summary_op=None, - global_step=tensors.global_step - ) - - save_path = sv.save_path - - should_retry = True - local_step = 0 - - while should_retry: - try: - should_retry = False - with sv.managed_session(target, config=config) as session: - start_time = time.time() - now = start_time - - global_step = session.run(tensors.global_step) - training_progress_logger = TrainingProgressLogger( - start_time, - global_step, - self.task - ) - - log_scheduler = SimpleStepScheduler( - lambda: training_progress_logger.log(now, global_step, local_step), - min_interval=log_interval, - min_freq=self.args.log_freq, - step=global_step, - last_run=start_time - ) - - def do_save(): - logger.info('saving model to %s (%s)', save_path, global_step) - saver.save(session, save_path, tensors.global_step) - - save_scheduler = SimpleStepScheduler( - do_save, - min_interval=save_interval, - min_freq=self.args.save_freq, - step=global_step, - last_run=start_time - ) - - eval_train_scheduler = SimpleStepScheduler( - lambda: self.eval_train(session, tensors, global_step), - min_interval=eval_interval, - min_freq=self.args.eval_freq, - step=global_step, - last_run=start_time - ) - - schedulers = [ - log_scheduler, - save_scheduler, - eval_train_scheduler - ] - - if is_master: - eval_scheduler = SimpleStepScheduler( - lambda: self.eval(global_step=global_step), - min_interval=save_interval, - min_freq=self.args.save_freq, - step=global_step, - last_run=start_time - ) - schedulers = schedulers + [eval_scheduler] - - summary_op = sv.summary_op if tensors.summary is None else tensors.summary - if summary_op is not None: - schedulers.append(SimpleStepScheduler( - lambda: sv.summary_writer.add_summary( - *session.run([summary_op, tensors.global_step]) - ), - min_interval=summary_interval, - min_freq=summary_freq, - step=global_step, - last_run=start_time - )) - - # Loop until the supervisor shuts down or args.max_steps have - # completed. - max_steps = self.args.max_steps - while not sv.should_stop() and global_step < max_steps: - logging.info("global_step: %s", global_step) - try: - # Run one step of the model. - global_step = session.run([tensors.global_step, tensors.train])[0] - logging.info("global_step: %s", global_step) - local_step += 1 + save_path = sv.save_path - now = time.time() - for scheduler in schedulers: - scheduler.step(now) + should_retry = True + local_step = 0 + + while should_retry: + try: + should_retry = False + with sv.managed_session(target, config=config) as session: + start_time = time.time() + now = start_time + + global_step = session.run(tensors.global_step) + training_progress_logger = TrainingProgressLogger( + start_time, + global_step, + self.task + ) + + log_scheduler = SimpleStepScheduler( + lambda: training_progress_logger.log(now, global_step, local_step), + min_interval=log_interval, + min_freq=self.args.log_freq, + step=global_step, + last_run=start_time + ) + + def do_save(): + logger.info('saving model to %s (%s)', save_path, global_step) + saver.save(session, save_path, tensors.global_step) + + save_scheduler = SimpleStepScheduler( + do_save, + min_interval=save_interval, + min_freq=self.args.save_freq, + step=global_step, + last_run=start_time + ) + + eval_train_scheduler = SimpleStepScheduler( + lambda: self.eval_train(session, tensors, global_step), + min_interval=eval_interval, + min_freq=self.args.eval_freq, + step=global_step, + last_run=start_time + ) + + schedulers = [ + log_scheduler, + save_scheduler, + eval_train_scheduler + ] + + if is_master: + eval_scheduler = SimpleStepScheduler( + lambda: self.eval(global_step=global_step), + min_interval=save_interval, + min_freq=self.args.save_freq, + step=global_step, + last_run=start_time + ) + schedulers = schedulers + [eval_scheduler] + + summary_op = sv.summary_op if tensors.summary is None else tensors.summary + if summary_op is not None: + schedulers.append(SimpleStepScheduler( + lambda: sv.summary_writer.add_summary( + *session.run([summary_op, tensors.global_step]) + ), + min_interval=summary_interval, + min_freq=summary_freq, + step=global_step, + last_run=start_time + )) + + # Loop until the supervisor shuts down or args.max_steps have + # completed. + max_steps = self.args.max_steps + while not sv.should_stop() and global_step < max_steps: + logging.info("global_step: %s", global_step) + try: + # Run one step of the model. + global_step = session.run([tensors.global_step, tensors.train])[0] + logging.info("global_step: %s", global_step) + local_step += 1 + + now = time.time() + for scheduler in schedulers: + scheduler.step(now) + + except tf.errors.AbortedError as e: + should_retry = True + logging.info('AbortedError (%s)', e) + except (KeyboardInterrupt, tf.errors.CancelledError): + logging.info('cancelled') + should_retry = False + + logging.info('finished (is_master: %s)', is_master) + + if is_master: + # Take the final checkpoint and compute the final accuracy. + now = time.time() + for scheduler in schedulers: + scheduler.flush(now) except tf.errors.AbortedError as e: - should_retry = True - logging.info('AbortedError (%s)', e) - except (KeyboardInterrupt, tf.errors.CancelledError): - logging.info('cancelled') - should_retry = False + should_retry = True + logging.info('AbortedError (%s)', e) - logging.info('finished (is_master: %s)', is_master) + # Ask for all the services to stop. + sv.stop() - if is_master: - # Take the final checkpoint and compute the final accuracy. - now = time.time() - for scheduler in schedulers: - scheduler.flush(now) + def eval_train(self, session, tensors, global_step): + """Runs evaluation loop.""" + logging.info( + 'Eval, step %d:\n- on train set %s', + global_step, + self.model.format_metric_values( + self.train_evaluator.evaluate_in_session( + session=session, + tensors=tensors + ) + ) + ) - except tf.errors.AbortedError as e: - should_retry = True - logging.info('AbortedError (%s)', e) - - # Ask for all the services to stop. - sv.stop() - - def eval_train(self, session, tensors, global_step): - """Runs evaluation loop.""" - logging.info( - 'Eval, step %d:\n- on train set %s', - global_step, - self.model.format_metric_values( - self.train_evaluator.evaluate_in_session( - session=session, - tensors=tensors + def eval(self, global_step=None): + """Runs evaluation loop.""" + if self.qualitative_evaluator: + logging.info( + 'Quantitive Eval, step %s:\n- on eval set %s', + global_step, + self.model.format_metric_values(self.qualitative_evaluator.evaluate()) + ) + logging.info( + 'Eval, step %s:\n- on eval set %s', + global_step, + self.model.format_metric_values(self.evaluator.evaluate()) ) - ) - ) - def eval(self, global_step=None): - """Runs evaluation loop.""" - if self.qualitative_evaluator: - logging.info( - 'Quantitive Eval, step %s:\n- on eval set %s', - global_step, - self.model.format_metric_values(self.qualitative_evaluator.evaluate()) - ) - logging.info( - 'Eval, step %s:\n- on eval set %s', - global_step, - self.model.format_metric_values(self.evaluator.evaluate()) - ) def copy_data_to_tmp(input_files): - """Copies data to /tmp/ and returns glob matching the files.""" - files = [] - for e in input_files: - for path in e.split(','): - files.extend(file_io.get_matching_files(path)) + """Copies data to /tmp/ and returns glob matching the files.""" + files = [] + for e in input_files: + for path in e.split(','): + files.extend(file_io.get_matching_files(path)) - for path in files: - if not path.startswith('gs://'): - return input_files + for path in files: + if not path.startswith('gs://'): + return input_files - tmp_path = os.path.join('/tmp/', str(uuid.uuid4())) - os.makedirs(tmp_path) - subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [tmp_path]) - return [os.path.join(tmp_path, '*')] + tmp_path = os.path.join('/tmp/', str(uuid.uuid4())) + os.makedirs(tmp_path) + subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [tmp_path]) + return [os.path.join(tmp_path, '*')] def write_predictions(args, model, cluster, task): - if not cluster or not task or task.type == 'master': - pass # Run locally. - else: - raise ValueError('invalid task_type %s' % (task.type,)) - - if args.seed is not None: - set_random_seed(args.seed) - - logger = get_logger() - logger.info('Starting to write predictions on %s/%d', task.type, task.index) - pool = Pool(processes=args.pool_size) - run_async = lambda f, args: pool.apply_async(f, args) - - qualitative_evaluator = get_qualitative_evaluator( - args, - model, - run_async=run_async - ) - if qualitative_evaluator: - qualitative_evaluator.init() - qualitative_evaluator.write_predictions() - - evaluator = Evaluator( - args, model, train_dir(args.output_path), args.eval_data_paths, - run_async=run_async - ) - evaluator.init() - evaluator.write_predictions() - - logger.info('Waiting for background tasks to finish') - pool.close() - pool.join() - logger.info('Done writing predictions on %s/%d', task.type, task.index) + if not cluster or not task or task.type == 'master': + pass # Run locally. + else: + raise ValueError('invalid task_type %s' % (task.type,)) + + if args.seed is not None: + set_random_seed(args.seed) + + logger = get_logger() + logger.info('Starting to write predictions on %s/%d', task.type, task.index) + pool = Pool(processes=args.pool_size) + + def run_async(f, args): + return pool.apply_async(f, args) + + qualitative_evaluator = get_qualitative_evaluator( + args, + model, + run_async=run_async + ) + if qualitative_evaluator: + qualitative_evaluator.init() + qualitative_evaluator.write_predictions() + + evaluator = Evaluator( + args, model, train_dir(args.output_path), args.eval_data_paths, + run_async=run_async + ) + evaluator.init() + evaluator.write_predictions() + + logger.info('Waiting for background tasks to finish') + pool.close() + pool.join() + logger.info('Done writing predictions on %s/%d', task.type, task.index) + def predict(args, model, cluster, task): - if not cluster or not task or task.type == 'master': - pass # Run locally. - else: - raise ValueError('invalid task_type %s' % (task.type,)) - - if args.seed is not None: - set_random_seed(args.seed) - - predict_filename = args.predict - output_image_filename = args.predict_output - assert output_image_filename - if args.model_export_path: - load_saved_model_and_predict(args.model_export_path, predict_filename, output_image_filename) - else: - checkpoint_path = train_dir(args.output_path) - load_checkpoint_and_predict(model, checkpoint_path, predict_filename, output_image_filename) + if not cluster or not task or task.type == 'master': + pass # Run locally. + else: + raise ValueError('invalid task_type %s' % (task.type,)) + + if args.seed is not None: + set_random_seed(args.seed) + + predict_filename = args.predict + output_image_filename = args.predict_output + assert output_image_filename + if args.model_export_path: + load_saved_model_and_predict(args.model_export_path, + predict_filename, output_image_filename) + else: + checkpoint_path = train_dir(args.output_path) + load_checkpoint_and_predict(model, checkpoint_path, predict_filename, output_image_filename) + def save_model(args, model, cluster, task): - if not cluster or not task or task.type == 'master': - pass # Run locally. - else: - raise ValueError('invalid task_type %s' % (task.type,)) + if not cluster or not task or task.type == 'master': + pass # Run locally. + else: + raise ValueError('invalid task_type %s' % (task.type,)) - if args.seed is not None: - set_random_seed(args.seed) + if args.seed is not None: + set_random_seed(args.seed) + + export_dir = args.save_model + checkpoint_path = train_dir(args.output_path) + load_checkpoint_and_save_model(model, checkpoint_path, export_dir) - export_dir = args.save_model - checkpoint_path = train_dir(args.output_path) - load_checkpoint_and_save_model(model, checkpoint_path, export_dir) def dispatch(args, model, cluster, task): - if not cluster or not task or task.type == 'master': - # Run locally. - Trainer(args, model, cluster, task).run_training() - elif task.type == 'ps': - run_parameter_server(cluster, task) - elif task.type == 'worker': - Trainer(args, model, cluster, task).run_training() - else: - raise ValueError('invalid task_type %s' % (task.type,)) + if not cluster or not task or task.type == 'master': + # Run locally. + Trainer(args, model, cluster, task).run_training() + elif task.type == 'ps': + run_parameter_server(cluster, task) + elif task.type == 'worker': + Trainer(args, model, cluster, task).run_training() + else: + raise ValueError('invalid task_type %s' % (task.type,)) def run_parameter_server(cluster, task): - logging.info('Starting parameter server %d', task.index) - server = start_server(cluster, task) - server.join() + logging.info('Starting parameter server %d', task.index) + server = start_server(cluster, task) + server.join() def start_server(cluster, task): - if not task.type: - raise ValueError('--task_type must be specified.') - if task.index is None: - raise ValueError('--task_index must be specified.') - - # Create and start a server. - return tf.train.Server( - tf.train.ClusterSpec(cluster), - protocol='grpc', - job_name=task.type, - task_index=task.index - ) + if not task.type: + raise ValueError('--task_type must be specified.') + if task.index is None: + raise ValueError('--task_index must be specified.') + + # Create and start a server. + return tf.train.Server( + tf.train.ClusterSpec(cluster), + protocol='grpc', + job_name=task.type, + task_index=task.index + ) def ensure_output_path(output_path): - if not output_path: - raise ValueError('output_path must be specified') + if not output_path: + raise ValueError('output_path must be specified') - # GCS doesn't have real directories. - if output_path.startswith('gs://'): - return + # GCS doesn't have real directories. + if output_path.startswith('gs://'): + return - ensure_dir(output_path) + ensure_dir(output_path) def ensure_dir(path): - try: - os.makedirs(path) - except OSError as e: - # If the directory already existed, ignore the error. - if e.args[0] == 17: - pass - else: - raise + try: + os.makedirs(path) + except OSError as e: + # If the directory already existed, ignore the error. + if e.args[0] == 17: + pass + else: + raise def train_dir(output_path): - return os.path.join(output_path, 'train') + return os.path.join(output_path, 'train') def eval_dir(output_path): - return os.path.join(output_path, 'eval') + return os.path.join(output_path, 'eval') def model_dir(output_path): - return os.path.join(output_path, 'model') + return os.path.join(output_path, 'model') + def run(model, argv): - """Runs the training loop.""" - parser = argparse.ArgumentParser() - parser.add_argument( - '--train_data_paths', - nargs='+', - type=str, - help='The paths to the training data files. ' - 'Can be comma separated list of files or glob pattern.' - ) - parser.add_argument( - '--eval_data_paths', - nargs='+', - type=str, - help='The path to the files used for evaluation. ' - 'Can be comma separated list of files or glob pattern.' - ) - parser.add_argument( - '--qualitative_data_paths', - nargs='+', - type=str, - help='The path to the files used for qualitative evaluation. ' - 'You may choose a different set for the qualitative analysis to keep the results consistent.' - ) - parser.add_argument( - '--output_path', - type=str, - help='The path to which checkpoints and other outputs ' - 'should be saved. This can be either a local or GCS ' - 'path.' - ) - parser.add_argument( - '--max_steps', - type=int, - default=1000 - ) - parser.add_argument( - '--batch_size', - type=int, - default=100, - help='Number of examples to be processed per mini-batch.' - ) - parser.add_argument( - '--eval_set_size', type=int, default=370, - help='Number of examples in the eval set.' - ) - parser.add_argument( - '--qualitative_set_size', - type=int, - help='Number of examples in the qualitative eval set.' - ) - parser.add_argument( - '--eval_batch_size', type=int, help='Number of examples per eval batch.' - ) - parser.add_argument( - '--eval_interval_secs', - type=float, - default=5, - help='Minimal interval between calculating evaluation metrics and saving' - ' evaluation summaries.' - ) - parser.add_argument( - '--eval_freq', - type=int, - default=100, - help='Frequancy in steps between calculating evaluation metrics and saving' - ' evaluation summaries.' - ) - parser.add_argument( - '--save_interval_secs', - type=float, - default=300, - help='Minimal interval between saving the model checkpoint' - ) - parser.add_argument( - '--save_freq', - type=int, - default=1000, - help='Frequancy in steps between saving the model checkpoint' - ) - parser.add_argument( - '--save_max_to_keep', - type=int, - default=2, - help='Maximum number of recent checkpoint files to keep' - ) - parser.add_argument( - '--log_interval_secs', - type=float, - default=5, - help='Minimal interval between logging training metrics and saving ' - 'training summaries.' - ) - parser.add_argument( - '--log_freq', - type=int, - default=500, - help='Frequancy in steps between logging training metrics and saving ' - 'training summaries.' - ) - parser.add_argument( - '--write_predictions', - action='store_true', - default=False, - help='If set, model is restored from latest checkpoint ' - 'and predictions are written to a csv file and no training is performed.' - ) - parser.add_argument( - '--predict', - type=str, - default=False, - help='If set, predicts output for a given image using the latest checkpoint.' - ) - parser.add_argument( - '--predict-output', - type=str, - default=False, - help='The output file for the prediction.' - ) - parser.add_argument( - '--model_export_path', - type=str, - default=False, - help='If specified, predict using the "saved model".' - ) - parser.add_argument( - '--save_model', - type=str, - default=False, - help='If specified, export directory for the latest checkpoint' - ' to be saved as a more portable "saved model".' - ) - parser.add_argument( - '--min_train_eval_rate', - type=int, - default=20, - help='Minimal train / eval time ratio on master. ' - 'Default value 20 means that 20x more time is used for training than ' - 'for evaluation. If evaluation takes more time the eval_interval_secs ' - 'is increased.' - ) - parser.add_argument( - '--write_to_tmp', - action='store_true', - default=False, - help='If set, all checkpoints and summaries are written to ' - 'local filesystem (/tmp/) and copied to gcs once training is done. ' - 'This can speed up training but if training job fails all the summaries ' - 'and checkpoints are lost.' - ) - parser.add_argument( - '--copy_train_data_to_tmp', - action='store_true', - default=False, - help='If set, training data is copied to local filesystem ' - '(/tmp/). This can speed up training but requires extra space on the ' - 'local filesystem.' - ) - parser.add_argument( - '--copy_eval_data_to_tmp', - action='store_true', - default=False, - help='If set, evaluation data is copied to local filesystem ' - '(/tmp/). This can speed up training but requires extra space on the ' - 'local filesystem.' - ) - parser.add_argument( - '--streaming_eval', - action='store_true', - default=False, - help='If set to True the evaluation is performed in streaming mode. ' - 'During each eval cycle the evaluation data is read and parsed from ' - 'files. This allows for having very large evaluation set. ' - 'If set to False (default) evaluation data is read once and cached in ' - 'memory. This results in faster evaluation cycle but can potentially ' - 'use more memory (in streaming mode large per-file read-ahead buffer is ' - 'used - which may exceed eval data size).' - ) - parser.add_argument( - '--pool_size', - type=int, - default=50, - help='Number of examples in the eval set.' - ) - parser.add_argument( - '--seed', - type=int, - help='The random seed to use' - ) - - args = parser.parse_args(argv) - - env = json.loads(os.environ.get('TF_CONFIG', '{}')) - - # Print the job data as provided by the service. - logging.info('Original job data: %s', env.get('job', {})) - - # First find out if there's a task value on the environment variable. - # If there is none or it is empty define a default one. - task_data = env.get('task', None) or {'type': 'master', 'index': 0} - task = type('TaskSpec', (object,), task_data) - trial = task_data.get('trial') - if trial is not None: - args.output_path = os.path.join(args.output_path, trial) - if args.write_to_tmp and args.output_path.startswith('gs://'): - output_path = args.output_path - args.output_path = os.path.join('/tmp/', str(uuid.uuid4())) - os.makedirs(args.output_path) - else: - output_path = None - - if args.copy_train_data_to_tmp: - args.train_data_paths = copy_data_to_tmp(args.train_data_paths) - if args.copy_eval_data_to_tmp: - args.eval_data_paths = copy_data_to_tmp(args.eval_data_paths) - - if not args.eval_batch_size: - # If eval_batch_size not set, use min of batch_size and eval_set_size - args.eval_batch_size = min(args.batch_size, args.eval_set_size) - logging.info("setting eval batch size to %s", args.eval_batch_size) - - cluster_data = env.get('cluster', None) - cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None - if args.write_predictions: - write_predictions(args, model, cluster, task) - elif args.save_model: - save_model(args, model, cluster, task) - elif args.predict: - predict(args, model, cluster, task) - else: - dispatch(args, model, cluster, task) - - if output_path and (not cluster or not task or task.type == 'master'): - subprocess.check_call([ - 'gsutil', '-m', '-q', 'cp', '-r', args.output_path + '/*', output_path - ]) - shutil.rmtree(args.output_path, ignore_errors=True) + """Runs the training loop.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--train_data_paths', + nargs='+', + type=str, + help='The paths to the training data files. ' + 'Can be comma separated list of files or glob pattern.' + ) + parser.add_argument( + '--eval_data_paths', + nargs='+', + type=str, + help='The path to the files used for evaluation. ' + 'Can be comma separated list of files or glob pattern.' + ) + parser.add_argument( + '--qualitative_data_paths', + nargs='+', + type=str, + help='The path to the files used for qualitative evaluation.' + ' You may choose a different set for the qualitative analysis' + ' to keep the results consistent.' + ) + parser.add_argument( + '--output_path', + type=str, + help='The path to which checkpoints and other outputs ' + 'should be saved. This can be either a local or GCS ' + 'path.' + ) + parser.add_argument( + '--max_steps', + type=int, + default=1000 + ) + parser.add_argument( + '--batch_size', + type=int, + default=100, + help='Number of examples to be processed per mini-batch.' + ) + parser.add_argument( + '--eval_set_size', type=int, default=370, + help='Number of examples in the eval set.' + ) + parser.add_argument( + '--qualitative_set_size', + type=int, + help='Number of examples in the qualitative eval set.' + ) + parser.add_argument( + '--eval_batch_size', type=int, help='Number of examples per eval batch.' + ) + parser.add_argument( + '--eval_interval_secs', + type=float, + default=5, + help='Minimal interval between calculating evaluation metrics and saving' + ' evaluation summaries.' + ) + parser.add_argument( + '--eval_freq', + type=int, + default=100, + help='Frequancy in steps between calculating evaluation metrics and saving' + ' evaluation summaries.' + ) + parser.add_argument( + '--save_interval_secs', + type=float, + default=300, + help='Minimal interval between saving the model checkpoint' + ) + parser.add_argument( + '--save_freq', + type=int, + default=1000, + help='Frequancy in steps between saving the model checkpoint' + ) + parser.add_argument( + '--save_max_to_keep', + type=int, + default=2, + help='Maximum number of recent checkpoint files to keep' + ) + parser.add_argument( + '--log_interval_secs', + type=float, + default=5, + help='Minimal interval between logging training metrics and saving ' + 'training summaries.' + ) + parser.add_argument( + '--log_freq', + type=int, + default=500, + help='Frequancy in steps between logging training metrics and saving ' + 'training summaries.' + ) + parser.add_argument( + '--write_predictions', + action='store_true', + default=False, + help='If set, model is restored from latest checkpoint ' + 'and predictions are written to a csv file and no training is performed.' + ) + parser.add_argument( + '--predict', + type=str, + default=False, + help='If set, predicts output for a given image using the latest checkpoint.' + ) + parser.add_argument( + '--predict-output', + type=str, + default=False, + help='The output file for the prediction.' + ) + parser.add_argument( + '--model_export_path', + type=str, + default=False, + help='If specified, predict using the "saved model".' + ) + parser.add_argument( + '--save_model', + type=str, + default=False, + help='If specified, export directory for the latest checkpoint' + ' to be saved as a more portable "saved model".' + ) + parser.add_argument( + '--min_train_eval_rate', + type=int, + default=20, + help='Minimal train / eval time ratio on master. ' + 'Default value 20 means that 20x more time is used for training than ' + 'for evaluation. If evaluation takes more time the eval_interval_secs ' + 'is increased.' + ) + parser.add_argument( + '--write_to_tmp', + action='store_true', + default=False, + help='If set, all checkpoints and summaries are written to ' + 'local filesystem (/tmp/) and copied to gcs once training is done. ' + 'This can speed up training but if training job fails all the summaries ' + 'and checkpoints are lost.' + ) + parser.add_argument( + '--copy_train_data_to_tmp', + action='store_true', + default=False, + help='If set, training data is copied to local filesystem ' + '(/tmp/). This can speed up training but requires extra space on the ' + 'local filesystem.' + ) + parser.add_argument( + '--copy_eval_data_to_tmp', + action='store_true', + default=False, + help='If set, evaluation data is copied to local filesystem ' + '(/tmp/). This can speed up training but requires extra space on the ' + 'local filesystem.' + ) + parser.add_argument( + '--streaming_eval', + action='store_true', + default=False, + help='If set to True the evaluation is performed in streaming mode. ' + 'During each eval cycle the evaluation data is read and parsed from ' + 'files. This allows for having very large evaluation set. ' + 'If set to False (default) evaluation data is read once and cached in ' + 'memory. This results in faster evaluation cycle but can potentially ' + 'use more memory (in streaming mode large per-file read-ahead buffer is ' + 'used - which may exceed eval data size).' + ) + parser.add_argument( + '--pool_size', + type=int, + default=50, + help='Number of examples in the eval set.' + ) + parser.add_argument( + '--seed', + type=int, + help='The random seed to use' + ) + + args = parser.parse_args(argv) + + env = json.loads(os.environ.get('TF_CONFIG', '{}')) + + # Print the job data as provided by the service. + logging.info('Original job data: %s', env.get('job', {})) + + # First find out if there's a task value on the environment variable. + # If there is none or it is empty define a default one. + task_data = env.get('task', None) or {'type': 'master', 'index': 0} + task = type('TaskSpec', (object,), task_data) + trial = task_data.get('trial') + if trial is not None: + args.output_path = os.path.join(args.output_path, trial) + if args.write_to_tmp and args.output_path.startswith('gs://'): + output_path = args.output_path + args.output_path = os.path.join('/tmp/', str(uuid.uuid4())) + os.makedirs(args.output_path) + else: + output_path = None + + if args.copy_train_data_to_tmp: + args.train_data_paths = copy_data_to_tmp(args.train_data_paths) + if args.copy_eval_data_to_tmp: + args.eval_data_paths = copy_data_to_tmp(args.eval_data_paths) + + if not args.eval_batch_size: + # If eval_batch_size not set, use min of batch_size and eval_set_size + args.eval_batch_size = min(args.batch_size, args.eval_set_size) + logging.info("setting eval batch size to %s", args.eval_batch_size) + + cluster_data = env.get('cluster', None) + cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None + if args.write_predictions: + write_predictions(args, model, cluster, task) + elif args.save_model: + save_model(args, model, cluster, task) + elif args.predict: + predict(args, model, cluster, task) + else: + dispatch(args, model, cluster, task) + + is_master = not task or task.type == 'master' # pylint: disable=no-member + if output_path and (not cluster or is_master): + subprocess.check_call([ + 'gsutil', '-m', '-q', 'cp', '-r', args.output_path + '/*', output_path + ]) + shutil.rmtree(args.output_path, ignore_errors=True) + def get_model_factory(model_name): - if model_name == 'pix2pix': - import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as model_factory - return model_factory - raise Exception('unsupported model: {}'.format(model_name)) + if model_name == 'pix2pix': + import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as model_factory + return model_factory + raise Exception('unsupported model: {}'.format(model_name)) + def main(_): - parser = argparse.ArgumentParser() - parser.add_argument( - '--model', - type=str, - required=True, - help='The name of the model' - ) - args, other_args = parser.parse_known_args() + parser = argparse.ArgumentParser() + parser.add_argument( + '--model', + type=str, + required=True, + help='The name of the model' + ) + args, other_args = parser.parse_known_args() + + model_factory = get_model_factory(args.model) - model_factory = get_model_factory(args.model) + model, task_args = model_factory.create_model(other_args) + run(model, task_args) - model, task_args = model_factory.create_model(other_args) - run(model, task_args) if __name__ == '__main__': - logging.basicConfig(level=logging.INFO) - tf.app.run() + logging.basicConfig(level=logging.INFO) + tf.app.run() diff --git a/sciencebeam_gym/trainer/util.py b/sciencebeam_gym/trainer/util.py index 4d2b75c..557b914 100644 --- a/sciencebeam_gym/trainer/util.py +++ b/sciencebeam_gym/trainer/util.py @@ -3,119 +3,128 @@ import logging import numpy as np import tensorflow as tf -from tensorflow.python.framework import ops # pylint: disable=E0611 -from tensorflow.python.client.session import Session # pylint: disable=E0611 -from tensorflow.python.training.saver import get_checkpoint_state # pylint: disable=E0611 +from tensorflow.python.framework import ops # pylint: disable=E0611 +from tensorflow.python.client.session import Session # pylint: disable=E0611 +from tensorflow.python.training.saver import get_checkpoint_state # pylint: disable=E0611 + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + class CustomSessionManager(object): - def __init__(self, session_init_fn, graph=None): - self._session_init_fn = session_init_fn - if graph is None: - graph = ops.get_default_graph() - self._graph = graph - - def prepare_session(self, master, checkpoint_dir=None, saver=None, config=None, **_): - logger = get_logger() - logger.info('prepare_session') - session = Session(master, graph=self._graph, config=config) - self._session_init_fn(session) - if saver and checkpoint_dir: - ckpt = get_checkpoint_state(checkpoint_dir) - if ckpt and ckpt.model_checkpoint_path: - logger.info('restoring from %s', ckpt.model_checkpoint_path) - saver.restore(session, ckpt.model_checkpoint_path) - saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) - else: - logger.info('no valid checkpoint in %s', checkpoint_dir) - return session + def __init__(self, session_init_fn, graph=None): + self._session_init_fn = session_init_fn + if graph is None: + graph = ops.get_default_graph() + self._graph = graph + + def prepare_session(self, master, checkpoint_dir=None, saver=None, config=None, **_): + logger = get_logger() + logger.info('prepare_session') + session = Session(master, graph=self._graph, config=config) + self._session_init_fn(session) + if saver and checkpoint_dir: + ckpt = get_checkpoint_state(checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: # pylint: disable=no-member + logger.info('restoring from %s', + ckpt.model_checkpoint_path) # pylint: disable=no-member + saver.restore(session, ckpt.model_checkpoint_path) # pylint: disable=no-member + saver.recover_last_checkpoints( + ckpt.all_model_checkpoint_paths) # pylint: disable=no-member + else: + logger.info('no valid checkpoint in %s', checkpoint_dir) + return session + class CustomSupervisor(tf.train.Supervisor): - def __init__(self, model, graph, init_op=None, ready_op=None, save_model_secs=0, **kwargs): - with graph.as_default(): - init_op = tf.global_variables_initializer() - - def custom_init(session): - logging.info('initializing, session: %s', session) - session.run(init_op) - model.initialize(session) - return True - - session_manager = CustomSessionManager( - session_init_fn=custom_init, - graph=graph - ) - super(CustomSupervisor, self).__init__( - session_manager=session_manager, - graph=graph, - init_op=init_op, - ready_op=ready_op, - save_model_secs=save_model_secs, - **kwargs - ) + def __init__(self, model, graph, init_op=None, ready_op=None, save_model_secs=0, **kwargs): + with graph.as_default(): + init_op = tf.global_variables_initializer() + + def custom_init(session): + logging.info('initializing, session: %s', session) + session.run(init_op) + model.initialize(session) + return True + + session_manager = CustomSessionManager( + session_init_fn=custom_init, + graph=graph + ) + super(CustomSupervisor, self).__init__( + session_manager=session_manager, + graph=graph, + init_op=init_op, + ready_op=ready_op, + save_model_secs=save_model_secs, + **kwargs + ) + class SimpleStepScheduler(object): - """ - Rather than using threads, with this scheduler the client has full control. - For example it can be triggered any time intentionally or at the end. - """ - def __init__(self, do_fn, min_interval, min_freq=0, step=0, last_run=None): - self.do_fn = do_fn - self.min_interval = min_interval - self.min_freq = min_freq - self.current_step = step - self.last_run = last_run - self.dirty = False - - def run_now(self, now): - self.do_fn() - self.last_run = now - self.dirty = False - - def should_trigger(self, now): - result = ( - ( - (self.min_freq > 0) and - (self.current_step % self.min_freq == 0) - ) or - ( - (self.min_interval > 0) and - (self.last_run is None or (now - self.last_run) >= self.min_interval) - ) - ) - if result: - get_logger().info( - 'should_trigger: current_step:%s, min_freq=%s, now=%s, ' - 'last_run=%s, min_interval=%s, result=%s', - self.current_step, self.min_freq, now, - self.last_run, self.min_interval, result - ) - return result - - def step(self, now): - self.current_step += 1 - if self.should_trigger(now=now): - self.run_now(now=now) - else: - self.dirty = True - - def flush(self, now): - if self.dirty: - self.run_now(now) + """ + Rather than using threads, with this scheduler the client has full control. + For example it can be triggered any time intentionally or at the end. + """ + + def __init__(self, do_fn, min_interval, min_freq=0, step=0, last_run=None): + self.do_fn = do_fn + self.min_interval = min_interval + self.min_freq = min_freq + self.current_step = step + self.last_run = last_run + self.dirty = False + + def run_now(self, now): + self.do_fn() + self.last_run = now + self.dirty = False + + def should_trigger(self, now): + result = ( + ( + (self.min_freq > 0) and + (self.current_step % self.min_freq == 0) + ) or + ( + (self.min_interval > 0) and + (self.last_run is None or (now - self.last_run) >= self.min_interval) + ) + ) + if result: + get_logger().info( + 'should_trigger: current_step:%s, min_freq=%s, now=%s, ' + 'last_run=%s, min_interval=%s, result=%s', + self.current_step, self.min_freq, now, + self.last_run, self.min_interval, result + ) + return result + + def step(self, now): + self.current_step += 1 + if self.should_trigger(now=now): + self.run_now(now=now) + else: + self.dirty = True + + def flush(self, now): + if self.dirty: + self.run_now(now) + def loss(loss_value): - """Calculates aggregated mean loss.""" - total_loss = tf.Variable(0.0, False) - loss_count = tf.Variable(0, False) - total_loss_update = tf.assign_add(total_loss, loss_value) - loss_count_update = tf.assign_add(loss_count, 1) - loss_op = total_loss / tf.cast(loss_count, tf.float32) - return [total_loss_update, loss_count_update], loss_op + """Calculates aggregated mean loss.""" + total_loss = tf.Variable(0.0, False) + loss_count = tf.Variable(0, False) + total_loss_update = tf.assign_add(total_loss, loss_value) + loss_count_update = tf.assign_add(loss_count, 1) + loss_op = total_loss / tf.cast(loss_count, tf.float32) + return [total_loss_update, loss_count_update], loss_op + def get_graph_size(): - return sum([ - int(np.product(v.get_shape().as_list()) * v.dtype.size) - for v in tf.global_variables() - ]) + return sum([ + int(np.product(v.get_shape().as_list()) * v.dtype.size) + for v in tf.global_variables() + ]) diff --git a/sciencebeam_gym/utils/bounding_box.py b/sciencebeam_gym/utils/bounding_box.py index 5531c4b..4a4c719 100644 --- a/sciencebeam_gym/utils/bounding_box.py +++ b/sciencebeam_gym/utils/bounding_box.py @@ -1,103 +1,111 @@ class BoundingRange(object): - def __init__(self, start, length): - self.start = start - self.length = length - if length < 0: - raise ValueError('length must not be less than zero, was: ' + str(length)) + def __init__(self, start, length): + self.start = start + self.length = length + if length < 0: + raise ValueError('length must not be less than zero, was: ' + str(length)) - def __str__(self): - return '({}, {})'.format(self.start, self.length) + def __str__(self): + return '({}, {})'.format(self.start, self.length) - def __len__(self): - return self.length + def __len__(self): + return self.length - def empty(self): - return self.length == 0 + def empty(self): + return self.length == 0 - def intersects(self, other): - return (self.start < other.start + other.length) and (other.start < self.start + self.length) + def intersects(self, other): + return ( + (self.start < other.start + other.length) and + (other.start < self.start + self.length) + ) - def include(self, other): - if other.empty(): - return self - if self.empty(): - return other - start = min(self.start, other.start) - length = max(self.start + self.length, other.start + other.length) - start - return BoundingRange(start, length) + def include(self, other): + if other.empty(): + return self + if self.empty(): + return other + start = min(self.start, other.start) + length = max(self.start + self.length, other.start + other.length) - start + return BoundingRange(start, length) + + def __add__(self, other): + return self.include(other) - def __add__(self, other): - return self.include(other) class BoundingBox(object): - def __init__(self, x, y, width, height): - self.x = x - self.y = y - self.width = width - self.height = height - if width < 0: - raise ValueError('width must not be less than zero, was: ' + str(width)) - if height < 0: - raise ValueError('height must not be less than zero, was: ' + str(height)) - - def __str__(self): - return '({}, {}, {}, {})'.format(self.x, self.y, self.width, self.height) - - def __repr__(self): - return 'BB({}, {}, {}, {})'.format(self.x, self.y, self.width, self.height) - - def empty(self): - return self.width == 0 or self.height == 0 - - def move_by(self, rx, ry): - return BoundingBox(self.x + rx, self.y + ry, self.width, self.height) - - def scale_by(self, rx, ry): - return BoundingBox(self.x * rx, self.y * ry, self.width * rx, self.height * ry) - - def include(self, other): - if other.empty(): - return self - if self.empty(): - return other - x = min(self.x, other.x) - y = min(self.y, other.y) - w = max(self.x + self.width, other.x + other.width) - x - h = max(self.y + self.height, other.y + other.height) - y - return BoundingBox(x, y, w, h) - - def with_margin(self, dx, dy=None): - if dy is None: - dy = dx - return BoundingBox( - self.x - dx, - self.y - dy, - self.width + 2 * dx, - self.height + 2 * dy - ) - - def intersects(self, other): - return self.x_range().intersects(other.x_range()) and self.y_range().intersects(other.y_range()) - - def __add__(self, bb): - return self.include(bb) - - def x_range(self): - return BoundingRange(self.x, self.width) - - def y_range(self): - return BoundingRange(self.y, self.height) - - def __eq__(self, other): - return ( - other is not None and - self.x == other.x and - self.y == other.y and - self.width == other.width and - self.height == other.height - ) - - def __hash__(self): - return hash((self.x, self.y, self.width, self.height)) + def __init__(self, x, y, width, height): + self.x = x + self.y = y + self.width = width + self.height = height + if width < 0: + raise ValueError('width must not be less than zero, was: ' + str(width)) + if height < 0: + raise ValueError('height must not be less than zero, was: ' + str(height)) + + def __str__(self): + return '({}, {}, {}, {})'.format(self.x, self.y, self.width, self.height) + + def __repr__(self): + return 'BB({}, {}, {}, {})'.format(self.x, self.y, self.width, self.height) + + def empty(self): + return self.width == 0 or self.height == 0 + + def move_by(self, rx, ry): + return BoundingBox(self.x + rx, self.y + ry, self.width, self.height) + + def scale_by(self, rx, ry): + return BoundingBox(self.x * rx, self.y * ry, self.width * rx, self.height * ry) + + def include(self, other): + if other.empty(): + return self + if self.empty(): + return other + x = min(self.x, other.x) + y = min(self.y, other.y) + w = max(self.x + self.width, other.x + other.width) - x + h = max(self.y + self.height, other.y + other.height) - y + return BoundingBox(x, y, w, h) + + def with_margin(self, dx, dy=None): + if dy is None: + dy = dx + return BoundingBox( + self.x - dx, + self.y - dy, + self.width + 2 * dx, + self.height + 2 * dy + ) + + def intersects(self, other): + return ( + self.x_range().intersects(other.x_range()) and + self.y_range().intersects(other.y_range()) + ) + + def __add__(self, bb): + return self.include(bb) + + def x_range(self): + return BoundingRange(self.x, self.width) + + def y_range(self): + return BoundingRange(self.y, self.height) + + def __eq__(self, other): + return ( + other is not None and + self.x == other.x and + self.y == other.y and + self.width == other.width and + self.height == other.height + ) + + def __hash__(self): + return hash((self.x, self.y, self.width, self.height)) + BoundingBox.EMPTY = BoundingBox(0, 0, 0, 0) diff --git a/sciencebeam_gym/utils/bounding_box_test.py b/sciencebeam_gym/utils/bounding_box_test.py index d7767ad..457ce53 100644 --- a/sciencebeam_gym/utils/bounding_box_test.py +++ b/sciencebeam_gym/utils/bounding_box_test.py @@ -1,46 +1,44 @@ from sciencebeam_gym.utils.bounding_box import ( - BoundingBox + BoundingBox ) -class TestBoundingBox(object): - def test_should_indicate_empty_with_zero_width(self): - assert BoundingBox(0, 0, 0, 100).empty() - def test_should_indicate_empty_with_zero_height(self): - assert BoundingBox(0, 0, 100, 0).empty() +class TestBoundingBox(object): + def test_should_indicate_empty_with_zero_width(self): + assert BoundingBox(0, 0, 0, 100).empty() - def test_should_indicate_not_be_empty_with_non_zero_width_and_height(self): - assert not BoundingBox(0, 0, 100, 100).empty() + def test_should_indicate_empty_with_zero_height(self): + assert BoundingBox(0, 0, 100, 0).empty() - def test_should_equal_same_bounding_boxes(self): - assert BoundingBox(11, 12, 101, 102) == BoundingBox(11, 12, 101, 102) + def test_should_indicate_not_be_empty_with_non_zero_width_and_height(self): + assert not BoundingBox(0, 0, 100, 100).empty() - def test_should_not_equal_bounding_boxes_with_different_x(self): - assert BoundingBox(11, 12, 101, 102) != BoundingBox(99, 12, 101, 102) + def test_should_equal_same_bounding_boxes(self): + assert BoundingBox(11, 12, 101, 102) == BoundingBox(11, 12, 101, 102) - def test_should_not_equal_bounding_boxes_with_different_y(self): - assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 99, 101, 102) + def test_should_not_equal_bounding_boxes_with_different_x(self): + assert BoundingBox(11, 12, 101, 102) != BoundingBox(99, 12, 101, 102) - def test_should_not_equal_bounding_boxes_with_different_width(self): - assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 12, 999, 102) + def test_should_not_equal_bounding_boxes_with_different_y(self): + assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 99, 101, 102) - def test_should_not_equal_bounding_boxes_with_different_height(self): - assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 12, 101, 999) + def test_should_not_equal_bounding_boxes_with_different_width(self): + assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 12, 999, 102) - def test_should_not_equal_none(self): - assert BoundingBox(11, 12, 101, 102) != None + def test_should_not_equal_bounding_boxes_with_different_height(self): + assert BoundingBox(11, 12, 101, 102) != BoundingBox(11, 12, 101, 999) - def test_should_not_equal_to_none(self): - assert None != BoundingBox(11, 12, 101, 102) + def test_should_not_equal_none(self): + assert not BoundingBox(11, 12, 101, 102).__eq__(None) - def test_should_include_another_bounding_box_to_the_bottom_right(self): - assert ( - BoundingBox(10, 20, 50, 100).include(BoundingBox(100, 100, 200, 200)) == - BoundingBox(10, 20, 100 + 200 - 10, 100 + 200 - 20) - ) + def test_should_include_another_bounding_box_to_the_bottom_right(self): + assert ( + BoundingBox(10, 20, 50, 100).include(BoundingBox(100, 100, 200, 200)) == + BoundingBox(10, 20, 100 + 200 - 10, 100 + 200 - 20) + ) - def test_should_include_another_bounding_box_to_the_top_left(self): - assert ( - BoundingBox(100, 100, 200, 200).include(BoundingBox(10, 20, 50, 100)) == - BoundingBox(10, 20, 100 + 200 - 10, 100 + 200 - 20) - ) + def test_should_include_another_bounding_box_to_the_top_left(self): + assert ( + BoundingBox(100, 100, 200, 200).include(BoundingBox(10, 20, 50, 100)) == + BoundingBox(10, 20, 100 + 200 - 10, 100 + 200 - 20) + ) diff --git a/sciencebeam_gym/utils/pages_zip.py b/sciencebeam_gym/utils/pages_zip.py index f992400..f2f979b 100644 --- a/sciencebeam_gym/utils/pages_zip.py +++ b/sciencebeam_gym/utils/pages_zip.py @@ -4,32 +4,35 @@ from zipfile import ZipFile, ZIP_DEFLATED from apache_beam.io.filesystems import FileSystems from sciencebeam_utils.beam_utils.io import ( - dirname, - mkdirs_if_not_exists + dirname, + mkdirs_if_not_exists ) + def get_logger(): - return logging.getLogger(__name__) + return logging.getLogger(__name__) + def load_pages(filename, page_range=None): - with FileSystems.open(filename) as f: - with ZipFile(f) as zf: - filenames = zf.namelist() - if page_range: - filenames = filenames[ - max(0, page_range[0] - 1): - page_range[1] - ] - for filename in filenames: - with zf.open(filename) as f: - yield f + with FileSystems.open(filename) as f: + with ZipFile(f) as zf: + filenames = zf.namelist() + if page_range: + filenames = filenames[ + max(0, page_range[0] - 1): + page_range[1] + ] + for member_filename in filenames: + with zf.open(member_filename) as f: + yield f + def save_pages(output_filename, ext, bytes_by_page): - mkdirs_if_not_exists(dirname(output_filename)) - with FileSystems.create(output_filename) as f: - with ZipFile(f, 'w', compression=ZIP_DEFLATED) as zf: - for i, data in enumerate(bytes_by_page): - page_filename = 'page-%s%s' % (1 + i, ext) - get_logger().debug('page_filename: %s', page_filename) - zf.writestr(page_filename, data) - return output_filename + mkdirs_if_not_exists(dirname(output_filename)) + with FileSystems.create(output_filename) as f: + with ZipFile(f, 'w', compression=ZIP_DEFLATED) as zf: + for i, data in enumerate(bytes_by_page): + page_filename = 'page-%s%s' % (1 + i, ext) + get_logger().debug('page_filename: %s', page_filename) + zf.writestr(page_filename, data) + return output_filename diff --git a/sciencebeam_gym/utils/pyplot.py b/sciencebeam_gym/utils/pyplot.py new file mode 100644 index 0000000..622d0a9 --- /dev/null +++ b/sciencebeam_gym/utils/pyplot.py @@ -0,0 +1,6 @@ +import matplotlib as mpl +# this is important to run on the cloud - we won't have python-tk installed +mpl.use("Agg") + +# pylint: disable=unused-import, wrong-import-position +from matplotlib import pyplot # flake8: noqa diff --git a/sciencebeam_gym/utils/tf.py b/sciencebeam_gym/utils/tf.py index dab93df..1f78b4b 100644 --- a/sciencebeam_gym/utils/tf.py +++ b/sciencebeam_gym/utils/tf.py @@ -5,19 +5,22 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.framework import errors as tf_errors # pylint: enable=E0611 -def variable_scoped(name, fn): - with tf.variable_scope(name): - return fn() + +def variable_scoped(name, fn, *args, **kwargs): + with tf.variable_scope(name): + return fn(*args, **kwargs) + def tf_print(x, message=None, **kwargs): - return tf.Print(x, [x], message=message, **kwargs) + return tf.Print(x, [x], message=message, **kwargs) + def FileIO(filename, mode): - try: - return file_io.FileIO(filename, mode) - except tf_errors.InvalidArgumentError: - if 'b' in mode: - # older versions of TF don't support the 'b' flag as such - return file_io.FileIO(filename, mode.replace('b', '')) - else: - raise + try: + return file_io.FileIO(filename, mode) + except tf_errors.InvalidArgumentError: + if 'b' in mode: + # older versions of TF don't support the 'b' flag as such + return file_io.FileIO(filename, mode.replace('b', '')) + else: + raise diff --git a/sciencebeam_gym/utils/tfrecord.py b/sciencebeam_gym/utils/tfrecord.py index 8cd7606..5a827c3 100644 --- a/sciencebeam_gym/utils/tfrecord.py +++ b/sciencebeam_gym/utils/tfrecord.py @@ -4,50 +4,56 @@ from six import iteritems, raise_from, text_type, binary_type import tensorflow as tf -def encode_value_as_feature(value, name): - try: - if isinstance(value, text_type): - value = value.encode('utf-8') - if isinstance(value, binary_type): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - if isinstance(value, int): - return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) - raise TypeError('unsupported type: %s' % type(value)) - except TypeError as e: - raise_from(TypeError('failed to convert %s due to %s' % (name, e)), e) + +def encode_value_as_feature(value, name): # pylint: disable=inconsistent-return-statements + try: + if isinstance(value, text_type): + value = value.encode('utf-8') + if isinstance(value, binary_type): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + if isinstance(value, int): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + raise TypeError('unsupported type: %s' % type(value)) + except TypeError as e: + raise_from(TypeError('failed to convert %s due to %s' % (name, e)), e) + def decode_feature_value(feature): - if feature.bytes_list.value: - return feature.bytes_list.value[0] - if feature.int64_list.value: - return feature.int64_list.value[0] - raise TypeError('unsupported feature: %s' % feature) + if feature.bytes_list.value: + return feature.bytes_list.value[0] + if feature.int64_list.value: + return feature.int64_list.value[0] + raise TypeError('unsupported feature: %s' % feature) + def iter_examples_to_dict_list(examples, keys=None): - for example in examples: - result = tf.train.Example.FromString(example) - yield { - key: decode_feature_value(result.features.feature.get(key)) - for key in result.features.feature.keys() - if keys is None or key in keys - } + for example in examples: + result = tf.train.Example.FromString(example) # pylint: disable=no-member + yield { + key: decode_feature_value(result.features.feature.get(key)) + for key in result.features.feature.keys() + if keys is None or key in keys + } + def iter_read_tfrecord_file_as_dict_list(filename, keys=None): - options = None - if filename.endswith('.gz'): - options = tf.python_io.TFRecordOptions( - compression_type=tf.python_io.TFRecordCompressionType.GZIP - ) - examples = tf.python_io.tf_record_iterator(filename, options=options) - return iter_examples_to_dict_list(examples, keys=keys) + options = None + if filename.endswith('.gz'): + options = tf.python_io.TFRecordOptions( + compression_type=tf.python_io.TFRecordCompressionType.GZIP + ) + examples = tf.python_io.tf_record_iterator(filename, options=options) + return iter_examples_to_dict_list(examples, keys=keys) + def dict_to_example(props): - return tf.train.Example(features=tf.train.Features(feature={ - k: encode_value_as_feature(v, name=k) - for k, v in iteritems(props) - })) + return tf.train.Example(features=tf.train.Features(feature={ + k: encode_value_as_feature(v, name=k) + for k, v in iteritems(props) + })) + def write_examples_to_tfrecord(tfrecord_filename, examples): - with tf.python_io.TFRecordWriter(tfrecord_filename) as writer: - for example in examples: - writer.write(example.SerializeToString()) + with tf.python_io.TFRecordWriter(tfrecord_filename) as writer: + for example in examples: + writer.write(example.SerializeToString()) diff --git a/sciencebeam_gym/utils/tfrecord_test.py b/sciencebeam_gym/utils/tfrecord_test.py index 3ed63fc..c1689ad 100644 --- a/sciencebeam_gym/utils/tfrecord_test.py +++ b/sciencebeam_gym/utils/tfrecord_test.py @@ -1,30 +1,33 @@ from six import text_type from sciencebeam_gym.utils.tfrecord import ( - iter_examples_to_dict_list, - dict_to_example + iter_examples_to_dict_list, + dict_to_example ) + def dict_to_example_and_reverse(props): - return list(iter_examples_to_dict_list([ - dict_to_example(props).SerializeToString() - ]))[0] + return list(iter_examples_to_dict_list([ + dict_to_example(props).SerializeToString() + ]))[0] + def assert_dict_to_example_and_reverse(props): - assert dict_to_example_and_reverse(props) == props + assert dict_to_example_and_reverse(props) == props + class TestDictToExampleAndIterExamplesToDictList(object): - def test_should_handle_bytes(self): - assert_dict_to_example_and_reverse({ - b'a': b'data' - }) - - def test_should_handle_unicode(self): - assert_dict_to_example_and_reverse({ - b'a': text_type('data') - }) - - def test_should_handle_int(self): - assert_dict_to_example_and_reverse({ - b'a': 1 - }) + def test_should_handle_bytes(self): + assert_dict_to_example_and_reverse({ + b'a': b'data' + }) + + def test_should_handle_unicode(self): + assert_dict_to_example_and_reverse({ + b'a': text_type('data') + }) + + def test_should_handle_int(self): + assert_dict_to_example_and_reverse({ + b'a': 1 + }) diff --git a/setup.py b/setup.py index 13a8d24..b32f1b2 100644 --- a/setup.py +++ b/setup.py @@ -4,91 +4,95 @@ import os import subprocess import shlex +from distutils.command.build import build # pylint: disable=import-error, no-name-in-module + from setuptools import ( - find_packages, - setup, - Command, - Extension + find_packages, + setup, + Command, + Extension ) -from distutils.command.build import build - import numpy as np CUSTOM_COMMANDS = [ - shlex.split(command_line) for command_line in [ - 'apt-get update', - 'apt-get --assume-yes install libxml2', - 'apt-get --assume-yes install poppler-utils' - ] + shlex.split(command_line) for command_line in [ + 'apt-get update', + 'apt-get --assume-yes install libxml2', + 'apt-get --assume-yes install poppler-utils' + ] ] with open(os.path.join('requirements.txt'), 'r') as f: - REQUIRED_PACKAGES = f.readlines() + REQUIRED_PACKAGES = f.readlines() packages = find_packages() # This class handles the pip install mechanism. + + class CustomBuild(build): - """A build command class that will be invoked during package install. - The package built using the current setup.py will be staged and later - installed in the worker using `pip install package'. This class will be - instantiated during install for this specific scenario and will trigger - running the custom commands specified. - """ - sub_commands = build.sub_commands + [('CustomCommands', None)] + """A build command class that will be invoked during package install. + The package built using the current setup.py will be staged and later + installed in the worker using `pip install package'. This class will be + instantiated during install for this specific scenario and will trigger + running the custom commands specified. + """ + sub_commands = build.sub_commands + [('CustomCommands', None)] + class CustomCommands(Command): - """A setuptools Command class able to run arbitrary commands.""" - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def _run_custom_command(self, command_list): - print('Running command: %s' % command_list) - p = subprocess.Popen( - command_list, - stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT - ) - # Can use communicate(input='y\n'.encode()) if the command run requires - # some confirmation. - stdout_data, _ = p.communicate() - print('Command output: %s' % stdout_data) - if p.returncode != 0: - raise RuntimeError( - 'Command %s failed: exit code: %s (output: %s)' % - (command_list, p.returncode, stdout_data) - ) - - def run(self): - for command in CUSTOM_COMMANDS: - self._run_custom_command(command) + """A setuptools Command class able to run arbitrary commands.""" + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def _run_custom_command(self, command_list): + print('Running command: %s' % command_list) + p = subprocess.Popen( + command_list, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + # Can use communicate(input='y\n'.encode()) if the command run requires + # some confirmation. + stdout_data, _ = p.communicate() + print('Command output: %s' % stdout_data) + if p.returncode != 0: + raise RuntimeError( + 'Command %s failed: exit code: %s (output: %s)' % + (command_list, p.returncode, stdout_data) + ) + + def run(self): + for command in CUSTOM_COMMANDS: + self._run_custom_command(command) + setup( - name='sciencebeam_gym', - version='0.0.1', - install_requires=REQUIRED_PACKAGES, - packages=packages, - include_package_data=True, - description='ScienceBeam Gym', - setup_requires=[ - # Setuptools 18.0 properly handles Cython extensions. - 'setuptools>=18.0', - 'cython', - ], - ext_modules=[ - Extension( - 'sciencebeam_gym.alignment.align_fast_utils', - sources=['sciencebeam_gym/alignment/align_fast_utils.pyx'], - ), - ], - include_dirs=[np.get_include()], - cmdclass={ - 'build': CustomBuild, - 'CustomCommands': CustomCommands - } + name='sciencebeam_gym', + version='0.0.1', + install_requires=REQUIRED_PACKAGES, + packages=packages, + include_package_data=True, + description='ScienceBeam Gym', + setup_requires=[ + # Setuptools 18.0 properly handles Cython extensions. + 'setuptools>=18.0', + 'cython', + ], + ext_modules=[ + Extension( + 'sciencebeam_gym.alignment.align_fast_utils', + sources=['sciencebeam_gym/alignment/align_fast_utils.pyx'], + ), + ], + include_dirs=[np.get_include()], + cmdclass={ + 'build': CustomBuild, + 'CustomCommands': CustomCommands + } ) -- GitLab