aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/data_flow_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/data_flow_ops.py')
-rw-r--r--tensorflow/python/ops/data_flow_ops.py31
1 files changed, 26 insertions, 5 deletions
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 4eead79531..829aa99284 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -1582,7 +1582,7 @@ class StagingArea(BaseStagingArea):
This is mostly useful for limiting the number of tensors on
devices such as GPUs.
- All get() and peek() commands block if the the requested data
+ All get() and peek() commands block if the requested data
is not present in the Staging Area.
"""
@@ -2155,7 +2155,8 @@ class RecordInput(object):
parallelism=1,
shift_ratio=0,
seed=0,
- name=None):
+ name=None,
+ batches=None):
"""Constructs a RecordInput Op.
Args:
@@ -2169,12 +2170,18 @@ class RecordInput(object):
seed: Specify the random number seed used by generator that randomizes
records.
name: Optional name for the operation.
+ batches: None by default, creating a single batch op. Otherwise specifies
+ how many batches to create, which are returned as a list when
+ `get_yield_op()` is called. An example use case is to split processing
+ between devices on one computer.
Raises:
ValueError: If one of the arguments is invalid.
"""
-
self._batch_size = batch_size
+ if batches is not None:
+ self._batch_size *= batches
+ self._batches = batches
self._file_pattern = file_pattern
self._buffer_size = buffer_size
self._parallelism = parallelism
@@ -2183,8 +2190,11 @@ class RecordInput(object):
self._name = name
def get_yield_op(self):
- """Add a node that yields a minibatch every time it is executed."""
- return gen_data_flow_ops.record_input(
+ """Adds a node that yields a group of records every time it is executed.
+ If RecordInput `batches` parameter is not None, it yields a list of
+ record batches with the specified `batch_size`.
+ """
+ records = gen_data_flow_ops.record_input(
file_pattern=self._file_pattern,
file_buffer_size=self._buffer_size,
file_parallelism=self._parallelism,
@@ -2192,3 +2202,14 @@ class RecordInput(object):
batch_size=self._batch_size,
file_random_seed=self._seed,
name=self._name)
+ if self._batches is None:
+ return records
+ else:
+ with ops.name_scope(self._name):
+ batch_list = [[] for i in six.moves.range(self._batches)]
+ records = array_ops.split(records, self._batch_size, 0)
+ records = [array_ops.reshape(record, []) for record in records]
+ for index, protobuf in zip(six.moves.range(len(records)), records):
+ batch_index = index % self._batches
+ batch_list[batch_index].append(protobuf)
+ return batch_list