diff options
Diffstat (limited to 'tensorflow/python/ops/data_flow_ops.py')
-rw-r--r-- | tensorflow/python/ops/data_flow_ops.py | 743 |
1 files changed, 639 insertions, 104 deletions
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index c272a7115d..4eead79531 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -516,7 +516,7 @@ class QueueBase(object): that would block will fail immediately. If `cancel_pending_enqueues` is `True`, all pending requests will also - be cancelled. + be canceled. Args: cancel_pending_enqueues: (Optional.) A boolean, defaulting to @@ -988,7 +988,7 @@ class Barrier(object): TakeMany operations that would block will fail immediately. If `cancel_pending_enqueues` is `True`, all pending requests to the - underlying queue will also be cancelled, and completing of already + underlying queue will also be canceled, and completing of already started values is also not acceptable anymore. Args: @@ -1344,72 +1344,30 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase): dense_shape=return_val.shape) -class StagingArea(object): - """Class for staging inputs. No ordering guarantees. - - A `StagingArea` is a TensorFlow data structure that stores tensors across - multiple steps, and exposes operations that can put and get tensors. - - Each `StagingArea` element is a tuple of one or more tensors, where each - tuple component has a static dtype, and may have a static shape. - - The capacity of a `StagingArea` is unbounded and supports multiple - concurrent producers and consumers; and provides exactly-once delivery. - - Each element of a `StagingArea` is a fixed-length tuple of tensors whose - dtypes are described by `dtypes`, and whose shapes are optionally described - by the `shapes` argument. - - If the `shapes` argument is specified, each component of a staging area - element must have the respective fixed shape. If it is - unspecified, different elements may have different shapes, - """ - +class BaseStagingArea(object): + """Base class for Staging Areas.""" _identifier = 0 _lock = threading.Lock() - def __init__(self, dtypes, shapes=None, names=None, shared_name=None): - """Constructs a staging area object. - - The two optional lists, `shapes` and `names`, must be of the same length - as `dtypes` if provided. The values at a given index `i` indicate the - shape and name to use for the corresponding queue component in `dtypes`. - - The device scope at the time of object creation determines where the - storage for the `StagingArea` will reside. Calls to `put` will incur a copy - to this memory space, if necessary. Tensors returned by `get` will be - placed according to the device scope when `get` is called. - - Args: - dtypes: A list of types. The length of dtypes must equal the number - of tensors in each element. - shapes: (Optional.) Constraints on the shapes of tensors in an element. - A list of shape tuples or None. This list is the same length - as dtypes. If the shape of any tensors in the element are constrained, - all must be; shapes can be None if the shapes should not be constrained. - names: (Optional.) If provided, the `get()` and - `put()` methods will use dictionaries with these names as keys. - Must be None or a list or tuple of the same length as `dtypes`. - shared_name: (Optional.) A name to be used for the shared object. By - passing the same name to two different python objects they will share - the underlying staging area. Must be a string. - - Raises: - ValueError: If one of the arguments is invalid. - """ + def __init__(self, dtypes, shapes=None, names=None, shared_name=None, + capacity=0, memory_limit=0): if shared_name is None: - self._name = ops.get_default_graph().unique_name("StagingArea") + self._name = (ops.get_default_graph() + .unique_name(self.__class__.__name__)) elif isinstance(shared_name, six.string_types): self._name = shared_name else: raise ValueError("shared_name must be a string") + self._dtypes = dtypes + if shapes is not None: if len(shapes) != len(dtypes): raise ValueError("StagingArea shapes must be the same length as dtypes") self._shapes = [tensor_shape.TensorShape(s) for s in shapes] else: self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] + if names is not None: if len(names) != len(dtypes): raise ValueError("StagingArea names must be the same length as dtypes") @@ -1417,6 +1375,9 @@ class StagingArea(object): else: self._names = None + self._capacity = capacity + self._memory_limit = memory_limit + # all get and put ops must colocate with this op with ops.name_scope("%s_root" % self._name): self._coloc_op = control_flow_ops.no_op() @@ -1441,52 +1402,141 @@ class StagingArea(object): """The list of names for each component of a staging area element.""" return self._names - def _check_put_dtypes(self, vals): + @property + def capacity(self): + """The maximum number of elements of this staging area.""" + return self._capacity + + @property + def memory_limit(self): + """The maximum number of bytes of this staging area.""" + return self._memory_limit + + def _check_put_dtypes(self, vals, indices=None): """Validate and convert `vals` to a list of `Tensor`s. The `vals` argument can be a Tensor, a list or tuple of tensors, or a dictionary with tensor values. + If `vals` is a list, then the appropriate indices associated with the + values must be provided. + If it is a dictionary, the staging area must have been constructed with a `names` attribute and the dictionary keys must match the staging area names. + `indices` will be inferred from the dictionary keys. If the staging area was constructed with a `names` attribute, `vals` must be a dictionary. + Checks that the dtype and shape of each value matches that + of the staging area. + Args: vals: A tensor, a list or tuple of tensors, or a dictionary.. Returns: - A list of `Tensor` objects. + A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects + and `indices` is a list of indices associed with the tensors. Raises: - ValueError: If `vals` is invalid. + ValueError: If `vals` or `indices` is invalid. """ if isinstance(vals, dict): if not self._names: raise ValueError( "Staging areas must have names to enqueue a dictionary") - if sorted(self._names) != sorted(vals.keys()): + if not set(vals.keys()).issubset(self._names): raise ValueError("Keys in dictionary to put do not match names " "of staging area. Dictionary: (%s), Queue: (%s)" % (sorted(vals.keys()), sorted(self._names))) # The order of values in `self._names` indicates the order in which the # tensors in the dictionary `vals` must be listed. - vals = [vals[k] for k in self._names] + vals, indices, n = zip(*[(vals[k], i, k) for i, k in enumerate(self._names) + if k in vals]) else: if self._names: raise ValueError("You must enqueue a dictionary in a staging area " "with names") + + if indices is None: + raise ValueError("Indices must be supplied when inserting a list " + "of tensors") + + if len(indices) != len(vals): + raise ValueError("Number of indices '%s' doesn't match " + "number of values '%s'") + if not isinstance(vals, (list, tuple)): vals = [vals] + indices = [0] + + # Sanity check number of values + if not len(vals) <= len(self._dtypes): + raise ValueError("Unexpected number of inputs '%s' vs '%s'" % ( + len(values), len(self._dtypes))) tensors = [] - for i, (val, dtype) in enumerate(zip(vals, self._dtypes)): - tensors.append( - ops.convert_to_tensor( - val, dtype=dtype, name="component_%d" % i)) + + for val, i in zip(vals, indices): + dtype, shape = self._dtypes[i], self._shapes[i] + # Check dtype + if not val.dtype == dtype: + raise ValueError("Datatypes do not match. '%s' != '%s'" %( + str(val.dtype), str(dtype))) + + # Check shape + val.get_shape().assert_is_compatible_with(shape) + + tensors.append(ops.convert_to_tensor(val, dtype=dtype, + name="component_%d" % i)) + + return tensors, indices + + def _create_device_transfers(self, tensors): + """Encode inter-device transfers if the current device + is not the same as the Staging Area's device + """ + + if not isinstance(tensors, (tuple, list)): + tensors = [tensors] + + curr_device_scope = control_flow_ops.no_op().device + + if curr_device_scope != self._coloc_op.device: + tensors = [array_ops.identity(t) for t in tensors] return tensors + def _get_return_value(self, tensors, indices): + """Return the value to return from a get op. + + If the staging area has names, return a dictionary with the + names as keys. Otherwise return either a single tensor + or a list of tensors depending on the length of `tensors`. + + Args: + tensors: List of tensors from the get op. + indices: Indices of associated names and shapes + + Returns: + A single tensor, a list of tensors, or a dictionary + of tensors. + """ + + tensors = self._create_device_transfers(tensors) + + # Sets shape + for output, i in zip(tensors, indices): + output.set_shape(self._shapes[i]) + + if self._names: + # The returned values in `tensors` are in the same order as + # the names in `self._names`. + return {self._names[i]: t for t, i in zip(tensors, indices)} + elif len(tensors) == 1: + return tensors[0] + else: + return tensors + def _scope_vals(self, vals): """Return a list of values to pass to `name_scope()`. @@ -1503,9 +1553,86 @@ class StagingArea(object): else: return [vals] +class StagingArea(BaseStagingArea): + """Class for staging inputs. No ordering guarantees. + + A `StagingArea` is a TensorFlow data structure that stores tensors across + multiple steps, and exposes operations that can put and get tensors. + + Each `StagingArea` element is a tuple of one or more tensors, where each + tuple component has a static dtype, and may have a static shape. + + The capacity of a `StagingArea` may be bounded or unbounded. + It supports multiple concurrent producers and consumers; and + provides exactly-once delivery. + + Each element of a `StagingArea` is a fixed-length tuple of tensors whose + dtypes are described by `dtypes`, and whose shapes are optionally described + by the `shapes` argument. + + If the `shapes` argument is specified, each component of a staging area + element must have the respective fixed shape. If it is + unspecified, different elements may have different shapes, + + It can be configured with a capacity in which case + put(values) will block until space becomes available. + + Similarly, it can be configured with a memory limit which + will block put(values) until space is available. + This is mostly useful for limiting the number of tensors on + devices such as GPUs. + + All get() and peek() commands block if the the requested data + is not present in the Staging Area. + + """ + + def __init__(self, dtypes, shapes=None, names=None, shared_name=None, + capacity=0, memory_limit=0): + """Constructs a staging area object. + + The two optional lists, `shapes` and `names`, must be of the same length + as `dtypes` if provided. The values at a given index `i` indicate the + shape and name to use for the corresponding queue component in `dtypes`. + + The device scope at the time of object creation determines where the + storage for the `StagingArea` will reside. Calls to `put` will incur a copy + to this memory space, if necessary. Tensors returned by `get` will be + placed according to the device scope when `get` is called. + + Args: + dtypes: A list of types. The length of dtypes must equal the number + of tensors in each element. + capacity: (Optional.) Maximum number of elements. + An integer. If zero, the Staging Area is unbounded + memory_limit: (Optional.) Maximum number of bytes of all tensors + in the Staging Area. + An integer. If zero, the Staging Area is unbounded + shapes: (Optional.) Constraints on the shapes of tensors in an element. + A list of shape tuples or None. This list is the same length + as dtypes. If the shape of any tensors in the element are constrained, + all must be; shapes can be None if the shapes should not be constrained. + names: (Optional.) If provided, the `get()` and + `put()` methods will use dictionaries with these names as keys. + Must be None or a list or tuple of the same length as `dtypes`. + shared_name: (Optional.) A name to be used for the shared object. By + passing the same name to two different python objects they will share + the underlying staging area. Must be a string. + + Raises: + ValueError: If one of the arguments is invalid. + """ + + super(StagingArea, self).__init__(dtypes, shapes, + names, shared_name, + capacity, memory_limit) + def put(self, values, name=None): """Create an op that places a value into the staging area. + This operation will block if the `StagingArea` has reached + its capacity. + Args: values: Tensor (or a tuple of Tensors) to place into the staging area. name: A name for the operation (optional). @@ -1518,46 +1645,25 @@ class StagingArea(object): """ with ops.name_scope(name, "%s_put" % self._name, self._scope_vals(values)) as scope: - vals = self._check_put_dtypes(values) - if len(values) != len(self._dtypes): - raise ValueError("Unexpected number of inputs " + str(len(values)) + - "vs " + str(len(self._dtypes))) - for val, dtype in zip(vals, self._dtypes): - if val.dtype != dtype: - raise ValueError("Datatypes do not match. " + str(val.dtype) + " != " - + str(dtype)) - for val, shape in zip(vals, self._shapes): - val.get_shape().assert_is_compatible_with(shape) + # Hard-code indices for this staging area + indices = (list(six.moves.range(len(values))) + if isinstance(values, (list, tuple)) else None) + vals, _ = self._check_put_dtypes(values, indices) with ops.colocate_with(self._coloc_op): op = gen_data_flow_ops.stage(values=vals, shared_name=self._name, - name=scope) + name=scope, capacity=self._capacity, + memory_limit=self._memory_limit) return op - def _get_return_value(self, tensors): - """Return the value to return from a get op. - - If the staging area has names, return a dictionary with the - names as keys. Otherwise return either a single tensor - or a list of tensors depending on the length of `tensors`. - - Args: - tensors: List of tensors from the get op. + def __internal_get(self, get_fn, name): + with ops.colocate_with(self._coloc_op): + ret = get_fn() - Returns: - A single tensor, a list of tensors, or a dictionary - of tensors. - """ - if self._names: - # The returned values in `tensors` are in the same order as - # the names in `self._names`. - return {n: tensors[i] for i, n in enumerate(self._names)} - elif len(tensors) == 1: - return tensors[0] - else: - return tensors + indices = list(six.moves.range(len(self._dtypes))) # Hard coded + return self._get_return_value(ret, indices) def get(self, name=None): """Gets one element from this staging area. @@ -1584,19 +1690,448 @@ class StagingArea(object): if name is None: name = "%s_get" % self._name + fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes, + shared_name=self._name, name=name, + capacity=self._capacity, + memory_limit=self._memory_limit) + + return self.__internal_get(fn, name) + + def peek(self, index, name=None): + """Peeks at an element in the staging area. + + If the staging area is too small to contain the element at + the specified index, it will block until enough elements + are inserted to complete the operation. + + The placement of the returned tensor will be determined by + the current device scope when this function is called. + + Args: + index: The index of the tensor within the staging area + to look up. + name: A name for the operation (optional). + + Returns: + The tuple of tensors that was gotten. + """ + if name is None: + name = "%s_peek" % self._name + + fn = lambda: gen_data_flow_ops.stage_peek(index, + dtypes=self._dtypes, shared_name=self._name, + name=name, capacity=self._capacity, + memory_limit=self._memory_limit) + + return self.__internal_get(fn, name) + + def size(self, name=None): + """Returns the number of elements in the staging area. + + Args: + name: A name for the operation (optional) + + Returns: + The created op + """ + if name is None: + name = "%s_size" % self._name + + return gen_data_flow_ops.stage_size(name=name, shared_name=self._name, + dtypes=self._dtypes, capacity=self._capacity, + memory_limit=self._memory_limit) + + def clear(self, name=None): + """Clears the staging area. + + Args: + name: A name for the operation (optional) + + Returns: + The created op + """ + if name is None: + name = "%s_clear" % self._name + + return gen_data_flow_ops.stage_clear(name=name, shared_name=self._name, + dtypes=self._dtypes, capacity=self._capacity, + memory_limit=self._memory_limit) + +class MapStagingArea(BaseStagingArea): + """ + A `MapStagingArea` is a TensorFlow data structure that stores tensors across + multiple steps, and exposes operations that can put and get tensors. + + Each `MapStagingArea` element is a (key, value) pair. + Only int64 keys are supported, other types should be + hashed to produce a key. + Values are a tuple of one or more tensors. + Each tuple component has a static dtype, + and may have a static shape. + + The capacity of a `MapStagingArea` may be bounded or unbounded. + It supports multiple concurrent producers and consumers; and + provides exactly-once delivery. + + Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors whose + dtypes are described by `dtypes`, and whose shapes are optionally described + by the `shapes` argument. + + If the `shapes` argument is specified, each component of a staging area + element must have the respective fixed shape. If it is + unspecified, different elements may have different shapes, + + It behaves like an associative container with support for: + + - put(key, values) + - peek(key) like dict.get(key) + - get(key) like dict.pop(key) + - get(key=None) like dict.popitem() + - size() + - clear() + + If ordered a tree structure ordered by key will be used and + get(key=None) will remove (key, value) pairs in increasing key order. + Otherwise a hashtable + + It can be configured with a capacity in which case + put(key, values) will block until space becomes available. + + Similarly, it can be configured with a memory limit which + will block put(key, values) until space is available. + This is mostly useful for limiting the number of tensors on + devices such as GPUs. + + All get() and peek() commands block if the requested + (key, value) pair is not present in the staging area. + + Partial puts are supported and will be placed in an incomplete + map until such time as all values associated with the key have + been inserted. Once completed, this (key, value) pair will be + inserted into the map. Data in the incomplete map + counts towards the memory limit, but not towards capacity limit. + + Partial gets from the map are also supported. + This removes the partially requested tensors from the entry, + but the entry is only removed from the map once all tensors + associated with it are removed. + """ + + def __init__(self, dtypes, shapes=None, names=None, shared_name=None, + ordered=False, capacity=0, memory_limit=0): + """ + Args: + dtypes: A list of types. The length of dtypes must equal the number + of tensors in each element. + capacity: (Optional.) Maximum number of elements. + An integer. If zero, the Staging Area is unbounded + memory_limit: (Optional.) Maximum number of bytes of all tensors + in the Staging Area (excluding keys). + An integer. If zero, the Staging Area is unbounded + ordered: (Optional.) If True the underlying data structure + is a tree ordered on key. Otherwise assume a hashtable. + shapes: (Optional.) Constraints on the shapes of tensors in an element. + A list of shape tuples or None. This list is the same length + as dtypes. If the shape of any tensors in the element are constrained, + all must be; shapes can be None if the shapes should not be constrained. + names: (Optional.) If provided, the `get()` and + `put()` methods will use dictionaries with these names as keys. + Must be None or a list or tuple of the same length as `dtypes`. + shared_name: (Optional.) A name to be used for the shared object. By + passing the same name to two different python objects they will share + the underlying staging area. Must be a string. + + Raises: + ValueError: If one of the arguments is invalid. + + """ + + super(MapStagingArea, self).__init__(dtypes, shapes, + names, shared_name, + capacity, memory_limit) + + # Defer to different methods depending if the map is ordered + self._ordered = ordered + + if ordered: + self._put_fn = gen_data_flow_ops.ordered_map_stage + self._pop_fn = gen_data_flow_ops.ordered_map_unstage + self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key + self._peek_fn = gen_data_flow_ops.ordered_map_peek + self._size_fn = gen_data_flow_ops.ordered_map_size + self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size + self._clear_fn = gen_data_flow_ops.ordered_map_clear + else: + self._put_fn = gen_data_flow_ops.map_stage + self._pop_fn = gen_data_flow_ops.map_unstage + self._popitem_fn = gen_data_flow_ops.map_unstage_no_key + self._peek_fn = gen_data_flow_ops.map_peek + self._size_fn = gen_data_flow_ops.map_size + self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size + self._clear_fn = gen_data_flow_ops.map_clear + + def put(self, key, vals, indices=None, name=None): + """ + Create an op that stores the (key, vals) pair in the staging area. + + Incomplete puts are possible, preferably using a dictionary for vals + as the appropriate dtypes and shapes can be inferred from the value names + dictionary key values. If vals is a list or tuple, indices must + also be specified so that the op knows at which element position + to perform the insert. + + This operation will block if the capacity or memory limit of this + container is reached. + + Args: + key: Key associated with the data + vals: Tensor (or a dict/tuple of Tensors) to place + into the staging area. + indices: (Optional) if vals is a tuple/list, this is required. + name: A name for the operation (optional) + + Returns: + The created op + + Raises: + ValueError: If the number or type of inputs don't match the staging area. + """ + + with ops.name_scope(name, "%s_put" % self._name, + self._scope_vals(vals)) as scope: + + vals, indices = self._check_put_dtypes(vals, indices) + + with ops.colocate_with(self._coloc_op): + op = self._put_fn(key, indices, vals, dtypes=self._dtypes, + shared_name=self._name, name=scope, + capacity=self._capacity, + memory_limit=self._memory_limit) + return op + + def _get_indices_and_dtypes(self, indices=None): + if indices is None: + indices = list(six.moves.range(len(self._dtypes))) + + if not isinstance(indices, (tuple, list)): + raise TypeError("Invalid indices type '%s'" % type(indices)) + + if len(indices) == 0: + raise ValueError("Empty indices") + + if all(isinstance(i, str) for i in indices): + if self._names is None: + raise ValueError("String indices provided '%s', but this Staging Area " + "was not created with names." % indices) + + try: + indices = [self._names.index(n) for n in indices] + except ValueError: + raise ValueError("Named index '%s' not in " + "Staging Area names '%s'" % (n, self._names)) + elif all(isinstance(i, int) for i in indices): + pass + else: + raise TypeError("Mixed types in indices '%s'. " + "May only be str or int" % indices) + + dtypes = [self._dtypes[i] for i in indices] + + return indices, dtypes + + + def peek(self, key, indices=None, name=None): + """ + Peeks at staging area data associated with the key. + + If the key is not in the staging area, it will block + until the associated (key, value) is inserted. + + Args: + key: Key associated with the required data + indices: Partial list of tensors to retrieve (optional). + A list of integer or string indices. + String indices are only valid if the Staging Area + has names associated with it. + name: A name for the operation (optional) + + Returns: + The created op + """ + + if name is None: + name = "%s_pop" % self._name + + indices, dtypes = self._get_indices_and_dtypes(indices) + + with ops.colocate_with(self._coloc_op): + result = self._peek_fn(key, shared_name=self._name, + indices=indices, + dtypes=dtypes, + name=name, + capacity=self._capacity, + memory_limit=self._memory_limit) + + return self._get_return_value(result, indices) + + def get(self, key=None, indices=None, name=None): + """ + If the key is provided, the associated (key, value) + is returned from the staging area. If the key is not + in the staging area, this method will block until + the associated (key, value) is inserted. + + If no key is provided and the staging area is ordered, + the (key, value) with the smallest key will be returned. + Otherwise, a random (key, value) will be returned. + + If the staging area is empty when this operation executes, + it will block until there is an element to dequeue. + + Args: + key: Key associated with the required data (Optional) + indices: Partial list of tensors to retrieve (optional). + A list of integer or string indices. + String indices are only valid if the Staging Area + has names associated with it. + name: A name for the operation (optional) + + Returns: + The created op + """ + if key is None: + return self._popitem(indices=indices, name=name) + else: + return self._pop(key, indices=indices, name=name) + + def _pop(self, key, indices=None, name=None): + """ + Remove and return the associated (key, value) + is returned from the staging area. If the key is not + in the staging area, this method will block until + the associated (key, value) is inserted. + + Args: + key: Key associated with the required data + indices: Partial list of tensors to retrieve (optional). + A list of integer or string indices. + String indices are only valid if the Staging Area + has names associated with it. + name: A name for the operation (optional) + + Returns: + The created op + """ + if name is None: + name = "%s_get" % self._name + + indices, dtypes = self._get_indices_and_dtypes(indices) + with ops.colocate_with(self._coloc_op): - ret = gen_data_flow_ops.unstage(dtypes=self._dtypes, - shared_name=self._name, name=name) + result = self._pop_fn(key, shared_name=self._name, + indices=indices, + dtypes=dtypes, + name=name, + capacity=self._capacity, + memory_limit=self._memory_limit) - curr_device_scope = control_flow_ops.no_op().device - if curr_device_scope != self._coloc_op.device: - for i in range(len(ret)): - ret[i] = array_ops.identity(ret[i]) + return key, self._get_return_value(result, indices) - for output, shape in zip(ret, self._shapes): - output.set_shape(shape) + def _popitem(self, indices=None, name=None): + """ + If the staging area is ordered, + the (key, value) with the smallest key will be returned. + Otherwise, a random (key, value) will be returned. + + If the staging area is empty when this operation executes, + it will block until there is an element to dequeue. + + Args: + key: Key associated with the required data + indices: Partial list of tensors to retrieve (optional). + A list of integer or string indices. + String indices are only valid if the Staging Area + has names associated with it. + name: A name for the operation (optional) + + Returns: + The created op + """ + if name is None: + name = "%s_get_nokey" % self._name + + indices, dtypes = self._get_indices_and_dtypes(indices) + + with ops.colocate_with(self._coloc_op): + key, result = self._popitem_fn(shared_name=self._name, + indices=indices, + dtypes=dtypes, + name=name, + capacity=self._capacity, + memory_limit=self._memory_limit) + + # Separate keys and results out from + # underlying namedtuple + key = self._create_device_transfers(key)[0] + result = self._get_return_value(result, indices) + + return key, result + + def size(self, name=None): + """ + Returns the number of elements in the staging area. + + Args: + name: A name for the operation (optional) + + Returns: + The created op + """ + if name is None: + name = "%s_size" % self._name + + return self._size_fn(shared_name=self._name, + name=name, dtypes=self._dtypes, + capacity=self._capacity, + memory_limit=self._memory_limit) + + def incomplete_size(self, name=None): + """ + Returns the number of incomplete elements in the staging area. + + Args: + name: A name for the operation (optional) + + Returns: + The created op + """ + if name is None: + name = "%s_incomplete_size" % self._name + + return self._incomplete_size_fn(shared_name=self._name, + name=name, dtypes=self._dtypes, + capacity=self._capacity, + memory_limit=self._memory_limit) + + + + def clear(self, name=None): + """ + Clears the staging area. + + Args: + name: A name for the operation (optional) + + Returns: + The created op + """ + if name is None: + name = "%s_clear" % self._name - return self._get_return_value(ret) + return self._clear_fn(shared_name=self._name, + name=name, dtypes=self._dtypes, + capacity=self._capacity, + memory_limit=self._memory_limit) class RecordInput(object): |