diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-15 17:03:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-15 17:07:28 -0700 |
commit | 351e1673beffa8583ff75046eb516893b9e5c79d (patch) | |
tree | ce43351b12727dbeaae0ebfd9d91473bcf5e671d /tensorflow/contrib/hvx | |
parent | 6af52579e4b93d47ae6658d7d9d7144f76547290 (diff) |
Generarize TF HVX runtime in order to benchmark models on HVX
PiperOrigin-RevId: 159174734
Diffstat (limited to 'tensorflow/contrib/hvx')
-rw-r--r-- | tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 6ae7c4a742..6af608396a 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -33,10 +33,15 @@ limitations under the License. #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { + namespace { -static int ParseFlags(int argc, char* argv[], string* in_graph) { +static int ParseFlags(int argc, char* argv[], string* in_graph, + bool* dump_all_nodes, bool* dump_shape_and_type) { std::vector<Flag> flag_list = { - Flag("in_graph", in_graph, "input graph file name"), + Flag("in_graph", in_graph, "Input graph file name to check hvx support."), + Flag("dump_all_nodes", dump_all_nodes, "Dump all nodes in the model."), + Flag("dump_shape_and_type", dump_shape_and_type, + "Dump shape and type of nodes"), }; CHECK(Flags::Parse(&argc, argv, flag_list)); // We need to call this to set up global state for TensorFlow. @@ -48,12 +53,25 @@ static int ParseFlags(int argc, char* argv[], string* in_graph) { return 0; } -static void SummarizeNode(const NodeDef& node_def) { +static void SummarizeNode(const NodeDef& node_def, + const bool dump_shape_and_type) { LOG(INFO) << "Node(" << node_def.name() << ")"; LOG(INFO) << " op: " << node_def.op(); for (const string& input : node_def.input()) { LOG(INFO) << " Input: " << input; } + std::vector<DataType> data_types; + std::vector<TensorShape> shapes; + const Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( + node_def, &data_types, &shapes); + if (data_types.empty() || shapes.empty()) { + return; + } + CHECK_EQ(data_types.size(), shapes.size()); + for (int i = 0; i < data_types.size(); ++i) { + LOG(INFO) << " Output(" << i << "): " << DataType_Name(data_types.at(i)) + << ", " << shapes.at(i).DebugString(); + } } static void DumpRemoteFusedGraph(const NodeDef& node_def) { @@ -89,10 +107,14 @@ static void DumpRemoteFusedGraph(const NodeDef& node_def) { } } -static void CheckOpsSupport(const GraphDef& graph_def) { +static void CheckOpsSupport(const GraphDef& graph_def, + const bool dump_all_nodes, + const bool dump_shape_and_type) { const IGraphTransferOpsDefinitions& ops_definition = HexagonOpsDefinitions::getInstance(); LOG(INFO) << "Checking " << graph_def.node_size() << " nodes"; + LOG(INFO) << "dump_all_nodes = " << dump_all_nodes + << ", dump_shape_and_tpye = " << dump_shape_and_type; std::unordered_set<string> unsupported_ops; bool all_supported = true; @@ -125,9 +147,9 @@ static void CheckOpsSupport(const GraphDef& graph_def) { LOG(INFO) << count << " ops are not supported."; } - if (contains_remote_graph) { + if (contains_remote_graph || dump_all_nodes) { for (const NodeDef& node : graph_def.node()) { - SummarizeNode(node); + SummarizeNode(node, dump_shape_and_type); } } } @@ -137,7 +159,10 @@ static void CheckOpsSupport(const GraphDef& graph_def) { int main(int argc, char** argv) { tensorflow::string in_graph; - const int ret = tensorflow::ParseFlags(argc, argv, &in_graph); + bool dump_all_nodes; + bool dump_shape_and_type; + const int ret = tensorflow::ParseFlags(argc, argv, &in_graph, &dump_all_nodes, + &dump_shape_and_type); if (ret != 0) { return ret; } @@ -146,6 +171,6 @@ int main(int argc, char** argv) { TF_CHECK_OK(tensorflow::graph_transforms::LoadTextOrBinaryGraphFile( in_graph, &graph_def)); - tensorflow::CheckOpsSupport(graph_def); + tensorflow::CheckOpsSupport(graph_def, dump_all_nodes, dump_shape_and_type); return 0; } |