diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-23 14:58:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-23 15:01:48 -0700 |
commit | f97fec3cf5d361103d21989b78a74dd1820620d8 (patch) | |
tree | 5bdecdf30ea8ff3a758b08a7036a93eb23925b35 /tensorflow/compiler/tf2xla/lib | |
parent | d12244894aa0cdd068b46ebed407ced1915272b2 (diff) |
Refactoring triangular_solve.cc to use the new common utility functions.
PiperOrigin-RevId: 193990473
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib')
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/triangular_solve.cc | 82 |
1 files changed, 25 insertions, 57 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 7f72a6073d..9bf5821b54 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -83,15 +83,6 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve( block_size); } - // Returns [b1, b2, ... , bn, indices[0], indices[1]]. - auto prepend_batch_dims = [&](std::array<int64, 2> indices) { - std::vector<int64> output(ndims); - std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); - std::copy(indices.begin(), indices.end(), - output.begin() + batch_dimensions.size()); - return output; - }; - // Applies a complex conjugation operation if `a` is complex and `conjugate_a` // is true, otherwise returns its argument. auto maybe_conj = [&](xla::ComputationBuilder* builder, @@ -108,11 +99,12 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve( std::unique_ptr<xla::ComputationBuilder> sub = builder->CreateSubBuilder( tensorflow::strings::StrCat("trsm_base_", k)); - auto a_param = - sub->Parameter(0, - xla::ShapeUtil::MakeShape(b_shape->element_type(), - prepend_batch_dims({k, k})), - "a"); + auto a_param = sub->Parameter( + 0, + xla::ShapeUtil::MakeShape( + b_shape->element_type(), + PrependMajorDims(sub.get(), batch_dimensions, {k, k})), + "a"); std::array<int64, 2> b_lastd; if (left_side) { @@ -120,11 +112,12 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve( } else { b_lastd = {m, k}; } - auto b_param = - sub->Parameter(1, - xla::ShapeUtil::MakeShape(b_shape->element_type(), - prepend_batch_dims(b_lastd)), - "b"); + auto b_param = sub->Parameter( + 1, + xla::ShapeUtil::MakeShape( + b_shape->element_type(), + PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), + "b"); // We use a left-looking subroutine on the block diagonal in some common // cases, while falling back to a recursive call in unsupported cases. The @@ -380,14 +373,6 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking( batch_dimensions.push_back(a_size); } - auto prepend_batch_dims = [&](std::array<int64, 2> indices) { - std::vector<int64> output(ndims); - std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); - std::copy(indices.begin(), indices.end(), - output.begin() + batch_dimensions.size()); - return output; - }; - auto maybe_conj = [&](xla::ComputationBuilder* builder, xla::ComputationDataHandle x) { auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; @@ -479,30 +464,6 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking( auto body_b = bodyb->GetTupleElement(input_tuple, 3); auto zero = bodyb->ConstantR0<int32>(0); - // Set up some helper functions. - auto prepend_zeros = [&](std::array<xla::ComputationDataHandle, 2> starts) { - auto zero = bodyb->Reshape(bodyb->ConstantR0<int32>(0), {1}); - std::vector<xla::ComputationDataHandle> padded_starts(ndims, zero); - padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1}); - padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1}); - return bodyb->ConcatInDim(padded_starts, 0); - }; - - auto dynamic_slice = [&](xla::ComputationDataHandle x, - std::array<xla::ComputationDataHandle, 2> starts, - std::array<int64, 2> sizes) { - auto padded_starts = prepend_zeros(starts); - auto padded_sizes = prepend_batch_dims(sizes); - return bodyb->DynamicSlice(x, padded_starts, padded_sizes); - }; - - auto update = [&](xla::ComputationDataHandle x, - xla::ComputationDataHandle update, - std::array<xla::ComputationDataHandle, 2> starts) { - auto padded_starts = prepend_zeros(starts); - return bodyb->DynamicUpdateSlice(x, update, padded_starts); - }; - // We'd like to implement this: // if transpose_a: // a_row = T(a[..., i+1:, i:i+1]) @@ -516,22 +477,29 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking( // all zeros and use that as zero-padding (doing unnecessary FLOPs). xla::ComputationDataHandle a_row; if (transpose_a) { - a_row = dynamic_slice(body_a, {zero, i}, {m, 1}); + TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, + {zero, i}, {m, 1})); } else { - a_row = dynamic_slice(body_a, {i, zero}, {1, m}); + TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, zero}, {1, m})); } TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, /*transpose_x=*/transpose_a, /*transpose_y=*/false, /*conjugate_x=*/conjugate_a, /*conjugate_y=*/false)); - auto result_row = - bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update); + TF_ASSIGN_OR_RETURN( + auto result_row_slice, + DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); + auto result_row = bodyb->Sub(result_row_slice, b_update); // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1}); + TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, i}, {1, 1})); auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt)); - body_out = update(body_out, div_result, {i, zero}); + TF_ASSIGN_OR_RETURN(body_out, + DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, + div_result, {i, zero})); // if transpose_a: // return (i - 1, body_out, a, b) |