diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..180e8c02e4170781b19d9d2038f9b7e0231b9da9 --- /dev/null +++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py @@ -0,0 +1,42 @@ +from itertools import repeat + +from sciencebeam_gym.models.text.feature_extractor import ( + structured_document_to_token_props, + token_props_list_to_features +) + +def _iter_tokens(structured_document): + for page in structured_document.get_pages(): + for line in structured_document.get_lines_of_page(page): + for token in structured_document.get_tokens_of_line(line): + yield token + +def annotate_structured_document_using_predictions( + structured_document, predictions, token_props_list=None): + """ + Annotates the structured document using the predicted tags. + + Args: + structured_document: the document that will be tagged + predictions: list of predicted tags + token_props_list: optional, used to verify that the correct token is being tagged + """ + + if token_props_list is None: + token_props_list = repeat(None) + for token, prediction, token_props in zip( + _iter_tokens(structured_document), + predictions, token_props_list + ): + + if token_props: + assert structured_document.get_text(token) == token_props['text'] + + if prediction: + structured_document.set_tag(token, prediction) + +def predict_and_annotate_structured_document(structured_document, model): + token_props = list(structured_document_to_token_props(structured_document)) + x = token_props_list_to_features(token_props) + y_pred = model.predict([x])[0] + annotate_structured_document_using_predictions(structured_document, y_pred, token_props) diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d781f4720e76746c3f3e5362963f167205b9506a --- /dev/null +++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py @@ -0,0 +1,95 @@ +from mock import MagicMock + +import pytest + +from sciencebeam_gym.structured_document import ( + SimpleStructuredDocument, + SimplePage, + SimpleLine, + SimpleToken +) + +from sciencebeam_gym.models.text.feature_extractor import ( + structured_document_to_token_props, + token_props_list_to_features +) + +from sciencebeam_gym.utils.bounding_box import ( + BoundingBox +) + +from sciencebeam_gym.models.text.crf.annotate_using_predictions import ( + annotate_structured_document_using_predictions, + predict_and_annotate_structured_document +) + +TAG_1 = 'tag1' + +TOKEN_TEXT_1 = 'token 1' +TOKEN_TEXT_2 = 'token 2' + +BOUNDING_BOX = BoundingBox(0, 0, 10, 10) + +class TestAnnotateStructuredDocumentUsingPredictions(object): + def test_should_not_fail_with_empty_document(self): + structured_document = SimpleStructuredDocument() + annotate_structured_document_using_predictions( + structured_document, + [] + ) + + def test_should_tag_single_token_using_prediction(self): + token_1 = SimpleToken(TOKEN_TEXT_1) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) + annotate_structured_document_using_predictions( + structured_document, + [TAG_1] + ) + assert structured_document.get_tag(token_1) == TAG_1 + + def test_should_tag_single_token_using_prediction_and_check_token_props(self): + token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) + structured_document = SimpleStructuredDocument(SimplePage( + lines=[SimpleLine([token_1])], + bounding_box=BOUNDING_BOX + )) + token_props_list = structured_document_to_token_props(structured_document) + annotate_structured_document_using_predictions( + structured_document, + [TAG_1], + token_props_list + ) + assert structured_document.get_tag(token_1) == TAG_1 + + def test_should_raise_error_if_token_props_do_not_match(self): + token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) + structured_document = SimpleStructuredDocument(SimplePage( + lines=[SimpleLine([token_1])], + bounding_box=BOUNDING_BOX + )) + token_props_list = list(structured_document_to_token_props(structured_document)) + token_props_list[0]['text'] = TOKEN_TEXT_2 + with pytest.raises(AssertionError): + annotate_structured_document_using_predictions( + structured_document, + [TAG_1], + token_props_list + ) + +class TestPredictAndAnnotateStructuredDocument(object): + def test_should_predict_and_annotate_single_token(self): + token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) + structured_document = SimpleStructuredDocument(SimplePage( + lines=[SimpleLine([token_1])], + bounding_box=BOUNDING_BOX + )) + model = MagicMock() + model.predict.return_value = [[TAG_1]] + token_props = list(structured_document_to_token_props(structured_document)) + X = [token_props_list_to_features(token_props)] + predict_and_annotate_structured_document( + structured_document, + model + ) + assert structured_document.get_tag(token_1) == TAG_1 + model.predict.assert_called_with(X)