aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 72ab77fbbd..5dc93806f4 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -485,8 +485,7 @@ class Layer(checkpointable.CheckpointableBase):
"""
if dtype is None:
dtype = self.dtype or backend.floatx()
- else:
- dtype = dtypes.as_dtype(dtype)
+ dtype = dtypes.as_dtype(dtype)
initializer = initializers.get(initializer)
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
@@ -514,7 +513,7 @@ class Layer(checkpointable.CheckpointableBase):
# Manage errors in Layer rather than Checkpointable.
overwrite=True,
initializer=initializer,
- dtype=dtypes.as_dtype(dtype),
+ dtype=dtype,
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner,