Skip to content
Snippets Groups Projects
Select Git revision
  • 96a3de1f1f44d6159c26e636b928ea5904aa90e7
  • develop default protected
  • dependabot/pip/gcsfs-2022.1.0
  • dependabot/pip/numpy-1.21.0
  • dependabot/pip/types-cachetools-4.2.8
  • master
  • dependabot/pip/fsspec-2022.1.0
  • dependabot/pip/cachetools-5.0.0
  • dependabot/pip/pylint-2.12.2
  • dependabot/pip/tensorflow-transform-1.5.0
  • dependabot/pip/apache-beam-gcp--2.35.0
  • dependabot/pip/typing-extensions-4.0.1
  • dependabot/pip/google-cloud-bigquery-lt-2.32.0
  • v0.1.1
  • v0.1.0
  • v0.0.1
16 results

predict.sh

Blame
  • predict.sh 1.59 KiB
    #!/bin/bash
    
    set -e
    
    source prepare-shell.sh
    
    PREDICT_INPUT="$1"
    PREDICT_OUTPUT="$2"
    
    if [ -z "$PREDICT_INPUT" ]; then
      echo "Usage: $0 <predict input> [<predict output>]"
      exit 1
    fi
    
    if [ -z "$PREDICT_OUTPUT" ]; then
      PREDICT_OUTPUT=$PREDICT_INPUT.out.png
    fi
    
    COMMON_ARGS=(
      --output_path "${TRAIN_MODEL_PATH}/"
      --train_data_paths "${TRAIN_PREPROC_PATH}/*tfrecord*"
      --eval_data_paths "${EVAL_PREPROC_PATH}/*tfrecord*"
      --model "${MODEL_NAME}"
      --color_map "${CONFIG_PATH}/${COLOR_MAP_FILENAME}"
      --class_weights="${CLASS_WEIGHTS_URL}"
      --channels="$CHANNEL_NAMES"
      --use_separate_channels $USE_SEPARATE_CHANNELS
      --batch_size $BATCH_SIZE
      --eval_set_size $EVAL_SET_SIZE
      --max_steps 0
      --base_loss $BASE_LOSS
      --seed $RANDOM_SEED
      --predict="$PREDICT_INPUT"
      --predict-output="$PREDICT_OUTPUT"
      ${TRAINING_ARGS[@]}
    )
    
    if [ "$USE_MODEL_EXPORT_PATH" == true ]; then
      COMMON_ARGS=(
        ${COMMON_ARGS[@]}
        --model_export_path="${MODEL_EXPORT_PATH}"
      )
    fi
    
    if [ $USE_SEPARATE_CHANNELS == true ]; then
      COMMON_ARGS=(
        ${COMMON_ARGS[@]}
        --color_map "${CONFIG_PATH}/${COLOR_MAP_FILENAME}"
      )
    fi
    
    if [ $USE_CLOUD == true ]; then
      JOB_ID="predict_$JOB_ID"
      gcloud ml-engine jobs submit training "$JOB_ID" \
        --stream-logs \
        --module-name sciencebeam_gym.trainer.task \
        --package-path sciencebeam_gym \
        --staging-bucket "$TEMP_BUCKET" \
        --region us-central1 \
        --runtime-version=1.2 \
        -- \
        ${COMMON_ARGS[@]}
    else
      gcloud ml-engine local train \
        --module-name sciencebeam_gym.trainer.task \
        --package-path sciencebeam_gym.trainer \
        -- \
        ${COMMON_ARGS[@]}
    fi