diff --git a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py index 41d8b9d2e5a402b345b56d7acc305d4a603871ea..7f4cb32595484ea2e3ebd674f4a684c06b53f60a 100644 --- a/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py +++ b/sciencebeam_gym/models/text/crf/crfsuite_training_pipeline.py @@ -2,6 +2,7 @@ import logging import argparse import pickle from functools import partial +from concurrent.futures import ThreadPoolExecutor from six import raise_from @@ -117,18 +118,22 @@ def train_model(file_list, cv_file_list, page_range=None, progress=True): token_props_list_by_document = [] total = len(file_list) with tqdm(total=total, leave=False, desc='loading files', disable=not progress) as pbar: - for filename, cv_filename in zip(file_list, cv_file_list): - token_props_list_by_document.append( + with ThreadPoolExecutor(max_workers=50) as executor: + process_fn = lambda (filename, cv_filename): ( load_and_convert_to_token_props(filename, cv_filename, page_range=page_range) ) - pbar.update(1) + for result in executor.map(process_fn, zip(file_list, cv_file_list)): + token_props_list_by_document.append(result) + pbar.update(1) X = [token_props_list_to_features(x) for x in token_props_list_by_document] y = [token_props_list_to_labels(x) for x in token_props_list_by_document] model = CrfSuiteModel() + get_logger().info('training model (with %d documents)', len(X)) model.fit(X, y) return serialize_model(model) def save_model(output_filename, model_bytes): + get_logger().info('saving model to %s', output_filename) save_file_content(output_filename, model_bytes) def run(opt):