diff --git a/sciencebeam_gym/inference_model/extract_from_annotated_document.py b/sciencebeam_gym/inference_model/extract_from_annotated_document.py index 2bb3fd3868567d0bb388fd0249c8a596eef93209..aa06c4c2650818aa1ca1703a5c76458c4cf6265c 100644 --- a/sciencebeam_gym/inference_model/extract_from_annotated_document.py +++ b/sciencebeam_gym/inference_model/extract_from_annotated_document.py @@ -8,16 +8,20 @@ def get_logger(): return logging.getLogger(__name__) class ExtractedItem(object): - def __init__(self, tag, text, tag_prefix=None): + def __init__(self, tag, text, tokens=None, tag_prefix=None, sub_items=None): self.tag = tag self.tag_prefix = tag_prefix self.text = text + self.tokens = tokens or [] + self.sub_items = sub_items or [] def extend(self, other_item): return ExtractedItem( self.tag, self.text + '\n' + other_item.text, - tag_prefix=self.tag_prefix + tokens=self.tokens + other_item.tokens, + tag_prefix=self.tag_prefix, + sub_items=self.sub_items + other_item.sub_items ) def get_lines(structured_document): @@ -25,12 +29,14 @@ def get_lines(structured_document): for line in structured_document.get_lines_of_page(page): yield line -def extract_from_annotated_tokens(structured_document, tokens, tag_scope=None): +def extract_from_annotated_tokens(structured_document, tokens, tag_scope=None, level=None): previous_tokens = [] previous_tag = None previous_tag_prefix = None for token in tokens: - tag_prefix, tag = structured_document.get_tag_prefix_and_value(token, scope=tag_scope) + tag_prefix, tag = structured_document.get_tag_prefix_and_value( + token, scope=tag_scope, level=level + ) if not previous_tokens: previous_tokens = [token] previous_tag = tag @@ -41,6 +47,7 @@ def extract_from_annotated_tokens(structured_document, tokens, tag_scope=None): yield ExtractedItem( previous_tag, ' '.join(structured_document.get_text(t) for t in previous_tokens), + tokens=previous_tokens, tag_prefix=previous_tag_prefix ) previous_tokens = [token] @@ -50,9 +57,22 @@ def extract_from_annotated_tokens(structured_document, tokens, tag_scope=None): yield ExtractedItem( previous_tag, ' '.join(structured_document.get_text(t) for t in previous_tokens), - tag_prefix=previous_tag_prefix + tokens=previous_tokens, + tag_prefix=previous_tag_prefix ) +def with_sub_items(structured_document, extracted_item, tag_scope=None): + return ExtractedItem( + extracted_item.tag, + extracted_item.text, + tokens=extracted_item.tokens, + tag_prefix=extracted_item.tag_prefix, + sub_items=list(extract_from_annotated_tokens( + structured_document, extracted_item.tokens, + tag_scope=tag_scope, level=2 + )) + ) + def extract_from_annotated_lines(structured_document, lines, tag_scope=None): previous_item = None for line in lines: @@ -62,12 +82,12 @@ def extract_from_annotated_lines(structured_document, lines, tag_scope=None): if previous_item.tag == item.tag and item.tag_prefix != B_TAG_PREFIX: previous_item = previous_item.extend(item) else: - yield previous_item + yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope) previous_item = item else: previous_item = item if previous_item is not None: - yield previous_item + yield with_sub_items(structured_document, previous_item, tag_scope=tag_scope) def extract_from_annotated_document(structured_document, tag_scope=None): return extract_from_annotated_lines( diff --git a/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py b/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py index c46a8fa9fe9783553af3efa87af7273e099c3d35..847f0e87081039e9be1dbe3760bfe71d76913e45 100644 --- a/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py +++ b/sciencebeam_gym/inference_model/extract_from_annotated_document_test.py @@ -53,6 +53,12 @@ def annotated_tokens(tokens, tag): def annotated_line(tokens, tag): return with_tag(to_line(tokens), tag) +def _token_with_sub_tag(text, tag=None, tag_prefix=None, sub_tag=None, sub_tag_prefix=None): + token = SimpleToken(text, tag=tag, tag_prefix=tag_prefix) + if sub_tag: + token.set_tag(sub_tag, prefix=sub_tag_prefix, level=2) + return token + class TestExtractFromAnnotatedDocument(object): def test_should_not_fail_on_empty_document(self): structured_document = SimpleStructuredDocument() @@ -153,3 +159,98 @@ class TestExtractFromAnnotatedDocument(object): (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3])), (TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3])) ] + + def test_should_extract_sub_tags_from_single_item(self): + tokens = [ + _token_with_sub_tag( + VALUE_1, + tag=TAG_1, tag_prefix=B_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_2, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_3, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX + ) + ] + structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)]) + extracted_items = list(extract_from_annotated_document(structured_document)) + result = [ + ( + x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items] + ) for x in extracted_items + ] + get_logger().debug('result: %s', result) + assert result == [ + ( + TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]), + [ + (TAG_2, ' '.join([VALUE_1, VALUE_2])), + (TAG_3, VALUE_3) + ] + ) + ] + + def test_should_extract_sub_tags_from_multiple_items(self): + tokens = [ + _token_with_sub_tag( + VALUE_1, + tag=TAG_1, tag_prefix=B_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_2, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=I_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_3, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX + ), + + _token_with_sub_tag( + VALUE_1, + tag=TAG_1, tag_prefix=B_TAG_PREFIX, + sub_tag=TAG_2, sub_tag_prefix=B_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_2, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_3, sub_tag_prefix=B_TAG_PREFIX + ), + _token_with_sub_tag( + VALUE_3, + tag=TAG_1, tag_prefix=I_TAG_PREFIX, + sub_tag=TAG_3, sub_tag_prefix=I_TAG_PREFIX + ) + ] + structured_document = SimpleStructuredDocument(lines=[SimpleLine(tokens)]) + extracted_items = list(extract_from_annotated_document(structured_document)) + result = [ + ( + x.tag, x.text, [(sub.tag, sub.text) for sub in x.sub_items] + ) for x in extracted_items + ] + get_logger().debug('result: %s', result) + assert result == [ + ( + TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]), + [ + (TAG_2, ' '.join([VALUE_1, VALUE_2])), + (TAG_3, VALUE_3) + ] + ), + ( + TAG_1, ' '.join([VALUE_1, VALUE_2, VALUE_3]), + [ + (TAG_2, VALUE_1), + (TAG_3, ' '.join([VALUE_2, VALUE_3])) + ] + ) + ] diff --git a/sciencebeam_gym/inference_model/extract_to_xml.py b/sciencebeam_gym/inference_model/extract_to_xml.py index bd66ee97b25b8509c5378bdf956873af45a99c39..1696056a23bb96ac34a185da55cdba9c39594c9f 100644 --- a/sciencebeam_gym/inference_model/extract_to_xml.py +++ b/sciencebeam_gym/inference_model/extract_to_xml.py @@ -28,6 +28,14 @@ class XmlPaths(object): AUTHOR = 'front/article-meta/contrib-group/contrib/name' AUTHOR_AFF = 'front/article-meta/contrib-group/aff' +class SubTags(object): + AUTHOR_SURNAME = 'surname' + AUTHOR_GIVEN_NAMES = 'givennames' + +class SubXmlPaths(object): + AUTHOR_SURNAME = 'surname' + AUTHOR_GIVEN_NAMES = 'given-names' + def get_logger(): return logging.getLogger(__name__) @@ -53,27 +61,29 @@ def create_node_recursive(xml_root, path, exists_ok=False): parent_node.append(node) return node -def create_xml_text(xml_root, path, text): +def create_and_append_xml_node(xml_root, path): parent, base = rsplit_xml_path(path) - parent_node = create_node_recursive(xml_root, parent, exists_ok=True) + parent_node = ( + create_node_recursive(xml_root, parent, exists_ok=True) + if parent + else xml_root + ) node = etree.Element(base) - node.text = text parent_node.append(node) return node +def create_xml_text(xml_root, path, text): + node = create_and_append_xml_node(xml_root, path) + node.text = text + return node + class XmlMapping(object): - def __init__(self, xml_path, single_node=False): + def __init__(self, xml_path, single_node=False, sub_mapping=None): self.xml_path = xml_path self.single_node = single_node + self.sub_mapping = sub_mapping -def extracted_items_to_xml(extracted_items): - xml_mapping = { - Tags.TITLE: XmlMapping(XmlPaths.TITLE, single_node=True), - Tags.ABSTRACT: XmlMapping(XmlPaths.ABSTRACT, single_node=True), - Tags.AUTHOR: XmlMapping(XmlPaths.AUTHOR), - Tags.AUTHOR_AFF: XmlMapping(XmlPaths.AUTHOR_AFF) - } - xml_root = E.article() +def _extract_items(parent_node, extracted_items, xml_mapping): previous_tag = None for extracted_item in extracted_items: tag = extracted_item.tag @@ -83,8 +93,11 @@ def extracted_items_to_xml(extracted_items): get_logger().warning('tag not configured: %s', tag) continue path = mapping_entry.xml_path - if mapping_entry.single_node: - node = create_node_recursive(xml_root, path, exists_ok=True) + if extracted_item.sub_items and mapping_entry.sub_mapping: + node = create_and_append_xml_node(parent_node, path) + _extract_items(node, extracted_item.sub_items, mapping_entry.sub_mapping) + elif mapping_entry.single_node: + node = create_node_recursive(parent_node, path, exists_ok=True) if node.text is None: node.text = extracted_item.text elif previous_tag == tag: @@ -92,8 +105,21 @@ def extracted_items_to_xml(extracted_items): else: get_logger().debug('ignoring tag %s, after tag %s', tag, previous_tag) else: - create_xml_text(xml_root, path, extracted_item.text) + create_xml_text(parent_node, path, extracted_item.text) previous_tag = tag + +def extracted_items_to_xml(extracted_items): + xml_mapping = { + Tags.TITLE: XmlMapping(XmlPaths.TITLE, single_node=True), + Tags.ABSTRACT: XmlMapping(XmlPaths.ABSTRACT, single_node=True), + Tags.AUTHOR: XmlMapping(XmlPaths.AUTHOR, sub_mapping={ + SubTags.AUTHOR_GIVEN_NAMES: XmlMapping(SubXmlPaths.AUTHOR_GIVEN_NAMES), + SubTags.AUTHOR_SURNAME: XmlMapping(SubXmlPaths.AUTHOR_SURNAME) + }), + Tags.AUTHOR_AFF: XmlMapping(XmlPaths.AUTHOR_AFF) + } + xml_root = E.article() + _extract_items(xml_root, extracted_items, xml_mapping) return xml_root def extract_structured_document_to_xml(structured_document, tag_scope=None): diff --git a/sciencebeam_gym/inference_model/extract_to_xml_test.py b/sciencebeam_gym/inference_model/extract_to_xml_test.py index f2552578b15ef3d7b0aadc5ce85c7fd0094e50f5..ce006d22f51fa5d582676d63790e6ec01abb19b8 100644 --- a/sciencebeam_gym/inference_model/extract_to_xml_test.py +++ b/sciencebeam_gym/inference_model/extract_to_xml_test.py @@ -18,6 +18,8 @@ from sciencebeam_gym.inference_model.extract_to_xml import ( extracted_items_to_xml, Tags, XmlPaths, + SubTags, + SubXmlPaths, main ) @@ -71,6 +73,19 @@ class TestExtractedItemsToXml(object): assert xml_root is not None assert get_text_content_list(xml_root.findall(XmlPaths.AUTHOR)) == [TEXT_1, TEXT_2] + def test_should_extract_author_surname_and_given_names_from_single_author(self): + xml_root = extracted_items_to_xml([ + ExtractedItem(Tags.AUTHOR, TEXT_1, sub_items=[ + ExtractedItem(SubTags.AUTHOR_GIVEN_NAMES, TEXT_2), + ExtractedItem(SubTags.AUTHOR_SURNAME, TEXT_3) + ]) + ]) + assert xml_root is not None + author = xml_root.find(XmlPaths.AUTHOR) + assert author is not None + assert get_text_content(author.find(SubXmlPaths.AUTHOR_GIVEN_NAMES)) == TEXT_2 + assert get_text_content(author.find(SubXmlPaths.AUTHOR_SURNAME)) == TEXT_3 + class TestMain(object): def test_should_extract_from_simple_annotated_document(self): with TemporaryDirectory() as path: diff --git a/sciencebeam_gym/structured_document/__init__.py b/sciencebeam_gym/structured_document/__init__.py index 79b0c0c40d736c6237e6cc63eb84ea7c77f250dd..6f8ccf1a9b4782e6476b68a98c2d1007ab81ba5f 100644 --- a/sciencebeam_gym/structured_document/__init__.py +++ b/sciencebeam_gym/structured_document/__init__.py @@ -87,8 +87,8 @@ class AbstractStructuredDocument(object, with_metaclass(ABCMeta)): other_structured_document, other_token ) - def get_tag_prefix_and_value(self, parent, scope=None): - return split_tag_prefix(self.get_tag(parent, scope=scope)) + def get_tag_prefix_and_value(self, parent, scope=None, level=None): + return split_tag_prefix(self.get_tag(parent, scope=scope, level=level)) def get_tag_value(self, parent, scope=None): return self.get_tag_prefix_and_value(parent, scope=scope)[1]