aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/BUILD26
-rw-r--r--tensorflow/python/kernel_tests/broadcast_to_ops_test.py85
-rw-r--r--tensorflow/python/kernel_tests/confusion_matrix_test.py7
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/conv3d_transpose_test.py12
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py55
-rw-r--r--tensorflow/python/kernel_tests/norm_op_test.py16
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py32
-rw-r--r--tensorflow/python/kernel_tests/random/multinomial_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/random/random_ops_test.py11
-rw-r--r--tensorflow/python/kernel_tests/string_strip_op_test.py56
11 files changed, 284 insertions, 23 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index ebbec39cf3..c03c514699 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -918,6 +918,20 @@ tf_py_test(
)
tf_py_test(
+ name = "string_strip_op_test",
+ size = "small",
+ srcs = ["string_strip_op_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:string_ops",
+ ],
+)
+
+tf_py_test(
name = "substr_op_test",
size = "small",
srcs = ["substr_op_test.py"],
@@ -1196,6 +1210,18 @@ cuda_py_test(
)
cuda_py_test(
+ name = "broadcast_to_ops_test",
+ size = "small",
+ srcs = ["broadcast_to_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+cuda_py_test(
name = "inplace_ops_test",
size = "small",
srcs = ["inplace_ops_test.py"],
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
new file mode 100644
index 0000000000..6a1bd958ba
--- /dev/null
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -0,0 +1,85 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for broadcast_to ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test as test_lib
+
+
+class BroadcastToTest(test_util.TensorFlowTestCase):
+
+ def testBroadcastToBasic(self):
+ for dtype in [np.uint8, np.uint16, np.int8, np.int16, np.int32, np.int64]:
+ with self.test_session(use_gpu=True):
+ x = np.array([1, 2, 3], dtype=dtype)
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToString(self):
+ with self.test_session(use_gpu=True):
+ x = np.array([b"1", b"2", b"3"])
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToBool(self):
+ with self.test_session(use_gpu=True):
+ x = np.array([True, False, True], dtype=np.bool)
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToShape(self):
+ for input_dim in range(1, 6):
+ for output_dim in range(input_dim, 6):
+ with self.test_session(use_gpu=True):
+ input_shape = [2] * input_dim
+ output_shape = [2] * output_dim
+ x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
+ v_np = np.broadcast_to(x, output_shape)
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToScalar(self):
+ with self.test_session(use_gpu=True):
+ x = np.array(1, dtype=np.int32)
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToShapeTypeAndInference(self):
+ for dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=True):
+ x = np.array([1, 2, 3])
+ v_tf = array_ops.broadcast_to(
+ constant_op.constant(x),
+ constant_op.constant([3, 3], dtype=dtype))
+ shape = v_tf.get_shape().as_list()
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+ # check shape inference when shape input is constant
+ self.assertAllEqual(shape, v_np.shape)
+
+if __name__ == "__main__":
+ test_lib.main()
diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py
index 670a625f0f..79e419867d 100644
--- a/tensorflow/python/kernel_tests/confusion_matrix_test.py
+++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -104,11 +105,7 @@ class ConfusionMatrixTest(test.TestCase):
d, l, cm_out = sess.run([data, lab, cm], {m_neg: 0.0, m_pos: 1.0, s: 1.0})
truth = np.zeros([2, 2], dtype=np_dtype)
- try:
- range_builder = xrange
- except NameError: # In Python 3.
- range_builder = range
- for i in range_builder(len(d)):
+ for i in xrange(len(d)):
truth[l[i], d[i]] += 1
self.assertEqual(cm_out.dtype, np_dtype)
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 749313b00d..107ee37fab 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -65,6 +65,11 @@ class ConstantTest(test.TestCase):
self._testCpu(x)
self._testGpu(x)
+ def testInvalidDType(self):
+ # Test case for GitHub issue 18474
+ with self.assertRaises(TypeError):
+ constant_op.constant(dtypes_lib.string, "[,]")
+
def testBFloat16(self):
bfloat16 = dtypes_lib.bfloat16.as_numpy_dtype
self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(bfloat16))
diff --git a/tensorflow/python/kernel_tests/conv3d_transpose_test.py b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
index a8b3af5096..8973a450fa 100644
--- a/tensorflow/python/kernel_tests/conv3d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
@@ -119,6 +119,18 @@ class Conv3DTransposeTest(test.TestCase):
target = 3.0
self.assertAllClose(target, value[n, d, h, w, k])
+ def testConv3DTransposeShapeMismatch(self):
+ # Test case for GitHub issue 18460
+ x_shape = [2, 2, 3, 4, 3]
+ f_shape = [3, 3, 3, 2, 2]
+ y_shape = [2, 2, 6, 8, 6]
+ strides = [1, 1, 2, 2, 2]
+ np.random.seed(1)
+ x_value = np.random.random_sample(x_shape).astype(np.float64)
+ f_value = np.random.random_sample(f_shape).astype(np.float64)
+ nn_ops.conv3d_transpose(
+ x_value, f_value, y_shape, strides, data_format='NCDHW')
+
def testConv3DTransposeValid(self):
with self.test_session():
strides = [1, 2, 2, 2, 1]
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
index b8200ac0cb..f31426713c 100644
--- a/tensorflow/python/kernel_tests/manip_ops_test.py
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import manip_ops
from tensorflow.python.platform import test as test_lib
@@ -88,41 +90,78 @@ class RollTest(test_util.TensorFlowTestCase):
x = np.random.rand(3, 2, 1, 1).astype(t)
self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
+ def testNegativeAxis(self):
+ self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
+ self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
+ # Make sure negative axis shoudl be 0 <= axis + dims < dims
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "is out of range"):
+ manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
+ 3, -10).eval()
+
+ def testInvalidInputShape(self):
+ # The input should be 1-D or higher, checked in shape function.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be at least rank 1 but is rank 0"):
+ manip_ops.roll(7, 1, 0)
+
def testRollInputMustVectorHigherRaises(self):
- tensor = 7
+ # The input should be 1-D or higher, checked in kernel.
+ tensor = array_ops.placeholder(dtype=dtypes.int32)
shift = 1
axis = 0
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"input must be 1-D or higher"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
+
+ def testInvalidAxisShape(self):
+ # The axis should be a scalar or 1-D, checked in shape function.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be at most rank 1 but is rank 2"):
+ manip_ops.roll([[1, 2], [3, 4]], 1, [[0, 1]])
def testRollAxisMustBeScalarOrVectorRaises(self):
+ # The axis should be a scalar or 1-D, checked in kernel.
tensor = [[1, 2], [3, 4]]
shift = 1
- axis = [[0, 1]]
+ axis = array_ops.placeholder(dtype=dtypes.int32)
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"axis must be a scalar or a 1-D vector"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
+
+ def testInvalidShiftShape(self):
+ # The shift should be a scalar or 1-D, checked in shape function.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be at most rank 1 but is rank 2"):
+ manip_ops.roll([[1, 2], [3, 4]], [[0, 1]], 1)
def testRollShiftMustBeScalarOrVectorRaises(self):
+ # The shift should be a scalar or 1-D, checked in kernel.
tensor = [[1, 2], [3, 4]]
- shift = [[0, 1]]
+ shift = array_ops.placeholder(dtype=dtypes.int32)
axis = 1
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift must be a scalar or a 1-D vector"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
+
+ def testInvalidShiftAndAxisNotEqualShape(self):
+ # The shift and axis must be same size, checked in shape function.
+ with self.assertRaisesRegexp(ValueError, "both shapes must be equal"):
+ manip_ops.roll([[1, 2], [3, 4]], [1], [0, 1])
def testRollShiftAndAxisMustBeSameSizeRaises(self):
+ # The shift and axis must be same size, checked in kernel.
tensor = [[1, 2], [3, 4]]
- shift = [1]
+ shift = array_ops.placeholder(dtype=dtypes.int32)
axis = [0, 1]
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift and axis must have the same size"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
def testRollAxisOutOfRangeRaises(self):
tensor = [1, 2]
diff --git a/tensorflow/python/kernel_tests/norm_op_test.py b/tensorflow/python/kernel_tests/norm_op_test.py
index d85512fae6..3f71b326a2 100644
--- a/tensorflow/python/kernel_tests/norm_op_test.py
+++ b/tensorflow/python/kernel_tests/norm_op_test.py
@@ -37,17 +37,17 @@ class NormOpTest(test_lib.TestCase):
def testBadOrder(self):
matrix = [[0., 1.], [2., 3.]]
- for ord_ in "foo", -7, -1.1, 0:
+ for ord_ in "fro", -7, -1.1, 0:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported vector norm"):
- linalg_ops.norm(matrix, ord="fro")
+ linalg_ops.norm(matrix, ord=ord_)
- for ord_ in "foo", -7, -1.1, 0:
+ for ord_ in "fro", -7, -1.1, 0:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported vector norm"):
linalg_ops.norm(matrix, ord=ord_, axis=-1)
- for ord_ in 1.1, 2:
+ for ord_ in "foo", -7, -1.1, 1.1:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported matrix norm"):
linalg_ops.norm(matrix, ord=ord_, axis=[-2, -1])
@@ -69,14 +69,14 @@ def _GetNormOpTest(dtype_, shape_, ord_, axis_, keep_dims_, use_static_shape_):
if use_static_shape_:
tf_matrix = constant_op.constant(matrix)
tf_norm = linalg_ops.norm(
- tf_matrix, ord=ord_, axis=axis_, keep_dims=keep_dims_)
+ tf_matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
tf_norm_val = sess.run(tf_norm)
else:
tf_matrix = array_ops.placeholder(dtype_)
tf_norm = linalg_ops.norm(
- tf_matrix, ord=ord_, axis=axis_, keep_dims=keep_dims_)
+ tf_matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
tf_norm_val = sess.run(tf_norm, feed_dict={tf_matrix: matrix})
- self.assertAllClose(np_norm, tf_norm_val)
+ self.assertAllClose(np_norm, tf_norm_val, rtol=1e-5, atol=1e-5)
def Test(self):
is_matrix_norm = (isinstance(axis_, tuple) or
@@ -85,8 +85,6 @@ def _GetNormOpTest(dtype_, shape_, ord_, axis_, keep_dims_, use_static_shape_):
if ((not is_matrix_norm and ord_ == "fro") or
(is_matrix_norm and is_fancy_p_norm)):
self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
- if is_matrix_norm and ord_ == 2:
- self.skipTest("Not supported by tf.norm")
if ord_ == 'euclidean' or (axis_ is None and len(shape) > 2):
self.skipTest("Not supported by numpy.linalg.norm")
matrix = np.random.randn(*shape_).astype(dtype_)
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 5b508b7c0e..b9f44d728a 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -52,6 +52,38 @@ class PyFuncTest(test.TestCase):
"""Encapsulates tests for py_func and eager_py_func."""
# ----- Tests for py_func -----
+ def testRealDataTypes(self):
+ def sum_func(x, y):
+ return x + y
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
+ dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16,
+ dtypes.int32, dtypes.int64]:
+ with self.test_session():
+ x = constant_op.constant(1, dtype=dtype)
+ y = constant_op.constant(2, dtype=dtype)
+ z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
+ self.assertEqual(z, 3)
+
+ def testComplexDataTypes(self):
+ def sub_func(x, y):
+ return x - y
+ for dtype in [dtypes.complex64, dtypes.complex128]:
+ with self.test_session():
+ x = constant_op.constant(1 + 1j, dtype=dtype)
+ y = constant_op.constant(2 - 2j, dtype=dtype)
+ z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
+ self.assertEqual(z, -1 + 3j)
+
+ def testBoolDataTypes(self):
+ def and_func(x, y):
+ return x and y
+ dtype = dtypes.bool
+ with self.test_session():
+ x = constant_op.constant(True, dtype=dtype)
+ y = constant_op.constant(False, dtype=dtype)
+ z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype))
+ self.assertEqual(z, False)
+
def testSingleType(self):
with self.test_session():
x = constant_op.constant(1.0, dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
index a9dc7b7de0..051c7d86bf 100644
--- a/tensorflow/python/kernel_tests/random/multinomial_op_test.py
+++ b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
@@ -46,7 +46,7 @@ def composed_sampler(logits, num_samples):
logits = array_ops.expand_dims(logits, -1)
# [batch size, num samples]
- return math_ops.argmax(logits + noise, dimension=1)
+ return math_ops.argmax(logits + noise, axis=1)
native_sampler = random_ops.multinomial
diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py
index df37dd98ec..e4b5c3832a 100644
--- a/tensorflow/python/kernel_tests/random/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/random_ops_test.py
@@ -228,6 +228,17 @@ class RandomUniformTest(test.TestCase):
print("count = ", count)
self.assertTrue(count < count_limit)
+ def testUniformIntsWithInvalidShape(self):
+ for dtype in dtypes.int32, dtypes.int64:
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ random_ops.random_uniform(
+ [1000], minval=[1, 2], maxval=3, dtype=dtype)
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ random_ops.random_uniform(
+ [1000], minval=1, maxval=[2, 3], dtype=dtype)
+
# Check that uniform ints actually follow a uniform distribution.
def testUniformInts(self):
minv = -2
diff --git a/tensorflow/python/kernel_tests/string_strip_op_test.py b/tensorflow/python/kernel_tests/string_strip_op_test.py
new file mode 100644
index 0000000000..30fd477ff4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_strip_op_test.py
@@ -0,0 +1,56 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for string_strip_op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class StringStripOpTest(test.TestCase):
+ """ Test cases for tf.string_strip."""
+
+ def test_string_strip(self):
+ strings = ["pigs on the wing", "animals"]
+
+ with self.test_session() as sess:
+ output = string_ops.string_strip(strings)
+ output = sess.run(output)
+ self.assertAllEqual(output, [b"pigs on the wing", b"animals"])
+
+ def test_string_strip_2d(self):
+ strings = [["pigs on the wing", "animals"],
+ [" hello ", "\n\tworld \r \n"]]
+
+ with self.test_session() as sess:
+ output = string_ops.string_strip(strings)
+ output = sess.run(output)
+ self.assertAllEqual(output, [[b"pigs on the wing", b"animals"],
+ [b"hello", b"world"]])
+
+ def test_string_strip_with_empty_strings(self):
+ strings = [" hello ", "", "world ", " \t \r \n "]
+
+ with self.test_session() as sess:
+ output = string_ops.string_strip(strings)
+ output = sess.run(output)
+ self.assertAllEqual(output, [b"hello", b"", b"world", b""])
+
+
+if __name__ == "__main__":
+ test.main()