From 7ab54c4c48f35a4107e6170cefe5c93245595601 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2017 17:37:21 -0800 Subject: Support compressed TensorProto format in constant folding for types iny16, int8, uint8, and bool, in addition to float ,double, int32, and int64, which were already supported. Add unit test for all types. PiperOrigin-RevId: 177533200 --- .../core/grappler/optimizers/constant_folding.cc | 24 ++++++--- .../grappler/optimizers/constant_folding_test.cc | 63 +++++++++++++++++++++- 2 files changed, 80 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index cf913d6f48..e0f39c2931 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -657,9 +657,9 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { namespace { -#define SET_TENSOR_VAL_CASE(DTYPE, TYPE) \ +#define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \ case DTYPE: \ - t->add_##TYPE##_val(static_cast(value)); \ + t->add_##NAME##_val(static_cast(value)); \ break; Status CreateConstantTensorAttrValue(DataType type, double value, @@ -668,10 +668,14 @@ Status CreateConstantTensorAttrValue(DataType type, double value, TensorProto* t = attr_tensor->mutable_tensor(); *t->mutable_tensor_shape() = shape; switch (type) { - SET_TENSOR_VAL_CASE(DT_FLOAT, float); - SET_TENSOR_VAL_CASE(DT_DOUBLE, double); - SET_TENSOR_VAL_CASE(DT_INT64, int64); - SET_TENSOR_VAL_CASE(DT_INT32, int); + SET_TENSOR_VAL_CASE(DT_FLOAT, float, float); + SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double); + SET_TENSOR_VAL_CASE(DT_INT64, int64, int64); + SET_TENSOR_VAL_CASE(DT_INT32, int32, int); + SET_TENSOR_VAL_CASE(DT_INT16, int32, int); + SET_TENSOR_VAL_CASE(DT_INT8, int32, int); + SET_TENSOR_VAL_CASE(DT_UINT8, int32, int); + SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool); default: return errors::InvalidArgument("Unsupported type: ", type); } @@ -721,6 +725,14 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name, POPULATE_TENSOR_PROTO(tensor, t, int64, int64) } else if (tensor->dtype() == DT_INT32) { POPULATE_TENSOR_PROTO(tensor, t, int32, int) + } else if (tensor->dtype() == DT_INT16) { + POPULATE_TENSOR_PROTO(tensor, t, int16, int) + } else if (tensor->dtype() == DT_INT8) { + POPULATE_TENSOR_PROTO(tensor, t, int8, int) + } else if (tensor->dtype() == DT_UINT8) { + POPULATE_TENSOR_PROTO(tensor, t, uint8, int) + } else if (tensor->dtype() == DT_BOOL) { + POPULATE_TENSOR_PROTO(tensor, t, bool, bool) } } if (optimized) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index c72ed96520..32a691d3ee 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -173,11 +173,70 @@ TEST_F(ConstantFoldingTest, NeutralElement) { } } +TEST_F(ConstantFoldingTest, CreateConstNodes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + +#define MAKE_TEST_GRAPH(TYPE) \ + Output TYPE##_const = \ + ops::Const(s.WithOpName(#TYPE "_const"), static_cast(10), {5}); \ + Output TYPE##_mul = \ + ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \ + Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul) + + MAKE_TEST_GRAPH(float); + MAKE_TEST_GRAPH(double); + MAKE_TEST_GRAPH(int64); + MAKE_TEST_GRAPH(int32); + MAKE_TEST_GRAPH(int16); + MAKE_TEST_GRAPH(int8); + MAKE_TEST_GRAPH(uint8); +#undef MAKE_TEST_GRAPH + + Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5}); + Output bool_and = + ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const); + Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(24, output.node_size()); + for (const NodeDef& node : output.node()) { +#define CHECK_RESULT(TYPE, FIELD) \ + if (node.name() == #TYPE "_mul") { \ + EXPECT_EQ(5, \ + node.attr().at("value").tensor().tensor_shape().dim(0).size()); \ + EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \ + EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \ + } + + CHECK_RESULT(float, float); + CHECK_RESULT(double, double); + CHECK_RESULT(int64, int64); + CHECK_RESULT(int32, int); + CHECK_RESULT(int16, int); + CHECK_RESULT(int8, int); + CHECK_RESULT(uint8, int); +#undef CHECK_RESULT + + if (node.name() == "bool_and") { + EXPECT_EQ(5, + node.attr().at("value").tensor().tensor_shape().dim(0).size()); + EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size()); + EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0)); + } + } +} + TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output a = ops::Const(s.WithOpName("a"), 10, {3}); + Output a = ops::Const(s.WithOpName("a"), 10, {5}); auto b = ops::Unique(s.WithOpName("b"), {a}); Output c = ops::Identity(s.WithOpName("c"), {b.y}); Output d = ops::Identity(s.WithOpName("d"), {b.idx}); @@ -1059,3 +1118,5 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { } // namespace } // namespace grappler } // namespace tensorflow + +// LocalWords: NewRootScope -- cgit v1.2.3