diff options
author | 2018-04-18 16:35:44 -0700 | |
---|---|---|
committer | 2018-04-18 16:38:07 -0700 | |
commit | fddfa9f8dcd1a922ade5362c0538ca39e99472a7 (patch) | |
tree | 7320668c20041c85bcce4e0c6614f7270dce44c6 /tensorflow/contrib | |
parent | e9d47fbff0d644a75c6f3dcdcb852685ef515b64 (diff) |
Change distribution.distribute_dataset to accept an input_fn instead of a dataset.
PiperOrigin-RevId: 193437651
Diffstat (limited to 'tensorflow/contrib')
7 files changed, 53 insertions, 42 deletions
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index d7fbf7f379..6c73250ded 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -54,21 +54,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, is_tpu): with distribution.scope(): - model_fn, dataset, layer = minimize_loss_example( - optimizer_fn, - use_bias=True, - use_callable_loss=use_callable_loss) + model_fn, dataset_fn, layer = minimize_loss_example( + optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) + def tpu_dataset_fn(): + return dataset_fn().batch(2) # TODO(isaprykin): Eliminate `is_tpu`. Probably add a # `DistributionStrategy.create_monitor` so that each DistributionStrategy # could influence its training loop. That method would return an instance # of Monitor. TPUMonitor would execute tpu.initialize_system() and # tpu.shutdown_system(). - if is_tpu: - dataset = dataset.batch(2) - iterator = distribution.distribute_dataset( - dataset).make_one_shot_iterator() + tpu_dataset_fn if is_tpu else dataset_fn).make_one_shot_iterator() def run_step(): # TODO(isaprykin): Make iterator get_next() return a list of sub- @@ -122,14 +119,14 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # `distribution.scope`. with variable_scope.variable_creator_scope( appending_creator), distribution.scope(): - model_fn, dataset, layer = minimize_loss_example( + model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=True, create_optimizer_inside_model_fn=True) iterator = distribution.distribute_dataset( - dataset).make_one_shot_iterator() + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.group( @@ -176,7 +173,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): """Verifies that moving mean updates are reduced across towers.""" with distribution.scope(): num_towers = len(distribution.worker_devices) - model_fn, dataset, batchnorm = batchnorm_example( + model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, batch_per_epoch=num_towers, momentum=momentum, @@ -188,7 +185,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): if isinstance(distribution, mirrored_strategy.MirroredStrategy): distribution._prefetch_on_device = False iterator = distribution.distribute_dataset( - dataset).make_one_shot_iterator() + dataset_fn).make_one_shot_iterator() def run_step(): return control_flow_ops.group( @@ -260,11 +257,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): else: return optimizer.minimize(loss_fn()) - features = dataset_ops.Dataset.from_tensors([[2.], [7.]]) - labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) - dataset = dataset_ops.Dataset.zip((features, labels)).repeat() + def dataset_fn(): + features = dataset_ops.Dataset.from_tensors([[2.], [7.]]) + labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) + return dataset_ops.Dataset.zip((features, labels)).repeat() + iterator = distribution.distribute_dataset( - dataset).make_one_shot_iterator() + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.group( diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index d5e22e8100..6efd578a77 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -140,9 +140,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): g.add_to_collections(collections, result) return result - def distribute_dataset(self, dataset): + def distribute_dataset(self, dataset_fn): return values.PerDeviceDataset( - dataset, self._devices, self._prefetch_on_device) + self._call_dataset_fn(dataset_fn), self._devices, + self._prefetch_on_device) def _broadcast(self, tensor, destinations): # TODO(josh11b): In eager mode, use one thread per device, or async mode. diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 59cd6703b9..6c5c055070 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -247,9 +247,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase): dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) - features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) features = dist.distribute_dataset( - features).make_one_shot_iterator().get_next() + lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + ).make_one_shot_iterator().get_next() with dist.scope(): result = dist.call_for_each_tower( diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 2002266dd5..646d2a5c3b 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -60,8 +60,8 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.colocate_with(colocate_with): return next_creator(*args, **kwargs) - def distribute_dataset(self, dataset): - return dataset + def distribute_dataset(self, dataset_fn): + return self._call_dataset_fn(dataset_fn) def _broadcast(self, tensor, destinations): return tensor diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index 6e4d050073..abd3a65ac4 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -39,11 +39,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss=True): with distribution.scope(): - model_fn, dataset, layer = minimize_loss_example( + model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) iterator = distribution.distribute_dataset( - dataset).make_one_shot_iterator() + dataset_fn).make_one_shot_iterator() def run_step(): return control_flow_ops.group(distribution.unwrap( diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index cef5fd2f89..9e8f919c8a 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -29,7 +29,10 @@ from tensorflow.python.ops import math_ops def single_loss_example(optimizer_fn, distribution, use_bias=False): """Build a very simple network to use in tests and examples.""" - dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + + def dataset_fn(): + return dataset_ops.Dataset.from_tensors([[1.]]).repeat() + optimizer = optimizer_fn() layer = core.Dense(1, use_bias=use_bias) @@ -37,8 +40,8 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False): y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) return y * y - single_loss_step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer, - distribution) + single_loss_step = step_fn.StandardSingleLossStep(dataset_fn, loss_fn, + optimizer, distribution) # Layer is returned for inspecting the kernels in tests. return single_loss_step, layer @@ -49,7 +52,10 @@ def minimize_loss_example(optimizer_fn, use_callable_loss=True, create_optimizer_inside_model_fn=False): """Example of non-distribution-aware legacy code.""" - dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + + def dataset_fn(): + return dataset_ops.Dataset.from_tensors([[1.]]).repeat() + # An Optimizer instance is created either outside or inside model_fn. outer_optimizer = None if not create_optimizer_inside_model_fn: @@ -71,7 +77,7 @@ def minimize_loss_example(optimizer_fn, else: return optimizer.minimize(loss_fn()) - return model_fn, dataset, layer + return model_fn, dataset_fn, layer def batchnorm_example(optimizer_fn, @@ -79,12 +85,15 @@ def batchnorm_example(optimizer_fn, momentum=0.9, renorm=False): """Example of non-distribution-aware legacy code with batch normalization.""" - # input shape is [16, 8], input values are increasing in both dimensions. - dataset = dataset_ops.Dataset.from_tensor_slices( - [[[float(x * 8 + y + z * 100) - for y in range(8)] - for x in range(16)] - for z in range(batch_per_epoch)]).repeat() + + def dataset_fn(): + # input shape is [16, 8], input values are increasing in both dimensions. + return dataset_ops.Dataset.from_tensor_slices( + [[[float(x * 8 + y + z * 100) + for y in range(8)] + for x in range(16)] + for z in range(batch_per_epoch)]).repeat() + optimizer = optimizer_fn() batchnorm = normalization.BatchNormalization( renorm=renorm, momentum=momentum, fused=False) @@ -99,4 +108,4 @@ def batchnorm_example(optimizer_fn, # Callable loss. return optimizer.minimize(loss_fn) - return model_fn, dataset, batchnorm + return model_fn, dataset_fn, batchnorm diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index 68b8f4d626..d1910622b3 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -49,13 +49,14 @@ class StandardInputStep(Step): """Step with a standard implementation of input handling. Args: - input_dataset: a tf.data Dataset that provides input. + dataset_fn: a function that returns a tf.data Dataset that produces the + input for the model. """ - def __init__(self, input_dataset, distribution): + def __init__(self, dataset_fn, distribution): Step.__init__(self, distribution) self._distributed_input = distribution.distribute_dataset( - input_dataset).make_one_shot_iterator() + dataset_fn).make_one_shot_iterator() def inputs(self): return self._distributed_input.get_next() @@ -77,14 +78,15 @@ class StandardSingleLossStep(StandardInputStep): ``` Args: - input_dataset: a tf.data Dataset that provides input. + dataset_fn: a function that returns a tf.data Dataset that produces the + input for the model. loss_fn: a function that returns loss. optimizer: an optimizer that implements an update rule. distribution: a `DistributionStrategy` object. """ - def __init__(self, input_dataset, loss_fn, optimizer, distribution): - StandardInputStep.__init__(self, input_dataset, distribution) + def __init__(self, dataset_fn, loss_fn, optimizer, distribution): + StandardInputStep.__init__(self, dataset_fn, distribution) self._loss_fn = loss_fn self._optimizer = optimizer self._is_run_concurrently = False |