diff options
author | 2018-04-12 10:47:26 -0700 | |
---|---|---|
committer | 2018-04-12 10:50:24 -0700 | |
commit | 151c31ce75f4370fd3749f3b07ac8297d3b2e203 (patch) | |
tree | e5d5fab632100e0dbdc7ddef8efa85db1805daeb | |
parent | 844b8cae970d835850a75f8063324224b2de0df0 (diff) |
Make default weights initializer in `base_layers.Layer` suitable for their dtype.
PiperOrigin-RevId: 192634133
-rw-r--r-- | tensorflow/python/keras/_impl/keras/engine/base_layer.py | 20 | ||||
-rw-r--r-- | tensorflow/python/layers/base_test.py | 6 |
2 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py index 3b3af7d092..6c68d25127 100644 --- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py +++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py @@ -473,16 +473,30 @@ class Layer(checkpointable.CheckpointableBase): Raises: RuntimeError: If called with partioned variable regularization and eager execution is enabled. + ValueError: When giving unsupported dtype and no initializer. """ if dtype is None: dtype = self.dtype or backend.floatx() + else: + dtype = dtypes.as_dtype(dtype) initializer = initializers.get(initializer) - if initializer is None: - # Default TensorFlow initializer. - initializer = initializers.glorot_uniform() regularizer = regularizers.get(regularizer) constraint = constraints.get(constraint) + # Initialize variable when no initializer provided + if initializer is None: + # If dtype is DT_FLOAT, provide a uniform unit scaling initializer + if dtype.is_floating: + initializer = initializers.glorot_uniform() + # If dtype is DT_INT/DT_UINT, provide a default value `zero` + # If dtype is DT_BOOL, provide a default value `FALSE` + elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: + initializer = initializers.zeros() + # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? + else: + raise ValueError('An initializer for variable %s of type %s is required' + ' for layer %s' % (name, dtype.base_dtype, self.name)) + variable = self._add_variable_with_custom_getter( name=name, shape=shape, diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index c05c675263..f08b552840 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -53,6 +53,12 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer.trainable, False) @test_util.run_in_graph_and_eager_modes() + def testInt64Layer(self): + layer = base_layers.Layer(name='my_layer', dtype='int64') + layer.add_variable('my_var', [2, 2]) + self.assertEqual(layer.name, 'my_layer') + + @test_util.run_in_graph_and_eager_modes() def testAddWeight(self): layer = base_layers.Layer(name='my_layer') |