From bc0c20852d7977c4df30897c7a968c94828d8d85 Mon Sep 17 00:00:00 2001 From: Daniel Ecer <de-code@users.noreply.github.com> Date: Thu, 1 Feb 2018 08:54:48 +0000 Subject: [PATCH] fixed token indices incorrect when str filter removes characters --- .../annotation/matching_annotator.py | 6 ++--- .../annotation/matching_annotator_test.py | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/sciencebeam_gym/preprocess/annotation/matching_annotator.py b/sciencebeam_gym/preprocess/annotation/matching_annotator.py index 83cc965..b148828 100644 --- a/sciencebeam_gym/preprocess/annotation/matching_annotator.py +++ b/sciencebeam_gym/preprocess/annotation/matching_annotator.py @@ -73,9 +73,9 @@ class SequenceWrapper(object): self.str_filter_f = str_filter_f self.tokens = tokens self.token_str_list = [structured_document.get_text(t) or '' for t in tokens] - self.tokens_as_str = ' '.join(self.token_str_list) if str_filter_f: - self.tokens_as_str = str_filter_f(self.tokens_as_str) + self.token_str_list = [str_filter_f(s) for s in self.token_str_list] + self.tokens_as_str = ' '.join(self.token_str_list) def tokens_between(self, index_range): start, end = index_range @@ -307,7 +307,7 @@ class TargetAnnotationMatchFinder(object): get_logger().debug( 'all_matches (bonding=%s): %s', self.target_annotation.bonding, all_matches ) - if not self.target_annotation.bonding or len(all_matches) > 1: + if not self.target_annotation.bonding or len(all_matches) > 1 or self.target_annotation.name == 'author': for m in all_matches: yield m else: diff --git a/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py b/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py index 1af4e3b..0bfb597 100644 --- a/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py +++ b/sciencebeam_gym/preprocess/annotation/matching_annotator_test.py @@ -15,6 +15,7 @@ from sciencebeam_gym.preprocess.annotation.target_annotation import ( from sciencebeam_gym.preprocess.annotation.matching_annotator import ( normalise_str, MatchingAnnotator, + SequenceWrapper, THIN_SPACE, EN_DASH, EM_DASH @@ -72,6 +73,31 @@ class TestNormaliseStr(object): def test_should_replace_en_dash_with_hyphen(self): assert normalise_str(EN_DASH) == '-' +class TestSequenceWrapper(object): + def test_should_find_all_tokens_without_str_filter(self): + text = 'this is matching' + tokens = _tokens_for_text(text) + doc = _document_for_tokens([tokens]) + seq = SequenceWrapper(doc, tokens) + assert str(seq) == text + assert list(seq.tokens_between((0, len(text)))) == tokens + + def test_should_find_tokens_between_without_str_filter(self): + text = 'this is matching' + tokens = _tokens_for_text(text) + doc = _document_for_tokens([tokens]) + seq = SequenceWrapper(doc, tokens) + assert str(seq) == text + assert list(seq.tokens_between((6, 7))) == [tokens[1]] + + def test_should_find_tokens_between_adjusted_indices_due_to_str_filter(self): + text = 'this is matching' + tokens = _tokens_for_text(text) + doc = _document_for_tokens([tokens]) + seq = SequenceWrapper(doc, tokens, str_filter_f=lambda s: s.replace('th', '')) + assert str(seq) == 'is is matching' + assert list(seq.tokens_between((4, 5))) == [tokens[1]] + class TestMatchingAnnotator(object): def test_should_not_fail_on_empty_document(self): doc = SimpleStructuredDocument(lines=[]) -- GitLab