diff options
author | 2018-06-15 15:27:11 -0700 | |
---|---|---|
committer | 2018-06-15 15:32:43 -0700 | |
commit | b8861afe21d8d654c2a726cabd82069faca04532 (patch) | |
tree | 6d967f76284235e5ae43e19d23320f3098ebfecb /tensorflow/python/layers | |
parent | 5e9a39d6ad6eee207a7af88bb1bbe1deefb8bbb2 (diff) |
Automatic cast layer inputs to the layer's dtype.
This makes it more convenient to use layer of different dtypes in a model. Instead of having to manually cast intermediate tensors between layers of different dtypes, they will automatically be casted.
This is also useful for the upcoming mixed precision API.
PiperOrigin-RevId: 200783477
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/base.py | 12 | ||||
-rw-r--r-- | tensorflow/python/layers/base_test.py | 59 |
2 files changed, 66 insertions, 5 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index eda036ece4..abbe9d0c56 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -43,13 +43,15 @@ class Layer(base_layer.Layer): Arguments: trainable: Boolean, whether the layer's variables should be trainable. name: String name of the layer. - dtype: Default dtype of the layer's weights (default of `None` means use the - type of the first input). + dtype: Default dtype of the layer's weights and computations (default of + `None` means use the type of the first input). If not None, inputs will be + casted to this dtype. Read-only properties: name: The name of the layer (string). - dtype: Default dtype of the layer's weights (default of `None` means use the - type of the first input). + dtype: Default dtype of the layer's weights and computations. (default of + `None` means use the type of the first input). If not None, inputs will be + casted to this dtype. trainable_variables: List of trainable variables. non_trainable_variables: List of non-trainable variables. variables: List of all variables of this layer, trainable and @@ -191,7 +193,7 @@ class Layer(base_layer.Layer): RuntimeError: If called with partioned variable regularization and eager execution is enabled. """ - + def _should_add_regularizer(variable, existing_variable_set): if isinstance(variable, tf_variables.PartitionedVariable): for var in variable: diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index ab49e37b90..15448c6be8 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -25,6 +25,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend +from tensorflow.python.keras.engine import base_layer as keras_base_layer from tensorflow.python.layers import base as base_layers from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops @@ -589,6 +591,63 @@ class BaseLayerTest(test.TestCase): ValueError, 'Input graph and Layer graph are not the same'): layer.apply(constant_op.constant([[1.]])) + @test_util.run_in_graph_and_eager_modes() + def testOnlyCastInputsWhenDtypeSpecified(self): + class MyLayerBase(keras_base_layer.Layer): + + def call(self, inputs): + self.x = inputs[0] + self.y = inputs[1] + return self.x + 1, self.y + 2 + + # Inherit from both the Keras Layer and base_layers.Layer to ensure we + # still get the base_layers.Layer behavior when directly inheriting from + # the Keras Layer. + class MyLayer(MyLayerBase, base_layers.Layer): + pass + + # Test inputs are casted. + input1 = array_ops.constant(1.0, dtype=dtypes.float64) + input2 = array_ops.constant(1.0, dtype=dtypes.float32) + layer = MyLayer(dtype=dtypes.float16) + output1, output2 = layer([input1, input2]) + self.assertEqual(output1.dtype, dtypes.float16) + self.assertEqual(output2.dtype, dtypes.float16) + + # Test inputs are not casted. + input1 = array_ops.constant(1.0, dtype=dtypes.float64) + input2 = array_ops.constant(1.0, dtype=dtypes.float32) + layer = MyLayer() + output1, output2 = layer([input1, input2]) + self.assertEqual(output1.dtype, dtypes.float64) + self.assertEqual(output2.dtype, dtypes.float32) + + @test_util.run_in_graph_and_eager_modes() + def testVariablesDefaultToFloat32(self): + class MyLayerBase(keras_base_layer.Layer): + + def build(self, input_shape): + self.x = self.add_weight('x', ()) + + def call(self, inputs): + return inputs + self.x + + # Inherit from both the Keras Layer and base_layers.Layer to ensure we + # still get the base_layers.Layer behavior when directly inheriting from + # the Keras Layer. + class MyLayer(MyLayerBase, base_layers.Layer): + pass + + try: + # The behavior of Keras Layers is to default to floatx. Ensure that this + # behavior is overridden to instead default to float32. + backend.set_floatx('float16') + layer = MyLayer() + layer.build(()) + self.assertEqual(layer.dtype, None) + self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32) + finally: + backend.set_floatx('float32') if __name__ == '__main__': test.main() |