diff --git a/sciencebeam_gym/preprocess/get_output_files.py b/sciencebeam_gym/preprocess/get_output_files.py index 65fa9da80fb93045a63a63771c57d363e6959048..864a9a64bb655e8a5282e9738776f16357054616 100644 --- a/sciencebeam_gym/preprocess/get_output_files.py +++ b/sciencebeam_gym/preprocess/get_output_files.py @@ -7,6 +7,7 @@ from sciencebeam_gym.utils.file_list import ( ) from sciencebeam_gym.preprocess.preprocessing_utils import ( + get_or_validate_base_path, get_output_file ) @@ -29,7 +30,7 @@ def parse_args(argv=None): help='csv/tsv column (ignored for plain file list)' ) source.add_argument( - '--source-base-path', type=str, required=True, + '--source-base-path', type=str, required=False, help='base data path for source file urls' ) @@ -75,9 +76,12 @@ def run(opt): column=opt.source_file_column, limit=opt.limit ) + source_base_path = get_or_validate_base_path( + source_file_list, opt.source_base_path + ) target_file_list = get_output_file_list( - source_file_list, opt.source_base_path, opt.output_base_path, opt.output_file_suffix + source_file_list, source_base_path, opt.output_base_path, opt.output_file_suffix ) save_file_list( diff --git a/sciencebeam_gym/preprocess/get_output_files_test.py b/sciencebeam_gym/preprocess/get_output_files_test.py index 3fa0b5a98ee084e909498e3f35c1d50e445dee4a..295db38f6395fbb5c640d80b7951ec7f94df7c37 100644 --- a/sciencebeam_gym/preprocess/get_output_files_test.py +++ b/sciencebeam_gym/preprocess/get_output_files_test.py @@ -1,4 +1,6 @@ -from mock import patch +from mock import patch, ANY + +import pytest import sciencebeam_gym.preprocess.get_output_files as get_output_files from sciencebeam_gym.preprocess.get_output_files import ( @@ -10,11 +12,15 @@ from sciencebeam_gym.preprocess.get_output_files import ( SOME_ARGV = [ '--source-file-list=source.csv', - '--source-base-path=/source', '--output-file-list=output.csv', '--limit=10' ] +BASE_SOURCE_PATH = '/source' + +FILE_1 = BASE_SOURCE_PATH + '/file1' +FILE_2 = BASE_SOURCE_PATH + '/file2' + class TestGetOutputFileList(object): def test_should_return_output_file_with_path_and_change_ext(self): assert get_output_file_list( @@ -31,6 +37,7 @@ class TestRun(object): with patch.object(m, 'load_file_list') as load_file_list: with patch.object(m, 'get_output_file_list') as get_output_file_list_mock: with patch.object(m, 'save_file_list') as save_file_list: + load_file_list.return_value = [FILE_1, FILE_2] run(opt) load_file_list.assert_called_with( opt.source_file_list, @@ -39,7 +46,7 @@ class TestRun(object): ) get_output_file_list_mock.assert_called_with( load_file_list.return_value, - opt.source_base_path, + BASE_SOURCE_PATH, opt.output_base_path, opt.output_file_suffix ) @@ -49,6 +56,33 @@ class TestRun(object): column=opt.source_file_column ) + def test_should_raise_error_if_source_path_is_invalid(self): + m = get_output_files + opt = parse_args(SOME_ARGV) + opt.source_base_path = '/other/path' + with patch.object(m, 'load_file_list') as load_file_list: + with patch.object(m, 'get_output_file_list'): + with patch.object(m, 'save_file_list'): + with pytest.raises(AssertionError): + load_file_list.return_value = [FILE_1, FILE_2] + run(opt) + + def test_should_use_passed_in_source_path_if_valid(self): + m = get_output_files + opt = parse_args(SOME_ARGV) + opt.source_base_path = '/base' + with patch.object(m, 'load_file_list') as load_file_list: + with patch.object(m, 'get_output_file_list') as get_output_file_list_mock: + with patch.object(m, 'save_file_list'): + load_file_list.return_value = ['/base/source/file1', '/base/source/file2'] + run(opt) + get_output_file_list_mock.assert_called_with( + ANY, + opt.source_base_path, + ANY, + ANY + ) + class TestMain(object): def test_should_parse_args_and_call_run(self): m = get_output_files diff --git a/sciencebeam_gym/preprocess/preprocessing_utils.py b/sciencebeam_gym/preprocess/preprocessing_utils.py index 8db5c55370671794a35f8e2e0363e74720d1efd3..a35b15902d42d296b56f8db8c4107936c3ee48cf 100644 --- a/sciencebeam_gym/preprocess/preprocessing_utils.py +++ b/sciencebeam_gym/preprocess/preprocessing_utils.py @@ -228,6 +228,25 @@ def change_ext(path, old_ext, new_ext): else: return path + new_ext +def base_path_for_file_list(file_list): + common_prefix = os.path.commonprefix(file_list) + i = max(common_prefix.rfind('/'), common_prefix.rfind('\\')) + if i >= 0: + return common_prefix[:i] + else: + return '' + +def get_or_validate_base_path(file_list, base_path): + common_path = base_path_for_file_list(file_list) + if base_path: + if not common_path.startswith(base_path): + raise AssertionError( + "invalid base path '%s', common path is: '%s'" % (base_path, common_path) + ) + return base_path + else: + return common_path + def get_output_file(filename, source_base_path, output_base_path, output_file_suffix): return FileSystems.join( output_base_path, diff --git a/sciencebeam_gym/preprocess/preprocessing_utils_test.py b/sciencebeam_gym/preprocess/preprocessing_utils_test.py index 9edef8946bd6784118f1a6c7a4274378e8828809..210cefdbcfa388ae685ea4a4f102375f88f7dea0 100644 --- a/sciencebeam_gym/preprocess/preprocessing_utils_test.py +++ b/sciencebeam_gym/preprocess/preprocessing_utils_test.py @@ -1,5 +1,7 @@ from mock import patch, MagicMock, DEFAULT +import pytest + from lxml import etree from sciencebeam_gym.structured_document.svg import ( @@ -11,6 +13,8 @@ from sciencebeam_gym.preprocess.preprocessing_utils import ( group_file_pairs_by_parent_directory_or_name, convert_pdf_bytes_to_lxml, change_ext, + base_path_for_file_list, + get_or_validate_base_path, get_output_file, parse_page_range ) @@ -121,6 +125,55 @@ class TestChangeExt(object): def test_should_remove_gz_ext_before_replacing_ext(self): assert change_ext('file.pdf.gz', None, '.svg.zip') == 'file.svg.zip' +class TestBasePathForFileList(object): + def test_should_return_empty_string_if_file_list_is_empty(self): + assert base_path_for_file_list([]) == '' + + def test_should_return_empty_string_if_filename_is_empty(self): + assert base_path_for_file_list(['']) == '' + + def test_should_return_parent_directory_of_single_file(self): + assert base_path_for_file_list(['/base/path/1/file']) == '/base/path/1' + + def test_should_return_common_path_of_two_files(self): + assert base_path_for_file_list(['/base/path/1/file', '/base/path/2/file']) == '/base/path' + + def test_should_return_common_path_of_two_files_using_protocol(self): + assert base_path_for_file_list([ + 'a://base/path/1/file', 'a://base/path/2/file' + ]) == 'a://base/path' + + def test_should_return_common_path_of_two_files_using_forward_slash(self): + assert base_path_for_file_list([ + '\\base\\path\\1\\file', '\\base\\path\\2\\file' + ]) == '\\base\\path' + + def test_should_return_empty_string_if_no_common_path_was_found(self): + assert base_path_for_file_list(['a://base/path/1/file', 'b://base/path/2/file']) == '' + + def test_should_return_common_path_ignoring_partial_name_match(self): + assert base_path_for_file_list(['/base/path/file1', '/base/path/file2']) == '/base/path' + +class TestGetOrValidateBasePath(object): + def test_should_return_base_path_of_two_files_if_no_base_path_was_provided(self): + assert get_or_validate_base_path( + ['/base/path/1/file', '/base/path/2/file'], + None + ) == '/base/path' + + def test_should_return_passed_in_base_path_if_valid(self): + assert get_or_validate_base_path( + ['/base/path/1/file', '/base/path/2/file'], + '/base' + ) == '/base' + + def test_should_raise_error_if_passed_in_base_path_is_invalid(self): + with pytest.raises(AssertionError): + get_or_validate_base_path( + ['/base/path/1/file', '/base/path/2/file'], + '/base/other' + ) + class TestGetOutputFile(object): def test_should_return_output_file_with_path_and_change_ext(self): assert get_output_file(