aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-06-28 12:02:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 12:07:30 -0700
commitca6ad219889d6aa6c54be7c5aba2f67cca25643f (patch)
treefefb1d18580c8f56a520b71945540b06ab8c4605 /tensorflow/contrib/distribute/python/values.py
parent412292d0427c908f544730d7212246b755f7c203 (diff)
Add an output context that can be used to specify outputs to capture when running multiple steps at a time using the `run_steps_on_dataset` API. It allows the user's step function to specify which outputs to emit at what frequency. Currently it only supports capturing output from the last step, but will soon be augmented to support other use cases such as output each N steps.
PiperOrigin-RevId: 202520245
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index ce95b718f6..95390041f4 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -861,3 +861,72 @@ class MapOutput(object):
def get(self):
return self._l
+
+
+class MultiStepContext(object):
+ """A context object that can be used to capture things when running steps.
+
+ This context object is useful when running multiple steps at a time using the
+ `run_steps_on_dataset` API. For e.g. it allows the user's step function to
+ specify which outputs to emit at what frequency. Currently it only supports
+ capturing output from the last step, but will soon be augmented to support
+ other use cases such as output each N steps.
+ """
+
+ def __init__(self, initial_loop_values=None):
+ """Initializes an output context.
+
+ Args:
+ initial_loop_values: Initial values passed to the run steps
+ while loop. The only purpose is to verify the shapes and types
+ when the actual output is set. This will be removed once we
+ automatically infer the output shapes and types (and do not need to
+ check for user error in specifying them manually).
+ Returns:
+ A context object.
+ """
+ self._last_step_outputs = None
+ self._non_tensor_outputs = None
+ self._initial_loop_values = initial_loop_values
+
+ @property
+ def last_step_outputs(self):
+ """Return the last step's outputs."""
+ return self._last_step_outputs
+
+ @last_step_outputs.setter
+ def last_step_outputs(self, outputs):
+ """Set the last step's outputs."""
+ self._verify_structure_shapes_types(outputs, self._initial_loop_values)
+ self._last_step_outputs = outputs
+
+ @property
+ def non_tensor_outputs(self):
+ """Return the non tensor outputs."""
+ return self._non_tensor_outputs
+
+ @non_tensor_outputs.setter
+ def non_tensor_outputs(self, outputs):
+ """Set any non tensor outputs."""
+ self._non_tensor_outputs = outputs
+
+ def _verify_structure_shapes_types(self, left, right):
+ """Verify that the structure, shapes and types of left are same as right."""
+ nest.assert_same_structure(left, right)
+ flat_left = nest.flatten(left)
+ flat_right = nest.flatten(right)
+ assert len(flat_left) == len(flat_right), (
+ "Length of left {} and right {} should be same.".
+ format(len(flat_left), len(flat_right)))
+
+ for o, i in zip(flat_left, flat_right):
+ # TODO(priyag): Add checks for other types like IndexedSlices.
+ if isinstance(o, ops.Tensor):
+ assert isinstance(i, ops.Tensor)
+ assert o.shape == i.shape, (
+ "Shape {} of left {} doesn't match shape {} of right {}.".
+ format(o.shape, o, i.shape, i))
+ assert o.dtype == i.dtype, (
+ "Dtype {} of left {} doesn't match dtype {} of right {}.".
+ format(o.dtype, o, i.dtype, i))
+