aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-20 13:56:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 14:01:17 -0700
commit17dbe77f5ad47e8fd71924f12b3bc53c05afbacf (patch)
tree46142d37c97ca378139cb73785171903a74f3516
parentd388770922ad1afa95e55597a33836fe74035c75 (diff)
Fix bug in Pow optimizer rule when broadcasting is involved.
Minor cleanup by moving the helper function ShapesEqual to GraphProperties and adding unit tests for it. PiperOrigin-RevId: 213876779
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD35
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc61
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc19
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc27
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc3
-rw-r--r--tensorflow/core/grappler/utils/BUILD29
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.cc)2
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes.h (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.h)6
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes_test.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc)2
9 files changed, 89 insertions, 95 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 029205248b..261dee4382 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -8,10 +8,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# Platform specific build config
load(
- "//tensorflow/core:platform/default/build_config.bzl",
- "tf_protos_grappler",
-)
-load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
)
@@ -97,7 +93,6 @@ cc_library(
deps = [
":evaluation_utils",
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -107,6 +102,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -261,7 +257,6 @@ cc_library(
":constant_folding",
":graph_optimizer",
":graph_optimizer_stage",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -270,6 +265,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -648,7 +644,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -658,6 +653,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -715,31 +711,6 @@ tf_cuda_cc_test(
)
cc_library(
- name = "symbolic_shapes",
- srcs = ["symbolic_shapes.cc"],
- hdrs = ["symbolic_shapes.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ] + tf_protos_grappler(),
-)
-
-tf_cc_test(
- name = "symbolic_shapes_test",
- srcs = ["symbolic_shapes_test.cc"],
- deps = [
- ":symbolic_shapes",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
-cc_library(
name = "debug_stripper",
srcs = ["debug_stripper.cc"],
hdrs = [
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 992e85d2c6..76a9dca73b 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -35,8 +35,8 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -2367,26 +2367,24 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
- for (int i = 0; i < p.shape().dim_size(); ++i) {
- if (p.shape().dim(i).size() < 0) {
+ const auto& pow_props =
+ ctx().graph_properties->GetInputProperties(node->name())[1];
+ for (int i = 0; i < pow_props.shape().dim_size(); ++i) {
+ if (pow_props.shape().dim(i).size() < 0) {
// skip if p is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor pow(p.dtype(), p.shape());
- if (!pow.FromProto(p.value())) {
+ if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
+ Tensor pow(pow_props.dtype(), pow_props.shape());
+ if (!pow.FromProto(pow_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
+ pow_props.value().DebugString());
}
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- if (!GetElementUnexhaustive(pow, i,
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &curr)) {
+ if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
// input data type is not supported by Pow. Skip.
return Status::OK();
}
@@ -2399,12 +2397,19 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ const auto& value_props =
+ ctx().graph_properties->GetInputProperties(node->name())[0];
+ const TensorShapeProto& output_shape =
+ ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
if (curr == complex128(2, 0)) {
node->set_op("Square");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(1, 0)) {
+ } else if (curr == complex128(1, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ // Pow could be used to broadcast, so make sure the shapes of the two
+ // arguments are identical before replacing Pow with Identity.
node->set_op("Identity");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
@@ -2414,20 +2419,20 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(0, 0)) {
- const auto& b =
- ctx().graph_properties->GetInputProperties(node->name())[0];
- for (int i = 0; i < b.shape().dim_size(); ++i) {
- if (b.shape().dim(i).size() < 0) {
+ } else if (curr == complex128(0, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ for (int i = 0; i < value_props.shape().dim_size(); ++i) {
+ if (value_props.shape().dim(i).size() < 0) {
// skip if b is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(b.shape()) && b.has_value()) {
- Tensor base(b.dtype(), b.shape());
- if (!base.FromProto(b.value())) {
+ if (TensorShape::IsValid(value_props.shape()) &&
+ value_props.has_value()) {
+ Tensor base(value_props.dtype(), value_props.shape());
+ if (!base.FromProto(value_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- b.value().DebugString());
+ value_props.value().DebugString());
}
node->set_op("Const");
Tensor c(base.dtype(), base.shape());
@@ -2585,12 +2590,10 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
~ConvertExpm1Stage() override = default;
bool IsSupported(const NodeDef* node) const override {
- if (!IsSub(*node))
- return false;
+ if (!IsSub(*node)) return false;
NodeDef* input;
- if (!GetInputNode(node->input(0), &input).ok())
- return false;
+ if (!GetInputNode(node->input(0), &input).ok()) return false;
return IsExp(*input);
}
@@ -2610,10 +2613,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
return Status::OK();
}
- const auto& t =
- ctx().graph_properties->GetInputProperties(exp->name())[0];
- const auto& c =
- ctx().graph_properties->GetInputProperties(node->name())[1];
+ const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
+ const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 88839d944c..77f3c64c65 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2474,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+ auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
+ auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
+ auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
@@ -2481,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
Output out = ops::Pow(s.WithOpName("out"), x, y);
+ Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
+ Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
GrapplerItem item;
- item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+ item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
+ "out_1", "out", "out_bcast1", "out_bcast2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
- EXPECT_EQ(7, tensors_expected.size());
+ EXPECT_EQ(9, tensors_expected.size());
GraphDef got;
ArithmeticOptimizer optimizer;
EnableOnlyConvertPow(&optimizer);
OptimizeAndPrune(&optimizer, &item, &got);
auto tensors = EvaluateNodes(got, item.fetch);
- EXPECT_EQ(7, tensors.size());
+ EXPECT_EQ(9, tensors.size());
- for (int i = 0; i < 7; ++i) {
+ for (int i = 0; i < tensors.size(); ++i) {
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
}
@@ -2509,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("y_.5", "Const", {}, {}, &want);
AddNode("y_1", "Const", {}, {}, &want);
AddNode("y", "Const", {}, {}, &want);
+ AddNode("z", "Const", {}, {}, &want);
+ AddNode("ones", "Const", {}, {}, &want);
+ AddNode("zeros", "Const", {}, {}, &want);
AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
@@ -2517,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
AddNode("out", "Pow", {"x", "y"}, {}, &want);
+ AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
+ AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
CompareGraphs(want, got);
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 99737a71eb..cfbd298f11 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -32,8 +32,8 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -437,25 +437,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
}
namespace {
-bool ShapesEqual(const TensorShapeProto& shape1,
- const TensorShapeProto& shape2) {
- if (shape1.unknown_rank() || shape2.unknown_rank()) {
- return false;
- }
- if (shape1.dim_size() != shape2.dim_size()) {
- return false;
- }
- for (int i = 0; i < shape1.dim_size(); ++i) {
- if (shape1.dim(i).size() != shape2.dim(i).size()) {
- return false;
- }
- if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
- return false;
- }
- }
- return true;
-}
-
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
BCast::Vec* shape, int64* min_id) {
if (shape_node.op() == "Shape") {
@@ -2348,7 +2329,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
- const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
+ const bool y_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
@@ -2378,7 +2360,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
- const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
+ const bool x_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index caa0b7b0cb..4542d17ccc 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -20,10 +20,9 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
-
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index e540cc0476..bdbb8836e1 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -1,6 +1,10 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_protos_grappler",
+)
cc_library(
name = "scc",
@@ -210,3 +214,28 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
+
+cc_library(
+ name = "symbolic_shapes",
+ srcs = ["symbolic_shapes.cc"],
+ hdrs = ["symbolic_shapes.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ] + tf_protos_grappler(),
+)
+
+tf_cc_test(
+ name = "symbolic_shapes_test",
+ srcs = ["symbolic_shapes_test.cc"],
+ deps = [
+ ":symbolic_shapes",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/utils/symbolic_shapes.cc
index 155843a744..1666de4b80 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/utils/symbolic_shapes.h
index ace7bd1fe7..0a7d8ac82b 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
@@ -74,4 +74,4 @@ int64 ComputeSizeRatio(const TensorShapeProto& numerator,
} // namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
index 7ce995d1c5..6ac644cdb1 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/test.h"