Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
crfsuite_training_pipeline_test.py 2.67 KiB
from mock import patch, ANY

from sciencebeam_gym.utils.collection import (
  to_namedtuple
)

import sciencebeam_gym.models.text.crf.crfsuite_training_pipeline as crfsuite_training_pipeline
from sciencebeam_gym.models.text.crf.crfsuite_training_pipeline import (
  train_model,
  save_model,
  run,
  main
)

SOURCE_FILE_LIST_PATH = '.temp/source-file-list.lst'

FILE_1 = 'file1.pdf'
FILE_2 = 'file2.pdf'
UNICODE_FILE_1 = u'file1\u1234.pdf'

MODEL_DATA = b'model data'

PAGE_RANGE = (2, 3)

class TestTrainModel(object):
  def test_should_train_on_single_file(self):
    m = crfsuite_training_pipeline
    with patch.object(m, 'load_structured_document') as load_structured_document_mock:
      with patch.object(m, 'CrfSuiteModel') as CrfSuiteModel_mock:
        with patch.object(m, 'pickle') as pickle:
          with patch.object(m, 'token_props_list_to_features') as _:
            train_model([FILE_1], page_range=PAGE_RANGE)
            load_structured_document_mock.assert_called_with(FILE_1, page_range=PAGE_RANGE)
            model = CrfSuiteModel_mock.return_value
            model.fit.assert_called_with(ANY, ANY)
            pickle.dumps.assert_called_with(model)

class TestSaveModel(object):
  def test_should_call_save_content(self):
    m = crfsuite_training_pipeline
    with patch.object(m, 'save_file_content') as save_file_content:
      save_model(FILE_1, MODEL_DATA)
      save_file_content.assert_called_with(FILE_1, MODEL_DATA)

class TestRun(object):
  def test_should_train_on_single_file(self):
    m = crfsuite_training_pipeline
    opt = to_namedtuple(
      source_file_list=SOURCE_FILE_LIST_PATH,
      source_file_column='url',
      output_path=FILE_1,
      limit=2,
      pages=PAGE_RANGE
    )
    with patch.object(m, 'load_file_list') as load_file_list:
      with patch.object(m, 'train_model') as train_model_mock:
        with patch.object(m, 'save_model') as save_model_mock:
          run(opt)
          load_file_list.assert_called_with(
            opt.source_file_list, opt.source_file_column, limit=opt.limit
          )
          train_model_mock.assert_called_with(
            load_file_list.return_value,
            page_range=PAGE_RANGE
          )
          save_model_mock.assert_called_with(
            opt.output_path,
            train_model_mock.return_value
          )

class TestMain(object):
  def test_should(self):
    argv = ['--source-file-list', SOURCE_FILE_LIST_PATH]
    with patch.object(crfsuite_training_pipeline, 'parse_args') as parse_args_mock:
      with patch.object(crfsuite_training_pipeline, 'run') as run_mock:
        main(argv)
        parse_args_mock.assert_called_with(argv)
        run_mock.assert_called_with(parse_args_mock.return_value)