diff options
author | 2016-08-11 09:26:05 -0800 | |
---|---|---|
committer | 2016-08-11 10:32:03 -0700 | |
commit | 95287a7b8365ad02c2d10b8d919a6c7463a332b9 (patch) | |
tree | 4fd00d1dfba8d7020a18b6df384dc89836cdc565 /tensorflow/contrib/graph_editor | |
parent | bfb577a15b055abf9a239a1114dfe1bd26c67234 (diff) |
Copying an op now copy the shape of its output tensors as well.
Change: 130002137
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 4d64987429..5f4213b763 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -119,12 +119,13 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): return None -def copy_op_handler(info, op): +def copy_op_handler(info, op, copy_shape=True): """Copy a tf.Operation. Args: info: Transform._Info instance. op: the tf.Operation to be copied. + copy_shape: also copy the shape of the tensor Returns: A copy of op. """ @@ -161,6 +162,13 @@ def copy_op_handler(info, op): # Initialize a new Operation instance op_ = tf_ops.Operation(node_def_, info.graph_, inputs_, output_types_, control_inputs_, input_types_, original_op_, op_def_) + + # copy the shape over + if copy_shape: + for t, t_ in zip(op.outputs, op_.outputs): + t_.set_shape(t.get_shape()) + + # Add op to the graph info.graph_._add_op(op_) # pylint: enable=protected-access |