Skip to content
Snippets Groups Projects
train.sh 1.96 KiB
Newer Older
Daniel Ecer's avatar
Daniel Ecer committed
#!/bin/bash

source prepare-shell.sh

echo "TRAIN_PREPROC_TRAIN_PATH: $TRAIN_PREPROC_PATH"
echo "EVAL_PREPROC_EVAL_PATH: $EVAL_PREPROC_PATH"
echo "QUALITATIVE_PREPROC_EVAL_PATH: $QUALITATIVE_PREPROC_PATH"
echo "TRAIN_MODEL_PATH: $TRAIN_MODEL_PATH"
echo "CLASS_WEIGHTS_URL: ${CLASS_WEIGHTS_URL}"
Daniel Ecer's avatar
Daniel Ecer committed
COMMON_ARGS=(
  --output_path "${TRAIN_MODEL_PATH}/"
  --train_data_paths "${TRAIN_PREPROC_PATH}/*tfrecord*"
  --eval_data_paths "${EVAL_PREPROC_PATH}/*tfrecord*"
Daniel Ecer's avatar
Daniel Ecer committed
  --model "${MODEL_NAME}"
  --color_map "${CONFIG_PATH}/${COLOR_MAP_FILENAME}"
  --use_separate_channels $USE_SEPARATE_CHANNELS
  --batch_size $BATCH_SIZE
  --eval_set_size $EVAL_SET_SIZE
Daniel Ecer's avatar
Daniel Ecer committed
  --seed $RANDOM_SEED
  --base_loss $BASE_LOSS
  ${TRAINING_ARGS[@]}
Daniel Ecer's avatar
Daniel Ecer committed
)

if [ ! -z "$QUALITATIVE_PREPROC_PATH" ]; then
  COMMON_ARGS=(
    ${COMMON_ARGS[@]}
    --qualitative_data_paths "${QUALITATIVE_PREPROC_PATH}/*tfrecord*"
    --qualitative_set_size ${QUALITATIVE_SET_SIZE}
Daniel Ecer's avatar
Daniel Ecer committed
if [ $USE_CLOUD == true ]; then
  JOB_ID="train_$JOB_ID"
Daniel Ecer's avatar
Daniel Ecer committed
  gcloud ml-engine jobs submit training "$JOB_ID" \
    --stream-logs \
    --module-name sciencebeam_gym.trainer.task \
    --package-path sciencebeam_gym \
    --staging-bucket "$TEMP_BUCKET" \
Daniel Ecer's avatar
Daniel Ecer committed
    --region us-central1 \
    --runtime-version=1.2 \
Daniel Ecer's avatar
Daniel Ecer committed
    --scale-tier=BASIC_GPU \
    -- \
    --save_max_to_keep 10 \
    --log_interval_secs 100000 \
    --eval_interval_secs 100000 \
    --save_interval_secs 100000 \
    --log_freq 500 \
    --eval_freq 500 \
    --save_freq 500 \
Daniel Ecer's avatar
Daniel Ecer committed
    ${COMMON_ARGS[@]}
else
  gcloud ml-engine local train \
    --module-name sciencebeam_gym.trainer.task \
    --package-path sciencebeam_gym.trainer \
    -- \
    --save_max_to_keep 3 \
    --log_interval_secs 600 \
    --eval_interval_secs 300 \
    --save_interval_secs 300 \
    --log_freq 50 \
    --eval_freq 50 \
    --save_freq 50 \
    --max_steps 3 \
    ${COMMON_ARGS[@]}
fi