aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-08-29 12:16:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 12:20:54 -0700
commit8a0b57966dd19be02404512f3db445f88203d92f (patch)
tree9a7c56a1ca8d23ed3fc8b6990bbd2bbd08fac335 /tensorflow/contrib/data
parentc69f998bf46bd3c67acb6a797c53f13dae5f85e5 (diff)
[tf.data] Adds an optional label_key argument to `make_batch_featured_dataset()` for extracting the label from the feature dictionaries. If label_key is provided, returned dataset will be a tuple of feature dictionaries and label.
PiperOrigin-RevId: 210766469
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py65
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py4
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py24
4 files changed, 119 insertions, 32 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 64fe6dae24..fd00cdc5c6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -47,22 +47,50 @@ class ReadBatchFeaturesTest(
# Basic test: read from file 0.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, 0, num_epochs=num_epochs)
+ self.verify_records(
+ sess,
+ batch_size,
+ 0,
+ num_epochs=num_epochs,
+ label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
+ self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from file 1.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[1],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, 1, num_epochs=num_epochs)
+ self.verify_records(
+ sess,
+ batch_size,
+ 1,
+ num_epochs=num_epochs,
+ label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
@@ -90,7 +118,7 @@ class ReadBatchFeaturesTest(
with self.test_session() as sess:
sess.run(init_op)
- for file_batch, _, _, _, record_batch in self._next_expected_batch(
+ for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
actual_batch = sess.run(next_element)
self.assertAllEqual(file_batch, actual_batch["file"])
@@ -155,6 +183,25 @@ class ReadBatchFeaturesTest(
with self.session(graph=g) as sess:
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size,
reader_num_threads=reader_num_threads,
@@ -175,16 +222,20 @@ class ReadBatchFeaturesTest(
# Basic test: read from file 0.
outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size,
drop_final_batch=True).make_one_shot_iterator().get_next()
- for _, tensor in outputs.items():
+ for tensor in nest.flatten(outputs):
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
def testIndefiniteRepeatShapeInference(self):
dataset = self.make_batch_feature(
- filenames=self.test_filenames[0], num_epochs=None, batch_size=32)
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=None,
+ batch_size=32)
for shape, clazz in zip(nest.flatten(dataset.output_shapes),
nest.flatten(dataset.output_classes)):
if issubclass(clazz, ops.Tensor):
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
index e63bc4c720..08b9f03816 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -76,6 +76,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
filenames,
num_epochs,
batch_size,
+ label_key=None,
reader_num_threads=1,
parser_num_threads=1,
shuffle=False,
@@ -91,8 +92,10 @@ class ReadBatchFeaturesTestBase(test.TestCase):
features={
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
- "keywords": parsing_ops.VarLenFeature(dtypes.string)
+ "keywords": parsing_ops.VarLenFeature(dtypes.string),
+ "label": parsing_ops.FixedLenFeature([], dtypes.string),
},
+ label_key=label_key,
reader=core_readers.TFRecordDataset,
num_epochs=self.num_epochs,
shuffle=shuffle,
@@ -101,7 +104,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
parser_num_threads=parser_num_threads,
drop_final_batch=drop_final_batch)
- def _record(self, f, r):
+ def _record(self, f, r, l):
example = example_pb2.Example(
features=feature_pb2.Features(
feature={
@@ -114,7 +117,11 @@ class ReadBatchFeaturesTestBase(test.TestCase):
"keywords":
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
- value=self._get_keywords(f, r)))
+ value=self._get_keywords(f, r))),
+ "label":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[compat.as_bytes(l)]))
}))
return example.SerializeToString()
@@ -139,23 +146,30 @@ class ReadBatchFeaturesTestBase(test.TestCase):
filenames.append(fn)
writer = python_io.TFRecordWriter(fn)
for j in range(self._num_records):
- writer.write(self._record(i, j))
+ writer.write(self._record(i, j, "fake-label"))
writer.close()
return filenames
- def _run_actual_batch(self, outputs, sess):
- file_op = outputs["file"]
- keywords_indices_op = outputs["keywords"].indices
- keywords_values_op = outputs["keywords"].values
- keywords_dense_shape_op = outputs["keywords"].dense_shape
- record_op = outputs["record"]
+ def _run_actual_batch(self, outputs, sess, label_key_provided=False):
+ if label_key_provided:
+ # outputs would be a tuple of (feature dict, label)
+ label_op = outputs[1]
+ features_op = outputs[0]
+ else:
+ features_op = outputs
+ label_op = features_op["label"]
+ file_op = features_op["file"]
+ keywords_indices_op = features_op["keywords"].indices
+ keywords_values_op = features_op["keywords"].values
+ keywords_dense_shape_op = features_op["keywords"].dense_shape
+ record_op = features_op["record"]
return sess.run([
file_op, keywords_indices_op, keywords_values_op,
- keywords_dense_shape_op, record_op
+ keywords_dense_shape_op, record_op, label_op
])
- def _next_actual_batch(self, sess):
- return self._run_actual_batch(self.outputs, sess)
+ def _next_actual_batch(self, sess, label_key_provided=False):
+ return self._run_actual_batch(self.outputs, sess, label_key_provided)
def _interleave(self, iterators, cycle_length):
pending_iterators = iterators
@@ -188,7 +202,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
def _next_record(file_indices):
for j in file_indices:
for i in range(self._num_records):
- yield j, i
+ yield j, i, compat.as_bytes("fake-label")
def _next_record_interleaved(file_indices, cycle_length):
return self._interleave([_next_record([i]) for i in file_indices],
@@ -200,6 +214,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
+ label_batch = []
for _ in range(num_epochs):
if cycle_length == 1:
next_records = _next_record(file_indices)
@@ -208,6 +223,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
for record in next_records:
f = record[0]
r = record[1]
+ label_batch.append(record[2])
file_batch.append(f)
record_batch.append(r)
keywords = self._get_keywords(f, r)
@@ -219,7 +235,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
if len(file_batch) == batch_size:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
- [batch_size, keywords_batch_max_len], record_batch
+ [batch_size, keywords_batch_max_len], record_batch, label_batch
]
file_batch = []
keywords_batch_indices = []
@@ -227,10 +243,11 @@ class ReadBatchFeaturesTestBase(test.TestCase):
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
+ label_batch = []
if file_batch:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
- [len(file_batch), keywords_batch_max_len], record_batch
+ [len(file_batch), keywords_batch_max_len], record_batch, label_batch
]
def verify_records(self,
@@ -238,6 +255,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
batch_size,
file_index=None,
num_epochs=1,
+ label_key_provided=False,
interleave_cycle_length=1):
if file_index is not None:
file_indices = [file_index]
@@ -245,8 +263,12 @@ class ReadBatchFeaturesTestBase(test.TestCase):
file_indices = range(self._num_files)
for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length):
- actual_batch = self._next_actual_batch(sess)
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=interleave_cycle_length):
+ actual_batch = self._next_actual_batch(
+ sess, label_key_provided=label_key_provided)
for i in range(len(expected_batch)):
self.assertAllEqual(expected_batch[i], actual_batch[i])
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 53c22628c7..7abb610706 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -209,10 +209,10 @@ class FeatureStatsDatasetTest(
self._assertSummaryHasCount(
sess.run(summary_t), "record_stats:feature-values", total_records)
self._assertSummaryHasSum(
- sess.run(summary_t), "record_stats:features", total_records * 3)
+ sess.run(summary_t), "record_stats:features", total_records * 4)
self._assertSummaryHasSum(
sess.run(summary_t), "record_stats:feature-values",
- self._sum_keywords(1) * num_epochs + 2 * total_records)
+ self._sum_keywords(1) * num_epochs + 3 * total_records)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 29005859d7..7f09ba71dc 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -659,6 +659,7 @@ def make_batched_features_dataset(file_pattern,
batch_size,
features,
reader=core_readers.TFRecordDataset,
+ label_key=None,
reader_args=None,
num_epochs=None,
shuffle=True,
@@ -671,6 +672,9 @@ def make_batched_features_dataset(file_pattern,
drop_final_batch=False):
"""Returns a `Dataset` of feature dictionaries from `Example` protos.
+ If label_key argument is provided, returns a `Dataset` of tuple
+ comprising of feature dictionaries and label.
+
Example:
```
@@ -721,6 +725,9 @@ def make_batched_features_dataset(file_pattern,
reader: A function or class that can be
called with a `filenames` tensor and (optional) `reader_args` and returns
a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
+ label_key: (Optional) A string corresponding to the key labels are stored in
+ `tf.Examples`. If provided, it must be one of the `features` key,
+ otherwise results in `ValueError`.
reader_args: Additional arguments to pass to the reader class.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. Defaults to `None`.
@@ -746,8 +753,11 @@ def make_batched_features_dataset(file_pattern,
`False`.
Returns:
- A dataset of `dict` elements. Each `dict` maps feature keys to
- `Tensor` or `SparseTensor` objects.
+ A dataset of `dict` elements, (or a tuple of `dict` elements and label).
+ Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
+
+ Raises:
+ ValueError: If `label_key` is not one of the `features` keys.
"""
# Create dataset of all matching filenames
filenames = _get_file_names(file_pattern, False)
@@ -786,9 +796,13 @@ def make_batched_features_dataset(file_pattern,
parsing_ops.parse_example_dataset(
features, num_parallel_calls=parser_num_threads))
- # TODO(rachelim): Add an optional label_name argument for extracting the label
- # from the features dictionary, to comply with the type expected by the
- # input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function.
+ if label_key:
+ if label_key not in features:
+ raise ValueError(
+ "The `label_key` provided (%r) must be one of the `features` keys." %
+ label_key)
+ dataset = dataset.map(lambda x: (x, x.pop(label_key)))
+
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset