aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-04-18 16:35:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 16:38:07 -0700
commitfddfa9f8dcd1a922ade5362c0538ca39e99472a7 (patch)
tree7320668c20041c85bcce4e0c6614f7270dce44c6 /tensorflow/contrib
parente9d47fbff0d644a75c6f3dcdcb852685ef515b64 (diff)
Change distribution.distribute_dataset to accept an input_fn instead of a dataset.
PiperOrigin-RevId: 193437651
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py31
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py5
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py4
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/single_loss_example.py33
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py14
7 files changed, 53 insertions, 42 deletions
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index d7fbf7f379..6c73250ded 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -54,21 +54,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss,
is_tpu):
with distribution.scope():
- model_fn, dataset, layer = minimize_loss_example(
- optimizer_fn,
- use_bias=True,
- use_callable_loss=use_callable_loss)
+ model_fn, dataset_fn, layer = minimize_loss_example(
+ optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
+ def tpu_dataset_fn():
+ return dataset_fn().batch(2)
# TODO(isaprykin): Eliminate `is_tpu`. Probably add a
# `DistributionStrategy.create_monitor` so that each DistributionStrategy
# could influence its training loop. That method would return an instance
# of Monitor. TPUMonitor would execute tpu.initialize_system() and
# tpu.shutdown_system().
- if is_tpu:
- dataset = dataset.batch(2)
-
iterator = distribution.distribute_dataset(
- dataset).make_one_shot_iterator()
+ tpu_dataset_fn if is_tpu else dataset_fn).make_one_shot_iterator()
def run_step():
# TODO(isaprykin): Make iterator get_next() return a list of sub-
@@ -122,14 +119,14 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
# `distribution.scope`.
with variable_scope.variable_creator_scope(
appending_creator), distribution.scope():
- model_fn, dataset, layer = minimize_loss_example(
+ model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn,
use_bias=True,
use_callable_loss=True,
create_optimizer_inside_model_fn=True)
iterator = distribution.distribute_dataset(
- dataset).make_one_shot_iterator()
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.group(
@@ -176,7 +173,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
"""Verifies that moving mean updates are reduced across towers."""
with distribution.scope():
num_towers = len(distribution.worker_devices)
- model_fn, dataset, batchnorm = batchnorm_example(
+ model_fn, dataset_fn, batchnorm = batchnorm_example(
optimizer_fn,
batch_per_epoch=num_towers,
momentum=momentum,
@@ -188,7 +185,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
if isinstance(distribution, mirrored_strategy.MirroredStrategy):
distribution._prefetch_on_device = False
iterator = distribution.distribute_dataset(
- dataset).make_one_shot_iterator()
+ dataset_fn).make_one_shot_iterator()
def run_step():
return control_flow_ops.group(
@@ -260,11 +257,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
else:
return optimizer.minimize(loss_fn())
- features = dataset_ops.Dataset.from_tensors([[2.], [7.]])
- labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
- dataset = dataset_ops.Dataset.zip((features, labels)).repeat()
+ def dataset_fn():
+ features = dataset_ops.Dataset.from_tensors([[2.], [7.]])
+ labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
+ return dataset_ops.Dataset.zip((features, labels)).repeat()
+
iterator = distribution.distribute_dataset(
- dataset).make_one_shot_iterator()
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.group(
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index d5e22e8100..6efd578a77 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -140,9 +140,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
g.add_to_collections(collections, result)
return result
- def distribute_dataset(self, dataset):
+ def distribute_dataset(self, dataset_fn):
return values.PerDeviceDataset(
- dataset, self._devices, self._prefetch_on_device)
+ self._call_dataset_fn(dataset_fn), self._devices,
+ self._prefetch_on_device)
def _broadcast(self, tensor, destinations):
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 59cd6703b9..6c5c055070 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -247,9 +247,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
- features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
features = dist.distribute_dataset(
- features).make_one_shot_iterator().get_next()
+ lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
+ ).make_one_shot_iterator().get_next()
with dist.scope():
result = dist.call_for_each_tower(
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 2002266dd5..646d2a5c3b 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -60,8 +60,8 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.colocate_with(colocate_with):
return next_creator(*args, **kwargs)
- def distribute_dataset(self, dataset):
- return dataset
+ def distribute_dataset(self, dataset_fn):
+ return self._call_dataset_fn(dataset_fn)
def _broadcast(self, tensor, destinations):
return tensor
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index 6e4d050073..abd3a65ac4 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -39,11 +39,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
def testTrainNetwork(self, distribution, optimizer_fn,
use_callable_loss=True):
with distribution.scope():
- model_fn, dataset, layer = minimize_loss_example(
+ model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
iterator = distribution.distribute_dataset(
- dataset).make_one_shot_iterator()
+ dataset_fn).make_one_shot_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(
diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py
index cef5fd2f89..9e8f919c8a 100644
--- a/tensorflow/contrib/distribute/python/single_loss_example.py
+++ b/tensorflow/contrib/distribute/python/single_loss_example.py
@@ -29,7 +29,10 @@ from tensorflow.python.ops import math_ops
def single_loss_example(optimizer_fn, distribution, use_bias=False):
"""Build a very simple network to use in tests and examples."""
- dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
+
+ def dataset_fn():
+ return dataset_ops.Dataset.from_tensors([[1.]]).repeat()
+
optimizer = optimizer_fn()
layer = core.Dense(1, use_bias=use_bias)
@@ -37,8 +40,8 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False):
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
return y * y
- single_loss_step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer,
- distribution)
+ single_loss_step = step_fn.StandardSingleLossStep(dataset_fn, loss_fn,
+ optimizer, distribution)
# Layer is returned for inspecting the kernels in tests.
return single_loss_step, layer
@@ -49,7 +52,10 @@ def minimize_loss_example(optimizer_fn,
use_callable_loss=True,
create_optimizer_inside_model_fn=False):
"""Example of non-distribution-aware legacy code."""
- dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
+
+ def dataset_fn():
+ return dataset_ops.Dataset.from_tensors([[1.]]).repeat()
+
# An Optimizer instance is created either outside or inside model_fn.
outer_optimizer = None
if not create_optimizer_inside_model_fn:
@@ -71,7 +77,7 @@ def minimize_loss_example(optimizer_fn,
else:
return optimizer.minimize(loss_fn())
- return model_fn, dataset, layer
+ return model_fn, dataset_fn, layer
def batchnorm_example(optimizer_fn,
@@ -79,12 +85,15 @@ def batchnorm_example(optimizer_fn,
momentum=0.9,
renorm=False):
"""Example of non-distribution-aware legacy code with batch normalization."""
- # input shape is [16, 8], input values are increasing in both dimensions.
- dataset = dataset_ops.Dataset.from_tensor_slices(
- [[[float(x * 8 + y + z * 100)
- for y in range(8)]
- for x in range(16)]
- for z in range(batch_per_epoch)]).repeat()
+
+ def dataset_fn():
+ # input shape is [16, 8], input values are increasing in both dimensions.
+ return dataset_ops.Dataset.from_tensor_slices(
+ [[[float(x * 8 + y + z * 100)
+ for y in range(8)]
+ for x in range(16)]
+ for z in range(batch_per_epoch)]).repeat()
+
optimizer = optimizer_fn()
batchnorm = normalization.BatchNormalization(
renorm=renorm, momentum=momentum, fused=False)
@@ -99,4 +108,4 @@ def batchnorm_example(optimizer_fn,
# Callable loss.
return optimizer.minimize(loss_fn)
- return model_fn, dataset, batchnorm
+ return model_fn, dataset_fn, batchnorm
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index 68b8f4d626..d1910622b3 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -49,13 +49,14 @@ class StandardInputStep(Step):
"""Step with a standard implementation of input handling.
Args:
- input_dataset: a tf.data Dataset that provides input.
+ dataset_fn: a function that returns a tf.data Dataset that produces the
+ input for the model.
"""
- def __init__(self, input_dataset, distribution):
+ def __init__(self, dataset_fn, distribution):
Step.__init__(self, distribution)
self._distributed_input = distribution.distribute_dataset(
- input_dataset).make_one_shot_iterator()
+ dataset_fn).make_one_shot_iterator()
def inputs(self):
return self._distributed_input.get_next()
@@ -77,14 +78,15 @@ class StandardSingleLossStep(StandardInputStep):
```
Args:
- input_dataset: a tf.data Dataset that provides input.
+ dataset_fn: a function that returns a tf.data Dataset that produces the
+ input for the model.
loss_fn: a function that returns loss.
optimizer: an optimizer that implements an update rule.
distribution: a `DistributionStrategy` object.
"""
- def __init__(self, input_dataset, loss_fn, optimizer, distribution):
- StandardInputStep.__init__(self, input_dataset, distribution)
+ def __init__(self, dataset_fn, loss_fn, optimizer, distribution):
+ StandardInputStep.__init__(self, dataset_fn, distribution)
self._loss_fn = loss_fn
self._optimizer = optimizer
self._is_run_concurrently = False