diff options
author | 2018-09-20 13:56:49 -0700 | |
---|---|---|
committer | 2018-09-20 14:01:17 -0700 | |
commit | 17dbe77f5ad47e8fd71924f12b3bc53c05afbacf (patch) | |
tree | 46142d37c97ca378139cb73785171903a74f3516 | |
parent | d388770922ad1afa95e55597a33836fe74035c75 (diff) |
Fix bug in Pow optimizer rule when broadcasting is involved.
Minor cleanup by moving the helper function ShapesEqual to GraphProperties and adding unit tests for it.
PiperOrigin-RevId: 213876779
-rw-r--r-- | tensorflow/core/grappler/optimizers/BUILD | 35 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 61 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 27 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/shape_optimizer.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/BUILD | 29 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/symbolic_shapes.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.cc) | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/symbolic_shapes.h (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.h) | 6 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/symbolic_shapes_test.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc) | 2 |
9 files changed, 89 insertions, 95 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 029205248b..261dee4382 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -8,10 +8,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") # Platform specific build config load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_protos_grappler", -) -load( "//tensorflow/core:platform/default/build_config_root.bzl", "if_static", ) @@ -97,7 +93,6 @@ cc_library( deps = [ ":evaluation_utils", ":graph_optimizer", - ":symbolic_shapes", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -107,6 +102,7 @@ cc_library( "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:symbolic_shapes", ], ) @@ -261,7 +257,6 @@ cc_library( ":constant_folding", ":graph_optimizer", ":graph_optimizer_stage", - ":symbolic_shapes", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -270,6 +265,7 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:symbolic_shapes", "//tensorflow/core/grappler/utils:topological_sort", ], ) @@ -648,7 +644,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_optimizer", - ":symbolic_shapes", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -658,6 +653,7 @@ cc_library( "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:frame", + "//tensorflow/core/grappler/utils:symbolic_shapes", ], ) @@ -715,31 +711,6 @@ tf_cuda_cc_test( ) cc_library( - name = "symbolic_shapes", - srcs = ["symbolic_shapes.cc"], - hdrs = ["symbolic_shapes.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ] + tf_protos_grappler(), -) - -tf_cc_test( - name = "symbolic_shapes_test", - srcs = ["symbolic_shapes_test.cc"], - deps = [ - ":symbolic_shapes", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( name = "debug_stripper", srcs = ["debug_stripper.cc"], hdrs = [ diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 992e85d2c6..76a9dca73b 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -35,8 +35,8 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h" -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -2367,26 +2367,24 @@ class ConvertPowStage : public ArithmeticOptimizerStage { } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1]; - for (int i = 0; i < p.shape().dim_size(); ++i) { - if (p.shape().dim(i).size() < 0) { + const auto& pow_props = + ctx().graph_properties->GetInputProperties(node->name())[1]; + for (int i = 0; i < pow_props.shape().dim_size(); ++i) { + if (pow_props.shape().dim(i).size() < 0) { // skip if p is is not fully defined. return Status::OK(); } } - if (TensorShape::IsValid(p.shape()) && p.has_value()) { - Tensor pow(p.dtype(), p.shape()); - if (!pow.FromProto(p.value())) { + if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) { + Tensor pow(pow_props.dtype(), pow_props.shape()); + if (!pow.FromProto(pow_props.value())) { return errors::InvalidArgument("Cannot parse tensor from proto: ", - p.value().DebugString()); + pow_props.value().DebugString()); } complex128 prev, curr; for (int i = 0; i < pow.NumElements(); ++i) { - if (!GetElementUnexhaustive(pow, i, - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_COMPLEX128}, - &curr)) { + if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) { // input data type is not supported by Pow. Skip. return Status::OK(); } @@ -2399,12 +2397,19 @@ class ConvertPowStage : public ArithmeticOptimizerStage { NodeDef *x, *y; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); + const auto& value_props = + ctx().graph_properties->GetInputProperties(node->name())[0]; + const TensorShapeProto& output_shape = + ctx().graph_properties->GetOutputProperties(node->name())[0].shape(); if (curr == complex128(2, 0)) { node->set_op("Square"); node->set_input(1, AsControlDependency(y->name())); AddToOptimizationQueue(node); AddToOptimizationQueue(y); - } else if (curr == complex128(1, 0)) { + } else if (curr == complex128(1, 0) && + ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { + // Pow could be used to broadcast, so make sure the shapes of the two + // arguments are identical before replacing Pow with Identity. node->set_op("Identity"); node->set_input(1, AsControlDependency(y->name())); AddToOptimizationQueue(node); @@ -2414,20 +2419,20 @@ class ConvertPowStage : public ArithmeticOptimizerStage { node->set_input(1, AsControlDependency(y->name())); AddToOptimizationQueue(node); AddToOptimizationQueue(y); - } else if (curr == complex128(0, 0)) { - const auto& b = - ctx().graph_properties->GetInputProperties(node->name())[0]; - for (int i = 0; i < b.shape().dim_size(); ++i) { - if (b.shape().dim(i).size() < 0) { + } else if (curr == complex128(0, 0) && + ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { + for (int i = 0; i < value_props.shape().dim_size(); ++i) { + if (value_props.shape().dim(i).size() < 0) { // skip if b is is not fully defined. return Status::OK(); } } - if (TensorShape::IsValid(b.shape()) && b.has_value()) { - Tensor base(b.dtype(), b.shape()); - if (!base.FromProto(b.value())) { + if (TensorShape::IsValid(value_props.shape()) && + value_props.has_value()) { + Tensor base(value_props.dtype(), value_props.shape()); + if (!base.FromProto(value_props.value())) { return errors::InvalidArgument("Cannot parse tensor from proto: ", - b.value().DebugString()); + value_props.value().DebugString()); } node->set_op("Const"); Tensor c(base.dtype(), base.shape()); @@ -2585,12 +2590,10 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage { ~ConvertExpm1Stage() override = default; bool IsSupported(const NodeDef* node) const override { - if (!IsSub(*node)) - return false; + if (!IsSub(*node)) return false; NodeDef* input; - if (!GetInputNode(node->input(0), &input).ok()) - return false; + if (!GetInputNode(node->input(0), &input).ok()) return false; return IsExp(*input); } @@ -2610,10 +2613,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage { return Status::OK(); } - const auto& t = - ctx().graph_properties->GetInputProperties(exp->name())[0]; - const auto& c = - ctx().graph_properties->GetInputProperties(node->name())[1]; + const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0]; + const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1]; for (int k = 0; k < c.shape().dim_size(); ++k) { // Skip if c shape is not fully determined. if (c.shape().dim(k).size() < 0) { diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 88839d944c..77f3c64c65 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -2474,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2}); auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2}); auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2}); + auto z = ops::Const(s.WithOpName("z"), {42.0f}, {}); + auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3}); + auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3}); Output out2 = ops::Pow(s.WithOpName("out2"), x, y2); Output out1 = ops::Pow(s.WithOpName("out1"), x, y1); Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5); @@ -2481,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5); Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1); Output out = ops::Pow(s.WithOpName("out"), x, y); + Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones); + Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros); GrapplerItem item; - item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"}; + item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", + "out_1", "out", "out_bcast1", "out_bcast2"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - EXPECT_EQ(7, tensors_expected.size()); + EXPECT_EQ(9, tensors_expected.size()); GraphDef got; ArithmeticOptimizer optimizer; EnableOnlyConvertPow(&optimizer); OptimizeAndPrune(&optimizer, &item, &got); auto tensors = EvaluateNodes(got, item.fetch); - EXPECT_EQ(7, tensors.size()); + EXPECT_EQ(9, tensors.size()); - for (int i = 0; i < 7; ++i) { + for (int i = 0; i < tensors.size(); ++i) { EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements()); test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6); } @@ -2509,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { AddNode("y_.5", "Const", {}, {}, &want); AddNode("y_1", "Const", {}, {}, &want); AddNode("y", "Const", {}, {}, &want); + AddNode("z", "Const", {}, {}, &want); + AddNode("ones", "Const", {}, {}, &want); + AddNode("zeros", "Const", {}, {}, &want); AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want); AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want); AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want); @@ -2517,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want); AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want); AddNode("out", "Pow", {"x", "y"}, {}, &want); + AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want); + AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want); CompareGraphs(want, got); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 99737a71eb..cfbd298f11 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -32,8 +32,8 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/evaluation_utils.h" -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -437,25 +437,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { } namespace { -bool ShapesEqual(const TensorShapeProto& shape1, - const TensorShapeProto& shape2) { - if (shape1.unknown_rank() || shape2.unknown_rank()) { - return false; - } - if (shape1.dim_size() != shape2.dim_size()) { - return false; - } - for (int i = 0; i < shape1.dim_size(); ++i) { - if (shape1.dim(i).size() != shape2.dim(i).size()) { - return false; - } - if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) { - return false; - } - } - return true; -} - bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties, BCast::Vec* shape, int64* min_id) { if (shape_node.op() == "Shape") { @@ -2348,7 +2329,8 @@ Status ConstantFolding::SimplifyArithmeticOperations( properties.GetInputProperties(node->name())[1].shape(); const bool x_is_zero = IsZeros(*x); const bool x_is_one = x_is_zero ? false : IsOnes(*x); - const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); + const bool y_matches_output_shape = + ShapesSymbolicallyEqual(output_shape, y_shape); if (y_matches_output_shape && ((is_mul && x_is_one) || (is_add && x_is_zero))) { // 1 * y = y or 0 + y = y. @@ -2378,7 +2360,8 @@ Status ConstantFolding::SimplifyArithmeticOperations( properties.GetInputProperties(node->name())[0].shape(); const bool y_is_zero = IsZeros(*y); const bool y_is_one = y_is_zero ? false : IsOnes(*y); - const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); + const bool x_matches_output_shape = + ShapesSymbolicallyEqual(output_shape, x_shape); if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || ((is_add || is_sub) && y_is_zero))) { // x * 1 = x or x / 1 = x or x +/- 0 = x diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index caa0b7b0cb..4542d17ccc 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -20,10 +20,9 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" - #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index e540cc0476..bdbb8836e1 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -1,6 +1,10 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_protos_grappler", +) cc_library( name = "scc", @@ -210,3 +214,28 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) + +cc_library( + name = "symbolic_shapes", + srcs = ["symbolic_shapes.cc"], + hdrs = ["symbolic_shapes.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ] + tf_protos_grappler(), +) + +tf_cc_test( + name = "symbolic_shapes_test", + srcs = ["symbolic_shapes_test.cc"], + deps = [ + ":symbolic_shapes", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/utils/symbolic_shapes.cc index 155843a744..1666de4b80 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc +++ b/tensorflow/core/grappler/utils/symbolic_shapes.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/util/bcast.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/utils/symbolic_shapes.h index ace7bd1fe7..0a7d8ac82b 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h +++ b/tensorflow/core/grappler/utils/symbolic_shapes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_ -#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_ #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" @@ -74,4 +74,4 @@ int64 ComputeSizeRatio(const TensorShapeProto& numerator, } // namespace grappler } // end namespace tensorflow -#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_ diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc index 7ce995d1c5..6ac644cdb1 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc +++ b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/test.h" |