aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/hvx
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-15 17:03:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-15 17:07:28 -0700
commit351e1673beffa8583ff75046eb516893b9e5c79d (patch)
treece43351b12727dbeaae0ebfd9d91473bcf5e671d /tensorflow/contrib/hvx
parent6af52579e4b93d47ae6658d7d9d7144f76547290 (diff)
Generarize TF HVX runtime in order to benchmark models on HVX
PiperOrigin-RevId: 159174734
Diffstat (limited to 'tensorflow/contrib/hvx')
-rw-r--r--tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc41
1 files changed, 33 insertions, 8 deletions
diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
index 6ae7c4a742..6af608396a 100644
--- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
+++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
@@ -33,10 +33,15 @@ limitations under the License.
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
+
namespace {
-static int ParseFlags(int argc, char* argv[], string* in_graph) {
+static int ParseFlags(int argc, char* argv[], string* in_graph,
+ bool* dump_all_nodes, bool* dump_shape_and_type) {
std::vector<Flag> flag_list = {
- Flag("in_graph", in_graph, "input graph file name"),
+ Flag("in_graph", in_graph, "Input graph file name to check hvx support."),
+ Flag("dump_all_nodes", dump_all_nodes, "Dump all nodes in the model."),
+ Flag("dump_shape_and_type", dump_shape_and_type,
+ "Dump shape and type of nodes"),
};
CHECK(Flags::Parse(&argc, argv, flag_list));
// We need to call this to set up global state for TensorFlow.
@@ -48,12 +53,25 @@ static int ParseFlags(int argc, char* argv[], string* in_graph) {
return 0;
}
-static void SummarizeNode(const NodeDef& node_def) {
+static void SummarizeNode(const NodeDef& node_def,
+ const bool dump_shape_and_type) {
LOG(INFO) << "Node(" << node_def.name() << ")";
LOG(INFO) << " op: " << node_def.op();
for (const string& input : node_def.input()) {
LOG(INFO) << " Input: " << input;
}
+ std::vector<DataType> data_types;
+ std::vector<TensorShape> shapes;
+ const Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
+ node_def, &data_types, &shapes);
+ if (data_types.empty() || shapes.empty()) {
+ return;
+ }
+ CHECK_EQ(data_types.size(), shapes.size());
+ for (int i = 0; i < data_types.size(); ++i) {
+ LOG(INFO) << " Output(" << i << "): " << DataType_Name(data_types.at(i))
+ << ", " << shapes.at(i).DebugString();
+ }
}
static void DumpRemoteFusedGraph(const NodeDef& node_def) {
@@ -89,10 +107,14 @@ static void DumpRemoteFusedGraph(const NodeDef& node_def) {
}
}
-static void CheckOpsSupport(const GraphDef& graph_def) {
+static void CheckOpsSupport(const GraphDef& graph_def,
+ const bool dump_all_nodes,
+ const bool dump_shape_and_type) {
const IGraphTransferOpsDefinitions& ops_definition =
HexagonOpsDefinitions::getInstance();
LOG(INFO) << "Checking " << graph_def.node_size() << " nodes";
+ LOG(INFO) << "dump_all_nodes = " << dump_all_nodes
+ << ", dump_shape_and_tpye = " << dump_shape_and_type;
std::unordered_set<string> unsupported_ops;
bool all_supported = true;
@@ -125,9 +147,9 @@ static void CheckOpsSupport(const GraphDef& graph_def) {
LOG(INFO) << count << " ops are not supported.";
}
- if (contains_remote_graph) {
+ if (contains_remote_graph || dump_all_nodes) {
for (const NodeDef& node : graph_def.node()) {
- SummarizeNode(node);
+ SummarizeNode(node, dump_shape_and_type);
}
}
}
@@ -137,7 +159,10 @@ static void CheckOpsSupport(const GraphDef& graph_def) {
int main(int argc, char** argv) {
tensorflow::string in_graph;
- const int ret = tensorflow::ParseFlags(argc, argv, &in_graph);
+ bool dump_all_nodes;
+ bool dump_shape_and_type;
+ const int ret = tensorflow::ParseFlags(argc, argv, &in_graph, &dump_all_nodes,
+ &dump_shape_and_type);
if (ret != 0) {
return ret;
}
@@ -146,6 +171,6 @@ int main(int argc, char** argv) {
TF_CHECK_OK(tensorflow::graph_transforms::LoadTextOrBinaryGraphFile(
in_graph, &graph_def));
- tensorflow::CheckOpsSupport(graph_def);
+ tensorflow::CheckOpsSupport(graph_def, dump_all_nodes, dump_shape_and_type);
return 0;
}