Newer
Older
from __future__ import absolute_import
import os
import logging
from io import BytesIO
from functools import reduce # pylint: disable=W0622
from six import iteritems
from lxml import etree
from apache_beam.io.filesystems import FileSystems
from sciencebeam_gym.utils.xml import (
xml_from_string_with_recover
)
from sciencebeam_gym.utils.stopwatch import (
StopWatchRecorder
)
from sciencebeam_gym.utils.collection import (
groupby_to_dict,
sort_and_groupby_to_dict
)
from sciencebeam_gym.utils.pages_zip import (
save_pages
)
find_matching_filenames
)
from sciencebeam_gym.utils.file_path import (
relative_path
)
from sciencebeam_gym.preprocess.lxml_to_svg import (
iter_svg_pages_for_lxml
)
from sciencebeam_gym.structured_document.svg import (
SvgStructuredDocument
)
from sciencebeam_gym.preprocess.annotation.annotator import (
Annotator,
DEFAULT_ANNOTATORS
)
from sciencebeam_gym.alignment.align import (
native_enabled as align_native_enabled
)
from sciencebeam_gym.preprocess.annotation.matching_annotator import (
MatchingAnnotator
)
from sciencebeam_gym.preprocess.annotation.target_annotation import (
xml_root_to_target_annotations
)
from sciencebeam_gym.preprocess.visualize_svg_annotation import (
visualize_svg_annotations
)
from sciencebeam_gym.preprocess.blockify_annotations import (
annotation_document_page_to_annotation_blocks,
merge_blocks,
expand_blocks,
annotated_blocks_to_image
)
from sciencebeam_gym.pdf import (
PdfToLxmlWrapper,
PdfToPng
)
# deprecated, moved to sciencebeam_gym.utils.file_path
# pylint: disable=wrong-import-position, unused-import
from sciencebeam_gym.utils.file_path import (
join_if_relative_path,
)
# pylint: enable=wrong-import-position, unused-import
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def get_logger():
return logging.getLogger(__name__)
def group_files_by_parent_directory(filenames):
return groupby_to_dict(sorted(filenames), lambda x: os.path.dirname(x))
def get_ext(filename):
name, ext = os.path.splitext(filename)
if ext == '.gz':
ext = get_ext(name) + ext
return ext
def strip_ext(filename):
# strip of gz, assuming there will be another extension before .gz
if filename.endswith('.gz'):
filename = filename[:-3]
return os.path.splitext(filename)[0]
def group_files_by_name_excl_ext(filenames):
return sort_and_groupby_to_dict(filenames, strip_ext)
def zip_by_keys(*dict_list):
keys = reduce(lambda agg, v: agg | set(v.keys()), dict_list, set())
return (
[d.get(k) for d in dict_list]
for k in sorted(keys)
)
def group_file_pairs_by_parent_directory_or_name(files_by_type):
grouped_files_by_pattern = [
group_files_by_parent_directory(files) for files in files_by_type
]
for files_in_group_by_pattern in zip_by_keys(*grouped_files_by_pattern):
if all(len(files or []) == 1 for files in files_in_group_by_pattern):
yield tuple([files[0] for files in files_in_group_by_pattern])
else:
grouped_by_name = [
group_files_by_name_excl_ext(files or [])
for files in files_in_group_by_pattern
]
for files_by_name in zip_by_keys(*grouped_by_name):
if all(len(files or []) == 1 for files in files_by_name):
yield tuple([files[0] for files in files_by_name])
else:
get_logger().info(
'no exclusively matching files found: %s',
[files for files in files_by_name]
)
def find_file_pairs_grouped_by_parent_directory_or_name(patterns, limit=None):
matching_files_by_pattern = [
list(find_matching_filenames(pattern)) for pattern in patterns
]
get_logger().info(
'found number of files %s',
', '.join(
'%s: %d' % (pattern, len(files))
for pattern, files in zip(patterns, matching_files_by_pattern)
)
)
patterns_without_files = [
pattern
for pattern, files in zip(patterns, matching_files_by_pattern)
if len(files) == 0
]
if patterns_without_files:
raise RuntimeError('no files found for: %s' % patterns_without_files)
return group_file_pairs_by_parent_directory_or_name(
matching_files_by_pattern
)
def convert_pdf_bytes_to_lxml(pdf_content, path=None, page_range=None):
args = '-blocks -noImageInline -noImage -fullFontName'.split()
if page_range:
args += ['-f', str(page_range[0]), '-l', str(page_range[1])]
stop_watch_recorder.start('convert to lxml')
lxml_content = PdfToLxmlWrapper().process_input(
pdf_content,
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
)
stop_watch_recorder.stop()
get_logger().info(
'converted to lxml: path=%s, pdf size=%s, lxml size=%s, timings=[%s]',
path, format(len(pdf_content), ','), format(len(lxml_content), ','),
stop_watch_recorder
)
return lxml_content
def convert_and_annotate_lxml_content(lxml_content, xml_content, xml_mapping, name=None):
stop_watch_recorder = StopWatchRecorder()
stop_watch_recorder.start('parse lxml')
lxml_root = etree.fromstring(lxml_content)
# use a more lenient way to parse xml as xml errors are not uncomment
stop_watch_recorder.start('parse xml')
xml_root = xml_from_string_with_recover(xml_content)
stop_watch_recorder.start('extract target annotations')
target_annotations = xml_root_to_target_annotations(
xml_root,
xml_mapping
)
stop_watch_recorder.stop()
annotators = DEFAULT_ANNOTATORS + [MatchingAnnotator(
target_annotations,
use_tag_begin_prefix=True
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
)]
annotator = Annotator(annotators)
stop_watch_recorder.start('convert to svg')
svg_roots = list(iter_svg_pages_for_lxml(lxml_root))
stop_watch_recorder.start('annotate svg')
annotator.annotate(SvgStructuredDocument(svg_roots))
stop_watch_recorder.start('add visualisation')
svg_roots = [
visualize_svg_annotations(svg_root)
for svg_root in svg_roots
]
stop_watch_recorder.stop()
get_logger().info(
'processed: name=%s, lxml size=%s, xml size=%s, timings=[%s] (native align impl=%s)',
name, format(len(lxml_content), ','), format(len(xml_content), ','),
stop_watch_recorder, align_native_enabled
)
return svg_roots
def change_ext(path, old_ext, new_ext):
if old_ext is None:
old_ext = os.path.splitext(path)[1]
if old_ext == '.gz':
path = path[:-len(old_ext)]
old_ext = os.path.splitext(path)[1]
if old_ext and path.endswith(old_ext):
return path[:-len(old_ext)] + new_ext
else:
return path + new_ext
def base_path_for_file_list(file_list):
common_prefix = os.path.commonprefix(file_list)
i = max(common_prefix.rfind('/'), common_prefix.rfind('\\'))
if i >= 0:
return common_prefix[:i]
else:
return ''
def get_or_validate_base_path(file_list, base_path):
common_path = base_path_for_file_list(file_list)
if base_path:
if not common_path.startswith(base_path):
raise AssertionError(
"invalid base path '%s', common path is: '%s'" % (base_path, common_path)
)
return base_path
else:
return common_path
def get_output_file(filename, source_base_path, output_base_path, output_file_suffix):
return FileSystems.join(
output_base_path,
change_ext(
relative_path(source_base_path, filename),
None, output_file_suffix
)
)
def save_svg_roots(output_filename, svg_pages):
return save_pages(output_filename, '.svg', (
etree.tostring(svg_page)
for svg_page in svg_pages
))
def pdf_bytes_to_png_pages(pdf_bytes, dpi, image_size, page_range=None):
pdf_to_png = PdfToPng(dpi=dpi, image_size=image_size, page_range=page_range)
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
return (
fp.read()
for fp in pdf_to_png.iter_pdf_bytes_to_png_fp(pdf_bytes)
)
def svg_page_to_blockified_png_bytes(svg_page, color_map, image_size=None):
structured_document = SvgStructuredDocument(svg_page)
blocks = expand_blocks(
merge_blocks(
annotation_document_page_to_annotation_blocks(
structured_document,
structured_document.get_pages()[0]
)
)
)
viewbox = svg_page.attrib.get('viewBox')
if not viewbox:
raise RuntimeError(
'viewbox missing on svg, available attributes: %s' % svg_page.attrib.keys()
)
_, _, width, height = viewbox.split()
image = annotated_blocks_to_image(
blocks, color_map,
width=float(width), height=float(height), background='white',
scale_to_size=image_size
)
out = BytesIO()
image.save(out, 'png')
return out.getvalue()
def filter_list_props_by_indices(d, indices, list_props):
return {
k: (
[x for i, x in enumerate(v) if i in indices]
if k in list_props
else v
)
for k, v in iteritems(d)
}
def get_page_indices_with_min_annotation_percentage(
annotation_evaluation, min_annotation_percentage):
return [
i
for i, page_evaluation in enumerate(annotation_evaluation)
if page_evaluation['percentage'].get(None) <= (1 - min_annotation_percentage)
]
def parse_page_range(s):
s = s.strip()
if not s:
return None
a = tuple([int(x) for x in s.split('-')])
if len(a) == 1:
return (a[0], a[0])
elif len(a) == 2:
return a
else:
raise TypeError('invalid page range: %s' % s)