aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/binary_ops_test.py
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-02-01 06:47:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-01 17:03:04 -0800
commit47fcca75bc8ec9e3c9d484e055c94facef280e21 (patch)
tree263bbebd7ced8527197f19204ea9e26c49a500e0 /tensorflow/compiler/tests/binary_ops_test.py
parentbaf490ba79acaacb458078370e4bad1c3fd17563 (diff)
[TF:XLA] Implement MatrixSetDiag and MatrixBandPart.
Add support for int32 indices to the MatrixBandPart operator. PiperOrigin-RevId: 184133343
Diffstat (limited to 'tensorflow/compiler/tests/binary_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 16856bd736..9d34cdfe10 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1181,6 +1181,50 @@ class BinaryOpsTest(XLATestCase):
np.array([4, 5, 6], dtype=np.int32),
expected=None)
+ def testMatrixSetDiag(self):
+ for dtype in self.numeric_types:
+ # Square
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
+ dtype=dtype),
+ np.array([1.0, 2.0, 3.0], dtype=dtype),
+ expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]],
+ dtype=dtype))
+
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
+ [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
+ dtype=dtype),
+ np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype),
+ expected=np.array(
+ [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]],
+ [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]],
+ dtype=dtype))
+
+ # Rectangular
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
+ np.array([3.0, 4.0], dtype=dtype),
+ expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
+
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
+ np.array([3.0, 4.0], dtype=dtype),
+ expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
+
+ self._testBinary(
+ array_ops.matrix_set_diag,
+ np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
+ [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
+ np.array([[-1.0, -2.0], [-4.0, -5.0]],
+ dtype=dtype),
+ expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
+ [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
+ dtype=dtype))
if __name__ == "__main__":
googletest.main()