aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py77
1 files changed, 58 insertions, 19 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index e35be8a23f..cfef40e192 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.core.framework import graph_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
@@ -26,51 +25,91 @@ from tensorflow.python.platform import test
class OptimizeDatasetTest(test.TestCase):
+ def testAssertSuffix(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertEqual(0, sess.run(get_next))
+
+ def testAssertSuffixInvalid(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted Whoops transformation at offset 0 but encountered "
+ "Map transformation instead."
+ ):
+ sess.run(get_next)
+
+ def testAssertSuffixShort(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted next 2 transformations but encountered only 1."):
+ sess.run(get_next)
+
def testDefaultOptimizations(self):
- dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
- 10).apply(optimization.optimize())
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize())
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
- graph = graph_pb2.GraphDef().FromString(
- sess.run(dataset._as_serialized_graph()))
- self.assertTrue(
- all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testEmptyOptimizations(self):
- dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
- 10).apply(optimization.optimize([]))
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize([]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
- graph = graph_pb2.GraphDef().FromString(
- sess.run(dataset._as_serialized_graph()))
- self.assertTrue(
- all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testOptimization(self):
- dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
- 10).apply(optimization.optimize(["map_and_batch_fusion"]))
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
- graph = graph_pb2.GraphDef().FromString(
- sess.run(dataset._as_serialized_graph()))
- self.assertTrue(
- any([node.op == "MapAndBatchDatasetV2" for node in graph.node]))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testFunctionLibraryDefinitionModification(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).map(lambda x: x).apply(
+ optimization.optimize(["_test_only_function_rename"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(errors.NotFoundError,
+ "Function .* is not defined."):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()