diff options
Diffstat (limited to 'tensorflow/python/ops/data_flow_ops.py')
-rw-r--r-- | tensorflow/python/ops/data_flow_ops.py | 31 |
1 files changed, 26 insertions, 5 deletions
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 4eead79531..829aa99284 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -1582,7 +1582,7 @@ class StagingArea(BaseStagingArea): This is mostly useful for limiting the number of tensors on devices such as GPUs. - All get() and peek() commands block if the the requested data + All get() and peek() commands block if the requested data is not present in the Staging Area. """ @@ -2155,7 +2155,8 @@ class RecordInput(object): parallelism=1, shift_ratio=0, seed=0, - name=None): + name=None, + batches=None): """Constructs a RecordInput Op. Args: @@ -2169,12 +2170,18 @@ class RecordInput(object): seed: Specify the random number seed used by generator that randomizes records. name: Optional name for the operation. + batches: None by default, creating a single batch op. Otherwise specifies + how many batches to create, which are returned as a list when + `get_yield_op()` is called. An example use case is to split processing + between devices on one computer. Raises: ValueError: If one of the arguments is invalid. """ - self._batch_size = batch_size + if batches is not None: + self._batch_size *= batches + self._batches = batches self._file_pattern = file_pattern self._buffer_size = buffer_size self._parallelism = parallelism @@ -2183,8 +2190,11 @@ class RecordInput(object): self._name = name def get_yield_op(self): - """Add a node that yields a minibatch every time it is executed.""" - return gen_data_flow_ops.record_input( + """Adds a node that yields a group of records every time it is executed. + If RecordInput `batches` parameter is not None, it yields a list of + record batches with the specified `batch_size`. + """ + records = gen_data_flow_ops.record_input( file_pattern=self._file_pattern, file_buffer_size=self._buffer_size, file_parallelism=self._parallelism, @@ -2192,3 +2202,14 @@ class RecordInput(object): batch_size=self._batch_size, file_random_seed=self._seed, name=self._name) + if self._batches is None: + return records + else: + with ops.name_scope(self._name): + batch_list = [[] for i in six.moves.range(self._batches)] + records = array_ops.split(records, self._batch_size, 0) + records = [array_ops.reshape(record, []) for record in records] + for index, protobuf in zip(six.moves.range(len(records)), records): + batch_index = index % self._batches + batch_list[batch_index].append(protobuf) + return batch_list |