diff --git a/sciencebeam_gym/preprocess/annotation/matching_annotator.py b/sciencebeam_gym/preprocess/annotation/matching_annotator.py index 83cc9655b3a957290cbf6605500cef518d2b0b77..b148828c547d0a42a3b391e51f41b839486a83f5 100644 --- a/sciencebeam_gym/preprocess/annotation/matching_annotator.py +++ b/sciencebeam_gym/preprocess/annotation/matching_annotator.py @@ -73,9 +73,9 @@ class SequenceWrapper(object): self.str_filter_f = str_filter_f self.tokens = tokens self.token_str_list = [structured_document.get_text(t) or '' for t in tokens] - self.tokens_as_str = ' '.join(self.token_str_list) if str_filter_f: - self.tokens_as_str = str_filter_f(self.tokens_as_str) + self.token_str_list = [str_filter_f(s) for s in self.token_str_list] + self.tokens_as_str = ' '.join(self.token_str_list) def tokens_between(self, index_range): start, end = index_range @@ -307,7 +307,7 @@ class TargetAnnotationMatchFinder(object): get_logger().debug( 'all_matches (bonding=%s): %s', self.target_annotation.bonding, all_matches ) - if not self.target_annotation.bonding or len(all_matches) > 1: + if not self.target_annotation.bonding or len(all_matches) > 1 or self.target_annotation.name == 'author': for m in all_matches: yield m else: diff --git a/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py b/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py index 1af4e3bcd3b0f8549a6e7fc73bd75c269ebfaf23..0bfb597279f8e53d067489966a03f621fbba49e1 100644 --- a/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py +++ b/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py @@ -15,6 +15,7 @@ from sciencebeam_gym.preprocess.annotation.target_annotation import ( from sciencebeam_gym.preprocess.annotation.matching_annotator import ( normalise_str, MatchingAnnotator, + SequenceWrapper, THIN_SPACE, EN_DASH, EM_DASH @@ -72,6 +73,31 @@ class TestNormaliseStr(object): def test_should_replace_en_dash_with_hyphen(self): assert normalise_str(EN_DASH) == '-' +class TestSequenceWrapper(object): + def test_should_find_all_tokens_without_str_filter(self): + text = 'this is matching' + tokens = _tokens_for_text(text) + doc = _document_for_tokens([tokens]) + seq = SequenceWrapper(doc, tokens) + assert str(seq) == text + assert list(seq.tokens_between((0, len(text)))) == tokens + + def test_should_find_tokens_between_without_str_filter(self): + text = 'this is matching' + tokens = _tokens_for_text(text) + doc = _document_for_tokens([tokens]) + seq = SequenceWrapper(doc, tokens) + assert str(seq) == text + assert list(seq.tokens_between((6, 7))) == [tokens[1]] + + def test_should_find_tokens_between_adjusted_indices_due_to_str_filter(self): + text = 'this is matching' + tokens = _tokens_for_text(text) + doc = _document_for_tokens([tokens]) + seq = SequenceWrapper(doc, tokens, str_filter_f=lambda s: s.replace('th', '')) + assert str(seq) == 'is is matching' + assert list(seq.tokens_between((4, 5))) == [tokens[1]] + class TestMatchingAnnotator(object): def test_should_not_fail_on_empty_document(self): doc = SimpleStructuredDocument(lines=[])