aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/BUILD13
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py21
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/decode_jpeg_op_test.py1
-rw-r--r--tensorflow/python/kernel_tests/io_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py16
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py138
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py1
-rw-r--r--tensorflow/python/kernel_tests/tensordot_op_test.py54
-rw-r--r--tensorflow/python/kernel_tests/topk_op_test.py2
11 files changed, 220 insertions, 44 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c87b7652ad..3a6058054b 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1602,6 +1602,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "manip_ops_test",
+ size = "small",
+ srcs = ["manip_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:manip_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+cuda_py_test(
name = "matmul_op_test",
size = "small",
srcs = ["matmul_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index aae6d0a36e..7ec4624310 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1162,6 +1162,27 @@ class InvertPermutationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1])
+class UnravelIndexTest(test_util.TensorFlowTestCase):
+
+ def testUnravelIndex(self):
+ with self.test_session():
+ for dtype in [dtypes.int32, dtypes.int64]:
+ indices_1 = constant_op.constant(1621, dtype=dtype)
+ dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
+ out_1 = array_ops.unravel_index(indices_1, dims_1)
+ self.assertAllEqual(out_1.eval(), [3, 1, 4, 1])
+
+ indices_2 = constant_op.constant([1621], dtype=dtype)
+ dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
+ out_2 = array_ops.unravel_index(indices_2, dims_2)
+ self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]])
+
+ indices_3 = constant_op.constant([22, 41, 37], dtype=dtype)
+ dims_3 = constant_op.constant([7, 6], dtype=dtype)
+ out_3 = array_ops.unravel_index(indices_3, dims_3)
+ self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]])
+
+
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
def testSimple(self):
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 030c690167..16e56349c4 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -454,18 +454,19 @@ class ZerosLikeTest(test.TestCase):
def testZerosLikeCPU(self):
for dtype in [
- dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int8,
- dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16, dtypes_lib.int32,
- dtypes_lib.int64, dtypes_lib.bool, dtypes_lib.complex64,
- dtypes_lib.complex128, dtypes_lib.string
+ dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64,
+ dtypes_lib.int8, dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16,
+ dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.bool,
+ dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.string
]:
self._compareZeros(dtype, fully_defined_shape=False, use_gpu=False)
self._compareZeros(dtype, fully_defined_shape=True, use_gpu=False)
def testZerosLikeGPU(self):
for dtype in [
- dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
- dtypes_lib.bool, dtypes_lib.int64, dtypes_lib.string
+ dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64,
+ dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.complex64,
+ dtypes_lib.complex128, dtypes_lib.bool
]:
self._compareZeros(dtype, fully_defined_shape=False, use_gpu=True)
self._compareZeros(dtype, fully_defined_shape=True, use_gpu=True)
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 3e9bd3dade..edfb20d6a2 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -24,6 +24,7 @@ import time
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib import layers
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
@@ -519,7 +520,7 @@ class Conv2DTest(test.TestCase):
dilations=[2, 2],
padding="VALID")
- # TODO this currently fails.
+ # TODO(yzhwang): this currently fails.
# self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
# filter_in_sizes=[2, 2, 1, 1],
# strides=[4, 4], padding="SAME",
diff --git a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
index ead55cd03b..89fd26c544 100644
--- a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import os
import time
+from six.moves import xrange
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py
index f91875c6f0..61944f7e31 100644
--- a/tensorflow/python/kernel_tests/io_ops_test.py
+++ b/tensorflow/python/kernel_tests/io_ops_test.py
@@ -1,4 +1,4 @@
-# -*- coding: utf-8 -*-
+# -*- coding: utf-8 -*-
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 00c6706593..197dbf44af 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -953,14 +953,14 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
# Compute the expected loss 'manually'.
total = np.zeros((batch_size,))
for b in range(batch_size):
- for i in range(dims):
- for j in range(dims):
+ for i in range(dims - 1):
+ for j in range(i + 1, dims):
x = self._predictions[b, i].item() - self._predictions[b, j].item()
y = self._labels[b, i].item() - self._labels[b, j].item()
diff = (x - y)
total[b] += (diff * diff)
- self._expected_losses = np.divide(total, 9.0)
+ self._expected_losses = np.divide(total, 3.0)
def testValueErrorThrownWhenWeightIsNone(self):
with self.test_session():
@@ -1059,8 +1059,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
[[4, 8, 12], [1, 2, 3], [4, 5, 6]],
[[8, 1, 3], [7, 8, 9], [10, 11, 12]],
])
- self._test_valid_weights(
- labels, predictions, expected_loss=122.22222)
+ self._test_valid_weights(labels, predictions, expected_loss=137.5)
def test3dWeightedScalar(self):
labels = np.array([
@@ -1073,8 +1072,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
])
weight = 3.0
self._test_valid_weights(
- labels, predictions, expected_loss=weight * 122.22222,
- weights=weight)
+ labels, predictions, expected_loss=weight * 137.5, weights=weight)
def _test_invalid_weights(
self, labels, predictions, weights=1.0):
@@ -1124,7 +1122,9 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
])
self._test_valid_weights(
# TODO(ptucker): This doesn't look right.
- labels, predictions, expected_loss=9 * 122.22222,
+ labels,
+ predictions,
+ expected_loss=9 * 137.5,
weights=np.ones((2, 3, 3)))
def testLossWithAllZeroBatchSpecificWeights(self):
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
new file mode 100644
index 0000000000..b8200ac0cb
--- /dev/null
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -0,0 +1,138 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for manip_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import manip_ops
+from tensorflow.python.platform import test as test_lib
+
+# pylint: disable=g-import-not-at-top
+try:
+ from distutils.version import StrictVersion as Version
+ # numpy.roll for multiple shifts was introduced in numpy version 1.12.0
+ NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0")
+except ImportError:
+ NP_ROLL_CAN_MULTISHIFT = False
+# pylint: enable=g-import-not-at-top
+
+
+class RollTest(test_util.TensorFlowTestCase):
+
+ def _testRoll(self, np_input, shift, axis):
+ expected_roll = np.roll(np_input, shift, axis)
+ with self.test_session():
+ roll = manip_ops.roll(np_input, shift, axis)
+ self.assertAllEqual(roll.eval(), expected_roll)
+
+ def _testGradient(self, np_input, shift, axis):
+ with self.test_session():
+ inx = constant_op.constant(np_input.tolist())
+ xs = list(np_input.shape)
+ y = manip_ops.roll(inx, shift, axis)
+ # Expected y's shape to be the same
+ ys = xs
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, xs, y, ys, x_init_value=np_input)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _testAll(self, np_input, shift, axis):
+ self._testRoll(np_input, shift, axis)
+ if np_input.dtype == np.float32:
+ self._testGradient(np_input, shift, axis)
+
+ def testIntTypes(self):
+ for t in [np.int32, np.int64]:
+ self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ self._testAll(
+ np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -2, 3],
+ [0, 1, 2])
+ self._testAll(
+ np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2],
+ [1, 2, 3])
+
+ def testFloatTypes(self):
+ for t in [np.float32, np.float64]:
+ self._testAll(np.random.rand(5).astype(t), 2, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ self._testAll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0])
+ self._testAll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2])
+
+ def testComplexTypes(self):
+ for t in [np.complex64, np.complex128]:
+ x = np.random.rand(4, 4).astype(t)
+ self._testAll(x + 1j * x, 2, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ x = np.random.rand(2, 5).astype(t)
+ self._testAll(x + 1j * x, [1, 2], [1, 0])
+ x = np.random.rand(3, 2, 1, 1).astype(t)
+ self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
+
+ def testRollInputMustVectorHigherRaises(self):
+ tensor = 7
+ shift = 1
+ axis = 0
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "input must be 1-D or higher"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollAxisMustBeScalarOrVectorRaises(self):
+ tensor = [[1, 2], [3, 4]]
+ shift = 1
+ axis = [[0, 1]]
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "axis must be a scalar or a 1-D vector"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollShiftMustBeScalarOrVectorRaises(self):
+ tensor = [[1, 2], [3, 4]]
+ shift = [[0, 1]]
+ axis = 1
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "shift must be a scalar or a 1-D vector"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollShiftAndAxisMustBeSameSizeRaises(self):
+ tensor = [[1, 2], [3, 4]]
+ shift = [1]
+ axis = [0, 1]
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "shift and axis must have the same size"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollAxisOutOfRangeRaises(self):
+ tensor = [1, 2]
+ shift = 1
+ axis = 1
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "is out of range"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+
+if __name__ == "__main__":
+ test_lib.main()
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 0c77d1db92..daa42938e6 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -23,6 +23,7 @@ import timeit
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index f1670a47f5..8ad29afd0a 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -66,7 +66,7 @@ class TensordotTest(test_lib.TestCase):
a = [[1, 2], [3, 4]]
b = [[1, 2], [3, 4]]
# Invalid static axes.
- for axes_value in -1, 0, [1], [[1]], [[1], [0, 1]]:
+ for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]:
with self.assertRaises(ValueError):
math_ops.tensordot(a, b, axes_value)
@@ -91,7 +91,7 @@ class TensordotTest(test_lib.TestCase):
# Test case for 11950
def test_valid_axis(self):
- for axes_value in [1, 2], [[1], [2]]:
+ for axes_value in [1, 2], [[1], [2]], [[], []], 0:
with self.test_session() as sess:
np_a = np.ones((3, 3))
np_b = np.array([2, 3, 1])[None, None]
@@ -105,29 +105,29 @@ class TensordotTest(test_lib.TestCase):
self.assertAllEqual(tf_ans, np_ans)
def test_partial_shape_inference(self):
- a = array_ops.placeholder(dtypes.float32)
- b = array_ops.placeholder(dtypes.float32)
- axes = ([1], [0])
- output = math_ops.tensordot(a, b, axes)
- self.assertEqual(output.get_shape().ndims, None)
- a.set_shape([None, 2])
- b.set_shape([2, 3])
- output = math_ops.tensordot(a, b, axes)
- output_shape = output.get_shape()
- self.assertEqual(output_shape.ndims, 2)
- output_shape = output_shape.as_list()
- self.assertEqual(output_shape[0], None)
- self.assertEqual(output_shape[1], 3)
- a = array_ops.placeholder(dtypes.float32)
- b = array_ops.placeholder(dtypes.float32)
- a.set_shape([2, 2])
- b.set_shape([2, None])
- output = math_ops.tensordot(a, b, axes)
- output_shape = output.get_shape()
- self.assertEqual(output_shape.ndims, 2)
- output_shape = output_shape.as_list()
- self.assertEqual(output_shape[0], 2)
- self.assertEqual(output_shape[1], None)
+ for axes in ([1], [0]), 1:
+ a = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ output = math_ops.tensordot(a, b, axes)
+ self.assertEqual(output.get_shape().ndims, None)
+ a.set_shape([None, 2])
+ b.set_shape([2, 3])
+ output = math_ops.tensordot(a, b, axes)
+ output_shape = output.get_shape()
+ self.assertEqual(output_shape.ndims, 2)
+ output_shape = output_shape.as_list()
+ self.assertEqual(output_shape[0], None)
+ self.assertEqual(output_shape[1], 3)
+ a = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ a.set_shape([2, 2])
+ b.set_shape([2, None])
+ output = math_ops.tensordot(a, b, axes)
+ output_shape = output.get_shape()
+ self.assertEqual(output_shape.ndims, 2)
+ output_shape = output_shape.as_list()
+ self.assertEqual(output_shape[0], 2)
+ self.assertEqual(output_shape[1], None)
def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
@@ -196,8 +196,8 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
b_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
- all_axes = [1]
- if a_np.ndim > 1:
+ all_axes = [0, 1]
+ if a_np.ndim > 2:
all_axes.append(a_np.ndim - 1)
for axes in all_axes:
np_ans = np.tensordot(a_np, b_np, axes=axes)
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py
index efb5b9f364..6ab931fdb9 100644
--- a/tensorflow/python/kernel_tests/topk_op_test.py
+++ b/tensorflow/python/kernel_tests/topk_op_test.py
@@ -58,7 +58,7 @@ class TopKTest(test.TestCase):
# Do some special casing of equality of indices: if indices
# are not the same, but values are floating type, ensure that
# the values are within epsilon of each other.
- if not np.issubdtype(np_expected_values.dtype, np.float):
+ if not np.issubdtype(np_expected_values.dtype, np.floating):
# Values are not floating point type; check indices exactly
self.assertAllEqual(np_expected_indices, indices)
else: