aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/hadoop
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 14:44:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 14:44:20 -0700
commit0980e844c115bdffbbdd7d993355633c48a6100e (patch)
tree838b60294464aee13ed6dbb491dfd607d1da1a4c /tensorflow/contrib/hadoop
parent146ee1595820b0d7ac50f272ec741133b24ab4b3 (diff)
parent36db4d00e46c6ebecaf9e5a73ca1994df98086de (diff)
Merge pull request #19534 from yongtang:hadoop
PiperOrigin-RevId: 208112353
Diffstat (limited to 'tensorflow/contrib/hadoop')
-rw-r--r--tensorflow/contrib/hadoop/BUILD117
-rw-r--r--tensorflow/contrib/hadoop/__init__.py32
-rw-r--r--tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc339
-rw-r--r--tensorflow/contrib/hadoop/ops/dataset_ops.cc29
-rw-r--r--tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py66
-rwxr-xr-xtensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seqbin0 -> 603 bytes
-rw-r--r--tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py75
-rw-r--r--tensorflow/contrib/hadoop/python/ops/hadoop_op_loader.py24
8 files changed, 682 insertions, 0 deletions
diff --git a/tensorflow/contrib/hadoop/BUILD b/tensorflow/contrib/hadoop/BUILD
new file mode 100644
index 0000000000..ccad31efa1
--- /dev/null
+++ b/tensorflow/contrib/hadoop/BUILD
@@ -0,0 +1,117 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_py_test",
+)
+
+filegroup(
+ name = "test_data",
+ srcs = glob(["python/kernel_tests/testdata/*"]),
+)
+
+py_library(
+ name = "hadoop",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = ["ops/dataset_ops.cc"],
+ deps = [
+ ":dataset_kernels",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["dataset_ops"],
+)
+
+cc_library(
+ name = "dataset_kernels",
+ srcs = ["kernels/hadoop_dataset_ops.cc"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+py_library(
+ name = "dataset_ops",
+ srcs = [
+ "python/ops/hadoop_dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":hadoop_op_loader",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_dataset_ops",
+ out = "python/ops/gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/hadoop:dataset_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "dataset_ops_kernels",
+ deps = [
+ ":dataset_kernels",
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "hadoop_op_loader",
+ srcs = ["python/ops/hadoop_op_loader.py"],
+ dso = ["//tensorflow/contrib/hadoop:_dataset_ops.so"],
+ kernels = [
+ ":dataset_ops_kernels",
+ "//tensorflow/contrib/hadoop:dataset_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_dataset_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+tf_py_test(
+ name = "hadoop_test",
+ srcs = ["python/kernel_tests/hadoop_test.py"],
+ additional_deps = [
+ ":hadoop",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ data = [
+ ":test_data",
+ ],
+ tags = [
+ "notap",
+ ],
+)
diff --git a/tensorflow/contrib/hadoop/__init__.py b/tensorflow/contrib/hadoop/__init__.py
new file mode 100644
index 0000000000..abf8cd4845
--- /dev/null
+++ b/tensorflow/contrib/hadoop/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sequence File Dataset.
+
+@@SequenceFileDataset
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.hadoop.python.ops.hadoop_dataset_ops import SequenceFileDataset
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "SequenceFileDataset",
+]
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
new file mode 100644
index 0000000000..b510994152
--- /dev/null
+++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
@@ -0,0 +1,339 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/platform/file_system.h"
+
+namespace tensorflow {
+namespace {
+
+static const size_t kSyncMarkerSize = 16;
+static const size_t kSequenceFileBufferSize = 1024 * 1024;
+
+class SequenceFileReader {
+ public:
+ explicit SequenceFileReader(RandomAccessFile* file)
+ : input_stream_(
+ new io::BufferedInputStream(file, kSequenceFileBufferSize)) {}
+
+ Status ReadHeader() {
+ string version;
+ TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &version));
+ if (version.substr(0, 3) != "SEQ" || version[3] != 6) {
+ return errors::InvalidArgument(
+ "sequence file header must starts with `SEQ6`, received \"",
+ version.substr(0, 3), static_cast<int>(version[3]), "\"");
+ }
+ TF_RETURN_IF_ERROR(ReadString(&key_class_name_));
+ TF_RETURN_IF_ERROR(ReadString(&value_class_name_));
+
+ // At the moment we only support `org.apache.hadoop.io.Text` for key/value.
+ // TODO (yongtang): Add more class name support.
+ if (key_class_name_ != "org.apache.hadoop.io.Text" ||
+ value_class_name_ != "org.apache.hadoop.io.Text") {
+ return errors::Unimplemented("key/value of '", key_class_name_, "/",
+ value_class_name_,
+ "' is currently not supported");
+ }
+
+ string buffer;
+ TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(2, &buffer));
+ compression_ = buffer[0];
+ block_compression_ = buffer[1];
+ if (compression_ || block_compression_) {
+ TF_RETURN_IF_ERROR(ReadString(&compression_codec_class_name_));
+ }
+
+ // At the moment no compression is supported.
+ // TODO (yongtang): Add compression support.
+ if (compression_ || block_compression_) {
+ return errors::Unimplemented("compression is currently not supported");
+ }
+
+ // Not interested in metadata for now.
+ uint32 num_metadata_pairs = 0;
+ TF_RETURN_IF_ERROR(ReadUInt32(&num_metadata_pairs));
+ if (num_metadata_pairs > 1024) {
+ return errors::InvalidArgument(
+ "sequence file metadata should have key value pairs < 1024, "
+ "received ",
+ num_metadata_pairs);
+ }
+ for (int i = 0; i < num_metadata_pairs; i++) {
+ TF_RETURN_IF_ERROR(ReadString(nullptr));
+ TF_RETURN_IF_ERROR(ReadString(nullptr));
+ }
+
+ TF_RETURN_IF_ERROR(
+ input_stream_->ReadNBytes(kSyncMarkerSize, &sync_marker_));
+
+ return Status::OK();
+ }
+
+ Status ReadRecord(string* key, string* value) {
+ uint32 length = 0;
+ TF_RETURN_IF_ERROR(ReadUInt32(&length));
+ if (length == static_cast<uint32>(-1)) {
+ // Sync marker.
+ string sync_marker;
+ TF_RETURN_IF_ERROR(
+ input_stream_->ReadNBytes(kSyncMarkerSize, &sync_marker));
+ if (sync_marker != sync_marker_) {
+ return errors::InvalidArgument(
+ "sequence file should have sync marker \"", sync_marker_,
+ "\" at pos ", input_stream_->Tell() - kSyncMarkerSize,
+ ", received \"", sync_marker, "\"");
+ }
+ return ReadRecord(key, value);
+ }
+ uint32 key_length = 0;
+ TF_RETURN_IF_ERROR(ReadUInt32(&key_length));
+ if (key_length > length) {
+ return errors::InvalidArgument("key length (", key_length,
+ ") should be < record length (", length,
+ ")");
+ }
+ // At the moment we only support `org.apache.hadoop.io.Text` for key/value.
+ // TODO (yongtang): Expand supported format.
+ TF_RETURN_IF_ERROR(ReadString(key));
+ TF_RETURN_IF_ERROR(ReadString(value));
+ return Status::OK();
+ }
+
+ Status ReadString(string* value) {
+ int64 length = 0;
+ TF_RETURN_IF_ERROR(ReadVInt(&length));
+ if (value == nullptr) {
+ return input_stream_->SkipNBytes(length);
+ }
+ return input_stream_->ReadNBytes(length, value);
+ }
+
+ Status ReadUInt32(uint32* value) {
+ string buffer;
+ TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &buffer));
+ *value = ((static_cast<uint32>(buffer[0]) << 24) |
+ static_cast<uint32>(buffer[1]) << 16) |
+ (static_cast<uint32>(buffer[2]) << 8) |
+ static_cast<uint32>(buffer[3]);
+ return Status::OK();
+ }
+
+ Status ReadVInt(int64* value) {
+ string buffer;
+ TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(1, &buffer));
+ if (buffer[0] >= -112) {
+ *value = static_cast<int64>(buffer[0]);
+ return Status::OK();
+ }
+
+ int64 remaining = 0;
+ bool negative = false;
+ if (buffer[0] >= -120) {
+ remaining = static_cast<int64>(-112) - static_cast<int64>(buffer[0]);
+ } else {
+ remaining = static_cast<int64>(-120) - static_cast<int64>(buffer[0]);
+ negative = true;
+ }
+ buffer.clear();
+ TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(remaining, &buffer));
+
+ uint64 v = 0;
+ for (int i = 0; i < buffer.size(); i++) {
+ v = (v << 8) | static_cast<uint64>(buffer[i]);
+ }
+ if (negative) {
+ v = ~v;
+ }
+ *value = static_cast<int64>(v);
+ return Status::OK();
+ }
+
+ virtual ~SequenceFileReader() = default;
+
+ private:
+ std::unique_ptr<io::InputStreamInterface> input_stream_;
+ string key_class_name_;
+ string value_class_name_;
+ string sync_marker_;
+ bool compression_;
+ bool block_compression_;
+ string compression_codec_class_name_;
+ TF_DISALLOW_COPY_AND_ASSIGN(SequenceFileReader);
+};
+class SequenceFileDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+ explicit SequenceFileDatasetOp(OpKernelConstruction* ctx)
+ : DatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ for (const DataType& dt : output_types_) {
+ OP_REQUIRES(ctx, dt == DT_STRING,
+ errors::InvalidArgument(
+ "Each element of `output_types_` must be one of: "
+ "DT_STRING"));
+ }
+ }
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ *output = new Dataset(ctx, filenames, output_types_);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const std::vector<string>& filenames,
+ const DataTypeVector& output_types)
+ : GraphDatasetBase(ctx),
+ filenames_(filenames),
+ output_types_(output_types) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::SequenceFile")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "SequenceFileDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ do {
+ // We are currently processing a file, so try to read the next record.
+ if (reader_) {
+ string key, value;
+ Status status = reader_->ReadRecord(&key, &value);
+ if (!errors::IsOutOfRange(status)) {
+ TF_RETURN_IF_ERROR(status);
+
+ Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
+ key_tensor.scalar<string>()() = key;
+ out_tensors->emplace_back(std::move(key_tensor));
+
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<string>()() = value;
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ // We have reached the end of the current file, so maybe
+ // move on to next file.
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+
+ // Iteration ends when there are no more files to process.
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented("SaveInternal is currently not supported");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "RestoreInternal is currently not supported");
+ }
+
+ private:
+ // Sets up SequenceFile streams to read from the topic at
+ // `current_file_index_`.
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+
+ // Actually move on to next file.
+ const string& filename = dataset()->filenames_[current_file_index_];
+ TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file_));
+ reader_.reset(new SequenceFileReader(file_.get()));
+ return reader_->ReadHeader();
+ }
+
+ // Resets all Hadoop SequenceFile streams.
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ reader_.reset();
+ file_.reset();
+ }
+
+ mutex mu_;
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_);
+ std::unique_ptr<SequenceFileReader> reader_ GUARDED_BY(mu_);
+ };
+
+ const std::vector<string> filenames_;
+ const DataTypeVector output_types_;
+ };
+ DataTypeVector output_types_;
+};
+} // namespace
+
+REGISTER_KERNEL_BUILDER(Name("SequenceFileDataset").Device(DEVICE_CPU),
+ SequenceFileDatasetOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/hadoop/ops/dataset_ops.cc b/tensorflow/contrib/hadoop/ops/dataset_ops.cc
new file mode 100644
index 0000000000..66ad549b47
--- /dev/null
+++ b/tensorflow/contrib/hadoop/ops/dataset_ops.cc
@@ -0,0 +1,29 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("SequenceFileDataset")
+ .Input("filenames: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
new file mode 100644
index 0000000000..d796e43d87
--- /dev/null
+++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for SequenceFileDataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.hadoop.python.ops import hadoop_dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
+
+
+class SequenceFileDatasetTest(test.TestCase):
+
+ def test_sequence_file_dataset(self):
+ """Test case for SequenceFileDataset.
+
+ The file is generated with `org.apache.hadoop.io.Text` for key/value.
+ There are 25 records in the file with the format of:
+ key = XXX
+ value = VALUEXXX
+ where XXX is replaced as the line number (starts with 001).
+ """
+ filename = os.path.join(resource_loader.get_data_files_path(),
+ "testdata", "string.seq")
+
+ filenames = constant_op.constant([filename], dtypes.string)
+ num_repeats = 2
+
+ dataset = hadoop_dataset_ops.SequenceFileDataset(filenames).repeat(
+ num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(num_repeats): # Dataset is repeated.
+ for i in range(25): # 25 records.
+ v0 = b"%03d" % (i + 1)
+ v1 = b"VALUE%03d" % (i + 1)
+ self.assertEqual((v0, v1), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seq b/tensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seq
new file mode 100755
index 0000000000..b7175338af
--- /dev/null
+++ b/tensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seq
Binary files differ
diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
new file mode 100644
index 0000000000..6e0e628655
--- /dev/null
+++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
@@ -0,0 +1,75 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SequenceFile Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops
+from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class SequenceFileDataset(Dataset):
+ """A Sequence File Dataset that reads the sequence file."""
+
+ def __init__(self, filenames):
+ """Create a `SequenceFileDataset`.
+
+ `SequenceFileDataset` allows a user to read data from a hadoop sequence
+ file. A sequence file consists of (key value) pairs sequentially. At
+ the moment, `org.apache.hadoop.io.Text` is the only serialization type
+ being supported, and there is no compression support.
+
+ For example:
+
+ ```python
+ dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq")
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ # Prints the (key, value) pairs inside a hadoop sequence file.
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+
+ Args:
+ filenames: A `tf.string` tensor containing one or more filenames.
+ """
+ super(SequenceFileDataset, self).__init__()
+ self._filenames = ops.convert_to_tensor(
+ filenames, dtype=dtypes.string, name="filenames")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.sequence_file_dataset(
+ self._filenames, nest.flatten(self.output_types))
+
+ @property
+ def output_classes(self):
+ return ops.Tensor, ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
+
+ @property
+ def output_types(self):
+ return dtypes.string, dtypes.string
diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_op_loader.py b/tensorflow/contrib/hadoop/python/ops/hadoop_op_loader.py
new file mode 100644
index 0000000000..6dbf1253f3
--- /dev/null
+++ b/tensorflow/contrib/hadoop/python/ops/hadoop_op_loader.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python helper for loading hadoop ops and kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_dataset_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../../_dataset_ops.so"))