diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-08-31 20:19:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 20:24:08 -0700 |
commit | c4ac04af7bb1ceb2e83cd6e7a3e1cac5f3d5c256 (patch) | |
tree | fb197ef6ea0498635fa4a91d04a7ea7006d97ed0 /tensorflow/contrib/data | |
parent | 68452bed16ce2cfee8f962e5abef92d27fb0796f (diff) |
[tf.data] Avoiding serialization of (potentially large) tensors during optimization.
PiperOrigin-RevId: 211179990
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py | 41 |
2 files changed, 34 insertions, 8 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index b86a543fc3..34f594f741 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -293,6 +293,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index 446bf8d749..089717156c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -18,10 +18,13 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized +import numpy as np from tensorflow.contrib.data.python.ops import optimization from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -62,7 +65,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): "Asserted next 2 transformations but encountered only 1."): sess.run(get_next) - def testDefaultOptimizations(self): + def testOptimizationDefault(self): dataset = dataset_ops.Dataset.range(10).apply( optimization.assert_next( ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( @@ -75,7 +78,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testEmptyOptimizations(self): + def testOptimizationEmpty(self): dataset = dataset_ops.Dataset.range(10).apply( optimization.assert_next( ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( @@ -88,7 +91,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testOptimization(self): + def testOptimizationFusion(self): dataset = dataset_ops.Dataset.range(10).apply( optimization.assert_next( ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply( @@ -101,11 +104,9 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testStatefulFunctionOptimization(self): - dataset = dataset_ops.Dataset.range(10).apply( - optimization.assert_next([ - "MapAndBatch" - ])).map(lambda _: random_ops.random_uniform([])).batch(10).apply( + def testOptimizationStatefulFunction(self): + dataset = dataset_ops.Dataset.range(10).map( + lambda _: random_ops.random_uniform([])).batch(10).apply( optimization.optimize(["map_and_batch_fusion"])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() @@ -113,6 +114,30 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): with self.test_session() as sess: sess.run(get_next) + def testOptimizationLargeInputFromTensor(self): + input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) + dataset = dataset_ops.Dataset.from_tensors(input_t).apply( + optimization.optimize()) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) + sess.run(get_next) + + def testOptimizationLargeInputFromTensorSlices(self): + input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) + dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply( + optimization.optimize()) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) + sess.run(get_next) + if __name__ == "__main__": test.main() |