aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-01 10:11:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-01 11:40:50 -0700
commit155994b58ab116c2fb76a0f7ee7f069a5546ada4 (patch)
treec2d875b289ce115c6e5a2e12faf62728cda39975
parent39f5cc1caa396e1e8a13341847be602fb2609023 (diff)
Load convolution node definitions to transfer graphs to SOC
Change: 137845970
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.cc85
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.h11
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer_test.cc67
3 files changed, 148 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc
index bf91c7678f..eed6d57861 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc
@@ -27,6 +27,12 @@ static constexpr bool DBG = false;
static constexpr const char* const INPUTS_NODE_PREFIX = "inputs_for_";
static constexpr const char* const OUTPUTS_NODE_PREFIX = "outputs_for_";
static constexpr const char* const DATA_NODE_PREFIX = "data_for_op_";
+static constexpr const char* const CONST_SHAPE_PREFIX = "const_shape_";
+static constexpr const char* const PADDING_PREFIX = "NN_PAD_";
+static constexpr const char* const PADDING_ATTR_NAME = "padding";
+static constexpr const char* const STRIDES_ATTR_NAME = "strides";
+static constexpr const char* const PADDING_VALID_STR = "VALID";
+static constexpr const char* const PADDING_SAME_STR = "SAME";
void GraphTransferer::LoadGraphFromProto(const GraphDef& graph_def) {
ImportGraphDefOptions opts;
@@ -63,6 +69,11 @@ GraphTransferer::GetConstNodeParams() const {
return const_node_transfer_params_list_;
}
+const std::vector<GraphTransferer::NodeTransferParams>&
+GraphTransferer::GetOpNodeParams() const {
+ return node_transfer_params_list_;
+}
+
int GraphTransferer::CacheNode(const Node& node) {
if (node_name_to_id_cache_map_.count(node.name()) > 0) {
if (DBG) {
@@ -107,6 +118,8 @@ void GraphTransferer::RegisterNode(const ShapeRefiner& shape_refiner,
}
} else if (node.IsConstant()) {
RegisterConstantNode(shape_refiner, node);
+ } else if (HasPaddingAndStrides(node)) {
+ RegisterNodeWithPaddingAndStrides(shape_refiner, node);
} else {
// TODO(satok): register params for nodes which are supported by SOC
if (DBG) {
@@ -134,8 +147,6 @@ void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner,
CHECK(context->ValueKnown(num_elements_dim));
const int64 num_output_elements = context->Value(num_elements_dim);
const int data_size = max_bytes_per_data * num_output_elements;
- const int rank = context->Rank(shape_handle);
- CHECK(rank == 0);
const std::array<int64, SHAPE_ARRAY_SIZE> shape =
BuildShapeArray(shape_handle, context);
const_node_transfer_params_list_.emplace_back(
@@ -146,6 +157,46 @@ void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner,
data_size});
}
+int GraphTransferer::RegisterConstantShape(const std::vector<int>& shape) {
+ // TODO(satok): Handle non-4dim strides
+ CHECK(shape.size() == 4);
+ const string shape_name =
+ std::string(CONST_SHAPE_PREFIX) + std::to_string(shape.at(0)) + 'x' +
+ std::to_string(shape.at(1)) + 'x' + std::to_string(shape.at(2)) + 'x' +
+ std::to_string(shape.at(3));
+ if (node_name_to_id_cache_map_.count(shape_name) <= 0) {
+ node_name_cache_list_.emplace_back(nullptr);
+ const int id = node_name_cache_list_.size() - 1;
+ node_name_to_id_cache_map_.emplace(shape_name, id);
+ const_node_transfer_params_list_.emplace_back(ConstNodeTransferParams{
+ shape_name, id, {{shape[0], shape[1], shape[2], shape[3]}}, "", 0});
+ }
+ return node_name_to_id_cache_map_[shape_name];
+}
+
+bool GraphTransferer::HasPaddingAndStrides(const Node& node) {
+ return node.def().attr().count(PADDING_ATTR_NAME) > 0 &&
+ node.def().attr().count(STRIDES_ATTR_NAME) > 0;
+}
+
+void GraphTransferer::RegisterNodeWithPaddingAndStrides(
+ const ShapeRefiner& shape_refiner, const Node& node) {
+ CHECK(node_name_to_id_cache_map_.count(node.name()) == 1);
+ const int id = node_name_to_id_cache_map_[node.name()];
+ shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
+ CHECK(node.def().attr().count(PADDING_ATTR_NAME) > 0);
+ // TODO(satok): Use context->GetAttr(...) instead?
+ Padding padding;
+ context->GetAttr(PADDING_ATTR_NAME, &padding);
+ CHECK(node.def().attr().count(STRIDES_ATTR_NAME) > 0);
+ std::vector<int32> strides;
+ context->GetAttr(STRIDES_ATTR_NAME, &strides);
+ const int stride_id = RegisterConstantShape(strides);
+ std::vector<int> extra_inputs{stride_id, 0};
+ AppendNodeParams(node.name(), id, node.type_string(), padding,
+ node.num_inputs(), extra_inputs, node.num_outputs());
+}
+
bool GraphTransferer::RegisterNodeIfAllInputsAreCached(
const ShapeRefiner& shape_refiner, const Node& node,
const bool only_register_const_node) {
@@ -161,14 +212,21 @@ bool GraphTransferer::RegisterNodeIfAllInputsAreCached(
void GraphTransferer::AppendNodeParams(const string& name, const int id,
const string& type,
- const string& padding,
+ const Padding& padding,
const int inputs_size,
+ const std::vector<int>& extra_inputs,
const int outputs_size) {
+ // TODO(satok): register inputs
+ // TODO(satok): register outputs
+ // TODO(satok): store padding as Padding?
node_transfer_params_list_.emplace_back(NodeTransferParams{
- name, id, type, padding,
- string(INPUTS_NODE_PREFIX) + std::to_string(inputs_size), inputs_size,
- string(OUTPUTS_NODE_PREFIX) + std::to_string(outputs_size),
- outputs_size});
+ name, id, type,
+ string(PADDING_PREFIX) +
+ string(padding == VALID ? PADDING_VALID_STR : PADDING_SAME_STR),
+ string(INPUTS_NODE_PREFIX) + std::to_string(id),
+ inputs_size + static_cast<int>(extra_inputs.size()),
+ string(OUTPUTS_NODE_PREFIX) + std::to_string(id),
+ static_cast<int>(outputs_size)});
}
/* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>
@@ -205,6 +263,7 @@ GraphTransferer::BuildShapeArray(
void GraphTransferer::DumpNodeTransferParams() const {
// TODO(satok): Dump all params
+ LOG(INFO) << "*** Const Nodes ***";
for (const ConstNodeTransferParams& params :
const_node_transfer_params_list_) {
LOG(INFO) << "[ " << params.id << " \"" << params.name << "\" (Const)";
@@ -214,6 +273,18 @@ void GraphTransferer::DumpNodeTransferParams() const {
LOG(INFO) << " data_size: " << params.data_size << " bytes"
<< " ]";
}
+ LOG(INFO) << "******";
+ LOG(INFO) << "*** Op Nodes ***";
+ for (const NodeTransferParams& params : node_transfer_params_list_) {
+ LOG(INFO) << "[ " << params.id << " \"" << params.name;
+ LOG(INFO) << " type: " << params.type;
+ LOG(INFO) << " padding: " << params.padding;
+ LOG(INFO) << " inputs: " << params.inputs_name
+ << ", size = " << params.inputs_size;
+ LOG(INFO) << " outputs: " << params.outputs_name
+ << ", size = " << params.outputs_size << " ]";
+ }
+ LOG(INFO) << "******";
}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h
index 3e3ee3d49b..99e1952f60 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.h
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/padding.h"
namespace tensorflow {
@@ -67,17 +68,25 @@ class GraphTransferer {
// Return const node parameters for transfer
const std::vector<ConstNodeTransferParams>& GetConstNodeParams() const;
+ // Return op node parameters for transfer
+ const std::vector<NodeTransferParams>& GetOpNodeParams() const;
+
private:
int CacheNode(const Node& node);
bool AreAllInputsCached(const Node& node) const;
void RegisterConstantNode(const ShapeRefiner& shape_refiner,
const Node& node);
+ int RegisterConstantShape(const std::vector<int>& shape);
+ bool HasPaddingAndStrides(const Node& node);
+ void RegisterNodeWithPaddingAndStrides(const ShapeRefiner& shape_refiner,
+ const Node& node);
void RegisterNode(const ShapeRefiner& shape_refiner, const Node& node);
bool RegisterNodeIfAllInputsAreCached(const ShapeRefiner& shape_refiner,
const Node& node,
const bool only_register_const_node);
void AppendNodeParams(const string& name, const int id, const string& type,
- const string& padding, const int inputs_size,
+ const Padding& padding, const int inputs_size,
+ const std::vector<int>& extra_inputs,
const int outputs_size);
static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
const shape_inference::ShapeHandle& shape_handle,
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
index 4c386f4a62..272305717d 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
@@ -16,7 +16,9 @@ limitations under the License.
#include <memory>
#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/lib/core/status.h"
@@ -40,12 +42,30 @@ class GraphTransfererTest : public ::testing::Test {
std::unique_ptr<Session> _session;
};
-static GraphDef CreateSmallGraphDef() {
+static GraphDef CreateAddGraphDef() {
Scope root = Scope::NewRootScope();
ops::Output node_a = ops::Const(root.WithOpName(NAME_A), 1);
ops::Output node_b = ops::Const(root.WithOpName(NAME_B), 2);
- ops::Add(root.WithOpName("a_plus_b"), node_a, node_b);
+ ops::Output node_add = ops::Add(root.WithOpName("a_plus_b"), node_a, node_b);
+ GraphDef def;
+ TF_CHECK_OK(root.ToGraphDef(&def));
+ return def;
+}
+static GraphDef CreateConvGraphDef() {
+ Scope root = Scope::NewRootScope();
+ Tensor input_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillIota<float>(&input_data, 1.0f);
+ ops::Output input =
+ ops::Const(root.WithOpName("input"), ops::Input::Initializer(input_data));
+ const int stride = 1;
+ Tensor filter_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillIota<float>(&filter_data, 1.0f);
+ ops::Output filter = ops::Const(root.WithOpName("filter"),
+ ops::Input::Initializer(filter_data));
+ const std::vector<int> padding{0, 0, 0, 0};
+ ops::Output conv = ops::Conv2D(root.WithOpName("conv"), input, filter,
+ {1, stride, stride, 1}, "SAME");
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
@@ -62,17 +82,29 @@ static const GraphTransferer::ConstNodeTransferParams* FindConstNodeParams(
return nullptr;
}
-TEST_F(GraphTransfererTest, LoadGraph) {
- GraphDef def = CreateSmallGraphDef();
+static const GraphTransferer::NodeTransferParams* FindOpNodeParams(
+ const GraphTransferer& gt, const string& name) {
+ for (const GraphTransferer::NodeTransferParams& params :
+ gt.GetOpNodeParams()) {
+ if (params.name == name) {
+ return &params;
+ }
+ }
+ return nullptr;
+}
+
+TEST_F(GraphTransfererTest, LoadAddGraph) {
+ GraphDef def = CreateAddGraphDef();
_session->Create(def);
GraphTransferer gt;
gt.LoadGraphFromProto(def);
- ASSERT_EQ(2, gt.GetConstNodeParams().size());
+ const int const_node_count = gt.GetConstNodeParams().size();
+ ASSERT_EQ(2, const_node_count);
const GraphTransferer::ConstNodeTransferParams* params_a =
FindConstNodeParams(gt, NAME_A);
ASSERT_TRUE(params_a != nullptr);
- EXPECT_TRUE(params_a->id > 0 && params_a->id <= 2);
+ EXPECT_TRUE(params_a->id > 0 && params_a->id <= const_node_count);
EXPECT_EQ(NAME_A, params_a->name);
EXPECT_EQ(1, params_a->shape[0]);
EXPECT_EQ(1, params_a->shape[1]);
@@ -83,7 +115,7 @@ TEST_F(GraphTransfererTest, LoadGraph) {
const GraphTransferer::ConstNodeTransferParams* params_b =
FindConstNodeParams(gt, NAME_B);
ASSERT_TRUE(params_b != nullptr);
- EXPECT_TRUE(params_b->id > 0 && params_b->id <= 2);
+ EXPECT_TRUE(params_b->id > 0 && params_b->id <= const_node_count);
EXPECT_EQ(1, params_b->shape[0]);
EXPECT_EQ(1, params_b->shape[1]);
EXPECT_EQ(1, params_b->shape[2]);
@@ -91,4 +123,25 @@ TEST_F(GraphTransfererTest, LoadGraph) {
EXPECT_EQ(10, params_b->data_size);
}
+TEST_F(GraphTransfererTest, LoadConvGraph) {
+ GraphDef def = CreateConvGraphDef();
+ _session->Create(def);
+
+ GraphTransferer gt;
+ gt.LoadGraphFromProto(def);
+ const int const_node_count = gt.GetConstNodeParams().size();
+ ASSERT_EQ(3, const_node_count);
+ const int op_node_count = gt.GetOpNodeParams().size();
+ ASSERT_EQ(1, op_node_count);
+ const GraphTransferer::NodeTransferParams* params_conv =
+ FindOpNodeParams(gt, "conv");
+ ASSERT_TRUE(params_conv != nullptr);
+ const int id = params_conv->id;
+ EXPECT_TRUE(id > 0 && id <= (const_node_count + op_node_count));
+ EXPECT_EQ("Conv2D", params_conv->type);
+ EXPECT_EQ(4, params_conv->inputs_size);
+ EXPECT_EQ(1, params_conv->outputs_size);
+ EXPECT_EQ("NN_PAD_SAME", params_conv->padding);
+}
+
} // namespace tensorflow