diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2018-06-25 15:39:43 +0000 |
---|---|---|
committer | Yong Tang <yong.tang.github@outlook.com> | 2018-06-29 21:53:55 +0000 |
commit | 7e47b7733549b99dfa14aa0592eadbf384f0f036 (patch) | |
tree | 88439a9cbb080e9811b7cf6744fde6c1d2b39d42 /tensorflow/contrib/hadoop | |
parent | 21bf4812c5079764c817cb2e193f136d4e590d17 (diff) |
Address review comments
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/hadoop')
3 files changed, 22 insertions, 27 deletions
diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc index 36057bbd9b..c2ddbb858d 100644 --- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc +++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc @@ -30,7 +30,7 @@ class SequenceFileReader { new io::BufferedInputStream(file, kSequenceFileBufferSize)) {} Status ReadHeader() { - std::string version; + string version; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &version)); if (version.substr(0, 3) != "SEQ" || version[3] != 6) { return errors::InvalidArgument( @@ -49,7 +49,7 @@ class SequenceFileReader { "' is currently not supported"); } - std::string buffer; + string buffer; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(2, &buffer)); compression_ = buffer[0]; block_compression_ = buffer[1]; @@ -83,12 +83,12 @@ class SequenceFileReader { return Status::OK(); } - Status ReadRecord(std::string* key, std::string* value) { + Status ReadRecord(string* key, string* value) { uint32 length = 0; TF_RETURN_IF_ERROR(ReadUInt32(&length)); if (length == static_cast<uint32>(-1)) { // Sync marker. - std::string sync_marker; + string sync_marker; TF_RETURN_IF_ERROR( input_stream_->ReadNBytes(kSyncMarkerSize, &sync_marker)); if (sync_marker != sync_marker_) { @@ -114,7 +114,7 @@ class SequenceFileReader { return Status::OK(); } - Status ReadString(std::string* value) { + Status ReadString(string* value) { int64 length = 0; TF_RETURN_IF_ERROR(ReadVInt(&length)); if (value == nullptr) { @@ -124,7 +124,7 @@ class SequenceFileReader { } Status ReadUInt32(uint32* value) { - std::string buffer; + string buffer; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &buffer)); *value = (uint32(buffer[0]) << 24) | (uint32(buffer[1]) << 16) | (uint32(buffer[2]) << 8) | uint32(buffer[3]); @@ -132,7 +132,7 @@ class SequenceFileReader { } Status ReadVInt(int64* value) { - std::string buffer; + string buffer; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(1, &buffer)); if (buffer[0] >= -112) { *value = static_cast<int64>(buffer[0]); @@ -165,15 +165,14 @@ class SequenceFileReader { private: std::unique_ptr<io::InputStreamInterface> input_stream_; - std::string key_class_name_; - std::string value_class_name_; - std::string sync_marker_; + string key_class_name_; + string value_class_name_; + string sync_marker_; bool compression_; bool block_compression_; - std::string compression_codec_class_name_; + string compression_codec_class_name_; TF_DISALLOW_COPY_AND_ASSIGN(SequenceFileReader); }; -} class SequenceFileDatasetOp : public DatasetOpKernel { public: using DatasetOpKernel::DatasetOpKernel; @@ -252,17 +251,17 @@ class SequenceFileDatasetOp : public DatasetOpKernel { do { // We are currently processing a file, so try to read the next record. if (reader_) { - std::string key, value; + string key, value; Status status = reader_->ReadRecord(&key, &value); if (!errors::IsOutOfRange(status)) { TF_RETURN_IF_ERROR(status); Tensor key_tensor(ctx->allocator({}), DT_STRING, {}); - key_tensor.scalar<std::string>()() = key; + key_tensor.scalar<string>()() = key; out_tensors->emplace_back(std::move(key_tensor)); Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); - value_tensor.scalar<std::string>()() = value; + value_tensor.scalar<string>()() = value; out_tensors->emplace_back(std::move(value_tensor)); *end_of_sequence = false; @@ -312,7 +311,7 @@ class SequenceFileDatasetOp : public DatasetOpKernel { return reader_->ReadHeader(); } - // Resets all Parquet streams. + // Resets all Hadoop SequenceFile streams. void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { reader_.reset(); file_.reset(); @@ -330,6 +329,8 @@ class SequenceFileDatasetOp : public DatasetOpKernel { DataTypeVector output_types_; }; +} + REGISTER_KERNEL_BUILDER(Name("SequenceFileDataset").Device(DEVICE_CPU), SequenceFileDatasetOp); diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py index 108aaecdb7..8f6f4a20e3 100644 --- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py +++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py @@ -42,8 +42,7 @@ class SequenceFileDatasetTest(test.TestCase): filename = os.path.join( resource_loader.get_data_files_path(), 'testdata/string.seq') - filenames = array_ops.placeholder_with_default( - constant_op.constant([filename], dtypes.string), shape=[None]) + filenames = constant_op.constant([filename], dtypes.string) num_repeats = 2 dataset = hadoop_dataset_ops.SequenceFileDataset( diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index 661e298756..7e9e8094a8 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -29,7 +29,7 @@ from tensorflow.python.framework import tensor_shape class SequenceFileDataset(Dataset): """A Sequence File Dataset that reads the sequence file.""" - def __init__(self, filenames, output_types=(dtypes.string, dtypes.string)): + def __init__(self, filenames): """Create a `SequenceFileDataset`. `SequenceFileDataset` allows a user to read data from a hadoop sequence @@ -40,8 +40,7 @@ class SequenceFileDataset(Dataset): For example: ```python - dataset = tf.contrib.hadoop.SequenceFileDataset( - "/foo/bar.seq", (tf.string, tf.string)) + dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq") iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() # Prints the (key, value) pairs inside a hadoop sequence file. @@ -54,14 +53,10 @@ class SequenceFileDataset(Dataset): Args: filenames: A `tf.string` tensor containing one or more filenames. - output_types: A tuple of `tf.DType` objects representing the types of the - key-value pairs returned. Only `(tf.string, tf.string)` is supported - at the moment. """ super(SequenceFileDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - self._output_types = output_types def _as_variant_tensor(self): return gen_dataset_ops.sequence_file_dataset( @@ -69,7 +64,7 @@ class SequenceFileDataset(Dataset): @property def output_classes(self): - return nest.map_structure(lambda _: ops.Tensor, self._output_types) + return ops.Tensor, ops.Tensor @property def output_shapes(self): @@ -77,4 +72,4 @@ class SequenceFileDataset(Dataset): @property def output_types(self): - return self._output_types + return dtypes.string, dtypes.string |