aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-08-31 20:19:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 20:24:08 -0700
commitc4ac04af7bb1ceb2e83cd6e7a3e1cac5f3d5c256 (patch)
treefb197ef6ea0498635fa4a91d04a7ea7006d97ed0 /tensorflow/contrib/data
parent68452bed16ce2cfee8f962e5abef92d27fb0796f (diff)
[tf.data] Avoiding serialization of (potentially large) tensors during optimization.
PiperOrigin-RevId: 211179990
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py41
2 files changed, 34 insertions, 8 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index b86a543fc3..34f594f741 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -293,6 +293,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index 446bf8d749..089717156c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -18,10 +18,13 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
+import numpy as np
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -62,7 +65,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
"Asserted next 2 transformations but encountered only 1."):
sess.run(get_next)
- def testDefaultOptimizations(self):
+ def testOptimizationDefault(self):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
@@ -75,7 +78,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testEmptyOptimizations(self):
+ def testOptimizationEmpty(self):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
@@ -88,7 +91,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testOptimization(self):
+ def testOptimizationFusion(self):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
@@ -101,11 +104,9 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testStatefulFunctionOptimization(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next([
- "MapAndBatch"
- ])).map(lambda _: random_ops.random_uniform([])).batch(10).apply(
+ def testOptimizationStatefulFunction(self):
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda _: random_ops.random_uniform([])).batch(10).apply(
optimization.optimize(["map_and_batch_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
@@ -113,6 +114,30 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.test_session() as sess:
sess.run(get_next)
+ def testOptimizationLargeInputFromTensor(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
+ dataset = dataset_ops.Dataset.from_tensors(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensorSlices(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()