From e595927d05f4d3669be29efd5f9cb5702a63b1c0 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 8 Feb 2018 11:28:55 -0800 Subject: [tf.data] Remove deprecated `tf.contrib.data.Iterator` alias. This change removes the following class: * `tf.contrib.data.Iterator`. IF THIS BREAKS YOU: Replace `tf.contrib.data.Iterator` with `tf.data.Iterator` when explicitly constructing an iterator. The API for the resulting object is identical. PiperOrigin-RevId: 185024771 --- tensorflow/contrib/data/__init__.py | 2 - tensorflow/contrib/data/python/kernel_tests/BUILD | 44 +- .../python/kernel_tests/get_single_element_test.py | 57 ++ .../kernel_tests/iterator_ops_cluster_test.py | 108 ---- .../data/python/kernel_tests/iterator_ops_test.py | 625 --------------------- 5 files changed, 59 insertions(+), 777 deletions(-) create mode 100644 tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py delete mode 100644 tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py delete mode 100644 tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 22de13b558..21db1044b0 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -22,7 +22,6 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Dataset @@Counter -@@Iterator @@batch_and_drop_remainder @@dense_to_sparse_batch @@ -68,7 +67,6 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat -from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 3ee82933bc..f58872f2a8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -208,59 +208,19 @@ py_test( ) tf_py_test( - name = "iterator_ops_cluster_test", + name = "get_single_element_test", size = "small", - srcs = ["iterator_ops_cluster_test.py"], - additional_deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:session", - "//tensorflow/python/data/ops:iterator_ops", - ], - grpc_enabled = True, - tags = [ - "no_windows", - "oss_serial", - ], -) - -tf_py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], + srcs = ["get_single_element_test.py"], additional_deps = [ "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:io_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:iterator_ops", ], - grpc_enabled = True, ) py_test( diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py new file mode 100644 index 0000000000..03d30bd100 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class GetSingleElementTest(test.TestCase): + + def testGetSingleElement(self): + skip_value = array_ops.placeholder(dtypes.int64, shape=[]) + take_value = array_ops.placeholder_with_default( + constant_op.constant(1, dtype=dtypes.int64), shape=[]) + + dataset = (dataset_ops.Dataset.range(100) + .skip(skip_value) + .map(lambda x: x * x) + .take(take_value)) + + element = dataset_ops.get_single_element(dataset) + + with self.test_session() as sess: + self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) + self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) + self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset was empty."): + sess.run(element, feed_dict={skip_value: 100}) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset had more than one element."): + sess.run(element, feed_dict={skip_value: 0, take_value: 2}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py deleted file mode 100644 index 02379d064d..0000000000 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops that need test_util.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.platform import test - - -class IteratorClusterTest(test.TestCase): - - def testRemoteIteratorWithoutRemoteCallFail(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 - worker, _ = test_util.create_local_cluster( - 1, 1, worker_config=worker_config) - - with ops.device("/job:worker/replica:0/task:0/cpu:1"): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - with ops.device("/job:worker/replica:0/task:0/cpu:0"): - remote_it = iterator_ops.Iterator.from_string_handle( - iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes) - get_next_op = remote_it.get_next() - - with session.Session(worker[0].target) as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next_op) - - def _testRemoteIteratorHelper(self, device0, device1, target): - with ops.device(device1): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - @function.Defun(dtypes.string) - def _remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) - return remote_iterator.get_next() - - with ops.device(device0): - target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - remote_op = functional_ops.remote_call( - args=[iterator_3_handle], - Tout=[dtypes.int32], - f=_remote_fn, - target=target_placeholder) - - with session.Session(target) as sess: - elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) - self.assertEqual(elem, [1]) - # Fails when target is cpu:0 where the resource is not located. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(remote_op, feed_dict={target_placeholder: device0}) - elem = sess.run(iterator_3.get_next()) - self.assertEqual(elem, [2]) - elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) - self.assertEqual(elem, [3]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(remote_op, feed_dict={target_placeholder: device1}) - - def testRemoteIteratorUsingRemoteCallOp(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 - worker, _ = test_util.create_local_cluster( - 1, 1, worker_config=worker_config) - - self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", - "/job:worker/replica:0/task:0/cpu:1", - worker[0].target) - - def testRemoteIteratorUsingRemoteCallOpCrossProcess(self): - workers, _ = test_util.create_local_cluster(2, 1) - - self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", - "/job:worker/replica:0/task:1/cpu:0", - workers[0].target) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py deleted file mode 100644 index 9d11865dda..0000000000 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import numpy as np - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.ops import readers -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import script_ops -from tensorflow.python.platform import test -from tensorflow.python.training import server_lib - - -class IteratorTest(test.TestCase): - - def testAttemptingGradientsRaiseExceptions(self): - component = constant_op.constant([1]) - side = constant_op.constant(0) - add = lambda x: x + side - dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) - value = dataset.make_one_shot_iterator().get_next() - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, component) - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, side) - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, [component, side]) - - def testOneShotIterator(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(14).make_one_shot_iterator()) - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOneShotIteratorCaptureByValue(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - tensor_components = tuple([ops.convert_to_tensor(c) for c in components]) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = (dataset_ops.Dataset.from_tensor_slices(tensor_components) - .map(_map_fn).repeat(14).make_one_shot_iterator()) - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOneShotIteratorInsideContainer(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - def within_container(): - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(_map_fn).repeat(14).make_one_shot_iterator()) - return iterator.get_next() - - server = server_lib.Server.create_local_server() - - # Create two iterators within unique containers, and run them to - # make sure that the resources aren't shared. - # - # The test below would fail if cname were the same across both - # sessions. - for i in range(2): - with session.Session(server.target) as sess: - cname = "iteration%d" % i - with ops.container(cname): - get_next = within_container() - - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOneShotIteratorNonBlocking(self): - dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - # Create a session with a single thread to ensure that the - # one-shot iterator initializer does not deadlock. - config = config_pb2.ConfigProto(inter_op_parallelism_threads=1, - use_per_session_threads=True) - with session.Session(config=config) as sess: - self.assertAllEqual([1, 4, 9], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Test with multiple threads invoking the one-shot iterator concurrently. - with session.Session(config=config) as sess: - results = [] - def consumer_thread(): - try: - results.append(sess.run(next_element)) - except errors.OutOfRangeError: - results.append(None) - - num_threads = 8 - threads = [ - self.checkedThread(consumer_thread) for _ in range(num_threads)] - for t in threads: - t.start() - for t in threads: - t.join() - - self.assertEqual(num_threads, len(results)) - self.assertEqual(num_threads - 1, - len([None for r in results if r is None])) - self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) - - def testOneShotIteratorInitializerFails(self): - # Define a dataset whose initialization will always fail. - dataset = dataset_ops.Dataset.from_tensors( - array_ops.check_numerics( - constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(next_element) - - # Test that subsequent attempts to use the iterator also fail. - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(next_element) - - with self.test_session() as sess: - def consumer_thread(): - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(next_element) - - num_threads = 8 - threads = [ - self.checkedThread(consumer_thread) for _ in range(num_threads)] - for t in threads: - t.start() - for t in threads: - t.join() - - def testSimpleSharedResource(self): - components = ( - np.array(1, dtype=np.int64), - np.array([1, 2, 3], dtype=np.int64), - np.array(37.0, dtype=np.float64) - ) - - server = server_lib.Server.create_local_server() - - # Create two non-overlapping sessions that share the same iterator - # resource on the same server, and verify that an action of the - # first session (initializing the iterator) is visible in the - # second session. - with ops.Graph().as_default(): - iterator = (dataset_ops.Dataset.from_tensors(components) - .map(lambda x, y, z: (x, y, z)).make_initializable_iterator( - shared_name="shared_iterator")) - init_op = iterator.initializer - get_next = iterator.get_next() - - with session.Session(server.target) as sess: - sess.run(init_op) - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Re-initialize the iterator in the first session. - sess.run(init_op) - - with ops.Graph().as_default(): - # Re-define the iterator manually, without defining any of the - # functions in this graph, to ensure that we are not - # accidentally redefining functions with the same names in the - # new graph. - iterator = iterator_ops.Iterator.from_structure( - shared_name="shared_iterator", - output_types=(dtypes.int64, dtypes.int64, dtypes.float64), - output_shapes=([], [3], [])) - get_next = iterator.get_next() - - with session.Session(server.target) as sess: - # Use the iterator without re-initializing in the second session. - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testNotInitializedError(self): - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - iterator = (dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.FailedPreconditionError, - "iterator has not been initialized"): - sess.run(get_next) - - def testReinitializableIterator(self): - dataset_3 = dataset_ops.Dataset.from_tensors( - constant_op.constant([1, 2, 3])) - dataset_4 = dataset_ops.Dataset.from_tensors( - constant_op.constant([4, 5, 6, 7])) - iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types, - [None]) - - dataset_3_init_op = iterator.make_initializer(dataset_3) - dataset_4_init_op = iterator.make_initializer(dataset_4) - get_next = iterator.get_next() - - self.assertEqual(dataset_3.output_types, iterator.output_types) - self.assertEqual(dataset_4.output_types, iterator.output_types) - self.assertEqual([None], iterator.output_shapes.as_list()) - - with self.test_session() as sess: - # The iterator is initially uninitialized. - with self.assertRaises(errors.FailedPreconditionError): - sess.run(get_next) - - # Initialize with one dataset. - sess.run(dataset_3_init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Initialize with a different dataset. - sess.run(dataset_4_init_op) - self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Reinitialize with the first dataset. - sess.run(dataset_3_init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testReinitializableIteratorStaticErrors(self): - # Non-matching structure for types and shapes. - with self.assertRaises(TypeError): - iterator = iterator_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64), [None]) - - # Test validation of dataset argument. - iterator = iterator_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64)) - - # Incompatible structure. - with self.assertRaises(ValueError): - iterator.make_initializer( - dataset_ops.Dataset.from_tensors(((constant_op.constant( - [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float64),)))) - - # Incompatible types. - with self.assertRaises(TypeError): - iterator.make_initializer( - dataset_ops.Dataset.from_tensors((constant_op.constant( - [1, 2, 3], dtype=dtypes.int32), constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float32)))) - - # Incompatible shapes. - iterator = iterator_ops.Iterator.from_structure( - (dtypes.int64, dtypes.float64), ([None], [])) - with self.assertRaises(TypeError): - iterator.make_initializer( - dataset_ops.Dataset.from_tensors((constant_op.constant( - [1, 2, 3], dtype=dtypes.int64), constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float64)))) - - def testIteratorStringHandle(self): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) - - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_4 = dataset_4.make_one_shot_iterator() - - handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - feedable_iterator = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) - next_element = feedable_iterator.get_next() - - self.assertEqual(dataset_3.output_types, feedable_iterator.output_types) - self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) - self.assertEqual([], feedable_iterator.output_shapes) - - with self.test_session() as sess: - iterator_3_handle = sess.run(iterator_3.string_handle()) - iterator_4_handle = sess.run(iterator_4.string_handle()) - - self.assertEqual( - 10, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 1, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 20, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 2, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 30, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 3, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 40, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle}) - - def testIteratorStringHandleError(self): - dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices([1, 2, - 3]).repeat()) - dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])) - - handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - - feedable_int_scalar = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dtypes.int32, []) - feedable_int_vector = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dtypes.int32, [None]) - feedable_int_any = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dtypes.int32) - - with self.test_session() as sess: - handle_int_scalar = sess.run( - dataset_int_scalar.make_one_shot_iterator().string_handle()) - handle_float_vector = sess.run( - dataset_float_vector.make_one_shot_iterator().string_handle()) - - self.assertEqual(1, - sess.run( - feedable_int_scalar.get_next(), - feed_dict={handle_placeholder: handle_int_scalar})) - - self.assertEqual(2, - sess.run( - feedable_int_any.get_next(), - feed_dict={handle_placeholder: handle_int_scalar})) - - with self.assertRaises(errors.InvalidArgumentError): - print(sess.run( - feedable_int_vector.get_next(), - feed_dict={handle_placeholder: handle_int_scalar})) - - with self.assertRaises(errors.InvalidArgumentError): - print(sess.run( - feedable_int_vector.get_next(), - feed_dict={handle_placeholder: handle_float_vector})) - - def testRemoteIteratorUsingRemoteCallOpDirectSession(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 3 - - with ops.device("/job:localhost/replica:0/task:0/cpu:1"): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - @function.Defun(dtypes.string) - def _remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) - return remote_iterator.get_next() - - with ops.device("/job:localhost/replica:0/task:0/cpu:0"): - target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - remote_op = functional_ops.remote_call( - args=[iterator_3_handle], - Tout=[dtypes.int32], - f=_remote_fn, - target=target_placeholder) - - with self.test_session(config=worker_config) as sess: - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - self.assertEqual(elem, [1]) - # Fails when target is cpu:2 where the resource is not located. - with self.assertRaises(errors.InvalidArgumentError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" - }) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - self.assertEqual(elem, [2]) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - self.assertEqual(elem, [3]) - with self.assertRaises(errors.OutOfRangeError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - - def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - with ops.device("/job:localhost/replica:0/task:0/cpu:0"): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - def _encode_raw(byte_array): - return bytes(bytearray(byte_array)) - - @function.Defun(dtypes.uint8) - def _remote_fn(h): - handle = script_ops.py_func(_encode_raw, [h], dtypes.string) - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, dataset_3.output_types, dataset_3.output_shapes) - return remote_iterator.get_next() - - with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): - target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - iterator_3_handle_uint8 = parsing_ops.decode_raw( - bytes=iterator_3_handle, out_type=dtypes.uint8) - remote_op = functional_ops.remote_call( - args=[iterator_3_handle_uint8], - Tout=[dtypes.int32], - f=_remote_fn, - target=target_placeholder) - - with self.test_session() as sess: - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - self.assertEqual(elem, [1]) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - self.assertEqual(elem, [2]) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - self.assertEqual(elem, [3]) - with self.assertRaises(errors.OutOfRangeError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - - def testIncorrectIteratorRestore(self): - - def _path(): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - _path(), parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(_path()), dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def _build_range_dataset_graph(): - start = 1 - stop = 10 - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = _save_op(iterator._iterator_resource) - restore_op = _restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - def _build_reader_dataset_graph(): - filenames = ["test"] # Does not exist but we don't care in this test. - iterator = readers.FixedLengthRecordDataset( - filenames, 1, 0, 0).make_initializable_iterator() - init_op = iterator.initializer - get_next_op = iterator.get_next() - save_op = _save_op(iterator._iterator_resource) - restore_op = _restore_op(iterator._iterator_resource) - return init_op, get_next_op, save_op, restore_op - - # Saving iterator for RangeDataset graph. - with ops.Graph().as_default() as g: - init_op, _, save_op, _ = _build_range_dataset_graph() - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(save_op) - - # Attempt to restore the saved iterator into an IteratorResource of - # incompatible type. An iterator of RangeDataset has output type int64, - # while an iterator of FixedLengthRecordDataset has output type string. - # So an InvalidArgumentError should be raised by - # IteratorResource::set_iterator. - with ops.Graph().as_default() as g: - _, _, _, restore_op = _build_reader_dataset_graph() - with self.test_session(graph=g) as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(restore_op) - - def testToSingleElement(self): - skip_value = array_ops.placeholder(dtypes.int64, shape=[]) - take_value = array_ops.placeholder_with_default( - constant_op.constant(1, dtype=dtypes.int64), shape=[]) - - dataset = (dataset_ops.Dataset.range(100) - .skip(skip_value) - .map(lambda x: x * x) - .take(take_value)) - - element = dataset_ops.get_single_element(dataset) - - with self.test_session() as sess: - self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) - self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) - self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset was empty."): - sess.run(element, feed_dict={skip_value: 100}) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset had more than one element."): - sess.run(element, feed_dict={skip_value: 0, take_value: 2}) - - -if __name__ == "__main__": - test.main() -- cgit v1.2.3