aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-23 14:58:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 15:01:48 -0700
commitf97fec3cf5d361103d21989b78a74dd1820620d8 (patch)
tree5bdecdf30ea8ff3a758b08a7036a93eb23925b35 /tensorflow/compiler/tf2xla/lib
parentd12244894aa0cdd068b46ebed407ced1915272b2 (diff)
Refactoring triangular_solve.cc to use the new common utility functions.
PiperOrigin-RevId: 193990473
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib')
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc82
1 files changed, 25 insertions, 57 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 7f72a6073d..9bf5821b54 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -83,15 +83,6 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
block_size);
}
- // Returns [b1, b2, ... , bn, indices[0], indices[1]].
- auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
- std::vector<int64> output(ndims);
- std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
- std::copy(indices.begin(), indices.end(),
- output.begin() + batch_dimensions.size());
- return output;
- };
-
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
auto maybe_conj = [&](xla::ComputationBuilder* builder,
@@ -108,11 +99,12 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
std::unique_ptr<xla::ComputationBuilder> sub = builder->CreateSubBuilder(
tensorflow::strings::StrCat("trsm_base_", k));
- auto a_param =
- sub->Parameter(0,
- xla::ShapeUtil::MakeShape(b_shape->element_type(),
- prepend_batch_dims({k, k})),
- "a");
+ auto a_param = sub->Parameter(
+ 0,
+ xla::ShapeUtil::MakeShape(
+ b_shape->element_type(),
+ PrependMajorDims(sub.get(), batch_dimensions, {k, k})),
+ "a");
std::array<int64, 2> b_lastd;
if (left_side) {
@@ -120,11 +112,12 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
} else {
b_lastd = {m, k};
}
- auto b_param =
- sub->Parameter(1,
- xla::ShapeUtil::MakeShape(b_shape->element_type(),
- prepend_batch_dims(b_lastd)),
- "b");
+ auto b_param = sub->Parameter(
+ 1,
+ xla::ShapeUtil::MakeShape(
+ b_shape->element_type(),
+ PrependMajorDims(sub.get(), batch_dimensions, b_lastd)),
+ "b");
// We use a left-looking subroutine on the block diagonal in some common
// cases, while falling back to a recursive call in unsupported cases. The
@@ -380,14 +373,6 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
batch_dimensions.push_back(a_size);
}
- auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
- std::vector<int64> output(ndims);
- std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
- std::copy(indices.begin(), indices.end(),
- output.begin() + batch_dimensions.size());
- return output;
- };
-
auto maybe_conj = [&](xla::ComputationBuilder* builder,
xla::ComputationDataHandle x) {
auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
@@ -479,30 +464,6 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
auto body_b = bodyb->GetTupleElement(input_tuple, 3);
auto zero = bodyb->ConstantR0<int32>(0);
- // Set up some helper functions.
- auto prepend_zeros = [&](std::array<xla::ComputationDataHandle, 2> starts) {
- auto zero = bodyb->Reshape(bodyb->ConstantR0<int32>(0), {1});
- std::vector<xla::ComputationDataHandle> padded_starts(ndims, zero);
- padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1});
- padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1});
- return bodyb->ConcatInDim(padded_starts, 0);
- };
-
- auto dynamic_slice = [&](xla::ComputationDataHandle x,
- std::array<xla::ComputationDataHandle, 2> starts,
- std::array<int64, 2> sizes) {
- auto padded_starts = prepend_zeros(starts);
- auto padded_sizes = prepend_batch_dims(sizes);
- return bodyb->DynamicSlice(x, padded_starts, padded_sizes);
- };
-
- auto update = [&](xla::ComputationDataHandle x,
- xla::ComputationDataHandle update,
- std::array<xla::ComputationDataHandle, 2> starts) {
- auto padded_starts = prepend_zeros(starts);
- return bodyb->DynamicUpdateSlice(x, update, padded_starts);
- };
-
// We'd like to implement this:
// if transpose_a:
// a_row = T(a[..., i+1:, i:i+1])
@@ -516,22 +477,29 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// all zeros and use that as zero-padding (doing unnecessary FLOPs).
xla::ComputationDataHandle a_row;
if (transpose_a) {
- a_row = dynamic_slice(body_a, {zero, i}, {m, 1});
+ TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
+ {zero, i}, {m, 1}));
} else {
- a_row = dynamic_slice(body_a, {i, zero}, {1, m});
+ TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
+ {i, zero}, {1, m}));
}
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out,
/*transpose_x=*/transpose_a,
/*transpose_y=*/false,
/*conjugate_x=*/conjugate_a,
/*conjugate_y=*/false));
- auto result_row =
- bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update);
+ TF_ASSIGN_OR_RETURN(
+ auto result_row_slice,
+ DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n}));
+ auto result_row = bodyb->Sub(result_row_slice, b_update);
// body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
- auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1});
+ TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a,
+ {i, i}, {1, 1}));
auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt));
- body_out = update(body_out, div_result, {i, zero});
+ TF_ASSIGN_OR_RETURN(body_out,
+ DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
+ div_result, {i, zero}));
// if transpose_a:
// return (i - 1, body_out, a, b)