diff options
author | 2016-07-28 11:15:26 -0800 | |
---|---|---|
committer | 2016-07-28 12:17:38 -0700 | |
commit | 49a99586cdd04d6b8b1488b032b8fc4baac26019 (patch) | |
tree | cf79b177a6e1634ef6f3c16f9f3aeedf68ebf09c | |
parent | 34c820dcc6a6c2cf2df448b17b2a801fbcb16d6b (diff) |
Fail fast on bad dtype.
Change: 128727035
-rw-r--r-- | tensorflow/contrib/layers/python/layers/initializers.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/initializers_test.py | 5 |
2 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index 03665a4951..1786b71dcf 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -102,12 +102,13 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, TypeError: if `mode` is not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']. """ if not dtype.is_floating: - raise TypeError('Cannot create initializer for non-floating point ' - 'type.') + raise TypeError('Cannot create initializer for non-floating point type.') if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']: raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode) def _initializer(shape, dtype=dtype): """Initializer function.""" + if not dtype.is_floating: + raise TypeError('Cannot create initializer for non-floating point type.') # Estimating fan_in and fan_out is not possible to do perfectly, but we try. # This is the right thing for matrix multiply and convolutions. fan_in = float(shape[-2]) diff --git a/tensorflow/contrib/layers/python/layers/initializers_test.py b/tensorflow/contrib/layers/python/layers/initializers_test.py index bacf16c1ad..d619dd8ee0 100644 --- a/tensorflow/contrib/layers/python/layers/initializers_test.py +++ b/tensorflow/contrib/layers/python/layers/initializers_test.py @@ -64,6 +64,11 @@ class VarianceScalingInitializerTest(tf.test.TestCase): TypeError, 'Cannot create initializer for non-floating point type.'): tf.contrib.layers.variance_scaling_initializer(dtype=tf.int32) + initializer = tf.contrib.layers.variance_scaling_initializer() + with self.assertRaisesRegexp( + TypeError, + 'Cannot create initializer for non-floating point type.'): + initializer([], dtype=tf.int32) def _test_variance(self, initializer, shape, variance, factor, mode, uniform): with tf.Graph().as_default() as g: |