From 10675447e2bf2b34bfcd607a5cb40a8bb355816b Mon Sep 17 00:00:00 2001 From: Daniel Ecer <de-code@users.noreply.github.com> Date: Wed, 10 Jan 2018 19:51:08 +0000 Subject: [PATCH] optionally use cv predictions when predicting using the crf model --- .../text/crf/annotate_using_predictions.py | 30 ++++++++++++++----- .../text/crf/crfsuite_training_pipeline.py | 14 ++++----- .../crf/crfsuite_training_pipeline_test.py | 7 +++-- .../models/text/feature_extractor.py | 16 ++++++++++ 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py index e809e7c..2dc8585 100644 --- a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py +++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py @@ -12,11 +12,12 @@ from sciencebeam_gym.utils.tf import ( from sciencebeam_gym.models.text.feature_extractor import ( structured_document_to_token_props, token_props_list_to_features, + merge_with_cv_structured_document, NONE_TAG ) -from sciencebeam_gym.structured_document.lxml import ( - LxmlStructuredDocument +from sciencebeam_gym.structured_document.structured_document_loader import ( + load_lxml_structured_document ) def get_logger(): @@ -60,12 +61,18 @@ def predict_and_annotate_structured_document(structured_document, model): def parse_args(argv=None): parser = argparse.ArgumentParser('Annotated LXML using CRF model') - source = parser.add_mutually_exclusive_group(required=True) + source = parser.add_argument_group('source') source.add_argument( '--lxml-path', type=str, required=False, help='path to lxml document' ) + cv_source = parser.add_argument_group('CV source') + cv_source.add_argument( + '--cv-lxml-path', type=str, required=False, + help='path to lxml document with cv predicted tags' + ) + parser.add_argument( '--crf-model', type=str, required=True, help='path to saved crf model' @@ -83,19 +90,26 @@ def parse_args(argv=None): return parser.parse_args(argv) +def load_crf_model(path): + with FileIO(path, 'rb') as crf_model_f: + return pickle.load(crf_model_f) + def main(argv=None): args = parse_args(argv) if args.debug: logging.getLogger().setLevel('DEBUG') - with FileIO(args.lxml_path, 'rb') as lxml_f: - structured_document = LxmlStructuredDocument( - etree.parse(lxml_f) + structured_document = load_lxml_structured_document(args.lxml_path) + + if args.cv_lxml_path: + cv_structured_document = load_lxml_structured_document(args.cv_lxml_path) + structured_document = merge_with_cv_structured_document( + structured_document, + cv_structured_document ) - with FileIO(args.crf_model, 'rb') as crf_model_f: - model = pickle.load(crf_model_f) + model = load_crf_model(args.crf_model) predict_and_annotate_structured_document( structured_document, diff --git a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py index 411dbac..0506347 100644 --- a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py +++ b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py @@ -31,7 +31,8 @@ from sciencebeam_gym.preprocess.preprocessing_utils import ( from sciencebeam_gym.models.text.feature_extractor import ( structured_document_to_token_props, token_props_list_to_features, - token_props_list_to_labels + token_props_list_to_labels, + merge_with_cv_structured_document ) from sciencebeam_gym.models.text.crf.crfsuite_model import ( @@ -42,8 +43,6 @@ from sciencebeam_gym.beam_utils.io import ( save_file_content ) -CV_TAG_SCOPE = 'cv' - def get_logger(): return logging.getLogger(__name__) @@ -99,12 +98,9 @@ def load_and_convert_to_token_props(filename, cv_filename, page_range=None): structured_document = load_structured_document(filename, page_range=page_range) if cv_filename: cv_structured_document = load_structured_document(cv_filename, page_range=page_range) - structured_document.merge_with( - cv_structured_document, - partial( - merge_token_tag, - target_scope=CV_TAG_SCOPE - ) + structured_document = merge_with_cv_structured_document( + structured_document, + cv_structured_document ) return list(structured_document_to_token_props( structured_document diff --git a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py index 5e65799..d2ff85e 100644 --- a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py +++ b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py @@ -10,14 +10,17 @@ from sciencebeam_gym.structured_document import ( SimpleToken ) +from sciencebeam_gym.models.text.feature_extractor import ( + CV_TAG_SCOPE +) + import sciencebeam_gym.models.text.crf.crfsuite_training_pipeline as crfsuite_training_pipeline from sciencebeam_gym.models.text.crf.crfsuite_training_pipeline import ( load_and_convert_to_token_props, train_model, save_model, run, - main, - CV_TAG_SCOPE + main ) SOURCE_FILE_LIST_PATH = '.temp/source-file-list.lst' diff --git a/sciencebeam_gym/models/text/feature_extractor.py b/sciencebeam_gym/models/text/feature_extractor.py index 3c05918..7c9d70b 100644 --- a/sciencebeam_gym/models/text/feature_extractor.py +++ b/sciencebeam_gym/models/text/feature_extractor.py @@ -1,5 +1,11 @@ +from functools import partial + +from sciencebeam_gym.structured_document import ( + merge_token_tag +) NONE_TAG = 'O' +CV_TAG_SCOPE = 'cv' def structured_document_to_token_props(structured_document): pages = list(structured_document.get_pages()) @@ -94,3 +100,13 @@ def remove_labels_from_token_props_list(token_props_list): def token_props_list_to_labels(token_props_list): return [token_props.get('tag') or NONE_TAG for token_props in token_props_list] + +def merge_with_cv_structured_document(structured_document, cv_structured_document): + structured_document.merge_with( + cv_structured_document, + partial( + merge_token_tag, + target_scope=CV_TAG_SCOPE + ) + ) + return structured_document -- GitLab