aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/BUILD13
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py44
-rw-r--r--tensorflow/compiler/tests/matrix_band_part_test.py64
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc98
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc93
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.cc12
-rw-r--r--tensorflow/core/ops/array_ops.cc5
-rw-r--r--tensorflow/python/kernel_tests/matrix_band_part_op_test.py11
9 files changed, 335 insertions, 7 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 7277ba42ce..b0b038775f 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -354,6 +354,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "matrix_band_part_test",
+ size = "medium",
+ srcs = ["matrix_band_part_test.py"],
+ tags = ["optonly"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "momentum_test",
size = "small",
srcs = ["momentum_test.py"],
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 16856bd736..9d34cdfe10 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1181,6 +1181,50 @@ class BinaryOpsTest(XLATestCase):
np.array([4, 5, 6], dtype=np.int32),
expected=None)
+ def testMatrixSetDiag(self):
+ for dtype in self.numeric_types:
+ # Square
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
+ dtype=dtype),
+ np.array([1.0, 2.0, 3.0], dtype=dtype),
+ expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]],
+ dtype=dtype))
+
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
+ [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
+ dtype=dtype),
+ np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype),
+ expected=np.array(
+ [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]],
+ [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]],
+ dtype=dtype))
+
+ # Rectangular
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
+ np.array([3.0, 4.0], dtype=dtype),
+ expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
+
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
+ np.array([3.0, 4.0], dtype=dtype),
+ expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
+
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
+ [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
+ np.array([[-1.0, -2.0], [-4.0, -5.0]],
+ dtype=dtype),
+ expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
+ [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
+ dtype=dtype))
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
new file mode 100644
index 0000000000..29394f9ea5
--- /dev/null
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class MatrixBandPartTest(XLATestCase):
+
+ def _testMatrixBandPart(self, dtype, shape):
+ with self.test_session():
+ batch_shape = shape[:-2]
+ mat = np.ones(shape).astype(dtype)
+ batch_mat = np.tile(mat, batch_shape + [1, 1])
+ for lower in -1, 0, 1, shape[-2] - 1:
+ for upper in -1, 0, 1, shape[-1] - 1:
+ band_np = mat
+ if lower >= 0:
+ band_np = np.triu(band_np, -lower)
+ if upper >= 0:
+ band_np = np.tril(band_np, upper)
+ if batch_shape:
+ band_np = np.tile(band_np, batch_shape + [1, 1])
+
+ placeholder = array_ops.placeholder(dtype)
+ with self.test_scope():
+ band = array_ops.matrix_band_part(
+ placeholder,
+ constant_op.constant(lower, dtype=dtypes.int32),
+ constant_op.constant(upper, dtype=dtypes.int32))
+ feed_dict = {placeholder: batch_mat}
+ self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
+
+ def testMatrixBandPart(self):
+ for dtype in self.float_types:
+ for batch_shape in [[], [2,], [1, 3, 2]]:
+ for rows in 1, 2, 7:
+ for cols in 1, 2, 7:
+ self._testMatrixBandPart(dtype, batch_shape + [rows, cols])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 67be1a4ba6..e9be6f8476 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -44,6 +44,8 @@ tf_kernel_library(
"l2loss_op.cc",
"lrn_ops.cc",
"matmul_op.cc",
+ "matrix_band_part_op.cc",
+ "matrix_set_diag_op.cc",
"matrix_triangular_solve_op.cc",
"mirror_pad_op.cc",
"no_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
new file mode 100644
index 0000000000..faa415a97b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace {
+
+class MatrixBandPartOp : public XlaOpKernel {
+ public:
+ explicit MatrixBandPartOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
+ errors::InvalidArgument(
+ "input must be at least 2-dim, received shape: ",
+ input_shape.DebugString()));
+
+ const TensorShape num_lower_in_shape = context->InputShape(1);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape),
+ errors::InvalidArgument("num_lower must be scalar, got shape ",
+ num_lower_in_shape.DebugString()));
+
+ const TensorShape num_upper_in_shape = context->InputShape(2);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape),
+ errors::InvalidArgument("num_upper must be scalar, got shape ",
+ num_upper_in_shape.DebugString()));
+
+ xla::ComputationBuilder* builder = context->builder();
+ xla::ComputationDataHandle input = context->Input(0);
+ xla::ComputationDataHandle num_lower = context->Input(1);
+ xla::ComputationDataHandle num_upper = context->Input(2);
+ DataType input_type = context->input_type(0);
+ DataType index_type = context->input_type(1);
+
+ TensorShape batch_shape = input_shape;
+ batch_shape.RemoveLastDims(2);
+ const int64 m = input_shape.dim_size(input_shape.dims() - 2);
+ const int64 n = input_shape.dim_size(input_shape.dims() - 1);
+
+ // Compute 'offset', which is how many diagonals we are above/below the
+ // diagonal.
+ xla::ComputationDataHandle iota_m;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
+
+ xla::ComputationDataHandle iota_n;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
+
+ auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
+ /*broadcast_dimensions=*/{0});
+
+ // If num_lower or num_upper are negative, include all lower/upper
+ // diagonals.
+ auto zero_index = XlaHelpers::Zero(builder, index_type);
+ num_lower = builder->Select(
+ builder->Lt(num_lower, zero_index),
+ XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower);
+ num_upper = builder->Select(
+ builder->Lt(num_upper, zero_index),
+ XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper);
+
+ auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset),
+ builder->Le(offset, num_upper));
+ indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
+
+ auto zero_input = XlaHelpers::Zero(builder, input_type);
+ auto output = builder->Select(
+ indicator, input,
+ builder->Broadcast(zero_input, input_shape.dim_sizes()));
+
+ context->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp);
+};
+REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
new file mode 100644
index 0000000000..b2940bdcff
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -0,0 +1,93 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+
+class MatrixSetDiagOp : public XlaOpKernel {
+ public:
+ explicit MatrixSetDiagOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ const TensorShape diag_shape = context->InputShape(1);
+
+ const int rank = input_shape.dims();
+
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
+ errors::InvalidArgument(
+ "input must be at least 2-dim, received shape: ",
+ input_shape.DebugString()));
+
+ // Check to make sure the last dimension of diag is equal to the smaller of
+ // the last two dimensions of input.
+ const int64 m = input_shape.dim_size(rank - 2);
+ const int64 n = input_shape.dim_size(rank - 1);
+ const int64 min_dim = std::min(m, n);
+
+ TensorShape batch_shape = input_shape;
+ batch_shape.RemoveLastDims(2);
+
+ TensorShape expected_diag_shape = batch_shape;
+ expected_diag_shape.AddDim(min_dim);
+ OP_REQUIRES(context, expected_diag_shape == diag_shape,
+ errors::InvalidArgument(
+ "must have diagonal.shape == input.shape[:-2] + "
+ "min(input.shape[-2:]), but received input shape: ",
+ input_shape.DebugString(),
+ " and diagonal shape: ", diag_shape.DebugString()));
+
+ xla::ComputationBuilder* builder = context->builder();
+ xla::ComputationDataHandle input = context->Input(0);
+ xla::ComputationDataHandle diag = context->Input(1);
+
+ auto zero = XlaHelpers::Zero(builder, context->input_type(0));
+
+ // Create an indicator tensor that is true only on the diagonal.
+ xla::ComputationDataHandle iota_m;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
+ xla::ComputationDataHandle iota_n;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
+ auto indicator = builder->Eq(iota_m,
+ builder->Broadcast(iota_n, {m}),
+ /*broadcast_dimensions=*/{0});
+ indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
+
+ // Broadcast diag up to the input shape. Use an implicit broadcast (Add)
+ // because we need to broadcast on the right.
+ std::vector<int64> diag_broadcast_dims(rank - 1);
+ std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0);
+ if (min_dim != m) {
+ diag_broadcast_dims.back() = rank - 1;
+ }
+ diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()),
+ /*broadcast_dimensions=*/diag_broadcast_dims);
+
+ auto output = builder->Select(indicator, diag, input);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
+};
+
+REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc
index d7fff4bb0c..1439141f64 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op.cc
@@ -62,7 +62,15 @@ class MatrixBandPartOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
errors::InvalidArgument("num_lower must be scalar, got shape ",
num_lower_in.shape().DebugString()));
- const int64 num_lower = num_lower_in.scalar<int64>()();
+
+ auto as_int64_scalar = [](const Tensor& tensor) -> int64 {
+ if (tensor.dtype() == DT_INT32) {
+ return tensor.scalar<int32>()();
+ } else {
+ return tensor.scalar<int64>()();
+ }
+ };
+ const int64 num_lower = as_int64_scalar(num_lower_in);
OP_REQUIRES(
context, num_lower <= input_reshaped.dimension(1),
errors::InvalidArgument(
@@ -73,7 +81,7 @@ class MatrixBandPartOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in.shape().DebugString()));
- const int64 num_upper = num_upper_in.scalar<int64>()();
+ const int64 num_upper = as_int64_scalar(num_upper_in);
OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2),
errors::InvalidArgument("num_upper must be negative or less or "
"equal to number of columns (",
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index fb9e8ad50c..87dfa77689 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -701,10 +701,11 @@ REGISTER_OP("MatrixDiagPart")
// --------------------------------------------------------------------------
REGISTER_OP("MatrixBandPart")
.Input("input: T")
- .Input("num_lower: int64")
- .Input("num_upper: int64")
+ .Input("num_lower: Tindex")
+ .Input("num_upper: Tindex")
.Output("band: T")
.Attr("T: type")
+ .Attr("Tindex: {int32, int64} = DT_INT64")
.SetShapeFn(shape_inference::UnchangedShape);
// --------------------------------------------------------------------------
diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
index 317b8dc05b..68d626de2c 100644
--- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
@@ -21,6 +21,7 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -54,9 +55,13 @@ def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_):
band_np = np.tril(band_np, upper)
if batch_shape_ is not ():
band_np = np.tile(band_np, batch_shape_ + (1, 1))
- with self.test_session(use_gpu=False):
- band = array_ops.matrix_band_part(batch_mat, lower, upper)
- self.assertAllEqual(band_np, band.eval())
+ for index_dtype in [dtypes_lib.int32, dtypes_lib.int64]:
+ with self.test_session(use_gpu=False):
+ band = array_ops.matrix_band_part(
+ batch_mat,
+ constant_op.constant(lower, index_dtype),
+ constant_op.constant(upper, index_dtype))
+ self.assertAllEqual(band_np, band.eval())
return Test