aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-28 13:50:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 14:32:29 -0700
commit1724d155f00b49bc817189247cbfb0df2092a9da (patch)
tree3f2606f84d779d8fca28a3d253c70176c4ed3fc1
parent64be2ecc07c698df05d88051ec42a0409d1a9863 (diff)
Automated rollback of commit 7f1d70d97f543d69a9f02cd6df0964f22f9278f3
PiperOrigin-RevId: 214989908
-rw-r--r--tensorflow/contrib/distribute/python/BUILD28
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py26
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py3
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py12
-rw-r--r--tensorflow/contrib/distribute/python/monitor.py1
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py232
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py90
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py7
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py1
-rw-r--r--tensorflow/contrib/distribute/python/values.py51
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py23
13 files changed, 396 insertions, 89 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index e329b964c4..422983dbef 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -22,6 +22,7 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":input_ops",
+ ":prefetching_ops_v2",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@@ -29,7 +30,6 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
- "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
@@ -648,6 +648,32 @@ cuda_py_test(
)
py_library(
+ name = "prefetching_ops_v2",
+ srcs = ["prefetching_ops_v2.py"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:prefetching_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+cuda_py_test(
+ name = "prefetching_ops_v2_test",
+ srcs = ["prefetching_ops_v2_test.py"],
+ additional_deps = [
+ ":prefetching_ops_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+py_library(
name = "input_ops",
srcs = ["input_ops.py"],
visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index f7773aff4f..8163494c8e 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -86,11 +86,10 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
- dataset_fn).make_initializable_iterator()
+ dataset_fn).make_one_shot_iterator()
value, update = distribution.call_for_each_tower(
metric_fn, iterator.get_next())
update = distribution.group(update)
- self.evaluate(iterator.initializer)
self.evaluate(variables.local_variables_initializer())
# TODO(josh11b): Once we switch to using a global batch size for input,
# replace "distribution.num_towers" with "1".
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index d082d5c419..ba147e7824 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -41,14 +41,6 @@ from tensorflow.python.ops.losses import losses_impl
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
- def _get_iterator(self, ds):
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
- self.evaluate(iterator.initializer)
- return iterator
-
@combinations.generate(
combinations.times(
combinations.distributions_and_v1_optimizers(),
@@ -70,7 +62,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -106,7 +99,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.group(
@@ -165,7 +159,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -249,7 +244,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -342,7 +338,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -435,7 +432,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output=loss)
return distribution.group(train_op)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
initial_loss = lambda: constant_op.constant(1e7)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 93d42e09a2..4d7516063c 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -484,8 +484,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._prefetch_on_device, self._auto_shard_dataset)
else:
return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn),
- self._devices,
+ self._call_dataset_fn(dataset_fn), self._devices,
self._prefetch_on_device)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 04c712ce1d..f51e543624 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -300,15 +300,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
- ds = dist.distribute_dataset(
- lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
- self.evaluate([iterator.initializer])
-
- features = iterator.get_next()
+ features = dist.distribute_dataset(
+ lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
+ ).make_one_shot_iterator().get_next()
with dist.scope():
result = dist.call_for_each_tower(
diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py
index 17b7ab74f6..7644acedc9 100644
--- a/tensorflow/contrib/distribute/python/monitor.py
+++ b/tensorflow/contrib/distribute/python/monitor.py
@@ -51,7 +51,6 @@ class Monitor(object):
else:
if session is None:
raise ValueError("Should provide a `session` in Graph mode.")
- session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
self._run_step = session.make_callable(step_callable())
session.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index 3064433129..6e9ba37a19 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -42,11 +42,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- ds = distribution.distribute_dataset(dataset_fn)
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(
@@ -55,7 +52,6 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.cached_session() as sess:
- sess.run(iterator.initializer)
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
new file mode 100644
index 0000000000..8d949943b7
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -0,0 +1,232 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extension of prefetching_ops to support more than one device."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.util import nest as data_nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.util import nest
+
+
+# pylint: disable=protected-access
+class _PrefetchToDeviceIterator(object):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset.
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ devices: Devices on which to prefetch.
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be shared
+ under the given name across multiple sessions that share the same devices
+ (e.g. when using a remote server). Only used if one_shot is False.
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ one_shot,
+ devices,
+ buffer_size,
+ shared_name=None):
+ self._input_dataset = input_dataset
+ self._get_next_call_count = 0
+ self._one_shot = one_shot
+ if shared_name is None:
+ shared_name = ""
+ self._devices = devices
+
+ if self._one_shot:
+ self._input_iterator = input_dataset.make_one_shot_iterator()
+ else:
+ self._input_iterator = iterator_ops.Iterator.from_structure(
+ self._input_dataset.output_types, self._input_dataset.output_shapes,
+ shared_name, self._input_dataset.output_classes)
+ input_iterator_handle = self._input_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self._input_iterator.output_types,
+ self._input_iterator.output_shapes,
+ self._input_iterator.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ target_device = ged_ops.experimental_iterator_get_device(
+ self._input_iterator._iterator_resource)
+ self._buffering_resources = []
+ for device in nest.flatten(self._devices):
+ with ops.device(device):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_prefetch_fn,
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)),
+ target_device=target_device,
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=shared_name)
+ self._buffering_resources.append(buffer_resource_handle)
+
+ if not self._one_shot:
+ reset_ops = []
+ for buffer_resource in self._buffering_resources:
+ reset_ops.append(
+ ged_ops.experimental_function_buffering_resource_reset(
+ buffer_resource))
+ with ops.control_dependencies(reset_ops):
+ self._initializer = self._input_iterator.make_initializer(
+ self._input_dataset)
+
+ def get_next(self, name=None):
+ """See `tf.data.Iterator.get_next`."""
+ self._get_next_call_count += 1
+ if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
+ warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
+
+ flat_result = []
+ # TODO(priyag): This will fail if the input size (typically number of
+ # batches) is not divisible by number of devices.
+ # How do we handle that more gracefully / let the user know?
+ for buffer_resource in self._buffering_resources:
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
+ buffer_resource,
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
+
+ ret = sparse.deserialize_sparse_tensors(
+ data_nest.pack_sequence_as(self.output_types, flat_ret),
+ self.output_types, self.output_shapes, self.output_classes)
+
+ for tensor, shape in zip(
+ data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
+ if isinstance(tensor, ops.Tensor):
+ tensor.set_shape(shape)
+ flat_result.append(ret)
+
+ return nest.pack_sequence_as(self._devices, flat_result)
+
+ @property
+ def initializer(self):
+ if self._one_shot:
+ raise NotImplementedError("Can't initialize a one_shot_iterator")
+ return self._initializer
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+# pylint: enable=protected-access
+
+
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` whose iterator prefetches elements to other device(s)."""
+
+ def __init__(self, input_dataset, devices, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._devices = devices
+ self._buffer_size = buffer_size if buffer_size is not None else 1
+
+ def make_one_shot_iterator(self):
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=True,
+ devices=self._devices,
+ buffer_size=self._buffer_size)
+
+ def make_initializable_iterator(self, shared_name=None):
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "make_initializable_iterator is not supported when eager "
+ "execution is enabled.")
+
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=False,
+ devices=self._devices,
+ buffer_size=self._buffer_size,
+ shared_name=shared_name)
+
+ def _as_variant_tensor(self):
+ # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
+ # transformation methods is called.
+ # TODO(mrry): Investigate support for chaining further transformations after
+ # the prefetch, including GPU support.
+ raise NotImplementedError("`prefetch_to_devices()` must be the last "
+ "transformation in a dataset pipeline.")
+
+ # TODO(priyag): Fix the output types, shapes and classes to match the result
+ # of get_next (which has the additional nesting layer of devices now).
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+def prefetch_to_devices(devices, buffer_size=None):
+ """A transformation that prefetches dataset values to the given `devices`.
+
+ NOTE: Although the transformation creates a `tf.data.Dataset`, the
+ transformation must be the final `Dataset` in the input pipeline.
+
+ Args:
+ devices: A nested structure of devices on which to prefetch the data. It can
+ be a single device name, or a tuple or list of device names.
+ buffer_size: (Optional.) The number of elements to buffer on each device.
+ Defaults to an automatically chosen value.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
+
+ return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
new file mode 100644
index 0000000000..16799104e8
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -0,0 +1,90 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for prefetching_ops_v2."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class PrefetchingOpsV2Test(test.TestCase):
+
+ def testPrefetchToOneDevice(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToTwoDevicesInAList(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ output = []
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
+ with self.cached_session() as sess:
+ for _ in range(4):
+ result = sess.run(next_element)
+ self.assertEqual(2, len(result))
+ output.extend(result)
+ self.assertEquals(set(range(8)), set(output))
+
+ def testPrefetchToTwoDevicesWithReinit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
+
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(4):
+ sess.run(next_element)
+ sess.run(iterator.initializer)
+ for _ in range(4):
+ sess.run(next_element)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index 23bf36184f..1b5a4f64e5 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
-from tensorflow.python.eager import context
from tensorflow.python.training import optimizer as optimizer_lib
@@ -51,11 +50,7 @@ class StandardInputStep(Step):
def __init__(self, dataset_fn, distribution):
super(StandardInputStep, self).__init__(distribution)
self._distributed_input = distribution.distribute_dataset(dataset_fn)
- if context.executing_eagerly():
- self._iterator = self._distributed_input.make_one_shot_iterator()
- else:
- # TODO(priyag): Expose initializer via some initializer property.
- self._iterator = self._distributed_input.make_initializable_iterator()
+ self._iterator = self._distributed_input.make_one_shot_iterator()
class StandardSingleLossStep(StandardInputStep):
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index 1ff9b9ceec..f1ada49fa3 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -50,7 +50,6 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
run_step = single_loss_step
else:
with self.cached_session() as sess:
- sess.run(single_loss_step._iterator.initializer)
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 327775a729..4955ded4d5 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -26,7 +26,7 @@ import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
-from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
@@ -683,7 +683,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
- data_list = self._iterator.get_next()
+ data_list = self._iterator.get_next(name=name)
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@@ -703,24 +703,21 @@ class PerDeviceDataIterator(object):
class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
- def __init__(
- self,
- dataset,
- devices,
- prefetch_on_device=None,
- ):
+ def __init__(self, dataset, devices, prefetch_on_device=None):
self._devices = devices
# Default to using prefetching in graph mode, unless specified.
- # TODO(rohanj): Enable prefetching in eager mode.
+ # TODO(priyag): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
- self._dataset = dataset
- if not self._prefetch_on_device:
+ if self._prefetch_on_device:
+ self._dataset = dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(self._devices))
+ else:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@@ -728,33 +725,15 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
- # Graph mode prefetching with one shot iterator is disabled.
- if not context.executing_eagerly():
- raise ValueError("Cannot create a one shot iterator. Please use "
- "`make_initializable_iterator()` instead.")
- # Eager mode prefetching would error out in constructor. Only remaining
- # cases are non-prefetching eager / graph mode. We delegate to
- # PerDeviceDataIterator to handle them.
dataset_iterator = self._dataset.make_one_shot_iterator()
- return PerDeviceDataIterator(
- dataset_iterator, self._devices, prefetch_on_device=False)
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
- # Eager mode generates already initialized iterators. Hence we cannot create
- # an initializable iterator.
- if context.executing_eagerly():
- raise ValueError("Cannot create initializable iterator in Eager mode. "
- "Please use `make_one_shot_iterator` instead.")
- if self._prefetch_on_device:
- dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
- self._dataset, self._devices)
- else:
- dataset_iterator = self._dataset.make_initializable_iterator()
- return PerDeviceDataIterator(
- dataset_iterator,
- self._devices,
- prefetch_on_device=self._prefetch_on_device)
+ dataset_iterator = self._dataset.make_initializable_iterator()
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -837,9 +816,7 @@ class MultiWorkerDataset(object):
worker_input = input_ops.auto_shard_dataset(
worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
- worker_input,
- worker_devices,
- prefetch_on_device=prefetch_on_device)
+ worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):
iterators = {}
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 002d61f46e..ae3e134333 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -349,11 +349,7 @@ class PerDeviceDatasetTest(test.TestCase):
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=False)
- if context.executing_eagerly():
- iterator = per_device_dataset.make_one_shot_iterator()
- else:
- iterator = per_device_dataset.make_initializable_iterator()
- self.evaluate([iterator.initializer])
+ iterator = per_device_dataset.make_one_shot_iterator()
for expected_value in expected_values:
next_element = iterator.get_next()
@@ -370,14 +366,21 @@ class PerDeviceDatasetTest(test.TestCase):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
- iterator = per_device_dataset.make_initializable_iterator()
- self.evaluate([iterator.initializer])
+ iterator = per_device_dataset.make_one_shot_iterator()
+ # With prefetching, we cannot guarantee which input ends up on which
+ # device, so we verify that the complete set seen on all devices is
+ # correct, and equal numbers are distributed to each device.
+ combined_actual = []
+ combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
- computed_value = self.evaluate(
- [values.select_device(d, next_element) for d in devices])
- self.assertEqual(expected_value, computed_value)
+ combined_actual.extend(
+ self.evaluate(
+ [values.select_device(d, next_element) for d in devices]))
+ combined_expected.extend(expected_value)
+
+ self.assertEqual(set(combined_expected), set(combined_actual))
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()