diff options
author | 2016-11-01 10:11:34 -0800 | |
---|---|---|
committer | 2016-11-01 11:40:50 -0700 | |
commit | 155994b58ab116c2fb76a0f7ee7f069a5546ada4 (patch) | |
tree | c2d875b289ce115c6e5a2e12faf62728cda39975 | |
parent | 39f5cc1caa396e1e8a13341847be602fb2609023 (diff) |
Load convolution node definitions to transfer graphs to SOC
Change: 137845970
-rw-r--r-- | tensorflow/core/kernels/hexagon/graph_transferer.cc | 85 | ||||
-rw-r--r-- | tensorflow/core/kernels/hexagon/graph_transferer.h | 11 | ||||
-rw-r--r-- | tensorflow/core/kernels/hexagon/graph_transferer_test.cc | 67 |
3 files changed, 148 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index bf91c7678f..eed6d57861 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -27,6 +27,12 @@ static constexpr bool DBG = false; static constexpr const char* const INPUTS_NODE_PREFIX = "inputs_for_"; static constexpr const char* const OUTPUTS_NODE_PREFIX = "outputs_for_"; static constexpr const char* const DATA_NODE_PREFIX = "data_for_op_"; +static constexpr const char* const CONST_SHAPE_PREFIX = "const_shape_"; +static constexpr const char* const PADDING_PREFIX = "NN_PAD_"; +static constexpr const char* const PADDING_ATTR_NAME = "padding"; +static constexpr const char* const STRIDES_ATTR_NAME = "strides"; +static constexpr const char* const PADDING_VALID_STR = "VALID"; +static constexpr const char* const PADDING_SAME_STR = "SAME"; void GraphTransferer::LoadGraphFromProto(const GraphDef& graph_def) { ImportGraphDefOptions opts; @@ -63,6 +69,11 @@ GraphTransferer::GetConstNodeParams() const { return const_node_transfer_params_list_; } +const std::vector<GraphTransferer::NodeTransferParams>& +GraphTransferer::GetOpNodeParams() const { + return node_transfer_params_list_; +} + int GraphTransferer::CacheNode(const Node& node) { if (node_name_to_id_cache_map_.count(node.name()) > 0) { if (DBG) { @@ -107,6 +118,8 @@ void GraphTransferer::RegisterNode(const ShapeRefiner& shape_refiner, } } else if (node.IsConstant()) { RegisterConstantNode(shape_refiner, node); + } else if (HasPaddingAndStrides(node)) { + RegisterNodeWithPaddingAndStrides(shape_refiner, node); } else { // TODO(satok): register params for nodes which are supported by SOC if (DBG) { @@ -134,8 +147,6 @@ void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner, CHECK(context->ValueKnown(num_elements_dim)); const int64 num_output_elements = context->Value(num_elements_dim); const int data_size = max_bytes_per_data * num_output_elements; - const int rank = context->Rank(shape_handle); - CHECK(rank == 0); const std::array<int64, SHAPE_ARRAY_SIZE> shape = BuildShapeArray(shape_handle, context); const_node_transfer_params_list_.emplace_back( @@ -146,6 +157,46 @@ void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner, data_size}); } +int GraphTransferer::RegisterConstantShape(const std::vector<int>& shape) { + // TODO(satok): Handle non-4dim strides + CHECK(shape.size() == 4); + const string shape_name = + std::string(CONST_SHAPE_PREFIX) + std::to_string(shape.at(0)) + 'x' + + std::to_string(shape.at(1)) + 'x' + std::to_string(shape.at(2)) + 'x' + + std::to_string(shape.at(3)); + if (node_name_to_id_cache_map_.count(shape_name) <= 0) { + node_name_cache_list_.emplace_back(nullptr); + const int id = node_name_cache_list_.size() - 1; + node_name_to_id_cache_map_.emplace(shape_name, id); + const_node_transfer_params_list_.emplace_back(ConstNodeTransferParams{ + shape_name, id, {{shape[0], shape[1], shape[2], shape[3]}}, "", 0}); + } + return node_name_to_id_cache_map_[shape_name]; +} + +bool GraphTransferer::HasPaddingAndStrides(const Node& node) { + return node.def().attr().count(PADDING_ATTR_NAME) > 0 && + node.def().attr().count(STRIDES_ATTR_NAME) > 0; +} + +void GraphTransferer::RegisterNodeWithPaddingAndStrides( + const ShapeRefiner& shape_refiner, const Node& node) { + CHECK(node_name_to_id_cache_map_.count(node.name()) == 1); + const int id = node_name_to_id_cache_map_[node.name()]; + shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); + CHECK(node.def().attr().count(PADDING_ATTR_NAME) > 0); + // TODO(satok): Use context->GetAttr(...) instead? + Padding padding; + context->GetAttr(PADDING_ATTR_NAME, &padding); + CHECK(node.def().attr().count(STRIDES_ATTR_NAME) > 0); + std::vector<int32> strides; + context->GetAttr(STRIDES_ATTR_NAME, &strides); + const int stride_id = RegisterConstantShape(strides); + std::vector<int> extra_inputs{stride_id, 0}; + AppendNodeParams(node.name(), id, node.type_string(), padding, + node.num_inputs(), extra_inputs, node.num_outputs()); +} + bool GraphTransferer::RegisterNodeIfAllInputsAreCached( const ShapeRefiner& shape_refiner, const Node& node, const bool only_register_const_node) { @@ -161,14 +212,21 @@ bool GraphTransferer::RegisterNodeIfAllInputsAreCached( void GraphTransferer::AppendNodeParams(const string& name, const int id, const string& type, - const string& padding, + const Padding& padding, const int inputs_size, + const std::vector<int>& extra_inputs, const int outputs_size) { + // TODO(satok): register inputs + // TODO(satok): register outputs + // TODO(satok): store padding as Padding? node_transfer_params_list_.emplace_back(NodeTransferParams{ - name, id, type, padding, - string(INPUTS_NODE_PREFIX) + std::to_string(inputs_size), inputs_size, - string(OUTPUTS_NODE_PREFIX) + std::to_string(outputs_size), - outputs_size}); + name, id, type, + string(PADDING_PREFIX) + + string(padding == VALID ? PADDING_VALID_STR : PADDING_SAME_STR), + string(INPUTS_NODE_PREFIX) + std::to_string(id), + inputs_size + static_cast<int>(extra_inputs.size()), + string(OUTPUTS_NODE_PREFIX) + std::to_string(id), + static_cast<int>(outputs_size)}); } /* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE> @@ -205,6 +263,7 @@ GraphTransferer::BuildShapeArray( void GraphTransferer::DumpNodeTransferParams() const { // TODO(satok): Dump all params + LOG(INFO) << "*** Const Nodes ***"; for (const ConstNodeTransferParams& params : const_node_transfer_params_list_) { LOG(INFO) << "[ " << params.id << " \"" << params.name << "\" (Const)"; @@ -214,6 +273,18 @@ void GraphTransferer::DumpNodeTransferParams() const { LOG(INFO) << " data_size: " << params.data_size << " bytes" << " ]"; } + LOG(INFO) << "******"; + LOG(INFO) << "*** Op Nodes ***"; + for (const NodeTransferParams& params : node_transfer_params_list_) { + LOG(INFO) << "[ " << params.id << " \"" << params.name; + LOG(INFO) << " type: " << params.type; + LOG(INFO) << " padding: " << params.padding; + LOG(INFO) << " inputs: " << params.inputs_name + << ", size = " << params.inputs_size; + LOG(INFO) << " outputs: " << params.outputs_name + << ", size = " << params.outputs_size << " ]"; + } + LOG(INFO) << "******"; } } // namespace tensorflow diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h index 3e3ee3d49b..99e1952f60 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.h +++ b/tensorflow/core/kernels/hexagon/graph_transferer.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/padding.h" namespace tensorflow { @@ -67,17 +68,25 @@ class GraphTransferer { // Return const node parameters for transfer const std::vector<ConstNodeTransferParams>& GetConstNodeParams() const; + // Return op node parameters for transfer + const std::vector<NodeTransferParams>& GetOpNodeParams() const; + private: int CacheNode(const Node& node); bool AreAllInputsCached(const Node& node) const; void RegisterConstantNode(const ShapeRefiner& shape_refiner, const Node& node); + int RegisterConstantShape(const std::vector<int>& shape); + bool HasPaddingAndStrides(const Node& node); + void RegisterNodeWithPaddingAndStrides(const ShapeRefiner& shape_refiner, + const Node& node); void RegisterNode(const ShapeRefiner& shape_refiner, const Node& node); bool RegisterNodeIfAllInputsAreCached(const ShapeRefiner& shape_refiner, const Node& node, const bool only_register_const_node); void AppendNodeParams(const string& name, const int id, const string& type, - const string& padding, const int inputs_size, + const Padding& padding, const int inputs_size, + const std::vector<int>& extra_inputs, const int outputs_size); static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray( const shape_inference::ShapeHandle& shape_handle, diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc index 4c386f4a62..272305717d 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc @@ -16,7 +16,9 @@ limitations under the License. #include <memory> #include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/hexagon/graph_transferer.h" #include "tensorflow/core/lib/core/status.h" @@ -40,12 +42,30 @@ class GraphTransfererTest : public ::testing::Test { std::unique_ptr<Session> _session; }; -static GraphDef CreateSmallGraphDef() { +static GraphDef CreateAddGraphDef() { Scope root = Scope::NewRootScope(); ops::Output node_a = ops::Const(root.WithOpName(NAME_A), 1); ops::Output node_b = ops::Const(root.WithOpName(NAME_B), 2); - ops::Add(root.WithOpName("a_plus_b"), node_a, node_b); + ops::Output node_add = ops::Add(root.WithOpName("a_plus_b"), node_a, node_b); + GraphDef def; + TF_CHECK_OK(root.ToGraphDef(&def)); + return def; +} +static GraphDef CreateConvGraphDef() { + Scope root = Scope::NewRootScope(); + Tensor input_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); + test::FillIota<float>(&input_data, 1.0f); + ops::Output input = + ops::Const(root.WithOpName("input"), ops::Input::Initializer(input_data)); + const int stride = 1; + Tensor filter_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); + test::FillIota<float>(&filter_data, 1.0f); + ops::Output filter = ops::Const(root.WithOpName("filter"), + ops::Input::Initializer(filter_data)); + const std::vector<int> padding{0, 0, 0, 0}; + ops::Output conv = ops::Conv2D(root.WithOpName("conv"), input, filter, + {1, stride, stride, 1}, "SAME"); GraphDef def; TF_CHECK_OK(root.ToGraphDef(&def)); return def; @@ -62,17 +82,29 @@ static const GraphTransferer::ConstNodeTransferParams* FindConstNodeParams( return nullptr; } -TEST_F(GraphTransfererTest, LoadGraph) { - GraphDef def = CreateSmallGraphDef(); +static const GraphTransferer::NodeTransferParams* FindOpNodeParams( + const GraphTransferer& gt, const string& name) { + for (const GraphTransferer::NodeTransferParams& params : + gt.GetOpNodeParams()) { + if (params.name == name) { + return ¶ms; + } + } + return nullptr; +} + +TEST_F(GraphTransfererTest, LoadAddGraph) { + GraphDef def = CreateAddGraphDef(); _session->Create(def); GraphTransferer gt; gt.LoadGraphFromProto(def); - ASSERT_EQ(2, gt.GetConstNodeParams().size()); + const int const_node_count = gt.GetConstNodeParams().size(); + ASSERT_EQ(2, const_node_count); const GraphTransferer::ConstNodeTransferParams* params_a = FindConstNodeParams(gt, NAME_A); ASSERT_TRUE(params_a != nullptr); - EXPECT_TRUE(params_a->id > 0 && params_a->id <= 2); + EXPECT_TRUE(params_a->id > 0 && params_a->id <= const_node_count); EXPECT_EQ(NAME_A, params_a->name); EXPECT_EQ(1, params_a->shape[0]); EXPECT_EQ(1, params_a->shape[1]); @@ -83,7 +115,7 @@ TEST_F(GraphTransfererTest, LoadGraph) { const GraphTransferer::ConstNodeTransferParams* params_b = FindConstNodeParams(gt, NAME_B); ASSERT_TRUE(params_b != nullptr); - EXPECT_TRUE(params_b->id > 0 && params_b->id <= 2); + EXPECT_TRUE(params_b->id > 0 && params_b->id <= const_node_count); EXPECT_EQ(1, params_b->shape[0]); EXPECT_EQ(1, params_b->shape[1]); EXPECT_EQ(1, params_b->shape[2]); @@ -91,4 +123,25 @@ TEST_F(GraphTransfererTest, LoadGraph) { EXPECT_EQ(10, params_b->data_size); } +TEST_F(GraphTransfererTest, LoadConvGraph) { + GraphDef def = CreateConvGraphDef(); + _session->Create(def); + + GraphTransferer gt; + gt.LoadGraphFromProto(def); + const int const_node_count = gt.GetConstNodeParams().size(); + ASSERT_EQ(3, const_node_count); + const int op_node_count = gt.GetOpNodeParams().size(); + ASSERT_EQ(1, op_node_count); + const GraphTransferer::NodeTransferParams* params_conv = + FindOpNodeParams(gt, "conv"); + ASSERT_TRUE(params_conv != nullptr); + const int id = params_conv->id; + EXPECT_TRUE(id > 0 && id <= (const_node_count + op_node_count)); + EXPECT_EQ("Conv2D", params_conv->type); + EXPECT_EQ(4, params_conv->inputs_size); + EXPECT_EQ(1, params_conv->outputs_size); + EXPECT_EQ("NN_PAD_SAME", params_conv->padding); +} + } // namespace tensorflow |