diff options
Diffstat (limited to 'tensorflow/python/layers/core.py')
-rw-r--r-- | tensorflow/python/layers/core.py | 21 |
1 files changed, 7 insertions, 14 deletions
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 92177f932e..fc027ca95c 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -133,24 +133,17 @@ class Dense(base._Layer): # pylint: disable=protected-access def call(self, inputs): shape = inputs.get_shape().as_list() - input_dim = shape[-1] output_shape = shape[:-1] + [self.units] if len(output_shape) > 2: - # Reshape the input to 2D. - output_shape_tensors = array_ops.unstack(array_ops.shape(inputs)) - output_shape_tensors[-1] = self.units - output_shape_tensor = array_ops.stack(output_shape_tensors) - inputs = array_ops.reshape(inputs, [-1, input_dim]) - - outputs = standard_ops.matmul(inputs, self.kernel) - if self.use_bias: - outputs = nn.bias_add(outputs, self.bias) - - if len(output_shape) > 2: + # Broadcasting is required for the inputs. + outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1], + [0]]) # Reshape the output back to the original ndim of the input. - outputs = array_ops.reshape(outputs, output_shape_tensor) outputs.set_shape(output_shape) - + else: + outputs = standard_ops.matmul(inputs, self.kernel) + if self.use_bias: + outputs = nn.bias_add(outputs, self.bias) if self.activation is not None: return self.activation(outputs) # pylint: disable=not-callable return outputs |