aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/keras/engine/base_layer.py38
-rw-r--r--tensorflow/python/keras/engine/topology_test.py42
-rw-r--r--tensorflow/python/layers/base_test.py16
3 files changed, 57 insertions, 39 deletions
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 751cc5a8d5..b05bc96e28 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -89,11 +89,19 @@ class Layer(checkpointable.CheckpointableBase):
once. Should actually perform the logic of applying the layer to the
input tensors (which should be passed in as the first argument).
- By default, layers will cast all their inputs and arguments to the layer's
- dtype, if set. This is useful for creating a model with multiple dtypes, as
- the user does not need to explicitly cast tensors. If a `Layer` descendant
- wants only a subset of inputs/arguments to be casted, or none of them,
- `_cast_inputs_and_args()` should be overridden.
+ A note on a layer's `dtype` property:
+ A layer's dtype can be specified via the constructor `dtype` argument, and
+ defaults to the dtype of the first input when the layer is called. The dtype
+ cannot be changed once set.
+
+ All floating point tensor inputs and arguments are casted to the layer's
+ dtype, before the body of the layer computation happens. For models with
+ layers of different dtypes, this helps getting rid of the explicit casts
+ between layers.
+
+ The casting behavior can be customized in subclasses by overridding
+ `_cast_inputs_and_args()` function, which is useful if certain or all inputs
+ should not be casted.
Arguments:
trainable: Boolean, whether the layer's variables should be trainable.
@@ -675,10 +683,9 @@ class Layer(checkpointable.CheckpointableBase):
kwargs['mask'] = previous_mask
input_shapes = None
- # We only cast inputs if self.dtype was previous set, which occurs when
- # a dtype was passed to the constructor, or when this layer has previously
- # been called. We cast floating point inputs to self.dtype to ensure the
- # layer runs with the correct dtype.
+ # Inputs are only casted if a dtype is pased in the constructor, or if a
+ # layer's __call__() has been previously invoked. At present, only floating
+ # point tensor inputs are affected.
# TODO(b/77478433): Perhaps we should only cast inputs if a dtype was passed
# to the constructor, not when the layer has previously been called.
inputs_should_be_cast = (self.dtype is not None)
@@ -810,10 +817,13 @@ class Layer(checkpointable.CheckpointableBase):
def _cast_inputs_and_args(self, inputs, *args, **kwargs):
"""Casts the inputs, args, and kwargs of a layer to the layer's dtype.
- This is intended to be potentially overridden by layer subclasses. By
- default, inputs, args, and kwargs are automatically casted to the layer's
- dtype. Overriding this method allows only some of the inputs, args, and
- kwargs (or none of them) to be casted.
+ This is intended to be potentially overridden by subclasses. By default,
+ inputs, args, and kwargs are automatically casted to the layer's dtype.
+ Overriding this method allows only some of the parameters to be treated
+ differently.
+
+ Currently, this only casts floating point tensors to floating point dtypes,
+ but more types may be casted in the future.
Does not modify inputs, args, or kwargs.
@@ -823,7 +833,7 @@ class Layer(checkpointable.CheckpointableBase):
**kwargs: The kwargs to self.__call__.
Returns:
- The tuple (new_inputs, new_args, new_kwargs), where tensors in inputs,
+ A tuple (new_inputs, new_args, new_kwargs), where tensors in inputs,
args, and kwargs have been casted to self.dtype.
"""
new_inputs = nest.map_structure(self._cast_fn, inputs)
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 7fbe6b80ad..d28c30cb7d 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -1057,24 +1057,30 @@ class TopologyConstructionTest(test.TestCase):
def compute_output_shape(self, input_shapes):
return input_shapes[0]
- x = keras.layers.Input((32,), dtype='float64')
- layer1 = SingleInputLayer()
- layer2 = SingleInputLayer(dtype='float32')
- layer3 = MultiInputLayer(dtype='float16')
- i1 = layer1(x)
- i2 = layer2(i1)
- y = layer3((i1, i2))
- network = keras.engine.Network(x, y)
- x2 = array_ops.ones((32,), dtype='float16')
- y2 = network(x2)
- self.assertEqual(layer1.dtype, dtypes.float64)
- self.assertEqual(layer1.a.dtype, dtypes.float64)
- self.assertEqual(layer2.dtype, dtypes.float32)
- self.assertEqual(layer2.a.dtype, dtypes.float32)
- self.assertEqual(layer3.dtype, dtypes.float16)
- self.assertEqual(layer3.a.dtype, dtypes.float16)
- self.assertEqual(layer3.b.dtype, dtypes.float16)
- self.assertEqual(y2.dtype, dtypes.float16)
+ default_layer = SingleInputLayer()
+ fp32_layer = SingleInputLayer(dtype='float32')
+ fp16_layer = MultiInputLayer(dtype='float16')
+
+ input_t = keras.layers.Input((32,), dtype='float64')
+ o1 = default_layer(input_t)
+ o2 = fp32_layer(o1)
+ # fp16_layer has inputs of different dtypes.
+ output_t = fp16_layer((o1, o2))
+ network = keras.engine.Network(input_t, output_t)
+
+ x = array_ops.ones((32,), dtype='float16')
+ y = network(x)
+ self.assertEqual(default_layer.dtype, dtypes.float64)
+ self.assertEqual(default_layer.a.dtype, dtypes.float64)
+
+ self.assertEqual(fp32_layer.dtype, dtypes.float32)
+ self.assertEqual(fp32_layer.a.dtype, dtypes.float32)
+
+ self.assertEqual(fp16_layer.dtype, dtypes.float16)
+ self.assertEqual(fp16_layer.a.dtype, dtypes.float16)
+ self.assertEqual(fp16_layer.b.dtype, dtypes.float16)
+
+ self.assertEqual(y.dtype, dtypes.float16)
class DeferredModeTest(test.TestCase):
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 15448c6be8..ad44328aab 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -593,7 +593,8 @@ class BaseLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testOnlyCastInputsWhenDtypeSpecified(self):
- class MyLayerBase(keras_base_layer.Layer):
+
+ class MyKerasLayer(keras_base_layer.Layer):
def call(self, inputs):
self.x = inputs[0]
@@ -603,13 +604,13 @@ class BaseLayerTest(test.TestCase):
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
# still get the base_layers.Layer behavior when directly inheriting from
# the Keras Layer.
- class MyLayer(MyLayerBase, base_layers.Layer):
+ class MyTFLayer(MyKerasLayer, base_layers.Layer):
pass
# Test inputs are casted.
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
- layer = MyLayer(dtype=dtypes.float16)
+ layer = MyTFLayer(dtype=dtypes.float16)
output1, output2 = layer([input1, input2])
self.assertEqual(output1.dtype, dtypes.float16)
self.assertEqual(output2.dtype, dtypes.float16)
@@ -617,14 +618,15 @@ class BaseLayerTest(test.TestCase):
# Test inputs are not casted.
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
- layer = MyLayer()
+ layer = MyTFLayer()
output1, output2 = layer([input1, input2])
self.assertEqual(output1.dtype, dtypes.float64)
self.assertEqual(output2.dtype, dtypes.float32)
@test_util.run_in_graph_and_eager_modes()
def testVariablesDefaultToFloat32(self):
- class MyLayerBase(keras_base_layer.Layer):
+
+ class MyKerasLayer(keras_base_layer.Layer):
def build(self, input_shape):
self.x = self.add_weight('x', ())
@@ -635,14 +637,14 @@ class BaseLayerTest(test.TestCase):
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
# still get the base_layers.Layer behavior when directly inheriting from
# the Keras Layer.
- class MyLayer(MyLayerBase, base_layers.Layer):
+ class MyTFLayer(MyKerasLayer, base_layers.Layer):
pass
try:
# The behavior of Keras Layers is to default to floatx. Ensure that this
# behavior is overridden to instead default to float32.
backend.set_floatx('float16')
- layer = MyLayer()
+ layer = MyTFLayer()
layer.build(())
self.assertEqual(layer.dtype, None)
self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32)