diff options
author | 2017-12-04 11:38:59 -0800 | |
---|---|---|
committer | 2017-12-04 11:42:24 -0800 | |
commit | 3ff72d4fd07783e28752e5ec0bc70966dffe2e3c (patch) | |
tree | ee97b57878941ac6fd85a80b16fe946d7265926f | |
parent | ff71c2792746262bbce936d78c3543cdea0c4b70 (diff) |
Sanitize dtypes in filenames in normalization_test.
PiperOrigin-RevId: 177843964
-rw-r--r-- | tensorflow/python/layers/normalization_test.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index 7c91c3284e..e147f348b0 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -105,9 +105,17 @@ class BNTest(test.TestCase): infer_use_gpu): batch, height, width, input_channels = 2, 4, 5, 3 shape = [batch, height, width, input_channels] - checkpoint = os.path.join(self.get_temp_dir(), 'cp_%s_%s_%s_%s' % - (dtype, train1_use_gpu, train2_use_gpu, - infer_use_gpu)) + + # Not all characters in a dtype string representation are allowed in + # filenames in all operating systems. This map will sanitize these. + dtype_to_valid_fn = { + dtypes.float16: 'float16', + dtypes.float32: 'float32', + } + checkpoint = os.path.join( + self.get_temp_dir(), 'cp_%s_%s_%s_%s' % ( + dtype_to_valid_fn[dtype], train1_use_gpu, train2_use_gpu, + infer_use_gpu)) self._train( checkpoint, |