diff --git a/README.md b/README.md index 8a6bb4692fe1f19e52fe3c3163428ec2cca599af..a8d0a71c0b9245c243ed01b7dc03e5cb96b2df49 100644 --- a/README.md +++ b/README.md @@ -328,6 +328,10 @@ python -m sciencebeam_gym.models.text.crf.annotate_using_predictions \ --output-path="path/to/file.crf-annot-100-p1.lxml" ``` +## Conversion Pipeline + +See [Covnersion Pipeline](doc/conversion-pipeline.md). + ## Tests Unit tests are written using [pytest](https://docs.pytest.org/). Run for example `pytest` or `pytest-watch`. diff --git a/doc/conversion-pipeline.md b/doc/conversion-pipeline.md new file mode 100644 index 0000000000000000000000000000000000000000..53c0c11251a7e63461133c996ee5af4c2d81717c --- /dev/null +++ b/doc/conversion-pipeline.md @@ -0,0 +1,51 @@ +# Experimental Pipeline + +This pipeline is currently under development. It uses the CRF or computer vision model trained by +[ScienceBeam Gym](https://github.com/elifesciences/sciencebeam-gym). + +What you need before you can proceed: + +- At least one of: + - Path to [CRF model](https://github.com/elifesciences/sciencebeam-gym#training-crf-model) + - Path to [exported computer vision model](https://github.com/elifesciences/sciencebeam-gym#export-inference-model) + - `--use-grobid` option +- PDF files, as file list csv/tsv or glob pattern + +To use the CRF model together with the CV model, the CRF model will have to be trained with the CV predictions. + +The following command will process files locally: + +```bash +python -m sciencebeam_gym.convert.conversion_pipeline \ + --data-path=./data \ + --pdf-file-list=./data/file-list-validation.tsv \ + --crf-model=path/to/crf-model.pkl \ + --cv-model-export-dir=./my-model/export \ + --output-path=./data-results \ + --pages=1 \ + --limit=100 +``` + +The following command would process the first 100 files in the cloud using Dataflow: + +```bash +python -m sciencebeam_gym.convert.conversion_pipeline \ + --data-path=gs://my-bucket/data \ + --pdf-file-list=gs://my-bucket/data/file-list-validation.tsv \ + --crf-model=path/to/crf-model.pkl \ + --cv-model-export-dir=gs://mybucket/my-model/export \ + --output-path=gs://my-bucket/data-results \ + --pages=1 \ + --limit=100 \ + --cloud +``` + +You can also enable the post processing of the extracted authors and affiliations using Grobid by adding `--use-grobid`. In that case Grobid will be started automatically. To use an existing version also add `--grobid-url=<api url>` with the url to the Grobid API. If the `--use-grobid` option is used without a CRF or CV model, then it will use Grobid to translate the PDF to XML. + +Note: using Grobid as part of this pipeline is considered deprecated and will likely be removed. + +For a full list of parameters: + +```bash +python -m sciencebeam.examples.conversion_pipeline --help +``` diff --git a/sciencebeam_gym/convert/__init__.py b/sciencebeam_gym/convert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sciencebeam_gym/convert/conversion_pipeline.py b/sciencebeam_gym/convert/conversion_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ddafb5f8b5bc7f6bbc46f214b1c190d57d3083bd --- /dev/null +++ b/sciencebeam_gym/convert/conversion_pipeline.py @@ -0,0 +1,566 @@ +from __future__ import absolute_import + +import argparse +import os +import logging +import pickle +from io import BytesIO + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions + +from lxml import etree + +from sciencebeam_gym.utils.collection import ( + extend_dict, + remove_keys_from_dict +) + +from sciencebeam_gym.beam_utils.utils import ( + TransformAndCount, + TransformAndLog, + MapOrLog, + PreventFusion +) + +from sciencebeam_gym.beam_utils.files import ( + ReadFileList, + FindFiles +) + +from sciencebeam_gym.beam_utils.io import ( + read_all_from_path, + save_file_content +) + +from sciencebeam_gym.beam_utils.main import ( + add_cloud_args, + process_cloud_args, + process_sciencebeam_gym_dep_args +) + +from sciencebeam_gym.structured_document.structured_document_loader import ( + load_structured_document +) + +from sciencebeam_gym.structured_document.lxml import ( + LxmlStructuredDocument +) + +from sciencebeam_gym.preprocess.preprocessing_utils import ( + join_if_relative_path, + convert_pdf_bytes_to_lxml, + parse_page_range, + save_pages, + pdf_bytes_to_png_pages, + get_output_file +) + +from sciencebeam_gym.inference_model.extract_to_xml import ( + extract_structured_document_to_xml +) + +from sciencebeam_gym.models.text.crf.annotate_using_predictions import ( + predict_and_annotate_structured_document, + CRF_TAG_SCOPE +) + +from sciencebeam_gym.inference_model.annotate_using_predictions import ( + annotate_structured_document_using_predicted_images, + AnnotatedImage, + CV_TAG_SCOPE +) + +from .grobid.grobid_xml_enhancer import ( + GrobidXmlEnhancer +) + +from .cv_conversion_utils import ( + InferenceModelWrapper, + image_data_to_png +) + +from .grobid.grobid_service import ( + grobid_service, + GrobidApiPaths +) + + +def get_logger(): + return logging.getLogger(__name__) + +class MetricCounters(object): + FILES = 'files' + READ_LXML_ERROR = 'read_lxml_error_count' + CONVERT_PDF_TO_LXML_ERROR = 'ConvertPdfToLxml_error_count' + CONVERT_PDF_TO_PNG_ERROR = 'ConvertPdfToPng_error_count' + CONVERT_LXML_TO_SVG_ANNOT_ERROR = 'ConvertPdfToSvgAnnot_error_count' + CV_PREDICTION_ERROR = 'ComputerVisionPrediction_error_count' + ANNOTATE_USING_PREDICTION_ERROR = 'AnnotateLxmlUsingPrediction_error_count' + EXTRACT_TO_XML_ERROR = 'ExtractToXml_error_count' + GROBID_ERROR = 'Grobid_error_count' + +class OutputExt(object): + CRF_ANNOT_LXML = '.crf.lxml.gz' + CRF_CV_ANNOT_LXML = '.crf-cv.lxml.gz' + CV_ANNOT_LXML = '.cv.lxml.gz' + CV_PNG = '.cv-png.zip' + +class DataProps(object): + SOURCE_FILENAME = 'source_filename' + PDF_CONTENT = 'pdf_content' + STRUCTURED_DOCUMENT = 'structured_document' + PDF_PNG_PAGES = 'pdf_png_pages' + CV_PREDICTION_PNG_PAGES = 'cv_prediction_png_pages' + COLOR_MAP = 'color_map' + EXTRACTED_XML = 'extracted_xml' + +def convert_pdf_bytes_to_structured_document(pdf_content, path=None, page_range=None): + return LxmlStructuredDocument(etree.parse(BytesIO( + convert_pdf_bytes_to_lxml(pdf_content, path=path, page_range=page_range) + ))) + +def annotate_structured_document_using_predicted_image_data( + structured_document, prediction_images, color_map, tag_scope=None): + + return annotate_structured_document_using_predicted_images( + structured_document, ( + AnnotatedImage(prediction_image, color_map) + for prediction_image in prediction_images + ), tag_scope=tag_scope + ) + +def extract_annotated_structured_document_to_xml(structured_document, tag_scope=None): + xml_root = extract_structured_document_to_xml(structured_document, tag_scope=tag_scope) + return etree.tostring(xml_root, pretty_print=True) + +def load_crf_model(path): + with FileSystems.open(path) as crf_model_f: + return pickle.load(crf_model_f) + +def save_structured_document(filename, structured_document): + # only support saving lxml for now + assert isinstance(structured_document, LxmlStructuredDocument) + save_file_content(filename, etree.tostring(structured_document.root, pretty_print=True)) + return filename + +def get_annot_lxml_ext(crf_enabled, cv_enabled): + if crf_enabled and cv_enabled: + return OutputExt.CRF_CV_ANNOT_LXML + if crf_enabled: + return OutputExt.CRF_ANNOT_LXML + if cv_enabled: + return OutputExt.CV_ANNOT_LXML + raise AssertionError('at least one of crf or cv need to be enabled') + +def PdfUrlSource(opt): + if opt.pdf_file_list: + return ReadFileList(opt.pdf_file_list, column=opt.pdf_file_column, limit=opt.limit) + else: + return FindFiles(join_if_relative_path(opt.base_data_path, opt.pdf_path)) + +def ReadPdfContent(): + return "ReadPdfContent" >> TransformAndCount( + beam.Map(lambda pdf_url: { + DataProps.SOURCE_FILENAME: pdf_url, + DataProps.PDF_CONTENT: read_all_from_path(pdf_url) + }), + MetricCounters.FILES + ) + +def add_read_pdfs_to_annotated_lxml_pipeline_steps(p, opt, get_pipeline_output_file): + page_range = opt.pages + + cv_enabled = opt.cv_model_export_dir + + extract_tag_scope = None + + pdf_urls = p | PdfUrlSource(opt) + + lxml_content = ( + pdf_urls | + PreventFusion() | + + ReadPdfContent() | + + "ConvertPdfToLxml" >> MapOrLog(lambda v: extend_dict(v, { + DataProps.STRUCTURED_DOCUMENT: convert_pdf_bytes_to_structured_document( + v[DataProps.PDF_CONTENT], path=v[DataProps.SOURCE_FILENAME], + page_range=page_range + ) + }), log_fn=lambda e, v: ( + get_logger().warning( + 'caught exception (ignoring item): %s, pdf: %s', + e, v[DataProps.SOURCE_FILENAME], exc_info=e + ) + ), error_count=MetricCounters.CONVERT_PDF_TO_LXML_ERROR) + ) + + if cv_enabled: + image_size = ( + (opt.image_width, opt.image_height) + if opt.image_width and opt.image_height + else None + ) + inference_model_wrapper = InferenceModelWrapper(opt.cv_model_export_dir) + + cv_predictions = ( + lxml_content | + + "ConvertPdfToPng" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + DataProps.PDF_PNG_PAGES: list(pdf_bytes_to_png_pages( + v[DataProps.PDF_CONTENT], + dpi=90, # not used if the image is scaled + image_size=image_size, + page_range=page_range + )) + }), + keys_to_remove={DataProps.PDF_CONTENT} + ), error_count=MetricCounters.CONVERT_PDF_TO_PNG_ERROR) | + + "ComputerVisionPrediction" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + DataProps.CV_PREDICTION_PNG_PAGES: inference_model_wrapper(v[DataProps.PDF_PNG_PAGES]), + DataProps.COLOR_MAP: inference_model_wrapper.get_color_map() + }), + keys_to_remove={DataProps.PDF_PNG_PAGES} + ), error_count=MetricCounters.CV_PREDICTION_ERROR) + ) + + if opt.save_cv_output: + _ = ( + cv_predictions | + "SaveComputerVisionOutput" >> TransformAndLog( + beam.Map(lambda v: save_pages( + get_pipeline_output_file( + v[DataProps.SOURCE_FILENAME], + OutputExt.CV_PNG + ), + '.png', + [image_data_to_png(image_data) for image_data in v[DataProps.CV_PREDICTION_PNG_PAGES]] + )), + log_fn=lambda x: get_logger().info('saved cv output: %s', x) + ) + ) + + cv_annotated_lxml = ( + cv_predictions | + "AnnotateLxmlUsingCvPrediction" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + DataProps.STRUCTURED_DOCUMENT: annotate_structured_document_using_predicted_image_data( + v[DataProps.STRUCTURED_DOCUMENT], + v[DataProps.CV_PREDICTION_PNG_PAGES], + v[DataProps.COLOR_MAP], + tag_scope=CV_TAG_SCOPE + ) + }), + keys_to_remove={DataProps.PDF_PNG_PAGES} + ), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR) + ) + + lxml_content = cv_annotated_lxml + extract_tag_scope = CV_TAG_SCOPE + + if opt.crf_model: + model = load_crf_model(opt.crf_model) + crf_annotated_lxml = ( + lxml_content | + "AnnotateLxmlUsingCrfPrediction" >> MapOrLog(lambda v: extend_dict(v, { + DataProps.STRUCTURED_DOCUMENT: predict_and_annotate_structured_document( + v[DataProps.STRUCTURED_DOCUMENT], model + ) + }), error_count=MetricCounters.ANNOTATE_USING_PREDICTION_ERROR) + ) + + lxml_content = crf_annotated_lxml + extract_tag_scope = CRF_TAG_SCOPE + + if opt.save_annot_lxml: + _ = ( + lxml_content | + "SaveAnnotLxml" >> TransformAndLog( + beam.Map(lambda v: save_structured_document( + get_pipeline_output_file( + v[DataProps.SOURCE_FILENAME], + get_annot_lxml_ext( + crf_enabled=opt.crf_model, + cv_enabled=cv_enabled + ) + ), + v[DataProps.STRUCTURED_DOCUMENT] + )), + log_fn=lambda x: get_logger().info('saved annoted lxml to: %s', x) + ) + ) + return lxml_content, extract_tag_scope + +def add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt): + grobid_transformer = grobid_service( + opt.grobid_url, opt.grobid_action, start_service=opt.start_grobid_service + ) + + return ( + p | + PdfUrlSource(opt) | + PreventFusion() | + ReadPdfContent() | + "Grobid" >> MapOrLog(lambda v: extend_dict(v, { + DataProps.EXTRACTED_XML: grobid_transformer( + (v[DataProps.SOURCE_FILENAME], v[DataProps.PDF_CONTENT]) + )[1] + }), error_count=MetricCounters.GROBID_ERROR) + ) + +def add_read_source_to_extracted_xml_pipeline_steps(p, opt, get_pipeline_output_file): + if opt.lxml_file_list: + lxml_urls = p | ReadFileList(opt.lxml_file_list, column=opt.lxml_file_column, limit=opt.limit) + + annotated_lxml = ( + lxml_urls | + PreventFusion() | + + "ReadLxmlContent" >> TransformAndCount( + MapOrLog(lambda url: { + DataProps.SOURCE_FILENAME: url, + DataProps.STRUCTURED_DOCUMENT: load_structured_document(url) + }, error_count=MetricCounters.READ_LXML_ERROR), + MetricCounters.FILES + ) + ) + + extract_tag_scope = None + else: + annotated_lxml, extract_tag_scope = add_read_pdfs_to_annotated_lxml_pipeline_steps( + p, opt, get_pipeline_output_file + ) + + extracted_xml = ( + annotated_lxml | + "ExtractToXml" >> MapOrLog(lambda v: remove_keys_from_dict( + extend_dict(v, { + DataProps.EXTRACTED_XML: extract_annotated_structured_document_to_xml( + v[DataProps.STRUCTURED_DOCUMENT], + tag_scope=extract_tag_scope + ) + }), + keys_to_remove={DataProps.STRUCTURED_DOCUMENT} + ), error_count=MetricCounters.EXTRACT_TO_XML_ERROR) + ) + + if opt.use_grobid: + enhancer = GrobidXmlEnhancer( + opt.grobid_url, start_service=opt.start_grobid_service + ) + extracted_xml = ( + extracted_xml | + "GrobidEnhanceXml" >> MapOrLog(lambda v: extend_dict(v, { + DataProps.EXTRACTED_XML: enhancer( + v[DataProps.EXTRACTED_XML] + ) + }), error_count=MetricCounters.GROBID_ERROR) + ) + return extracted_xml + +def configure_pipeline(p, opt): + get_pipeline_output_file = lambda source_url, ext: get_output_file( + source_url, + opt.base_data_path, + opt.output_path, + ext + ) + + if ( + opt.use_grobid and not opt.crf_model and + not opt.cv_model_export_dir and not opt.lxml_file_list + ): + extracted_xml = add_read_pdfs_to_grobid_xml_pipeline_steps(p, opt) + else: + extracted_xml = add_read_source_to_extracted_xml_pipeline_steps( + p, opt, get_pipeline_output_file + ) + + _ = ( + extracted_xml | + "WriteXml" >> TransformAndLog( + beam.Map(lambda v: save_file_content( + get_pipeline_output_file( + v[DataProps.SOURCE_FILENAME], + opt.output_suffix + ), + v[DataProps.EXTRACTED_XML] + )), + log_fn=lambda x: get_logger().info('saved xml to: %s', x) + ) + ) + + +def add_main_args(parser): + parser.add_argument( + '--data-path', type=str, required=True, + help='base data path' + ) + + source_group = parser.add_argument_group('source') + source_one_of_group = source_group.add_mutually_exclusive_group(required=True) + source_one_of_group.add_argument( + '--pdf-path', type=str, required=False, + help='path to pdf file(s), relative to data-path' + ) + source_one_of_group.add_argument( + '--pdf-file-list', type=str, required=False, + help='path to pdf csv/tsv file list' + ) + source_group.add_argument( + '--pdf-file-column', type=str, required=False, default='pdf_url', + help='the column of the pdf file list to use' + ) + source_one_of_group.add_argument( + '--lxml-file-list', type=str, required=False, + help='path to annotated lxml or svg pages zip file list' + '; (CRF and CV models are not supported in this mode)' + ) + source_group.add_argument( + '--lxml-file-column', type=str, required=False, default='url', + help='the column of the lxml file list to use' + ) + + parser.add_argument( + '--limit', type=int, required=False, + help='limit the number of file pairs to process' + ) + + output_group = parser.add_argument_group('output') + output_group.add_argument( + '--output-path', required=False, + help='Output directory to write results to.' + ) + output_group.add_argument( + '--output-suffix', required=False, default='.crf.xml', + help='Output file suffix to add to the filename (excluding the file extension).' + ) + + parser.add_argument( + '--save-annot-lxml', action='store_true', default=False, + help='enable saving of annotated lxml' + ) + + grobid_group = parser.add_argument_group('Grobid') + grobid_group.add_argument( + '--use-grobid', action='store_true', default=False, + help='enable the use of grobid' + ) + grobid_group.add_argument( + '--grobid-url', required=False, default=None, + help='Base URL to the Grobid service' + ) + parser.add_argument( + '--grobid-action', required=False, + default=GrobidApiPaths.PROCESS_HEADER_DOCUMENT, + help='Name of the Grobid action (if Grobid is used without CRF or CV model)' + ) + + parser.add_argument( + '--debug', action='store_true', default=False, + help='enable debug output' + ) + + parser.add_argument( + '--pages', type=parse_page_range, default=None, + help='only processes the selected pages' + ) + + crf_group = parser.add_argument_group('CRF') + crf_group.add_argument( + '--crf-model', type=str, required=False, + help='path to saved crf model' + ) + + cv_group = parser.add_argument_group('CV') + cv_group.add_argument( + '--cv-model-export-dir', type=str, required=False, + help='path to cv model export dir' + ) + cv_group.add_argument( + '--image-width', type=int, required=False, + default=256, + help='image width of resulting PNGs' + ) + cv_group.add_argument( + '--image-height', type=int, required=False, + default=256, + help='image height of resulting PNGs' + ) + cv_group.add_argument( + '--save-cv-output', action='store_true', default=False, + help='enable saving of computer vision output (png pages)' + ) + +def process_main_args(args, parser): + args.base_data_path = args.data_path.replace('/*/', '/') + + if not args.output_path: + args.output_path = os.path.join( + os.path.dirname(args.base_data_path), + os.path.basename(args.base_data_path + '-results') + ) + + if args.lxml_file_list: + if args.crf_model: + parser.error('--crf-model cannot be used in conjunction with --lxml-file-list') + + if args.cv_model_export_dir: + parser.error('--crf-model-export-dir cannot be used in conjunction with --lxml-file-list') + else: + if not args.crf_model and not args.cv_model_export_dir and not args.use_grobid: + parser.error( + '--crf-model, --cv-model-export-dir or --use-grobid required in conjunction' + ' with --pdf-file-list or --pdf-path' + ) + + if args.use_grobid and not args.grobid_url: + args.grobid_url = 'http://localhost:8080/api' + args.start_grobid_service = True + else: + args.start_grobid_service = False + +def parse_args(argv=None): + parser = argparse.ArgumentParser() + add_main_args(parser) + add_cloud_args(parser) + + args = parser.parse_args(argv) + + if args.debug: + logging.getLogger().setLevel('DEBUG') + + process_main_args(args, parser) + process_cloud_args( + args, args.output_path, + name='sciencebeam-convert' + ) + process_sciencebeam_gym_dep_args(args) + + get_logger().info('args: %s', args) + + return args + +def run(argv=None): + args = parse_args(argv) + + # We use the save_main_session option because one or more DoFn's in this + # workflow rely on global context (e.g., a module imported at module level). + pipeline_options = PipelineOptions.from_dictionary(vars(args)) + pipeline_options.view_as(SetupOptions).save_main_session = True + + with beam.Pipeline(args.runner, options=pipeline_options) as p: + configure_pipeline(p, args) + + # Execute the pipeline and wait until it is completed. + + +if __name__ == '__main__': + logging.basicConfig(level='INFO') + + run() diff --git a/sciencebeam_gym/convert/conversion_pipeline_test.py b/sciencebeam_gym/convert/conversion_pipeline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..aec25b1136c5799aef1badde66721c2b0155db8e --- /dev/null +++ b/sciencebeam_gym/convert/conversion_pipeline_test.py @@ -0,0 +1,479 @@ +import logging +from mock import patch, DEFAULT + +import pytest + +import apache_beam as beam + +from sciencebeam_gym.beam_utils.testing import ( + BeamTest, + TestPipeline +) + +from . import conversion_pipeline as conversion_pipeline +from .conversion_pipeline import ( + get_annot_lxml_ext, + configure_pipeline, + parse_args, + OutputExt, + CV_TAG_SCOPE, + CRF_TAG_SCOPE +) + + +BASE_TEST_PATH = '.temp/test/conversion-pipeline' +BASE_DATA_PATH = BASE_TEST_PATH + '/data' +PDF_PATH = '*/*.pdf' +MODEL_EXPORT_DIR = BASE_TEST_PATH + '/model-export' +CV_MODEL_EXPORT_DIR = BASE_TEST_PATH + '/cv-model-export' +FILE_LIST_PATH = 'file-list.csv' +FILE_COLUMN = 'column1' + +REL_PDF_FILE_WITHOUT_EXT_1 = '1/file' +PDF_FILE_1 = BASE_DATA_PATH + '/' + REL_PDF_FILE_WITHOUT_EXT_1 + '.pdf' +LXML_FILE_1 = BASE_DATA_PATH + '/' + REL_PDF_FILE_WITHOUT_EXT_1 + '.lxml' + +OUTPUT_PATH = BASE_TEST_PATH + '/out' +OUTPUT_SUFFIX = '.cv.xml' +OUTPUT_XML_FILE_1 = OUTPUT_PATH + '/' + REL_PDF_FILE_WITHOUT_EXT_1 + OUTPUT_SUFFIX + +PDF_CONTENT_1 = b'pdf content' +LXML_CONTENT_1 = b'<LXML>lxml content</LXML>' +TEI_XML_CONTENT_1 = b'<TEI>tei content</TEI>' + +fake_pdf_png_page = lambda i=0: 'fake pdf png page: %d' % i + +MIN_ARGV = [ + '--data-path=' + BASE_DATA_PATH, + '--pdf-path=' + PDF_PATH, + '--crf-model=' + MODEL_EXPORT_DIR +] + +def setup_module(): + logging.basicConfig(level='DEBUG') + +def get_default_args(): + return parse_args(MIN_ARGV) + +def patch_conversion_pipeline(**kwargs): + always_mock = { + 'read_all_from_path', + 'load_structured_document', + 'convert_pdf_bytes_to_lxml', + 'convert_pdf_bytes_to_structured_document', + 'load_crf_model', + 'ReadFileList', + 'FindFiles', + 'predict_and_annotate_structured_document', + 'extract_annotated_structured_document_to_xml', + 'save_structured_document', + 'save_file_content', + 'GrobidXmlEnhancer', + 'grobid_service', + 'pdf_bytes_to_png_pages', + 'InferenceModelWrapper', + 'annotate_structured_document_using_predicted_image_data' + } + + return patch.multiple( + conversion_pipeline, + **{ + k: kwargs.get(k, DEFAULT) + for k in always_mock + } + ) + +def _setup_mocks_for_pages(mocks, page_no_list, file_count=1): + mocks['pdf_bytes_to_png_pages'].return_value = [ + fake_pdf_png_page(i) for i in page_no_list + ] + +class TestGetAnnotLxmlExt(object): + def test_should_return_crf_cv_annot_lxml(self): + assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=True) == OutputExt.CRF_CV_ANNOT_LXML + + def test_should_return_crf_annot_lxml(self): + assert get_annot_lxml_ext(crf_enabled=True, cv_enabled=False) == OutputExt.CRF_ANNOT_LXML + + def test_should_return_cv_annot_lxml(self): + assert get_annot_lxml_ext(crf_enabled=False, cv_enabled=True) == OutputExt.CV_ANNOT_LXML + + def test_should_raise_error_if_neither_crf_or_cv_are_enabled(self): + with pytest.raises(AssertionError): + get_annot_lxml_ext(crf_enabled=False, cv_enabled=False) + +@pytest.mark.slow +class TestConfigurePipeline(BeamTest): + def test_should_pass_pdf_pattern_to_find_files_and_read_pdf_file(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = PDF_PATH + opt.pdf_file_list = None + with TestPipeline() as p: + mocks['FindFiles'].return_value = beam.Create([PDF_FILE_1]) + configure_pipeline(p, opt) + + mocks['FindFiles'].assert_called_with( + BASE_DATA_PATH + '/' + PDF_PATH + ) + mocks['read_all_from_path'].assert_called_with( + PDF_FILE_1 + ) + + def test_should_pass_pdf_file_list_and_limit_to_read_file_list_and_read_pdf_file(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.limit = 100 + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + configure_pipeline(p, opt) + + mocks['ReadFileList'].assert_called_with( + opt.pdf_file_list, column='pdf_url', limit=opt.limit + ) + mocks['read_all_from_path'].assert_called_with( + PDF_FILE_1 + ) + + def test_should_pass_around_values_with_default_pipeline(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + configure_pipeline(p, opt) + + mocks['convert_pdf_bytes_to_structured_document'].assert_called_with( + PDF_CONTENT_1, page_range=None, path=PDF_FILE_1 + ) + mocks['predict_and_annotate_structured_document'].assert_called_with( + mocks['convert_pdf_bytes_to_structured_document'].return_value, + mocks['load_crf_model'].return_value + ) + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['predict_and_annotate_structured_document'].return_value, + tag_scope=CRF_TAG_SCOPE + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + mocks['extract_annotated_structured_document_to_xml'].return_value + ) + + def test_should_save_annotated_lxml_if_enabled(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.save_annot_lxml = True + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + configure_pipeline(p, opt) + + mocks['save_structured_document'].assert_called_with( + OUTPUT_PATH + '/' + REL_PDF_FILE_WITHOUT_EXT_1 + OutputExt.CRF_ANNOT_LXML, + mocks['predict_and_annotate_structured_document'].return_value + ) + + def test_should_use_lxml_file_list_if_provided_and_load_structured_documents(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = None + opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1]) + configure_pipeline(p, opt) + + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['load_structured_document'].return_value, + tag_scope=None + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + mocks['extract_annotated_structured_document_to_xml'].return_value + ) + + def test_should_use_crf_model_with_cv_model_if_enabled(self): + with patch_conversion_pipeline() as mocks: + inference_model_wrapper = mocks['InferenceModelWrapper'].return_value + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + _setup_mocks_for_pages(mocks, [1, 2]) + configure_pipeline(p, opt) + + mocks['convert_pdf_bytes_to_structured_document'].assert_called_with( + PDF_CONTENT_1, page_range=None, path=PDF_FILE_1 + ) + + # cv model + inference_model_wrapper.assert_called_with( + [fake_pdf_png_page(i) for i in [1, 2]] + ) + mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with( + mocks['convert_pdf_bytes_to_structured_document'].return_value, + inference_model_wrapper.return_value, + inference_model_wrapper.get_color_map.return_value, + tag_scope=CV_TAG_SCOPE + ) + + # crf model should receive output from cv model + mocks['predict_and_annotate_structured_document'].assert_called_with( + mocks['annotate_structured_document_using_predicted_image_data'].return_value, + mocks['load_crf_model'].return_value + ) + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['predict_and_annotate_structured_document'].return_value, + tag_scope=CRF_TAG_SCOPE + ) + + def test_should_use_cv_model_only_if_enabled(self): + with patch_conversion_pipeline() as mocks: + inference_model_wrapper = mocks['InferenceModelWrapper'].return_value + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.crf_model = None + opt.cv_model_export_dir = CV_MODEL_EXPORT_DIR + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + _setup_mocks_for_pages(mocks, [1, 2]) + configure_pipeline(p, opt) + + mocks['convert_pdf_bytes_to_structured_document'].assert_called_with( + PDF_CONTENT_1, page_range=None, path=PDF_FILE_1 + ) + + # cv model + inference_model_wrapper.assert_called_with( + [fake_pdf_png_page(i) for i in [1, 2]] + ) + mocks['annotate_structured_document_using_predicted_image_data'].assert_called_with( + mocks['convert_pdf_bytes_to_structured_document'].return_value, + inference_model_wrapper.return_value, + inference_model_wrapper.get_color_map.return_value, + tag_scope=CV_TAG_SCOPE + ) + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['annotate_structured_document_using_predicted_image_data'].return_value, + tag_scope=CV_TAG_SCOPE + ) + + # crf model not be called + mocks['predict_and_annotate_structured_document'].assert_not_called() + + def test_should_use_grobid_if_enabled(self): + with patch_conversion_pipeline() as mocks: + grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.use_grobid = True + opt.grobid_url = 'http://test/api' + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + mocks['convert_pdf_bytes_to_lxml'].return_value = LXML_CONTENT_1 + configure_pipeline(p, opt) + + mocks['GrobidXmlEnhancer'].assert_called_with( + opt.grobid_url, + start_service=opt.start_grobid_service + ) + grobid_xml_enhancer.assert_called_with( + mocks['extract_annotated_structured_document_to_xml'].return_value + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + grobid_xml_enhancer.return_value + ) + + def test_should_use_grobid_with_lxml_file_list_if_enabled(self): + with patch_conversion_pipeline() as mocks: + grobid_xml_enhancer = mocks['GrobidXmlEnhancer'].return_value + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = None + opt.lxml_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.crf_model = None + opt.cv_model_export_dir = None + opt.use_grobid = True + opt.grobid_url = 'http://test/api' + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([LXML_FILE_1]) + configure_pipeline(p, opt) + + mocks['extract_annotated_structured_document_to_xml'].assert_called_with( + mocks['load_structured_document'].return_value, + tag_scope=None + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + grobid_xml_enhancer.return_value + ) + + def test_should_use_grobid_only_if_crf_or_cv_model_are_not_enabled(self): + with patch_conversion_pipeline() as mocks: + opt = get_default_args() + opt.base_data_path = BASE_DATA_PATH + opt.pdf_path = None + opt.pdf_file_list = BASE_DATA_PATH + '/file-list.tsv' + opt.output_path = OUTPUT_PATH + opt.output_suffix = OUTPUT_SUFFIX + opt.crf_model = None + opt.cv_model_export_dir = None + opt.use_grobid = True + opt.grobid_url = 'http://test/api' + with TestPipeline() as p: + mocks['ReadFileList'].return_value = beam.Create([PDF_FILE_1]) + mocks['read_all_from_path'].return_value = PDF_CONTENT_1 + mocks['grobid_service'].return_value = lambda x: ( + PDF_FILE_1, TEI_XML_CONTENT_1 + ) + configure_pipeline(p, opt) + + mocks['grobid_service'].assert_called_with( + opt.grobid_url, + opt.grobid_action, + start_service=opt.start_grobid_service + ) + mocks['save_file_content'].assert_called_with( + OUTPUT_XML_FILE_1, + TEI_XML_CONTENT_1 + ) + +class TestParseArgs(object): + def test_should_parse_minimum_number_of_arguments(self): + parse_args(MIN_ARGV) + + def test_should_raise_error_if_no_source_argument_was_provided(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + + def test_should_allow_pdf_path_to_be_specified(self): + args = parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-path=' + PDF_PATH, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + assert args.pdf_path == PDF_PATH + + def test_should_allow_pdf_file_list_and_column_to_be_specified(self): + args = parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + assert args.pdf_file_list == FILE_LIST_PATH + assert args.pdf_file_column == FILE_COLUMN + + def test_should_allow_lxml_file_list_and_column_to_be_specified(self): + args = parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--lxml-file-list=' + FILE_LIST_PATH, + '--lxml-file-column=' + FILE_COLUMN + ]) + assert args.lxml_file_list == FILE_LIST_PATH + assert args.lxml_file_column == FILE_COLUMN + + def test_should_not_allow_crf_model_with_lxml_file_list(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--lxml-file-list=' + FILE_LIST_PATH, + '--lxml-file-column=' + FILE_COLUMN, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + + def test_should_not_allow_cv_model_with_lxml_file_list(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--lxml-file-list=' + FILE_LIST_PATH, + '--lxml-file-column=' + FILE_COLUMN, + '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR + ]) + + def test_should_require_crf_or_cv_model_with_pdf_file_list(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN + ]) + + def test_should_require_crf_or_cv_model_with_pdf_path(self): + with pytest.raises(SystemExit): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-path=' + PDF_PATH + ]) + + def test_should_allow_crf_model_only_with_pdf_file_list(self): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--crf-model=' + MODEL_EXPORT_DIR + ]) + + def test_should_allow_cv_model_only_with_pdf_file_list(self): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR + ]) + + def test_should_allow_crf_and_cv_model_only_with_pdf_file_list(self): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--crf-model=' + MODEL_EXPORT_DIR, + '--cv-model-export-dir=' + CV_MODEL_EXPORT_DIR + ]) + + def test_should_allow_grobid_only_with_pdf_file_list(self): + parse_args([ + '--data-path=' + BASE_DATA_PATH, + '--pdf-file-list=' + FILE_LIST_PATH, + '--pdf-file-column=' + FILE_COLUMN, + '--use-grobid' + ]) diff --git a/sciencebeam_gym/convert/cv_conversion_utils.py b/sciencebeam_gym/convert/cv_conversion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3cbc5674aa1ada83e336741c97ab49391911103 --- /dev/null +++ b/sciencebeam_gym/convert/cv_conversion_utils.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import + +from io import BytesIO + +import tensorflow as tf +import numpy as np + +from PIL import Image + +from sciencebeam_gym.inference_model import ( + load_inference_model +) + +def lazy_cached_value(value_fn): + cache = {} + def wrapper(): + value = cache.get('value') + if value is None: + value = value_fn() + cache['value'] = value + return value + return wrapper + +def image_data_to_png(image_data): + image = Image.fromarray(image_data, 'RGB') + out = BytesIO() + image.save(out, 'png') + return out.getvalue() + +def png_bytes_to_image_data(png_bytes): + return np.asarray(Image.open(BytesIO(png_bytes)).convert('RGB'), dtype=np.uint8) + +class InferenceModelWrapper(object): + def __init__(self, export_dir): + self.session_cache = lazy_cached_value(lambda: tf.InteractiveSession()) + self.inference_model_cache = lazy_cached_value( + lambda: load_inference_model(export_dir, session=self.session_cache()) + ) + + def get_color_map(self): + return self.inference_model_cache().get_color_map(session=self.session_cache()) + + def __call__(self, png_pages): + input_data = [ + png_bytes_to_image_data(png_page) + for png_page in png_pages + ] + output_img_data_batch = self.inference_model_cache()(input_data, session=self.session_cache()) + return output_img_data_batch diff --git a/sciencebeam_gym/convert/cv_conversion_utils_test.py b/sciencebeam_gym/convert/cv_conversion_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ae828217c21717a674d2eb1aa887d04a3eef35bc --- /dev/null +++ b/sciencebeam_gym/convert/cv_conversion_utils_test.py @@ -0,0 +1,46 @@ +import logging +from mock import patch + +from . import cv_conversion_utils as cv_conversion_utils +from .cv_conversion_utils import ( + InferenceModelWrapper +) + +CV_MODEL_EXPORT_DIR = './model-export' +PNG_BYTES = b'dummy png bytes' + +def setup_module(): + logging.basicConfig(level='DEBUG') + +class TestInferenceModelWrapper(object): + def test_should_lazy_load_model(self): + with patch.object(cv_conversion_utils, 'load_inference_model') as load_inference_model: + with patch.object(cv_conversion_utils, 'png_bytes_to_image_data') as png_bytes_to_image_data: + with patch.object(cv_conversion_utils, 'tf') as tf: + inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR) + load_inference_model.assert_not_called() + + output_image_data = inference_model_wrapper([PNG_BYTES]) + + tf.InteractiveSession.assert_called_with() + session = tf.InteractiveSession.return_value + + png_bytes_to_image_data.assert_called_with(PNG_BYTES) + + load_inference_model.assert_called_with(CV_MODEL_EXPORT_DIR, session=session) + inference_model = load_inference_model.return_value + + inference_model.assert_called_with([ + png_bytes_to_image_data.return_value + ], session=session) + + assert output_image_data == inference_model.return_value + + def test_should_load_model_only_once(self): + with patch.object(cv_conversion_utils, 'load_inference_model') as load_inference_model: + with patch.object(cv_conversion_utils, 'png_bytes_to_image_data') as _: + with patch.object(cv_conversion_utils, 'tf') as _: + inference_model_wrapper = InferenceModelWrapper(CV_MODEL_EXPORT_DIR) + inference_model_wrapper([PNG_BYTES]) + inference_model_wrapper([PNG_BYTES]) + load_inference_model.assert_called_once() diff --git a/sciencebeam_gym/convert/grobid/__init__.py b/sciencebeam_gym/convert/grobid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03b8fa859e9a15f08ed6f397fdb0f2c918d6e9c5 --- /dev/null +++ b/sciencebeam_gym/convert/grobid/__init__.py @@ -0,0 +1 @@ +# temporary package to retain functionality after moving from ScienceBeam diff --git a/sciencebeam_gym/convert/grobid/grobid_service.py b/sciencebeam_gym/convert/grobid/grobid_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e80fa086c03d12fd49c13a71196f1e7775d5bc35 --- /dev/null +++ b/sciencebeam_gym/convert/grobid/grobid_service.py @@ -0,0 +1,84 @@ +from io import BytesIO +import logging +from functools import partial + +import requests + +from .grobid_service_wrapper import ( + GrobidServiceWrapper +) + +class GrobidApiPaths(object): + PROCESS_HEADER_DOCUMENT = '/processHeaderDocument' + PROCESS_HEADER_NAMES = '/processHeaderNames' + PROCESS_CITATION_NAMES = '/processCitationNames' + PROCESS_AFFILIATIONS = '/processAffiliations' + PROCESS_CITATION = '/processCitation' + PROCESS_FULL_TEXT_DOCUMENT = '/processFulltextDocument' + +service_wrapper = GrobidServiceWrapper() + +def get_logger(): + return logging.getLogger(__name__) + +def start_service_if_not_running(): + service_wrapper.start_service_if_not_running() + +def run_grobid_service(item, base_url, path, start_service=True, field_name=None): + """ + Translates PDF content via the GROBID service. + + Args: + item: one of: + * tuple (filename, pdf content) + * pdf content + * field content (requires field name) + base_url: base url to the GROBID service + path: path of the GROBID endpoint + start_service: if true, a GROBID service will be started automatically and + kept running until the application ends + field_name: the field name the field content relates to + + Returns: + If item is tuple: + returns tuple (filename, xml result) + Otherwise: + returns xml result + """ + + url = base_url + path + + if start_service: + start_service_if_not_running() + + if field_name: + content = item + response = requests.post(url, + data={field_name: content} + ) + else: + filename = item[0] if isinstance(item, tuple) else 'unknown.pdf' + content = item[1] if isinstance(item, tuple) else item + get_logger().info('processing: %s (%d) - %s', filename, len(content), url) + response = requests.post(url, + files={'input': (filename, BytesIO(content))}, + data={ + 'consolidateHeader': '0', + 'consolidateCitations': '0' + } + ) + response.raise_for_status() + result_content = response.content + if isinstance(item, tuple): + return filename, result_content + else: + return result_content + +def grobid_service(base_url, path, start_service=True, field_name=None): + return partial( + run_grobid_service, + base_url=base_url, + path=path, + start_service=start_service, + field_name=field_name + ) diff --git a/sciencebeam_gym/convert/grobid/grobid_service_test.py b/sciencebeam_gym/convert/grobid/grobid_service_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fd02bfe06bf9672e174c32096f7b79ae9ed77c1e --- /dev/null +++ b/sciencebeam_gym/convert/grobid/grobid_service_test.py @@ -0,0 +1,67 @@ +from mock import patch, ANY + +import pytest + +from .grobid_service import grobid_service as create_grobid_service + +BASE_URL = 'http://grobid/api' +PATH_1 = '/path1' +PATH_2 = '/path2' + +FILENAME_1 = 'file1.pdf' +PDF_CONTENT_1 = b'pdf content1' + +FIELD_NAME_1 = 'field1' +FIELD_VALUE_1 = 'value1' + +@pytest.fixture(name='requests_post', autouse=True) +def _mock_requests_post(): + with patch('requests.post') as requests_post: + yield requests_post + +class TestCreateGrobidService(object): + def test_should_pass_url_and_data_as_file(self, requests_post): + create_grobid_service(BASE_URL, PATH_1, start_service=False)( + (FILENAME_1, PDF_CONTENT_1) + ) + requests_post.assert_called_with( + BASE_URL + PATH_1, + files=ANY, + data=ANY + ) + kwargs = requests_post.call_args[1] + assert kwargs['files']['input'][0] == FILENAME_1 + assert kwargs['files']['input'][1].read() == PDF_CONTENT_1 + + def test_should_be_able_to_override_path_on_call(self, requests_post): + create_grobid_service(BASE_URL, PATH_1, start_service=False)( + (FILENAME_1, PDF_CONTENT_1), + path=PATH_2 + ) + requests_post.assert_called_with( + BASE_URL + PATH_2, + files=ANY, + data=ANY + ) + + def test_should_pass_data_as_field(self, requests_post): + create_grobid_service(BASE_URL, PATH_1, start_service=False, field_name=FIELD_NAME_1)( + FIELD_VALUE_1 + ) + requests_post.assert_called_with( + BASE_URL + PATH_1, + data={FIELD_NAME_1: FIELD_VALUE_1} + ) + + def test_should_pass_consolidate_flags(self, requests_post): + create_grobid_service(BASE_URL, PATH_1, start_service=False)( + (FILENAME_1, PDF_CONTENT_1) + ) + requests_post.assert_called_with( + BASE_URL + PATH_1, + files=ANY, + data={ + 'consolidateHeader': '0', + 'consolidateCitations': '0' + } + ) diff --git a/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py b/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f9e0c8e4ba3b2060b1bfaa9376feefdd06086e27 --- /dev/null +++ b/sciencebeam_gym/convert/grobid/grobid_service_wrapper.py @@ -0,0 +1,126 @@ +import logging +from threading import Thread +from functools import partial +import shlex +import subprocess +from subprocess import PIPE +import atexit +import os +from zipfile import ZipFile +from shutil import rmtree +from urllib import URLopener + +from sciencebeam_gym.utils.io import makedirs +from sciencebeam_gym.utils.zip import extract_all_with_executable_permission + +def get_logger(): + return logging.getLogger(__name__) + +def iter_read_lines(reader): + while True: + line = reader.readline() + if not line: + break + yield line + +def stream_lines_to_logger(lines, logger, prefix=''): + for line in lines: + line = line.strip() + if line: + logger.info('%s%s', prefix, line) + +class GrobidServiceWrapper(object): + def __init__(self): + self.grobid_service_instance = None + temp_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../.temp')) + self.grobid_service_target_directory = os.path.join(temp_dir, 'grobid-service') + self.grobid_service_zip_filename = os.path.join(temp_dir, 'grobid-service.zip') + self.grobid_service_zip_url = ( + 'https://storage.googleapis.com/elife-ml/artefacts/grobid-service.zip' + ) + + def stop_service_if_running(self): + if self.grobid_service_instance is not None: + get_logger().info('stopping instance: %s', self.grobid_service_instance) + self.grobid_service_instance.kill() + + def download__grobid_service_zip_if_not_exist(self): + if not os.path.isfile(self.grobid_service_zip_filename): + get_logger().info( + 'downloading %s to %s', + self.grobid_service_zip_url, + self.grobid_service_zip_filename + ) + + makedirs(os.path.dirname(self.grobid_service_zip_filename), exists_ok=True) + + temp_zip_filename = self.grobid_service_zip_filename + '.part' + if os.path.isfile(temp_zip_filename): + os.remove(temp_zip_filename) + URLopener().retrieve(self.grobid_service_zip_url, temp_zip_filename) + os.rename(temp_zip_filename, self.grobid_service_zip_filename) + + def unzip_grobid_service_zip_if_target_directory_does_not_exist(self): + if not os.path.isdir(self.grobid_service_target_directory): + self.download__grobid_service_zip_if_not_exist() + get_logger().info( + 'unzipping %s to %s', + self.grobid_service_zip_filename, + self.grobid_service_target_directory + ) + temp_target_directory = self.grobid_service_target_directory + '.part' + if os.path.isdir(temp_target_directory): + rmtree(temp_target_directory) + + with ZipFile(self.grobid_service_zip_filename, 'r') as zf: + extract_all_with_executable_permission(zf, temp_target_directory) + sub_dir = os.path.join(temp_target_directory, 'grobid-service') + if os.path.isdir(sub_dir): + os.rename(sub_dir, self.grobid_service_target_directory) + rmtree(temp_target_directory) + else: + os.rename(temp_target_directory, self.grobid_service_target_directory) + + def start_service_if_not_running(self): + get_logger().info('grobid_service_instance: %s', self.grobid_service_instance) + if self.grobid_service_instance is None: + self.unzip_grobid_service_zip_if_target_directory_does_not_exist() + grobid_service_home = os.path.abspath(self.grobid_service_target_directory) + cwd = grobid_service_home + '/bin' + grobid_service_home_jar_dir = grobid_service_home + '/lib' + command_line = 'java -cp "{}/*" org.grobid.service.main.GrobidServiceApplication'.format( + grobid_service_home_jar_dir + ) + args = shlex.split(command_line) + get_logger().info('command_line: %s', command_line) + get_logger().info('args: %s', args) + self.grobid_service_instance = subprocess.Popen( + args, cwd=cwd, stdout=PIPE, stderr=subprocess.STDOUT + ) + if self.grobid_service_instance is None: + raise RuntimeError('failed to start grobid service') + atexit.register(self.stop_service_if_running) + pstdout = self.grobid_service_instance.stdout + out_prefix = 'stdout: ' + while True: + line = pstdout.readline().strip() + if line: + get_logger().info('%s%s', out_prefix, line) + if 'jetty.server.Server: Started' in line: + get_logger().info('grobid service started successfully') + break + if 'ERROR' in line or 'Error' in line: + raise RuntimeError('failed to start grobid service due to {}'.format(line)) + t = Thread(target=partial( + stream_lines_to_logger, + lines=iter_read_lines(pstdout), + logger=get_logger(), + prefix=out_prefix + )) + t.daemon = True + t.start() + +if __name__ == '__main__': + logging.basicConfig(level='INFO') + + GrobidServiceWrapper().start_service_if_not_running() diff --git a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..26646cb215d0b48f0056cdb22057c072ce9e3a43 --- /dev/null +++ b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer.py @@ -0,0 +1,109 @@ +import logging +from io import BytesIO + +from lxml import etree +from lxml.builder import E + +from sciencebeam_gym.inference_model.extract_to_xml import ( + XmlPaths, + create_node_recursive, + rsplit_xml_path +) + +from .grobid_service import ( + grobid_service, + GrobidApiPaths +) + +TEI_NS = 'http://www.tei-c.org/ns/1.0' +TEI_NS_PREFIX = '{%s}' % TEI_NS +TEI_PERS_NAME = TEI_NS_PREFIX + 'persName' +TEI_FORNAME = TEI_NS_PREFIX + 'forename' +TEI_SURNAME = TEI_NS_PREFIX + 'surname' + +JATS_SURNAME = 'surname' +JATS_GIVEN_NAMES = 'given-names' +JATS_ADDR_LINE = 'addr-line' +JATS_NAMED_CONTENT = 'named-content' +JATS_INSTITUTION = 'institution' + +def get_logger(): + return logging.getLogger(__name__) + +def create_or_append(xml_root, path): + parent_path, tag_name = rsplit_xml_path(path) + parent_node = create_node_recursive(xml_root, parent_path, exists_ok=True) + node = E(tag_name) + parent_node.append(node) + return node + +class GrobidXmlEnhancer(object): + def __init__(self, grobid_url, start_service): + self.process_header_names = grobid_service( + grobid_url, + GrobidApiPaths.PROCESS_HEADER_NAMES, + start_service=start_service, + field_name='names' + ) + self.process_affiliations = grobid_service( + grobid_url, + GrobidApiPaths.PROCESS_AFFILIATIONS, + start_service=start_service, + field_name='affiliations' + ) + + def process_and_replace_authors(self, xml_root): + author_nodes = list(xml_root.findall(XmlPaths.AUTHOR)) + if author_nodes: + authors = '\n'.join(x.text for x in author_nodes) + get_logger().debug('authors: %s', authors) + grobid_response = self.process_header_names(authors) + get_logger().debug('grobid_response: %s', grobid_response) + response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response)) + for author in author_nodes: + author.getparent().remove(author) + for pers_name in response_xml_root.findall(TEI_PERS_NAME): + get_logger().debug('pers_name: %s', pers_name) + node = create_or_append(xml_root, XmlPaths.AUTHOR) + for surname in pers_name.findall(TEI_SURNAME): + node.append(E(JATS_SURNAME, surname.text)) + forenames = [x.text for x in pers_name.findall(TEI_FORNAME)] + if forenames: + node.append(E(JATS_GIVEN_NAMES, ' '.join(forenames))) + return xml_root + + def process_and_replace_affiliations(self, xml_root): + aff_nodes = list(xml_root.findall(XmlPaths.AUTHOR_AFF)) + if aff_nodes: + affiliations = '\n'.join(x.text for x in aff_nodes) + get_logger().debug('affiliations: %s', affiliations) + grobid_response = self.process_affiliations(affiliations) + get_logger().debug('grobid_response: %s', grobid_response) + response_xml_root = etree.parse(BytesIO('<dummy>%s</dummy>' % grobid_response)) + for aff in aff_nodes: + aff.getparent().remove(aff) + for affiliation in response_xml_root.findall('affiliation'): + get_logger().debug('affiliation: %s', affiliation) + node = create_or_append(xml_root, XmlPaths.AUTHOR_AFF) + for department in affiliation.xpath('./orgName[@type="department"]'): + node.append(E( + JATS_ADDR_LINE, + E( + JATS_NAMED_CONTENT, + department.text, + { + 'content-type': 'department' + } + ) + )) + for institution in affiliation.xpath('./orgName[@type="institution"]'): + node.append(E( + JATS_INSTITUTION, + institution.text + )) + + def __call__(self, extracted_xml): + xml_root = etree.parse(BytesIO(extracted_xml)) + self.process_and_replace_authors(xml_root) + self.process_and_replace_affiliations(xml_root) + return etree.tostring(xml_root, pretty_print=True) diff --git a/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b705240cb1d40b413776c17be193254319fa31 --- /dev/null +++ b/sciencebeam_gym/convert/grobid/grobid_xml_enhancer_test.py @@ -0,0 +1,207 @@ +import logging +from io import BytesIO +from contextlib import contextmanager +from mock import patch, Mock + +from lxml import etree +from lxml.builder import E + +from sciencebeam_gym.inference_model.extract_to_xml import ( + XmlPaths, + create_xml_text +) + +from .grobid_service import ( + GrobidApiPaths +) + +from . import grobid_xml_enhancer as grobid_xml_enhancer +from .grobid_xml_enhancer import ( + GrobidXmlEnhancer +) + +GROBID_URL = 'http://localhost:8080/api' + +TEXT_1 = 'text 1' +TEXT_2 = 'text 2' + +FORENAME_1 = 'forename 1' +FORENAME_2 = 'forename 2' +SURNAME_1 = 'surname 1' +SURNAME_2 = 'surname 2' + +DEPARTMENT_1 = 'department 1' +DEPARTMENT_2 = 'department 2' +INSTITUTION_1 = 'institution 1' +INSTITUTION_2 = 'institution 2' + +DEPARTMENT_XPATH = './addr-line/named-content[@content-type="department"]' +INSTITUTION_XPATH = './institution' + +def get_logger(): + return logging.getLogger(__name__) + +def setup_module(): + logging.basicConfig(level='DEBUG') + +def pers_name(*names): + forenames = names[:-1] + surname = names[-1] + return ( + '<persName xmlns="http://www.tei-c.org/ns/1.0">' + ' %s' + ' <surname>%s</surname>' + '</persName>' + ) % ( + ' '.join( + '<forename type="%s">%s</forename>' % ( + 'first' if i == 0 else 'middle', + forename + ) + for i, forename in enumerate(forenames) + ), + surname + ) + +def tei_affiliation(department=None, institution=None): + affiliation = E.affiliation() + if department: + affiliation.append(E.orgName(department, type='department')) + if institution: + affiliation.append(E.orgName(institution, type='institution')) + return etree.tostring(affiliation) + +def get_text(node): + if isinstance(node, list): + return ''.join(get_text(x) for x in node) + return node.text if node is not None else None + +def get_child_text(node, name): + return get_text(node.find(name)) + +@contextmanager +def patch_grobid_service(): + with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service: + process_header_names = Mock() + process_affiliations = Mock() + grobid_service.side_effect = [process_header_names, process_affiliations] + yield process_header_names, process_affiliations + +class TestGrobidXmlEnhancer(object): + def test_should_initialise_grobid_service(self): + with patch.object(grobid_xml_enhancer, 'grobid_service') as grobid_service: + GrobidXmlEnhancer(GROBID_URL, start_service=False) + grobid_service.assert_any_call( + GROBID_URL, GrobidApiPaths.PROCESS_HEADER_NAMES, start_service=False, field_name='names' + ) + grobid_service.assert_any_call( + GROBID_URL, GrobidApiPaths.PROCESS_AFFILIATIONS, start_service=False, + field_name='affiliations' + ) + + def test_should_convert_single_author(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_affiliations + process_header_names.return_value = pers_name(FORENAME_1, SURNAME_1) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + authors = enhanced_xml_root.findall(XmlPaths.AUTHOR) + assert [ + (get_child_text(author, 'given-names'), get_child_text(author, 'surname')) + for author in authors + ] == [(FORENAME_1, SURNAME_1)] + + def test_should_convert_single_affiliation(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_header_names + process_affiliations.return_value = tei_affiliation( + department=DEPARTMENT_1, + institution=INSTITUTION_1 + ) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF) + assert [ + get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations + ] == [DEPARTMENT_1] + assert [ + get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations + ] == [INSTITUTION_1] + + def test_should_convert_multiple_author(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_affiliations + process_header_names.return_value = ( + pers_name(FORENAME_1, SURNAME_1) + + pers_name(FORENAME_2, SURNAME_2) + ) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1) + create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_2) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + authors = enhanced_xml_root.findall(XmlPaths.AUTHOR) + assert [ + (get_child_text(author, 'given-names'), get_child_text(author, 'surname')) + for author in authors + ] == [(FORENAME_1, SURNAME_1), (FORENAME_2, SURNAME_2)] + + def test_should_convert_multiple_affiliations(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_header_names + process_affiliations.return_value = ( + tei_affiliation( + department=DEPARTMENT_1, + institution=INSTITUTION_1 + ) + + tei_affiliation( + department=DEPARTMENT_2, + institution=INSTITUTION_2 + ) + ) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR_AFF, TEXT_1) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + affiliations = enhanced_xml_root.findall(XmlPaths.AUTHOR_AFF) + assert [ + get_text(x.xpath(DEPARTMENT_XPATH)) for x in affiliations + ] == [DEPARTMENT_1, DEPARTMENT_2] + assert [ + get_text(x.xpath(INSTITUTION_XPATH)) for x in affiliations + ] == [INSTITUTION_1, INSTITUTION_2] + + def test_should_combine_multiple_forenames(self): + logging.basicConfig(level='DEBUG') + with patch_grobid_service() as (process_header_names, process_affiliations): + _ = process_affiliations + process_header_names.return_value = pers_name( + FORENAME_1, FORENAME_2, SURNAME_1 + ) + enhancer = GrobidXmlEnhancer(GROBID_URL, start_service=False) + xml_root = E.article() + create_xml_text(xml_root, XmlPaths.AUTHOR, TEXT_1) + enhanced_xml = enhancer(etree.tostring(xml_root)) + get_logger().info('enhanced_xml: %s', enhanced_xml) + enhanced_xml_root = etree.parse(BytesIO(enhanced_xml)) + authors = enhanced_xml_root.findall(XmlPaths.AUTHOR) + assert [ + (get_child_text(author, 'given-names'), get_child_text(author, 'surname')) + for author in authors + ] == [(' '.join([FORENAME_1, FORENAME_2]), SURNAME_1)]