aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-30 17:37:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 17:40:33 -0800
commit7ab54c4c48f35a4107e6170cefe5c93245595601 (patch)
treeb1d471137286ecc081de7321ed396af98a5733f7
parentb2db981a6731e978453862a73dab892bc674db68 (diff)
Support compressed TensorProto format in constant folding for types iny16, int8, uint8, and bool, in addition to float ,double, int32, and int64, which were already supported.
Add unit test for all types. PiperOrigin-RevId: 177533200
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc24
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc63
2 files changed, 80 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index cf913d6f48..e0f39c2931 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -657,9 +657,9 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
namespace {
-#define SET_TENSOR_VAL_CASE(DTYPE, TYPE) \
+#define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \
case DTYPE: \
- t->add_##TYPE##_val(static_cast<TYPE>(value)); \
+ t->add_##NAME##_val(static_cast<TYPE>(value)); \
break;
Status CreateConstantTensorAttrValue(DataType type, double value,
@@ -668,10 +668,14 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
TensorProto* t = attr_tensor->mutable_tensor();
*t->mutable_tensor_shape() = shape;
switch (type) {
- SET_TENSOR_VAL_CASE(DT_FLOAT, float);
- SET_TENSOR_VAL_CASE(DT_DOUBLE, double);
- SET_TENSOR_VAL_CASE(DT_INT64, int64);
- SET_TENSOR_VAL_CASE(DT_INT32, int);
+ SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
+ SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
+ SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
+ SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
default:
return errors::InvalidArgument("Unsupported type: ", type);
}
@@ -721,6 +725,14 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name,
POPULATE_TENSOR_PROTO(tensor, t, int64, int64)
} else if (tensor->dtype() == DT_INT32) {
POPULATE_TENSOR_PROTO(tensor, t, int32, int)
+ } else if (tensor->dtype() == DT_INT16) {
+ POPULATE_TENSOR_PROTO(tensor, t, int16, int)
+ } else if (tensor->dtype() == DT_INT8) {
+ POPULATE_TENSOR_PROTO(tensor, t, int8, int)
+ } else if (tensor->dtype() == DT_UINT8) {
+ POPULATE_TENSOR_PROTO(tensor, t, uint8, int)
+ } else if (tensor->dtype() == DT_BOOL) {
+ POPULATE_TENSOR_PROTO(tensor, t, bool, bool)
}
}
if (optimized) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index c72ed96520..32a691d3ee 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -173,11 +173,70 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
}
}
+TEST_F(ConstantFoldingTest, CreateConstNodes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+#define MAKE_TEST_GRAPH(TYPE) \
+ Output TYPE##_const = \
+ ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
+ Output TYPE##_mul = \
+ ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \
+ Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
+
+ MAKE_TEST_GRAPH(float);
+ MAKE_TEST_GRAPH(double);
+ MAKE_TEST_GRAPH(int64);
+ MAKE_TEST_GRAPH(int32);
+ MAKE_TEST_GRAPH(int16);
+ MAKE_TEST_GRAPH(int8);
+ MAKE_TEST_GRAPH(uint8);
+#undef MAKE_TEST_GRAPH
+
+ Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
+ Output bool_and =
+ ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
+ Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ ConstantFolding fold(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(24, output.node_size());
+ for (const NodeDef& node : output.node()) {
+#define CHECK_RESULT(TYPE, FIELD) \
+ if (node.name() == #TYPE "_mul") { \
+ EXPECT_EQ(5, \
+ node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
+ EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \
+ EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \
+ }
+
+ CHECK_RESULT(float, float);
+ CHECK_RESULT(double, double);
+ CHECK_RESULT(int64, int64);
+ CHECK_RESULT(int32, int);
+ CHECK_RESULT(int16, int);
+ CHECK_RESULT(int8, int);
+ CHECK_RESULT(uint8, int);
+#undef CHECK_RESULT
+
+ if (node.name() == "bool_and") {
+ EXPECT_EQ(5,
+ node.attr().at("value").tensor().tensor_shape().dim(0).size());
+ EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
+ EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
+ }
+ }
+}
+
TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output a = ops::Const(s.WithOpName("a"), 10, {3});
+ Output a = ops::Const(s.WithOpName("a"), 10, {5});
auto b = ops::Unique(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), {b.y});
Output d = ops::Identity(s.WithOpName("d"), {b.idx});
@@ -1059,3 +1118,5 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
} // namespace
} // namespace grappler
} // namespace tensorflow
+
+// LocalWords: NewRootScope