aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/backports.cc
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2017-03-20 09:46:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-20 11:04:56 -0700
commit331f2727d20cb23e26f39348d67e70202fc45e02 (patch)
treefa7c1c4fa0e7a766996fcb3850b8dd600be9d394 /tensorflow/tools/graph_transforms/backports.cc
parentb25d1c7d3e9f30925aa132ba62e79d281191a3dc (diff)
Added backporting rules for TensorArrayV3 and friends
Change: 150645199
Diffstat (limited to 'tensorflow/tools/graph_transforms/backports.cc')
-rw-r--r--tensorflow/tools/graph_transforms/backports.cc70
1 files changed, 70 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/backports.cc b/tensorflow/tools/graph_transforms/backports.cc
index 3b1d57146b..c5ec1fdd7b 100644
--- a/tensorflow/tools/graph_transforms/backports.cc
+++ b/tensorflow/tools/graph_transforms/backports.cc
@@ -61,5 +61,75 @@ Status BackportConcatV2Transform(const GraphDef& input_graph_def,
REGISTER_GRAPH_TRANSFORM("backport_concatv2", BackportConcatV2Transform);
+// Switch any TensorArrayV3 nodes to the v2 version, removing the second output.
+Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def) {
+ std::map<string, string> inputs_to_rename;
+ GraphDef replaced_graph_def;
+ TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
+ input_graph_def, {"TensorArrayV3|TensorArrayGradV3"},
+ [&inputs_to_rename](const NodeMatch& match,
+ const std::set<string>& input_nodes,
+ const std::set<string>& output_nodes,
+ std::vector<NodeDef>* new_nodes) {
+ const NodeDef& tensor_array_v3_node = match.node;
+
+ // All we need to do here is rename the op type, since the attributes
+ // remain the same.
+ NodeDef tensor_array_v2_node = tensor_array_v3_node;
+ if (tensor_array_v3_node.op() == "TensorArrayV3") {
+ tensor_array_v2_node.set_op("TensorArrayV2");
+ } else {
+ tensor_array_v2_node.set_op("TensorArrayGradV2");
+ }
+
+ // The v3 version has a second 'flow' output that's not present in v2,
+ // so substitute a dummy constant instead in any places that use it.
+ NodeDef replacement_flow_node;
+ replacement_flow_node.set_op("Const");
+ replacement_flow_node.set_name(tensor_array_v3_node.name() +
+ "/replacement_flow_node");
+ Tensor replacement_flow_tensor(DT_FLOAT, {});
+ // I'm picking an arbitrary value for the gradient flow here, for lack
+ // of a better alternative.
+ replacement_flow_tensor.flat<float>()(0) = 1.0f;
+ SetNodeTensorAttr<float>("value", replacement_flow_tensor,
+ &replacement_flow_node);
+ inputs_to_rename[tensor_array_v3_node.name() + ":1"] =
+ replacement_flow_node.name();
+
+ new_nodes->push_back(tensor_array_v2_node);
+ new_nodes->push_back(replacement_flow_node);
+ return Status::OK();
+ },
+ {true}, &replaced_graph_def));
+ // Update the graph so that any nodes that referred to removed inputs now
+ // pull from the substitute constants we've added.
+ GraphDef renamed_graph_def;
+ TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename,
+ std::unordered_set<string>(),
+ &renamed_graph_def));
+ TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
+ renamed_graph_def,
+ {"TensorArrayWriteV3|TensorArrayReadV3|TensorArrayGatherV3|"
+ "TensorArrayScatterV3|TensorArrayConcatV3|TensorArraySplitV3|"
+ "TensorArraySizeV3|TensorArrayCloseV3"},
+ [](const NodeMatch& match, const std::set<string>& input_nodes,
+ const std::set<string>& output_nodes,
+ std::vector<NodeDef>* new_nodes) {
+ const NodeDef& v3_node = match.node;
+ NodeDef v2_node = v3_node;
+ v2_node.set_op(v3_node.op().substr(0, v3_node.op().size() - 1) + "2");
+ new_nodes->push_back(v2_node);
+ return Status::OK();
+ },
+ {true}, output_graph_def));
+ return Status::OK();
+}
+
+REGISTER_GRAPH_TRANSFORM("backport_tensor_array_v3",
+ BackportTensorArrayV3Transform);
+
} // namespace graph_transforms
} // namespace tensorflow