Skip to content
Snippets Groups Projects
Commit 6d7a715d authored by Daniel Ecer's avatar Daniel Ecer
Browse files

added MAX_TRAIN_STEPS script configuration; fixed passing in CLASS_WEIGHTS_FILENAME

parent 7bf05aff
No related branches found
No related tags found
No related merge requests found
......@@ -4,13 +4,13 @@ set -e
source prepare-shell.sh
echo "output will be written to: ${CLASS_WEIGHTS_FILENAME}"
echo "output will be written to: ${CLASS_WEIGHTS_URL}"
python -m sciencebeam_gym.tools.calculate_class_weights \
--tfrecord-paths "${TRAIN_PREPROC_PATH}/*tfrecord*" \
--image-key "annotation_image" \
--color-map "${CONFIG_PATH}/${COLOR_MAP_FILENAME}" \
--channels="$CHANNEL_NAMES" \
--out "${CLASS_WEIGHTS_FILENAME}"
--out "${CLASS_WEIGHTS_URL}"
echo "output written to: ${CLASS_WEIGHTS_FILENAME}"
echo "output written to: ${CLASS_WEIGHTS_URL}"
......@@ -30,6 +30,7 @@ QUALITATIVE_FILE_LIMIT=10
QUALITATIVE_PREPROC_PATH=
NUM_WORKERS=1
CLASS_WEIGHTS_FILENAME=
MAX_TRAIN_STEPS=1000
USE_CLOUD=false
......@@ -96,8 +97,8 @@ if [ ! -z "$QUALITATIVE_FOLDER_NAME" ]; then
QUALITATIVE_PREPROC_PATH=${PREPROC_PATH}/$QUALITATIVE_FOLDER_NAME
fi
if [ -z "$CLASS_WEIGHTS_FILENAME" ]; then
CLASS_WEIGHTS_FILENAME="${TRAIN_PREPROC_PATH}/class-weights.json"
if [ ! -z "$CLASS_WEIGHTS_FILENAME" ]; then
CLASS_WEIGHTS_URL="${TRAIN_PREPROC_PATH}/${CLASS_WEIGHTS_FILENAME}"
fi
if [ ! -z "$POST_CONFIG_FILE" ]; then
......
......@@ -11,7 +11,7 @@ echo "TRAIN_PREPROC_TRAIN_PATH: $TRAIN_PREPROC_PATH"
echo "EVAL_PREPROC_EVAL_PATH: $EVAL_PREPROC_PATH"
echo "QUALITATIVE_PREPROC_EVAL_PATH: $QUALITATIVE_PREPROC_PATH"
echo "TRAIN_MODEL_PATH: $TRAIN_MODEL_PATH"
echo "CLASS_WEIGHTS_FILENAME: ${CLASS_WEIGHTS_FILENAME}"
echo "CLASS_WEIGHTS_URL: ${CLASS_WEIGHTS_URL}"
COMMON_ARGS=(
--output_path "${TRAIN_MODEL_PATH}/"
......@@ -20,14 +20,13 @@ COMMON_ARGS=(
--model "${MODEL_NAME}"
--color_map "${CONFIG_PATH}/${COLOR_MAP_FILENAME}"
--channels="$CHANNEL_NAMES"
--class_weights="${CLASS_WEIGHTS_FILENAME}"
--class_weights="${CLASS_WEIGHTS_URL}"
--use_separate_channels $USE_SEPARATE_CHANNELS
--batch_size $BATCH_SIZE
--eval_set_size $EVAL_SET_SIZE
--seed $RANDOM_SEED
--base_loss $BASE_LOSS
${TRAINING_ARGS[@]}
$@
)
if [ ! -z "$QUALITATIVE_PREPROC_PATH" ]; then
......@@ -39,6 +38,7 @@ if [ ! -z "$QUALITATIVE_PREPROC_PATH" ]; then
fi
if [ $USE_CLOUD == true ]; then
echo "MAX_TRAIN_STEPS: $MAX_TRAIN_STEPS"
gcloud ml-engine jobs submit training "$JOB_ID" \
--stream-logs \
--module-name sciencebeam_gym.trainer.task \
......@@ -55,7 +55,7 @@ if [ $USE_CLOUD == true ]; then
--log_freq 500 \
--eval_freq 500 \
--save_freq 500 \
--max_steps 1000 \
--max_steps ${MAX_TRAIN_STEPS} \
${COMMON_ARGS[@]}
else
gcloud ml-engine local train \
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment