diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib')
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/qr.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/qr.h | 2 |
2 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index b6f30d8d49..df2504a0f9 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -331,7 +331,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation( // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. xla::StatusOr<QRDecompositionResult> QRDecomposition( - xla::XlaOp a, int64 block_size, + xla::XlaOp a, bool full_matrices, int64 block_size, xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -396,6 +396,13 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition( q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } QRDecompositionResult result; + + // full_matrices is false when only a partial result in needed. Slice to the + // needed dimensions here. + if (!full_matrices) { + q = SliceInMinorDims(q, {0, 0}, {m, p}); + a = SliceInMinorDims(a, {0, 0}, {p, n}); + } result.q = q; result.r = a; return result; diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index 05565477b6..8a389fb7b0 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -34,7 +34,7 @@ struct QRDecompositionResult { }; xla::StatusOr<QRDecompositionResult> QRDecomposition( - xla::XlaOp a, int64 block_size = 128, + xla::XlaOp a, bool full_matrices, int64 block_size = 128, xla::PrecisionConfigProto::Precision precision = xla::PrecisionConfigProto::HIGHEST); |