diff options
Diffstat (limited to 'tensorflow/python/keras/layers/core.py')
-rw-r--r-- | tensorflow/python/keras/layers/core.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 2bf6229ccb..f28cade474 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -26,6 +26,7 @@ import warnings import numpy as np from tensorflow.python.eager import context +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import activations @@ -929,13 +930,13 @@ class Dense(Layer): def call(self, inputs): inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - shape = inputs.get_shape().as_list() - if len(shape) > 2: + rank = common_shapes.rank(inputs) + if rank > 2: # Broadcasting is required for the inputs. - outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1], - [0]]) + outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]]) # Reshape the output back to the original ndim of the input. if not context.executing_eagerly(): + shape = inputs.get_shape().as_list() output_shape = shape[:-1] + [self.units] outputs.set_shape(output_shape) else: |