Skip to content
Snippets Groups Projects
Commit c5f49894 authored by Daniel Ecer's avatar Daniel Ecer
Browse files

train with class weights

parent 71c8b056
No related branches found
No related tags found
No related merge requests found
......@@ -4,8 +4,6 @@ set -e
source prepare-shell.sh
CLASS_WEIGHTS_FILENAME="${TRAIN_PREPROC_PATH}/class-weights.json"
echo "output will be written to: ${CLASS_WEIGHTS_FILENAME}"
python -m sciencebeam_gym.tools.calculate_class_weights \
......
......@@ -29,6 +29,7 @@ QUALITATIVE_PAGE_RANGE=1
QUALITATIVE_FILE_LIMIT=10
QUALITATIVE_PREPROC_PATH=
NUM_WORKERS=1
CLASS_WEIGHTS_FILENAME=
USE_CLOUD=false
......@@ -95,6 +96,10 @@ if [ ! -z "$QUALITATIVE_FOLDER_NAME" ]; then
QUALITATIVE_PREPROC_PATH=${PREPROC_PATH}/$QUALITATIVE_FOLDER_NAME
fi
if [ -z "$CLASS_WEIGHTS_FILENAME" ]; then
CLASS_WEIGHTS_FILENAME="${TRAIN_PREPROC_PATH}/class-weights.json"
fi
if [ ! -z "$POST_CONFIG_FILE" ]; then
source "${POST_CONFIG_FILE}"
fi
......@@ -10,7 +10,16 @@ def cross_entropy_loss(labels, logits):
return tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
logits=logits,
labels=labels,
name='softmax_cross_entropy_with_logits'
labels=labels
)
)
def weighted_cross_entropy_loss(targets, logits, pos_weight):
with tf.name_scope("weighted_cross_entropy"):
return tf.reduce_mean(
tf.nn.weighted_cross_entropy_with_logits(
logits=logits,
targets=targets,
pos_weight=pos_weight
)
)
......@@ -5,14 +5,15 @@ import numpy as np
from sciencebeam_gym.trainer.models.pix2pix.loss import (
l1_loss,
cross_entropy_loss
cross_entropy_loss,
weighted_cross_entropy_loss
)
def assert_close(a, b):
def assert_close(a, b, atol=1.e-8):
try:
assert np.allclose([a], [b])
assert np.allclose([a], [b], atol=atol)
except AssertionError as e:
raise_from(AssertionError('expected %s to be close to %s' % (a, b)), e)
raise_from(AssertionError('expected %s to be close to %s (atol=%s)' % (a, b, atol)), e)
class TestL1Loss(object):
def test_should_return_abs_diff_for_single_value(self):
......@@ -47,3 +48,54 @@ class TestCrossEntropyLoss(object):
loss = cross_entropy_loss(labels, logits)
with tf.Session() as session:
assert session.run([loss])[0] > 0.5
class TestWeightedCrossEntropyLoss(object):
def test_should_return_zero_if_logits_are_matching_labels_with_neg_pos_value(self):
with tf.Graph().as_default():
labels = tf.constant([
[[0.0, 1.0]]
])
logits = tf.constant([
[[-10.0, 10.0]]
])
pos_weight = tf.constant([
1.0
])
loss = weighted_cross_entropy_loss(labels, logits, pos_weight)
with tf.Session() as session:
assert_close(session.run([loss])[0], 0.0, atol=0.0001)
def test_should_return_not_zero_if_logits_are_not_matching_labels(self):
with tf.Graph().as_default():
labels = tf.constant([
[[0.0, 1.0]]
])
logits = tf.constant([
[[10.0, 10.0]]
])
pos_weight = tf.constant([
1.0
])
loss = weighted_cross_entropy_loss(labels, logits, pos_weight)
with tf.Session() as session:
assert session.run([loss])[0] > 0.5
def test_should_return_higher_loss_for_value_with_greater_weight(self):
with tf.Graph().as_default():
labels = tf.constant([
[[0.0, 1.0]]
])
logits = tf.constant([
[[10.0, 10.0]]
])
pos_weight_1 = tf.constant([
0.5
])
pos_weight_2 = tf.constant([
1.0
])
loss_1 = weighted_cross_entropy_loss(labels, logits, pos_weight_1)
loss_2 = weighted_cross_entropy_loss(labels, logits, pos_weight_2)
with tf.Session() as session:
loss_1_value, loss_2_value = session.run([loss_1, loss_2])
assert loss_1_value < loss_2_value
......@@ -16,7 +16,8 @@ from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
from sciencebeam_gym.trainer.models.pix2pix.loss import (
l1_loss,
cross_entropy_loss
cross_entropy_loss,
weighted_cross_entropy_loss
)
EPS = 1e-12
......@@ -24,6 +25,9 @@ EPS = 1e-12
class BaseLoss(object):
L1 = "L1"
CROSS_ENTROPY = "CE"
WEIGHTED_CROSS_ENTROPY = "WCE"
ALL_BASE_LOSS = [BaseLoss.L1, BaseLoss.CROSS_ENTROPY, BaseLoss.WEIGHTED_CROSS_ENTROPY]
Pix2PixModel = collections.namedtuple(
"Pix2PixModel", [
......@@ -271,14 +275,14 @@ def create_separate_channel_discriminator_by_blanking_out_channels(inputs, targe
return predict_real, predict_real_blanked
def create_pix2pix_model(inputs, targets, a):
def create_pix2pix_model(inputs, targets, a, pos_weight=None):
get_logger().info('gan_weight: %s, l1_weight: %s', a.gan_weight, a.l1_weight)
gan_enabled = abs(a.gan_weight) > 0.000001
with tf.variable_scope("generator"):
out_channels = int(targets.get_shape()[-1])
outputs = create_generator(inputs, out_channels, a)
if a.base_loss == BaseLoss.CROSS_ENTROPY:
if a.base_loss == BaseLoss.CROSS_ENTROPY or a.base_loss == BaseLoss.WEIGHTED_CROSS_ENTROPY:
output_logits = outputs
outputs = tf.nn.softmax(output_logits)
else:
......@@ -351,18 +355,26 @@ def create_pix2pix_model(inputs, targets, a):
discrim_grads_and_vars = []
with tf.name_scope("generator_loss"):
if a.base_loss == BaseLoss.CROSS_ENTROPY:
get_logger().info('using cross entropy loss function')
# TODO change variable name
gen_loss_L1 = cross_entropy_loss(
get_logger().info('using loss: %s', a.base_loss)
if a.base_loss == BaseLoss.L1:
gen_base_loss = l1_loss(labels=targets, outputs=outputs)
elif a.base_loss == BaseLoss.CROSS_ENTROPY:
gen_base_loss = cross_entropy_loss(
logits=output_logits,
labels=targets
)
elif a.base_loss == BaseLoss.WEIGHTED_CROSS_ENTROPY:
if pos_weight is None:
raise ValueError('pos_weight missing')
pos_weight = tf.convert_to_tensor(pos_weight)
gen_base_loss = weighted_cross_entropy_loss(
logits=output_logits,
targets=targets,
pos_weight=pos_weight
)
else:
get_logger().info('using L1 loss function')
# abs(targets - outputs) => 0
gen_loss_L1 = l1_loss(labels=targets, outputs=outputs)
gen_loss = gen_loss_L1 * a.l1_weight
raise ValueError('unrecognised base loss: %s' % a.base_loss)
gen_loss = gen_base_loss * a.l1_weight
if gan_enabled:
# predict_fake => 1
......@@ -383,11 +395,12 @@ def create_pix2pix_model(inputs, targets, a):
gen_train = gen_optim.apply_gradients(gen_grads_and_vars)
ema = tf.train.ExponentialMovingAverage(decay=0.99)
update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])
update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_base_loss])
global_step = tf.contrib.framework.get_or_create_global_step()
incr_global_step = tf.assign(global_step, global_step+1)
# TODO change gen_loss_L1 name
return Pix2PixModel(
inputs=inputs,
targets=targets,
......@@ -396,7 +409,7 @@ def create_pix2pix_model(inputs, targets, a):
discrim_loss=ema.average(discrim_loss),
discrim_grads_and_vars=discrim_grads_and_vars,
gen_loss_GAN=ema.average(gen_loss_GAN),
gen_loss_L1=ema.average(gen_loss_L1),
gen_loss_L1=ema.average(gen_base_loss),
gen_grads_and_vars=gen_grads_and_vars,
outputs=outputs,
global_step=global_step,
......
import logging
from collections import namedtuple
from mock import patch
import tensorflow as tf
import numpy as np
......@@ -9,6 +10,8 @@ from sciencebeam_gym.utils.collection import (
extend_dict
)
import sciencebeam_gym.trainer.models.pix2pix.pix2pix_core as pix2pix_core
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
create_pix2pix_model,
BaseLoss
......@@ -41,6 +44,9 @@ def create_args(*args, **kwargs):
d = extend_dict(*list(args) + [kwargs])
return namedtuple('args', d.keys())(**d)
def patch_spy_object(o, name):
return patch.object(o, name, wraps=getattr(o, name))
@pytest.mark.slow
@pytest.mark.very_slow
class TestCreatePix2pixModel(object):
......@@ -48,8 +54,6 @@ class TestCreatePix2pixModel(object):
with tf.Graph().as_default():
inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
get_logger().info('inputs: %s', inputs)
get_logger().info('targets: %s', targets)
a = create_args(DEFAULT_ARGS, gan_weight=0.0)
create_pix2pix_model(inputs, targets, a)
......@@ -57,28 +61,41 @@ class TestCreatePix2pixModel(object):
with tf.Graph().as_default():
inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
get_logger().info('inputs: %s', inputs)
get_logger().info('targets: %s', targets)
a = create_args(DEFAULT_ARGS, gan_weight=1.0)
create_pix2pix_model(inputs, targets, a)
def test_should_be_able_to_construct_graph_with_gan_and_sep_discrim_channels(self):
with tf.Graph().as_default():
inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
get_logger().info('inputs: %s', inputs)
get_logger().info('targets: %s', targets)
a = create_args(DEFAULT_ARGS, gan_weight=1.0, use_separate_discriminator_channels=True)
create_pix2pix_model(inputs, targets, a)
with patch_spy_object(pix2pix_core, 'l1_loss') as l1_loss:
with tf.Graph().as_default():
inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
a = create_args(DEFAULT_ARGS, gan_weight=1.0, use_separate_discriminator_channels=True)
create_pix2pix_model(inputs, targets, a)
assert l1_loss.called
def test_should_be_able_to_construct_graph_with_sep_discrim_channels_and_cross_entropy_loss(self):
with tf.Graph().as_default():
inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
get_logger().info('inputs: %s', inputs)
get_logger().info('targets: %s', targets)
a = create_args(
DEFAULT_ARGS,
gan_weight=1.0, use_separate_discriminator_channels=True, base_loss=BaseLoss.CROSS_ENTROPY
)
create_pix2pix_model(inputs, targets, a)
with patch_spy_object(pix2pix_core, 'cross_entropy_loss') as cross_entropy_loss:
with tf.Graph().as_default():
inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
a = create_args(
DEFAULT_ARGS,
gan_weight=1.0, use_separate_discriminator_channels=True, base_loss=BaseLoss.CROSS_ENTROPY
)
create_pix2pix_model(inputs, targets, a)
assert cross_entropy_loss.called
def test_should_be_able_to_construct_graph_with_weighted_cross_entropy_loss(self):
with patch_spy_object(pix2pix_core, 'weighted_cross_entropy_loss') \
as weighted_cross_entropy_loss:
with tf.Graph().as_default():
inputs = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
targets = tf.constant(np.zeros((BATCH_SIZE, HEIGHT, WIDTH, CHANNELS), dtype=np.float32))
a = create_args(
DEFAULT_ARGS,
gan_weight=1.0, use_separate_discriminator_channels=True,
base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY
)
create_pix2pix_model(inputs, targets, a, pos_weight=[1.0] * CHANNELS)
assert weighted_cross_entropy_loss.called
......@@ -4,13 +4,14 @@ from __future__ import print_function
import logging
import argparse
import json
from six.moves import reduce
import tensorflow as tf
from tensorflow.python.lib.io.file_io import FileIO
from tensorflow.python.lib.io.file_io import FileIO # pylint: disable=E0611
from sciencebeam_gym.trainer.util import (
read_examples
......@@ -26,6 +27,7 @@ from sciencebeam_gym.trainer.models.pix2pix.tf_utils import (
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
BaseLoss,
ALL_BASE_LOSS,
create_pix2pix_model,
create_other_summaries
)
......@@ -254,6 +256,14 @@ def add_model_summary_images(
)
tensors.summaries['output_image'] = tensors.image_tensors['output']
def parse_json_file(filename):
with FileIO(filename, 'r') as f:
return json.load(f)
def class_weights_to_pos_weight(class_weights, labels, use_unknown_class):
pos_weight = [class_weights[k] for k in labels]
return pos_weight + [0.0] if use_unknown_class else pos_weight
def parse_color_map(color_map_filename):
with FileIO(color_map_filename, 'r') as config_f:
return parse_color_map_from_file(
......@@ -300,6 +310,7 @@ class Model(object):
self.image_width = 256
self.image_height = 256
self.color_map = None
self.pos_weight = None
self.dimension_colors = None
self.dimension_labels = None
self.use_unknown_class = args.use_unknown_class
......@@ -320,6 +331,13 @@ class Model(object):
)
logger.debug("dimension_colors: %s", self.dimension_colors)
logger.debug("dimension_labels: %s", self.dimension_labels)
if self.args.class_weights:
self.pos_weight = class_weights_to_pos_weight(
parse_json_file(self.args.class_weights),
self.dimension_labels,
self.use_unknown_class
)
logger.info("pos_weight: %s", self.pos_weight)
def build_graph(self, data_paths, batch_size, graph_mode):
logger = get_logger()
......@@ -412,7 +430,8 @@ class Model(object):
pix2pix_model = create_pix2pix_model(
tensors.image_tensor,
tensors.separate_channel_annotation_tensor,
self.args
self.args,
pos_weight=self.pos_weight
)
if self.use_separate_channels:
......@@ -529,6 +548,11 @@ def model_args_parser():
type=str,
help='The path to the color map configuration.'
)
parser.add_argument(
'--class_weights',
type=str,
help='The path to the class weights configuration.'
)
parser.add_argument(
'--channels',
type=str_to_list,
......@@ -562,7 +586,7 @@ def model_args_parser():
'--base_loss',
type=str,
default=BaseLoss.L1,
choices=[BaseLoss.L1, BaseLoss.CROSS_ENTROPY],
choices=ALL_BASE_LOSS,
help='The base loss function to use'
)
return parser
......
......@@ -4,6 +4,21 @@ from mock import patch
import pytest
from pytest import raises
import tensorflow as tf
from sciencebeam_gym.utils.collection import (
extend_dict
)
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core import (
BaseLoss,
ALL_BASE_LOSS
)
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_core_test import (
DEFAULT_ARGS as CORE_DEFAULT_ARGS
)
import sciencebeam_gym.trainer.models.pix2pix.pix2pix_model as pix2pix_model
from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import (
......@@ -16,16 +31,26 @@ from sciencebeam_gym.trainer.models.pix2pix.pix2pix_model import (
UNKNOWN_LABEL,
Model,
str_to_list,
model_args_parser
model_args_parser,
class_weights_to_pos_weight
)
COLOR_MAP_FILENAME = 'color_map.conf'
CLASS_WEIGHTS_FILENAME = 'class-weights.json'
DATA_PATH = 'some/where/*.tfrecord'
BATCH_SIZE = 10
def some_color(i):
return (i, i, i)
SOME_COLORS = [some_color(1), some_color(2), some_color(3)]
SOME_LABELS = ['a', 'b', 'c']
SOME_COLOR_MAP = {
k: v for k, v in zip(SOME_LABELS, SOME_COLORS)
}
SOME_CLASS_WEIGHTS = {
k: float(i) for i, k in enumerate(SOME_LABELS)
}
class TestParseColorMap(object):
def test_should_use_fileio_to_load_file_and_pass_to_parser(self):
......@@ -121,21 +146,76 @@ class TestColorsAndLabelsWithUnknownClass(object):
assert colors_with_unknown == [UNKNOWN_COLOR]
assert labels_with_unknown == [UNKNOWN_LABEL]
def create_args(**kwargs):
return namedtuple('args', kwargs.keys())(**kwargs)
class TestClassWeightsToPosWeight(object):
def test_should_extract_selected_weights(self):
assert class_weights_to_pos_weight({
'a': 0.1,
'b': 0.2,
'c': 0.3
}, ['a', 'b'], False) == [0.1, 0.2]
def test_should_add_zero_if_unknown_class_is_true(self):
assert class_weights_to_pos_weight({
'a': 0.1,
'b': 0.2,
'c': 0.3
}, ['a', 'b'], True) == [0.1, 0.2, 0.0]
DEFAULT_ARGS = extend_dict(
CORE_DEFAULT_ARGS,
dict(
color_map=None,
class_weights=None,
channels=None,
use_separate_channels=False,
use_unknown_class=False
)
)
def create_args(*args, **kwargs):
d = extend_dict(*list(args) + [kwargs])
return namedtuple('args', d.keys())(**d)
@pytest.mark.slow
class TestModel(object):
def test_parse_separate_channels_with_color_map(self):
with patch.object(pix2pix_model, 'FileIO'):
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
def test_parse_separate_channels_with_color_map_without_class_weights(self):
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
parse_color_map_from_file.return_value = {
'a': some_color(1),
'b': some_color(2),
'c': some_color(3)
}
args = create_args(
DEFAULT_ARGS,
color_map=COLOR_MAP_FILENAME,
class_weights=None,
channels=['a', 'b'],
use_separate_channels=True,
use_unknown_class=True
)
model = Model(args)
assert model.dimension_colors == [some_color(1), some_color(2)]
assert model.dimension_labels == ['a', 'b']
assert model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
assert model.pos_weight is None
def test_parse_separate_channels_with_color_map_and_class_weights(self):
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
parse_color_map_from_file.return_value = {
'a': some_color(1),
'b': some_color(2),
'c': some_color(3)
}
parse_json_file.return_value = {
'a': 0.1,
'b': 0.2,
'c': 0.3
}
args = create_args(
DEFAULT_ARGS,
color_map=COLOR_MAP_FILENAME,
class_weights=CLASS_WEIGHTS_FILENAME,
channels=['a', 'b'],
use_separate_channels=True,
use_unknown_class=True
......@@ -143,8 +223,52 @@ class TestModel(object):
model = Model(args)
assert model.dimension_colors == [some_color(1), some_color(2)]
assert model.dimension_labels == ['a', 'b']
assert model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
assert (
model.dimension_colors_with_unknown == [some_color(1), some_color(2), UNKNOWN_COLOR]
)
assert model.dimension_labels_with_unknown == ['a', 'b', UNKNOWN_LABEL]
assert model.pos_weight == [0.1, 0.2, 0.0]
@pytest.mark.slow
@pytest.mark.very_slow
class TestModelBuildGraph(object):
def test_should_build_train_graph_with_defaults(self):
with tf.Graph().as_default():
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
with patch.object(pix2pix_model, 'read_examples') as read_examples:
parse_color_map_from_file.return_value = SOME_COLOR_MAP
read_examples.return_value = (
tf.constant(1),
b'dummy'
)
args = create_args(
DEFAULT_ARGS
)
model = Model(args)
model.build_train_graph(DATA_PATH, BATCH_SIZE)
def test_should_build_train_graph_with_class_weights(self):
with tf.Graph().as_default():
with patch.object(pix2pix_model, 'parse_color_map_from_file') as parse_color_map_from_file:
with patch.object(pix2pix_model, 'parse_json_file') as parse_json_file:
with patch.object(pix2pix_model, 'read_examples') as read_examples:
parse_color_map_from_file.return_value = SOME_COLOR_MAP
parse_json_file.return_value = SOME_CLASS_WEIGHTS
read_examples.return_value = (
tf.constant(1),
b'dummy'
)
args = create_args(
DEFAULT_ARGS,
base_loss=BaseLoss.WEIGHTED_CROSS_ENTROPY,
color_map=COLOR_MAP_FILENAME,
class_weights=CLASS_WEIGHTS_FILENAME,
channels=['a', 'b'],
use_separate_channels=True,
use_unknown_class=True
)
model = Model(args)
model.build_train_graph(DATA_PATH, BATCH_SIZE)
class TestStrToList(object):
def test_should_parse_empty_string_as_empty_list(self):
......@@ -159,7 +283,7 @@ class TestStrToList(object):
def test_should_ignore_white_space_around_values(self):
assert str_to_list(' a , b , c ') == ['a', 'b', 'c']
class Test(object):
class TestModelArgsParser(object):
def test_should_parse_channels(self):
args = model_args_parser().parse_args(['--channels', 'a,b,c'])
assert args.channels == ['a', 'b', 'c']
......@@ -167,3 +291,8 @@ class Test(object):
def test_should_set_channels_to_none_by_default(self):
args = model_args_parser().parse_args([])
assert args.channels is None
def test_should_allow_all_base_loss_options(self):
for base_loss in ALL_BASE_LOSS:
args = model_args_parser().parse_args(['--base_loss', base_loss])
assert args.base_loss == base_loss
......@@ -11,6 +11,7 @@ echo "TRAIN_PREPROC_TRAIN_PATH: $TRAIN_PREPROC_PATH"
echo "EVAL_PREPROC_EVAL_PATH: $EVAL_PREPROC_PATH"
echo "QUALITATIVE_PREPROC_EVAL_PATH: $QUALITATIVE_PREPROC_PATH"
echo "TRAIN_MODEL_PATH: $TRAIN_MODEL_PATH"
echo "CLASS_WEIGHTS_FILENAME: ${CLASS_WEIGHTS_FILENAME}"
COMMON_ARGS=(
--output_path "${TRAIN_MODEL_PATH}/"
......@@ -19,6 +20,7 @@ COMMON_ARGS=(
--model "${MODEL_NAME}"
--color_map "${CONFIG_PATH}/${COLOR_MAP_FILENAME}"
--channels="$CHANNEL_NAMES"
--class_weights="${CLASS_WEIGHTS_FILENAME}"
--use_separate_channels $USE_SEPARATE_CHANNELS
--batch_size $BATCH_SIZE
--eval_set_size $EVAL_SET_SIZE
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment