diff options
author | 2016-11-17 11:15:51 -0800 | |
---|---|---|
committer | 2016-11-17 11:26:07 -0800 | |
commit | 7d0573f0d7b551257dc013a98eaaead5f509c4d0 (patch) | |
tree | 35cddf13b6b91f486c45fa74c155b22fadd69676 | |
parent | adf0dc7c15b3d9114839a8aa101126d98f75710f (diff) |
Remove spurious unpack/shape from fully_connected if not needed.
Change: 139483354
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 633d4f9c06..6b2bdb9970 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1296,9 +1296,6 @@ def fully_connected(inputs, static_shape = inputs_shape.as_list() static_shape[-1] = num_outputs - out_shape = array_ops.unpack(array_ops.shape(inputs)) - out_shape[-1] = num_outputs - weights_shape = [num_input_units, num_outputs] weights_collections = utils.get_variable_collections( variables_collections, 'weights') @@ -1310,6 +1307,8 @@ def fully_connected(inputs, collections=weights_collections, trainable=trainable) if len(static_shape) > 2: + out_shape = array_ops.unpack(array_ops.shape(inputs)) + out_shape[-1] = num_outputs # Reshape inputs inputs = array_ops.reshape(inputs, [-1, num_input_units]) outputs = standard_ops.matmul(inputs, weights) |