aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/data_flow_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/data_flow_ops.py')
-rw-r--r--tensorflow/python/ops/data_flow_ops.py31
1 files changed, 5 insertions, 26 deletions
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 829aa99284..4eead79531 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -1582,7 +1582,7 @@ class StagingArea(BaseStagingArea):
This is mostly useful for limiting the number of tensors on
devices such as GPUs.
- All get() and peek() commands block if the requested data
+ All get() and peek() commands block if the the requested data
is not present in the Staging Area.
"""
@@ -2155,8 +2155,7 @@ class RecordInput(object):
parallelism=1,
shift_ratio=0,
seed=0,
- name=None,
- batches=None):
+ name=None):
"""Constructs a RecordInput Op.
Args:
@@ -2170,18 +2169,12 @@ class RecordInput(object):
seed: Specify the random number seed used by generator that randomizes
records.
name: Optional name for the operation.
- batches: None by default, creating a single batch op. Otherwise specifies
- how many batches to create, which are returned as a list when
- `get_yield_op()` is called. An example use case is to split processing
- between devices on one computer.
Raises:
ValueError: If one of the arguments is invalid.
"""
+
self._batch_size = batch_size
- if batches is not None:
- self._batch_size *= batches
- self._batches = batches
self._file_pattern = file_pattern
self._buffer_size = buffer_size
self._parallelism = parallelism
@@ -2190,11 +2183,8 @@ class RecordInput(object):
self._name = name
def get_yield_op(self):
- """Adds a node that yields a group of records every time it is executed.
- If RecordInput `batches` parameter is not None, it yields a list of
- record batches with the specified `batch_size`.
- """
- records = gen_data_flow_ops.record_input(
+ """Add a node that yields a minibatch every time it is executed."""
+ return gen_data_flow_ops.record_input(
file_pattern=self._file_pattern,
file_buffer_size=self._buffer_size,
file_parallelism=self._parallelism,
@@ -2202,14 +2192,3 @@ class RecordInput(object):
batch_size=self._batch_size,
file_random_seed=self._seed,
name=self._name)
- if self._batches is None:
- return records
- else:
- with ops.name_scope(self._name):
- batch_list = [[] for i in six.moves.range(self._batches)]
- records = array_ops.split(records, self._batch_size, 0)
- records = [array_ops.reshape(record, []) for record in records]
- for index, protobuf in zip(six.moves.range(len(records)), records):
- batch_index = index % self._batches
- batch_list[batch_index].append(protobuf)
- return batch_list