aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py')
-rw-r--r--tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py13
1 files changed, 4 insertions, 9 deletions
diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
index 661e298756..7e9e8094a8 100644
--- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
+++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
@@ -29,7 +29,7 @@ from tensorflow.python.framework import tensor_shape
class SequenceFileDataset(Dataset):
"""A Sequence File Dataset that reads the sequence file."""
- def __init__(self, filenames, output_types=(dtypes.string, dtypes.string)):
+ def __init__(self, filenames):
"""Create a `SequenceFileDataset`.
`SequenceFileDataset` allows a user to read data from a hadoop sequence
@@ -40,8 +40,7 @@ class SequenceFileDataset(Dataset):
For example:
```python
- dataset = tf.contrib.hadoop.SequenceFileDataset(
- "/foo/bar.seq", (tf.string, tf.string))
+ dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq")
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
# Prints the (key, value) pairs inside a hadoop sequence file.
@@ -54,14 +53,10 @@ class SequenceFileDataset(Dataset):
Args:
filenames: A `tf.string` tensor containing one or more filenames.
- output_types: A tuple of `tf.DType` objects representing the types of the
- key-value pairs returned. Only `(tf.string, tf.string)` is supported
- at the moment.
"""
super(SequenceFileDataset, self).__init__()
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
- self._output_types = output_types
def _as_variant_tensor(self):
return gen_dataset_ops.sequence_file_dataset(
@@ -69,7 +64,7 @@ class SequenceFileDataset(Dataset):
@property
def output_classes(self):
- return nest.map_structure(lambda _: ops.Tensor, self._output_types)
+ return ops.Tensor, ops.Tensor
@property
def output_shapes(self):
@@ -77,4 +72,4 @@ class SequenceFileDataset(Dataset):
@property
def output_types(self):
- return self._output_types
+ return dtypes.string, dtypes.string