Skip to content
Snippets Groups Projects
Commit 10675447 authored by Daniel Ecer's avatar Daniel Ecer
Browse files

optionally use cv predictions when predicting using the crf model

parent 072a625b
No related branches found
No related tags found
No related merge requests found
......@@ -12,11 +12,12 @@ from sciencebeam_gym.utils.tf import (
from sciencebeam_gym.models.text.feature_extractor import (
structured_document_to_token_props,
token_props_list_to_features,
merge_with_cv_structured_document,
NONE_TAG
)
from sciencebeam_gym.structured_document.lxml import (
LxmlStructuredDocument
from sciencebeam_gym.structured_document.structured_document_loader import (
load_lxml_structured_document
)
def get_logger():
......@@ -60,12 +61,18 @@ def predict_and_annotate_structured_document(structured_document, model):
def parse_args(argv=None):
parser = argparse.ArgumentParser('Annotated LXML using CRF model')
source = parser.add_mutually_exclusive_group(required=True)
source = parser.add_argument_group('source')
source.add_argument(
'--lxml-path', type=str, required=False,
help='path to lxml document'
)
cv_source = parser.add_argument_group('CV source')
cv_source.add_argument(
'--cv-lxml-path', type=str, required=False,
help='path to lxml document with cv predicted tags'
)
parser.add_argument(
'--crf-model', type=str, required=True,
help='path to saved crf model'
......@@ -83,19 +90,26 @@ def parse_args(argv=None):
return parser.parse_args(argv)
def load_crf_model(path):
with FileIO(path, 'rb') as crf_model_f:
return pickle.load(crf_model_f)
def main(argv=None):
args = parse_args(argv)
if args.debug:
logging.getLogger().setLevel('DEBUG')
with FileIO(args.lxml_path, 'rb') as lxml_f:
structured_document = LxmlStructuredDocument(
etree.parse(lxml_f)
structured_document = load_lxml_structured_document(args.lxml_path)
if args.cv_lxml_path:
cv_structured_document = load_lxml_structured_document(args.cv_lxml_path)
structured_document = merge_with_cv_structured_document(
structured_document,
cv_structured_document
)
with FileIO(args.crf_model, 'rb') as crf_model_f:
model = pickle.load(crf_model_f)
model = load_crf_model(args.crf_model)
predict_and_annotate_structured_document(
structured_document,
......
......@@ -31,7 +31,8 @@ from sciencebeam_gym.preprocess.preprocessing_utils import (
from sciencebeam_gym.models.text.feature_extractor import (
structured_document_to_token_props,
token_props_list_to_features,
token_props_list_to_labels
token_props_list_to_labels,
merge_with_cv_structured_document
)
from sciencebeam_gym.models.text.crf.crfsuite_model import (
......@@ -42,8 +43,6 @@ from sciencebeam_gym.beam_utils.io import (
save_file_content
)
CV_TAG_SCOPE = 'cv'
def get_logger():
return logging.getLogger(__name__)
......@@ -99,12 +98,9 @@ def load_and_convert_to_token_props(filename, cv_filename, page_range=None):
structured_document = load_structured_document(filename, page_range=page_range)
if cv_filename:
cv_structured_document = load_structured_document(cv_filename, page_range=page_range)
structured_document.merge_with(
cv_structured_document,
partial(
merge_token_tag,
target_scope=CV_TAG_SCOPE
)
structured_document = merge_with_cv_structured_document(
structured_document,
cv_structured_document
)
return list(structured_document_to_token_props(
structured_document
......
......@@ -10,14 +10,17 @@ from sciencebeam_gym.structured_document import (
SimpleToken
)
from sciencebeam_gym.models.text.feature_extractor import (
CV_TAG_SCOPE
)
import sciencebeam_gym.models.text.crf.crfsuite_training_pipeline as crfsuite_training_pipeline
from sciencebeam_gym.models.text.crf.crfsuite_training_pipeline import (
load_and_convert_to_token_props,
train_model,
save_model,
run,
main,
CV_TAG_SCOPE
main
)
SOURCE_FILE_LIST_PATH = '.temp/source-file-list.lst'
......
from functools import partial
from sciencebeam_gym.structured_document import (
merge_token_tag
)
NONE_TAG = 'O'
CV_TAG_SCOPE = 'cv'
def structured_document_to_token_props(structured_document):
pages = list(structured_document.get_pages())
......@@ -94,3 +100,13 @@ def remove_labels_from_token_props_list(token_props_list):
def token_props_list_to_labels(token_props_list):
return [token_props.get('tag') or NONE_TAG for token_props in token_props_list]
def merge_with_cv_structured_document(structured_document, cv_structured_document):
structured_document.merge_with(
cv_structured_document,
partial(
merge_token_tag,
target_scope=CV_TAG_SCOPE
)
)
return structured_document
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment