aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-12 10:47:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 10:50:24 -0700
commit151c31ce75f4370fd3749f3b07ac8297d3b2e203 (patch)
treee5d5fab632100e0dbdc7ddef8efa85db1805daeb
parent844b8cae970d835850a75f8063324224b2de0df0 (diff)
Make default weights initializer in `base_layers.Layer` suitable for their dtype.
PiperOrigin-RevId: 192634133
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py20
-rw-r--r--tensorflow/python/layers/base_test.py6
2 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 3b3af7d092..6c68d25127 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -473,16 +473,30 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
+ ValueError: When giving unsupported dtype and no initializer.
"""
if dtype is None:
dtype = self.dtype or backend.floatx()
+ else:
+ dtype = dtypes.as_dtype(dtype)
initializer = initializers.get(initializer)
- if initializer is None:
- # Default TensorFlow initializer.
- initializer = initializers.glorot_uniform()
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
+ # Initialize variable when no initializer provided
+ if initializer is None:
+ # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
+ if dtype.is_floating:
+ initializer = initializers.glorot_uniform()
+ # If dtype is DT_INT/DT_UINT, provide a default value `zero`
+ # If dtype is DT_BOOL, provide a default value `FALSE`
+ elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
+ initializer = initializers.zeros()
+ # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
+ else:
+ raise ValueError('An initializer for variable %s of type %s is required'
+ ' for layer %s' % (name, dtype.base_dtype, self.name))
+
variable = self._add_variable_with_custom_getter(
name=name,
shape=shape,
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index c05c675263..f08b552840 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -53,6 +53,12 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer.trainable, False)
@test_util.run_in_graph_and_eager_modes()
+ def testInt64Layer(self):
+ layer = base_layers.Layer(name='my_layer', dtype='int64')
+ layer.add_variable('my_var', [2, 2])
+ self.assertEqual(layer.name, 'my_layer')
+
+ @test_util.run_in_graph_and_eager_modes()
def testAddWeight(self):
layer = base_layers.Layer(name='my_layer')