aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/data_flow_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/data_flow_ops.py')
-rw-r--r--tensorflow/python/ops/data_flow_ops.py743
1 files changed, 639 insertions, 104 deletions
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index c272a7115d..4eead79531 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -516,7 +516,7 @@ class QueueBase(object):
that would block will fail immediately.
If `cancel_pending_enqueues` is `True`, all pending requests will also
- be cancelled.
+ be canceled.
Args:
cancel_pending_enqueues: (Optional.) A boolean, defaulting to
@@ -988,7 +988,7 @@ class Barrier(object):
TakeMany operations that would block will fail immediately.
If `cancel_pending_enqueues` is `True`, all pending requests to the
- underlying queue will also be cancelled, and completing of already
+ underlying queue will also be canceled, and completing of already
started values is also not acceptable anymore.
Args:
@@ -1344,72 +1344,30 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
dense_shape=return_val.shape)
-class StagingArea(object):
- """Class for staging inputs. No ordering guarantees.
-
- A `StagingArea` is a TensorFlow data structure that stores tensors across
- multiple steps, and exposes operations that can put and get tensors.
-
- Each `StagingArea` element is a tuple of one or more tensors, where each
- tuple component has a static dtype, and may have a static shape.
-
- The capacity of a `StagingArea` is unbounded and supports multiple
- concurrent producers and consumers; and provides exactly-once delivery.
-
- Each element of a `StagingArea` is a fixed-length tuple of tensors whose
- dtypes are described by `dtypes`, and whose shapes are optionally described
- by the `shapes` argument.
-
- If the `shapes` argument is specified, each component of a staging area
- element must have the respective fixed shape. If it is
- unspecified, different elements may have different shapes,
- """
-
+class BaseStagingArea(object):
+ """Base class for Staging Areas."""
_identifier = 0
_lock = threading.Lock()
- def __init__(self, dtypes, shapes=None, names=None, shared_name=None):
- """Constructs a staging area object.
-
- The two optional lists, `shapes` and `names`, must be of the same length
- as `dtypes` if provided. The values at a given index `i` indicate the
- shape and name to use for the corresponding queue component in `dtypes`.
-
- The device scope at the time of object creation determines where the
- storage for the `StagingArea` will reside. Calls to `put` will incur a copy
- to this memory space, if necessary. Tensors returned by `get` will be
- placed according to the device scope when `get` is called.
-
- Args:
- dtypes: A list of types. The length of dtypes must equal the number
- of tensors in each element.
- shapes: (Optional.) Constraints on the shapes of tensors in an element.
- A list of shape tuples or None. This list is the same length
- as dtypes. If the shape of any tensors in the element are constrained,
- all must be; shapes can be None if the shapes should not be constrained.
- names: (Optional.) If provided, the `get()` and
- `put()` methods will use dictionaries with these names as keys.
- Must be None or a list or tuple of the same length as `dtypes`.
- shared_name: (Optional.) A name to be used for the shared object. By
- passing the same name to two different python objects they will share
- the underlying staging area. Must be a string.
-
- Raises:
- ValueError: If one of the arguments is invalid.
- """
+ def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
+ capacity=0, memory_limit=0):
if shared_name is None:
- self._name = ops.get_default_graph().unique_name("StagingArea")
+ self._name = (ops.get_default_graph()
+ .unique_name(self.__class__.__name__))
elif isinstance(shared_name, six.string_types):
self._name = shared_name
else:
raise ValueError("shared_name must be a string")
+
self._dtypes = dtypes
+
if shapes is not None:
if len(shapes) != len(dtypes):
raise ValueError("StagingArea shapes must be the same length as dtypes")
self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
else:
self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
+
if names is not None:
if len(names) != len(dtypes):
raise ValueError("StagingArea names must be the same length as dtypes")
@@ -1417,6 +1375,9 @@ class StagingArea(object):
else:
self._names = None
+ self._capacity = capacity
+ self._memory_limit = memory_limit
+
# all get and put ops must colocate with this op
with ops.name_scope("%s_root" % self._name):
self._coloc_op = control_flow_ops.no_op()
@@ -1441,52 +1402,141 @@ class StagingArea(object):
"""The list of names for each component of a staging area element."""
return self._names
- def _check_put_dtypes(self, vals):
+ @property
+ def capacity(self):
+ """The maximum number of elements of this staging area."""
+ return self._capacity
+
+ @property
+ def memory_limit(self):
+ """The maximum number of bytes of this staging area."""
+ return self._memory_limit
+
+ def _check_put_dtypes(self, vals, indices=None):
"""Validate and convert `vals` to a list of `Tensor`s.
The `vals` argument can be a Tensor, a list or tuple of tensors, or a
dictionary with tensor values.
+ If `vals` is a list, then the appropriate indices associated with the
+ values must be provided.
+
If it is a dictionary, the staging area must have been constructed with a
`names` attribute and the dictionary keys must match the staging area names.
+ `indices` will be inferred from the dictionary keys.
If the staging area was constructed with a `names` attribute, `vals` must
be a dictionary.
+ Checks that the dtype and shape of each value matches that
+ of the staging area.
+
Args:
vals: A tensor, a list or tuple of tensors, or a dictionary..
Returns:
- A list of `Tensor` objects.
+ A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects
+ and `indices` is a list of indices associed with the tensors.
Raises:
- ValueError: If `vals` is invalid.
+ ValueError: If `vals` or `indices` is invalid.
"""
if isinstance(vals, dict):
if not self._names:
raise ValueError(
"Staging areas must have names to enqueue a dictionary")
- if sorted(self._names) != sorted(vals.keys()):
+ if not set(vals.keys()).issubset(self._names):
raise ValueError("Keys in dictionary to put do not match names "
"of staging area. Dictionary: (%s), Queue: (%s)" %
(sorted(vals.keys()), sorted(self._names)))
# The order of values in `self._names` indicates the order in which the
# tensors in the dictionary `vals` must be listed.
- vals = [vals[k] for k in self._names]
+ vals, indices, n = zip(*[(vals[k], i, k) for i, k in enumerate(self._names)
+ if k in vals])
else:
if self._names:
raise ValueError("You must enqueue a dictionary in a staging area "
"with names")
+
+ if indices is None:
+ raise ValueError("Indices must be supplied when inserting a list "
+ "of tensors")
+
+ if len(indices) != len(vals):
+ raise ValueError("Number of indices '%s' doesn't match "
+ "number of values '%s'")
+
if not isinstance(vals, (list, tuple)):
vals = [vals]
+ indices = [0]
+
+ # Sanity check number of values
+ if not len(vals) <= len(self._dtypes):
+ raise ValueError("Unexpected number of inputs '%s' vs '%s'" % (
+ len(values), len(self._dtypes)))
tensors = []
- for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
- tensors.append(
- ops.convert_to_tensor(
- val, dtype=dtype, name="component_%d" % i))
+
+ for val, i in zip(vals, indices):
+ dtype, shape = self._dtypes[i], self._shapes[i]
+ # Check dtype
+ if not val.dtype == dtype:
+ raise ValueError("Datatypes do not match. '%s' != '%s'" %(
+ str(val.dtype), str(dtype)))
+
+ # Check shape
+ val.get_shape().assert_is_compatible_with(shape)
+
+ tensors.append(ops.convert_to_tensor(val, dtype=dtype,
+ name="component_%d" % i))
+
+ return tensors, indices
+
+ def _create_device_transfers(self, tensors):
+ """Encode inter-device transfers if the current device
+ is not the same as the Staging Area's device
+ """
+
+ if not isinstance(tensors, (tuple, list)):
+ tensors = [tensors]
+
+ curr_device_scope = control_flow_ops.no_op().device
+
+ if curr_device_scope != self._coloc_op.device:
+ tensors = [array_ops.identity(t) for t in tensors]
return tensors
+ def _get_return_value(self, tensors, indices):
+ """Return the value to return from a get op.
+
+ If the staging area has names, return a dictionary with the
+ names as keys. Otherwise return either a single tensor
+ or a list of tensors depending on the length of `tensors`.
+
+ Args:
+ tensors: List of tensors from the get op.
+ indices: Indices of associated names and shapes
+
+ Returns:
+ A single tensor, a list of tensors, or a dictionary
+ of tensors.
+ """
+
+ tensors = self._create_device_transfers(tensors)
+
+ # Sets shape
+ for output, i in zip(tensors, indices):
+ output.set_shape(self._shapes[i])
+
+ if self._names:
+ # The returned values in `tensors` are in the same order as
+ # the names in `self._names`.
+ return {self._names[i]: t for t, i in zip(tensors, indices)}
+ elif len(tensors) == 1:
+ return tensors[0]
+ else:
+ return tensors
+
def _scope_vals(self, vals):
"""Return a list of values to pass to `name_scope()`.
@@ -1503,9 +1553,86 @@ class StagingArea(object):
else:
return [vals]
+class StagingArea(BaseStagingArea):
+ """Class for staging inputs. No ordering guarantees.
+
+ A `StagingArea` is a TensorFlow data structure that stores tensors across
+ multiple steps, and exposes operations that can put and get tensors.
+
+ Each `StagingArea` element is a tuple of one or more tensors, where each
+ tuple component has a static dtype, and may have a static shape.
+
+ The capacity of a `StagingArea` may be bounded or unbounded.
+ It supports multiple concurrent producers and consumers; and
+ provides exactly-once delivery.
+
+ Each element of a `StagingArea` is a fixed-length tuple of tensors whose
+ dtypes are described by `dtypes`, and whose shapes are optionally described
+ by the `shapes` argument.
+
+ If the `shapes` argument is specified, each component of a staging area
+ element must have the respective fixed shape. If it is
+ unspecified, different elements may have different shapes,
+
+ It can be configured with a capacity in which case
+ put(values) will block until space becomes available.
+
+ Similarly, it can be configured with a memory limit which
+ will block put(values) until space is available.
+ This is mostly useful for limiting the number of tensors on
+ devices such as GPUs.
+
+ All get() and peek() commands block if the the requested data
+ is not present in the Staging Area.
+
+ """
+
+ def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
+ capacity=0, memory_limit=0):
+ """Constructs a staging area object.
+
+ The two optional lists, `shapes` and `names`, must be of the same length
+ as `dtypes` if provided. The values at a given index `i` indicate the
+ shape and name to use for the corresponding queue component in `dtypes`.
+
+ The device scope at the time of object creation determines where the
+ storage for the `StagingArea` will reside. Calls to `put` will incur a copy
+ to this memory space, if necessary. Tensors returned by `get` will be
+ placed according to the device scope when `get` is called.
+
+ Args:
+ dtypes: A list of types. The length of dtypes must equal the number
+ of tensors in each element.
+ capacity: (Optional.) Maximum number of elements.
+ An integer. If zero, the Staging Area is unbounded
+ memory_limit: (Optional.) Maximum number of bytes of all tensors
+ in the Staging Area.
+ An integer. If zero, the Staging Area is unbounded
+ shapes: (Optional.) Constraints on the shapes of tensors in an element.
+ A list of shape tuples or None. This list is the same length
+ as dtypes. If the shape of any tensors in the element are constrained,
+ all must be; shapes can be None if the shapes should not be constrained.
+ names: (Optional.) If provided, the `get()` and
+ `put()` methods will use dictionaries with these names as keys.
+ Must be None or a list or tuple of the same length as `dtypes`.
+ shared_name: (Optional.) A name to be used for the shared object. By
+ passing the same name to two different python objects they will share
+ the underlying staging area. Must be a string.
+
+ Raises:
+ ValueError: If one of the arguments is invalid.
+ """
+
+ super(StagingArea, self).__init__(dtypes, shapes,
+ names, shared_name,
+ capacity, memory_limit)
+
def put(self, values, name=None):
"""Create an op that places a value into the staging area.
+ This operation will block if the `StagingArea` has reached
+ its capacity.
+
Args:
values: Tensor (or a tuple of Tensors) to place into the staging area.
name: A name for the operation (optional).
@@ -1518,46 +1645,25 @@ class StagingArea(object):
"""
with ops.name_scope(name, "%s_put" % self._name,
self._scope_vals(values)) as scope:
- vals = self._check_put_dtypes(values)
- if len(values) != len(self._dtypes):
- raise ValueError("Unexpected number of inputs " + str(len(values)) +
- "vs " + str(len(self._dtypes)))
- for val, dtype in zip(vals, self._dtypes):
- if val.dtype != dtype:
- raise ValueError("Datatypes do not match. " + str(val.dtype) + " != "
- + str(dtype))
- for val, shape in zip(vals, self._shapes):
- val.get_shape().assert_is_compatible_with(shape)
+ # Hard-code indices for this staging area
+ indices = (list(six.moves.range(len(values)))
+ if isinstance(values, (list, tuple)) else None)
+ vals, _ = self._check_put_dtypes(values, indices)
with ops.colocate_with(self._coloc_op):
op = gen_data_flow_ops.stage(values=vals, shared_name=self._name,
- name=scope)
+ name=scope, capacity=self._capacity,
+ memory_limit=self._memory_limit)
return op
- def _get_return_value(self, tensors):
- """Return the value to return from a get op.
-
- If the staging area has names, return a dictionary with the
- names as keys. Otherwise return either a single tensor
- or a list of tensors depending on the length of `tensors`.
-
- Args:
- tensors: List of tensors from the get op.
+ def __internal_get(self, get_fn, name):
+ with ops.colocate_with(self._coloc_op):
+ ret = get_fn()
- Returns:
- A single tensor, a list of tensors, or a dictionary
- of tensors.
- """
- if self._names:
- # The returned values in `tensors` are in the same order as
- # the names in `self._names`.
- return {n: tensors[i] for i, n in enumerate(self._names)}
- elif len(tensors) == 1:
- return tensors[0]
- else:
- return tensors
+ indices = list(six.moves.range(len(self._dtypes))) # Hard coded
+ return self._get_return_value(ret, indices)
def get(self, name=None):
"""Gets one element from this staging area.
@@ -1584,19 +1690,448 @@ class StagingArea(object):
if name is None:
name = "%s_get" % self._name
+ fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes,
+ shared_name=self._name, name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
+ return self.__internal_get(fn, name)
+
+ def peek(self, index, name=None):
+ """Peeks at an element in the staging area.
+
+ If the staging area is too small to contain the element at
+ the specified index, it will block until enough elements
+ are inserted to complete the operation.
+
+ The placement of the returned tensor will be determined by
+ the current device scope when this function is called.
+
+ Args:
+ index: The index of the tensor within the staging area
+ to look up.
+ name: A name for the operation (optional).
+
+ Returns:
+ The tuple of tensors that was gotten.
+ """
+ if name is None:
+ name = "%s_peek" % self._name
+
+ fn = lambda: gen_data_flow_ops.stage_peek(index,
+ dtypes=self._dtypes, shared_name=self._name,
+ name=name, capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
+ return self.__internal_get(fn, name)
+
+ def size(self, name=None):
+ """Returns the number of elements in the staging area.
+
+ Args:
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+ """
+ if name is None:
+ name = "%s_size" % self._name
+
+ return gen_data_flow_ops.stage_size(name=name, shared_name=self._name,
+ dtypes=self._dtypes, capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
+ def clear(self, name=None):
+ """Clears the staging area.
+
+ Args:
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+ """
+ if name is None:
+ name = "%s_clear" % self._name
+
+ return gen_data_flow_ops.stage_clear(name=name, shared_name=self._name,
+ dtypes=self._dtypes, capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
+class MapStagingArea(BaseStagingArea):
+ """
+ A `MapStagingArea` is a TensorFlow data structure that stores tensors across
+ multiple steps, and exposes operations that can put and get tensors.
+
+ Each `MapStagingArea` element is a (key, value) pair.
+ Only int64 keys are supported, other types should be
+ hashed to produce a key.
+ Values are a tuple of one or more tensors.
+ Each tuple component has a static dtype,
+ and may have a static shape.
+
+ The capacity of a `MapStagingArea` may be bounded or unbounded.
+ It supports multiple concurrent producers and consumers; and
+ provides exactly-once delivery.
+
+ Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors whose
+ dtypes are described by `dtypes`, and whose shapes are optionally described
+ by the `shapes` argument.
+
+ If the `shapes` argument is specified, each component of a staging area
+ element must have the respective fixed shape. If it is
+ unspecified, different elements may have different shapes,
+
+ It behaves like an associative container with support for:
+
+ - put(key, values)
+ - peek(key) like dict.get(key)
+ - get(key) like dict.pop(key)
+ - get(key=None) like dict.popitem()
+ - size()
+ - clear()
+
+ If ordered a tree structure ordered by key will be used and
+ get(key=None) will remove (key, value) pairs in increasing key order.
+ Otherwise a hashtable
+
+ It can be configured with a capacity in which case
+ put(key, values) will block until space becomes available.
+
+ Similarly, it can be configured with a memory limit which
+ will block put(key, values) until space is available.
+ This is mostly useful for limiting the number of tensors on
+ devices such as GPUs.
+
+ All get() and peek() commands block if the requested
+ (key, value) pair is not present in the staging area.
+
+ Partial puts are supported and will be placed in an incomplete
+ map until such time as all values associated with the key have
+ been inserted. Once completed, this (key, value) pair will be
+ inserted into the map. Data in the incomplete map
+ counts towards the memory limit, but not towards capacity limit.
+
+ Partial gets from the map are also supported.
+ This removes the partially requested tensors from the entry,
+ but the entry is only removed from the map once all tensors
+ associated with it are removed.
+ """
+
+ def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
+ ordered=False, capacity=0, memory_limit=0):
+ """
+ Args:
+ dtypes: A list of types. The length of dtypes must equal the number
+ of tensors in each element.
+ capacity: (Optional.) Maximum number of elements.
+ An integer. If zero, the Staging Area is unbounded
+ memory_limit: (Optional.) Maximum number of bytes of all tensors
+ in the Staging Area (excluding keys).
+ An integer. If zero, the Staging Area is unbounded
+ ordered: (Optional.) If True the underlying data structure
+ is a tree ordered on key. Otherwise assume a hashtable.
+ shapes: (Optional.) Constraints on the shapes of tensors in an element.
+ A list of shape tuples or None. This list is the same length
+ as dtypes. If the shape of any tensors in the element are constrained,
+ all must be; shapes can be None if the shapes should not be constrained.
+ names: (Optional.) If provided, the `get()` and
+ `put()` methods will use dictionaries with these names as keys.
+ Must be None or a list or tuple of the same length as `dtypes`.
+ shared_name: (Optional.) A name to be used for the shared object. By
+ passing the same name to two different python objects they will share
+ the underlying staging area. Must be a string.
+
+ Raises:
+ ValueError: If one of the arguments is invalid.
+
+ """
+
+ super(MapStagingArea, self).__init__(dtypes, shapes,
+ names, shared_name,
+ capacity, memory_limit)
+
+ # Defer to different methods depending if the map is ordered
+ self._ordered = ordered
+
+ if ordered:
+ self._put_fn = gen_data_flow_ops.ordered_map_stage
+ self._pop_fn = gen_data_flow_ops.ordered_map_unstage
+ self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key
+ self._peek_fn = gen_data_flow_ops.ordered_map_peek
+ self._size_fn = gen_data_flow_ops.ordered_map_size
+ self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size
+ self._clear_fn = gen_data_flow_ops.ordered_map_clear
+ else:
+ self._put_fn = gen_data_flow_ops.map_stage
+ self._pop_fn = gen_data_flow_ops.map_unstage
+ self._popitem_fn = gen_data_flow_ops.map_unstage_no_key
+ self._peek_fn = gen_data_flow_ops.map_peek
+ self._size_fn = gen_data_flow_ops.map_size
+ self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size
+ self._clear_fn = gen_data_flow_ops.map_clear
+
+ def put(self, key, vals, indices=None, name=None):
+ """
+ Create an op that stores the (key, vals) pair in the staging area.
+
+ Incomplete puts are possible, preferably using a dictionary for vals
+ as the appropriate dtypes and shapes can be inferred from the value names
+ dictionary key values. If vals is a list or tuple, indices must
+ also be specified so that the op knows at which element position
+ to perform the insert.
+
+ This operation will block if the capacity or memory limit of this
+ container is reached.
+
+ Args:
+ key: Key associated with the data
+ vals: Tensor (or a dict/tuple of Tensors) to place
+ into the staging area.
+ indices: (Optional) if vals is a tuple/list, this is required.
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+
+ Raises:
+ ValueError: If the number or type of inputs don't match the staging area.
+ """
+
+ with ops.name_scope(name, "%s_put" % self._name,
+ self._scope_vals(vals)) as scope:
+
+ vals, indices = self._check_put_dtypes(vals, indices)
+
+ with ops.colocate_with(self._coloc_op):
+ op = self._put_fn(key, indices, vals, dtypes=self._dtypes,
+ shared_name=self._name, name=scope,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
+ return op
+
+ def _get_indices_and_dtypes(self, indices=None):
+ if indices is None:
+ indices = list(six.moves.range(len(self._dtypes)))
+
+ if not isinstance(indices, (tuple, list)):
+ raise TypeError("Invalid indices type '%s'" % type(indices))
+
+ if len(indices) == 0:
+ raise ValueError("Empty indices")
+
+ if all(isinstance(i, str) for i in indices):
+ if self._names is None:
+ raise ValueError("String indices provided '%s', but this Staging Area "
+ "was not created with names." % indices)
+
+ try:
+ indices = [self._names.index(n) for n in indices]
+ except ValueError:
+ raise ValueError("Named index '%s' not in "
+ "Staging Area names '%s'" % (n, self._names))
+ elif all(isinstance(i, int) for i in indices):
+ pass
+ else:
+ raise TypeError("Mixed types in indices '%s'. "
+ "May only be str or int" % indices)
+
+ dtypes = [self._dtypes[i] for i in indices]
+
+ return indices, dtypes
+
+
+ def peek(self, key, indices=None, name=None):
+ """
+ Peeks at staging area data associated with the key.
+
+ If the key is not in the staging area, it will block
+ until the associated (key, value) is inserted.
+
+ Args:
+ key: Key associated with the required data
+ indices: Partial list of tensors to retrieve (optional).
+ A list of integer or string indices.
+ String indices are only valid if the Staging Area
+ has names associated with it.
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+ """
+
+ if name is None:
+ name = "%s_pop" % self._name
+
+ indices, dtypes = self._get_indices_and_dtypes(indices)
+
+ with ops.colocate_with(self._coloc_op):
+ result = self._peek_fn(key, shared_name=self._name,
+ indices=indices,
+ dtypes=dtypes,
+ name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
+ return self._get_return_value(result, indices)
+
+ def get(self, key=None, indices=None, name=None):
+ """
+ If the key is provided, the associated (key, value)
+ is returned from the staging area. If the key is not
+ in the staging area, this method will block until
+ the associated (key, value) is inserted.
+
+ If no key is provided and the staging area is ordered,
+ the (key, value) with the smallest key will be returned.
+ Otherwise, a random (key, value) will be returned.
+
+ If the staging area is empty when this operation executes,
+ it will block until there is an element to dequeue.
+
+ Args:
+ key: Key associated with the required data (Optional)
+ indices: Partial list of tensors to retrieve (optional).
+ A list of integer or string indices.
+ String indices are only valid if the Staging Area
+ has names associated with it.
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+ """
+ if key is None:
+ return self._popitem(indices=indices, name=name)
+ else:
+ return self._pop(key, indices=indices, name=name)
+
+ def _pop(self, key, indices=None, name=None):
+ """
+ Remove and return the associated (key, value)
+ is returned from the staging area. If the key is not
+ in the staging area, this method will block until
+ the associated (key, value) is inserted.
+
+ Args:
+ key: Key associated with the required data
+ indices: Partial list of tensors to retrieve (optional).
+ A list of integer or string indices.
+ String indices are only valid if the Staging Area
+ has names associated with it.
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+ """
+ if name is None:
+ name = "%s_get" % self._name
+
+ indices, dtypes = self._get_indices_and_dtypes(indices)
+
with ops.colocate_with(self._coloc_op):
- ret = gen_data_flow_ops.unstage(dtypes=self._dtypes,
- shared_name=self._name, name=name)
+ result = self._pop_fn(key, shared_name=self._name,
+ indices=indices,
+ dtypes=dtypes,
+ name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
- curr_device_scope = control_flow_ops.no_op().device
- if curr_device_scope != self._coloc_op.device:
- for i in range(len(ret)):
- ret[i] = array_ops.identity(ret[i])
+ return key, self._get_return_value(result, indices)
- for output, shape in zip(ret, self._shapes):
- output.set_shape(shape)
+ def _popitem(self, indices=None, name=None):
+ """
+ If the staging area is ordered,
+ the (key, value) with the smallest key will be returned.
+ Otherwise, a random (key, value) will be returned.
+
+ If the staging area is empty when this operation executes,
+ it will block until there is an element to dequeue.
+
+ Args:
+ key: Key associated with the required data
+ indices: Partial list of tensors to retrieve (optional).
+ A list of integer or string indices.
+ String indices are only valid if the Staging Area
+ has names associated with it.
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+ """
+ if name is None:
+ name = "%s_get_nokey" % self._name
+
+ indices, dtypes = self._get_indices_and_dtypes(indices)
+
+ with ops.colocate_with(self._coloc_op):
+ key, result = self._popitem_fn(shared_name=self._name,
+ indices=indices,
+ dtypes=dtypes,
+ name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
+ # Separate keys and results out from
+ # underlying namedtuple
+ key = self._create_device_transfers(key)[0]
+ result = self._get_return_value(result, indices)
+
+ return key, result
+
+ def size(self, name=None):
+ """
+ Returns the number of elements in the staging area.
+
+ Args:
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+ """
+ if name is None:
+ name = "%s_size" % self._name
+
+ return self._size_fn(shared_name=self._name,
+ name=name, dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
+ def incomplete_size(self, name=None):
+ """
+ Returns the number of incomplete elements in the staging area.
+
+ Args:
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+ """
+ if name is None:
+ name = "%s_incomplete_size" % self._name
+
+ return self._incomplete_size_fn(shared_name=self._name,
+ name=name, dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
+
+
+ def clear(self, name=None):
+ """
+ Clears the staging area.
+
+ Args:
+ name: A name for the operation (optional)
+
+ Returns:
+ The created op
+ """
+ if name is None:
+ name = "%s_clear" % self._name
- return self._get_return_value(ret)
+ return self._clear_fn(shared_name=self._name,
+ name=name, dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
class RecordInput(object):