aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/numeric.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/numeric.cc')
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.cc64
1 files changed, 61 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc
index fd4e8fc390..1c91237ae1 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric.cc
@@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
-
#include <numeric>
#include <vector>
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
namespace xla {
namespace {
@@ -28,7 +31,7 @@ XlaOp MakeIota(XlaBuilder* builder, int64 size) {
for (int64 i = 0; i < size; ++i) {
values[i] = static_cast<T>(i);
}
- return xla::ConstantR1<T>(builder, values);
+ return ConstantR1<T>(builder, values);
}
} // namespace
@@ -76,4 +79,59 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
return ConvertElementType(indicator, type);
}
+XlaOp GetMatrixDiagonal(XlaOp x) {
+ XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
+ const int64 n_dims = ShapeUtil::Rank(shape);
+ TF_RET_CHECK(n_dims >= 2);
+ const int64 m = shape.dimensions(n_dims - 2);
+ const int64 n = shape.dimensions(n_dims - 1);
+ tensorflow::gtl::ArraySlice<int64> major_dims(
+ AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2);
+ auto a = Iota(builder, U32, n);
+ auto b = Iota(builder, U32, m);
+ auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
+ auto mask = Broadcast(indicator, major_dims);
+
+ // TPUs don't support S64 add reduction at the moment. But fortunately
+ // OR-reductions work just as well for integers.
+ XlaComputation reducer =
+ primitive_util::IsIntegralType(shape.element_type())
+ ? CreateScalarOrComputation(shape.element_type(), builder)
+ : CreateScalarAddComputation(shape.element_type(), builder);
+
+ return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
+ reducer, {m >= n ? n_dims - 2 : n_dims - 1});
+ });
+}
+
+XlaOp Triangle(XlaOp x, bool lower) {
+ XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
+ const int64 n_dims = ShapeUtil::Rank(shape);
+ TF_RET_CHECK(n_dims >= 2);
+ const int64 m = shape.dimensions(n_dims - 2);
+ const int64 n = shape.dimensions(n_dims - 1);
+ tensorflow::gtl::ArraySlice<int64> major_dims(
+ AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2);
+ auto a = Iota(builder, U32, n);
+ auto b = Iota(builder, U32, m);
+ xla::XlaOp indicator;
+ if (lower) {
+ indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
+ } else {
+ indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
+ }
+ auto mask = Broadcast(indicator, major_dims);
+
+ return Select(mask, x, Zeros(builder, shape));
+ });
+}
+
+XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
+
+XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
+
} // namespace xla