aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/hvx
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-18 10:38:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-18 10:43:12 -0700
commit9ad851e54d014532dd3b3c8308396769f9a7aeee (patch)
treeff777c97939a6f4d9de6ebd7226b19e9ffb2ff2f /tensorflow/contrib/hvx
parent7916e22e954fc893e673f74b4088b9e9c3a9be97 (diff)
Add graph transform rewriter for remote fused graph
PiperOrigin-RevId: 156448934
Diffstat (limited to 'tensorflow/contrib/hvx')
-rw-r--r--tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD1
-rw-r--r--tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc55
2 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD
index 922996a686..fa75943d78 100644
--- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD
+++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD
@@ -29,6 +29,7 @@ cc_binary(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
+ "//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//tensorflow/core/kernels/hexagon:graph_transferer",
"//tensorflow/tools/graph_transforms:transform_utils",
],
diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
index 3a219bb3e6..6ae7c4a742 100644
--- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
+++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
@@ -22,7 +22,9 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
+#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h"
@@ -46,6 +48,47 @@ static int ParseFlags(int argc, char* argv[], string* in_graph) {
return 0;
}
+static void SummarizeNode(const NodeDef& node_def) {
+ LOG(INFO) << "Node(" << node_def.name() << ")";
+ LOG(INFO) << " op: " << node_def.op();
+ for (const string& input : node_def.input()) {
+ LOG(INFO) << " Input: " << input;
+ }
+}
+
+static void DumpRemoteFusedGraph(const NodeDef& node_def) {
+ LOG(INFO) << "Remote fused graph found.";
+ RemoteFusedGraphExecuteInfo info;
+ string serialized_proto;
+ GetNodeAttr(node_def,
+ RemoteFusedGraphExecuteUtils::
+ ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
+ &serialized_proto)
+ .IgnoreError();
+ info.ParseFromString(serialized_proto);
+ LOG(INFO) << "Node name: " << node_def.name();
+ LOG(INFO) << "Executor name: " << info.executor_name();
+ for (const string& input : info.graph_input_node_name()) {
+ LOG(INFO) << "Input: " << input;
+ }
+ for (const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type :
+ info.default_graph_input_tensor_shape()) {
+ LOG(INFO) << "Input shape type: " << shape_type.DebugString();
+ }
+ for (const string& output : info.graph_output_node_name()) {
+ LOG(INFO) << "Output: " << output;
+ }
+ for (const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type :
+ info.default_graph_output_tensor_shape()) {
+ LOG(INFO) << "Output shape type: " << shape_type.DebugString();
+ }
+ const int subgraph_node_size = info.remote_graph().node_size();
+ LOG(INFO) << "Nodes in the graph: " << subgraph_node_size;
+ for (int i = 0; i < subgraph_node_size; ++i) {
+ LOG(INFO) << "node(" << i << "): " << info.remote_graph().node(i).name();
+ }
+}
+
static void CheckOpsSupport(const GraphDef& graph_def) {
const IGraphTransferOpsDefinitions& ops_definition =
HexagonOpsDefinitions::getInstance();
@@ -53,7 +96,13 @@ static void CheckOpsSupport(const GraphDef& graph_def) {
std::unordered_set<string> unsupported_ops;
bool all_supported = true;
+ bool contains_remote_graph = false;
for (const NodeDef& node : graph_def.node()) {
+ if (node.op() == "RemoteFusedGraphExecute") {
+ contains_remote_graph = true;
+ DumpRemoteFusedGraph(node);
+ continue;
+ }
// TODO(satok): Set correct data type if it's given.
const int op_id = ops_definition.GetOpIdFor(node.op(), {});
if (op_id == IGraphTransferOpsDefinitions::INVALID_OP_ID) {
@@ -75,6 +124,12 @@ static void CheckOpsSupport(const GraphDef& graph_def) {
} else {
LOG(INFO) << count << " ops are not supported.";
}
+
+ if (contains_remote_graph) {
+ for (const NodeDef& node : graph_def.node()) {
+ SummarizeNode(node);
+ }
+ }
}
} // namespace