From 10675447e2bf2b34bfcd607a5cb40a8bb355816b Mon Sep 17 00:00:00 2001
From: Daniel Ecer <de-code@users.noreply.github.com>
Date: Wed, 10 Jan 2018 19:51:08 +0000
Subject: [PATCH] optionally use cv predictions when predicting using the crf
 model

---
 .../text/crf/annotate_using_predictions.py    | 30 ++++++++++++++-----
 .../text/crf/crfsuite_training_pipeline.py    | 14 ++++-----
 .../crf/crfsuite_training_pipeline_test.py    |  7 +++--
 .../models/text/feature_extractor.py          | 16 ++++++++++
 4 files changed, 48 insertions(+), 19 deletions(-)

diff --git a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py
index e809e7c..2dc8585 100644
--- a/sciencebeam_gym/models/text/crf/annotate_using_predictions.py
+++ b/sciencebeam_gym/models/text/crf/annotate_using_predictions.py
@@ -12,11 +12,12 @@ from sciencebeam_gym.utils.tf import (
 from sciencebeam_gym.models.text.feature_extractor import (
   structured_document_to_token_props,
   token_props_list_to_features,
+  merge_with_cv_structured_document,
   NONE_TAG
 )
 
-from sciencebeam_gym.structured_document.lxml import (
-  LxmlStructuredDocument
+from sciencebeam_gym.structured_document.structured_document_loader import (
+  load_lxml_structured_document
 )
 
 def get_logger():
@@ -60,12 +61,18 @@ def predict_and_annotate_structured_document(structured_document, model):
 
 def parse_args(argv=None):
   parser = argparse.ArgumentParser('Annotated LXML using CRF model')
-  source = parser.add_mutually_exclusive_group(required=True)
+  source = parser.add_argument_group('source')
   source.add_argument(
     '--lxml-path', type=str, required=False,
     help='path to lxml document'
   )
 
+  cv_source = parser.add_argument_group('CV source')
+  cv_source.add_argument(
+    '--cv-lxml-path', type=str, required=False,
+    help='path to lxml document with cv predicted tags'
+  )
+
   parser.add_argument(
     '--crf-model', type=str, required=True,
     help='path to saved crf model'
@@ -83,19 +90,26 @@ def parse_args(argv=None):
 
   return parser.parse_args(argv)
 
+def load_crf_model(path):
+  with FileIO(path, 'rb') as crf_model_f:
+    return pickle.load(crf_model_f)
+
 def main(argv=None):
   args = parse_args(argv)
 
   if args.debug:
     logging.getLogger().setLevel('DEBUG')
 
-  with FileIO(args.lxml_path, 'rb') as lxml_f:
-    structured_document = LxmlStructuredDocument(
-      etree.parse(lxml_f)
+  structured_document = load_lxml_structured_document(args.lxml_path)
+
+  if args.cv_lxml_path:
+    cv_structured_document = load_lxml_structured_document(args.cv_lxml_path)
+    structured_document = merge_with_cv_structured_document(
+      structured_document,
+      cv_structured_document
     )
 
-  with FileIO(args.crf_model, 'rb') as crf_model_f:
-    model = pickle.load(crf_model_f)
+  model = load_crf_model(args.crf_model)
 
   predict_and_annotate_structured_document(
     structured_document,
diff --git a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py
index 411dbac..0506347 100644
--- a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py
+++ b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py
@@ -31,7 +31,8 @@ from sciencebeam_gym.preprocess.preprocessing_utils import (
 from sciencebeam_gym.models.text.feature_extractor import (
   structured_document_to_token_props,
   token_props_list_to_features,
-  token_props_list_to_labels
+  token_props_list_to_labels,
+  merge_with_cv_structured_document
 )
 
 from sciencebeam_gym.models.text.crf.crfsuite_model import (
@@ -42,8 +43,6 @@ from sciencebeam_gym.beam_utils.io import (
   save_file_content
 )
 
-CV_TAG_SCOPE = 'cv'
-
 def get_logger():
   return logging.getLogger(__name__)
 
@@ -99,12 +98,9 @@ def load_and_convert_to_token_props(filename, cv_filename, page_range=None):
     structured_document = load_structured_document(filename, page_range=page_range)
     if cv_filename:
       cv_structured_document = load_structured_document(cv_filename, page_range=page_range)
-      structured_document.merge_with(
-        cv_structured_document,
-        partial(
-          merge_token_tag,
-          target_scope=CV_TAG_SCOPE
-        )
+      structured_document = merge_with_cv_structured_document(
+        structured_document,
+        cv_structured_document
       )
     return list(structured_document_to_token_props(
       structured_document
diff --git a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py
index 5e65799..d2ff85e 100644
--- a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py
+++ b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline_test.py
@@ -10,14 +10,17 @@ from sciencebeam_gym.structured_document import (
   SimpleToken
 )
 
+from sciencebeam_gym.models.text.feature_extractor import (
+  CV_TAG_SCOPE
+)
+
 import sciencebeam_gym.models.text.crf.crfsuite_training_pipeline as crfsuite_training_pipeline
 from sciencebeam_gym.models.text.crf.crfsuite_training_pipeline import (
   load_and_convert_to_token_props,
   train_model,
   save_model,
   run,
-  main,
-  CV_TAG_SCOPE
+  main
 )
 
 SOURCE_FILE_LIST_PATH = '.temp/source-file-list.lst'
diff --git a/sciencebeam_gym/models/text/feature_extractor.py b/sciencebeam_gym/models/text/feature_extractor.py
index 3c05918..7c9d70b 100644
--- a/sciencebeam_gym/models/text/feature_extractor.py
+++ b/sciencebeam_gym/models/text/feature_extractor.py
@@ -1,5 +1,11 @@
+from functools import partial
+
+from sciencebeam_gym.structured_document import (
+  merge_token_tag
+)
 
 NONE_TAG = 'O'
+CV_TAG_SCOPE = 'cv'
 
 def structured_document_to_token_props(structured_document):
   pages = list(structured_document.get_pages())
@@ -94,3 +100,13 @@ def remove_labels_from_token_props_list(token_props_list):
 
 def token_props_list_to_labels(token_props_list):
   return [token_props.get('tag') or NONE_TAG for token_props in token_props_list]
+
+def merge_with_cv_structured_document(structured_document, cv_structured_document):
+  structured_document.merge_with(
+    cv_structured_document,
+    partial(
+      merge_token_tag,
+      target_scope=CV_TAG_SCOPE
+    )
+  )
+  return structured_document
-- 
GitLab