diff options
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/compiler/tests/segment_reduction_ops_test.py | 139 | ||||
-rw-r--r-- | tensorflow/compiler/tests/tensor_array_ops_test.py | 4 |
3 files changed, 156 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4f0137e8d9..c693f58f8b 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -354,6 +354,20 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "segment_reduction_ops_test", + size = "small", + srcs = ["segment_reduction_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:math_ops_gen", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( name = "spacetobatch_op_test", size = "medium", srcs = ["spacetobatch_op_test.py"], diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py new file mode 100644 index 0000000000..260a04421b --- /dev/null +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for segment reduction ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class SegmentReductionOpsTest(XLATestCase): + """Test cases for segment reduction ops.""" + + def UnsortedSegmentSum(self, data, indices, num_segments): + with self.test_session() as sess, self.test_scope(): + d = array_ops.placeholder(data.dtype, shape=data.shape) + if isinstance(indices, int): + i = array_ops.placeholder(np.int32, shape=[]) + else: + i = array_ops.placeholder(indices.dtype, shape=indices.shape) + return sess.run( + math_ops.unsorted_segment_sum(d, i, num_segments), + {d: data, + i: indices}) + + def testUnsortedSegmentSum0DIndices1DData(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array( + [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5], + [0, 0, 0, 0, 0, 0]], + dtype=dtype), + self.UnsortedSegmentSum( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4)) + + def testUnsortedSegmentSum1DIndices1DData(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([1, 3, 2, 9], dtype=dtype), + self.UnsortedSegmentSum( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) + + def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): + for dtype in self.numeric_types: + data = np.array( + [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43], + [50, 51, 52, 53]], + dtype=dtype) + indices = np.array([8, 1, 0, 3, 7], dtype=np.int32) + num_segments = 10 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0], + [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], + [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]], + dtype=dtype), y) + + def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self): + for dtype in self.numeric_types: + data = np.array( + [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43], + [50, 51, 52, 53]], + dtype=dtype) + indices = np.array([0, 1, 2, 0, 1], dtype=np.int32) + num_segments = 4 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33], + [0, 0, 0, 0]], + dtype=dtype), y) + + def testUnsortedSegmentSum2DIndices3DData(self): + for dtype in self.numeric_types: + data = np.array( + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], + [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], + [310, 311, 312]]], + dtype=dtype) + indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) + num_segments = 8 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[210, 211, 212], [110, 111, 112], [310, 311, 312], + [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301, + 302], [0, 0, 0]], + dtype=dtype), y) + + def testUnsortedSegmentSum1DIndices3DData(self): + for dtype in self.numeric_types: + data = np.array( + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], + [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], + [310, 311, 312]]], + dtype=dtype) + indices = np.array([3, 0, 2, 5], dtype=np.int32) + num_segments = 6 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], + [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]], + [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]], + dtype=dtype), y) + + def testUnsortedSegmentSumShapeError(self): + for dtype in self.numeric_types: + data = np.ones((4, 8, 7), dtype=dtype) + indices = np.ones((3, 2), dtype=np.int32) + num_segments = 4 + self.assertRaises(ValueError, + functools.partial(self.UnsortedSegmentSum, data, + indices, num_segments)) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index f277314352..ac039e0162 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -57,11 +57,13 @@ class TensorArrayTest(xla_test.XLATestCase): r0 = w2.read(0) r1 = w2.read(1) r2 = w2.read(2) + flow = w2.flow - d0, d1, d2 = session.run([r0, r1, r2]) + d0, d1, d2, flow_val = session.run([r0, r1, r2, flow]) self.assertAllEqual([[4.0, 5.0]], d0) self.assertAllEqual([[1.0, 3.0]], d1) self.assertAllEqual([[7.0, -8.5]], d2) + self.assertAllEqual([], flow_val.shape) def _testTensorArrayWritePack(self, tf_dtype): with self.test_session(), self.test_scope(): |