aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 02:41:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 02:45:01 -0700
commit0fd21d8c34e15bc3013e93014d101b672e1f3687 (patch)
tree207d412e7af182fd0c4f4a2566fdccce1fd760e1 /tensorflow/compiler
parent03c5f9cdce62f6711b91fe81505e3c085e54a771 (diff)
[TF:XLA] Teach deadness analysis more of distributive property.
PiperOrigin-RevId: 215183847
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc107
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc31
2 files changed, 112 insertions, 26 deletions
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 9128b48da3..25e2e9a7af 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
+#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -383,6 +384,8 @@ class PredicateFactory {
}
Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
+ Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
+ Predicate::Kind pred_kind);
// Predicate instances are interned, meaning that there is only a single
// instance of a Predicate object with a given content. This makes checking
@@ -429,11 +432,40 @@ class PredicateFactory {
interned_symbol_instances_;
};
+Predicate* PredicateFactory::MakeInternedAndOr(
+ std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
+ std::stable_sort(
+ simplified_ops.begin(), simplified_ops.end(),
+ [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+
+ auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
+ if (it != interned_and_or_instances_.end()) {
+ return it->second.get();
+ }
+
+ simplified_ops.shrink_to_fit();
+ // NB! Because we'll use a non-owning reference to simplified_ops in the
+ // key for interned_and_or_instances_ we need to be careful to std::move()
+ // it all the way through.
+ absl::Span<Predicate* const> operands_slice = simplified_ops;
+ std::unique_ptr<Predicate> new_pred =
+ pred_kind == Predicate::Kind::kAnd
+ ? Make<AndPredicate>(std::move(simplified_ops))
+ : Make<OrPredicate>(std::move(simplified_ops));
+
+ Predicate* new_pred_ptr = new_pred.get();
+ interned_and_or_instances_.emplace(
+ SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
+ return new_pred_ptr;
+}
+
// Common code to create AndPredicate or OrPredicate instances.
Predicate* PredicateFactory::MakeAndOrImpl(
absl::Span<Predicate* const> operands, bool is_and) {
Predicate::Kind pred_kind =
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
+ Predicate::Kind other_pred_kind =
+ is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
gtl::FlatSet<Predicate*> simplified_ops_set;
std::vector<Predicate*> simplified_ops;
for (Predicate* op : operands) {
@@ -472,30 +504,63 @@ Predicate* PredicateFactory::MakeAndOrImpl(
}
}
- std::stable_sort(
- simplified_ops.begin(), simplified_ops.end(),
- [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+ // If all ops contain the same subop, then factor it out thanks to the
+ // distributive property. Such as:
+ // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
+ // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
+ //
+ // First find any predicates contained in all subops.
+ std::vector<Predicate*> common_inner_operands;
+ gtl::FlatSet<Predicate*> common_inner_operands_set;
+ for (Predicate* op : simplified_ops) {
+ if (op->kind() != other_pred_kind) {
+ common_inner_operands.clear();
+ break;
+ }
- auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
- if (it == interned_and_or_instances_.end()) {
- simplified_ops.shrink_to_fit();
- // NB! Because we'll use a non-owning reference to simplified_ops in the
- // key for interned_and_or_instances_ we need to be careful to std::move()
- // it all the way through.
- absl::Span<Predicate* const> operands_slice = simplified_ops;
- std::unique_ptr<Predicate> new_pred =
- is_and ? Make<AndPredicate>(std::move(simplified_ops))
- : Make<OrPredicate>(std::move(simplified_ops));
+ if (common_inner_operands.empty()) {
+ common_inner_operands.insert(common_inner_operands.end(),
+ op->GetOperands().begin(),
+ op->GetOperands().end());
+ } else {
+ std::vector<Predicate*> sub_ops_intersection;
+ common_inner_operands.clear();
+ absl::c_copy_if(op->GetOperands(),
+ std::back_inserter(common_inner_operands),
+ [&](Predicate* sub_op) {
+ return common_inner_operands_set.count(sub_op) == 1;
+ });
+ }
+ if (common_inner_operands.empty()) break;
+ common_inner_operands_set.clear();
+ common_inner_operands_set.insert(common_inner_operands.begin(),
+ common_inner_operands.end());
+ }
- Predicate* new_pred_ptr = new_pred.get();
- CHECK(interned_and_or_instances_
- .emplace(SignatureForAndOr(pred_kind, operands_slice),
- std::move(new_pred))
- .second);
- return new_pred_ptr;
- } else {
- return it->second.get();
+ if (common_inner_operands.empty()) {
+ return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
}
+
+ // For all predicates that can be factored out, remove them and recreate the
+ // subops.
+ std::vector<Predicate*> factored_ops;
+ for (Predicate* op : simplified_ops) {
+ std::vector<Predicate*> new_sub_op_ops;
+ absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
+ [&](Predicate* sub_op) {
+ return std::find(common_inner_operands.begin(),
+ common_inner_operands.end(),
+ sub_op) == common_inner_operands.end();
+ });
+ factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
+ }
+
+ Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
+ std::vector<Predicate*> outer_ops;
+ outer_ops.push_back(new_inner_op);
+ outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
+ common_inner_operands.end());
+ return MakeAndOrImpl(outer_ops, !is_and);
}
class DeadnessAnalysisImpl : public DeadnessAnalysis {
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index 28a56044d5..617e31488c 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -384,10 +384,31 @@ TEST(DeadnessAnalysisTest, OrOfAnd) {
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
}
-TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
- // This demonstrates one of the weaknesses in the current approach -- since we
- // only do some basic simplifications we can't see that "(A|B)&C" ==
- // "(A&C)|(B&C)".
+TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
+ // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "A");
+ ops::Switch sw_1 = CreateSwitch(root, "B");
+ Output add0 =
+ ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
+ Output add1 =
+ ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
+ ops::Merge or2(root.WithOpName("or2"), {add0, add1});
+ Output add3 =
+ ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
+ ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
+ EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
+}
+
+TEST(DeadnessAnalysisTest, AndOrDistributive) {
+ // (A|B)&C == (A&C)|(B&C)
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
@@ -408,7 +429,7 @@ TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
- EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node()));
}
TEST(DeadnessAnalysisTest, Ternary) {