diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-30 01:56:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-30 01:59:58 -0700 |
commit | 1a046ed54bea276db1121c6d8f92fd817dc18077 (patch) | |
tree | 19b0086ca43ee548a8571ababf32a0ab456e422d /tensorflow/compiler/tf2xla/lib | |
parent | c73d4e56eb2ac66e8fb519cbe83c5f7bddbfc80a (diff) |
[TF:XLA] Implement full_matrices=False case of QR decomposition
PiperOrigin-RevId: 210870412
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib')
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/qr.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/qr.h | 2 |
2 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index b6f30d8d49..df2504a0f9 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -331,7 +331,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation( // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. xla::StatusOr<QRDecompositionResult> QRDecomposition( - xla::XlaOp a, int64 block_size, + xla::XlaOp a, bool full_matrices, int64 block_size, xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -396,6 +396,13 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition( q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } QRDecompositionResult result; + + // full_matrices is false when only a partial result in needed. Slice to the + // needed dimensions here. + if (!full_matrices) { + q = SliceInMinorDims(q, {0, 0}, {m, p}); + a = SliceInMinorDims(a, {0, 0}, {p, n}); + } result.q = q; result.r = a; return result; diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index 05565477b6..8a389fb7b0 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -34,7 +34,7 @@ struct QRDecompositionResult { }; xla::StatusOr<QRDecompositionResult> QRDecomposition( - xla::XlaOp a, int64 block_size = 128, + xla::XlaOp a, bool full_matrices, int64 block_size = 128, xla::PrecisionConfigProto::Precision precision = xla::PrecisionConfigProto::HIGHEST); |