diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-01 02:41:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 02:45:01 -0700 |
commit | 0fd21d8c34e15bc3013e93014d101b672e1f3687 (patch) | |
tree | 207d412e7af182fd0c4f4a2566fdccce1fd760e1 /tensorflow/compiler | |
parent | 03c5f9cdce62f6711b91fe81505e3c085e54a771 (diff) |
[TF:XLA] Teach deadness analysis more of distributive property.
PiperOrigin-RevId: 215183847
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/jit/deadness_analysis.cc | 107 | ||||
-rw-r--r-- | tensorflow/compiler/jit/deadness_analysis_test.cc | 31 |
2 files changed, 112 insertions, 26 deletions
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 9128b48da3..25e2e9a7af 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/algorithm/container.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" @@ -383,6 +384,8 @@ class PredicateFactory { } Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and); + Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops, + Predicate::Kind pred_kind); // Predicate instances are interned, meaning that there is only a single // instance of a Predicate object with a given content. This makes checking @@ -429,11 +432,40 @@ class PredicateFactory { interned_symbol_instances_; }; +Predicate* PredicateFactory::MakeInternedAndOr( + std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) { + std::stable_sort( + simplified_ops.begin(), simplified_ops.end(), + [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + + auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); + if (it != interned_and_or_instances_.end()) { + return it->second.get(); + } + + simplified_ops.shrink_to_fit(); + // NB! Because we'll use a non-owning reference to simplified_ops in the + // key for interned_and_or_instances_ we need to be careful to std::move() + // it all the way through. + absl::Span<Predicate* const> operands_slice = simplified_ops; + std::unique_ptr<Predicate> new_pred = + pred_kind == Predicate::Kind::kAnd + ? Make<AndPredicate>(std::move(simplified_ops)) + : Make<OrPredicate>(std::move(simplified_ops)); + + Predicate* new_pred_ptr = new_pred.get(); + interned_and_or_instances_.emplace( + SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred)); + return new_pred_ptr; +} + // Common code to create AndPredicate or OrPredicate instances. Predicate* PredicateFactory::MakeAndOrImpl( absl::Span<Predicate* const> operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; + Predicate::Kind other_pred_kind = + is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; gtl::FlatSet<Predicate*> simplified_ops_set; std::vector<Predicate*> simplified_ops; for (Predicate* op : operands) { @@ -472,30 +504,63 @@ Predicate* PredicateFactory::MakeAndOrImpl( } } - std::stable_sort( - simplified_ops.begin(), simplified_ops.end(), - [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + // If all ops contain the same subop, then factor it out thanks to the + // distributive property. Such as: + // - (A & B) | (A & C) | (A & D) => A & (B | C | D) + // - (A | B) & (A | C) & (A | D) => A | (B & C & D) + // + // First find any predicates contained in all subops. + std::vector<Predicate*> common_inner_operands; + gtl::FlatSet<Predicate*> common_inner_operands_set; + for (Predicate* op : simplified_ops) { + if (op->kind() != other_pred_kind) { + common_inner_operands.clear(); + break; + } - auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); - if (it == interned_and_or_instances_.end()) { - simplified_ops.shrink_to_fit(); - // NB! Because we'll use a non-owning reference to simplified_ops in the - // key for interned_and_or_instances_ we need to be careful to std::move() - // it all the way through. - absl::Span<Predicate* const> operands_slice = simplified_ops; - std::unique_ptr<Predicate> new_pred = - is_and ? Make<AndPredicate>(std::move(simplified_ops)) - : Make<OrPredicate>(std::move(simplified_ops)); + if (common_inner_operands.empty()) { + common_inner_operands.insert(common_inner_operands.end(), + op->GetOperands().begin(), + op->GetOperands().end()); + } else { + std::vector<Predicate*> sub_ops_intersection; + common_inner_operands.clear(); + absl::c_copy_if(op->GetOperands(), + std::back_inserter(common_inner_operands), + [&](Predicate* sub_op) { + return common_inner_operands_set.count(sub_op) == 1; + }); + } + if (common_inner_operands.empty()) break; + common_inner_operands_set.clear(); + common_inner_operands_set.insert(common_inner_operands.begin(), + common_inner_operands.end()); + } - Predicate* new_pred_ptr = new_pred.get(); - CHECK(interned_and_or_instances_ - .emplace(SignatureForAndOr(pred_kind, operands_slice), - std::move(new_pred)) - .second); - return new_pred_ptr; - } else { - return it->second.get(); + if (common_inner_operands.empty()) { + return MakeInternedAndOr(std::move(simplified_ops), pred_kind); } + + // For all predicates that can be factored out, remove them and recreate the + // subops. + std::vector<Predicate*> factored_ops; + for (Predicate* op : simplified_ops) { + std::vector<Predicate*> new_sub_op_ops; + absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops), + [&](Predicate* sub_op) { + return std::find(common_inner_operands.begin(), + common_inner_operands.end(), + sub_op) == common_inner_operands.end(); + }); + factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and)); + } + + Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and); + std::vector<Predicate*> outer_ops; + outer_ops.push_back(new_inner_op); + outer_ops.insert(outer_ops.end(), common_inner_operands.begin(), + common_inner_operands.end()); + return MakeAndOrImpl(outer_ops, !is_and); } class DeadnessAnalysisImpl : public DeadnessAnalysis { diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 28a56044d5..617e31488c 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -384,10 +384,31 @@ TEST(DeadnessAnalysisTest, OrOfAnd) { EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); } -TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { - // This demonstrates one of the weaknesses in the current approach -- since we - // only do some basic simplifications we can't see that "(A|B)&C" == - // "(A&C)|(B&C)". +TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) { + // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "A"); + ops::Switch sw_1 = CreateSwitch(root, "B"); + Output add0 = + ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true); + Output add1 = + ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false); + ops::Merge or2(root.WithOpName("or2"), {add0, add1}); + Output add3 = + ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false); + ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true}); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true"); +} + +TEST(DeadnessAnalysisTest, AndOrDistributive) { + // (A|B)&C == (A&C)|(B&C) Scope root = Scope::NewRootScope().ExitOnError(); ops::Switch sw_0 = CreateSwitch(root, "0"); @@ -408,7 +429,7 @@ TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { std::unique_ptr<DeadnessAnalysis> result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node())); } TEST(DeadnessAnalysisTest, Ternary) { |