aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2017-03-24 12:17:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-24 14:08:40 -0700
commit45dbb0a02d2fa5b1eb20836fe854e19842b9593f (patch)
tree0bc88a0628e6983a9b75225c8b59e87bf74a00df
parentfdf32bc5ede067b000c42f6404df8ab98e56ec11 (diff)
Strength reduce Dot into broadcasting multiply and reduce. Also optimizes
transposes and reshapes that feed reductions. Change: 151162327
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc223
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h1
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc66
4 files changed, 298 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 9fe2d0e6b6..d171a8dfff 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -76,6 +76,24 @@ bool ReshapeIsBitcast(
return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) &&
valid_bitcast_callback(operand->shape(), reshape->shape());
}
+
+// Adds a scalar computation to the module to enable optimizations with dot
+// converting into reduction.
+HloComputation* CreateScalarBinaryComputation(HloModule* module,
+ PrimitiveType primitive_type,
+ HloOpcode opcode) {
+ HloComputation::Builder b("scalar computation");
+ auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "scalar lhs"));
+ auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {}), "scalar rhs"));
+ auto scalar_op = b.AddInstruction(
+ HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
+ opcode, scalar_lhs, scalar_rhs));
+ HloComputation* scalar_computation =
+ module->AddEmbeddedComputation(b.Build(scalar_op));
+ return scalar_computation;
+}
} // namespace
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
@@ -105,6 +123,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
HloInstruction* rhs) override;
+ Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
Status HandleGetTupleElement(HloInstruction* get_tuple_element,
HloInstruction* operand) override;
@@ -304,6 +325,140 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide,
return Status::OK();
}
+Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or
+ // below.
+ if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 ||
+ ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) {
+ return Status::OK();
+ }
+
+ // Replace a zero element dot with a broadcast of the constant 0.
+ if (ShapeUtil::HasZeroElements(dot->shape()) ||
+ ShapeUtil::HasZeroElements(lhs->shape()) ||
+ ShapeUtil::HasZeroElements(rhs->shape())) {
+ auto zero = computation_->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
+ changed_ = true;
+ return computation_->ReplaceWithNewInstruction(
+ dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
+ }
+
+ // Simplify dot(transpose(a), transpose(b)) to tranpose(dot(b,a)).
+ if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
+ auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot,
+ rhs->mutable_operand(0), lhs->mutable_operand(0)));
+ changed_ = true;
+ return computation_->ReplaceWithNewInstruction(
+ dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
+ }
+
+ // Simplify outer product into multiply with implicit broadcasting.
+ //
+ // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
+ if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) {
+ changed_ = true;
+ return computation_->ReplaceWithNewInstruction(
+ dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
+ lhs, rhs));
+ }
+
+ // The following graph transformations take Dots where at least one input is a
+ // vector or has a degenerate dimension and converts it into a multiply and
+ // reduce. This should enable more fusion than leaving the nodes as Dot
+ // operations.
+
+ // Strength reduce dot(a[K] , b[K]) =
+ // reshape(result.shape,
+ // reduce_sum(multiply(a, b), {0}))
+ if (ShapeUtil::Rank(rhs->shape()) == 1 &&
+ ShapeUtil::Rank(lhs->shape()) == 1) {
+ auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
+ rhs->shape(), HloOpcode::kMultiply, lhs, rhs));
+ HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
+ computation_->parent(), F32, HloOpcode::kAdd);
+ auto zero = computation_->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
+ auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
+ ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
+ {0}, add_reduce_computation));
+ changed_ = true;
+ return computation_->ReplaceWithNewInstruction(
+ dot, HloInstruction::CreateReshape(dot->shape(), reduce));
+ }
+
+ // Strength reduce dot(a[1, K], b) =
+ // reshape(result.shape,
+ // reduce_sum(
+ // multiply(broadcast(reshape(a, [K]), {0}), b),
+ // {0})
+ // )
+ // )
+ if (ShapeUtil::Rank(lhs->shape()) == 1 ||
+ (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) {
+ auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(lhs->shape().element_type(),
+ {ShapeUtil::ElementsIn(lhs->shape())}),
+ lhs));
+ HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
+ computation_->parent(), F32, HloOpcode::kAdd);
+ auto zero = computation_->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
+ HloInstruction* reduce;
+ if (ShapeUtil::Rank(rhs->shape()) == 1) {
+ auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
+ rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
+ reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
+ ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
+ {0}, add_reduce_computation));
+ } else {
+ new_lhs = computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0}));
+ auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
+ rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
+
+ reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
+ ShapeUtil::MakeShape(dot->shape().element_type(),
+ {rhs->shape().dimensions(1)}),
+ multiply, zero, {0}, add_reduce_computation));
+ }
+ changed_ = true;
+ return computation_->ReplaceWithNewInstruction(
+ dot, HloInstruction::CreateReshape(dot->shape(), reduce));
+ }
+
+ // Strength reduce dot(a, b[K, 1]) =
+ // reshape(result.shape,
+ // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
+ // )
+ if (ShapeUtil::Rank(rhs->shape()) == 1 ||
+ (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) {
+ auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(rhs->shape().element_type(),
+ {ShapeUtil::ElementsIn(rhs->shape())}),
+ rhs));
+ new_rhs = computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1}));
+ auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
+ lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs));
+ HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
+ computation_->parent(), F32, HloOpcode::kAdd);
+ auto zero = computation_->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
+ auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
+ ShapeUtil::MakeShape(dot->shape().element_type(),
+ {lhs->shape().dimensions(0)}),
+ multiply, zero, {1}, add_reduce_computation));
+ changed_ = true;
+ return computation_->ReplaceWithNewInstruction(
+ dot, HloInstruction::CreateReshape(dot->shape(), reduce));
+ }
+ return Status::OK();
+}
+
Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply,
HloInstruction* lhs,
HloInstruction* rhs) {
@@ -858,8 +1013,74 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice,
Status AlgebraicSimplifierVisitor::HandleReduce(
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) {
+ if (ShapeUtil::HasZeroElements(arg->shape()) ||
+ ShapeUtil::HasZeroElements(reduce->shape())) {
+ return computation_->ReplaceWithNewInstruction(
+ reduce,
+ HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
+ return Status::OK();
+ }
+ // A Transpose feeding a reduce can simply permute the reduction dimensions
+ // field.
+ if (arg->opcode() == HloOpcode::kTranspose) {
+ auto transpose_dimensions = arg->dimensions();
+ std::vector<int64> new_reduce_dimensions;
+ for (auto dim : dimensions) {
+ new_reduce_dimensions.push_back(transpose_dimensions[dim]);
+ }
+ return computation_->ReplaceWithNewInstruction(
+ reduce, HloInstruction::CreateReduce(
+ reduce->shape(), arg->mutable_operand(0), init_value,
+ new_reduce_dimensions, function));
+ }
+
+ // A reshape that collapses multiple dimensions into a dimension being reduced
+ // can just reduce all of those dimensions instead of doing a collapsing
+ // reshape before a reduction.
+ if (arg->opcode() == HloOpcode::kReshape) {
+ std::vector<std::pair<int64, int64>> unmodified_dims =
+ ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
+ arg->shape());
+ std::vector<bool> arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true);
+ std::vector<bool> arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false);
+ for (auto dim : dimensions) {
+ arg_dim_in_output[dim] = false;
+ }
+ for (auto dim_pair : unmodified_dims) {
+ arg_dim_unmodified[dim_pair.second] = true;
+ }
+ // The goal is to verify that all dimensions that are not removed in the
+ // reduce are unmodified by the reshape. For example:
+ // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
+ bool can_move_reshape_into_reduce = true;
+ for (int64 i = 0; i < arg_dim_in_output.size(); ++i) {
+ if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
+ can_move_reshape_into_reduce = false;
+ }
+ }
+ if (can_move_reshape_into_reduce) {
+ changed_ = true;
+ std::unordered_set<int64> dimensions_not_to_reduce;
+ for (auto dim_pair : unmodified_dims) {
+ if (arg_dim_in_output[dim_pair.second]) {
+ dimensions_not_to_reduce.insert(dim_pair.first);
+ }
+ }
+ std::vector<int64> new_reduce_dimensions;
+ for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) {
+ if (dimensions_not_to_reduce.count(i) == 0) {
+ new_reduce_dimensions.push_back(i);
+ }
+ }
+ return computation_->ReplaceWithNewInstruction(
+ reduce, HloInstruction::CreateReduce(
+ reduce->shape(), arg->mutable_operand(0), init_value,
+ new_reduce_dimensions, function));
+ }
+ }
if (ShapeUtil::ElementsIn(reduce->shape()) ==
- ShapeUtil::ElementsIn(arg->shape())) {
+ ShapeUtil::ElementsIn(arg->shape()) ||
+ ShapeUtil::HasZeroElements(arg->shape())) {
auto reshape = computation_->AddInstruction(
HloInstruction::CreateReshape(reduce->shape(), arg));
changed_ = true;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index fc02b2c4ef..08cde9d032 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -194,6 +194,7 @@ class HloComputation {
// Set/get the module containing this computation.
void set_parent(HloModule* module) { parent_ = module; }
const HloModule* parent() const { return parent_; }
+ HloModule* parent() { return parent_; }
// Visit every node in the computation in DFS post-order with the given
// visitor. This is similar to calling HloInstruction::Accept on the root of
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 45df811453..64d2af3c39 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -67,6 +67,15 @@ XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) {
ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_);
}
+XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2<float>({{3.0, 4.0}});
+ auto rhs = builder.ConstantR1<float>({3.0, 4.0});
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR1<float>(&builder, {25.0}, {}, error_spec_);
+}
+
template <typename Element>
void DotOperationTest::TestOneElementVectorDot() {
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 3d61b624dc..34fce21758 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -320,6 +320,72 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
ErrorSpec(0.01, 1e-4));
}
+XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
+ const int64 rows = 111, cols = 50;
+
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto log_ = builder.Log(input);
+ auto transpose = builder.Transpose(log_, {1, 0});
+ builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal =
+ LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ std::vector<float> expected;
+ for (int64 colno = 0; colno < cols; ++colno) {
+ float column_sum = 0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ column_sum += log(input_data(rowno, colno));
+ }
+ expected.push_back(column_sum);
+ }
+ ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+}
+
+XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
+ const int64 rows = 111, cols = 50;
+
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto log_ = builder.Log(input);
+ auto reshape = builder.Reshape(log_, {rows, cols});
+ builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ Array3D<float> input_data(rows, 2, cols / 2);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR3FromArray3D(input_data);
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ std::vector<float> expected;
+ for (int64 major = 0; major < 2; ++major) {
+ for (int64 colno = 0; colno < cols / 2; ++colno) {
+ float column_sum = 0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ column_sum += log(input_data(rowno, major, colno));
+ }
+ expected.push_back(column_sum);
+ }
+ }
+ ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+}
+
struct BoundsLayout {
std::vector<int64> bounds;
std::vector<int64> layout;