diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-05-18 10:38:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-18 10:43:12 -0700 |
commit | 9ad851e54d014532dd3b3c8308396769f9a7aeee (patch) | |
tree | ff777c97939a6f4d9de6ebd7226b19e9ffb2ff2f /tensorflow/contrib/hvx | |
parent | 7916e22e954fc893e673f74b4088b9e9c3a9be97 (diff) |
Add graph transform rewriter for remote fused graph
PiperOrigin-RevId: 156448934
Diffstat (limited to 'tensorflow/contrib/hvx')
-rw-r--r-- | tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc | 55 |
2 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 922996a686..fa75943d78 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -29,6 +29,7 @@ cc_binary( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", + "//tensorflow/core/kernels:remote_fused_graph_execute_utils", "//tensorflow/core/kernels/hexagon:graph_transferer", "//tensorflow/tools/graph_transforms:transform_utils", ], diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 3a219bb3e6..6ae7c4a742 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -22,7 +22,9 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h" +#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" @@ -46,6 +48,47 @@ static int ParseFlags(int argc, char* argv[], string* in_graph) { return 0; } +static void SummarizeNode(const NodeDef& node_def) { + LOG(INFO) << "Node(" << node_def.name() << ")"; + LOG(INFO) << " op: " << node_def.op(); + for (const string& input : node_def.input()) { + LOG(INFO) << " Input: " << input; + } +} + +static void DumpRemoteFusedGraph(const NodeDef& node_def) { + LOG(INFO) << "Remote fused graph found."; + RemoteFusedGraphExecuteInfo info; + string serialized_proto; + GetNodeAttr(node_def, + RemoteFusedGraphExecuteUtils:: + ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO, + &serialized_proto) + .IgnoreError(); + info.ParseFromString(serialized_proto); + LOG(INFO) << "Node name: " << node_def.name(); + LOG(INFO) << "Executor name: " << info.executor_name(); + for (const string& input : info.graph_input_node_name()) { + LOG(INFO) << "Input: " << input; + } + for (const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type : + info.default_graph_input_tensor_shape()) { + LOG(INFO) << "Input shape type: " << shape_type.DebugString(); + } + for (const string& output : info.graph_output_node_name()) { + LOG(INFO) << "Output: " << output; + } + for (const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type : + info.default_graph_output_tensor_shape()) { + LOG(INFO) << "Output shape type: " << shape_type.DebugString(); + } + const int subgraph_node_size = info.remote_graph().node_size(); + LOG(INFO) << "Nodes in the graph: " << subgraph_node_size; + for (int i = 0; i < subgraph_node_size; ++i) { + LOG(INFO) << "node(" << i << "): " << info.remote_graph().node(i).name(); + } +} + static void CheckOpsSupport(const GraphDef& graph_def) { const IGraphTransferOpsDefinitions& ops_definition = HexagonOpsDefinitions::getInstance(); @@ -53,7 +96,13 @@ static void CheckOpsSupport(const GraphDef& graph_def) { std::unordered_set<string> unsupported_ops; bool all_supported = true; + bool contains_remote_graph = false; for (const NodeDef& node : graph_def.node()) { + if (node.op() == "RemoteFusedGraphExecute") { + contains_remote_graph = true; + DumpRemoteFusedGraph(node); + continue; + } // TODO(satok): Set correct data type if it's given. const int op_id = ops_definition.GetOpIdFor(node.op(), {}); if (op_id == IGraphTransferOpsDefinitions::INVALID_OP_ID) { @@ -75,6 +124,12 @@ static void CheckOpsSupport(const GraphDef& graph_def) { } else { LOG(INFO) << count << " ops are not supported."; } + + if (contains_remote_graph) { + for (const NodeDef& node : graph_def.node()) { + SummarizeNode(node); + } + } } } // namespace |