-
Daniel Ecer authored9bc7ff91
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
preprocessing_pipeline_test.py 14.59 KiB
from contextlib import contextmanager
import logging
from mock import Mock, patch, DEFAULT, MagicMock
import pytest
import apache_beam as beam
from sciencebeam_gym.utils.collection import (
extend_dict
)
from sciencebeam_gym.beam_utils.utils import (
TransformAndLog
)
from sciencebeam_gym.beam_utils.testing import (
BeamTest,
TestPipeline,
get_current_test_context,
get_counter_value
)
from sciencebeam_gym.preprocess.preprocessing_pipeline import (
parse_args,
configure_pipeline,
MetricCounters
)
PREPROCESSING_PIPELINE = 'sciencebeam_gym.preprocess.preprocessing_pipeline'
BASE_DATA_PATH = 'base'
PDF_PATH = '*/*.pdf'
XML_PATH = '*/*.xml'
PDF_FILE_1 = '1/file.pdf'
XML_FILE_1 = '1/file.xml'
PDF_FILE_2 = '2/file.pdf'
XML_FILE_2 = '2/file.xml'
PDF_XML_FILE_LIST_FILE_1 = 'pdf-xml-files.tsv'
def get_logger():
return logging.getLogger(__name__)
def fake_content(path):
return 'fake content: %s' % path
def fake_lxml_for_pdf(pdf, path, page_range=None):
return 'fake lxml for pdf: %s (%s) [%s]' % (pdf, path, page_range)
fake_svg_page = lambda i=0: 'fake svg page: %d' % i
fake_pdf_png_page = lambda i=0: 'fake pdf png page: %d' % i
fake_block_png_page = lambda i=0: 'fake block png page: %d' % i
def get_global_tfrecords_mock():
# workaround for mock that would get serialized/deserialized before being invoked
return get_current_test_context().tfrecords_mock
@contextmanager
def patch_preprocessing_pipeline(**kwargs):
always_mock = {
'find_file_pairs_grouped_by_parent_directory_or_name',
'read_all_from_path',
'pdf_bytes_to_png_pages',
'convert_pdf_bytes_to_lxml',
'convert_and_annotate_lxml_content',
'svg_page_to_blockified_png_bytes',
'save_svg_roots',
'save_pages',
'evaluate_document_by_page',
'ReadDictCsv'
}
tfrecords_mock = Mock(name='tfrecords_mock')
def DummyWritePropsToTFRecord(file_path, extract_props):
return TransformAndLog(beam.Map(
lambda v: tfrecords_mock(file_path, list(extract_props(v)))
), log_fn=lambda x: get_logger().info('tfrecords: %s', x))
with patch.multiple(
PREPROCESSING_PIPELINE,
WritePropsToTFRecord=DummyWritePropsToTFRecord,
**{
k: kwargs.get(k, DEFAULT)
for k in always_mock
}
) as mocks:
get_current_test_context().mocks = mocks
mocks['read_all_from_path'].side_effect = fake_content
mocks['convert_pdf_bytes_to_lxml'].side_effect = fake_lxml_for_pdf
yield extend_dict(
mocks,
{'tfrecords': tfrecords_mock}
)
MIN_ARGV = [
'--data-path=' + BASE_DATA_PATH,
'--pdf-path=' + PDF_PATH,
'--xml-path=' + XML_PATH,
'--save-svg'
]
def get_default_args():
return parse_args([
'--data-path=' + BASE_DATA_PATH,
'--pdf-path=' + PDF_PATH,
'--xml-path=' + XML_PATH,
'--save-svg'
])
def _expected_tfrecord_props(pdf_file, page_no=1):
return {
'input_uri': pdf_file,
'annotation_uri': pdf_file + '.annot',
'input_image': fake_pdf_png_page(page_no),
'annotation_image': fake_block_png_page(page_no)
}
def _setup_mocks_for_pages(mocks, page_no_list, file_count=1):
mocks['convert_and_annotate_lxml_content'].return_value = [
fake_svg_page(i) for i in page_no_list
]
mocks['pdf_bytes_to_png_pages'].return_value = [
fake_pdf_png_page(i) for i in page_no_list
]
mocks['svg_page_to_blockified_png_bytes'].side_effect = [
fake_block_png_page(i)
for _ in range(file_count)
for i in page_no_list
]
class TestConfigurePipeline(BeamTest):
def test_should_pass_pdf_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self):
with patch_preprocessing_pipeline() as mocks:
opt = get_default_args()
opt.base_data_path = 'base'
opt.pdf_path = 'pdf'
opt.xml_path = 'xml'
with TestPipeline() as p:
mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = []
configure_pipeline(p, opt)
mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with(
['base/pdf', 'base/xml']
)
def test_should_pass_lxml_and_xml_patterns_to_find_file_pairs_grouped_by_parent_directory(self):
with patch_preprocessing_pipeline() as mocks:
opt = get_default_args()
opt.base_data_path = 'base'
opt.pdf_path = ''
opt.lxml_path = 'lxml'
opt.xml_path = 'xml'
with TestPipeline() as p:
mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = []
configure_pipeline(p, opt)
mocks['find_file_pairs_grouped_by_parent_directory_or_name'].assert_called_with(
['base/lxml', 'base/xml']
)
def test_should_write_tfrecords_from_pdf_xml_file_list(self):
with patch_preprocessing_pipeline() as mocks:
opt = get_default_args()
opt.pdf_path = None
opt.xml_path = None
opt.pdf_xml_file_list = '.temp/file-list.tsv'
opt.save_tfrecords = True
with TestPipeline() as p:
mocks['ReadDictCsv'].return_value = beam.Create([{
'pdf_url': PDF_FILE_1,
'xml_url': XML_FILE_1
}])
_setup_mocks_for_pages(mocks, [1])
configure_pipeline(p, opt)
mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None)
mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
_expected_tfrecord_props(PDF_FILE_1)
])
def test_should_write_multiple_tfrecords_from_pdf_xml_file_list(self):
with patch_preprocessing_pipeline() as mocks:
opt = get_default_args()
opt.pdf_path = None
opt.xml_path = None
opt.pdf_xml_file_list = '.temp/file-list.tsv'
opt.save_tfrecords = True
with TestPipeline() as p:
mocks['ReadDictCsv'].return_value = beam.Create([{
'pdf_url': PDF_FILE_1,
'xml_url': XML_FILE_1
}, {
'pdf_url': PDF_FILE_2,
'xml_url': XML_FILE_2
}])
_setup_mocks_for_pages(mocks, [1], file_count=2)
configure_pipeline(p, opt)
mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=None)
for pdf_file in [PDF_FILE_1, PDF_FILE_2]:
mocks['tfrecords'].assert_any_call(opt.output_path + '/data', [
_expected_tfrecord_props(pdf_file)
])
assert mocks['tfrecords'].call_count == 2
def test_should_pass_limit_to_read_dict_csv(self):
with patch_preprocessing_pipeline() as mocks:
opt = get_default_args()
opt.pdf_path = None
opt.xml_path = None
opt.pdf_xml_file_list = '.temp/file-list.tsv'
opt.limit = 1
opt.save_tfrecords = True
with TestPipeline() as p:
mocks['ReadDictCsv'].return_value = beam.Create([{
'pdf_url': PDF_FILE_1,
'xml_url': XML_FILE_1
}])
_setup_mocks_for_pages(mocks, [1])
configure_pipeline(p, opt)
mocks['ReadDictCsv'].assert_called_with(opt.pdf_xml_file_list, limit=opt.limit)
assert mocks['tfrecords'].call_count == 1
def test_should_pass_limit_to_find_file_pairs_grouped_by_parent_directory_or_name(self):
with patch_preprocessing_pipeline() as mocks:
opt = get_default_args()
opt.base_data_path = 'base'
opt.pdf_path = ''
opt.lxml_path = 'lxml'
opt.xml_path = 'xml'
opt.save_tfrecords = True
opt.limit = 1
with TestPipeline() as p:
mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
(PDF_FILE_1, XML_FILE_1),
(PDF_FILE_2, XML_FILE_2)
]
configure_pipeline(p, opt)
mocks['tfrecords'].call_count == 1
def test_should_write_tfrecords_from_pdf_xml_path(self):
with patch_preprocessing_pipeline() as mocks:
opt = get_default_args()
opt.save_tfrecords = True
with TestPipeline() as p:
mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
(PDF_FILE_1, XML_FILE_1)
]
_setup_mocks_for_pages(mocks, [1])
configure_pipeline(p, opt)
mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
_expected_tfrecord_props(PDF_FILE_1)
])
def test_should_write_multiple_tfrecords_and_count_pages(self):
with patch_preprocessing_pipeline() as mocks:
opt = get_default_args()
opt.save_tfrecords = True
with TestPipeline() as p:
mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
(PDF_FILE_1, XML_FILE_1)
]
_setup_mocks_for_pages(mocks, [1, 2])
configure_pipeline(p, opt)
p_result = p.run()
assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1
assert get_counter_value(p_result, MetricCounters.PAGE) == 2
assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) is None
mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
_expected_tfrecord_props(PDF_FILE_1, page_no=i)
for i in [1, 2]
])
def test_should_not_write_tfrecord_below_annotation_threshold_and_count_pages(self):
custom_mocks = dict(
evaluate_document_by_page=lambda _: [{
'percentage': {
# low percentage of None (no annotation, include)
None: 0.1
}
}, {
'percentage': {
# low percentage of None (no annotation, exclude)
None: 0.9
}
}]
)
with patch_preprocessing_pipeline(**custom_mocks) as mocks:
opt = get_default_args()
opt.save_tfrecords = True
opt.min_annotation_percentage = 0.5
with TestPipeline() as p:
mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
(PDF_FILE_1, XML_FILE_1)
]
_setup_mocks_for_pages(mocks, [1, 2])
configure_pipeline(p, opt)
p_result = p.run()
assert get_counter_value(p_result, MetricCounters.FILE_PAIR) == 1
assert get_counter_value(p_result, MetricCounters.PAGE) == 2
assert get_counter_value(p_result, MetricCounters.FILTERED_PAGE) == 1
mocks['tfrecords'].assert_called_with(opt.output_path + '/data', [
_expected_tfrecord_props(PDF_FILE_1, page_no=i)
for i in [1]
])
def test_should_only_process_selected_pages(self):
with patch_preprocessing_pipeline() as mocks:
opt = get_default_args()
opt.save_tfrecords = True
opt.save_png = True
opt.pages = (1, 3)
with TestPipeline() as p:
mocks['find_file_pairs_grouped_by_parent_directory_or_name'].return_value = [
(PDF_FILE_1, XML_FILE_1)
]
_setup_mocks_for_pages(mocks, [1, 2])
configure_pipeline(p, opt)
assert mocks['convert_pdf_bytes_to_lxml'].called
assert mocks['convert_pdf_bytes_to_lxml'].call_args[1].get('page_range') == opt.pages
assert mocks['pdf_bytes_to_png_pages'].called
assert mocks['pdf_bytes_to_png_pages'].call_args[1].get('page_range') == opt.pages
class TestParseArgs(object):
def test_should_raise_error_without_arguments(self):
with pytest.raises(SystemExit):
parse_args([])
def test_should_not_raise_error_with_minimum_arguments(self):
parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-svg'])
def test_should_not_raise_error_with_lxml_path_instead_of_pdf_path(self):
parse_args(['--data-path=test', '--lxml-path=test', '--xml-path=test', '--save-svg'])
def test_should_raise_error_if_no_output_option_specified(self):
with pytest.raises(SystemExit):
parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test'])
def test_should_raise_error_if_pdf_and_lxml_path_are_specified(self):
with pytest.raises(SystemExit):
parse_args([
'--data-path=test', '--pdf-path=test', '--lxml-path=test', '--xml-path=test',
'--save-svg'
])
def test_should_raise_error_if_pdf_path_specified_without_xml_path(self):
with pytest.raises(SystemExit):
parse_args(['--data-path=test', '--pdf-path=test', '--save-svg'])
def test_should_not_raise_error_if_pdf_xml_file_list_specified_without_xml_path(self):
parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-svg'])
def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_path(self):
parse_args(['--data-path=test', '--pdf-path=test', '--save-lxml', '--xml-path=test'])
def test_should_not_raise_error_with_save_lxml_path_together_with_pdf_xml_file_list(self):
parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-lxml', '--xml-path=test'])
def test_should_raise_error_if_save_lxml_specified_without_pdf_path(self):
with pytest.raises(SystemExit):
parse_args(['--data-path=test', '--lxml-path=test', '--save-lxml', '--xml-path=test'])
def test_should_raise_error_if_save_png_is_specified_without_pdf_path(self):
with pytest.raises(SystemExit):
parse_args(['--data-path=test', '--lxml-path=test', '--save-png', '--xml-path=test'])
def test_should_not_raise_error_with_save_png_path_together_with_pdf_path(self):
parse_args(['--data-path=test', '--pdf-path=test', '--save-png', '--xml-path=test'])
def test_should_not_raise_error_with_save_png_path_together_with_pdf_xml_file_list(self):
parse_args(['--data-path=test', '--pdf-xml-file-list=test', '--save-png', '--xml-path=test'])
def test_should_raise_error_if_image_width_was_specified_without_image_height(self):
with pytest.raises(SystemExit):
parse_args([
'--data-path=test', '--pdf-path=test', '--xml-path=test',
'--save-png', '--image-width=100'
])
def test_should_raise_error_if_image_height_was_specified_without_image_width(self):
with pytest.raises(SystemExit):
parse_args([
'--data-path=test', '--pdf-path=test', '--xml-path=test',
'--save-png', '--image-height=100'
])
def test_should_not_raise_error_if_both_image_width_and_height_are_specified(self):
parse_args([
'--data-path=test', '--pdf-path=test', '--xml-path=test',
'--save-png', '--image-width=100', '--image-height=100'
])
def test_should_raise_error_if_save_tfrecords_specified_without_pdf_path(self):
with pytest.raises(SystemExit):
parse_args(['--data-path=test', '--lxml-path=test', '--xml-path=test', '--save-tfrecords'])
def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_path(self):
parse_args(['--data-path=test', '--pdf-path=test', '--xml-path=test', '--save-tfrecords'])
def test_should_not_raise_error_if_save_tfrecords_specified_with_pdf_xml_file_list(self):
parse_args([
'--data-path=test', '--pdf-xml-file-list=test', '--xml-path=test', '--save-tfrecords'
])
def test_should_have_none_page_range_by_default(self):
assert parse_args(MIN_ARGV).pages is None
def test_should_parse_pages_as_list(self):
assert parse_args(MIN_ARGV + ['--pages=1-3']).pages == (1, 3)