aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kafka
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-03-21 12:07:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 12:10:30 -0700
commit2d0531d72c7dcbb0e149cafdd3a16ee8c3ff357a (patch)
tree1179ecdd684d10c6549f85aa95f33dd79463a093 /tensorflow/contrib/kafka
parentcbede3ea7574b36f429710bc08617d08455bcc21 (diff)
Merge changes from github.
PiperOrigin-RevId: 189945839
Diffstat (limited to 'tensorflow/contrib/kafka')
-rw-r--r--tensorflow/contrib/kafka/BUILD108
-rw-r--r--tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc4
-rw-r--r--tensorflow/contrib/kafka/ops/dataset_ops.cc44
-rw-r--r--tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py9
-rw-r--r--tensorflow/contrib/kafka/python/ops/kafka_op_loader.py24
5 files changed, 143 insertions, 46 deletions
diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD
index efb403462a..1c3974871c 100644
--- a/tensorflow/contrib/kafka/BUILD
+++ b/tensorflow/contrib/kafka/BUILD
@@ -1,66 +1,93 @@
-package(
- default_visibility = ["//visibility:private"],
-)
+package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
-load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
-load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_py_test",
+)
-tf_kernel_library(
- name = "kafka_kernels",
+py_library(
+ name = "kafka",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = ["ops/dataset_ops.cc"],
+ deps = [":dataset_kernels"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["dataset_ops"],
+)
+
+cc_library(
+ name = "dataset_kernels",
srcs = ["kernels/kafka_dataset_ops.cc"],
- visibility = ["//visibility:public"],
deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core/kernels:bounds_check_lib",
- "//tensorflow/core/kernels:dataset",
+ "//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@kafka",
+ "@protobuf_archive//:protobuf_headers",
],
+ alwayslink = 1,
)
-tf_gen_op_libs(
- op_lib_names = ["kafka_ops"],
+py_library(
+ name = "dataset_ops",
+ srcs = [
+ "python/ops/kafka_dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
deps = [
- "//tensorflow/core:lib",
+ ":kafka_op_loader",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
],
)
tf_gen_op_wrapper_py(
- name = "gen_kafka_ops",
- out = "python/ops/gen_kafka_ops.py",
- require_shape_functions = True,
- deps = [":kafka_ops_op_lib"],
+ name = "gen_dataset_ops",
+ out = "python/ops/gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/kafka:dataset_ops_op_lib"],
)
-py_library(
- name = "kafka",
- srcs = [
- "__init__.py",
- "python/ops/kafka_dataset_ops.py",
+tf_kernel_library(
+ name = "dataset_ops_kernels",
+ deps = [
+ ":dataset_kernels",
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "kafka_op_loader",
+ srcs = ["python/ops/kafka_op_loader.py"],
+ dso = ["//tensorflow/contrib/kafka:_dataset_ops.so"],
+ kernels = [
+ ":dataset_ops_kernels",
+ "//tensorflow/contrib/kafka:dataset_ops_op_lib",
],
srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
deps = [
- ":gen_kafka_ops",
+ ":gen_dataset_ops",
"//tensorflow/contrib/util:util_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:readers",
],
)
@@ -88,6 +115,7 @@ tf_py_test(
],
tags = [
"manual",
+ "no_windows",
"notap",
],
)
@@ -95,7 +123,9 @@ tf_py_test(
filegroup(
name = "all_files",
srcs = glob(
- ["**/*"],
+ include = [
+ "**/*",
+ ],
exclude = [
"**/METADATA",
"**/OWNERS",
diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
index 88ef5f3571..a4cd4a2cc4 100644
--- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
+++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
@@ -13,9 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/kernels/dataset.h"
-
-#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/dataset.h"
#include "src-cpp/rdkafkacpp.h"
diff --git a/tensorflow/contrib/kafka/ops/dataset_ops.cc b/tensorflow/contrib/kafka/ops/dataset_ops.cc
new file mode 100644
index 0000000000..8cdf16103b
--- /dev/null
+++ b/tensorflow/contrib/kafka/ops/dataset_ops.cc
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KafkaDataset")
+ .Input("topics: string")
+ .Input("servers: string")
+ .Input("group: string")
+ .Input("eof: bool")
+ .Input("timeout: int64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that emits the messages of one or more Kafka topics.
+
+topics: A `tf.string` tensor containing one or more subscriptions,
+ in the format of [topic:partition:offset:length],
+ by default length is -1 for unlimited.
+servers: A list of bootstrap servers.
+group: The consumer group id.
+eof: If True, the kafka reader will stop on EOF.
+timeout: The timeout value for the Kafka Consumer to wait
+ (in millisecond).
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
index 8e51d27a34..a1624614d1 100644
--- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
+++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
@@ -17,8 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.kafka.python.ops import gen_kafka_ops
-from tensorflow.python.data.ops.readers import Dataset
+from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.kafka.python.ops import gen_dataset_ops
+from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -58,8 +59,8 @@ class KafkaDataset(Dataset):
timeout, dtype=dtypes.int64, name="timeout")
def _as_variant_tensor(self):
- return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group,
- self._eof, self._timeout)
+ return gen_dataset_ops.kafka_dataset(self._topics, self._servers,
+ self._group, self._eof, self._timeout)
@property
def output_classes(self):
diff --git a/tensorflow/contrib/kafka/python/ops/kafka_op_loader.py b/tensorflow/contrib/kafka/python/ops/kafka_op_loader.py
new file mode 100644
index 0000000000..ec2fdea962
--- /dev/null
+++ b/tensorflow/contrib/kafka/python/ops/kafka_op_loader.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python helper for loading kafka ops and kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_dataset_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../../_dataset_ops.so"))