aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Reed Wanderman-Milne <reedwm@google.com>2018-06-15 15:27:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-15 15:32:43 -0700
commitb8861afe21d8d654c2a726cabd82069faca04532 (patch)
tree6d967f76284235e5ae43e19d23320f3098ebfecb /tensorflow/python/layers
parent5e9a39d6ad6eee207a7af88bb1bbe1deefb8bbb2 (diff)
Automatic cast layer inputs to the layer's dtype.
This makes it more convenient to use layer of different dtypes in a model. Instead of having to manually cast intermediate tensors between layers of different dtypes, they will automatically be casted. This is also useful for the upcoming mixed precision API. PiperOrigin-RevId: 200783477
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py12
-rw-r--r--tensorflow/python/layers/base_test.py59
2 files changed, 66 insertions, 5 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index eda036ece4..abbe9d0c56 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -43,13 +43,15 @@ class Layer(base_layer.Layer):
Arguments:
trainable: Boolean, whether the layer's variables should be trainable.
name: String name of the layer.
- dtype: Default dtype of the layer's weights (default of `None` means use the
- type of the first input).
+ dtype: Default dtype of the layer's weights and computations (default of
+ `None` means use the type of the first input). If not None, inputs will be
+ casted to this dtype.
Read-only properties:
name: The name of the layer (string).
- dtype: Default dtype of the layer's weights (default of `None` means use the
- type of the first input).
+ dtype: Default dtype of the layer's weights and computations. (default of
+ `None` means use the type of the first input). If not None, inputs will be
+ casted to this dtype.
trainable_variables: List of trainable variables.
non_trainable_variables: List of non-trainable variables.
variables: List of all variables of this layer, trainable and
@@ -191,7 +193,7 @@ class Layer(base_layer.Layer):
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
"""
-
+
def _should_add_regularizer(variable, existing_variable_set):
if isinstance(variable, tf_variables.PartitionedVariable):
for var in variable:
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index ab49e37b90..15448c6be8 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -25,6 +25,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import backend
+from tensorflow.python.keras.engine import base_layer as keras_base_layer
from tensorflow.python.layers import base as base_layers
from tensorflow.python.layers import core as core_layers
from tensorflow.python.ops import array_ops
@@ -589,6 +591,63 @@ class BaseLayerTest(test.TestCase):
ValueError, 'Input graph and Layer graph are not the same'):
layer.apply(constant_op.constant([[1.]]))
+ @test_util.run_in_graph_and_eager_modes()
+ def testOnlyCastInputsWhenDtypeSpecified(self):
+ class MyLayerBase(keras_base_layer.Layer):
+
+ def call(self, inputs):
+ self.x = inputs[0]
+ self.y = inputs[1]
+ return self.x + 1, self.y + 2
+
+ # Inherit from both the Keras Layer and base_layers.Layer to ensure we
+ # still get the base_layers.Layer behavior when directly inheriting from
+ # the Keras Layer.
+ class MyLayer(MyLayerBase, base_layers.Layer):
+ pass
+
+ # Test inputs are casted.
+ input1 = array_ops.constant(1.0, dtype=dtypes.float64)
+ input2 = array_ops.constant(1.0, dtype=dtypes.float32)
+ layer = MyLayer(dtype=dtypes.float16)
+ output1, output2 = layer([input1, input2])
+ self.assertEqual(output1.dtype, dtypes.float16)
+ self.assertEqual(output2.dtype, dtypes.float16)
+
+ # Test inputs are not casted.
+ input1 = array_ops.constant(1.0, dtype=dtypes.float64)
+ input2 = array_ops.constant(1.0, dtype=dtypes.float32)
+ layer = MyLayer()
+ output1, output2 = layer([input1, input2])
+ self.assertEqual(output1.dtype, dtypes.float64)
+ self.assertEqual(output2.dtype, dtypes.float32)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testVariablesDefaultToFloat32(self):
+ class MyLayerBase(keras_base_layer.Layer):
+
+ def build(self, input_shape):
+ self.x = self.add_weight('x', ())
+
+ def call(self, inputs):
+ return inputs + self.x
+
+ # Inherit from both the Keras Layer and base_layers.Layer to ensure we
+ # still get the base_layers.Layer behavior when directly inheriting from
+ # the Keras Layer.
+ class MyLayer(MyLayerBase, base_layers.Layer):
+ pass
+
+ try:
+ # The behavior of Keras Layers is to default to floatx. Ensure that this
+ # behavior is overridden to instead default to float32.
+ backend.set_floatx('float16')
+ layer = MyLayer()
+ layer.build(())
+ self.assertEqual(layer.dtype, None)
+ self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32)
+ finally:
+ backend.set_floatx('float32')
if __name__ == '__main__':
test.main()