Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
S
Sciencebeam Gym
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package Registry
Container Registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Sciencebeam
Sciencebeam Gym
Commits
7bf05aff
Commit
7bf05aff
authored
7 years ago
by
Daniel Ecer
Browse files
Options
Downloads
Patches
Plain Diff
optionally include unknown in class weights
parent
4400522a
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
sciencebeam_gym/tools/calculate_class_weights.py
+56
-19
56 additions, 19 deletions
sciencebeam_gym/tools/calculate_class_weights.py
sciencebeam_gym/tools/calculate_class_weights_test.py
+46
-0
46 additions, 0 deletions
sciencebeam_gym/tools/calculate_class_weights_test.py
with
102 additions
and
19 deletions
sciencebeam_gym/tools/calculate_class_weights.py
+
56
−
19
View file @
7bf05aff
...
...
@@ -21,31 +21,49 @@ from sciencebeam_gym.utils.tfrecord import (
def
get_logger
():
return
logging
.
getLogger
(
__name__
)
def
color_frequency
(
image
,
color
):
return
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
reduce_all
(
tf
.
equal
(
image
,
color
),
axis
=-
1
,
name
=
'
is_color
'
),
tf
.
float32
)
def
color_equals_mask
(
image
,
color
):
return
tf
.
reduce_all
(
tf
.
equal
(
image
,
color
),
axis
=-
1
,
name
=
'
is_color
'
)
def
color_equals_mask_as_float
(
image
,
color
):
return
tf
.
cast
(
color_equals_mask
(
image
,
color
),
tf
.
float32
)
def
color_frequency
(
image
,
color
):
return
tf
.
reduce_sum
(
color_equals_mask_as_float
(
image
,
color
))
def
get_shape
(
x
):
try
:
return
x
.
shape
except
AttributeError
:
return
tf
.
constant
(
x
).
shape
def
calculate_sample_frequencies
(
image
,
colors
):
return
[
color_
fr
equ
ency
(
image
,
color
)
def
calculate_sample_frequencies
(
image
,
colors
,
use_unknown_class
=
False
):
color_masks
=
[
color_equ
als_mask_as_float
(
image
,
color
)
for
color
in
colors
]
if
use_unknown_class
:
shape
=
tf
.
shape
(
color_masks
[
0
])
ones
=
tf
.
fill
(
shape
,
1.0
,
name
=
'
ones
'
)
zeros
=
tf
.
fill
(
shape
,
0.0
,
name
=
'
zeros
'
)
color_masks
.
append
(
tf
.
where
(
tf
.
add_n
(
color_masks
)
<
0.5
,
ones
,
zeros
)
)
return
[
tf
.
reduce_sum
(
color_mask
)
for
color_mask
in
color_masks
]
def
iter_calculate_sample_frequencies
(
images
,
colors
,
image_shape
=
None
,
image_format
=
None
,
use_unknown_class
=
False
):
def
iter_calculate_sample_frequencies
(
images
,
colors
,
image_shape
=
None
,
image_format
=
None
):
with
tf
.
Graph
().
as_default
():
if
image_format
==
'
png
'
:
image_tensor
=
tf
.
placeholder
(
tf
.
string
,
shape
=
[],
name
=
'
image
'
)
...
...
@@ -56,7 +74,9 @@ def iter_calculate_sample_frequencies(images, colors, image_shape=None, image_fo
image_tensor
=
tf
.
placeholder
(
tf
.
uint8
,
shape
=
image_shape
,
name
=
'
image
'
)
decoded_image_tensor
=
image_tensor
get_logger
().
debug
(
'
decoded_image_tensor: %s
'
,
decoded_image_tensor
)
frequency_tensors
=
calculate_sample_frequencies
(
decoded_image_tensor
,
colors
)
frequency_tensors
=
calculate_sample_frequencies
(
decoded_image_tensor
,
colors
,
use_unknown_class
=
use_unknown_class
)
with
tf
.
Session
()
as
session
:
for
image
in
images
:
frequencies
=
session
.
run
(
frequency_tensors
,
{
...
...
@@ -117,7 +137,7 @@ def iter_images_for_tfrecord_paths(tfrecord_paths, image_key, progress=False):
yield
d
[
image_key
]
def
calculate_median_class_weights_for_tfrecord_paths_and_colors
(
tfrecord_paths
,
image_key
,
colors
,
progress
=
False
):
tfrecord_paths
,
image_key
,
colors
,
use_unknown_class
=
False
,
progress
=
False
):
get_logger
().
debug
(
'
colors: %s
'
,
colors
)
get_logger
().
info
(
'
loading tfrecords: %s
'
,
tfrecord_paths
)
...
...
@@ -125,7 +145,9 @@ def calculate_median_class_weights_for_tfrecord_paths_and_colors(
if
progress
:
images
=
list
(
images
)
images
=
tqdm
(
images
,
'
analysing images
'
,
leave
=
False
)
frequency_list
=
list
(
iter_calculate_sample_frequencies
(
images
,
colors
,
image_format
=
'
png
'
))
frequency_list
=
list
(
iter_calculate_sample_frequencies
(
images
,
colors
,
image_format
=
'
png
'
,
use_unknown_class
=
use_unknown_class
))
get_logger
().
debug
(
'
frequency_list: %s
'
,
frequency_list
)
frequencies
=
transpose
(
frequency_list
)
get_logger
().
debug
(
'
frequencies: %s
'
,
frequencies
)
...
...
@@ -133,7 +155,9 @@ def calculate_median_class_weights_for_tfrecord_paths_and_colors(
return
class_weights
def
calculate_median_class_weights_for_tfrecord_paths_and_color_map
(
tfrecord_paths
,
image_key
,
color_map
,
channels
=
None
,
progress
=
False
):
tfrecord_paths
,
image_key
,
color_map
,
channels
=
None
,
use_unknown_class
=
False
,
unknown_class_label
=
'
unknown
'
,
progress
=
False
):
if
not
channels
:
channels
=
sorted
(
color_map
.
keys
())
colors
=
[
color_map
[
k
]
for
k
in
channels
]
...
...
@@ -141,12 +165,18 @@ def calculate_median_class_weights_for_tfrecord_paths_and_color_map(
tfrecord_paths
,
image_key
,
colors
,
progress
=
progress
progress
=
progress
,
use_unknown_class
=
use_unknown_class
)
if
use_unknown_class
:
channels
+=
[
unknown_class_label
]
return
{
k
:
class_weight
for
k
,
class_weight
in
zip
(
channels
,
class_weights
)
}
def
str_to_bool
(
s
):
return
s
.
lower
()
in
(
'
yes
'
,
'
true
'
,
'
1
'
)
def
str_to_list
(
s
):
s
=
s
.
strip
()
if
not
s
:
...
...
@@ -179,6 +209,12 @@ def get_args_parser():
type
=
str_to_list
,
help
=
'
The channels to use (subset of color map), otherwise all of the labels will be used
'
)
parser
.
add_argument
(
'
--use-unknown-class
'
,
type
=
str_to_bool
,
default
=
True
,
help
=
'
Use unknown class channel
'
)
parser
.
add_argument
(
'
--out
'
,
required
=
False
,
...
...
@@ -200,6 +236,7 @@ def main(argv=None):
args
.
image_key
,
color_map
,
channels
=
args
.
channels
,
use_unknown_class
=
args
.
use_unknown_class
,
progress
=
True
)
get_logger
().
info
(
'
class_weights: %s
'
,
class_weights_map
)
...
...
This diff is collapsed.
Click to expand it.
sciencebeam_gym/tools/calculate_class_weights_test.py
+
46
−
0
View file @
7bf05aff
...
...
@@ -63,6 +63,12 @@ class TestCalculateSampleFrequencies(object):
COLOR_1
,
COLOR_1
,
COLOR_2
]],
[
COLOR_1
,
COLOR_2
]))
==
[
2.0
,
1.0
]
def
test_should_include_unknown_class_count_if_enabled
(
self
):
with
tf
.
Session
()
as
session
:
assert
session
.
run
(
calculate_sample_frequencies
([[
COLOR_1
,
COLOR_2
,
COLOR_3
]],
[
COLOR_1
],
use_unknown_class
=
True
))
==
[
1.0
,
2.0
]
def
encode_png
(
data
):
out
=
BytesIO
()
data
=
np
.
array
(
data
,
dtype
=
np
.
uint8
)
...
...
@@ -90,6 +96,20 @@ class TestIterCalculateSampleFrequencies(object):
]]
],
[
COLOR_1
]))
==
[[
0.0
]]
def
test_should_include_unknown_class_if_enabled
(
self
):
assert
list
(
iter_calculate_sample_frequencies
([
[[
COLOR_0
]]
],
[
COLOR_1
],
image_shape
=
(
1
,
1
,
3
),
use_unknown_class
=
True
))
==
[[
0.0
,
1.0
]]
def
test_should_include_unknown_class_if_enabled_and_infer_shape
(
self
):
assert
list
(
iter_calculate_sample_frequencies
([
[[
COLOR_0
]]
],
[
COLOR_1
],
use_unknown_class
=
True
))
==
[[
0.0
,
1.0
]]
def
test_should_return_total_count_for_multiple_mixed_color
(
self
):
assert
list
(
iter_calculate_sample_frequencies
([
[[
...
...
@@ -119,6 +139,13 @@ class TestIterCalculateSampleFrequencies(object):
]])
],
[
COLOR_1
],
image_format
=
'
png
'
))
==
[[
1.0
]]
def
_test_should_infer_shape_when_decoding_png_and_include_unknown_class
(
self
):
assert
list
(
iter_calculate_sample_frequencies
([
encode_png
([[
COLOR_1
,
COLOR_2
,
COLOR_3
]])
],
[
COLOR_1
],
image_format
=
'
png
'
,
use_unknown_class
=
True
))
==
[[
1.0
,
2.0
]]
class
TestCalculateMedianClassWeight
(
object
):
def
test_should_return_median_frequency_balanced_for_same_frequencies
(
self
):
assert
calculate_median_class_weight
([
3
,
3
,
3
])
==
1
/
3
...
...
@@ -249,3 +276,22 @@ class TestCalculateMedianClassWeightsForFfrecordPathsAndColorMap(object):
}
)
assert
set
(
class_weights_map
.
keys
())
==
{
'
color1
'
,
'
color2
'
}
def
test_should_include_unknown_class_if_enabled
(
self
):
with
TemporaryDirectory
()
as
path
:
tfrecord_filename
=
os
.
path
.
join
(
path
,
'
data.tfrecord
'
)
get_logger
().
debug
(
'
writing to test tfrecord_filename: %s
'
,
tfrecord_filename
)
write_examples_to_tfrecord
(
tfrecord_filename
,
[
dict_to_example
({
'
image
'
:
encode_png
([[
COLOR_0
,
COLOR_1
,
COLOR_2
,
COLOR_3
]])
})])
class_weights_map
=
calculate_median_class_weights_for_tfrecord_paths_and_color_map
(
[
tfrecord_filename
],
'
image
'
,
{
'
color1
'
:
COLOR_1
,
'
color2
'
:
COLOR_2
},
use_unknown_class
=
True
,
unknown_class_label
=
'
unknown
'
)
assert
set
(
class_weights_map
.
keys
())
==
{
'
color1
'
,
'
color2
'
,
'
unknown
'
}
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment