aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-25 20:16:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 20:22:00 -0700
commit7f1d70d97f543d69a9f02cd6df0964f22f9278f3 (patch)
tree29612b6cd40203beba4f2b9689eef27a1f8da8d7
parent3f4b8c138165cc9deb0ed931c5a6bb3d8ab556f0 (diff)
Switching Distribution strategies to use MultiDeviceIterator. Currently only supported in Graph mode using initializable iterators. In a subsequent change, we'll add in support for Eager mode as well.
This removes prefetching_ops_v2 code. PiperOrigin-RevId: 214546754
-rw-r--r--tensorflow/contrib/distribute/python/BUILD28
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py26
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py6
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py12
-rw-r--r--tensorflow/contrib/distribute/python/monitor.py1
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py229
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py90
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py7
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py1
-rw-r--r--tensorflow/contrib/distribute/python/values.py50
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py22
13 files changed, 92 insertions, 391 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 48a7593ab4..7eead6e472 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -22,7 +22,6 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":input_ops",
- ":prefetching_ops_v2",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@@ -30,6 +29,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
@@ -648,32 +648,6 @@ cuda_py_test(
)
py_library(
- name = "prefetching_ops_v2",
- srcs = ["prefetching_ops_v2.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
- "//tensorflow/contrib/data/python/ops:prefetching_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-cuda_py_test(
- name = "prefetching_ops_v2_test",
- srcs = ["prefetching_ops_v2_test.py"],
- additional_deps = [
- ":prefetching_ops_v2",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
-)
-
-py_library(
name = "input_ops",
srcs = ["input_ops.py"],
visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 8163494c8e..f7773aff4f 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -86,10 +86,11 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ dataset_fn).make_initializable_iterator()
value, update = distribution.call_for_each_tower(
metric_fn, iterator.get_next())
update = distribution.group(update)
+ self.evaluate(iterator.initializer)
self.evaluate(variables.local_variables_initializer())
# TODO(josh11b): Once we switch to using a global batch size for input,
# replace "distribution.num_towers" with "1".
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index ba147e7824..d082d5c419 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
+ def _get_iterator(self, ds):
+ if context.executing_eagerly():
+ iterator = ds.make_one_shot_iterator()
+ else:
+ iterator = ds.make_initializable_iterator()
+ self.evaluate(iterator.initializer)
+ return iterator
+
@combinations.generate(
combinations.times(
combinations.distributions_and_v1_optimizers(),
@@ -62,8 +70,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -99,8 +106,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.group(
@@ -159,8 +165,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -244,8 +249,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -338,8 +342,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -432,8 +435,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output=loss)
return distribution.group(train_op)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
initial_loss = lambda: constant_op.constant(1e7)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 0c6805d682..945f450387 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -480,8 +480,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._prefetch_on_device)
else:
return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn), self._devices,
- self._prefetch_on_device)
+ self._call_dataset_fn(dataset_fn),
+ self._devices,
+ self._prefetch_on_device,
+ source_device=device_util.resolve("/device:CPU:0"))
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _run_steps_on_dataset(self, fn, iterator, iterations,
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index f51e543624..04c712ce1d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -300,9 +300,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
- features = dist.distribute_dataset(
- lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
- ).make_one_shot_iterator().get_next()
+ ds = dist.distribute_dataset(
+ lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
+ if context.executing_eagerly():
+ iterator = ds.make_one_shot_iterator()
+ else:
+ iterator = ds.make_initializable_iterator()
+ self.evaluate([iterator.initializer])
+
+ features = iterator.get_next()
with dist.scope():
result = dist.call_for_each_tower(
diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py
index 7644acedc9..17b7ab74f6 100644
--- a/tensorflow/contrib/distribute/python/monitor.py
+++ b/tensorflow/contrib/distribute/python/monitor.py
@@ -51,6 +51,7 @@ class Monitor(object):
else:
if session is None:
raise ValueError("Should provide a `session` in Graph mode.")
+ session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
self._run_step = session.make_callable(step_callable())
session.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index 6e9ba37a19..3064433129 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -42,8 +42,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ ds = distribution.distribute_dataset(dataset_fn)
+ if context.executing_eagerly():
+ iterator = ds.make_one_shot_iterator()
+ else:
+ iterator = ds.make_initializable_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(
@@ -52,6 +55,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.cached_session() as sess:
+ sess.run(iterator.initializer)
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
deleted file mode 100644
index 492d82f6a1..0000000000
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Extension of prefetching_ops to support more than one device."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import warnings
-
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
-from tensorflow.contrib.data.python.ops import prefetching_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.util import nest as data_nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.util import nest
-
-
-# pylint: disable=protected-access
-class _PrefetchToDeviceIterator(object):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset.
- one_shot: If true, we make a one shot iterator that's already initialized.
- devices: Devices on which to prefetch.
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server). Only used if one_shot
- is False.
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- one_shot,
- devices,
- buffer_size,
- shared_name=None):
- self._input_dataset = input_dataset
- self._get_next_call_count = 0
- self._one_shot = one_shot
- if shared_name is None:
- shared_name = ""
- self._devices = devices
-
- if self._one_shot:
- self._input_iterator = input_dataset.make_one_shot_iterator()
- else:
- self._input_iterator = iterator_ops.Iterator.from_structure(
- self._input_dataset.output_types, self._input_dataset.output_shapes,
- shared_name, self._input_dataset.output_classes)
- input_iterator_handle = self._input_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self._input_iterator.output_types,
- self._input_iterator.output_shapes,
- self._input_iterator.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- target_device = gen_dataset_ops.iterator_get_device(
- self._input_iterator._iterator_resource)
- self._buffering_resources = []
- for device in nest.flatten(self._devices):
- with ops.device(device):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_prefetch_fn,
- output_types=data_nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes)),
- target_device=target_device,
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=shared_name)
- self._buffering_resources.append(buffer_resource_handle)
-
- if not self._one_shot:
- reset_ops = []
- for buffer_resource in self._buffering_resources:
- reset_ops.append(
- prefetching_ops.function_buffering_resource_reset(buffer_resource))
- with ops.control_dependencies(reset_ops):
- self._initializer = self._input_iterator.make_initializer(
- self._input_dataset)
-
- def get_next(self, name=None):
- """See `tf.data.Iterator.get_next`."""
- self._get_next_call_count += 1
- if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
- warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
-
- flat_result = []
- # TODO(priyag): This will fail if the input size (typically number of
- # batches) is not divisible by number of devices.
- # How do we handle that more gracefully / let the user know?
- for buffer_resource in self._buffering_resources:
- flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
- buffer_resource,
- output_types=data_nest.flatten(sparse.as_dense_types(
- self.output_types, self.output_classes)), name=name)
-
- ret = sparse.deserialize_sparse_tensors(
- data_nest.pack_sequence_as(self.output_types, flat_ret),
- self.output_types, self.output_shapes, self.output_classes)
-
- for tensor, shape in zip(
- data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
- if isinstance(tensor, ops.Tensor):
- tensor.set_shape(shape)
- flat_result.append(ret)
-
- return nest.pack_sequence_as(self._devices, flat_result)
-
- @property
- def initializer(self):
- if self._one_shot:
- raise NotImplementedError("Can't initialize a one_shot_iterator")
- return self._initializer
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-# pylint: enable=protected-access
-
-
-class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
- """A `Dataset` whose iterator prefetches elements to other device(s)."""
-
- def __init__(self, input_dataset, devices, buffer_size):
- super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._devices = devices
- self._buffer_size = buffer_size if buffer_size is not None else 1
-
- def make_one_shot_iterator(self):
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=True,
- devices=self._devices,
- buffer_size=self._buffer_size)
-
- def make_initializable_iterator(self, shared_name=None):
- if context.executing_eagerly():
- raise RuntimeError(
- "make_initializable_iterator is not supported when eager "
- "execution is enabled.")
-
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=False,
- devices=self._devices,
- buffer_size=self._buffer_size,
- shared_name=shared_name)
-
- def _as_variant_tensor(self):
- # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
- # transformation methods is called.
- # TODO(mrry): Investigate support for chaining further transformations after
- # the prefetch, including GPU support.
- raise NotImplementedError("`prefetch_to_devices()` must be the last "
- "transformation in a dataset pipeline.")
-
- # TODO(priyag): Fix the output types, shapes and classes to match the result
- # of get_next (which has the additional nesting layer of devices now).
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
-
-def prefetch_to_devices(devices, buffer_size=None):
- """A transformation that prefetches dataset values to the given `devices`.
-
- NOTE: Although the transformation creates a `tf.data.Dataset`, the
- transformation must be the final `Dataset` in the input pipeline.
-
- Args:
- devices: A nested structure of devices on which to prefetch the data. It can
- be a single device name, or a tuple or list of device names.
- buffer_size: (Optional.) The number of elements to buffer on each device.
- Defaults to an automatically chosen value.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
- def _apply_fn(dataset):
- return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
-
- return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
deleted file mode 100644
index 16799104e8..0000000000
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for prefetching_ops_v2."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.distribute.python import prefetching_ops_v2
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
-from tensorflow.python.platform import test
-
-
-class PrefetchingOpsV2Test(test.TestCase):
-
- def testPrefetchToOneDevice(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToTwoDevicesInAList(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- output = []
- # TODO(rohanj): Modify test to go till the end of the dataset when we
- # switch to MultiDeviceIterator.
- with self.cached_session() as sess:
- for _ in range(4):
- result = sess.run(next_element)
- self.assertEqual(2, len(result))
- output.extend(result)
- self.assertEquals(set(range(8)), set(output))
-
- def testPrefetchToTwoDevicesWithReinit(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
-
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- # TODO(rohanj): Modify test to go till the end of the dataset when we
- # switch to MultiDeviceIterator.
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for _ in range(4):
- sess.run(next_element)
- sess.run(iterator.initializer)
- for _ in range(4):
- sess.run(next_element)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index 1b5a4f64e5..23bf36184f 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
from tensorflow.python.training import optimizer as optimizer_lib
@@ -50,7 +51,11 @@ class StandardInputStep(Step):
def __init__(self, dataset_fn, distribution):
super(StandardInputStep, self).__init__(distribution)
self._distributed_input = distribution.distribute_dataset(dataset_fn)
- self._iterator = self._distributed_input.make_one_shot_iterator()
+ if context.executing_eagerly():
+ self._iterator = self._distributed_input.make_one_shot_iterator()
+ else:
+ # TODO(priyag): Expose initializer via some initializer property.
+ self._iterator = self._distributed_input.make_initializable_iterator()
class StandardSingleLossStep(StandardInputStep):
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index f1ada49fa3..1ff9b9ceec 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
run_step = single_loss_step
else:
with self.cached_session() as sess:
+ sess.run(single_loss_step._iterator.initializer)
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index fafa6384a1..a0cd029f51 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -26,7 +26,7 @@ import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
-from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
@@ -683,7 +683,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
- data_list = self._iterator.get_next(name=name)
+ data_list = self._iterator.get_next()
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@@ -703,21 +703,26 @@ class PerDeviceDataIterator(object):
class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
- def __init__(self, dataset, devices, prefetch_on_device=None):
+ def __init__(
+ self,
+ dataset,
+ devices,
+ prefetch_on_device=None,
+ source_device="/cpu:0",
+ ):
self._devices = devices
+ self._source_device = source_device if source_device is not None else "/cpu:0"
# Default to using prefetching in graph mode, unless specified.
- # TODO(priyag): Enable prefetching in eager mode.
+ # TODO(rohanj): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
- if self._prefetch_on_device:
- self._dataset = dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(self._devices))
- else:
+ self._dataset = dataset
+ if not self._prefetch_on_device:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@@ -725,15 +730,33 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
+ # Graph mode prefetching with one shot iterator is disabled.
+ if not context.executing_eagerly():
+ raise ValueError("Cannot create a one shot iterator. Please use "
+ "`make_initializable_iterator()` instead.")
+ # Eager mode prefetching would error out in constructor. Only remaining
+ # cases are non-prefetching eager / graph mode. We delegate to
+ # PerDeviceDataIterator to handle them.
dataset_iterator = self._dataset.make_one_shot_iterator()
return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ dataset_iterator, self._devices, prefetch_on_device=False)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
- dataset_iterator = self._dataset.make_initializable_iterator()
+ # Eager mode generates already initialized iterators. Hence we cannot create
+ # an initializable iterator.
+ if context.executing_eagerly():
+ raise ValueError("Cannot create initializable iterator in Eager mode. "
+ "Please use `make_one_shot_iterator` instead.")
+ if self._prefetch_on_device:
+ dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ self._dataset, self._devices, source_device=self._source_device)
+ else:
+ dataset_iterator = self._dataset.make_initializable_iterator()
return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ dataset_iterator,
+ self._devices,
+ prefetch_on_device=self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -813,7 +836,10 @@ class MultiWorkerDataset(object):
worker_input = input_ops.auto_shard_dataset(
worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
- worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
+ worker_input,
+ worker_devices,
+ source_device=worker,
+ prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):
iterators = {}
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 15a85a28f5..002d61f46e 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -349,7 +349,11 @@ class PerDeviceDatasetTest(test.TestCase):
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=False)
- iterator = per_device_dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ iterator = per_device_dataset.make_one_shot_iterator()
+ else:
+ iterator = per_device_dataset.make_initializable_iterator()
+ self.evaluate([iterator.initializer])
for expected_value in expected_values:
next_element = iterator.get_next()
@@ -366,20 +370,14 @@ class PerDeviceDatasetTest(test.TestCase):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
- iterator = per_device_dataset.make_one_shot_iterator()
+ iterator = per_device_dataset.make_initializable_iterator()
+ self.evaluate([iterator.initializer])
- # With prefetching, we cannot guarantee which input ends up on which
- # device, so we verify that the complete set seen on all devices is
- # correct, and equal numbers are distributed to each device.
- combined_actual = []
- combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
- combined_actual.extend(self.evaluate([
- values.select_device(d, next_element) for d in devices]))
- combined_expected.extend(expected_value)
-
- self.assertEqual(set(combined_expected), set(combined_actual))
+ computed_value = self.evaluate(
+ [values.select_device(d, next_element) for d in devices])
+ self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()