diff options
author | Pete Warden <petewarden@google.com> | 2017-03-20 09:46:20 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-20 11:04:56 -0700 |
commit | 331f2727d20cb23e26f39348d67e70202fc45e02 (patch) | |
tree | fa7c1c4fa0e7a766996fcb3850b8dd600be9d394 /tensorflow/tools/graph_transforms/backports.cc | |
parent | b25d1c7d3e9f30925aa132ba62e79d281191a3dc (diff) |
Added backporting rules for TensorArrayV3 and friends
Change: 150645199
Diffstat (limited to 'tensorflow/tools/graph_transforms/backports.cc')
-rw-r--r-- | tensorflow/tools/graph_transforms/backports.cc | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/backports.cc b/tensorflow/tools/graph_transforms/backports.cc index 3b1d57146b..c5ec1fdd7b 100644 --- a/tensorflow/tools/graph_transforms/backports.cc +++ b/tensorflow/tools/graph_transforms/backports.cc @@ -61,5 +61,75 @@ Status BackportConcatV2Transform(const GraphDef& input_graph_def, REGISTER_GRAPH_TRANSFORM("backport_concatv2", BackportConcatV2Transform); +// Switch any TensorArrayV3 nodes to the v2 version, removing the second output. +Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + std::map<string, string> inputs_to_rename; + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, {"TensorArrayV3|TensorArrayGradV3"}, + [&inputs_to_rename](const NodeMatch& match, + const std::set<string>& input_nodes, + const std::set<string>& output_nodes, + std::vector<NodeDef>* new_nodes) { + const NodeDef& tensor_array_v3_node = match.node; + + // All we need to do here is rename the op type, since the attributes + // remain the same. + NodeDef tensor_array_v2_node = tensor_array_v3_node; + if (tensor_array_v3_node.op() == "TensorArrayV3") { + tensor_array_v2_node.set_op("TensorArrayV2"); + } else { + tensor_array_v2_node.set_op("TensorArrayGradV2"); + } + + // The v3 version has a second 'flow' output that's not present in v2, + // so substitute a dummy constant instead in any places that use it. + NodeDef replacement_flow_node; + replacement_flow_node.set_op("Const"); + replacement_flow_node.set_name(tensor_array_v3_node.name() + + "/replacement_flow_node"); + Tensor replacement_flow_tensor(DT_FLOAT, {}); + // I'm picking an arbitrary value for the gradient flow here, for lack + // of a better alternative. + replacement_flow_tensor.flat<float>()(0) = 1.0f; + SetNodeTensorAttr<float>("value", replacement_flow_tensor, + &replacement_flow_node); + inputs_to_rename[tensor_array_v3_node.name() + ":1"] = + replacement_flow_node.name(); + + new_nodes->push_back(tensor_array_v2_node); + new_nodes->push_back(replacement_flow_node); + return Status::OK(); + }, + {true}, &replaced_graph_def)); + // Update the graph so that any nodes that referred to removed inputs now + // pull from the substitute constants we've added. + GraphDef renamed_graph_def; + TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename, + std::unordered_set<string>(), + &renamed_graph_def)); + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + renamed_graph_def, + {"TensorArrayWriteV3|TensorArrayReadV3|TensorArrayGatherV3|" + "TensorArrayScatterV3|TensorArrayConcatV3|TensorArraySplitV3|" + "TensorArraySizeV3|TensorArrayCloseV3"}, + [](const NodeMatch& match, const std::set<string>& input_nodes, + const std::set<string>& output_nodes, + std::vector<NodeDef>* new_nodes) { + const NodeDef& v3_node = match.node; + NodeDef v2_node = v3_node; + v2_node.set_op(v3_node.op().substr(0, v3_node.op().size() - 1) + "2"); + new_nodes->push_back(v2_node); + return Status::OK(); + }, + {true}, output_graph_def)); + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("backport_tensor_array_v3", + BackportTensorArrayV3Transform); + } // namespace graph_transforms } // namespace tensorflow |