diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-30 17:37:21 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-30 17:40:33 -0800 |
commit | 7ab54c4c48f35a4107e6170cefe5c93245595601 (patch) | |
tree | b1d471137286ecc081de7321ed396af98a5733f7 | |
parent | b2db981a6731e978453862a73dab892bc674db68 (diff) |
Support compressed TensorProto format in constant folding for types iny16, int8, uint8, and bool, in addition to float ,double, int32, and int64, which were already supported.
Add unit test for all types.
PiperOrigin-RevId: 177533200
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 24 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 63 |
2 files changed, 80 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index cf913d6f48..e0f39c2931 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -657,9 +657,9 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { namespace { -#define SET_TENSOR_VAL_CASE(DTYPE, TYPE) \ +#define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \ case DTYPE: \ - t->add_##TYPE##_val(static_cast<TYPE>(value)); \ + t->add_##NAME##_val(static_cast<TYPE>(value)); \ break; Status CreateConstantTensorAttrValue(DataType type, double value, @@ -668,10 +668,14 @@ Status CreateConstantTensorAttrValue(DataType type, double value, TensorProto* t = attr_tensor->mutable_tensor(); *t->mutable_tensor_shape() = shape; switch (type) { - SET_TENSOR_VAL_CASE(DT_FLOAT, float); - SET_TENSOR_VAL_CASE(DT_DOUBLE, double); - SET_TENSOR_VAL_CASE(DT_INT64, int64); - SET_TENSOR_VAL_CASE(DT_INT32, int); + SET_TENSOR_VAL_CASE(DT_FLOAT, float, float); + SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double); + SET_TENSOR_VAL_CASE(DT_INT64, int64, int64); + SET_TENSOR_VAL_CASE(DT_INT32, int32, int); + SET_TENSOR_VAL_CASE(DT_INT16, int32, int); + SET_TENSOR_VAL_CASE(DT_INT8, int32, int); + SET_TENSOR_VAL_CASE(DT_UINT8, int32, int); + SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool); default: return errors::InvalidArgument("Unsupported type: ", type); } @@ -721,6 +725,14 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name, POPULATE_TENSOR_PROTO(tensor, t, int64, int64) } else if (tensor->dtype() == DT_INT32) { POPULATE_TENSOR_PROTO(tensor, t, int32, int) + } else if (tensor->dtype() == DT_INT16) { + POPULATE_TENSOR_PROTO(tensor, t, int16, int) + } else if (tensor->dtype() == DT_INT8) { + POPULATE_TENSOR_PROTO(tensor, t, int8, int) + } else if (tensor->dtype() == DT_UINT8) { + POPULATE_TENSOR_PROTO(tensor, t, uint8, int) + } else if (tensor->dtype() == DT_BOOL) { + POPULATE_TENSOR_PROTO(tensor, t, bool, bool) } } if (optimized) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index c72ed96520..32a691d3ee 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -173,11 +173,70 @@ TEST_F(ConstantFoldingTest, NeutralElement) { } } +TEST_F(ConstantFoldingTest, CreateConstNodes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + +#define MAKE_TEST_GRAPH(TYPE) \ + Output TYPE##_const = \ + ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \ + Output TYPE##_mul = \ + ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \ + Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul) + + MAKE_TEST_GRAPH(float); + MAKE_TEST_GRAPH(double); + MAKE_TEST_GRAPH(int64); + MAKE_TEST_GRAPH(int32); + MAKE_TEST_GRAPH(int16); + MAKE_TEST_GRAPH(int8); + MAKE_TEST_GRAPH(uint8); +#undef MAKE_TEST_GRAPH + + Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5}); + Output bool_and = + ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const); + Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(24, output.node_size()); + for (const NodeDef& node : output.node()) { +#define CHECK_RESULT(TYPE, FIELD) \ + if (node.name() == #TYPE "_mul") { \ + EXPECT_EQ(5, \ + node.attr().at("value").tensor().tensor_shape().dim(0).size()); \ + EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \ + EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \ + } + + CHECK_RESULT(float, float); + CHECK_RESULT(double, double); + CHECK_RESULT(int64, int64); + CHECK_RESULT(int32, int); + CHECK_RESULT(int16, int); + CHECK_RESULT(int8, int); + CHECK_RESULT(uint8, int); +#undef CHECK_RESULT + + if (node.name() == "bool_and") { + EXPECT_EQ(5, + node.attr().at("value").tensor().tensor_shape().dim(0).size()); + EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size()); + EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0)); + } + } +} + TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output a = ops::Const(s.WithOpName("a"), 10, {3}); + Output a = ops::Const(s.WithOpName("a"), 10, {5}); auto b = ops::Unique(s.WithOpName("b"), {a}); Output c = ops::Identity(s.WithOpName("c"), {b.y}); Output d = ops::Identity(s.WithOpName("d"), {b.idx}); @@ -1059,3 +1118,5 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { } // namespace } // namespace grappler } // namespace tensorflow + +// LocalWords: NewRootScope |