aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/core.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/core.py')
-rw-r--r--tensorflow/python/keras/layers/core.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 2bf6229ccb..f28cade474 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -26,6 +26,7 @@ import warnings
import numpy as np
from tensorflow.python.eager import context
+from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
@@ -929,13 +930,13 @@ class Dense(Layer):
def call(self, inputs):
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
- shape = inputs.get_shape().as_list()
- if len(shape) > 2:
+ rank = common_shapes.rank(inputs)
+ if rank > 2:
# Broadcasting is required for the inputs.
- outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
- [0]])
+ outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]])
# Reshape the output back to the original ndim of the input.
if not context.executing_eagerly():
+ shape = inputs.get_shape().as_list()
output_shape = shape[:-1] + [self.units]
outputs.set_shape(output_shape)
else: