aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/BUILD14
-rw-r--r--tensorflow/compiler/tests/segment_reduction_ops_test.py139
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py4
3 files changed, 156 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 4f0137e8d9..c693f58f8b 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -354,6 +354,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "segment_reduction_ops_test",
+ size = "small",
+ srcs = ["segment_reduction_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "spacetobatch_op_test",
size = "medium",
srcs = ["spacetobatch_op_test.py"],
diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py
new file mode 100644
index 0000000000..260a04421b
--- /dev/null
+++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py
@@ -0,0 +1,139 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for segment reduction ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class SegmentReductionOpsTest(XLATestCase):
+ """Test cases for segment reduction ops."""
+
+ def UnsortedSegmentSum(self, data, indices, num_segments):
+ with self.test_session() as sess, self.test_scope():
+ d = array_ops.placeholder(data.dtype, shape=data.shape)
+ if isinstance(indices, int):
+ i = array_ops.placeholder(np.int32, shape=[])
+ else:
+ i = array_ops.placeholder(indices.dtype, shape=indices.shape)
+ return sess.run(
+ math_ops.unsorted_segment_sum(d, i, num_segments),
+ {d: data,
+ i: indices})
+
+ def testUnsortedSegmentSum0DIndices1DData(self):
+ for dtype in self.numeric_types:
+ self.assertAllClose(
+ np.array(
+ [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5],
+ [0, 0, 0, 0, 0, 0]],
+ dtype=dtype),
+ self.UnsortedSegmentSum(
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4))
+
+ def testUnsortedSegmentSum1DIndices1DData(self):
+ for dtype in self.numeric_types:
+ self.assertAllClose(
+ np.array([1, 3, 2, 9], dtype=dtype),
+ self.UnsortedSegmentSum(
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
+ np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))
+
+ def testUnsortedSegmentSum1DIndices2DDataDisjoint(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
+ [50, 51, 52, 53]],
+ dtype=dtype)
+ indices = np.array([8, 1, 0, 3, 7], dtype=np.int32)
+ num_segments = 10
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0],
+ [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
+ [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
+ [50, 51, 52, 53]],
+ dtype=dtype)
+ indices = np.array([0, 1, 2, 0, 1], dtype=np.int32)
+ num_segments = 4
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33],
+ [0, 0, 0, 0]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSum2DIndices3DData(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
+ [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
+ [310, 311, 312]]],
+ dtype=dtype)
+ indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32)
+ num_segments = 8
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[210, 211, 212], [110, 111, 112], [310, 311, 312],
+ [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301,
+ 302], [0, 0, 0]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSum1DIndices3DData(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
+ [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
+ [310, 311, 312]]],
+ dtype=dtype)
+ indices = np.array([3, 0, 2, 5], dtype=np.int32)
+ num_segments = 6
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]],
+ [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]],
+ [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSumShapeError(self):
+ for dtype in self.numeric_types:
+ data = np.ones((4, 8, 7), dtype=dtype)
+ indices = np.ones((3, 2), dtype=np.int32)
+ num_segments = 4
+ self.assertRaises(ValueError,
+ functools.partial(self.UnsortedSegmentSum, data,
+ indices, num_segments))
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index f277314352..ac039e0162 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -57,11 +57,13 @@ class TensorArrayTest(xla_test.XLATestCase):
r0 = w2.read(0)
r1 = w2.read(1)
r2 = w2.read(2)
+ flow = w2.flow
- d0, d1, d2 = session.run([r0, r1, r2])
+ d0, d1, d2, flow_val = session.run([r0, r1, r2, flow])
self.assertAllEqual([[4.0, 5.0]], d0)
self.assertAllEqual([[1.0, 3.0]], d1)
self.assertAllEqual([[7.0, -8.5]], d2)
+ self.assertAllEqual([], flow_val.shape)
def _testTensorArrayWritePack(self, tf_dtype):
with self.test_session(), self.test_scope():