diff options
Diffstat (limited to 'tensorflow/python/ops/data_flow_ops.py')
-rw-r--r-- | tensorflow/python/ops/data_flow_ops.py | 680 |
1 files changed, 680 insertions, 0 deletions
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py new file mode 100644 index 0000000000..5c8ab66297 --- /dev/null +++ b/tensorflow/python/ops/data_flow_ops.py @@ -0,0 +1,680 @@ +"""Data Flow Operations.""" +# pylint: disable=g-bad-name +import re + +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import types +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import common_shapes +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_data_flow_ops +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_data_flow_ops import * + + +def _as_type_list(dtypes): + """Convert dtypes to a list of types.""" + assert dtypes is not None + if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)): + # We have a single type. + return [dtypes] + else: + # We have a list or tuple of types. + return list(dtypes) + + +def _as_shape_list(shapes, dtypes): + """Convert shapes to a list of tuples of int (or None).""" + if shapes is None: return None + if isinstance(shapes, tensor_shape.TensorShape): + shapes = [shapes] + if not isinstance(shapes, (tuple, list)): + raise TypeError( + "shapes must be a TensorShape or a list or tuple of TensorShapes.") + if all(isinstance(shape, int) for shape in shapes): + # We have a single shape. + shapes = [shapes] + shapes = [tensor_shape.as_shape(shape) for shape in shapes] + if any(not shape.is_fully_defined() for shape in shapes): + raise ValueError("All shapes must be fully defined.") + return shapes + + +# pylint: disable=protected-access +class QueueBase(object): + """Base class for queue implementations. + + A queue is a TensorFlow data structure that stores tensors across + multiple steps, and exposes operations that enqueue and dequeue + tensors. + + Each queue element is a tuple of one or more tensors, where each + tuple component has a static dtype, and may have a static shape. The + queue implementations support versions of enqueue and dequeue that + handle single elements, versions that support enqueuing and + dequeuing a batch of elements at once. + + See [`tf.FIFOQueue`](#FIFOQueue) and + [`tf.RandomShuffleQueue`](#RandomShuffleQueue) for concrete + implementations of this class, and instructions on how to create + them. + + @@enqueue + @@enqueue_many + + @@dequeue + @@dequeue_many + + @@size + + @@close + + """ + + def __init__(self, dtypes, shapes, queue_ref): + """Constructs a queue object from a queue reference. + + Args: + dtypes: A list of types. The length of dtypes must equal the number + of tensors in each element. + shapes: Constraints on the shapes of tensors in an element: + A list of shape tuples or None. This list is the same length + as dtypes. If the shape of any tensors in the element are constrained, + all must be; shapes can be None if the shapes should not be constrained. + queue_ref: The queue reference, i.e. the output of the queue op. + """ + self._dtypes = dtypes + if shapes is not None: + self._shapes = [tensor_shape.TensorShape(s) for s in shapes] + else: + self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] + self._queue_ref = queue_ref + self._name = self._queue_ref.op.name.split("/")[-1] + + @staticmethod + def from_list(index, queues): + """Create a queue using the queue reference from `queues[index]`. + + Args: + index: An integer scalar tensor that determines the input that gets + selected. + queues: A list of `QueueBase` objects. + + Returns: + A `QueueBase` object. + + Raises: + TypeError: when `queues` is not a list of `QueueBase` objects, + or when the data types of `queues` are not all the same. + """ + if ((not queues) or + (not isinstance(queues, list)) or + (not all([isinstance(x, QueueBase) for x in queues]))): + raise TypeError("A list of queues expected") + + dtypes = queues[0].dtypes + if not all([dtypes == q.dtypes for q in queues[1:]]): + raise TypeError("Queues do not have matching component dtypes.") + + queue_refs = [x.queue_ref for x in queues] + selected_queue = control_flow_ops.ref_select(index, queue_refs) + # TODO(josh11b): Unify the shapes of the queues too? + return QueueBase(dtypes=dtypes, shapes=None, queue_ref=selected_queue) + + @property + def queue_ref(self): + """The underlying queue reference.""" + return self._queue_ref + + @property + def name(self): + """The name of the underlying queue.""" + return self._queue_ref.op.name + + @property + def dtypes(self): + """The list of dtypes for each component of a queue element.""" + return self._dtypes + + def enqueue(self, vals, name=None): + """Enqueues one element to this queue. + + If the queue is full when this operation executes, it will block + until the element has been enqueued. + + Args: + vals: The tuple of `Tensor` objects to be enqueued. + name: A name for the operation (optional). + + Returns: + The operation that enqueues a new tuple of tensors to the queue. + """ + if name is None: + name = "%s_enqueue" % self._name + ret = gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=name) + + # NOTE(mrry): Not using a shape function because we need access to + # the Queue object. + for val, shape in zip(ret.inputs[1:], self._shapes): + val.get_shape().assert_is_compatible_with(shape) + + return ret + + def enqueue_many(self, vals, name=None): + """Enqueues zero or elements to this queue. + + This operation slices each component tensor along the 0th dimension to + make multiple queue elements. All of the tensors in `vals` must have the + same size in the 0th dimension. + + If the queue is full when this operation executes, it will block + until all of the elements have been enqueued. + + Args: + vals: The tensor or tuple of tensors from which the queue elements + are taken. + name: A name for the operation (optional). + + Returns: + The operation that enqueues a batch of tuples of tensors to the queue. + """ + if name is None: + name = "%s_EnqueueMany" % self._name + + ret = gen_data_flow_ops._queue_enqueue_many( + self._queue_ref, vals, name=name) + + # NOTE(mrry): Not using a shape function because we need access to + # the `QueueBase` object. + batch_dim = ret.inputs[1].get_shape()[0] + for val, shape in zip(ret.inputs[1:], self._shapes): + batch_dim.merge_with(val.get_shape()[0]) + val.get_shape()[1:].assert_is_compatible_with(shape) + + return ret + + def dequeue(self, name=None): + """Dequeues one element from this queue. + + If the queue is empty when this operation executes, it will block + until there is an element to dequeue. + + Args: + name: A name for the operation (optional). + + Returns: + The tuple of tensors that was dequeued. + """ + if name is None: + name = "%s_Dequeue" % self._name + ret = gen_data_flow_ops._queue_dequeue( + self._queue_ref, self._dtypes, name=name) + + # NOTE(mrry): Not using a shape function because we need access to + # the `QueueBase` object. + op = ret[0].op + for output, shape in zip(op.values(), self._shapes): + output.set_shape(shape) + + return ret if len(ret) != 1 else ret[0] + + def dequeue_many(self, n, name=None): + """Dequeues and concatenates `n` elements from this queue. + + This operation concatenates queue-element component tensors along + the 0th dimension to make a single component tensor. All of the + components in the dequeued tuple will have size `n` in the 0th dimension. + + If the queue contains fewer than `n` elements when this operation + executes, it will block until `n` elements have been dequeued. + + Args: + n: A scalar `Tensor` containing the number of elements to dequeue. + name: A name for the operation (optional). + + Returns: + The tuple of concatenated tensors that was dequeued. + """ + if name is None: + name = "%s_DequeueMany" % self._name + + ret = gen_data_flow_ops._queue_dequeue_many( + self._queue_ref, n, self._dtypes, name=name) + + # NOTE(mrry): Not using a shape function because we need access to + # the Queue object. + op = ret[0].op + batch_dim = tensor_shape.Dimension(tensor_util.ConstantValue(op.inputs[1])) + for output, shape in zip(op.values(), self._shapes): + output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape)) + + return ret if len(ret) != 1 else ret[0] + + def close(self, cancel_pending_enqueues=False, name=None): + """Closes this queue. + + This operation signals that no more elements will be enqueued in + the given queue. Subsequent `enqueue` and `enqueue_many` + operations will fail. Subsequent `dequeue` and `dequeue_many` + operations will continue to succeed if sufficient elements remain + in the queue. Subsequent `dequeue` and `dequeue_many` operations + that would block will fail immediately. + + If `cancel_pending_enqueues` is `True`, all pending requests will also + be cancelled. + + Args: + cancel_pending_enqueues: (Optional.) A boolean, defaulting to + `False` (described above). + name: A name for the operation (optional). + + Returns: + The operation that closes the queue. + """ + if name is None: + name = "%s_Close" % self._name + return gen_data_flow_ops._queue_close( + self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues, + name=name) + + def size(self, name=None): + """Compute the number of elements in this queue. + + Args: + name: A name for the operation (optional). + + Returns: + A scalar tensor containing the number of elements in this queue. + """ + if name is None: + name = "%s_Size" % self._name + return gen_data_flow_ops._queue_size(self._queue_ref, name=name) + + +class RandomShuffleQueue(QueueBase): + """A queue implementation that dequeues elements in a random order. + + See [`tf.QueueBase`](#QueueBase) for a description of the methods on + this class. + + @@__init__ + """ + + def __init__(self, capacity, min_after_dequeue, dtypes, shapes=None, + seed=None, shared_name=None, name="random_shuffle_queue"): + """Create a queue that dequeues elements in a random order. + + A `RandomShuffleQueue` has bounded capacity; supports multiple + concurrent producers and consumers; and provides exactly-once + delivery. + + A `RandomShuffleQueue` holds a list of up to `capacity` + elements. Each element is a fixed-length tuple of tensors whose + dtypes are described by `dtypes`, and whose shapes are optionally + described by the `shapes` argument. + + If the `shapes` argument is specified, each component of a queue + element must have the respective fixed shape. If it is + unspecified, different queue elements may have different shapes, + but the use of `dequeue_many` is disallowed. + + The `min_after_dequeue` argument allows the caller to specify a + minimum number of elements that will remain in the queue after a + `dequeue` or `dequeue_many` operation completes, to ensure a + minimum level of mixing of elements. This invariant is maintained + by blocking those operations until sufficient elements have been + enqueued. The `min_after_dequeue` argument is ignored after the + queue has been closed. + + Args: + capacity: An integer. The upper bound on the number of elements + that may be stored in this queue. + min_after_dequeue: An integer (described above). + dtypes: A list of `DType` objects. The length of `dtypes` must equal + the number of tensors in each queue element. + shapes: (Optional.) A list of fully-defined `TensorShape` objects, + with the same length as `dtypes` or `None`. + seed: A Python integer. Used to create a random seed. + See [`set_random_seed`](constant_op.md#set_random_seed) for behavior. + shared_name: (Optional.) If non-empty, this queue will be shared under + the given name across multiple sessions. + name: Optional name for the queue operation. + """ + dtypes = _as_type_list(dtypes) + shapes = _as_shape_list(shapes, dtypes) + seed1, seed2 = random_seed.get_seed(seed) + queue_ref = gen_data_flow_ops._random_shuffle_queue( + component_types=dtypes, shapes=shapes, capacity=capacity, + min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2, + shared_name=shared_name, name=name) + + super(RandomShuffleQueue, self).__init__(dtypes, shapes, queue_ref) + + +class FIFOQueue(QueueBase): + """A queue implementation that dequeues elements in first-in-first out order. + + See [`tf.QueueBase`](#QueueBase) for a description of the methods on + this class. + + @@__init__ + """ + + def __init__(self, capacity, dtypes, shapes=None, shared_name=None, + name="fifo_queue"): + """Creates a queue that dequeues elements in a first-in first-out order. + + A `FIFOQueue` has bounded capacity; supports multiple concurrent + producers and consumers; and provides exactly-once delivery. + + A `FIFOQueue` holds a list of up to `capacity` elements. Each + element is a fixed-length tuple of tensors whose dtypes are + described by `dtypes`, and whose shapes are optionally described + by the `shapes` argument. + + If the `shapes` argument is specified, each component of a queue + element must have the respective fixed shape. If it is + unspecified, different queue elements may have different shapes, + but the use of `dequeue_many` is disallowed. + + Args: + capacity: An integer. The upper bound on the number of elements + that may be stored in this queue. + dtypes: A list of `DType` objects. The length of `dtypes` must equal + the number of tensors in each queue element. + shapes: (Optional.) A list of fully-defined `TensorShape` objects, + with the same length as `dtypes` or `None`. + shared_name: (Optional.) If non-empty, this queue will be shared under + the given name across multiple sessions. + name: Optional name for the queue operation. + """ + dtypes = _as_type_list(dtypes) + shapes = _as_shape_list(shapes, dtypes) + queue_ref = gen_data_flow_ops._fifo_queue( + component_types=dtypes, shapes=shapes, capacity=capacity, + shared_name=shared_name, name=name) + + super(FIFOQueue, self).__init__(dtypes, shapes, queue_ref) + + +# TODO(josh11b): class BatchQueue(QueueBase): + + +# pylint: disable=protected-access +class LookupTableBase(object): + """Represents a lookup table that persists across different steps.""" + + def __init__(self, key_dtype, value_dtype, default_value, table_ref): + """Construct a table object from a table reference. + + Args: + key_dtype: The key data type of the table. + value_dtype: The kvalue data type of the table. + default_value: The scalar tensor to be used when a key is not present in + the table. + table_ref: The table reference, i.e. the output of the lookup table ops. + """ + self._key_dtype = types.as_dtype(key_dtype) + self._value_dtype = types.as_dtype(value_dtype) + self._shapes = [tensor_shape.TensorShape([1])] + self._table_ref = table_ref + self._name = self._table_ref.op.name.split("/")[-1] + self._default_value = ops.convert_to_tensor(default_value, + dtype=self._value_dtype) + self._default_value.get_shape().merge_with(tensor_shape.scalar()) + + @property + def table_ref(self): + """Get the underlying table reference.""" + return self._table_ref + + @property + def key_dtype(self): + """The key dtype supported by the table.""" + return self._key_dtype + + @property + def value_dtype(self): + """The value dtype supported by the table.""" + return self._value_dtype + + @property + def name(self): + """The name of the table.""" + return self._name + + @property + def default_value(self): + """The default value of the table.""" + return self._default_value + + def size(self, name=None): + """Compute the number of elements in this table. + + Args: + name: A name for the operation (optional). + + Returns: + A scalar tensor containing the number of elements in this table. + """ + if name is None: + name = "%s_Size" % self._name + return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) + + def lookup(self, keys, name=None): + """Returns the values for the given 'keys' tensor. + + If an element on the key tensor is not found in the table, the default_value + is used. + + Args: + keys: The tensor for the keys. + name: Optional name for the op. + + Returns: + The operation that looks up the keys. + + Raises: + TypeError: when 'keys' or 'default_value' doesn't match the table data + types. + """ + if name is None: + name = "%s_lookup_table_find" % self._name + + if keys.dtype != self._key_dtype: + raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % ( + self._key_dtype, keys.dtype)) + + return gen_data_flow_ops._lookup_table_find( + self._table_ref, keys, self._default_value, name=name) + + def initialize_from(self, keys, values, name=None): + """Initialize the lookup table with the provided keys and values tensors. + + Construct an initializer object from keys and value tensors. + + Args: + keys: The tensor for the keys. + values: The tensor for the values. + name: Optional name for the op. + + Returns: + The operation that initializes a lookup table. + + Raises: + TypeError: when the 'keys' and 'values' data type do not match the table + key and value data types. + """ + if name is None: + name = "%s_initialize_table" % self.name + with ops.op_scope([keys, values], None, name): + keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys") + values = ops.convert_to_tensor(values, dtype=self.value_dtype, + name="values") + + init_op = gen_data_flow_ops._initialize_table( + self.table_ref, keys, values, name=name) + ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) + return init_op + + def _check_table_dtypes(self, key_dtype, value_dtype): + """Check that the given key_dtype and value_dtype matches the table dtypes'. + + Args: + key_dtype: The key data type to check. + value_dtype: The value data type to check. + + Raises: + TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data + types. + """ + if key_dtype != self.key_dtype: + raise TypeError("Invalid key dtype, expected %s but got %s." % ( + self.key_dtype, key_dtype)) + if value_dtype != self.value_dtype: + raise TypeError("Invalid value dtype, expected %s but got %s." % ( + self.value_dtype, value_dtype)) + + +class HashTable(LookupTableBase): + """A generic hash table implementation.""" + + def __init__(self, key_dtype, value_dtype, default_value, shared_name=None, + name="hash_table"): + """Create a generic hash table. + + A table holds a key-value pairs. The key and value types are + described by key_dtype and value_dtype respectively. + + Args: + key_dtype: The key data type of the table. + value_dtype: The kvalue data type of the table. + default_value: The scalar tensor to be used when a key is not present in + the table. + shared_name: Optional. If non-empty, this table will be shared under + the given name across multiple sessions. + name: Optional name for the hash table op. + + Returns: + A table object that can be used to lookup data. + """ + table_ref = gen_data_flow_ops._hash_table( + shared_name=shared_name, key_dtype=key_dtype, + value_dtype=value_dtype, name=name) + + super(HashTable, self).__init__(key_dtype, value_dtype, default_value, + table_ref) + + +def initialize_all_tables(name="init_all_tables"): + """Returns an Op that initializes all tables of the default graph. + + Returns: + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + """ + initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) + if initializers: + return control_flow_ops.group(*initializers, name=name) + return control_flow_ops.no_op(name=name) + + +ops.NoGradient("LookupTableFind") +ops.NoGradient("LookupTableSize") +ops.NoGradient("HashTable") +ops.NoGradient("InitializeTable") + + +ops.RegisterShape("QueueSize")(common_shapes.scalar_shape) +ops.RegisterShape("Queue")(common_shapes.scalar_shape) +ops.RegisterShape("FIFOQueue")(common_shapes.scalar_shape) +ops.RegisterShape("RandomShuffleQueue")(common_shapes.scalar_shape) + + +# NOTE(mrry): The following ops use higher-level information in the +# Queue class to provide shape information. +ops.RegisterShape("QueueDequeue")(common_shapes.unknown_shape) +ops.RegisterShape("QueueDequeueMany")(common_shapes.unknown_shape) +ops.RegisterShape("QueueEnqueue")(common_shapes.unknown_shape) +ops.RegisterShape("QueueEnqueueMany")(common_shapes.unknown_shape) + + +@ops.RegisterShape("QueueClose") +def _ScalarToVoidShape(op): + """Shape function for ops that take a scalar and produce no outputs.""" + unused_input_shape = op.inputs[0].get_shape().merge_with( + tensor_shape.scalar()) + return [] + + +@ops.RegisterShape("DynamicPartition") +def _DynamicPartitionShape(op): + """Shape function for data_flow_ops.dynamic_partition.""" + data_shape = op.inputs[0].get_shape() + partitions_shape = op.inputs[1].get_shape() + # If we don't know the rank of partitions, we don't know anything + mid = partitions_shape.ndims + if mid is None: + result_shape = tensor_shape.unknown_shape() + else: + # data_shape must start with partitions_shape + partitions_shape.assert_is_compatible_with(data_shape[:mid]) + # The partition shape is dynamic in the 0th dimension, and matches + # data_shape in the remaining dimensions. + result_shape = tensor_shape.TensorShape([None]).concatenate( + data_shape[mid:]) + return [result_shape] * op.get_attr("num_partitions") + + +@ops.RegisterShape("DynamicStitch") +def _DynamicStitchShape(op): + """Shape function for data_flow_ops.dynamic_stitch.""" + num_partitions = op.get_attr("N") + indices_shapes = [t.get_shape() for t in op.inputs[0:num_partitions]] + data_shapes = [t.get_shape() for t in op.inputs[num_partitions:]] + output_shape = tensor_shape.unknown_shape() + extra_shape = tensor_shape.TensorShape(None) + for indices_shape, data_shape in zip(indices_shapes, data_shapes): + indices_ndims = indices_shape.ndims + if indices_ndims is not None: + # Assert that data_shape starts with indices_shape + indices_shape.merge_with(data_shape[:indices_ndims]) + # The rest belongs to output + extra_shape = extra_shape.merge_with(data_shape[indices_ndims:]) + return [tensor_shape.TensorShape([None]).concatenate(extra_shape)] + + +@ops.RegisterShape("LookupTableFind") +def _LookupTableFindShape(op): + """Shape function for data_flow_ops._lookup_table_find.""" + unused_table_shape = op.inputs[0].get_shape().merge_with( + tensor_shape.scalar()) + shape_in = op.inputs[1].get_shape() + return [shape_in] + + +@ops.RegisterShape("LookupTableSize") +def _LookupTableSizeShape(op): + """Shape function for data_flow_ops._lookup_table_find.""" + unused_table_shape = op.inputs[0].get_shape().merge_with( + tensor_shape.scalar()) + return [tensor_shape.scalar()] + + +@ops.RegisterShape("HashTable") +def _HashTableShape(unused_op): + """Shape function for data_flow_ops._hash_table.""" + return [tensor_shape.scalar()] + + +@ops.RegisterShape("InitializeTable") +def _InitializeLookupTableShape(op): + """Shape function for data_flow_ops._initialize_table.""" + unused_table_shape = op.inputs[0].get_shape().merge_with( + tensor_shape.scalar()) + keys_shape = op.inputs[1].get_shape().with_rank(1) + unused_values_shape = op.inputs[2].get_shape().merge_with(keys_shape) + return [] |