aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ben Lee <blee@google.com>2016-11-17 11:15:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-17 11:26:07 -0800
commit7d0573f0d7b551257dc013a98eaaead5f509c4d0 (patch)
tree35cddf13b6b91f486c45fa74c155b22fadd69676
parentadf0dc7c15b3d9114839a8aa101126d98f75710f (diff)
Remove spurious unpack/shape from fully_connected if not needed.
Change: 139483354
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 633d4f9c06..6b2bdb9970 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -1296,9 +1296,6 @@ def fully_connected(inputs,
static_shape = inputs_shape.as_list()
static_shape[-1] = num_outputs
- out_shape = array_ops.unpack(array_ops.shape(inputs))
- out_shape[-1] = num_outputs
-
weights_shape = [num_input_units, num_outputs]
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
@@ -1310,6 +1307,8 @@ def fully_connected(inputs,
collections=weights_collections,
trainable=trainable)
if len(static_shape) > 2:
+ out_shape = array_ops.unpack(array_ops.shape(inputs))
+ out_shape[-1] = num_outputs
# Reshape inputs
inputs = array_ops.reshape(inputs, [-1, num_input_units])
outputs = standard_ops.matmul(inputs, weights)