aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-02 14:44:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-02 15:04:42 -0800
commitb44959b879661204c500763ea7f1f66b39c2bc88 (patch)
tree5bec34f1deafa9e65de9e13578cfa6ba83593afa
parent4eb0164dbcf690b9e33160004198905e96b3f049 (diff)
Add proto for graph transfer to serialize soc node
Change: 146409586
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt1
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/framework/graph_transfer_info.proto45
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.cc84
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.h29
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer_test.cc60
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc16
11 files changed, 147 insertions, 93 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 2dd00c5574..a60e4e74d1 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -88,6 +88,7 @@ set(tf_proto_text_srcs
"tensorflow/core/framework/device_attributes.proto"
"tensorflow/core/framework/function.proto"
"tensorflow/core/framework/graph.proto"
+ "tensorflow/core/framework/graph_transfer_info.proto"
"tensorflow/core/framework/kernel_def.proto"
"tensorflow/core/framework/log_memory.proto"
"tensorflow/core/framework/node_def.proto"
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index 395b8bde60..c17b26eccc 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -24,6 +24,7 @@ tensorflow/core/framework/op_def.pb.cc
tensorflow/core/framework/node_def.pb.cc
tensorflow/core/framework/log_memory.pb.cc
tensorflow/core/framework/kernel_def.pb.cc
+tensorflow/core/framework/graph_transfer_info.pb.cc
tensorflow/core/framework/graph.pb.cc
tensorflow/core/framework/function.pb.cc
tensorflow/core/framework/device_attributes.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index 4bd371f4dc..f389742186 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -25,6 +25,7 @@ tensorflow/core/framework/op_def.pb.h
tensorflow/core/framework/node_def.pb.h
tensorflow/core/framework/log_memory.pb.h
tensorflow/core/framework/kernel_def.pb.h
+tensorflow/core/framework/graph_transfer_info.pb.h
tensorflow/core/framework/graph.pb.h
tensorflow/core/framework/function.pb.h
tensorflow/core/framework/device_attributes.pb.h
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index f1657793b2..d4b34e7146 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -18,6 +18,7 @@ tensorflow/core/framework/op_def.pb_text.cc
tensorflow/core/framework/node_def.pb_text.cc
tensorflow/core/framework/log_memory.pb_text.cc
tensorflow/core/framework/kernel_def.pb_text.cc
+tensorflow/core/framework/graph_transfer_info.pb_text.cc
tensorflow/core/framework/graph.pb_text.cc
tensorflow/core/framework/function.pb_text.cc
tensorflow/core/framework/device_attributes.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index f60e7f2360..27d16ac144 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -25,6 +25,7 @@ tensorflow/core/framework/op_def.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/log_memory.proto
tensorflow/core/framework/kernel_def.proto
+tensorflow/core/framework/graph_transfer_info.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/device_attributes.proto
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 19ac04b9e6..780a97af18 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -128,6 +128,7 @@ CORE_PROTO_SRCS = [
"framework/device_attributes.proto",
"framework/function.proto",
"framework/graph.proto",
+ "framework/graph_transfer_info.proto",
"framework/kernel_def.proto",
"framework/log_memory.proto",
"framework/node_def.proto",
diff --git a/tensorflow/core/framework/graph_transfer_info.proto b/tensorflow/core/framework/graph_transfer_info.proto
new file mode 100644
index 0000000000..1a1ce6d97a
--- /dev/null
+++ b/tensorflow/core/framework/graph_transfer_info.proto
@@ -0,0 +1,45 @@
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "GraphTransferInfoProto";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+// Protocol buffer representing a handle to a tensorflow resource. Handles are
+// not valid across executions, but can be serialized back and forth from within
+// a single run.
+message GraphTransferInfo {
+ message NodeInput {
+ int32 node_id = 1;
+ int32 output_port = 2;
+ }
+ message NodeInfo {
+ string name = 1;
+ int32 node_id = 2;
+ string type_name = 3;
+ int32 soc_op_id = 4;
+ int32 padding_id = 5;
+ int32 input_count = 6;
+ int32 output_count = 7;
+ };
+ message ConstNodeInfo {
+ string name = 1;
+ int32 node_id = 2;
+ repeated int64 shape = 3;
+ int32 data_size = 4;
+ bytes data = 5;
+ };
+ message NodeInputInfo {
+ int32 node_id = 1;
+ repeated NodeInput node_input = 2;
+ };
+ message NodeOutputInfo {
+ int32 node_id = 1;
+ repeated int32 max_byte_size = 2;
+ };
+ repeated NodeInfo node_info = 1;
+ repeated ConstNodeInfo const_node_info = 2;
+ repeated NodeInputInfo node_input_info = 3;
+ repeated NodeOutputInfo node_output_info = 4;
+};
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc
index 662b935b90..b597d67b6f 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc
@@ -279,17 +279,18 @@ void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
// Setup dependency map placeholder
std::vector<int> output_node_ids;
std::unordered_map<int, std::unordered_set<int>> dependency_map;
- for (const NodeTransferParams& params : node_transfer_params_list_) {
- const int node_id = params.node_id;
+ for (const GraphTransferInfo::NodeInfo& params :
+ graph_transfer_info_.node_info()) {
+ const int node_id = params.node_id();
for (const string& output_node_name : output_node_names) {
- if (params.name == output_node_name) {
+ if (params.name() == output_node_name) {
output_node_ids.emplace_back(node_id);
}
}
dependency_map.emplace(std::piecewise_construct, std::make_tuple(node_id),
std::make_tuple());
- if (params.inputs_size == 0) {
+ if (params.input_count() == 0) {
continue;
}
CHECK(input_map.count(node_id) == 1);
@@ -305,8 +306,8 @@ void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
FillDependencyRec(output_node_id, dependency_map, completed);
}
- std::sort(node_transfer_params_list_.begin(),
- node_transfer_params_list_.end(),
+ std::sort(graph_transfer_info_.mutable_node_info()->begin(),
+ graph_transfer_info_.mutable_node_info()->end(),
TransferParamsComparator(dependency_map));
}
@@ -319,11 +320,6 @@ GraphTransferer::GetConstNodeParams() const {
return const_node_transfer_params_list_;
}
-const std::vector<GraphTransferer::NodeTransferParams>&
-GraphTransferer::GetOpNodeParams() const {
- return node_transfer_params_list_;
-}
-
const std::vector<GraphTransferer::NodeInputParams>&
GraphTransferer::GetNodeInputParams() const {
return node_input_params_list_;
@@ -334,6 +330,10 @@ GraphTransferer::GetNodeOutputParams() const {
return node_output_params_list_;
}
+const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const {
+ return graph_transfer_info_;
+}
+
int GraphTransferer::CacheNode(const Node& node) {
if (node_name_to_id_cache_map_.count(node.name()) > 0) {
VLOG(1) << "Emplace node to cache failed";
@@ -643,10 +643,16 @@ void GraphTransferer::AppendNodeParams(const string& name, const int id,
const std::vector<int>& extra_inputs,
const int outputs_size) {
VLOG(1) << "Append node params: " << name;
- node_transfer_params_list_.emplace_back(
- NodeTransferParams{name, id, type, type_id, padding,
- inputs_size + static_cast<int>(extra_inputs.size()),
- static_cast<int>(outputs_size)});
+ GraphTransferInfo::NodeInfo& node_info =
+ *graph_transfer_info_.add_node_info();
+ node_info.set_name(name);
+ node_info.set_node_id(id);
+ node_info.set_type_name(type);
+ node_info.set_soc_op_id(type_id);
+ node_info.set_padding_id(padding);
+ node_info.set_input_count(inputs_size +
+ static_cast<int>(extra_inputs.size()));
+ node_info.set_output_count(static_cast<int>(outputs_size));
}
void GraphTransferer::AppendNodeInputParams(
@@ -833,10 +839,10 @@ GraphTransferer::TransferParamsComparator::TransferParamsComparator(
: dependency_map_(dep_map) {}
bool GraphTransferer::TransferParamsComparator::operator()(
- const GraphTransferer::NodeTransferParams& obj0,
- const GraphTransferer::NodeTransferParams& obj1) {
- const int node_id0 = obj0.node_id;
- const int node_id1 = obj1.node_id;
+ const GraphTransferInfo::NodeInfo& obj0,
+ const GraphTransferInfo::NodeInfo& obj1) {
+ const int node_id0 = obj0.node_id();
+ const int node_id1 = obj1.node_id();
bool obj0_uses_obj1 = false;
if (dependency_map_.count(node_id0)) {
obj0_uses_obj1 = dependency_map_.at(node_id0).count(node_id1) > 0;
@@ -916,17 +922,18 @@ void GraphTransferer::DumpNodeTransferParams() const {
}
LOG(INFO) << "******\n";
LOG(INFO) << "*** Op Nodes ***";
- for (const NodeTransferParams& params : node_transfer_params_list_) {
- LOG(INFO) << "[ " << params.node_id << " \"" << params.name;
- LOG(INFO) << " type: " << params.type_name;
- LOG(INFO) << " padding: " << ToPaddingDebugString(params.padding);
- LOG(INFO) << " inputs: " << INPUTS_NODE_PREFIX + ToString(params.node_id)
- << ", size = " << params.inputs_size;
+ for (const GraphTransferInfo::NodeInfo& params :
+ graph_transfer_info_.node_info()) {
+ LOG(INFO) << "[ " << params.node_id() << " \"" << params.name();
+ LOG(INFO) << " type: " << params.type_name();
+ LOG(INFO) << " padding: " << ToPaddingDebugString(params.padding_id());
+ LOG(INFO) << " inputs: " << INPUTS_NODE_PREFIX + ToString(params.node_id())
+ << ", size = " << params.input_count();
LOG(INFO) << " outputs: "
- << (params.outputs_size <= 0
+ << (params.output_count() <= 0
? NULL_OUTPUT_NAME
- : (OUTPUTS_NODE_PREFIX + ToString(params.node_id)))
- << ", size = " << params.outputs_size << " ]";
+ : (OUTPUTS_NODE_PREFIX + ToString(params.node_id())))
+ << ", size = " << params.output_count() << " ]";
}
LOG(INFO) << "******\n";
LOG(INFO) << "*** Node input params ***";
@@ -963,20 +970,21 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
LOG(INFO) << sstream.str();
}
LOG(INFO) << "Const node count = " << const_node_transfer_params_list_.size();
- for (const NodeTransferParams& params : node_transfer_params_list_) {
+ for (const GraphTransferInfo::NodeInfo& params :
+ graph_transfer_info_.node_info()) {
std::stringstream sstream;
- sstream << "---(OP) [" << params.name.c_str() << "," << std::hex
- << params.node_id << std::dec << "," << params.soc_op_id << ","
- << ToPaddingDebugString(params.padding) << ","
- << INPUTS_NODE_PREFIX + ToString(params.node_id) << ","
- << params.inputs_size << ","
- << (params.outputs_size <= 0
+ sstream << "---(OP) [" << params.name().c_str() << "," << std::hex
+ << params.node_id() << std::dec << "," << params.soc_op_id() << ","
+ << ToPaddingDebugString(params.padding_id()) << ","
+ << INPUTS_NODE_PREFIX + ToString(params.node_id()) << ","
+ << params.input_count() << ","
+ << (params.output_count() <= 0
? NULL_OUTPUT_NAME
- : (OUTPUTS_NODE_PREFIX + ToString(params.node_id)))
- << "," << params.outputs_size << "," << params.type_name << "]";
+ : (OUTPUTS_NODE_PREFIX + ToString(params.node_id())))
+ << "," << params.output_count() << "," << params.type_name() << "]";
LOG(INFO) << sstream.str();
}
- LOG(INFO) << "Op node count = " << node_transfer_params_list_.size();
+ LOG(INFO) << "Op node count = " << graph_transfer_info_.node_info_size();
for (const NodeInputParams& params : node_input_params_list_) {
std::stringstream sstream;
sstream << "---(INPUT) [" << std::hex << params.node_id << std::dec;
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h
index d86452905f..f8abd08bd4 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.h
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
@@ -39,7 +40,9 @@ namespace tensorflow {
// to avoid unsupported ops in SoC.
class GraphTransferer {
public:
+ // TODO(satok): Remove. Use proto definition instead.
static constexpr int MAX_SUPPORTED_RANK = 5;
+ // TODO(satok): Remove. Use proto definition instead.
static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK - 1;
using OutputTensorMap = std::unordered_map<string, Tensor*>;
@@ -48,18 +51,8 @@ class GraphTransferer {
Tensor tensor;
};
- // Node parameters for transfer
- struct NodeTransferParams {
- string name;
- int node_id;
- string type_name;
- int soc_op_id;
- int padding;
- int inputs_size;
- int outputs_size;
- };
-
// Const node parameters for transfer
+ // TODO(satok): Remove. Use proto definition instead.
struct ConstNodeTransferParams {
string name;
int node_id;
@@ -69,12 +62,14 @@ class GraphTransferer {
};
// Input parameters of a node for transfer
+ // TODO(satok): Remove. Use proto definition instead.
struct NodeInputParams {
int node_id;
std::vector<std::tuple<int, int>> input_node_id_and_output_port_list;
};
// Output parameters of a node for transfer
+ // TODO(satok): Remove. Use proto definition instead.
struct NodeOutputParams {
int node_id;
std::vector<int> max_sizes;
@@ -130,22 +125,22 @@ class GraphTransferer {
// Return const node parameters for transfer
const std::vector<ConstNodeTransferParams>& GetConstNodeParams() const;
- // Return op node parameters for transfer
- const std::vector<NodeTransferParams>& GetOpNodeParams() const;
-
// Return input params of nodes
const std::vector<NodeInputParams>& GetNodeInputParams() const;
// Return output params of nodes
const std::vector<NodeOutputParams>& GetNodeOutputParams() const;
+ // Return parameters for transfer
+ const GraphTransferInfo& GetGraphTransferInfo() const;
+
private:
class TransferParamsComparator {
public:
TransferParamsComparator(
const std::unordered_map<int, std::unordered_set<int>>& dep_map);
- bool operator()(const GraphTransferer::NodeTransferParams& obj0,
- const GraphTransferer::NodeTransferParams& obj1);
+ bool operator()(const GraphTransferInfo::NodeInfo& obj0,
+ const GraphTransferInfo::NodeInfo& obj1);
const std::unordered_map<int, std::unordered_set<int>>& dependency_map_;
};
@@ -258,7 +253,7 @@ class GraphTransferer {
// Dump verification string of parameters to verify with offline tools
void DumpVerificationStringOfNodeTransferParams() const;
- std::vector<NodeTransferParams> node_transfer_params_list_;
+ GraphTransferInfo graph_transfer_info_;
std::vector<ConstNodeTransferParams> const_node_transfer_params_list_;
std::vector<NodeInputParams> node_input_params_list_;
std::vector<NodeOutputParams> node_output_params_list_;
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
index 92b58083b9..6b0e57ed97 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
@@ -128,11 +129,11 @@ static const GraphTransferer::ConstNodeTransferParams* FindConstNodeParams(
return nullptr;
}
-static const GraphTransferer::NodeTransferParams* FindOpNodeParams(
+static const GraphTransferInfo::NodeInfo* FindNodeInfo(
const GraphTransferer& gt, const string& name) {
- for (const GraphTransferer::NodeTransferParams& params :
- gt.GetOpNodeParams()) {
- if (params.name == name) {
+ for (const GraphTransferInfo::NodeInfo& params :
+ gt.GetGraphTransferInfo().node_info()) {
+ if (params.name() == name) {
return &params;
}
}
@@ -162,26 +163,26 @@ static const GraphTransferer::NodeOutputParams* FindNodeOutputParams(
}
static void SanityCheckNodes(const GraphTransferer& gt) {
- for (const GraphTransferer::NodeTransferParams& params :
- gt.GetOpNodeParams()) {
- if (params.inputs_size > 0) {
+ for (const GraphTransferInfo::NodeInfo& params :
+ gt.GetGraphTransferInfo().node_info()) {
+ if (params.input_count() > 0) {
const GraphTransferer::NodeInputParams* input_params =
- FindNodeInputParams(gt, params.node_id);
+ FindNodeInputParams(gt, params.node_id());
ASSERT_NE(nullptr, input_params);
- EXPECT_EQ(params.inputs_size,
+ EXPECT_EQ(params.input_count(),
input_params->input_node_id_and_output_port_list.size());
- EXPECT_EQ(params.node_id, input_params->node_id);
+ EXPECT_EQ(params.node_id(), input_params->node_id);
for (const std::tuple<int, int>& pair :
input_params->input_node_id_and_output_port_list) {
EXPECT_GE(std::get<1>(pair), 0);
}
}
- if (params.outputs_size > 0) {
+ if (params.output_count() > 0) {
const GraphTransferer::NodeOutputParams* output_params =
- FindNodeOutputParams(gt, params.node_id);
+ FindNodeOutputParams(gt, params.node_id());
ASSERT_NE(nullptr, output_params);
- EXPECT_EQ(params.outputs_size, output_params->max_sizes.size());
- EXPECT_EQ(params.node_id, output_params->node_id);
+ EXPECT_EQ(params.output_count(), output_params->max_sizes.size());
+ EXPECT_EQ(params.node_id(), output_params->node_id);
for (const int max_size : output_params->max_sizes) {
EXPECT_GE(max_size, 0);
}
@@ -344,17 +345,16 @@ TEST_F(GraphTransfererTest, LoadConvGraph) {
SanityCheckNodes(gt_);
const int const_node_count = gt_.GetConstNodeParams().size();
ASSERT_EQ(2, const_node_count);
- const int op_node_count = gt_.GetOpNodeParams().size();
+ const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
ASSERT_EQ(3, op_node_count);
- const GraphTransferer::NodeTransferParams* params_conv =
- FindOpNodeParams(gt_, "conv");
+ const GraphTransferInfo::NodeInfo* params_conv = FindNodeInfo(gt_, "conv");
ASSERT_TRUE(params_conv != nullptr);
- const int id = params_conv->node_id;
+ const int id = params_conv->node_id();
EXPECT_GE(id, 0);
- EXPECT_EQ("Conv2D", params_conv->type_name);
- EXPECT_EQ(3, params_conv->inputs_size);
- EXPECT_EQ(1, params_conv->outputs_size);
- EXPECT_EQ(Padding::SAME, params_conv->padding);
+ EXPECT_EQ("Conv2D", params_conv->type_name());
+ EXPECT_EQ(3, params_conv->input_count());
+ EXPECT_EQ(1, params_conv->output_count());
+ EXPECT_EQ(Padding::SAME, params_conv->padding_id());
}
TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
@@ -370,17 +370,17 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
SanityCheckNodes(gt_);
const int const_node_count = gt_.GetConstNodeParams().size();
ASSERT_EQ(2, const_node_count);
- const int op_node_count = gt_.GetOpNodeParams().size();
+ const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
ASSERT_EQ(3, op_node_count);
- const GraphTransferer::NodeTransferParams* params_max_pool =
- FindOpNodeParams(gt_, "maxpool");
+ const GraphTransferInfo::NodeInfo* params_max_pool =
+ FindNodeInfo(gt_, "maxpool");
ASSERT_TRUE(params_max_pool != nullptr);
- const int id = params_max_pool->node_id;
+ const int id = params_max_pool->node_id();
EXPECT_GE(id, 0);
- EXPECT_EQ("MaxPool", params_max_pool->type_name);
- EXPECT_EQ(3, params_max_pool->inputs_size);
- EXPECT_EQ(1, params_max_pool->outputs_size);
- EXPECT_EQ(Padding::SAME, params_max_pool->padding);
+ EXPECT_EQ("MaxPool", params_max_pool->type_name());
+ EXPECT_EQ(3, params_max_pool->input_count());
+ EXPECT_EQ(1, params_max_pool->output_count());
+ EXPECT_EQ(Padding::SAME, params_max_pool->padding_id());
}
TEST(HexagonOpsDefinitions, CheckOpsDefinitions) {
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
index ca29fcdd47..52ec1a4e9d 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
@@ -122,10 +122,10 @@ bool HexagonControlWrapper::SetupGraph(
}
// 2. Setup op nodes
- for (const GraphTransferer::NodeTransferParams& params :
- graph_transferer.GetOpNodeParams()) {
- const int node_id = params.node_id;
- const int op_id = params.soc_op_id;
+ for (const GraphTransferInfo::NodeInfo& params :
+ graph_transferer.GetGraphTransferInfo().node_info()) {
+ const int node_id = params.node_id();
+ const int op_id = params.soc_op_id();
CHECK(inputs_map.count(node_id) == 1);
CHECK(outputs_map.count(node_id) <= 1);
// Only output node doesn't have output
@@ -142,16 +142,16 @@ bool HexagonControlWrapper::SetupGraph(
CHECK(output_count > 0);
}
int padding_id = -1;
- if (params.padding == 0) {
+ if (params.padding_id() == 0) {
padding_id = 0;
- } else if (params.padding == Padding::SAME) {
+ } else if (params.padding_id() == Padding::SAME) {
padding_id = 1;
- } else if (params.padding == Padding::VALID) {
+ } else if (params.padding_id() == Padding::VALID) {
padding_id = 2;
} else {
CHECK(false);
}
- soc_interface_AppendNode(params.name.c_str(), node_id + NODE_ID_OFFSET,
+ soc_interface_AppendNode(params.name().c_str(), node_id + NODE_ID_OFFSET,
op_id, padding_id, input_ptr, input_count,
output_ptr, output_count);
}