aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/hadoop
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-06-25 15:39:43 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-06-29 21:53:55 +0000
commit7e47b7733549b99dfa14aa0592eadbf384f0f036 (patch)
tree88439a9cbb080e9811b7cf6744fde6c1d2b39d42 /tensorflow/contrib/hadoop
parent21bf4812c5079764c817cb2e193f136d4e590d17 (diff)
Address review comments
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/hadoop')
-rw-r--r--tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc33
-rw-r--r--tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py3
-rw-r--r--tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py13
3 files changed, 22 insertions, 27 deletions
diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
index 36057bbd9b..c2ddbb858d 100644
--- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
+++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
@@ -30,7 +30,7 @@ class SequenceFileReader {
new io::BufferedInputStream(file, kSequenceFileBufferSize)) {}
Status ReadHeader() {
- std::string version;
+ string version;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &version));
if (version.substr(0, 3) != "SEQ" || version[3] != 6) {
return errors::InvalidArgument(
@@ -49,7 +49,7 @@ class SequenceFileReader {
"' is currently not supported");
}
- std::string buffer;
+ string buffer;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(2, &buffer));
compression_ = buffer[0];
block_compression_ = buffer[1];
@@ -83,12 +83,12 @@ class SequenceFileReader {
return Status::OK();
}
- Status ReadRecord(std::string* key, std::string* value) {
+ Status ReadRecord(string* key, string* value) {
uint32 length = 0;
TF_RETURN_IF_ERROR(ReadUInt32(&length));
if (length == static_cast<uint32>(-1)) {
// Sync marker.
- std::string sync_marker;
+ string sync_marker;
TF_RETURN_IF_ERROR(
input_stream_->ReadNBytes(kSyncMarkerSize, &sync_marker));
if (sync_marker != sync_marker_) {
@@ -114,7 +114,7 @@ class SequenceFileReader {
return Status::OK();
}
- Status ReadString(std::string* value) {
+ Status ReadString(string* value) {
int64 length = 0;
TF_RETURN_IF_ERROR(ReadVInt(&length));
if (value == nullptr) {
@@ -124,7 +124,7 @@ class SequenceFileReader {
}
Status ReadUInt32(uint32* value) {
- std::string buffer;
+ string buffer;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &buffer));
*value = (uint32(buffer[0]) << 24) | (uint32(buffer[1]) << 16) |
(uint32(buffer[2]) << 8) | uint32(buffer[3]);
@@ -132,7 +132,7 @@ class SequenceFileReader {
}
Status ReadVInt(int64* value) {
- std::string buffer;
+ string buffer;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(1, &buffer));
if (buffer[0] >= -112) {
*value = static_cast<int64>(buffer[0]);
@@ -165,15 +165,14 @@ class SequenceFileReader {
private:
std::unique_ptr<io::InputStreamInterface> input_stream_;
- std::string key_class_name_;
- std::string value_class_name_;
- std::string sync_marker_;
+ string key_class_name_;
+ string value_class_name_;
+ string sync_marker_;
bool compression_;
bool block_compression_;
- std::string compression_codec_class_name_;
+ string compression_codec_class_name_;
TF_DISALLOW_COPY_AND_ASSIGN(SequenceFileReader);
};
-}
class SequenceFileDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
@@ -252,17 +251,17 @@ class SequenceFileDatasetOp : public DatasetOpKernel {
do {
// We are currently processing a file, so try to read the next record.
if (reader_) {
- std::string key, value;
+ string key, value;
Status status = reader_->ReadRecord(&key, &value);
if (!errors::IsOutOfRange(status)) {
TF_RETURN_IF_ERROR(status);
Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
- key_tensor.scalar<std::string>()() = key;
+ key_tensor.scalar<string>()() = key;
out_tensors->emplace_back(std::move(key_tensor));
Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
- value_tensor.scalar<std::string>()() = value;
+ value_tensor.scalar<string>()() = value;
out_tensors->emplace_back(std::move(value_tensor));
*end_of_sequence = false;
@@ -312,7 +311,7 @@ class SequenceFileDatasetOp : public DatasetOpKernel {
return reader_->ReadHeader();
}
- // Resets all Parquet streams.
+ // Resets all Hadoop SequenceFile streams.
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
reader_.reset();
file_.reset();
@@ -330,6 +329,8 @@ class SequenceFileDatasetOp : public DatasetOpKernel {
DataTypeVector output_types_;
};
+}
+
REGISTER_KERNEL_BUILDER(Name("SequenceFileDataset").Device(DEVICE_CPU),
SequenceFileDatasetOp);
diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
index 108aaecdb7..8f6f4a20e3 100644
--- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
+++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
@@ -42,8 +42,7 @@ class SequenceFileDatasetTest(test.TestCase):
filename = os.path.join(
resource_loader.get_data_files_path(), 'testdata/string.seq')
- filenames = array_ops.placeholder_with_default(
- constant_op.constant([filename], dtypes.string), shape=[None])
+ filenames = constant_op.constant([filename], dtypes.string)
num_repeats = 2
dataset = hadoop_dataset_ops.SequenceFileDataset(
diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
index 661e298756..7e9e8094a8 100644
--- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
+++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
@@ -29,7 +29,7 @@ from tensorflow.python.framework import tensor_shape
class SequenceFileDataset(Dataset):
"""A Sequence File Dataset that reads the sequence file."""
- def __init__(self, filenames, output_types=(dtypes.string, dtypes.string)):
+ def __init__(self, filenames):
"""Create a `SequenceFileDataset`.
`SequenceFileDataset` allows a user to read data from a hadoop sequence
@@ -40,8 +40,7 @@ class SequenceFileDataset(Dataset):
For example:
```python
- dataset = tf.contrib.hadoop.SequenceFileDataset(
- "/foo/bar.seq", (tf.string, tf.string))
+ dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq")
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
# Prints the (key, value) pairs inside a hadoop sequence file.
@@ -54,14 +53,10 @@ class SequenceFileDataset(Dataset):
Args:
filenames: A `tf.string` tensor containing one or more filenames.
- output_types: A tuple of `tf.DType` objects representing the types of the
- key-value pairs returned. Only `(tf.string, tf.string)` is supported
- at the moment.
"""
super(SequenceFileDataset, self).__init__()
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
- self._output_types = output_types
def _as_variant_tensor(self):
return gen_dataset_ops.sequence_file_dataset(
@@ -69,7 +64,7 @@ class SequenceFileDataset(Dataset):
@property
def output_classes(self):
- return nest.map_structure(lambda _: ops.Tensor, self._output_types)
+ return ops.Tensor, ops.Tensor
@property
def output_shapes(self):
@@ -77,4 +72,4 @@ class SequenceFileDataset(Dataset):
@property
def output_types(self):
- return self._output_types
+ return dtypes.string, dtypes.string