aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Reed Wanderman-Milne <reedwm@google.com>2018-05-14 16:07:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-14 16:10:12 -0700
commit2451eef12c6b6b09dbf6b5b4a19d95272e197409 (patch)
treec7715c9fc03eaf23dec3cccd746a345114e71cfd
parent1761e1dde7d874888eb01af7cef2d18488ff7b60 (diff)
Fix bug where custom layers could crash.
Layer.add_weight would crash when called without a dtype or initializer. PiperOrigin-RevId: 196583182
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 72ab77fbbd..5dc93806f4 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -485,8 +485,7 @@ class Layer(checkpointable.CheckpointableBase):
"""
if dtype is None:
dtype = self.dtype or backend.floatx()
- else:
- dtype = dtypes.as_dtype(dtype)
+ dtype = dtypes.as_dtype(dtype)
initializer = initializers.get(initializer)
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
@@ -514,7 +513,7 @@ class Layer(checkpointable.CheckpointableBase):
# Manage errors in Layer rather than Checkpointable.
overwrite=True,
initializer=initializer,
- dtype=dtypes.as_dtype(dtype),
+ dtype=dtype,
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner,