diff options
author | 2018-08-29 12:16:39 -0700 | |
---|---|---|
committer | 2018-08-29 12:20:54 -0700 | |
commit | 8a0b57966dd19be02404512f3db445f88203d92f (patch) | |
tree | 9a7c56a1ca8d23ed3fc8b6990bbd2bbd08fac335 /tensorflow/contrib/data | |
parent | c69f998bf46bd3c67acb6a797c53f13dae5f85e5 (diff) |
[tf.data] Adds an optional label_key argument to `make_batch_featured_dataset()` for extracting the label from the feature dictionaries. If label_key is provided, returned dataset will be a tuple of feature dictionaries and label.
PiperOrigin-RevId: 210766469
Diffstat (limited to 'tensorflow/contrib/data')
4 files changed, 119 insertions, 32 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 64fe6dae24..fd00cdc5c6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -47,22 +47,50 @@ class ReadBatchFeaturesTest( # Basic test: read from file 0. self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], + label_key="label", num_epochs=num_epochs, batch_size=batch_size).make_one_shot_iterator().get_next() - self.verify_records(sess, batch_size, 0, num_epochs=num_epochs) + self.verify_records( + sess, + batch_size, + 0, + num_epochs=num_epochs, + label_key_provided=True) with self.assertRaises(errors.OutOfRangeError): - self._next_actual_batch(sess) + self._next_actual_batch(sess, label_key_provided=True) with ops.Graph().as_default() as g: with self.session(graph=g) as sess: # Basic test: read from file 1. self.outputs = self.make_batch_feature( filenames=self.test_filenames[1], + label_key="label", num_epochs=num_epochs, batch_size=batch_size).make_one_shot_iterator().get_next() - self.verify_records(sess, batch_size, 1, num_epochs=num_epochs) + self.verify_records( + sess, + batch_size, + 1, + num_epochs=num_epochs, + label_key_provided=True) with self.assertRaises(errors.OutOfRangeError): - self._next_actual_batch(sess) + self._next_actual_batch(sess, label_key_provided=True) + + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + # Basic test: read from both files. + self.outputs = self.make_batch_feature( + filenames=self.test_filenames, + label_key="label", + num_epochs=num_epochs, + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records( + sess, + batch_size, + num_epochs=num_epochs, + label_key_provided=True) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess, label_key_provided=True) with ops.Graph().as_default() as g: with self.session(graph=g) as sess: @@ -90,7 +118,7 @@ class ReadBatchFeaturesTest( with self.test_session() as sess: sess.run(init_op) - for file_batch, _, _, _, record_batch in self._next_expected_batch( + for file_batch, _, _, _, record_batch, _ in self._next_expected_batch( range(self._num_files), 2, 10): actual_batch = sess.run(next_element) self.assertAllEqual(file_batch, actual_batch["file"]) @@ -155,6 +183,25 @@ class ReadBatchFeaturesTest( with self.session(graph=g) as sess: self.outputs = self.make_batch_feature( filenames=self.test_filenames, + label_key="label", + num_epochs=num_epochs, + batch_size=batch_size, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() + self.verify_records( + sess, + batch_size, + num_epochs=num_epochs, + label_key_provided=True, + interleave_cycle_length=reader_num_threads) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess, label_key_provided=True) + + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + self.outputs = self.make_batch_feature( + filenames=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, reader_num_threads=reader_num_threads, @@ -175,16 +222,20 @@ class ReadBatchFeaturesTest( # Basic test: read from file 0. outputs = self.make_batch_feature( filenames=self.test_filenames[0], + label_key="label", num_epochs=num_epochs, batch_size=batch_size, drop_final_batch=True).make_one_shot_iterator().get_next() - for _, tensor in outputs.items(): + for tensor in nest.flatten(outputs): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) def testIndefiniteRepeatShapeInference(self): dataset = self.make_batch_feature( - filenames=self.test_filenames[0], num_epochs=None, batch_size=32) + filenames=self.test_filenames[0], + label_key="label", + num_epochs=None, + batch_size=32) for shape, clazz in zip(nest.flatten(dataset.output_shapes), nest.flatten(dataset.output_classes)): if issubclass(clazz, ops.Tensor): diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py index e63bc4c720..08b9f03816 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -76,6 +76,7 @@ class ReadBatchFeaturesTestBase(test.TestCase): filenames, num_epochs, batch_size, + label_key=None, reader_num_threads=1, parser_num_threads=1, shuffle=False, @@ -91,8 +92,10 @@ class ReadBatchFeaturesTestBase(test.TestCase): features={ "file": parsing_ops.FixedLenFeature([], dtypes.int64), "record": parsing_ops.FixedLenFeature([], dtypes.int64), - "keywords": parsing_ops.VarLenFeature(dtypes.string) + "keywords": parsing_ops.VarLenFeature(dtypes.string), + "label": parsing_ops.FixedLenFeature([], dtypes.string), }, + label_key=label_key, reader=core_readers.TFRecordDataset, num_epochs=self.num_epochs, shuffle=shuffle, @@ -101,7 +104,7 @@ class ReadBatchFeaturesTestBase(test.TestCase): parser_num_threads=parser_num_threads, drop_final_batch=drop_final_batch) - def _record(self, f, r): + def _record(self, f, r, l): example = example_pb2.Example( features=feature_pb2.Features( feature={ @@ -114,7 +117,11 @@ class ReadBatchFeaturesTestBase(test.TestCase): "keywords": feature_pb2.Feature( bytes_list=feature_pb2.BytesList( - value=self._get_keywords(f, r))) + value=self._get_keywords(f, r))), + "label": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=[compat.as_bytes(l)])) })) return example.SerializeToString() @@ -139,23 +146,30 @@ class ReadBatchFeaturesTestBase(test.TestCase): filenames.append(fn) writer = python_io.TFRecordWriter(fn) for j in range(self._num_records): - writer.write(self._record(i, j)) + writer.write(self._record(i, j, "fake-label")) writer.close() return filenames - def _run_actual_batch(self, outputs, sess): - file_op = outputs["file"] - keywords_indices_op = outputs["keywords"].indices - keywords_values_op = outputs["keywords"].values - keywords_dense_shape_op = outputs["keywords"].dense_shape - record_op = outputs["record"] + def _run_actual_batch(self, outputs, sess, label_key_provided=False): + if label_key_provided: + # outputs would be a tuple of (feature dict, label) + label_op = outputs[1] + features_op = outputs[0] + else: + features_op = outputs + label_op = features_op["label"] + file_op = features_op["file"] + keywords_indices_op = features_op["keywords"].indices + keywords_values_op = features_op["keywords"].values + keywords_dense_shape_op = features_op["keywords"].dense_shape + record_op = features_op["record"] return sess.run([ file_op, keywords_indices_op, keywords_values_op, - keywords_dense_shape_op, record_op + keywords_dense_shape_op, record_op, label_op ]) - def _next_actual_batch(self, sess): - return self._run_actual_batch(self.outputs, sess) + def _next_actual_batch(self, sess, label_key_provided=False): + return self._run_actual_batch(self.outputs, sess, label_key_provided) def _interleave(self, iterators, cycle_length): pending_iterators = iterators @@ -188,7 +202,7 @@ class ReadBatchFeaturesTestBase(test.TestCase): def _next_record(file_indices): for j in file_indices: for i in range(self._num_records): - yield j, i + yield j, i, compat.as_bytes("fake-label") def _next_record_interleaved(file_indices, cycle_length): return self._interleave([_next_record([i]) for i in file_indices], @@ -200,6 +214,7 @@ class ReadBatchFeaturesTestBase(test.TestCase): keywords_batch_max_len = 0 record_batch = [] batch_index = 0 + label_batch = [] for _ in range(num_epochs): if cycle_length == 1: next_records = _next_record(file_indices) @@ -208,6 +223,7 @@ class ReadBatchFeaturesTestBase(test.TestCase): for record in next_records: f = record[0] r = record[1] + label_batch.append(record[2]) file_batch.append(f) record_batch.append(r) keywords = self._get_keywords(f, r) @@ -219,7 +235,7 @@ class ReadBatchFeaturesTestBase(test.TestCase): if len(file_batch) == batch_size: yield [ file_batch, keywords_batch_indices, keywords_batch_values, - [batch_size, keywords_batch_max_len], record_batch + [batch_size, keywords_batch_max_len], record_batch, label_batch ] file_batch = [] keywords_batch_indices = [] @@ -227,10 +243,11 @@ class ReadBatchFeaturesTestBase(test.TestCase): keywords_batch_max_len = 0 record_batch = [] batch_index = 0 + label_batch = [] if file_batch: yield [ file_batch, keywords_batch_indices, keywords_batch_values, - [len(file_batch), keywords_batch_max_len], record_batch + [len(file_batch), keywords_batch_max_len], record_batch, label_batch ] def verify_records(self, @@ -238,6 +255,7 @@ class ReadBatchFeaturesTestBase(test.TestCase): batch_size, file_index=None, num_epochs=1, + label_key_provided=False, interleave_cycle_length=1): if file_index is not None: file_indices = [file_index] @@ -245,8 +263,12 @@ class ReadBatchFeaturesTestBase(test.TestCase): file_indices = range(self._num_files) for expected_batch in self._next_expected_batch( - file_indices, batch_size, num_epochs, interleave_cycle_length): - actual_batch = self._next_actual_batch(sess) + file_indices, + batch_size, + num_epochs, + cycle_length=interleave_cycle_length): + actual_batch = self._next_actual_batch( + sess, label_key_provided=label_key_provided) for i in range(len(expected_batch)): self.assertAllEqual(expected_batch[i], actual_batch[i]) diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 53c22628c7..7abb610706 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -209,10 +209,10 @@ class FeatureStatsDatasetTest( self._assertSummaryHasCount( sess.run(summary_t), "record_stats:feature-values", total_records) self._assertSummaryHasSum( - sess.run(summary_t), "record_stats:features", total_records * 3) + sess.run(summary_t), "record_stats:features", total_records * 4) self._assertSummaryHasSum( sess.run(summary_t), "record_stats:feature-values", - self._sum_keywords(1) * num_epochs + 2 * total_records) + self._sum_keywords(1) * num_epochs + 3 * total_records) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 29005859d7..7f09ba71dc 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -659,6 +659,7 @@ def make_batched_features_dataset(file_pattern, batch_size, features, reader=core_readers.TFRecordDataset, + label_key=None, reader_args=None, num_epochs=None, shuffle=True, @@ -671,6 +672,9 @@ def make_batched_features_dataset(file_pattern, drop_final_batch=False): """Returns a `Dataset` of feature dictionaries from `Example` protos. + If label_key argument is provided, returns a `Dataset` of tuple + comprising of feature dictionaries and label. + Example: ``` @@ -721,6 +725,9 @@ def make_batched_features_dataset(file_pattern, reader: A function or class that can be called with a `filenames` tensor and (optional) `reader_args` and returns a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. + label_key: (Optional) A string corresponding to the key labels are stored in + `tf.Examples`. If provided, it must be one of the `features` key, + otherwise results in `ValueError`. reader_args: Additional arguments to pass to the reader class. num_epochs: Integer specifying the number of times to read through the dataset. If None, cycles through the dataset forever. Defaults to `None`. @@ -746,8 +753,11 @@ def make_batched_features_dataset(file_pattern, `False`. Returns: - A dataset of `dict` elements. Each `dict` maps feature keys to - `Tensor` or `SparseTensor` objects. + A dataset of `dict` elements, (or a tuple of `dict` elements and label). + Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects. + + Raises: + ValueError: If `label_key` is not one of the `features` keys. """ # Create dataset of all matching filenames filenames = _get_file_names(file_pattern, False) @@ -786,9 +796,13 @@ def make_batched_features_dataset(file_pattern, parsing_ops.parse_example_dataset( features, num_parallel_calls=parser_num_threads)) - # TODO(rachelim): Add an optional label_name argument for extracting the label - # from the features dictionary, to comply with the type expected by the - # input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function. + if label_key: + if label_key not in features: + raise ValueError( + "The `label_key` provided (%r) must be one of the `features` keys." % + label_key) + dataset = dataset.map(lambda x: (x, x.pop(label_key))) + dataset = dataset.prefetch(prefetch_buffer_size) return dataset |