aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-28 11:15:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-28 12:17:38 -0700
commit49a99586cdd04d6b8b1488b032b8fc4baac26019 (patch)
treecf79b177a6e1634ef6f3c16f9f3aeedf68ebf09c
parent34c820dcc6a6c2cf2df448b17b2a801fbcb16d6b (diff)
Fail fast on bad dtype.
Change: 128727035
-rw-r--r--tensorflow/contrib/layers/python/layers/initializers.py5
-rw-r--r--tensorflow/contrib/layers/python/layers/initializers_test.py5
2 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py
index 03665a4951..1786b71dcf 100644
--- a/tensorflow/contrib/layers/python/layers/initializers.py
+++ b/tensorflow/contrib/layers/python/layers/initializers.py
@@ -102,12 +102,13 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
TypeError: if `mode` is not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG'].
"""
if not dtype.is_floating:
- raise TypeError('Cannot create initializer for non-floating point '
- 'type.')
+ raise TypeError('Cannot create initializer for non-floating point type.')
if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']:
raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)
def _initializer(shape, dtype=dtype):
"""Initializer function."""
+ if not dtype.is_floating:
+ raise TypeError('Cannot create initializer for non-floating point type.')
# Estimating fan_in and fan_out is not possible to do perfectly, but we try.
# This is the right thing for matrix multiply and convolutions.
fan_in = float(shape[-2])
diff --git a/tensorflow/contrib/layers/python/layers/initializers_test.py b/tensorflow/contrib/layers/python/layers/initializers_test.py
index bacf16c1ad..d619dd8ee0 100644
--- a/tensorflow/contrib/layers/python/layers/initializers_test.py
+++ b/tensorflow/contrib/layers/python/layers/initializers_test.py
@@ -64,6 +64,11 @@ class VarianceScalingInitializerTest(tf.test.TestCase):
TypeError,
'Cannot create initializer for non-floating point type.'):
tf.contrib.layers.variance_scaling_initializer(dtype=tf.int32)
+ initializer = tf.contrib.layers.variance_scaling_initializer()
+ with self.assertRaisesRegexp(
+ TypeError,
+ 'Cannot create initializer for non-floating point type.'):
+ initializer([], dtype=tf.int32)
def _test_variance(self, initializer, shape, variance, factor, mode, uniform):
with tf.Graph().as_default() as g: