import pytest import apache_beam as beam from apache_beam.io.filesystems import FileSystems from sciencebeam_gym.beam_utils.io import ( find_matching_filenames ) from sciencebeam_gym.utils.tfrecord import ( iter_read_tfrecord_file_as_dict_list ) from sciencebeam_gym.beam_utils.testing import ( BeamTest, TestPipeline ) from sciencebeam_gym.preprocess.preprocessing_transforms import ( WritePropsToTFRecord ) TFRECORDS_PATH = '.temp/test-data' KEY_1 = b'key1' KEY_2 = b'key2' VALUE_1 = b'value 1' VALUE_2 = b'value 2' @pytest.mark.slow class TestWritePropsToTFRecord(BeamTest): def test_should_write_single_entry(self): dict_list = [{KEY_1: VALUE_1}] with TestPipeline() as p: _ = ( p | beam.Create(dict_list) | WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x]) ) filenames = list(find_matching_filenames(TFRECORDS_PATH + '*')) assert len(filenames) == 1 records = list(iter_read_tfrecord_file_as_dict_list(filenames[0])) assert records == dict_list FileSystems.delete(filenames) def test_should_write_multiple_entries(self): dict_list = [ {KEY_1: VALUE_1}, {KEY_2: VALUE_2} ] with TestPipeline() as p: _ = ( p | beam.Create(dict_list) | WritePropsToTFRecord(TFRECORDS_PATH, lambda x: [x]) ) filenames = list(find_matching_filenames(TFRECORDS_PATH + '*')) assert len(filenames) == 1 records = list(iter_read_tfrecord_file_as_dict_list(filenames[0])) assert records == dict_list FileSystems.delete(filenames)