Skip to content
Snippets Groups Projects
Unverified Commit 60d8e0c7 authored by Daniel Ecer's avatar Daniel Ecer Committed by GitHub
Browse files

improve memory consumption when finding bounding boxes (2nd iteration) (#426)

* log progress while matching figure image to each pdf page

* yield empty image object match to keep total

* added simple MultiLevelCache

* implemented disk cache

* added TestRegisterPickleFunction

* increased default memory cache size
parent 082e5182
No related branches found
No related tags found
No related merge requests found
......@@ -6,4 +6,5 @@ mypy==0.910
nose==1.3.7
pytest==6.2.5
pytest-watch==4.2.0
types-cachetools==4.2.4
types-requests==2.25.9
cachetools==4.2.4
diskcache==5.2.1
Flask==2.0.2
gevent==21.8.0
gunicorn==20.1.0
......@@ -23,4 +25,5 @@ scikit-image==0.18.3
sklearn-crfsuite==0.3.6
scikit-learn>=0.24.2
tensorflow-transform==1.3.0
typing-extensions==3.10.0.2
tqdm==4.62.3
......@@ -7,8 +7,10 @@ import os
from datetime import datetime
from io import BytesIO
from tempfile import TemporaryDirectory
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple, cast
from typing import Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple, cast
import cachetools
import diskcache
import matplotlib.cm
import PIL.Image
import pdf2image
......@@ -21,6 +23,7 @@ from sciencebeam_utils.utils.progress_logger import logging_tqdm
from sciencebeam_utils.utils.file_list import load_file_list
from sciencebeam_gym.utils.bounding_box import BoundingBox
from sciencebeam_gym.utils.cache import MultiLevelCache
from sciencebeam_gym.utils.collections import get_inverted_dict
from sciencebeam_gym.utils.cv import load_pil_image_from_file
from sciencebeam_gym.utils.io import copy_file, read_bytes, write_bytes, write_text
......@@ -28,10 +31,12 @@ from sciencebeam_gym.utils.image_object_matching import (
DEFAULT_MAX_BOUNDING_BOX_ADJUSTMENT_ITERATIONS,
DEFAULT_MAX_HEIGHT,
DEFAULT_MAX_WIDTH,
EMPTY_IMAGE_LIST_OBJECT_MATCH_RESULT,
get_bounding_box_for_image,
get_image_list_object_match,
get_sift_detector_matcher
get_sift_detector_matcher,
iter_current_best_image_list_object_match
)
from sciencebeam_gym.utils.pickle_reg import register_pickle_functions
from sciencebeam_gym.utils.visualize_bounding_box import draw_bounding_box
from sciencebeam_gym.utils.pipeline import (
AbstractPipelineFactory,
......@@ -56,6 +61,8 @@ DEFAULT_OUTPUT_JSON_FILE_SUFFIX = '.annotation.coco.json'
DEFAULT_OUTPUT_XML_FILE_SUFFIX = '.annotated.xml'
DEFAULT_OUTPUT_ANNOTATED_IMAGES_DIR__SUFFIX = '-annotated-images'
DEFAULT_MEMORY_CACHE_SIZE = 512
def get_images_from_pdf(pdf_path: str, pdf_scale_to: Optional[int]) -> List[PIL.Image.Image]:
with TemporaryDirectory(suffix='-pdf') as temp_dir:
......@@ -298,6 +305,12 @@ def get_args_parser():
type=int,
help='If specified, rendered PDF pages will be scaled to specified value (longest side)'
)
parser.add_argument(
'--memory-cache-size',
type=int,
default=DEFAULT_MEMORY_CACHE_SIZE,
help='Number of items to keep in the memory cache'
)
parser.add_argument(
'--max-internal-width',
type=int,
......@@ -419,6 +432,15 @@ def get_xml_root_with_update_nsmap(
return updated_root
def get_cache(temp_dir: str, memory_cache_size: int):
register_pickle_functions()
LOGGER.info('using cache dir: %r (memory_cache_size: %r)', temp_dir, memory_cache_size)
return MultiLevelCache([
cachetools.LRUCache(maxsize=memory_cache_size),
diskcache.Cache(directory=temp_dir)
])
def process_single_document(
pdf_path: str,
image_paths: Optional[List[str]],
......@@ -430,6 +452,8 @@ def process_single_document(
use_grayscale: bool,
ignore_unmatched_graphics: bool,
max_bounding_box_adjustment_iterations: int,
temp_dir: str,
memory_cache_size: int,
selected_categories: Sequence[str] = tuple([]),
output_xml_path: Optional[str] = None,
output_annotated_images_path: Optional[str] = None
......@@ -469,70 +493,74 @@ def process_single_document(
category_id_by_name: Dict[str, int] = {}
annotations: List[dict] = []
missing_annotations: List[dict] = []
image_cache: Dict[Any, Any] = {}
image_cache = get_cache(temp_dir, memory_cache_size=memory_cache_size)
LOGGER.info(
'start processing images(%r): %d',
os.path.basename(pdf_path), len(image_descriptors)
)
for image_descriptor in logging_tqdm(
image_descriptors,
with logging_tqdm(
total=len(image_descriptors) * len(pdf_images),
logger=LOGGER,
desc='processing images(%r):' % os.path.basename(pdf_path)
):
LOGGER.debug('processing article image: %r', image_descriptor.href)
template_image = PIL.Image.open(BytesIO(read_bytes_with_optional_gz_extension(
image_descriptor.path
)))
LOGGER.debug('template_image: %s x %s', template_image.width, template_image.height)
image_list_match_result = get_image_list_object_match(
pdf_images,
template_image,
object_detector_matcher=object_detector_matcher,
image_cache=image_cache,
template_image_id=f'{id(image_descriptor)}-{image_descriptor.href}',
max_width=max_internal_width,
max_height=max_internal_height,
use_grayscale=use_grayscale,
max_bounding_box_adjustment_iterations=max_bounding_box_adjustment_iterations
)
category_id = category_id_by_name.get(image_descriptor.category_name)
if category_id is None:
category_id = 1 + len(category_id_by_name)
category_id_by_name[image_descriptor.category_name] = category_id
annotation = {
'file_name': image_descriptor.href,
'category_id': category_id
}
if image_descriptor.related_element_id:
annotation['related_element_id'] = image_descriptor.related_element_id
if not image_list_match_result:
if not ignore_unmatched_graphics:
raise GraphicImageNotFoundError(
'image bounding box not found for: %r' % image_descriptor.href
)
missing_annotations.append(annotation)
continue
page_index = image_list_match_result.target_image_index
pdf_image = pdf_images[page_index]
pdf_page_bounding_box = get_bounding_box_for_image(pdf_image)
bounding_box = image_list_match_result.target_bounding_box
assert bounding_box
LOGGER.debug('bounding_box: %s', bounding_box)
normalized_bounding_box = bounding_box.intersection(pdf_page_bounding_box).round()
annotation = {
**annotation,
'image_id': (1 + page_index),
'bbox': normalized_bounding_box.to_list(),
'_score': image_list_match_result.score
}
annotations.append(annotation)
if image_descriptor.element is not None:
image_descriptor.element.attrib[COORDS_ATTRIB_NAME] = (
format_coords_attribute_value(
page_number=1 + page_index,
bounding_box=normalized_bounding_box
) as pbar:
for image_descriptor in image_descriptors:
LOGGER.debug('processing article image: %r', image_descriptor.href)
template_image = PIL.Image.open(BytesIO(read_bytes_with_optional_gz_extension(
image_descriptor.path
)))
LOGGER.debug('template_image: %s x %s', template_image.width, template_image.height)
image_list_match_result = EMPTY_IMAGE_LIST_OBJECT_MATCH_RESULT
for _image_list_match_result in iter_current_best_image_list_object_match(
pdf_images,
template_image,
object_detector_matcher=object_detector_matcher,
image_cache=image_cache,
template_image_id=f'{id(image_descriptor)}-{image_descriptor.href}',
max_width=max_internal_width,
max_height=max_internal_height,
use_grayscale=use_grayscale,
max_bounding_box_adjustment_iterations=max_bounding_box_adjustment_iterations
):
image_list_match_result = _image_list_match_result
pbar.update(1)
category_id = category_id_by_name.get(image_descriptor.category_name)
if category_id is None:
category_id = 1 + len(category_id_by_name)
category_id_by_name[image_descriptor.category_name] = category_id
annotation = {
'file_name': image_descriptor.href,
'category_id': category_id
}
if image_descriptor.related_element_id:
annotation['related_element_id'] = image_descriptor.related_element_id
if not image_list_match_result:
if not ignore_unmatched_graphics:
raise GraphicImageNotFoundError(
'image bounding box not found for: %r' % image_descriptor.href
)
missing_annotations.append(annotation)
continue
page_index = image_list_match_result.target_image_index
pdf_image = pdf_images[page_index]
pdf_page_bounding_box = get_bounding_box_for_image(pdf_image)
bounding_box = image_list_match_result.target_bounding_box
assert bounding_box
LOGGER.debug('bounding_box: %s', bounding_box)
normalized_bounding_box = bounding_box.intersection(pdf_page_bounding_box).round()
annotation = {
**annotation,
'image_id': (1 + page_index),
'bbox': normalized_bounding_box.to_list(),
'_score': image_list_match_result.score
}
annotations.append(annotation)
if image_descriptor.element is not None:
image_descriptor.element.attrib[COORDS_ATTRIB_NAME] = (
format_coords_attribute_value(
page_number=1 + page_index,
bounding_box=normalized_bounding_box
)
)
)
if output_annotated_images_path:
LOGGER.info('saving annotated images to: %r', output_annotated_images_path)
save_annotated_images(
......@@ -594,6 +622,7 @@ class FindBoundingBoxPipelineFactory(AbstractPipelineFactory[FindBoundingBoxItem
self.save_annotated_images_enabled = args.save_annotated_images
self.selected_categories = args.categories
self.pdf_scale_to = args.pdf_scale_to
self.memory_cache_size = args.memory_cache_size
self.max_internal_width = args.max_internal_width
self.max_internal_height = args.max_internal_height
self.use_grayscale = args.use_grayscale
......@@ -612,21 +641,24 @@ class FindBoundingBoxPipelineFactory(AbstractPipelineFactory[FindBoundingBoxItem
if self.save_annotated_images_enabled
else None
)
process_single_document(
pdf_path=item.pdf_file,
image_paths=item.image_files,
xml_path=item.xml_file,
output_json_path=output_json_file,
selected_categories=self.selected_categories,
pdf_scale_to=self.pdf_scale_to,
max_internal_width=self.max_internal_width,
max_internal_height=self.max_internal_height,
use_grayscale=self.use_grayscale,
ignore_unmatched_graphics=self.ignore_unmatched_graphics,
output_xml_path=output_xml_file,
output_annotated_images_path=output_annotated_images_path,
max_bounding_box_adjustment_iterations=self.max_bounding_box_adjustment_iterations
)
with TemporaryDirectory(suffix='-find-bbox') as temp_dir:
process_single_document(
temp_dir=temp_dir,
pdf_path=item.pdf_file,
image_paths=item.image_files,
xml_path=item.xml_file,
output_json_path=output_json_file,
selected_categories=self.selected_categories,
pdf_scale_to=self.pdf_scale_to,
memory_cache_size=self.memory_cache_size,
max_internal_width=self.max_internal_width,
max_internal_height=self.max_internal_height,
use_grayscale=self.use_grayscale,
ignore_unmatched_graphics=self.ignore_unmatched_graphics,
output_xml_path=output_xml_file,
output_annotated_images_path=output_annotated_images_path,
max_bounding_box_adjustment_iterations=self.max_bounding_box_adjustment_iterations
)
def get_item_list(self):
args = self.args
......
from typing import Optional, Sequence, TypeVar
from typing_extensions import Protocol
T_Key = TypeVar('T_Key', contravariant=True)
T_Value = TypeVar('T_Value')
class CacheProtocol(Protocol[T_Key, T_Value]):
def get(self, key: T_Key) -> Optional[T_Value]:
pass
def __setitem__(self, key: T_Key, value: T_Value):
pass
def __delitem__(self, key: T_Key):
pass
class SimpleDictCache(CacheProtocol):
def __init__(self):
super().__init__()
self._data = {}
def get(self, key: T_Key) -> Optional[T_Value]:
return self._data.get(key)
def __setitem__(self, key: T_Key, value: T_Value):
self._data[key] = value
def __delitem__(self, key: T_Key):
del self._data[key]
class MultiLevelCache(CacheProtocol):
def __init__(self, cache_list: Sequence[CacheProtocol]):
super().__init__()
self.cache_list = cache_list
def get(self, key: T_Key) -> Optional[T_Value]:
for cache in self.cache_list:
value = cache.get(key)
if value is not None:
return value
return None
def __setitem__(self, key: T_Key, value: T_Value):
for cache in self.cache_list:
cache[key] = value
def __delitem__(self, key: T_Key):
for cache in self.cache_list:
try:
del cache[key]
except KeyError:
pass
......@@ -522,6 +522,11 @@ def iter_image_list_object_match(
**kwargs
)
if not match_result:
# we need to yield empty result to keep the total iterations
yield ImageListObjectMatchResult(
target_image_index=target_image_index,
match_result=EMPTY_IMAGE_OBJECT_MATCH_RESULT
)
continue
yield ImageListObjectMatchResult(
target_image_index=target_image_index,
......@@ -529,10 +534,10 @@ def iter_image_list_object_match(
)
def get_image_list_object_match(
def iter_current_best_image_list_object_match(
*args,
**kwargs
) -> ImageListObjectMatchResult:
) -> Iterable[ImageListObjectMatchResult]:
best_image_list_object_match = EMPTY_IMAGE_LIST_OBJECT_MATCH_RESULT
for image_list_object_match in iter_image_list_object_match(*args, **kwargs):
LOGGER.debug(
......@@ -544,4 +549,14 @@ def get_image_list_object_match(
> best_image_list_object_match.match_result.sort_key
):
best_image_list_object_match = image_list_object_match
yield best_image_list_object_match
def get_image_list_object_match(
*args,
**kwargs
) -> ImageListObjectMatchResult:
best_image_list_object_match = EMPTY_IMAGE_LIST_OBJECT_MATCH_RESULT
for image_list_object_match in iter_current_best_image_list_object_match(*args, **kwargs):
best_image_list_object_match = image_list_object_match
return best_image_list_object_match
import copyreg
import cv2
# based on: https://stackoverflow.com/a/48832618/8676953
def _pickle_keypoints(point):
return (
cv2.KeyPoint,
(
*point.pt, point.size, point.angle,
point.response, point.octave, point.class_id
)
)
def register_pickle_functions():
copyreg.pickle(cv2.KeyPoint().__class__, _pickle_keypoints)
from typing import Sequence
import pytest
from sciencebeam_gym.utils.cache import (
MultiLevelCache,
SimpleDictCache
)
KEY_1 = 'key1'
VALUE_1 = 'value1'
@pytest.fixture(name='dict_cache_list')
def _dict_cache_list() -> Sequence[SimpleDictCache]:
return [SimpleDictCache() for _ in range(10)]
class TestMultiLevelCache:
def test_should_return_none_if_not_in_any_cache(
self,
dict_cache_list: Sequence[SimpleDictCache]
):
cache = MultiLevelCache(dict_cache_list[:2])
assert cache.get(KEY_1) is None
def test_should_populate_cache_and_retrieve_item(
self,
dict_cache_list: Sequence[SimpleDictCache]
):
cache = MultiLevelCache(dict_cache_list[:2])
cache[KEY_1] = VALUE_1
assert cache.get(KEY_1) == VALUE_1
def test_should_retrieve_item_from_next_level(
self,
dict_cache_list: Sequence[SimpleDictCache]
):
cache = MultiLevelCache(dict_cache_list[:3])
cache[KEY_1] = VALUE_1
del dict_cache_list[0][KEY_1]
del dict_cache_list[2][KEY_1]
assert cache.get(KEY_1) == VALUE_1
import pickle
import cv2 as cv
from sciencebeam_gym.utils.pickle_reg import register_pickle_functions
class TestRegisterPickleFunction:
def test_should_be_able_to_pickle_and_unpickle_cv_keypoint(self):
register_pickle_functions()
key_point = cv.KeyPoint(x=1, y=2, size=3, angle=4, response=5, octave=6, class_id=7)
unpickled_key_point = pickle.loads(pickle.dumps(key_point))
assert unpickled_key_point.pt == key_point.pt
assert unpickled_key_point.size == key_point.size
assert unpickled_key_point.angle == key_point.angle
assert unpickled_key_point.response == key_point.response
assert unpickled_key_point.octave == key_point.octave
assert unpickled_key_point.class_id == key_point.class_id
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment