"""Data Flow Operations.""" # pylint: disable=g-bad-name import re from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import common_shapes from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_data_flow_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * def _as_type_list(dtypes): """Convert dtypes to a list of types.""" assert dtypes is not None if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)): # We have a single type. return [dtypes] else: # We have a list or tuple of types. return list(dtypes) def _as_shape_list(shapes, dtypes): """Convert shapes to a list of tuples of int (or None).""" if shapes is None: return None if isinstance(shapes, tensor_shape.TensorShape): shapes = [shapes] if not isinstance(shapes, (tuple, list)): raise TypeError( "shapes must be a TensorShape or a list or tuple of TensorShapes.") if all(isinstance(shape, int) for shape in shapes): # We have a single shape. shapes = [shapes] shapes = [tensor_shape.as_shape(shape) for shape in shapes] if any(not shape.is_fully_defined() for shape in shapes): raise ValueError("All shapes must be fully defined.") return shapes # pylint: disable=protected-access class QueueBase(object): """Base class for queue implementations. A queue is a TensorFlow data structure that stores tensors across multiple steps, and exposes operations that enqueue and dequeue tensors. Each queue element is a tuple of one or more tensors, where each tuple component has a static dtype, and may have a static shape. The queue implementations support versions of enqueue and dequeue that handle single elements, versions that support enqueuing and dequeuing a batch of elements at once. See [`tf.FIFOQueue`](#FIFOQueue) and [`tf.RandomShuffleQueue`](#RandomShuffleQueue) for concrete implementations of this class, and instructions on how to create them. @@enqueue @@enqueue_many @@dequeue @@dequeue_many @@size @@close """ def __init__(self, dtypes, shapes, queue_ref): """Constructs a queue object from a queue reference. Args: dtypes: A list of types. The length of dtypes must equal the number of tensors in each element. shapes: Constraints on the shapes of tensors in an element: A list of shape tuples or None. This list is the same length as dtypes. If the shape of any tensors in the element are constrained, all must be; shapes can be None if the shapes should not be constrained. queue_ref: The queue reference, i.e. the output of the queue op. """ self._dtypes = dtypes if shapes is not None: self._shapes = [tensor_shape.TensorShape(s) for s in shapes] else: self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] self._queue_ref = queue_ref self._name = self._queue_ref.op.name.split("/")[-1] @staticmethod def from_list(index, queues): """Create a queue using the queue reference from `queues[index]`. Args: index: An integer scalar tensor that determines the input that gets selected. queues: A list of `QueueBase` objects. Returns: A `QueueBase` object. Raises: TypeError: when `queues` is not a list of `QueueBase` objects, or when the data types of `queues` are not all the same. """ if ((not queues) or (not isinstance(queues, list)) or (not all([isinstance(x, QueueBase) for x in queues]))): raise TypeError("A list of queues expected") dtypes = queues[0].dtypes if not all([dtypes == q.dtypes for q in queues[1:]]): raise TypeError("Queues do not have matching component dtypes.") queue_refs = [x.queue_ref for x in queues] selected_queue = control_flow_ops.ref_select(index, queue_refs) # TODO(josh11b): Unify the shapes of the queues too? return QueueBase(dtypes=dtypes, shapes=None, queue_ref=selected_queue) @property def queue_ref(self): """The underlying queue reference.""" return self._queue_ref @property def name(self): """The name of the underlying queue.""" return self._queue_ref.op.name @property def dtypes(self): """The list of dtypes for each component of a queue element.""" return self._dtypes def enqueue(self, vals, name=None): """Enqueues one element to this queue. If the queue is full when this operation executes, it will block until the element has been enqueued. Args: vals: The tuple of `Tensor` objects to be enqueued. name: A name for the operation (optional). Returns: The operation that enqueues a new tuple of tensors to the queue. """ if name is None: name = "%s_enqueue" % self._name ret = gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=name) # NOTE(mrry): Not using a shape function because we need access to # the Queue object. for val, shape in zip(ret.inputs[1:], self._shapes): val.get_shape().assert_is_compatible_with(shape) return ret def enqueue_many(self, vals, name=None): """Enqueues zero or elements to this queue. This operation slices each component tensor along the 0th dimension to make multiple queue elements. All of the tensors in `vals` must have the same size in the 0th dimension. If the queue is full when this operation executes, it will block until all of the elements have been enqueued. Args: vals: The tensor or tuple of tensors from which the queue elements are taken. name: A name for the operation (optional). Returns: The operation that enqueues a batch of tuples of tensors to the queue. """ if name is None: name = "%s_EnqueueMany" % self._name ret = gen_data_flow_ops._queue_enqueue_many( self._queue_ref, vals, name=name) # NOTE(mrry): Not using a shape function because we need access to # the `QueueBase` object. batch_dim = ret.inputs[1].get_shape()[0] for val, shape in zip(ret.inputs[1:], self._shapes): batch_dim.merge_with(val.get_shape()[0]) val.get_shape()[1:].assert_is_compatible_with(shape) return ret def dequeue(self, name=None): """Dequeues one element from this queue. If the queue is empty when this operation executes, it will block until there is an element to dequeue. Args: name: A name for the operation (optional). Returns: The tuple of tensors that was dequeued. """ if name is None: name = "%s_Dequeue" % self._name ret = gen_data_flow_ops._queue_dequeue( self._queue_ref, self._dtypes, name=name) # NOTE(mrry): Not using a shape function because we need access to # the `QueueBase` object. op = ret[0].op for output, shape in zip(op.values(), self._shapes): output.set_shape(shape) return ret if len(ret) != 1 else ret[0] def dequeue_many(self, n, name=None): """Dequeues and concatenates `n` elements from this queue. This operation concatenates queue-element component tensors along the 0th dimension to make a single component tensor. All of the components in the dequeued tuple will have size `n` in the 0th dimension. If the queue contains fewer than `n` elements when this operation executes, it will block until `n` elements have been dequeued. Args: n: A scalar `Tensor` containing the number of elements to dequeue. name: A name for the operation (optional). Returns: The tuple of concatenated tensors that was dequeued. """ if name is None: name = "%s_DequeueMany" % self._name ret = gen_data_flow_ops._queue_dequeue_many( self._queue_ref, n, self._dtypes, name=name) # NOTE(mrry): Not using a shape function because we need access to # the Queue object. op = ret[0].op batch_dim = tensor_shape.Dimension(tensor_util.ConstantValue(op.inputs[1])) for output, shape in zip(op.values(), self._shapes): output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape)) return ret if len(ret) != 1 else ret[0] def close(self, cancel_pending_enqueues=False, name=None): """Closes this queue. This operation signals that no more elements will be enqueued in the given queue. Subsequent `enqueue` and `enqueue_many` operations will fail. Subsequent `dequeue` and `dequeue_many` operations will continue to succeed if sufficient elements remain in the queue. Subsequent `dequeue` and `dequeue_many` operations that would block will fail immediately. If `cancel_pending_enqueues` is `True`, all pending requests will also be cancelled. Args: cancel_pending_enqueues: (Optional.) A boolean, defaulting to `False` (described above). name: A name for the operation (optional). Returns: The operation that closes the queue. """ if name is None: name = "%s_Close" % self._name return gen_data_flow_ops._queue_close( self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues, name=name) def size(self, name=None): """Compute the number of elements in this queue. Args: name: A name for the operation (optional). Returns: A scalar tensor containing the number of elements in this queue. """ if name is None: name = "%s_Size" % self._name return gen_data_flow_ops._queue_size(self._queue_ref, name=name) class RandomShuffleQueue(QueueBase): """A queue implementation that dequeues elements in a random order. See [`tf.QueueBase`](#QueueBase) for a description of the methods on this class. @@__init__ """ def __init__(self, capacity, min_after_dequeue, dtypes, shapes=None, seed=None, shared_name=None, name="random_shuffle_queue"): """Create a queue that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. If the `shapes` argument is specified, each component of a queue element must have the respective fixed shape. If it is unspecified, different queue elements may have different shapes, but the use of `dequeue_many` is disallowed. The `min_after_dequeue` argument allows the caller to specify a minimum number of elements that will remain in the queue after a `dequeue` or `dequeue_many` operation completes, to ensure a minimum level of mixing of elements. This invariant is maintained by blocking those operations until sufficient elements have been enqueued. The `min_after_dequeue` argument is ignored after the queue has been closed. Args: capacity: An integer. The upper bound on the number of elements that may be stored in this queue. min_after_dequeue: An integer (described above). dtypes: A list of `DType` objects. The length of `dtypes` must equal the number of tensors in each queue element. shapes: (Optional.) A list of fully-defined `TensorShape` objects, with the same length as `dtypes` or `None`. seed: A Python integer. Used to create a random seed. See [`set_random_seed`](constant_op.md#set_random_seed) for behavior. shared_name: (Optional.) If non-empty, this queue will be shared under the given name across multiple sessions. name: Optional name for the queue operation. """ dtypes = _as_type_list(dtypes) shapes = _as_shape_list(shapes, dtypes) seed1, seed2 = random_seed.get_seed(seed) queue_ref = gen_data_flow_ops._random_shuffle_queue( component_types=dtypes, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2, shared_name=shared_name, name=name) super(RandomShuffleQueue, self).__init__(dtypes, shapes, queue_ref) class FIFOQueue(QueueBase): """A queue implementation that dequeues elements in first-in-first out order. See [`tf.QueueBase`](#QueueBase) for a description of the methods on this class. @@__init__ """ def __init__(self, capacity, dtypes, shapes=None, shared_name=None, name="fifo_queue"): """Creates a queue that dequeues elements in a first-in first-out order. A `FIFOQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `FIFOQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. If the `shapes` argument is specified, each component of a queue element must have the respective fixed shape. If it is unspecified, different queue elements may have different shapes, but the use of `dequeue_many` is disallowed. Args: capacity: An integer. The upper bound on the number of elements that may be stored in this queue. dtypes: A list of `DType` objects. The length of `dtypes` must equal the number of tensors in each queue element. shapes: (Optional.) A list of fully-defined `TensorShape` objects, with the same length as `dtypes` or `None`. shared_name: (Optional.) If non-empty, this queue will be shared under the given name across multiple sessions. name: Optional name for the queue operation. """ dtypes = _as_type_list(dtypes) shapes = _as_shape_list(shapes, dtypes) queue_ref = gen_data_flow_ops._fifo_queue( component_types=dtypes, shapes=shapes, capacity=capacity, shared_name=shared_name, name=name) super(FIFOQueue, self).__init__(dtypes, shapes, queue_ref) # TODO(josh11b): class BatchQueue(QueueBase): # pylint: disable=protected-access class LookupTableBase(object): """Represents a lookup table that persists across different steps.""" def __init__(self, key_dtype, value_dtype, default_value, table_ref): """Construct a table object from a table reference. Args: key_dtype: The key data type of the table. value_dtype: The kvalue data type of the table. default_value: The scalar tensor to be used when a key is not present in the table. table_ref: The table reference, i.e. the output of the lookup table ops. """ self._key_dtype = types.as_dtype(key_dtype) self._value_dtype = types.as_dtype(value_dtype) self._shapes = [tensor_shape.TensorShape([1])] self._table_ref = table_ref self._name = self._table_ref.op.name.split("/")[-1] self._default_value = ops.convert_to_tensor(default_value, dtype=self._value_dtype) self._default_value.get_shape().merge_with(tensor_shape.scalar()) @property def table_ref(self): """Get the underlying table reference.""" return self._table_ref @property def key_dtype(self): """The key dtype supported by the table.""" return self._key_dtype @property def value_dtype(self): """The value dtype supported by the table.""" return self._value_dtype @property def name(self): """The name of the table.""" return self._name @property def default_value(self): """The default value of the table.""" return self._default_value def size(self, name=None): """Compute the number of elements in this table. Args: name: A name for the operation (optional). Returns: A scalar tensor containing the number of elements in this table. """ if name is None: name = "%s_Size" % self._name return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) def lookup(self, keys, name=None): """Returns the values for the given 'keys' tensor. If an element on the key tensor is not found in the table, the default_value is used. Args: keys: The tensor for the keys. name: Optional name for the op. Returns: The operation that looks up the keys. Raises: TypeError: when 'keys' or 'default_value' doesn't match the table data types. """ if name is None: name = "%s_lookup_table_find" % self._name if keys.dtype != self._key_dtype: raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % ( self._key_dtype, keys.dtype)) return gen_data_flow_ops._lookup_table_find( self._table_ref, keys, self._default_value, name=name) def initialize_from(self, keys, values, name=None): """Initialize the lookup table with the provided keys and values tensors. Construct an initializer object from keys and value tensors. Args: keys: The tensor for the keys. values: The tensor for the values. name: Optional name for the op. Returns: The operation that initializes a lookup table. Raises: TypeError: when the 'keys' and 'values' data type do not match the table key and value data types. """ if name is None: name = "%s_initialize_table" % self.name with ops.op_scope([keys, values], None, name): keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys") values = ops.convert_to_tensor(values, dtype=self.value_dtype, name="values") init_op = gen_data_flow_ops._initialize_table( self.table_ref, keys, values, name=name) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op def _check_table_dtypes(self, key_dtype, value_dtype): """Check that the given key_dtype and value_dtype matches the table dtypes'. Args: key_dtype: The key data type to check. value_dtype: The value data type to check. Raises: TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data types. """ if key_dtype != self.key_dtype: raise TypeError("Invalid key dtype, expected %s but got %s." % ( self.key_dtype, key_dtype)) if value_dtype != self.value_dtype: raise TypeError("Invalid value dtype, expected %s but got %s." % ( self.value_dtype, value_dtype)) class HashTable(LookupTableBase): """A generic hash table implementation.""" def __init__(self, key_dtype, value_dtype, default_value, shared_name=None, name="hash_table"): """Create a generic hash table. A table holds a key-value pairs. The key and value types are described by key_dtype and value_dtype respectively. Args: key_dtype: The key data type of the table. value_dtype: The kvalue data type of the table. default_value: The scalar tensor to be used when a key is not present in the table. shared_name: Optional. If non-empty, this table will be shared under the given name across multiple sessions. name: Optional name for the hash table op. Returns: A table object that can be used to lookup data. """ table_ref = gen_data_flow_ops._hash_table( shared_name=shared_name, key_dtype=key_dtype, value_dtype=value_dtype, name=name) super(HashTable, self).__init__(key_dtype, value_dtype, default_value, table_ref) def initialize_all_tables(name="init_all_tables"): """Returns an Op that initializes all tables of the default graph. Returns: An Op that initializes all tables. Note that if there are not tables the returned Op is a NoOp. """ initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) if initializers: return control_flow_ops.group(*initializers, name=name) return control_flow_ops.no_op(name=name) ops.NoGradient("LookupTableFind") ops.NoGradient("LookupTableSize") ops.NoGradient("HashTable") ops.NoGradient("InitializeTable") ops.RegisterShape("QueueSize")(common_shapes.scalar_shape) ops.RegisterShape("Queue")(common_shapes.scalar_shape) ops.RegisterShape("FIFOQueue")(common_shapes.scalar_shape) ops.RegisterShape("RandomShuffleQueue")(common_shapes.scalar_shape) # NOTE(mrry): The following ops use higher-level information in the # Queue class to provide shape information. ops.RegisterShape("QueueDequeue")(common_shapes.unknown_shape) ops.RegisterShape("QueueDequeueMany")(common_shapes.unknown_shape) ops.RegisterShape("QueueEnqueue")(common_shapes.unknown_shape) ops.RegisterShape("QueueEnqueueMany")(common_shapes.unknown_shape) @ops.RegisterShape("QueueClose") def _ScalarToVoidShape(op): """Shape function for ops that take a scalar and produce no outputs.""" unused_input_shape = op.inputs[0].get_shape().merge_with( tensor_shape.scalar()) return [] @ops.RegisterShape("DynamicPartition") def _DynamicPartitionShape(op): """Shape function for data_flow_ops.dynamic_partition.""" data_shape = op.inputs[0].get_shape() partitions_shape = op.inputs[1].get_shape() # If we don't know the rank of partitions, we don't know anything mid = partitions_shape.ndims if mid is None: result_shape = tensor_shape.unknown_shape() else: # data_shape must start with partitions_shape partitions_shape.assert_is_compatible_with(data_shape[:mid]) # The partition shape is dynamic in the 0th dimension, and matches # data_shape in the remaining dimensions. result_shape = tensor_shape.TensorShape([None]).concatenate( data_shape[mid:]) return [result_shape] * op.get_attr("num_partitions") @ops.RegisterShape("DynamicStitch") def _DynamicStitchShape(op): """Shape function for data_flow_ops.dynamic_stitch.""" num_partitions = op.get_attr("N") indices_shapes = [t.get_shape() for t in op.inputs[0:num_partitions]] data_shapes = [t.get_shape() for t in op.inputs[num_partitions:]] output_shape = tensor_shape.unknown_shape() extra_shape = tensor_shape.TensorShape(None) for indices_shape, data_shape in zip(indices_shapes, data_shapes): indices_ndims = indices_shape.ndims if indices_ndims is not None: # Assert that data_shape starts with indices_shape indices_shape.merge_with(data_shape[:indices_ndims]) # The rest belongs to output extra_shape = extra_shape.merge_with(data_shape[indices_ndims:]) return [tensor_shape.TensorShape([None]).concatenate(extra_shape)] @ops.RegisterShape("LookupTableFind") def _LookupTableFindShape(op): """Shape function for data_flow_ops._lookup_table_find.""" unused_table_shape = op.inputs[0].get_shape().merge_with( tensor_shape.scalar()) shape_in = op.inputs[1].get_shape() return [shape_in] @ops.RegisterShape("LookupTableSize") def _LookupTableSizeShape(op): """Shape function for data_flow_ops._lookup_table_find.""" unused_table_shape = op.inputs[0].get_shape().merge_with( tensor_shape.scalar()) return [tensor_shape.scalar()] @ops.RegisterShape("HashTable") def _HashTableShape(unused_op): """Shape function for data_flow_ops._hash_table.""" return [tensor_shape.scalar()] @ops.RegisterShape("InitializeTable") def _InitializeLookupTableShape(op): """Shape function for data_flow_ops._initialize_table.""" unused_table_shape = op.inputs[0].get_shape().merge_with( tensor_shape.scalar()) keys_shape = op.inputs[1].get_shape().with_rank(1) unused_values_shape = op.inputs[2].get_shape().merge_with(keys_shape) return []