diff options
author | 2017-10-09 17:05:20 -0700 | |
---|---|---|
committer | 2017-10-09 17:16:15 -0700 | |
commit | 8ff5070392bd0066930d11e3e39d21d3fa84bb2e (patch) | |
tree | 5f06f35f6a4c16f903cab24b08f18062fc95faf8 | |
parent | fdb2b12d1ad84392df09dc5dcd457ca7e96cb423 (diff) |
[Grappler] Optimize bitcasts.
Two optimizations:
1. If dst_type == type(x), Bitcast(x, dst_type) => No-op
2. Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
PiperOrigin-RevId: 171608976
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 68 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 61 |
2 files changed, 127 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 3ec62b5a00..971163eadf 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -289,6 +289,44 @@ static DataType GetDataTypeFromAttr(const NodeDef& node, return attr.type(); } +static void SetDataTypeToAttr(DataType dtype, const string& attr_name, + NodeDef* node) { + (*node->mutable_attr())[attr_name].set_type(dtype); +} + +static string SourceDataTypeAttrName(const NodeDef& node) { + if (node.op() == "Bitcast") { + return "T"; + } else if (node.op() == "Cast") { + return "SrcT"; + } else { + LOG(FATAL) << "SourceDataTypeAttrName not implemented for op " << node.op(); + } +} + +static string DestinationDataTypeAttrName(const NodeDef& node) { + if (node.op() == "Bitcast") { + return "type"; + } else if (node.op() == "Cast") { + return "DstT"; + } else { + LOG(FATAL) << "DestinationDataTypeAttrName not implemented for op " + << node.op(); + } +} + +static DataType GetSourceDataType(const NodeDef& node) { + return GetDataTypeFromAttr(node, SourceDataTypeAttrName(node)); +} + +static DataType GetDestinationDataType(const NodeDef& node) { + return GetDataTypeFromAttr(node, DestinationDataTypeAttrName(node)); +} + +static void SetSourceDataType(DataType dtype, NodeDef* node) { + SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node); +} + static bool IsNumberType(DataType dtype) { DataTypeVector number_types = NumberTypes(); return std::find(number_types.begin(), number_types.end(), dtype) != @@ -369,8 +407,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* cast = node_map->GetNode(transpose->input(0)); if (cast->op() == "Cast") { const NodeDef* input = node_map->GetNode(cast->input(0)); - const DataType src_type = GetDataTypeFromAttr(*cast, "SrcT"); - const DataType dst_type = GetDataTypeFromAttr(*cast, "DstT"); + const DataType src_type = GetSourceDataType(*cast); + const DataType dst_type = GetDestinationDataType(*cast); if (IsNumberType(src_type) && IsNumberType(dst_type) && DataTypeSize(src_type) < DataTypeSize(dst_type)) { NodeDef* new_transpose = graph_def->add_node(); @@ -401,6 +439,32 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } } + if (node->op() == "Bitcast") { + NodeDef* bitcast = node_map->GetNode(node->name()); + // Bypass bitcasts whose source type and destination type are equal. + if (GetSourceDataType(*bitcast) == GetDestinationDataType(*bitcast)) { + return bitcast->input(0); + } + + const NodeDef* operand = node_map->GetNode(bitcast->input(0)); + if (operand->op() == bitcast->op()) { + // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2) + bitcast->set_input(0, operand->input(0)); + SetSourceDataType(GetSourceDataType(*operand), bitcast); + node_map->UpdateInput(bitcast->name(), bitcast->input(0), + operand->input(0)); + new_nodes->push_back(bitcast); + return bitcast->name(); + } + } + + if (node->op() == "Cast") { + // Bypass casts whose source type and destination type are equal. + if (GetSourceDataType(*node) == GetDestinationDataType(*node)) { + return node->input(0); + } + } + // Fold a multiply of a scalar into the following convolution. This folding // can jump across nodes that merely reorders data (such as reshape and // transpose). For example, we can optimize diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 234c096073..39b4999808 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -450,6 +450,67 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { EXPECT_EQ(conv_node->input(1), weights_node->name()); } +TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs = + ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({2, 3})); + Output bc1 = ops::Bitcast(s, inputs, DT_QINT8); + Output bc2 = ops::Bitcast(s, bc1, DT_INT8); + Output outputs = ops::Identity(s.WithOpName("outputs"), bc2); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + item.graph = output; + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + EXPECT_EQ(1, std::count_if( + output.node().begin(), output.node().end(), + [](const NodeDef& node) { return node.op() == "Bitcast"; })); +} + +TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3})); + Output bc1 = ops::Bitcast(s, inputs, DT_QINT8); + Output bc2 = ops::Bitcast(s, bc1, DT_INT8); + Output outputs = ops::Identity(s.WithOpName("outputs"), bc2); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + item.graph = output; + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + EXPECT_EQ(0, std::count_if( + output.node().begin(), output.node().end(), + [](const NodeDef& node) { return node.op() == "Bitcast"; })); +} + +TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3})); + Output cast = ops::Cast(s, inputs, DT_INT8); + Output outputs = ops::Identity(s.WithOpName("outputs"), cast); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + item.graph = output; + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + EXPECT_EQ(0, std::count_if( + output.node().begin(), output.node().end(), + [](const NodeDef& node) { return node.op() == "Cast"; })); +} + } // namespace } // namespace grappler } // namespace tensorflow |