aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/core.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/core.py')
-rw-r--r--tensorflow/python/layers/core.py21
1 files changed, 7 insertions, 14 deletions
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index 92177f932e..fc027ca95c 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -133,24 +133,17 @@ class Dense(base._Layer): # pylint: disable=protected-access
def call(self, inputs):
shape = inputs.get_shape().as_list()
- input_dim = shape[-1]
output_shape = shape[:-1] + [self.units]
if len(output_shape) > 2:
- # Reshape the input to 2D.
- output_shape_tensors = array_ops.unstack(array_ops.shape(inputs))
- output_shape_tensors[-1] = self.units
- output_shape_tensor = array_ops.stack(output_shape_tensors)
- inputs = array_ops.reshape(inputs, [-1, input_dim])
-
- outputs = standard_ops.matmul(inputs, self.kernel)
- if self.use_bias:
- outputs = nn.bias_add(outputs, self.bias)
-
- if len(output_shape) > 2:
+ # Broadcasting is required for the inputs.
+ outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
+ [0]])
# Reshape the output back to the original ndim of the input.
- outputs = array_ops.reshape(outputs, output_shape_tensor)
outputs.set_shape(output_shape)
-
+ else:
+ outputs = standard_ops.matmul(inputs, self.kernel)
+ if self.use_bias:
+ outputs = nn.bias_add(outputs, self.bias)
if self.activation is not None:
return self.activation(outputs) # pylint: disable=not-callable
return outputs