From 6d7a715d3b02fd3841ab0f3704e74aa239e35c9b Mon Sep 17 00:00:00 2001
From: Daniel Ecer <de-code@users.noreply.github.com>
Date: Thu, 7 Dec 2017 18:35:06 +0000
Subject: [PATCH] added MAX_TRAIN_STEPS script configuration; fixed passing in
 CLASS_WEIGHTS_FILENAME

---
 calculate-class-weights.sh | 6 +++---
 prepare-shell.sh           | 5 +++--
 train.sh                   | 8 ++++----
 3 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/calculate-class-weights.sh b/calculate-class-weights.sh
index 307313d..06c6397 100755
--- a/calculate-class-weights.sh
+++ b/calculate-class-weights.sh
@@ -4,13 +4,13 @@ set -e
 
 source prepare-shell.sh
 
-echo "output will be written to: ${CLASS_WEIGHTS_FILENAME}"
+echo "output will be written to: ${CLASS_WEIGHTS_URL}"
 
 python -m sciencebeam_gym.tools.calculate_class_weights \
   --tfrecord-paths "${TRAIN_PREPROC_PATH}/*tfrecord*" \
   --image-key "annotation_image" \
   --color-map "${CONFIG_PATH}/${COLOR_MAP_FILENAME}" \
   --channels="$CHANNEL_NAMES" \
-  --out "${CLASS_WEIGHTS_FILENAME}"
+  --out "${CLASS_WEIGHTS_URL}"
 
-echo "output written to: ${CLASS_WEIGHTS_FILENAME}"
+echo "output written to: ${CLASS_WEIGHTS_URL}"
diff --git a/prepare-shell.sh b/prepare-shell.sh
index 8e95748..7620581 100755
--- a/prepare-shell.sh
+++ b/prepare-shell.sh
@@ -30,6 +30,7 @@ QUALITATIVE_FILE_LIMIT=10
 QUALITATIVE_PREPROC_PATH=
 NUM_WORKERS=1
 CLASS_WEIGHTS_FILENAME=
+MAX_TRAIN_STEPS=1000
 
 USE_CLOUD=false
 
@@ -96,8 +97,8 @@ if [ ! -z "$QUALITATIVE_FOLDER_NAME" ]; then
   QUALITATIVE_PREPROC_PATH=${PREPROC_PATH}/$QUALITATIVE_FOLDER_NAME
 fi
 
-if [ -z "$CLASS_WEIGHTS_FILENAME" ]; then
-  CLASS_WEIGHTS_FILENAME="${TRAIN_PREPROC_PATH}/class-weights.json"
+if [ ! -z "$CLASS_WEIGHTS_FILENAME" ]; then
+  CLASS_WEIGHTS_URL="${TRAIN_PREPROC_PATH}/${CLASS_WEIGHTS_FILENAME}"
 fi
 
 if [ ! -z "$POST_CONFIG_FILE" ]; then
diff --git a/train.sh b/train.sh
index f8c6702..42bdeee 100755
--- a/train.sh
+++ b/train.sh
@@ -11,7 +11,7 @@ echo "TRAIN_PREPROC_TRAIN_PATH: $TRAIN_PREPROC_PATH"
 echo "EVAL_PREPROC_EVAL_PATH: $EVAL_PREPROC_PATH"
 echo "QUALITATIVE_PREPROC_EVAL_PATH: $QUALITATIVE_PREPROC_PATH"
 echo "TRAIN_MODEL_PATH: $TRAIN_MODEL_PATH"
-echo "CLASS_WEIGHTS_FILENAME: ${CLASS_WEIGHTS_FILENAME}"
+echo "CLASS_WEIGHTS_URL: ${CLASS_WEIGHTS_URL}"
 
 COMMON_ARGS=(
   --output_path "${TRAIN_MODEL_PATH}/"
@@ -20,14 +20,13 @@ COMMON_ARGS=(
   --model "${MODEL_NAME}"
   --color_map "${CONFIG_PATH}/${COLOR_MAP_FILENAME}"
   --channels="$CHANNEL_NAMES"
-  --class_weights="${CLASS_WEIGHTS_FILENAME}"
+  --class_weights="${CLASS_WEIGHTS_URL}"
   --use_separate_channels $USE_SEPARATE_CHANNELS
   --batch_size $BATCH_SIZE
   --eval_set_size $EVAL_SET_SIZE
   --seed $RANDOM_SEED
   --base_loss $BASE_LOSS
   ${TRAINING_ARGS[@]}
-  $@
 )
 
 if [ ! -z "$QUALITATIVE_PREPROC_PATH" ]; then
@@ -39,6 +38,7 @@ if [ ! -z "$QUALITATIVE_PREPROC_PATH" ]; then
 fi
 
 if [ $USE_CLOUD == true ]; then
+  echo "MAX_TRAIN_STEPS: $MAX_TRAIN_STEPS"
   gcloud ml-engine jobs submit training "$JOB_ID" \
     --stream-logs \
     --module-name sciencebeam_gym.trainer.task \
@@ -55,7 +55,7 @@ if [ $USE_CLOUD == true ]; then
     --log_freq 500 \
     --eval_freq 500 \
     --save_freq 500 \
-    --max_steps 1000 \
+    --max_steps ${MAX_TRAIN_STEPS} \
     ${COMMON_ARGS[@]}
 else
   gcloud ml-engine local train \
-- 
GitLab