aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 01:56:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 01:59:58 -0700
commit1a046ed54bea276db1121c6d8f92fd817dc18077 (patch)
tree19b0086ca43ee548a8571ababf32a0ab456e422d /tensorflow/compiler/tf2xla/lib
parentc73d4e56eb2ac66e8fb519cbe83c5f7bddbfc80a (diff)
[TF:XLA] Implement full_matrices=False case of QR decomposition
PiperOrigin-RevId: 210870412
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib')
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.cc9
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.h2
2 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
index b6f30d8d49..df2504a0f9 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.cc
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -331,7 +331,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
// rather than WY transformations.
xla::StatusOr<QRDecompositionResult> QRDecomposition(
- xla::XlaOp a, int64 block_size,
+ xla::XlaOp a, bool full_matrices, int64 block_size,
xla::PrecisionConfigProto::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -396,6 +396,13 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition(
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
}
QRDecompositionResult result;
+
+ // full_matrices is false when only a partial result in needed. Slice to the
+ // needed dimensions here.
+ if (!full_matrices) {
+ q = SliceInMinorDims(q, {0, 0}, {m, p});
+ a = SliceInMinorDims(a, {0, 0}, {p, n});
+ }
result.q = q;
result.r = a;
return result;
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
index 05565477b6..8a389fb7b0 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.h
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -34,7 +34,7 @@ struct QRDecompositionResult {
};
xla::StatusOr<QRDecompositionResult> QRDecomposition(
- xla::XlaOp a, int64 block_size = 128,
+ xla::XlaOp a, bool full_matrices, int64 block_size = 128,
xla::PrecisionConfigProto::Precision precision =
xla::PrecisionConfigProto::HIGHEST);