aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-11 09:26:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-11 10:32:03 -0700
commit95287a7b8365ad02c2d10b8d919a6c7463a332b9 (patch)
tree4fd00d1dfba8d7020a18b6df384dc89836cdc565 /tensorflow/contrib/graph_editor
parentbfb577a15b055abf9a239a1114dfe1bd26c67234 (diff)
Copying an op now copy the shape of its output tensors as well.
Change: 130002137
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/transform.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 4d64987429..5f4213b763 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -119,12 +119,13 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True):
return None
-def copy_op_handler(info, op):
+def copy_op_handler(info, op, copy_shape=True):
"""Copy a tf.Operation.
Args:
info: Transform._Info instance.
op: the tf.Operation to be copied.
+ copy_shape: also copy the shape of the tensor
Returns:
A copy of op.
"""
@@ -161,6 +162,13 @@ def copy_op_handler(info, op):
# Initialize a new Operation instance
op_ = tf_ops.Operation(node_def_, info.graph_, inputs_, output_types_,
control_inputs_, input_types_, original_op_, op_def_)
+
+ # copy the shape over
+ if copy_shape:
+ for t, t_ in zip(op.outputs, op_.outputs):
+ t_.set_shape(t.get_shape())
+
+ # Add op to the graph
info.graph_._add_op(op_)
# pylint: enable=protected-access