Skip to content
Snippets Groups Projects
Commit a7c9dd7c authored by Daniel Ecer's avatar Daniel Ecer
Browse files

added main to annotation of lxml document using crf model; ignore 'none tag'

parent f5e69e91
No related branches found
No related tags found
No related merge requests found
import argparse
import logging
import pickle
from itertools import repeat
from lxml import etree
from sciencebeam_gym.utils.tf import (
FileIO
)
from sciencebeam_gym.models.text.feature_extractor import (
structured_document_to_token_props,
token_props_list_to_features
token_props_list_to_features,
NONE_TAG
)
from sciencebeam_gym.structured_document.lxml import (
LxmlStructuredDocument
)
def get_logger():
return logging.getLogger(__name__)
def _iter_tokens(structured_document):
for page in structured_document.get_pages():
for line in structured_document.get_lines_of_page(page):
......@@ -32,7 +49,7 @@ def annotate_structured_document_using_predictions(
if token_props:
assert structured_document.get_text(token) == token_props['text']
if prediction:
if prediction and prediction != NONE_TAG:
structured_document.set_tag(token, prediction)
def predict_and_annotate_structured_document(structured_document, model):
......@@ -40,3 +57,56 @@ def predict_and_annotate_structured_document(structured_document, model):
x = token_props_list_to_features(token_props)
y_pred = model.predict([x])[0]
annotate_structured_document_using_predictions(structured_document, y_pred, token_props)
def parse_args(argv=None):
parser = argparse.ArgumentParser('Annotated LXML using CRF model')
source = parser.add_mutually_exclusive_group(required=True)
source.add_argument(
'--lxml-path', type=str, required=False,
help='path to lxml document'
)
parser.add_argument(
'--crf-model', type=str, required=True,
help='path to saved crf model'
)
parser.add_argument(
'--output-path', type=str, required=True,
help='output path to annotated document'
)
parser.add_argument(
'--debug', action='store_true', default=False,
help='enable debug output'
)
return parser.parse_args(argv)
def main(argv=None):
args = parse_args(argv)
if args.debug:
logging.getLogger().setLevel('DEBUG')
with FileIO(args.lxml_path, 'rb') as lxml_f:
structured_document = LxmlStructuredDocument(
etree.parse(lxml_f)
)
with FileIO(args.crf_model, 'rb') as crf_model_f:
model = pickle.load(crf_model_f)
predict_and_annotate_structured_document(
structured_document,
model
)
get_logger().info('writing result to: %s', args.output_path)
with FileIO(args.output_path, 'w') as out_f:
out_f.write(etree.tostring(structured_document.root))
if __name__ == '__main__':
logging.basicConfig(level='INFO')
main()
......@@ -11,7 +11,8 @@ from sciencebeam_gym.structured_document import (
from sciencebeam_gym.models.text.feature_extractor import (
structured_document_to_token_props,
token_props_list_to_features
token_props_list_to_features,
NONE_TAG
)
from sciencebeam_gym.utils.bounding_box import (
......@@ -47,6 +48,15 @@ class TestAnnotateStructuredDocumentUsingPredictions(object):
)
assert structured_document.get_tag(token_1) == TAG_1
def test_should_not_tag_using_none_tag(self):
token_1 = SimpleToken(TOKEN_TEXT_1)
structured_document = SimpleStructuredDocument(lines=[SimpleLine([token_1])])
annotate_structured_document_using_predictions(
structured_document,
[NONE_TAG]
)
assert structured_document.get_tag(token_1) is None
def test_should_tag_single_token_using_prediction_and_check_token_props(self):
token_1 = SimpleToken(TOKEN_TEXT_1, bounding_box=BOUNDING_BOX)
structured_document = SimpleStructuredDocument(SimplePage(
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment