Skip to content
Snippets Groups Projects
Commit d424348e authored by Daniel Ecer's avatar Daniel Ecer
Browse files

optionally print debug information (pos_weight)

parent 9ac3e055
No related branches found
No related tags found
No related merge requests found
......@@ -407,6 +407,12 @@ class Model(object):
tensors.pos_weight = tf_calculate_efnet_weights_for_frequency_by_label(
frequency_by_label
)
if self.args.debug:
tensors.pos_weight = tf.Print(
tensors.pos_weight, [tensors.pos_weight, frequency_by_label, tensors.input_uri],
'pos weights, frequency, uri: ',
summarize=1000
)
get_logger().debug(
'pos_weight before batch: %s (frequency_by_label: %s)',
tensors.pos_weight, frequency_by_label
......@@ -608,6 +614,12 @@ def model_args_parser():
choices=ALL_BASE_LOSS,
help='The base loss function to use'
)
parser.add_argument(
'--debug',
type=str_to_bool,
default=True,
help='Enable debug mode'
)
return parser
......
......@@ -149,7 +149,8 @@ DEFAULT_ARGS = extend_dict(
class_weights=None,
channels=None,
use_separate_channels=False,
use_unknown_class=False
use_unknown_class=False,
debug=False
)
)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment