# partially copied from tensorflow example project
from __future__ import absolute_import
import argparse
import datetime
import errno
import io
import logging
import os
import subprocess
import sys
import re

import six

import apache_beam as beam
from apache_beam.metrics import Metrics

# pylint: disable=g-import-not-at-top
# TODO(yxshi): Remove after Dataflow 0.4.5 SDK is released.
try:
  try:
    from apache_beam.options.pipeline_options import PipelineOptions
  except ImportError:
    from apache_beam.utils.pipeline_options import PipelineOptions
except ImportError:
  from apache_beam.utils.options import PipelineOptions
import tensorflow as tf

from tensorflow.python.framework import errors
from tensorflow.python.lib.io import file_io

from PIL import Image

# TODO copied functions due to pickling issue
# from .colorize_image import (
#   parse_color_map,
#   map_colors
# )
import re
from six.moves.configparser import ConfigParser

slim = tf.contrib.slim

error_count = Metrics.counter('main', 'errorCount')

def parse_color_map(f, section_names=None):
  if section_names is None:
    section_names = ['color_map']
  color_map_config = ConfigParser()
  color_map_config.readfp(f)

  num_pattern = re.compile(r'(\d+)')
  rgb_pattern = re.compile(r'\((\d+),(\d+),(\d+)\)')

  def parse_color(s):
    m = num_pattern.match(s)
    if m:
      x = int(m.group(1))
      return (x, x, x)
    else:
      m = rgb_pattern.match(s)
      if m:
        return (int(m.group(1)), int(m.group(2)), int(m.group(3)))
    raise Exception('invalid color value: {}'.format(s))

  color_map = dict()
  for section_name in section_names:
    if color_map_config.has_section(section_name):
      for k, v in color_map_config.items(section_name):
        color_map[parse_color(k)] = parse_color(v)
  return color_map

def map_colors(img, color_map):
  if color_map is None or len(color_map) == 0:
    return img
  original_data = img.getdata()
  mapped_data = [
    color_map.get(color, color)
    for color in original_data
  ]
  img.putdata(mapped_data)
  return img

def _open_file_read_binary(uri):
  # TF will enable 'rb' in future versions, but until then, 'r' is
  # required.
  try:
    return file_io.FileIO(uri, mode='rb')
  except errors.InvalidArgumentError:
    return file_io.FileIO(uri, mode='r')

def iter_read_image(uri):
  try:
    with _open_file_read_binary(uri) as f:
      image_bytes = f.read()
      yield Image.open(io.BytesIO(image_bytes)).convert('RGB')

  # A variety of different calling libraries throw different exceptions here.
  # They all correspond to an unreadable file so we treat them equivalently.
  except Exception as e:  # pylint: disable=broad-except
    logging.exception('Error processing image %s: %s', uri, str(e))
    error_count.inc()

def image_resize_nearest(image, size):
  return image.resize(size, Image.NEAREST)

def image_resize_bicubic(image, size):
  return image.resize(size, Image.BICUBIC)

def image_save_to_bytes(image, format):
  output = io.BytesIO()
  image.save(output, format)
  return output.getvalue()

def get_image_filenames_for_filenames(filenames, target_pattern):
  r_target_pattern = re.compile(target_pattern)
  png_filenames = [s for s in filenames if s.endswith('.png')]
  annot_filenames = sorted([s for s in png_filenames if r_target_pattern.search(s)])
  image_filenames = sorted([s for s in png_filenames if s not in annot_filenames])
  assert len(annot_filenames) == len(image_filenames)
  return image_filenames, annot_filenames

def get_image_filenames_for_patterns(data_paths, target_pattern):
  files = []
  for path in data_paths:
    files.extend(file_io.get_matching_files(path))
  return get_image_filenames_for_filenames(files, target_pattern)

def MapDictPropAs(in_key, out_key, fn, label='MapDictPropAs'):
  def wrapper_fn(d):
    d = d.copy()
    d.update({
      out_key: fn(d[in_key])
    })
    return d
  return label >> beam.Map(wrapper_fn)

class MapDictPropAsIfNotNone(beam.DoFn):
  def __init__(self, in_key, out_key, fn):
    self.in_key = in_key
    self.out_key = out_key
    self.fn = fn

  def process(self, element):
    if hasattr(element, 'element'):
      element = element.element
    if not isinstance(element, dict):
      raise Exception('expected dict, got: {} ({})'.format(type(element), element))
    d = element
    value = self.fn(d[self.in_key])
    if value is not None:
      d = d.copy()
      d.update({
        self.out_key: value
      })
      yield d

def FlatMapDictProp(in_key, out_key, fn, label='FlatMapDictProp'):
  def wrapper_fn(d):
    out_value = fn(d[in_key])
    for v in out_value:
      d_out = d.copy()
      d_out.update({
        out_key: v
      })
      yield d_out
  return label >> beam.FlatMap(wrapper_fn)

def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def WriteToLog(message_fn):
  def wrapper_fn(x):
    logging.info(message_fn(x))
    return x
  return beam.Map(wrapper_fn)

def ReadAndConvertInputImage(image_size):
  def convert_annotation_image(image):
    image = image_resize_bicubic(image, image_size)
    return image_save_to_bytes(image, 'png')
  return lambda uri: [
    convert_annotation_image(image)
    for image in iter_read_image(uri)
  ]

def ReadAndConvertAnnotationImage(image_size, color_map):
  def convert_annotation_image(image):
    image = image_resize_nearest(image, image_size)
    if color_map:
      image = map_colors(image, color_map)
    return image_save_to_bytes(image, 'png')
  return lambda uri: [
    convert_annotation_image(image)
    for image in iter_read_image(uri)
  ]

def configure_pipeline(p, opt):
  """Specify PCollection and transformations in pipeline."""
  logger = logging.getLogger(__name__)
  image_size = (opt.image_width, opt.image_height)
  color_map = None
  if opt.color_map:
    with file_io.FileIO(opt.color_map, 'r') as config_f:
      color_map = parse_color_map(config_f, section_names=['color_alias', 'color_map'])
  if color_map:
    logger.info('read {} color mappings'.format(len(color_map)))
  else:
    logger.info('no color mappings configured')
  train_image_filenames, train_annotation_filenames = (
    get_image_filenames_for_patterns(opt.data_paths, opt.target_pattern)
  )
  logging.info('train/annotation_filenames:\n%s', '\n'.join([
    '...{} - ...{}'.format(train_image_filename[-20:], train_annotation_filename[-40:])
    for train_image_filename, train_annotation_filename
    in zip(train_image_filenames, train_annotation_filenames)
  ]))

  file_io.recursive_create_dir(opt.output_path)

  _ = (
    p
    | beam.Create([
      {
        'input_uri': train_image_filename,
        'annotation_uri': train_annotation_filename
      }
      for train_image_filename, train_annotation_filename in
      zip(train_image_filenames, train_annotation_filenames)
    ])
    | 'ReadAndConvertInputImage' >> FlatMapDictProp(
      'input_uri', 'input_image', ReadAndConvertInputImage(
        image_size
      )
    )
    | 'ReadAndConvertAnnotationImage' >> FlatMapDictProp(
      'annotation_uri', 'annotation_image', ReadAndConvertAnnotationImage(
        image_size, color_map
      )
    )
    | 'Log' >> WriteToLog(lambda x: 'processed: {} ({})'.format(
      x['input_uri'], x['annotation_uri']
    ))
    | 'ConvertToExamples' >> beam.Map(
      lambda x: tf.train.Example(features=tf.train.Features(feature={
        k: _bytes_feature([v])
        for k, v in six.iteritems(x)
      }))
    )
    | 'SerializeToString' >> beam.Map(lambda x: x.SerializeToString())
    | 'SaveToDisk' >> beam.io.WriteToTFRecord(
      opt.output_path,
      file_name_suffix='.tfrecord.gz'
    )
  )

def run(in_args=None):
  """Runs the pre-processing pipeline."""

  pipeline_options = PipelineOptions.from_dictionary(vars(in_args))
  with beam.Pipeline(options=pipeline_options) as p:
    configure_pipeline(p, in_args)

def get_cloud_project():
  cmd = [
    'gcloud', '-q', 'config', 'list', 'project',
    '--format=value(core.project)'
  ]
  with open(os.devnull, 'w') as dev_null:
    try:
      res = subprocess.check_output(cmd, stderr=dev_null).strip()
      if not res:
        raise Exception(
          '--cloud specified but no Google Cloud Platform '
          'project found.\n'
          'Please specify your project name with the --project '
          'flag or set a default project: '
          'gcloud config set project YOUR_PROJECT_NAME'
        )
      return res
    except OSError as e:
      if e.errno == errno.ENOENT:
        raise Exception(
          'gcloud is not installed. The Google Cloud SDK is '
          'necessary to communicate with the Cloud ML service. '
          'Please install and set up gcloud.'
        )
      raise

def default_args(argv):
  """Provides default values for Workflow flags."""
  parser = argparse.ArgumentParser()

  parser.add_argument(
    '--data_paths',
    type=str,
    action='append',
    help='The paths to the training data files. '
      'Can be comma separated list of files or glob pattern.'
  )
  parser.add_argument(
    '--target_pattern',
    type=str,
    default=r'\btarget\b|\bannot',
    help='The regex pattern to identify target files.'
  )
  parser.add_argument(
    '--output_path',
    required=True,
    help='Output directory to write results to.'
  )
  parser.add_argument(
    '--project',
    type=str,
    help='The cloud project name to be used for running this pipeline'
  )

  parser.add_argument(
    '--job_name',
    type=str,
    default='sciencebeam-gym-' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
    help='A unique job identifier.'
  )
  parser.add_argument(
    '--num_workers', default=20, type=int, help='The number of workers.'
  )
  parser.add_argument('--cloud', default=False, action='store_true')
  parser.add_argument(
    '--runner',
    help='See Dataflow runners, may be blocking'
    ' or not, on cloud or not, etc.'
  )
  parser.add_argument(
    '--image_width',
    type=int,
    required=True,
    help='Resize images to the specified width'
  )
  parser.add_argument(
    '--image_height',
    type=int,
    required=True,
    help='Resize images to the specified height'
  )
  parser.add_argument(
    '--color_map',
    type=str,
    help='The path to the color map configuration.'
  )

  parsed_args, _ = parser.parse_known_args(argv)

  if parsed_args.cloud:
    # Flags which need to be set for cloud runs.
    default_values = {
      'project':
        get_cloud_project(),
      'temp_location':
        os.path.join(os.path.dirname(parsed_args.output_path), 'temp'),
      'runner':
        'DataflowRunner',
      'save_main_session':
        True,
    }
  else:
    # Flags which need to be set for local runs.
    default_values = {
      'runner': 'DirectRunner',
    }

  for kk, vv in default_values.iteritems():
    if kk not in parsed_args or not vars(parsed_args)[kk]:
      vars(parsed_args)[kk] = vv

  return parsed_args

def main(argv):
  arg_dict = default_args(argv)
  run(arg_dict)


if __name__ == '__main__':
  logging.basicConfig(level=logging.INFO)

  main(sys.argv[1:])