aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2017-10-09 17:05:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-09 17:16:15 -0700
commit8ff5070392bd0066930d11e3e39d21d3fa84bb2e (patch)
tree5f06f35f6a4c16f903cab24b08f18062fc95faf8
parentfdb2b12d1ad84392df09dc5dcd457ca7e96cb423 (diff)
[Grappler] Optimize bitcasts.
Two optimizations: 1. If dst_type == type(x), Bitcast(x, dst_type) => No-op 2. Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2) PiperOrigin-RevId: 171608976
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc68
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc61
2 files changed, 127 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 3ec62b5a00..971163eadf 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -289,6 +289,44 @@ static DataType GetDataTypeFromAttr(const NodeDef& node,
return attr.type();
}
+static void SetDataTypeToAttr(DataType dtype, const string& attr_name,
+ NodeDef* node) {
+ (*node->mutable_attr())[attr_name].set_type(dtype);
+}
+
+static string SourceDataTypeAttrName(const NodeDef& node) {
+ if (node.op() == "Bitcast") {
+ return "T";
+ } else if (node.op() == "Cast") {
+ return "SrcT";
+ } else {
+ LOG(FATAL) << "SourceDataTypeAttrName not implemented for op " << node.op();
+ }
+}
+
+static string DestinationDataTypeAttrName(const NodeDef& node) {
+ if (node.op() == "Bitcast") {
+ return "type";
+ } else if (node.op() == "Cast") {
+ return "DstT";
+ } else {
+ LOG(FATAL) << "DestinationDataTypeAttrName not implemented for op "
+ << node.op();
+ }
+}
+
+static DataType GetSourceDataType(const NodeDef& node) {
+ return GetDataTypeFromAttr(node, SourceDataTypeAttrName(node));
+}
+
+static DataType GetDestinationDataType(const NodeDef& node) {
+ return GetDataTypeFromAttr(node, DestinationDataTypeAttrName(node));
+}
+
+static void SetSourceDataType(DataType dtype, NodeDef* node) {
+ SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node);
+}
+
static bool IsNumberType(DataType dtype) {
DataTypeVector number_types = NumberTypes();
return std::find(number_types.begin(), number_types.end(), dtype) !=
@@ -369,8 +407,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* cast = node_map->GetNode(transpose->input(0));
if (cast->op() == "Cast") {
const NodeDef* input = node_map->GetNode(cast->input(0));
- const DataType src_type = GetDataTypeFromAttr(*cast, "SrcT");
- const DataType dst_type = GetDataTypeFromAttr(*cast, "DstT");
+ const DataType src_type = GetSourceDataType(*cast);
+ const DataType dst_type = GetDestinationDataType(*cast);
if (IsNumberType(src_type) && IsNumberType(dst_type) &&
DataTypeSize(src_type) < DataTypeSize(dst_type)) {
NodeDef* new_transpose = graph_def->add_node();
@@ -401,6 +439,32 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
+ if (node->op() == "Bitcast") {
+ NodeDef* bitcast = node_map->GetNode(node->name());
+ // Bypass bitcasts whose source type and destination type are equal.
+ if (GetSourceDataType(*bitcast) == GetDestinationDataType(*bitcast)) {
+ return bitcast->input(0);
+ }
+
+ const NodeDef* operand = node_map->GetNode(bitcast->input(0));
+ if (operand->op() == bitcast->op()) {
+ // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
+ bitcast->set_input(0, operand->input(0));
+ SetSourceDataType(GetSourceDataType(*operand), bitcast);
+ node_map->UpdateInput(bitcast->name(), bitcast->input(0),
+ operand->input(0));
+ new_nodes->push_back(bitcast);
+ return bitcast->name();
+ }
+ }
+
+ if (node->op() == "Cast") {
+ // Bypass casts whose source type and destination type are equal.
+ if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
+ return node->input(0);
+ }
+ }
+
// Fold a multiply of a scalar into the following convolution. This folding
// can jump across nodes that merely reorders data (such as reshape and
// transpose). For example, we can optimize
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 234c096073..39b4999808 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -450,6 +450,67 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
EXPECT_EQ(conv_node->input(1), weights_node->name());
}
+TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs =
+ ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({2, 3}));
+ Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
+ Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+ item.graph = output;
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(1, std::count_if(
+ output.node().begin(), output.node().end(),
+ [](const NodeDef& node) { return node.op() == "Bitcast"; }));
+}
+
+TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
+ Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
+ Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+ item.graph = output;
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(0, std::count_if(
+ output.node().begin(), output.node().end(),
+ [](const NodeDef& node) { return node.op() == "Bitcast"; }));
+}
+
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
+ Output cast = ops::Cast(s, inputs, DT_INT8);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+ item.graph = output;
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(0, std::count_if(
+ output.node().begin(), output.node().end(),
+ [](const NodeDef& node) { return node.op() == "Cast"; }));
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow