aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-05 14:05:27 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-05 14:05:27 -0800
commit1c579361cd1e088dd5e05a394b1561a73e3667ba (patch)
treeec464b9ac18113dc052744b6714eebbc7c6cc34d /tensorflow/python/kernel_tests
parent208350a6092f9faa473daf8b6eb6a80e9f9518f1 (diff)
Added 'logging' import to control_flow_ops which is used in the file but not imported.
Change: 110842260
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py9
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py7
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py24
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py27
-rw-r--r--tensorflow/python/kernel_tests/gradient_checker_test.py8
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py77
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py84
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py13
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py22
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py14
-rw-r--r--tensorflow/python/kernel_tests/sparse_ops_test.py239
-rw-r--r--tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py22
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py4
13 files changed, 538 insertions, 12 deletions
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index 0ea573932b..ab0676d9ec 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -364,5 +364,14 @@ class ConcatOpTest(tf.test.TestCase):
err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape)
self.assertLess(err, 1e-11)
+ def testConcatTuple(self):
+ c1 = np.random.rand(4, 4)
+ c2 = np.random.rand(4, 4)
+ with self.test_session():
+ concat_list_t = tf.concat(0, [c1, c2])
+ concat_tuple_t = tf.concat(0, (c1, c2))
+ self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 58302a683d..6de4c905b1 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -330,6 +330,13 @@ class ControlFlowTest(tf.test.TestCase):
result = exit_i.eval()
self.assertAllEqual(10, result)
+ def testCondBool(self):
+ values = tf.constant(10)
+ fn1 = lambda: tf.add(values, 1)
+ fn2 = lambda: tf.sub(values, 1)
+ with self.assertRaisesRegexp(TypeError, "must not be a Python bool"):
+ _ = control_flow_ops.cond(False, fn1, fn2)
+
def testCondIndexedSlices(self):
with self.test_session():
values = tf.constant(10)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index a823250d51..8f2720f1cf 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import math
+
import tensorflow.python.platform
import numpy as np
@@ -55,6 +57,11 @@ class UnaryOpTest(tf.test.TestCase):
tf_cpu = y.eval()
self.assertShapeEqual(np_ans, y)
self.assertAllClose(np_ans, tf_cpu)
+
+ # TODO(ebrevdo): add gradient for lgamma (digamma) and remove lgamma here.
+ if tf_func in (tf.lgamma,):
+ return # Return early
+
if x.dtype == np.float32:
s = list(np.shape(x))
jacob_t, jacob_n = tf.test.compute_gradient(inx,
@@ -94,6 +101,17 @@ class UnaryOpTest(tf.test.TestCase):
def _sigmoid(self, x):
return 1.0 / (1.0 + np.exp(-x))
+ def _replace_domain_error_with_inf(self, fn):
+ def func(x):
+ try:
+ return fn(x)
+ except ValueError, e:
+ if "domain error" in e.message:
+ return np.inf * np.ones_like(x)
+ else:
+ raise e
+ return func
+
def testFloatBasic(self):
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
y = (x + .5).astype(np.float32) # no zero
@@ -113,6 +131,12 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBoth(y, np.sign, tf.sign)
self._compareBoth(x, np.sin, tf.sin)
self._compareBoth(x, np.cos, tf.cos)
+ self._compareBoth(
+ x,
+ np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+ tf.lgamma)
+ self._compareBoth(x, np.vectorize(math.erf), tf.erf)
+ self._compareBoth(x, np.vectorize(math.erfc), tf.erfc)
def testFloatTanhEdge(self):
x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index f02e16a4ae..db8d4ba5c4 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -1124,6 +1124,33 @@ class FIFOQueueTest(tf.test.TestCase):
thread.join()
self.assertAllEqual(elem, results)
+ def testDtypes(self):
+ with self.test_session() as sess:
+ dtypes = [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8,
+ tf.int64, tf.bool, tf.complex64]
+ shape = (32, 4, 128)
+ q = tf.FIFOQueue(32, dtypes, [shape[1:]] * len(dtypes))
+
+ input_tuple = []
+ for dtype in dtypes:
+ np_dtype = dtype.as_numpy_dtype
+ np_array = np.random.randint(-10, 10, shape)
+ if dtype == tf.bool:
+ np_array = np_array > 0
+ elif dtype == tf.complex64:
+ np_array = np.sqrt(np_array.astype(np_dtype))
+ else:
+ np_array = np_array.astype(np_dtype)
+ input_tuple.append(np_array)
+
+ q.enqueue_many(input_tuple).run()
+
+ output_tuple_t = q.dequeue_many(32)
+ output_tuple = sess.run(output_tuple_t)
+
+ for (input_elem, output_elem) in zip(input_tuple, output_tuple):
+ self.assertAllEqual(input_elem, output_elem)
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/gradient_checker_test.py b/tensorflow/python/kernel_tests/gradient_checker_test.py
index 2ded0375a8..bcaaa8cc4e 100644
--- a/tensorflow/python/kernel_tests/gradient_checker_test.py
+++ b/tensorflow/python/kernel_tests/gradient_checker_test.py
@@ -27,6 +27,7 @@ import tensorflow as tf
class GradientCheckerTest(tf.test.TestCase):
def testAddSimple(self):
+ np.random.seed(1) # Fix seed to avoid flakiness
with self.test_session(use_gpu=False):
# a test case for Add operation
size = (2, 3)
@@ -40,6 +41,7 @@ class GradientCheckerTest(tf.test.TestCase):
assert error < 1e-4
def testAddSimpleGPU(self):
+ np.random.seed(2) # Fix seed to avoid flakiness
with self.test_session(use_gpu=True):
# a test case for Add operation
size = (2, 3)
@@ -53,6 +55,7 @@ class GradientCheckerTest(tf.test.TestCase):
assert error < 1e-4
def testAddCustomized(self):
+ np.random.seed(3) # Fix seed to avoid flakiness
with self.test_session():
# a test case for Add operation
size = (2, 3)
@@ -74,6 +77,7 @@ class GradientCheckerTest(tf.test.TestCase):
assert error < 1e-10
def testGather(self):
+ np.random.seed(4) # Fix seed to avoid flakiness
with self.test_session():
p_shape = (4, 2)
p_size = 8
@@ -89,6 +93,7 @@ class GradientCheckerTest(tf.test.TestCase):
assert error < 1e-4
def testNestedGather(self):
+ np.random.seed(5) # Fix seed to avoid flakiness
with self.test_session():
p_shape = (8, 2)
p_size = 16
@@ -110,6 +115,9 @@ class GradientCheckerTest(tf.test.TestCase):
# Gradient checker for MNIST.
def BuildAndTestMiniMNIST(param_index, tag):
+ # Fix seed to avoid occasional flakiness
+ np.random.seed(6)
+
# Hyperparameters
batch = 3
inputs = 16
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 5a0ffce6b4..a470fb7274 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -642,6 +642,60 @@ class ParseSequenceExampleTest(tf.test.TestCase):
"feature_list_dense_defaults": {"d": None},
}, expected_feat_list_values=expected_feature_list_output)
+ def testSequenceExampleWithSparseAndDenseFeatureLists(self):
+ feature_list_dense_keys = ["a"]
+ feature_list_dense_types = [tf.int64]
+ feature_list_dense_shapes = [(2,)]
+
+ original = sequence_example(feature_lists=feature_lists({
+ "a": feature_list([
+ int64_feature([3, 4]),
+ int64_feature([1, 0])]),
+ "st_a": feature_list([
+ float_feature([3.0, 4.0]),
+ float_feature([5.0]),
+ float_feature([])]),
+ "st_b": feature_list([
+ bytes_feature([b"a"]),
+ bytes_feature([]),
+ bytes_feature([]),
+ bytes_feature([b"b", b"c"])])}))
+
+ serialized = original.SerializeToString()
+
+ expected_st_a = (
+ np.array([[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
+ np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+
+ expected_st_b = (
+ np.array([[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
+ np.array(["a", "b", "c"], dtype=np.str), # values
+ np.array([4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
+
+ expected_st_c = (
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # values
+ np.array([0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
+
+ expected_feature_list_output = {
+ "a": np.array([[3, 4], [1, 0]], dtype=np.int64),
+ "st_a": expected_st_a,
+ "st_b": expected_st_b,
+ "st_c": expected_st_c,
+ }
+
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "feature_list_dense_types": feature_list_dense_types,
+ "feature_list_dense_keys": feature_list_dense_keys,
+ "feature_list_dense_shapes": feature_list_dense_shapes,
+ "feature_list_sparse_keys": ["st_a", "st_b", "st_c"],
+ "feature_list_sparse_types": [tf.float32, tf.string, tf.int64]
+ }, expected_feat_list_values=expected_feature_list_output)
+
def testSequenceExampleListWithInconsistentDataFails(self):
feature_list_dense_types = [tf.int64]
feature_list_dense_shapes = [(2,)]
@@ -687,6 +741,29 @@ class ParseSequenceExampleTest(tf.test.TestCase):
expected_err_re=("Feature list: a, Index: 0. Data types don't match. "
"Expected type: int64"))
+ def testSequenceExampleListWithWrongSparseDataTypeFails(self):
+ feature_list_sparse_types = [tf.int64]
+
+ original = sequence_example(feature_lists=feature_lists({
+ "a": feature_list([
+ int64_feature([3, 4]),
+ int64_feature([1, 2, 3]),
+ float_feature([2])])
+ }))
+
+ serialized = original.SerializeToString()
+
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "feature_list_sparse_types": feature_list_sparse_types,
+ "feature_list_sparse_keys": ["a"]
+ },
+ expected_err_re=(
+ "Name: in1, Feature List: a, Index: 2. Data types don't match. "
+ "Expected type: int64 Feature is: float_list"))
+
def testSequenceExampleListWithWrongShapeFails(self):
feature_list_dense_types = [tf.int64]
feature_list_dense_shapes = [(2,)]
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
new file mode 100644
index 0000000000..742402b8b7
--- /dev/null
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -0,0 +1,84 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for py_func op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import script_ops
+
+
+class PyOpTest(tf.test.TestCase):
+
+ def testBasic(self):
+
+ def my_func(x, y):
+ return np.sinh(x) + np.cosh(y)
+
+ # scalar
+ with self.test_session():
+ x = tf.constant(1.0, tf.float32)
+ y = tf.constant(2.0, tf.float32)
+ z = tf.py_func(my_func, [x, y], [tf.float32])
+ self.assertEqual(z[0].eval(), my_func(1.0, 2.0).astype(np.float32))
+
+ # array
+ with self.test_session():
+ x = tf.constant([1.0, 2.0], tf.float64)
+ y = tf.constant([2.0, 3.0], tf.float64)
+ z = tf.py_func(my_func, [x, y], [tf.float64])
+ self.assertAllEqual(
+ z[0].eval(),
+ my_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
+
+ # a bit exotic type (complex64)
+ with self.test_session():
+ x = tf.constant(1+2j, tf.complex64)
+ y = tf.constant(3+4j, tf.complex64)
+ z, = tf.py_func(my_func, [x, y], [tf.complex64])
+ self.assertAllClose(z.eval(), my_func(1+2j, 3+4j))
+
+ # a bit excotic function (rfft)
+ with self.test_session():
+ x = tf.constant([1., 2., 3., 4.], tf.float32)
+ def rfft(x):
+ return np.fft.rfft(x).astype(np.complex64)
+ y, = tf.py_func(rfft, [x], [tf.complex64])
+ self.assertAllClose(y.eval(), np.fft.rfft([1., 2., 3., 4.]))
+
+ def testLarge(self):
+ with self.test_session() as sess:
+ x = tf.zeros([1000000], dtype=np.float32)
+ y = tf.py_func(lambda x: x + 1, [x], [tf.float32])
+ z = tf.py_func(lambda x: x * 2, [x], [tf.float32])
+ for _ in xrange(100):
+ sess.run([y[0].op, z[0].op])
+
+ def testCleanup(self):
+ for _ in range(1000):
+ g = tf.Graph()
+ with g.as_default():
+ c = tf.constant([1.], tf.float32)
+ _ = tf.py_func(lambda x: x + 1, [c], [tf.float32])
+ self.assertTrue(script_ops._py_funcs.size() < 100)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index b1188d0672..2882182a03 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -235,7 +235,7 @@ class TextLineReaderTest(tf.test.TestCase):
def _LineText(self, f, l):
return tf.compat.as_bytes("%d: %d" % (f, l))
- def _CreateFiles(self):
+ def _CreateFiles(self, crlf=False):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
@@ -246,11 +246,10 @@ class TextLineReaderTest(tf.test.TestCase):
# Always include a newline after the record unless it is
# at the end of the file, in which case we include it sometimes.
if j + 1 != self._num_lines or i == 0:
- f.write(b"\n")
+ f.write(b"\r\n" if crlf else b"\n")
return filenames
- def testOneEpoch(self):
- files = self._CreateFiles()
+ def _testOneEpoch(self, files):
with self.test_session() as sess:
reader = tf.TextLineReader(name="test_reader")
queue = tf.FIFOQueue(99, [tf.string], shapes=())
@@ -268,6 +267,12 @@ class TextLineReaderTest(tf.test.TestCase):
"\\(requested 1, current size 0\\)"):
k, v = sess.run([key, value])
+ def testOneEpochLF(self):
+ self._testOneEpoch(self._CreateFiles(crlf=False))
+
+ def testOneEpochCRLF(self):
+ self._testOneEpoch(self._CreateFiles(crlf=True))
+
def testSkipHeaderLines(self):
files = self._CreateFiles()
with self.test_session() as sess:
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index 7ff3851da7..3b79ae341b 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -174,6 +174,28 @@ class SumReductionTest(tf.test.TestCase):
def testGradient4(self):
self._compareGradient([2, 3, 4, 2], [], None)
+ def testHighRank(self):
+ # Do a bunch of random high dimensional reductions
+ np.random.seed(42)
+ for _ in range(20):
+ rank = np.random.randint(4, 10 + 1)
+ axes, = np.nonzero(np.random.randint(2, size=rank))
+ shape = tuple(np.random.randint(1, 3 + 1, size=rank))
+ data = np.random.randint(1024, size=shape)
+ self._compareAll(data, axes)
+ # Check some particular axis patterns
+ for rank in 4, 7, 10:
+ shape = tuple(np.random.randint(1, 3 + 1, size=rank))
+ data = np.random.randint(1024, size=shape)
+ for axes in ([], np.arange(rank), np.arange(0, rank, 2),
+ np.arange(1, rank, 2)):
+ self._compareAll(data, axes)
+
+ def testExpand(self):
+ # Reduce an empty tensor to a nonempty tensor
+ x = np.zeros((5, 0))
+ self._compareAll(x, [1])
+
class MeanReductionTest(tf.test.TestCase):
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 81be48990b..38ba890c74 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -227,15 +227,23 @@ class TileTest(tf.test.TestCase):
def testSimple(self):
with self.test_session():
- inp = np.random.rand(4, 1).astype("f")
- a = tf.constant([float(x) for x in inp.ravel(order="C")],
- shape=[4, 1], dtype=tf.float32)
+ inp = np.random.rand(4, 1).astype(np.float32)
+ a = tf.constant(inp)
tiled = tf.tile(a, [1, 4])
result = tiled.eval()
self.assertEqual(result.shape, (4, 4))
self.assertEqual([4, 4], tiled.get_shape())
self.assertTrue((result == np.tile(inp, (1, 4))).all())
+ def testEmpty(self):
+ with self.test_session():
+ inp = np.random.rand(2, 3).astype(np.float32)
+ a = tf.constant(inp)
+ tiled = tf.tile(a, [5, 0])
+ result = tiled.eval()
+ self.assertEqual(result.shape, (10, 0))
+ self.assertEqual([10, 0], tiled.get_shape())
+
def testTypes(self):
types_to_test = {
"bool": (tf.bool, bool),
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
new file mode 100644
index 0000000000..d6ee16c8e2
--- /dev/null
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -0,0 +1,239 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for Python ops defined in sparse_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import, g-bad-import-order
+import tensorflow.python.platform
+# pylint: enable=unused-import, g-bad-import-order
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import googletest
+
+
+class SparseToIndicatorTest(test_util.TensorFlowTestCase):
+
+ def _SparseTensor_5x6(self, dtype):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array([0, 10, 13, 14, 32, 33])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtype),
+ constant_op.constant(shape, dtypes.int64))
+
+ def _SparseTensor_2x3x4(self, dtype):
+ # Includes two entries with the form [1, 1, x] : 150.
+ ind = np.array([
+ [0, 0, 1],
+ [0, 1, 0],
+ [0, 1, 2],
+ [1, 0, 3],
+ [1, 1, 0],
+ [1, 1, 1],
+ [1, 1, 2],
+ [1, 2, 2]])
+ val = np.array([1, 10, 12, 103, 150, 149, 150, 122])
+ shape = np.array([2, 3, 4])
+ return ops.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtype),
+ constant_op.constant(shape, dtypes.int64))
+
+ def testInt32(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_5x6(dtypes.int32)
+ output = sparse_ops.sparse_to_indicator(sp_input, 50).eval()
+
+ expected_output = np.zeros((5, 50), dtype=np.bool)
+ expected_trues = ((0, 0), (1, 10), (1, 13), (1, 14), (3, 32), (3, 33))
+ for expected_true in expected_trues:
+ expected_output[expected_true] = True
+
+ self.assertAllEqual(output, expected_output)
+
+ def testInt64(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_5x6(dtypes.int64)
+ output = sparse_ops.sparse_to_indicator(sp_input, 50).eval()
+
+ expected_output = np.zeros((5, 50), dtype=np.bool)
+ expected_trues = [(0, 0), (1, 10), (1, 13), (1, 14), (3, 32), (3, 33)]
+ for expected_true in expected_trues:
+ expected_output[expected_true] = True
+
+ self.assertAllEqual(output, expected_output)
+
+ def testHigherRank(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_2x3x4(dtypes.int64)
+ output = sparse_ops.sparse_to_indicator(sp_input, 200).eval()
+
+ expected_output = np.zeros((2, 3, 200), dtype=np.bool)
+ expected_trues = [(0, 0, 1), (0, 1, 10), (0, 1, 12),
+ (1, 0, 103), (1, 1, 149), (1, 1, 150),
+ (1, 2, 122)]
+ for expected_true in expected_trues:
+ expected_output[expected_true] = True
+
+ self.assertAllEqual(output, expected_output)
+
+
+class SparseRetainTest(test_util.TensorFlowTestCase):
+
+ def _SparseTensor_5x6(self):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array([0, 10, 13, 14, 32, 33])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int32),
+ constant_op.constant(shape, dtypes.int64))
+
+ def testBasic(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_5x6()
+ to_retain = np.array([1, 0, 0, 1, 1, 0], dtype=np.bool)
+ sp_output = sparse_ops.sparse_retain(sp_input, to_retain)
+
+ output = sess.run(sp_output)
+
+ self.assertAllEqual(output.indices, [[0, 0], [1, 4], [3, 2]])
+ self.assertAllEqual(output.values, [0, 14, 32])
+ self.assertAllEqual(output.shape, [5, 6])
+
+ def testRetainNone(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_5x6()
+ to_retain = np.zeros((6,), dtype=np.bool)
+ sp_output = sparse_ops.sparse_retain(sp_input, to_retain)
+
+ output = sess.run(sp_output)
+
+ self.assertAllEqual(output.indices, np.array([]).reshape((0, 2)))
+ self.assertAllEqual(output.values, [])
+ self.assertAllEqual(output.shape, [5, 6])
+
+ def testMismatchedRetainShape(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_5x6()
+ to_retain = np.array([1, 0, 0, 1, 0], dtype=np.bool)
+ with self.assertRaises(ValueError):
+ sparse_ops.sparse_retain(sp_input, to_retain)
+
+
+class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
+
+ def _SparseTensor_5x6(self):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array([0, 10, 13, 14, 32, 33])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int32),
+ constant_op.constant(shape, dtypes.int64))
+
+ def _SparseTensor_String5x6(self):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array(["a", "b", "c", "d", "e", "f"])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.string),
+ constant_op.constant(shape, dtypes.int64))
+
+ def _SparseTensor_2x6(self):
+ ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4]])
+ val = np.array([0, 10, 13, 14])
+ shape = np.array([2, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int32),
+ constant_op.constant(shape, dtypes.int64))
+
+ def testFillNumber(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_5x6()
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(
+ output.indices,
+ [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]])
+ self.assertAllEqual(output.values, [0, 10, 13, 14, -1, 32, 33, -1])
+ self.assertAllEqual(output.shape, [5, 6])
+ self.assertAllEqual(empty_row_indicator_out,
+ np.array([0, 0, 1, 0, 1]).astype(np.bool))
+
+ def testFillString(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_String5x6()
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, ""))
+
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(
+ output.indices,
+ [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]])
+ self.assertAllEqual(output.values,
+ [b"a", b"b", b"c", b"d", b"", b"e", b"f", b""])
+ self.assertAllEqual(output.shape, [5, 6])
+ self.assertAllEqual(empty_row_indicator_out,
+ np.array([0, 0, 1, 0, 1]).astype(np.bool))
+
+ def testNoEmptyRows(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_2x6()
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(output.indices, [[0, 0], [1, 0], [1, 3], [1, 4]])
+ self.assertAllEqual(output.values, [0, 10, 13, 14])
+ self.assertAllEqual(output.shape, [2, 6])
+ self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
index ee9a697a0b..6ea1e6d8eb 100644
--- a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
+++ b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
@@ -25,9 +25,11 @@ import tensorflow as tf
def _SparseToDense(sparse_indices, output_size, sparse_values,
- default_value):
+ default_value, validate_indices=True):
return tf.sparse_to_dense(sparse_indices, output_size,
- sparse_values, default_value)
+ sparse_values,
+ default_value=default_value,
+ validate_indices=validate_indices)
class SparseToDenseTest(tf.test.TestCase):
@@ -107,10 +109,24 @@ class SparseToDenseTest(tf.test.TestCase):
def testBadDefault(self):
with self.test_session():
- dense = _SparseToDense([1, 3], [5], [1, 2], [1, 2])
+ dense = _SparseToDense([1, 3], [5], [1, 2], [0])
with self.assertRaisesOpError("default_value should be a scalar"):
dense.eval()
+ def testInvalidIndicesWithWithoutValidation(self):
+ with self.test_session():
+ dense = _SparseToDense(
+ sparse_indices=[[1], [1]], output_size=[5],
+ sparse_values=[-1.0, 1.0], default_value=0.0)
+ with self.assertRaisesOpError(
+ "not lexicographically sorted or containing repeats"):
+ dense.eval()
+ # Disable checks
+ dense_without_validation = _SparseToDense(
+ sparse_indices=[[1], [1]], output_size=[5],
+ sparse_values=[-1.0, 1.0], default_value=0.0, validate_indices=False)
+ dense_without_validation.eval()
+
def testShapeInferenceKnownShape(self):
with self.test_session(use_gpu=False):
indices = tf.placeholder(tf.int64)
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index c6af05ff22..c769987a47 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -186,8 +186,8 @@ class TransposeTest(tf.test.TestCase):
def testError(self):
with self.assertRaises(ValueError):
tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [[0, 1], [2, 3]])
- self._testError(np.arange(0., 2 ** 10).reshape([2] * 10),
- np.arange(10),
+ self._testError(np.arange(0., 2 ** 11).reshape([2] * 11),
+ np.arange(11),
"not implemented")
with self.assertRaises(IndexError):
tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [0, 1, 3])