diff options
author | 2018-05-14 16:07:33 -0700 | |
---|---|---|
committer | 2018-05-14 16:10:12 -0700 | |
commit | 2451eef12c6b6b09dbf6b5b4a19d95272e197409 (patch) | |
tree | c7715c9fc03eaf23dec3cccd746a345114e71cfd | |
parent | 1761e1dde7d874888eb01af7cef2d18488ff7b60 (diff) |
Fix bug where custom layers could crash.
Layer.add_weight would crash when called without a dtype or initializer.
PiperOrigin-RevId: 196583182
-rw-r--r-- | tensorflow/python/keras/_impl/keras/engine/base_layer.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py index 72ab77fbbd..5dc93806f4 100644 --- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py +++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py @@ -485,8 +485,7 @@ class Layer(checkpointable.CheckpointableBase): """ if dtype is None: dtype = self.dtype or backend.floatx() - else: - dtype = dtypes.as_dtype(dtype) + dtype = dtypes.as_dtype(dtype) initializer = initializers.get(initializer) regularizer = regularizers.get(regularizer) constraint = constraints.get(constraint) @@ -514,7 +513,7 @@ class Layer(checkpointable.CheckpointableBase): # Manage errors in Layer rather than Checkpointable. overwrite=True, initializer=initializer, - dtype=dtypes.as_dtype(dtype), + dtype=dtype, constraint=constraint, trainable=trainable and self.trainable, partitioner=partitioner, |