diff options
author | 2017-02-02 14:44:30 -0800 | |
---|---|---|
committer | 2017-02-02 15:04:42 -0800 | |
commit | b44959b879661204c500763ea7f1f66b39c2bc88 (patch) | |
tree | 5bec34f1deafa9e65de9e13578cfa6ba83593afa | |
parent | 4eb0164dbcf690b9e33160004198905e96b3f049 (diff) |
Add proto for graph transfer to serialize soc node
Change: 146409586
-rw-r--r-- | tensorflow/contrib/cmake/tf_core_framework.cmake | 1 | ||||
-rw-r--r-- | tensorflow/contrib/makefile/proto_text_pb_cc_files.txt | 1 | ||||
-rw-r--r-- | tensorflow/contrib/makefile/proto_text_pb_h_files.txt | 1 | ||||
-rw-r--r-- | tensorflow/contrib/makefile/tf_pb_text_files.txt | 1 | ||||
-rw-r--r-- | tensorflow/contrib/makefile/tf_proto_files.txt | 1 | ||||
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/graph_transfer_info.proto | 45 | ||||
-rw-r--r-- | tensorflow/core/kernels/hexagon/graph_transferer.cc | 84 | ||||
-rw-r--r-- | tensorflow/core/kernels/hexagon/graph_transferer.h | 29 | ||||
-rw-r--r-- | tensorflow/core/kernels/hexagon/graph_transferer_test.cc | 60 | ||||
-rw-r--r-- | tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc | 16 |
11 files changed, 147 insertions, 93 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 2dd00c5574..a60e4e74d1 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -88,6 +88,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/device_attributes.proto" "tensorflow/core/framework/function.proto" "tensorflow/core/framework/graph.proto" + "tensorflow/core/framework/graph_transfer_info.proto" "tensorflow/core/framework/kernel_def.proto" "tensorflow/core/framework/log_memory.proto" "tensorflow/core/framework/node_def.proto" diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 395b8bde60..c17b26eccc 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -24,6 +24,7 @@ tensorflow/core/framework/op_def.pb.cc tensorflow/core/framework/node_def.pb.cc tensorflow/core/framework/log_memory.pb.cc tensorflow/core/framework/kernel_def.pb.cc +tensorflow/core/framework/graph_transfer_info.pb.cc tensorflow/core/framework/graph.pb.cc tensorflow/core/framework/function.pb.cc tensorflow/core/framework/device_attributes.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 4bd371f4dc..f389742186 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/op_def.pb.h tensorflow/core/framework/node_def.pb.h tensorflow/core/framework/log_memory.pb.h tensorflow/core/framework/kernel_def.pb.h +tensorflow/core/framework/graph_transfer_info.pb.h tensorflow/core/framework/graph.pb.h tensorflow/core/framework/function.pb.h tensorflow/core/framework/device_attributes.pb.h diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index f1657793b2..d4b34e7146 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -18,6 +18,7 @@ tensorflow/core/framework/op_def.pb_text.cc tensorflow/core/framework/node_def.pb_text.cc tensorflow/core/framework/log_memory.pb_text.cc tensorflow/core/framework/kernel_def.pb_text.cc +tensorflow/core/framework/graph_transfer_info.pb_text.cc tensorflow/core/framework/graph.pb_text.cc tensorflow/core/framework/function.pb_text.cc tensorflow/core/framework/device_attributes.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index f60e7f2360..27d16ac144 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/op_def.proto tensorflow/core/framework/node_def.proto tensorflow/core/framework/log_memory.proto tensorflow/core/framework/kernel_def.proto +tensorflow/core/framework/graph_transfer_info.proto tensorflow/core/framework/graph.proto tensorflow/core/framework/function.proto tensorflow/core/framework/device_attributes.proto diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 19ac04b9e6..780a97af18 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -128,6 +128,7 @@ CORE_PROTO_SRCS = [ "framework/device_attributes.proto", "framework/function.proto", "framework/graph.proto", + "framework/graph_transfer_info.proto", "framework/kernel_def.proto", "framework/log_memory.proto", "framework/node_def.proto", diff --git a/tensorflow/core/framework/graph_transfer_info.proto b/tensorflow/core/framework/graph_transfer_info.proto new file mode 100644 index 0000000000..1a1ce6d97a --- /dev/null +++ b/tensorflow/core/framework/graph_transfer_info.proto @@ -0,0 +1,45 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphTransferInfoProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message GraphTransferInfo { + message NodeInput { + int32 node_id = 1; + int32 output_port = 2; + } + message NodeInfo { + string name = 1; + int32 node_id = 2; + string type_name = 3; + int32 soc_op_id = 4; + int32 padding_id = 5; + int32 input_count = 6; + int32 output_count = 7; + }; + message ConstNodeInfo { + string name = 1; + int32 node_id = 2; + repeated int64 shape = 3; + int32 data_size = 4; + bytes data = 5; + }; + message NodeInputInfo { + int32 node_id = 1; + repeated NodeInput node_input = 2; + }; + message NodeOutputInfo { + int32 node_id = 1; + repeated int32 max_byte_size = 2; + }; + repeated NodeInfo node_info = 1; + repeated ConstNodeInfo const_node_info = 2; + repeated NodeInputInfo node_input_info = 3; + repeated NodeOutputInfo node_output_info = 4; +}; diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index 662b935b90..b597d67b6f 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -279,17 +279,18 @@ void GraphTransferer::SortParams(const std::vector<string>& output_node_names) { // Setup dependency map placeholder std::vector<int> output_node_ids; std::unordered_map<int, std::unordered_set<int>> dependency_map; - for (const NodeTransferParams& params : node_transfer_params_list_) { - const int node_id = params.node_id; + for (const GraphTransferInfo::NodeInfo& params : + graph_transfer_info_.node_info()) { + const int node_id = params.node_id(); for (const string& output_node_name : output_node_names) { - if (params.name == output_node_name) { + if (params.name() == output_node_name) { output_node_ids.emplace_back(node_id); } } dependency_map.emplace(std::piecewise_construct, std::make_tuple(node_id), std::make_tuple()); - if (params.inputs_size == 0) { + if (params.input_count() == 0) { continue; } CHECK(input_map.count(node_id) == 1); @@ -305,8 +306,8 @@ void GraphTransferer::SortParams(const std::vector<string>& output_node_names) { FillDependencyRec(output_node_id, dependency_map, completed); } - std::sort(node_transfer_params_list_.begin(), - node_transfer_params_list_.end(), + std::sort(graph_transfer_info_.mutable_node_info()->begin(), + graph_transfer_info_.mutable_node_info()->end(), TransferParamsComparator(dependency_map)); } @@ -319,11 +320,6 @@ GraphTransferer::GetConstNodeParams() const { return const_node_transfer_params_list_; } -const std::vector<GraphTransferer::NodeTransferParams>& -GraphTransferer::GetOpNodeParams() const { - return node_transfer_params_list_; -} - const std::vector<GraphTransferer::NodeInputParams>& GraphTransferer::GetNodeInputParams() const { return node_input_params_list_; @@ -334,6 +330,10 @@ GraphTransferer::GetNodeOutputParams() const { return node_output_params_list_; } +const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const { + return graph_transfer_info_; +} + int GraphTransferer::CacheNode(const Node& node) { if (node_name_to_id_cache_map_.count(node.name()) > 0) { VLOG(1) << "Emplace node to cache failed"; @@ -643,10 +643,16 @@ void GraphTransferer::AppendNodeParams(const string& name, const int id, const std::vector<int>& extra_inputs, const int outputs_size) { VLOG(1) << "Append node params: " << name; - node_transfer_params_list_.emplace_back( - NodeTransferParams{name, id, type, type_id, padding, - inputs_size + static_cast<int>(extra_inputs.size()), - static_cast<int>(outputs_size)}); + GraphTransferInfo::NodeInfo& node_info = + *graph_transfer_info_.add_node_info(); + node_info.set_name(name); + node_info.set_node_id(id); + node_info.set_type_name(type); + node_info.set_soc_op_id(type_id); + node_info.set_padding_id(padding); + node_info.set_input_count(inputs_size + + static_cast<int>(extra_inputs.size())); + node_info.set_output_count(static_cast<int>(outputs_size)); } void GraphTransferer::AppendNodeInputParams( @@ -833,10 +839,10 @@ GraphTransferer::TransferParamsComparator::TransferParamsComparator( : dependency_map_(dep_map) {} bool GraphTransferer::TransferParamsComparator::operator()( - const GraphTransferer::NodeTransferParams& obj0, - const GraphTransferer::NodeTransferParams& obj1) { - const int node_id0 = obj0.node_id; - const int node_id1 = obj1.node_id; + const GraphTransferInfo::NodeInfo& obj0, + const GraphTransferInfo::NodeInfo& obj1) { + const int node_id0 = obj0.node_id(); + const int node_id1 = obj1.node_id(); bool obj0_uses_obj1 = false; if (dependency_map_.count(node_id0)) { obj0_uses_obj1 = dependency_map_.at(node_id0).count(node_id1) > 0; @@ -916,17 +922,18 @@ void GraphTransferer::DumpNodeTransferParams() const { } LOG(INFO) << "******\n"; LOG(INFO) << "*** Op Nodes ***"; - for (const NodeTransferParams& params : node_transfer_params_list_) { - LOG(INFO) << "[ " << params.node_id << " \"" << params.name; - LOG(INFO) << " type: " << params.type_name; - LOG(INFO) << " padding: " << ToPaddingDebugString(params.padding); - LOG(INFO) << " inputs: " << INPUTS_NODE_PREFIX + ToString(params.node_id) - << ", size = " << params.inputs_size; + for (const GraphTransferInfo::NodeInfo& params : + graph_transfer_info_.node_info()) { + LOG(INFO) << "[ " << params.node_id() << " \"" << params.name(); + LOG(INFO) << " type: " << params.type_name(); + LOG(INFO) << " padding: " << ToPaddingDebugString(params.padding_id()); + LOG(INFO) << " inputs: " << INPUTS_NODE_PREFIX + ToString(params.node_id()) + << ", size = " << params.input_count(); LOG(INFO) << " outputs: " - << (params.outputs_size <= 0 + << (params.output_count() <= 0 ? NULL_OUTPUT_NAME - : (OUTPUTS_NODE_PREFIX + ToString(params.node_id))) - << ", size = " << params.outputs_size << " ]"; + : (OUTPUTS_NODE_PREFIX + ToString(params.node_id()))) + << ", size = " << params.output_count() << " ]"; } LOG(INFO) << "******\n"; LOG(INFO) << "*** Node input params ***"; @@ -963,20 +970,21 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const { LOG(INFO) << sstream.str(); } LOG(INFO) << "Const node count = " << const_node_transfer_params_list_.size(); - for (const NodeTransferParams& params : node_transfer_params_list_) { + for (const GraphTransferInfo::NodeInfo& params : + graph_transfer_info_.node_info()) { std::stringstream sstream; - sstream << "---(OP) [" << params.name.c_str() << "," << std::hex - << params.node_id << std::dec << "," << params.soc_op_id << "," - << ToPaddingDebugString(params.padding) << "," - << INPUTS_NODE_PREFIX + ToString(params.node_id) << "," - << params.inputs_size << "," - << (params.outputs_size <= 0 + sstream << "---(OP) [" << params.name().c_str() << "," << std::hex + << params.node_id() << std::dec << "," << params.soc_op_id() << "," + << ToPaddingDebugString(params.padding_id()) << "," + << INPUTS_NODE_PREFIX + ToString(params.node_id()) << "," + << params.input_count() << "," + << (params.output_count() <= 0 ? NULL_OUTPUT_NAME - : (OUTPUTS_NODE_PREFIX + ToString(params.node_id))) - << "," << params.outputs_size << "," << params.type_name << "]"; + : (OUTPUTS_NODE_PREFIX + ToString(params.node_id()))) + << "," << params.output_count() << "," << params.type_name() << "]"; LOG(INFO) << sstream.str(); } - LOG(INFO) << "Op node count = " << node_transfer_params_list_.size(); + LOG(INFO) << "Op node count = " << graph_transfer_info_.node_info_size(); for (const NodeInputParams& params : node_input_params_list_) { std::stringstream sstream; sstream << "---(INPUT) [" << std::hex << params.node_id << std::dec; diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h index d86452905f..f8abd08bd4 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.h +++ b/tensorflow/core/kernels/hexagon/graph_transferer.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_transfer_info.pb.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h" @@ -39,7 +40,9 @@ namespace tensorflow { // to avoid unsupported ops in SoC. class GraphTransferer { public: + // TODO(satok): Remove. Use proto definition instead. static constexpr int MAX_SUPPORTED_RANK = 5; + // TODO(satok): Remove. Use proto definition instead. static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK - 1; using OutputTensorMap = std::unordered_map<string, Tensor*>; @@ -48,18 +51,8 @@ class GraphTransferer { Tensor tensor; }; - // Node parameters for transfer - struct NodeTransferParams { - string name; - int node_id; - string type_name; - int soc_op_id; - int padding; - int inputs_size; - int outputs_size; - }; - // Const node parameters for transfer + // TODO(satok): Remove. Use proto definition instead. struct ConstNodeTransferParams { string name; int node_id; @@ -69,12 +62,14 @@ class GraphTransferer { }; // Input parameters of a node for transfer + // TODO(satok): Remove. Use proto definition instead. struct NodeInputParams { int node_id; std::vector<std::tuple<int, int>> input_node_id_and_output_port_list; }; // Output parameters of a node for transfer + // TODO(satok): Remove. Use proto definition instead. struct NodeOutputParams { int node_id; std::vector<int> max_sizes; @@ -130,22 +125,22 @@ class GraphTransferer { // Return const node parameters for transfer const std::vector<ConstNodeTransferParams>& GetConstNodeParams() const; - // Return op node parameters for transfer - const std::vector<NodeTransferParams>& GetOpNodeParams() const; - // Return input params of nodes const std::vector<NodeInputParams>& GetNodeInputParams() const; // Return output params of nodes const std::vector<NodeOutputParams>& GetNodeOutputParams() const; + // Return parameters for transfer + const GraphTransferInfo& GetGraphTransferInfo() const; + private: class TransferParamsComparator { public: TransferParamsComparator( const std::unordered_map<int, std::unordered_set<int>>& dep_map); - bool operator()(const GraphTransferer::NodeTransferParams& obj0, - const GraphTransferer::NodeTransferParams& obj1); + bool operator()(const GraphTransferInfo::NodeInfo& obj0, + const GraphTransferInfo::NodeInfo& obj1); const std::unordered_map<int, std::unordered_set<int>>& dependency_map_; }; @@ -258,7 +253,7 @@ class GraphTransferer { // Dump verification string of parameters to verify with offline tools void DumpVerificationStringOfNodeTransferParams() const; - std::vector<NodeTransferParams> node_transfer_params_list_; + GraphTransferInfo graph_transfer_info_; std::vector<ConstNodeTransferParams> const_node_transfer_params_list_; std::vector<NodeInputParams> node_input_params_list_; std::vector<NodeOutputParams> node_output_params_list_; diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc index 92b58083b9..6b0e57ed97 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/graph_transfer_info.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/hexagon/graph_transferer.h" @@ -128,11 +129,11 @@ static const GraphTransferer::ConstNodeTransferParams* FindConstNodeParams( return nullptr; } -static const GraphTransferer::NodeTransferParams* FindOpNodeParams( +static const GraphTransferInfo::NodeInfo* FindNodeInfo( const GraphTransferer& gt, const string& name) { - for (const GraphTransferer::NodeTransferParams& params : - gt.GetOpNodeParams()) { - if (params.name == name) { + for (const GraphTransferInfo::NodeInfo& params : + gt.GetGraphTransferInfo().node_info()) { + if (params.name() == name) { return ¶ms; } } @@ -162,26 +163,26 @@ static const GraphTransferer::NodeOutputParams* FindNodeOutputParams( } static void SanityCheckNodes(const GraphTransferer& gt) { - for (const GraphTransferer::NodeTransferParams& params : - gt.GetOpNodeParams()) { - if (params.inputs_size > 0) { + for (const GraphTransferInfo::NodeInfo& params : + gt.GetGraphTransferInfo().node_info()) { + if (params.input_count() > 0) { const GraphTransferer::NodeInputParams* input_params = - FindNodeInputParams(gt, params.node_id); + FindNodeInputParams(gt, params.node_id()); ASSERT_NE(nullptr, input_params); - EXPECT_EQ(params.inputs_size, + EXPECT_EQ(params.input_count(), input_params->input_node_id_and_output_port_list.size()); - EXPECT_EQ(params.node_id, input_params->node_id); + EXPECT_EQ(params.node_id(), input_params->node_id); for (const std::tuple<int, int>& pair : input_params->input_node_id_and_output_port_list) { EXPECT_GE(std::get<1>(pair), 0); } } - if (params.outputs_size > 0) { + if (params.output_count() > 0) { const GraphTransferer::NodeOutputParams* output_params = - FindNodeOutputParams(gt, params.node_id); + FindNodeOutputParams(gt, params.node_id()); ASSERT_NE(nullptr, output_params); - EXPECT_EQ(params.outputs_size, output_params->max_sizes.size()); - EXPECT_EQ(params.node_id, output_params->node_id); + EXPECT_EQ(params.output_count(), output_params->max_sizes.size()); + EXPECT_EQ(params.node_id(), output_params->node_id); for (const int max_size : output_params->max_sizes) { EXPECT_GE(max_size, 0); } @@ -344,17 +345,16 @@ TEST_F(GraphTransfererTest, LoadConvGraph) { SanityCheckNodes(gt_); const int const_node_count = gt_.GetConstNodeParams().size(); ASSERT_EQ(2, const_node_count); - const int op_node_count = gt_.GetOpNodeParams().size(); + const int op_node_count = gt_.GetGraphTransferInfo().node_info_size(); ASSERT_EQ(3, op_node_count); - const GraphTransferer::NodeTransferParams* params_conv = - FindOpNodeParams(gt_, "conv"); + const GraphTransferInfo::NodeInfo* params_conv = FindNodeInfo(gt_, "conv"); ASSERT_TRUE(params_conv != nullptr); - const int id = params_conv->node_id; + const int id = params_conv->node_id(); EXPECT_GE(id, 0); - EXPECT_EQ("Conv2D", params_conv->type_name); - EXPECT_EQ(3, params_conv->inputs_size); - EXPECT_EQ(1, params_conv->outputs_size); - EXPECT_EQ(Padding::SAME, params_conv->padding); + EXPECT_EQ("Conv2D", params_conv->type_name()); + EXPECT_EQ(3, params_conv->input_count()); + EXPECT_EQ(1, params_conv->output_count()); + EXPECT_EQ(Padding::SAME, params_conv->padding_id()); } TEST_F(GraphTransfererTest, LoadMaxPoolGraph) { @@ -370,17 +370,17 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) { SanityCheckNodes(gt_); const int const_node_count = gt_.GetConstNodeParams().size(); ASSERT_EQ(2, const_node_count); - const int op_node_count = gt_.GetOpNodeParams().size(); + const int op_node_count = gt_.GetGraphTransferInfo().node_info_size(); ASSERT_EQ(3, op_node_count); - const GraphTransferer::NodeTransferParams* params_max_pool = - FindOpNodeParams(gt_, "maxpool"); + const GraphTransferInfo::NodeInfo* params_max_pool = + FindNodeInfo(gt_, "maxpool"); ASSERT_TRUE(params_max_pool != nullptr); - const int id = params_max_pool->node_id; + const int id = params_max_pool->node_id(); EXPECT_GE(id, 0); - EXPECT_EQ("MaxPool", params_max_pool->type_name); - EXPECT_EQ(3, params_max_pool->inputs_size); - EXPECT_EQ(1, params_max_pool->outputs_size); - EXPECT_EQ(Padding::SAME, params_max_pool->padding); + EXPECT_EQ("MaxPool", params_max_pool->type_name()); + EXPECT_EQ(3, params_max_pool->input_count()); + EXPECT_EQ(1, params_max_pool->output_count()); + EXPECT_EQ(Padding::SAME, params_max_pool->padding_id()); } TEST(HexagonOpsDefinitions, CheckOpsDefinitions) { diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc index ca29fcdd47..52ec1a4e9d 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc @@ -122,10 +122,10 @@ bool HexagonControlWrapper::SetupGraph( } // 2. Setup op nodes - for (const GraphTransferer::NodeTransferParams& params : - graph_transferer.GetOpNodeParams()) { - const int node_id = params.node_id; - const int op_id = params.soc_op_id; + for (const GraphTransferInfo::NodeInfo& params : + graph_transferer.GetGraphTransferInfo().node_info()) { + const int node_id = params.node_id(); + const int op_id = params.soc_op_id(); CHECK(inputs_map.count(node_id) == 1); CHECK(outputs_map.count(node_id) <= 1); // Only output node doesn't have output @@ -142,16 +142,16 @@ bool HexagonControlWrapper::SetupGraph( CHECK(output_count > 0); } int padding_id = -1; - if (params.padding == 0) { + if (params.padding_id() == 0) { padding_id = 0; - } else if (params.padding == Padding::SAME) { + } else if (params.padding_id() == Padding::SAME) { padding_id = 1; - } else if (params.padding == Padding::VALID) { + } else if (params.padding_id() == Padding::VALID) { padding_id = 2; } else { CHECK(false); } - soc_interface_AppendNode(params.name.c_str(), node_id + NODE_ID_OFFSET, + soc_interface_AppendNode(params.name().c_str(), node_id + NODE_ID_OFFSET, op_id, padding_id, input_ptr, input_count, output_ptr, output_count); } |