From a7c9dd7c68c651c59392a456bce0696edae9c689 Mon Sep 17 00:00:00 2001 From: Daniel Ecer <de-code@users.noreply.github.com> Date: Mon, 8 Jan 2018 20:15:12 +0000 Subject: [PATCH] added main to annotation of lxml document using crf model; ignore 'none tag' --- .../text/crf/annotate_using_predictions.py | 74 ++++++++++++++++++- .../crf/annotate_using_predictions_test.py | 12 ++- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py index 180e8c0..e809e7c 100644 --- a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py +++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py @@ -1,10 +1,27 @@ +import argparse +import logging +import pickle from itertools import repeat +from lxml import etree + +from sciencebeam_gym.utils.tf import ( + FileIO +) + from sciencebeam_gym.models.text.feature_extractor import ( structured_document_to_token_props, - token_props_list_to_features + token_props_list_to_features, + NONE_TAG +) + +from sciencebeam_gym.structured_document.lxml import ( + LxmlStructuredDocument ) +def get_logger(): + return logging.getLogger(__name__) + def _iter_tokens(structured_document): for page in structured_document.get_pages(): for line in structured_document.get_lines_of_page(page): @@ -32,7 +49,7 @@ def annotate_structured_document_using_predictions( if token_props: assert structured_document.get_text(token) == token_props['text'] - if prediction: + if prediction and prediction != NONE_TAG: structured_document.set_tag(token, prediction) def predict_and_annotate_structured_document(structured_document, model): @@ -40,3 +57,56 @@ def predict_and_annotate_structured_document(structured_document, model): x = token_props_list_to_features(token_props) y_pred = model.predict([x])[0] annotate_structured_document_using_predictions(structured_document, y_pred, token_props) + +def parse_args(argv=None): + parser = argparse.ArgumentParser('Annotated LXML using CRF model') + source = parser.add_mutually_exclusive_group(required=True) + source.add_argument( + '--lxml-path', type=str, required=False, + help='path to lxml document' + ) + + parser.add_argument( + '--crf-model', type=str, required=True, + help='path to saved crf model' + ) + + parser.add_argument( + '--output-path', type=str, required=True, + help='output path to annotated document' + ) + + parser.add_argument( + '--debug', action='store_true', default=False, + help='enable debug output' + ) + + return parser.parse_args(argv) + +def main(argv=None): + args = parse_args(argv) + + if args.debug: + logging.getLogger().setLevel('DEBUG') + + with FileIO(args.lxml_path, 'rb') as lxml_f: + structured_document = LxmlStructuredDocument( + etree.parse(lxml_f) + ) + + with FileIO(args.crf_model, 'rb') as crf_model_f: + model = pickle.load(crf_model_f) + + predict_and_annotate_structured_document( + structured_document, + model + ) + + get_logger().info('writing result to: %s', args.output_path) + with FileIO(args.output_path, 'w') as out_f: + out_f.write(etree.tostring(structured_document.root)) + +if __name__ == '__main__': + logging.basicConfig(level='INFO') + + main() diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py index d781f47..a652251 100644 --- a/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py +++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions_test.py @@ -11,7 +11,8 @@ from sciencebeam_gym.structured_document import ( from sciencebeam_gym.models.text.feature_extractor import ( structured_document_to_token_props, - token_props_list_to_features + token_props_list_to_features, + NONE_TAG ) from sciencebeam_gym.utils.bounding_box import ( @@ -47,6 +48,15 @@ class TestAnnotateStructuredDocumentUsingPredictions(object): ) assert structured_document.get_tag(token_1) == TAG_1 + def test_should_not_tag_using_none_tag(self): + token_1 = SimpleToken(TOKEN_TEXT_1) + structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])]) + annotate_structured_document_using_predictions( + structured_document, + [NONE_TAG] + ) + assert structured_document.get_tag(token_1) is None + def test_should_tag_single_token_using_prediction_and_check_token_props(self): token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX) structured_document = SimpleStructuredDocument(SimplePage( -- GitLab