aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gunhan Gulsoy <gunan@google.com>2017-12-04 11:38:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 11:42:24 -0800
commit3ff72d4fd07783e28752e5ec0bc70966dffe2e3c (patch)
treeee97b57878941ac6fd85a80b16fe946d7265926f
parentff71c2792746262bbce936d78c3543cdea0c4b70 (diff)
Sanitize dtypes in filenames in normalization_test.
PiperOrigin-RevId: 177843964
-rw-r--r--tensorflow/python/layers/normalization_test.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index 7c91c3284e..e147f348b0 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -105,9 +105,17 @@ class BNTest(test.TestCase):
infer_use_gpu):
batch, height, width, input_channels = 2, 4, 5, 3
shape = [batch, height, width, input_channels]
- checkpoint = os.path.join(self.get_temp_dir(), 'cp_%s_%s_%s_%s' %
- (dtype, train1_use_gpu, train2_use_gpu,
- infer_use_gpu))
+
+ # Not all characters in a dtype string representation are allowed in
+ # filenames in all operating systems. This map will sanitize these.
+ dtype_to_valid_fn = {
+ dtypes.float16: 'float16',
+ dtypes.float32: 'float32',
+ }
+ checkpoint = os.path.join(
+ self.get_temp_dir(), 'cp_%s_%s_%s_%s' % (
+ dtype_to_valid_fn[dtype], train1_use_gpu, train2_use_gpu,
+ infer_use_gpu))
self._train(
checkpoint,