aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--configure.py2
-rw-r--r--tensorflow/BUILD6
-rw-r--r--tensorflow/contrib/BUILD15
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt2
-rw-r--r--tensorflow/contrib/ignite/BUILD136
-rw-r--r--tensorflow/contrib/ignite/README.md167
-rw-r--r--tensorflow/contrib/ignite/__init__.py42
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc304
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h54
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_client.cc55
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_client.h40
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset.cc123
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset.h65
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc447
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h87
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc145
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client.h43
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc132
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc143
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc149
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h49
-rw-r--r--tensorflow/contrib/ignite/ops/dataset_ops.cc64
-rw-r--r--tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py763
-rw-r--r--tensorflow/contrib/ignite/python/ops/ignite_op_loader.py25
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/bin/start-plain.sh24
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/bin/start-ssl-auth.sh28
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/bin/start-ssl.sh26
-rw-r--r--tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml39
-rw-r--r--tensorflow/contrib/ignite/python/tests/config/ignite-config-ssl-auth.xml59
-rw-r--r--tensorflow/contrib/ignite/python/tests/config/ignite-config-ssl.xml59
-rw-r--r--tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py77
-rw-r--r--tensorflow/contrib/ignite/python/tests/keystore/client.jksbin0 -> 3232 bytes
-rw-r--r--tensorflow/contrib/ignite/python/tests/keystore/client.pem69
-rw-r--r--tensorflow/contrib/ignite/python/tests/keystore/server.jksbin0 -> 3230 bytes
-rw-r--r--tensorflow/contrib/ignite/python/tests/keystore/trust.jksbin0 -> 2432 bytes
-rw-r--r--tensorflow/contrib/ignite/python/tests/sql/init.sql20
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/start_ignite.sh30
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/stop_ignite.sh19
38 files changed, 3508 insertions, 0 deletions
diff --git a/configure.py b/configure.py
index 361bd4764d..8f1957e870 100644
--- a/configure.py
+++ b/configure.py
@@ -1502,6 +1502,8 @@ def main():
'with_aws_support', True, 'aws')
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
'with_kafka_support', True, 'kafka')
+ set_build_var(environ_cp, 'TF_NEED_IGNITE', 'Apache Ignite',
+ 'with_ignite_support', True, 'ignite')
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
False, 'xla')
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 386e0096ff..6c29c78793 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -248,6 +248,12 @@ config_setting(
visibility = ["//visibility:public"],
)
+config_setting(
+ name = "with_ignite_support",
+ define_values = {"with_ignite_support": "true"},
+ visibility = ["//visibility:public"],
+)
+
# Crosses between platforms and file system libraries not supported on those
# platforms due to limitations in nested select() statements.
config_setting(
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 798f499870..f055e643d0 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -119,6 +119,11 @@ py_library(
],
"//conditions:default": [],
}) + select({
+ "//tensorflow:with_ignite_support": [
+ "//tensorflow/contrib/ignite",
+ ],
+ "//conditions:default": [],
+ }) + select({
"//tensorflow:with_aws_support_windows_override": [],
"//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis",
@@ -161,6 +166,11 @@ cc_library(
],
"//conditions:default": [],
}) + select({
+ "//tensorflow:with_ignite_support": [
+ "//tensorflow/contrib/ignite:dataset_kernels",
+ ],
+ "//conditions:default": [],
+ }) + select({
"//tensorflow:with_aws_support_windows_override": [],
"//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_kernels",
@@ -198,6 +208,11 @@ cc_library(
],
"//conditions:default": [],
}) + select({
+ "//tensorflow:with_ignite_support": [
+ "//tensorflow/contrib/ignite:dataset_ops_op_lib",
+ ],
+ "//conditions:default": [],
+ }) + select({
"//tensorflow:with_aws_support_windows_override": [],
"//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index fb871acae9..56755e817a 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -207,6 +207,8 @@ tensorflow/contrib/integrate/python
tensorflow/contrib/integrate/python/ops
tensorflow/contrib/kafka/python
tensorflow/contrib/kafka/python/ops
+tensorflow/contrib/ignite/python
+tensorflow/contrib/ignite/python/ops
tensorflow/contrib/keras
tensorflow/contrib/keras/api
tensorflow/contrib/keras/api/keras
diff --git a/tensorflow/contrib/ignite/BUILD b/tensorflow/contrib/ignite/BUILD
new file mode 100644
index 0000000000..9f6c666893
--- /dev/null
+++ b/tensorflow/contrib/ignite/BUILD
@@ -0,0 +1,136 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_py_test",
+ "if_not_windows",
+ "if_windows",
+)
+
+py_library(
+ name = "ignite",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = ["ops/dataset_ops.cc"],
+ deps = [":dataset_kernels"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["dataset_ops"],
+)
+
+cc_library(
+ name = "dataset_kernels",
+ srcs = [
+ "kernels/ignite_dataset_ops.cc",
+ "kernels/ignite_client.h",
+ "kernels/ignite_client.cc",
+ "kernels/ignite_plain_client.h",
+ "kernels/ignite_ssl_wrapper.h",
+ "kernels/ignite_ssl_wrapper.cc",
+ "kernels/ignite_binary_object_parser.h",
+ "kernels/ignite_binary_object_parser.cc",
+ "kernels/ignite_dataset.h",
+ "kernels/ignite_dataset.cc",
+ "kernels/ignite_dataset_iterator.h",
+ "kernels/ignite_dataset_iterator.cc",
+ ] + if_not_windows([
+ "kernels/ignite_plain_client_unix.cc",
+ ]) + if_windows([
+ "kernels/ignite_plain_client_windows.cc",
+ ]),
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@boringssl//:ssl",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+py_library(
+ name = "dataset_ops",
+ srcs = [
+ "python/ops/ignite_dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ignite_op_loader",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_dataset_ops",
+ out = "python/ops/gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/ignite:dataset_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "dataset_ops_kernels",
+ deps = [
+ ":dataset_kernels",
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "ignite_op_loader",
+ srcs = ["python/ops/ignite_op_loader.py"],
+ dso = ["//tensorflow/contrib/ignite:_dataset_ops.so"],
+ kernels = [
+ ":dataset_ops_kernels",
+ "//tensorflow/contrib/ignite:dataset_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_dataset_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+# The Apache Ignite servers have to setup before the test and tear down
+# after the test manually. The docker engine has to be installed.
+#
+# To setup Apache Ignite servers:
+# $ bash ./python/tests/start_ignite.sh
+#
+# To tear down Apache Ignite servers:
+# $ bash ./python/tests/stop_ignite.sh
+tf_py_test(
+ name = "ignite_dataset_test",
+ srcs = ["python/tests/ignite_dataset_test.py"],
+ additional_deps = [
+ ":ignite",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "no_windows",
+ "notap",
+ ],
+)
diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md
new file mode 100644
index 0000000000..9054344e94
--- /dev/null
+++ b/tensorflow/contrib/ignite/README.md
@@ -0,0 +1,167 @@
+### Ignite Dataset
+# Ignite Dataset
+
+- [Overview](#overview)
+- [Features](#features)
+ * [Distributed In-Memory Datasource](#distributed-in-memory-datasource)
+ * [Structured Objects](#structured-objects)
+ * [Distributed Training](#distributed-training)
+ * [SSL Connection](#ssl-connection)
+ * [Windows Support](#windows-support)
+- [Try it out](#try-it-out)
+- [Limitations](#limitations)
+
+## Overview
+
+[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed database, caching, and processing platform for
+transactional, analytical, and streaming workloads, delivering in-memory speeds at petabyte scale. This contrib package contains an integration between Apache Ignite and TensorFlow. The integration is based on [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow side and [Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) from Apache Ignite side. It allows to use Apache Ignite as a datasource for neural network training, inference and all other computations supported by TensorFlow.
+
+## Features
+
+Ignite Dataset provides a set of features that makes it possible to use it in a wide range of cases. The most important and interesting features are described below.
+
+### Distributed In-Memory Datasource
+[Apache Ignite](https://ignite.apache.org/) is a distributed in-memory database, caching, and processing platform that allows to avoid limitations of hard drive and provide high reading speed and ability to store and operate with as much data as you need in distributed cluster. Using of Ignite Dataset makes it possible to utilize all these advantages.
+- If you have a **gigabyte** of data you can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations. At the same time, you can store your data in Apache Ignite on the same machine and use it as a datasource for TensorFlow and thus avoid these limitations.
+- If you have a **terabyte** of data you probably still can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations again. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow and thus avoid these limitations.
+- If you have a **petabyte** of data you can't keep it on a single machine. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow.
+
+It's important that Apache Ignite is not just a step of ETL pipeline between database or data warehouse and TensorFlow. Apache Ignite is a high-grade database itself. Choosing Apache Ignite and TensorFlow you are getting everything you need to work with operational or historical data and, in the same time, an ability to use this data for neural network training and inference.
+
+```bash
+$ apache-ignite-fabric/bin/ignite.sh
+$ apache-ignite-fabric/bin/sqlline.sh -u "jdbc:ignite:thin://localhost:10800/"
+
+jdbc:ignite:thin://localhost/> CREATE TABLE KITTEN_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR);
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (1, 'WARM KITTY');
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (2, 'SOFT KITTY');
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL OF FUR');
+```
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE")
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> for _ in range(3):
+>>> print(sess.run(next_obj))
+
+{'key': 1, 'val': {'NAME': b'WARM KITTY'}}
+{'key': 2, 'val': {'NAME': b'SOFT KITTY'}}
+{'key': 3, 'val': {'NAME': b'LITTLE BALL OF FUR'}}
+```
+
+### Structured Objects
+[Apache Ignite](https://ignite.apache.org/) allows to store any objects you would like to store. These objects can have any hierarchy. Ignite Dataset provides an ability to work with such objects.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES")
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> print(sess.run(next_obj))
+
+{
+ 'key': 'kitten.png',
+ 'val': {
+ 'metadata': {
+ 'file_name': b'kitten.png',
+ 'label': b'little ball of fur',
+ width: 800,
+ height: 600
+ },
+ 'pixels': [0, 0, 0, 0, ..., 0]
+ }
+}
+```
+ Neural network training and other computations require transformations that can be done as part of [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) pipeline if you use Ignite Dataset.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels'])
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> print(sess.run(next_obj))
+
+[0, 0, 0, 0, ..., 0]
+```
+
+### Distributed Training
+
+TensorFlow is a machine learning framework that [natively supports](https://www.tensorflow.org/deploy/distributed) distributed neural network training, inference and other computations. The main idea behind the distributed neural network training is an ability to calculate gradients of loss functions (squares of the errors) on every partition of data (in terms of horizontal partitioning) and then sum them to get loss function gradient of the whole dataset.
+
+<a href="https://www.codecogs.com/eqnedit.php?latex=\nabla[\sum_1^n(y&space;-&space;\hat{y})^2]&space;=&space;\nabla[\sum_1^{n_1}(y&space;-&space;\hat{y})^2]&space;&plus;&space;\nabla[\sum_{n_1}^{n_2}(y&space;-&space;\hat{y})^2]&space;&plus;&space;...&space;&plus;&space;\nabla[\sum_{n_{k-1}}^n(y&space;-&space;\hat{y})^2]" target="_blank"><img src="https://latex.codecogs.com/gif.latex?\nabla[\sum_1^n(y&space;-&space;\hat{y})^2]&space;=&space;\nabla[\sum_1^{n_1}(y&space;-&space;\hat{y})^2]&space;&plus;&space;\nabla[\sum_{n_1}^{n_2}(y&space;-&space;\hat{y})^2]&space;&plus;&space;...&space;&plus;&space;\nabla[\sum_{n_{k-1}}^n(y&space;-&space;\hat{y})^2]" title="\nabla[\sum_1^n(y - \hat{y})^2] = \nabla[\sum_1^{n_1}(y - \hat{y})^2] + \nabla[\sum_{n_1}^{n_2}(y - \hat{y})^2] + ... + \nabla[\sum_{n_{k-1}}^n(y - \hat{y})^2]" /></a>
+
+Utilizing this ability we can calculate gradients on the nodes the data is stored on, reduce them and then finally update model parameters. It allows to avoid data transfers between nodes and thus to avoid network bottleneck.
+
+Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL) we can specify the number of partitions the data will be partitioned on. If, for example, Apache Ignite cluster consists of 10 machines and we creates cache with 10 partitions then every machine will maintain approximately one data partition.
+
+Ignite Dataset allows to utilize these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that might be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach we are able to assign specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset("IMAGES")
+>>>
+>>> # Compute gradients locally on every worker node.
+>>> gradients = []
+>>> for i in range(5):
+>>> with tf.device("/job:WORKER/task:%d" % i):
+>>> device_iterator = dataset.make_one_shot_iterator()
+>>> device_next_obj = device_iterator.get_next()
+>>> gradient = compute_gradient(device_next_obj)
+>>> gradients.append(gradient)
+>>>
+>>> # Aggregate them on master node.
+>>> result_gradient = tf.reduce_sum(gradients)
+>>>
+>>> with tf.Session("grpc://localhost:10000") as sess:
+>>> print(sess.run(result_gradient))
+```
+
+High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well.
+
+### SSL Connection
+
+Your data should not be accessible without any control. Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information please see [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", cert_password="password", username="ignite", password="ignite")
+>>> ...
+```
+
+### Windows Support
+
+Ignite Dataset is fully compatible with Windows, so you can use it as part of TensorFlow on your Windows workstation as well as on Linux/MacOS systems.
+
+## Try it out
+
+The simplest way to try Ignite Dataset out is to run [Docker](https://www.docker.com/) container with Apache Ignite and loaded [MNIST](http://yann.lecun.com/exdb/mnist/) data and then interruct with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine:
+
+```
+docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist
+```
+
+After that you will be able to work with it following way:
+
+![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist")
+
+## Limitations
+
+Presently Ignite Dataset works with assumption that all objects in the cache have the same structure (homogeneous objects) and the cache contains at least one object. Another limitation concerns structured objects, Ignite Dataset does not support UUID, Maps and Object arrays that might be parts of object structures. \ No newline at end of file
diff --git a/tensorflow/contrib/ignite/__init__.py b/tensorflow/contrib/ignite/__init__.py
new file mode 100644
index 0000000000..468920a557
--- /dev/null
+++ b/tensorflow/contrib/ignite/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Apache Ignite is a memory-centric distributed database, caching, and
+ processing platform for transactional, analytical, and streaming workloads,
+ delivering in-memory speeds at petabyte scale. This contrib package
+ contains an integration between Apache Ignite and TensorFlow. The
+ integration is based on tf.data from TensorFlow side and Binary Client
+ Protocol from Apache Ignite side. It allows to use Apache Ignite as a
+ datasource for neural network training, inference and all other
+ computations supported by TensorFlow. Ignite Dataset is based on Apache
+ Ignite Binary Client Protocol:
+ https://apacheignite.readme.io/v2.6/docs/binary-client-protocol.
+
+@@IgniteDataset
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.ignite.python.ops.ignite_dataset_ops \
+import IgniteDataset
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "IgniteDataset",
+]
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc
new file mode 100644
index 0000000000..bf0ef8766e
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc
@@ -0,0 +1,304 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_binary_object_parser.h"
+
+namespace ignite {
+
+tensorflow::Status BinaryObjectParser::Parse(
+ uint8_t*& ptr, std::vector<tensorflow::Tensor>& out_tensors,
+ std::vector<int32_t>& types) {
+ uint8_t object_type_id = *ptr;
+ ptr += 1;
+
+ switch (object_type_id) {
+ case BYTE: {
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_UINT8, {});
+ tensor.scalar<tensorflow::uint8>()() = *((uint8_t*)ptr);
+ ptr += 1;
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case SHORT: {
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_INT16, {});
+ tensor.scalar<tensorflow::int16>()() = *((int16_t*)ptr);
+ ptr += 2;
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case INT: {
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_INT32, {});
+ tensor.scalar<tensorflow::int32>()() = *((int32_t*)ptr);
+ ptr += 4;
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case LONG: {
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_INT64, {});
+ tensor.scalar<tensorflow::int64>()() = *((int64_t*)ptr);
+ ptr += 8;
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case FLOAT: {
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_FLOAT, {});
+ tensor.scalar<float>()() = *((float*)ptr);
+ ptr += 4;
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case DOUBLE: {
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_DOUBLE, {});
+ tensor.scalar<double>()() = *((double*)ptr);
+ ptr += 8;
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case UCHAR: {
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_UINT16, {});
+ tensor.scalar<tensorflow::uint16>()() = *((uint16_t*)ptr);
+ ptr += 2;
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case BOOL: {
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_BOOL, {});
+ tensor.scalar<bool>()() = *((bool*)ptr);
+ ptr += 1;
+ out_tensors.emplace_back(std::move(tensor));
+
+ break;
+ }
+ case STRING: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_STRING, {});
+ tensor.scalar<std::string>()() = std::string((char*)ptr, length);
+ ptr += length;
+ out_tensors.emplace_back(std::move(tensor));
+
+ break;
+ }
+ case DATE: {
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_INT64, {});
+ tensor.scalar<tensorflow::int64>()() = *((int64_t*)ptr);
+ ptr += 8;
+ out_tensors.emplace_back(std::move(tensor));
+
+ break;
+ }
+ case BYTE_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_UINT8,
+ tensorflow::TensorShape({length}));
+
+ uint8_t* arr = (uint8_t*)ptr;
+ ptr += length;
+
+ std::copy_n(arr, length, tensor.flat<tensorflow::uint8>().data());
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case SHORT_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_INT16,
+ tensorflow::TensorShape({length}));
+
+ int16_t* arr = (int16_t*)ptr;
+ ptr += length * 2;
+
+ std::copy_n(arr, length, tensor.flat<tensorflow::int16>().data());
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case INT_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_INT32,
+ tensorflow::TensorShape({length}));
+
+ int32_t* arr = (int32_t*)ptr;
+ ptr += length * 4;
+
+ std::copy_n(arr, length, tensor.flat<tensorflow::int32>().data());
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case LONG_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_INT64,
+ tensorflow::TensorShape({length}));
+
+ int64_t* arr = (int64_t*)ptr;
+ ptr += length * 8;
+
+ std::copy_n(arr, length, tensor.flat<tensorflow::int64>().data());
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case FLOAT_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_FLOAT,
+ tensorflow::TensorShape({length}));
+
+ float* arr = (float*)ptr;
+ ptr += 4 * length;
+
+ std::copy_n(arr, length, tensor.flat<float>().data());
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case DOUBLE_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_DOUBLE,
+ tensorflow::TensorShape({length}));
+
+ double* arr = (double*)ptr;
+ ptr += 8 * length;
+
+ std::copy_n(arr, length, tensor.flat<double>().data());
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case UCHAR_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_UINT16,
+ tensorflow::TensorShape({length}));
+
+ uint16_t* arr = (uint16_t*)ptr;
+ ptr += length * 2;
+
+ std::copy_n(arr, length, tensor.flat<tensorflow::uint16>().data());
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case BOOL_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_BOOL,
+ tensorflow::TensorShape({length}));
+
+ bool* arr = (bool*)ptr;
+ ptr += length;
+
+ std::copy_n(arr, length, tensor.flat<bool>().data());
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case STRING_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_STRING,
+ tensorflow::TensorShape({length}));
+
+ for (int32_t i = 0; i < length; i++) {
+ int32_t str_length = *((int32_t*)ptr);
+ ptr += 4;
+ const int8_t* str = (const int8_t*)ptr;
+ ptr += str_length;
+ tensor.vec<std::string>()(i) = std::string((char*)str, str_length);
+ }
+
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case DATE_ARR: {
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ tensorflow::Tensor tensor(tensorflow::cpu_allocator(),
+ tensorflow::DT_INT64,
+ tensorflow::TensorShape({length}));
+ int64_t* arr = (int64_t*)ptr;
+ ptr += length * 8;
+
+ std::copy_n(arr, length, tensor.flat<tensorflow::int64>().data());
+ out_tensors.emplace_back(std::move(tensor));
+ break;
+ }
+ case WRAPPED_OBJ: {
+ int32_t byte_arr_size = *((int32_t*)ptr);
+ ptr += 4;
+
+ tensorflow::Status status = Parse(ptr, out_tensors, types);
+ if (!status.ok()) return status;
+
+ int32_t offset = *((int32_t*)ptr);
+ ptr += 4;
+
+ break;
+ }
+ case COMPLEX_OBJ: {
+ uint8_t version = *ptr;
+ ptr += 1;
+ int16_t flags = *((int16_t*)ptr); // USER_TYPE = 1, HAS_SCHEMA = 2
+ ptr += 2;
+ int32_t type_id = *((int32_t*)ptr);
+ ptr += 4;
+ int32_t hash_code = *((int32_t*)ptr);
+ ptr += 4;
+ int32_t length = *((int32_t*)ptr);
+ ptr += 4;
+ int32_t schema_id = *((int32_t*)ptr);
+ ptr += 4;
+ int32_t schema_offset = *((int32_t*)ptr);
+ ptr += 4;
+
+ uint8_t* end = ptr + schema_offset - 24;
+ int32_t i = 0;
+ while (ptr < end) {
+ i++;
+ tensorflow::Status status = Parse(ptr, out_tensors, types);
+ if (!status.ok()) return status;
+ }
+
+ ptr += (length - schema_offset);
+
+ break;
+ }
+ default: {
+ return tensorflow::errors::Internal("Unknowd binary type (type id ",
+ (int)object_type_id, ")");
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h
new file mode 100644
index 0000000000..1e845cbc56
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace ignite {
+
+class BinaryObjectParser {
+ public:
+ tensorflow::Status Parse(uint8_t*& ptr,
+ std::vector<tensorflow::Tensor>& out_tensors,
+ std::vector<int32_t>& types);
+};
+
+enum ObjectType {
+ BYTE = 1,
+ SHORT = 2,
+ INT = 3,
+ LONG = 4,
+ FLOAT = 5,
+ DOUBLE = 6,
+ UCHAR = 7,
+ BOOL = 8,
+ STRING = 9,
+ DATE = 11,
+ BYTE_ARR = 12,
+ SHORT_ARR = 13,
+ INT_ARR = 14,
+ LONG_ARR = 15,
+ FLOAT_ARR = 16,
+ DOUBLE_ARR = 17,
+ UCHAR_ARR = 18,
+ BOOL_ARR = 19,
+ STRING_ARR = 20,
+ DATE_ARR = 22,
+ WRAPPED_OBJ = 27,
+ COMPLEX_OBJ = 103
+};
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.cc b/tensorflow/contrib/ignite/kernels/ignite_client.cc
new file mode 100644
index 0000000000..5a8eddb944
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_client.cc
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef IGNITE_CLIENT_H
+#define IGNITE_CLIENT_H
+#include "ignite_client.h"
+#endif
+
+namespace ignite {
+
+tensorflow::Status Client::ReadByte(uint8_t& data) {
+ return ReadData((uint8_t*)&data, 1);
+}
+
+tensorflow::Status Client::ReadShort(int16_t& data) {
+ return ReadData((uint8_t*)&data, 2);
+}
+
+tensorflow::Status Client::ReadInt(int32_t& data) {
+ return ReadData((uint8_t*)&data, 4);
+}
+
+tensorflow::Status Client::ReadLong(int64_t& data) {
+ return ReadData((uint8_t*)&data, 8);
+}
+
+tensorflow::Status Client::WriteByte(uint8_t data) {
+ return WriteData((uint8_t*)&data, 1);
+}
+
+tensorflow::Status Client::WriteShort(int16_t data) {
+ return WriteData((uint8_t*)&data, 2);
+}
+
+tensorflow::Status Client::WriteInt(int32_t data) {
+ return WriteData((uint8_t*)&data, 4);
+}
+
+tensorflow::Status Client::WriteLong(int64_t data) {
+ return WriteData((uint8_t*)&data, 8);
+}
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.h b/tensorflow/contrib/ignite/kernels/ignite_client.h
new file mode 100644
index 0000000000..64e28d75f0
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_client.h
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace ignite {
+
+class Client {
+ public:
+ virtual tensorflow::Status Connect() = 0;
+ virtual tensorflow::Status Disconnect() = 0;
+ virtual bool IsConnected() = 0;
+ virtual int GetSocketDescriptor() = 0;
+
+ virtual tensorflow::Status ReadByte(uint8_t& data);
+ virtual tensorflow::Status ReadShort(int16_t& data);
+ virtual tensorflow::Status ReadInt(int32_t& data);
+ virtual tensorflow::Status ReadLong(int64_t& data);
+ virtual tensorflow::Status ReadData(uint8_t* buf, int32_t length) = 0;
+
+ virtual tensorflow::Status WriteByte(uint8_t data);
+ virtual tensorflow::Status WriteShort(int16_t data);
+ virtual tensorflow::Status WriteInt(int32_t data);
+ virtual tensorflow::Status WriteLong(int64_t data);
+ virtual tensorflow::Status WriteData(uint8_t* buf, int32_t length) = 0;
+};
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc
new file mode 100644
index 0000000000..a9bf26955b
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc
@@ -0,0 +1,123 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_dataset_iterator.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace ignite {
+
+IgniteDataset::IgniteDataset(tensorflow::OpKernelContext* ctx,
+ std::string cache_name, std::string host,
+ tensorflow::int32 port, bool local,
+ tensorflow::int32 part,
+ tensorflow::int32 page_size, std::string username,
+ std::string password, std::string certfile,
+ std::string keyfile, std::string cert_password,
+ std::vector<tensorflow::int32> schema,
+ std::vector<tensorflow::int32> permutation)
+ : DatasetBase(tensorflow::DatasetContext(ctx)),
+ cache_name(cache_name),
+ host(host),
+ port(port),
+ local(local),
+ part(part),
+ page_size(page_size),
+ username(username),
+ password(password),
+ certfile(certfile),
+ keyfile(keyfile),
+ cert_password(cert_password),
+ schema(schema),
+ permutation(permutation) {
+ SchemaToTypes();
+ SchemaToShapes();
+
+ LOG(INFO) << "Ignite Dataset created [cache_name='" << cache_name
+ << "', host='" << host << "', port=" << port << ", local=" << local
+ << ", part=" << part << ", page_size=" << page_size
+ << ", username='" << username << "', certfile='" << certfile
+ << "', keyfile='" << keyfile + "']";
+}
+
+IgniteDataset::~IgniteDataset() { LOG(INFO) << "Ignite Dataset destroyed"; }
+
+std::unique_ptr<tensorflow::IteratorBase> IgniteDataset::MakeIteratorInternal(
+ const tensorflow::string& prefix) const {
+ return std::unique_ptr<tensorflow::IteratorBase>(new IgniteDatasetIterator(
+ {this, tensorflow::strings::StrCat(prefix, "::Ignite")}, this->host,
+ this->port, this->cache_name, this->local, this->part, this->page_size,
+ this->username, this->password, this->certfile, this->keyfile,
+ this->cert_password, this->schema, this->permutation));
+}
+
+const tensorflow::DataTypeVector& IgniteDataset::output_dtypes() const {
+ return dtypes;
+}
+
+const std::vector<tensorflow::PartialTensorShape>&
+IgniteDataset::output_shapes() const {
+ return shapes;
+}
+
+tensorflow::string IgniteDataset::DebugString() const {
+ return "IgniteDatasetOp::Dataset";
+}
+
+tensorflow::Status IgniteDataset::AsGraphDefInternal(
+ tensorflow::SerializationContext* ctx, DatasetGraphDefBuilder* b,
+ tensorflow::Node** output) const {
+ return tensorflow::errors::Unimplemented(
+ "IgniteDataset does not support 'AsGraphDefInternal'");
+}
+
+void IgniteDataset::SchemaToTypes() {
+ for (auto e : schema) {
+ if (e == BYTE || e == BYTE_ARR) {
+ dtypes.push_back(tensorflow::DT_UINT8);
+ } else if (e == SHORT || e == SHORT_ARR) {
+ dtypes.push_back(tensorflow::DT_INT16);
+ } else if (e == INT || e == INT_ARR) {
+ dtypes.push_back(tensorflow::DT_INT32);
+ } else if (e == LONG || e == LONG_ARR) {
+ dtypes.push_back(tensorflow::DT_INT64);
+ } else if (e == FLOAT || e == FLOAT_ARR) {
+ dtypes.push_back(tensorflow::DT_FLOAT);
+ } else if (e == DOUBLE || e == DOUBLE_ARR) {
+ dtypes.push_back(tensorflow::DT_DOUBLE);
+ } else if (e == UCHAR || e == UCHAR_ARR) {
+ dtypes.push_back(tensorflow::DT_UINT8);
+ } else if (e == BOOL || e == BOOL_ARR) {
+ dtypes.push_back(tensorflow::DT_BOOL);
+ } else if (e == STRING || e == STRING_ARR) {
+ dtypes.push_back(tensorflow::DT_STRING);
+ } else {
+ LOG(ERROR) << "Unexpected type in schema [type_id=" << e << "]";
+ }
+ }
+}
+
+void IgniteDataset::SchemaToShapes() {
+ for (auto e : schema) {
+ if (e >= 1 && e < 10) {
+ shapes.push_back(tensorflow::PartialTensorShape({}));
+ } else if (e >= 12 && e < 21) {
+ shapes.push_back(tensorflow::PartialTensorShape({-1}));
+ } else {
+ LOG(ERROR) << "Unexpected type in schema [type_id=" << e << "]";
+ }
+ }
+}
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.h b/tensorflow/contrib/ignite/kernels/ignite_dataset.h
new file mode 100644
index 0000000000..2120dfd342
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.h
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/dataset.h"
+
+namespace ignite {
+
+class IgniteDataset : public tensorflow::DatasetBase {
+ public:
+ IgniteDataset(tensorflow::OpKernelContext* ctx, std::string cache_name,
+ std::string host, tensorflow::int32 port, bool local,
+ tensorflow::int32 part, tensorflow::int32 page_size,
+ std::string username, std::string password,
+ std::string certfile, std::string keyfile,
+ std::string cert_password,
+ std::vector<tensorflow::int32> schema,
+ std::vector<tensorflow::int32> permutation);
+ ~IgniteDataset();
+ std::unique_ptr<tensorflow::IteratorBase> MakeIteratorInternal(
+ const tensorflow::string& prefix) const override;
+ const tensorflow::DataTypeVector& output_dtypes() const override;
+ const std::vector<tensorflow::PartialTensorShape>& output_shapes()
+ const override;
+ tensorflow::string DebugString() const override;
+
+ protected:
+ tensorflow::Status AsGraphDefInternal(
+ tensorflow::SerializationContext* ctx, DatasetGraphDefBuilder* b,
+ tensorflow::Node** output) const override;
+
+ private:
+ const std::string cache_name;
+ const std::string host;
+ const tensorflow::int32 port;
+ const bool local;
+ const tensorflow::int32 part;
+ const tensorflow::int32 page_size;
+ const std::string username;
+ const std::string password;
+ const std::string certfile;
+ const std::string keyfile;
+ const std::string cert_password;
+ const std::vector<tensorflow::int32> schema;
+ const std::vector<tensorflow::int32> permutation;
+
+ tensorflow::DataTypeVector dtypes;
+ std::vector<tensorflow::PartialTensorShape> shapes;
+
+ void SchemaToTypes();
+ void SchemaToShapes();
+};
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
new file mode 100644
index 0000000000..03cc3c1291
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
@@ -0,0 +1,447 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_dataset_iterator.h"
+
+#include "ignite_plain_client.h"
+#include "ignite_ssl_wrapper.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include <time.h>
+#include <chrono>
+
+namespace ignite {
+
+#define CHECK_STATUS(status) \
+ if (!status.ok()) return status;
+
+IgniteDatasetIterator::IgniteDatasetIterator(
+ const Params& params, std::string host, tensorflow::int32 port,
+ std::string cache_name, bool local, tensorflow::int32 part,
+ tensorflow::int32 page_size, std::string username, std::string password,
+ std::string certfile, std::string keyfile, std::string cert_password,
+ std::vector<tensorflow::int32> schema,
+ std::vector<tensorflow::int32> permutation)
+ : tensorflow::DatasetIterator<IgniteDataset>(params),
+ cache_name(cache_name),
+ local(local),
+ part(part),
+ page_size(page_size),
+ username(username),
+ password(password),
+ schema(schema),
+ permutation(permutation),
+ remainder(-1),
+ cursor_id(-1),
+ last_page(false) {
+ Client* p_client = new PlainClient(host, port);
+
+ if (certfile.empty())
+ client = std::unique_ptr<Client>(p_client);
+ else
+ client = std::unique_ptr<Client>(new SslWrapper(
+ std::unique_ptr<Client>(p_client), certfile, keyfile, cert_password));
+
+ LOG(INFO) << "Ignite Dataset Iterator created";
+}
+
+IgniteDatasetIterator::~IgniteDatasetIterator() {
+ tensorflow::Status status = CloseConnection();
+ if (!status.ok()) LOG(ERROR) << status.ToString();
+
+ LOG(INFO) << "Ignite Dataset Iterator destroyed";
+}
+
+tensorflow::Status IgniteDatasetIterator::EstablishConnection() {
+ if (!client->IsConnected()) {
+ tensorflow::Status status = client->Connect();
+ if (!status.ok()) return status;
+
+ status = Handshake();
+ if (!status.ok()) {
+ tensorflow::Status disconnect_status = client->Disconnect();
+ if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString();
+
+ return status;
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::CloseConnection() {
+ if (cursor_id != -1 && !last_page) {
+ tensorflow::Status conn_status = EstablishConnection();
+ if (!conn_status.ok()) return conn_status;
+
+ CHECK_STATUS(client->WriteInt(18)); // Message length
+ CHECK_STATUS(
+ client->WriteShort(close_connection_opcode)); // Operation code
+ CHECK_STATUS(client->WriteLong(0)); // Request ID
+ CHECK_STATUS(client->WriteLong(cursor_id)); // Resource ID
+
+ int32_t res_len;
+ CHECK_STATUS(client->ReadInt(res_len));
+ if (res_len < 12)
+ return tensorflow::errors::Internal(
+ "Close Resource Response is corrupted");
+
+ int64_t req_id;
+ CHECK_STATUS(client->ReadLong(req_id));
+ int32_t status;
+ CHECK_STATUS(client->ReadInt(status));
+ if (status != 0) {
+ uint8_t err_msg_header;
+ CHECK_STATUS(client->ReadByte(err_msg_header));
+ if (err_msg_header == string_val) {
+ int32_t err_msg_length;
+ CHECK_STATUS(client->ReadInt(err_msg_length));
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ CHECK_STATUS(client->ReadData(err_msg_c, err_msg_length));
+ std::string err_msg((char*)err_msg_c, err_msg_length);
+ delete[] err_msg_c;
+
+ return tensorflow::errors::Internal("Close Resource Error [status=",
+ status, ", message=", err_msg, "]");
+ }
+ return tensorflow::errors::Internal("Close Resource Error [status=",
+ status, "]");
+ }
+
+ LOG(INFO) << "Query Cursor " << cursor_id << " is closed";
+
+ cursor_id = -1;
+
+ return client->Disconnect();
+ } else {
+ LOG(INFO) << "Query Cursor " << cursor_id << " is already closed";
+ }
+
+ return client->IsConnected() ? client->Disconnect()
+ : tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::GetNextInternal(
+ tensorflow::IteratorContext* ctx,
+ std::vector<tensorflow::Tensor>* out_tensors, bool* end_of_sequence) {
+ if (remainder == 0 && last_page) {
+ LOG(INFO) << "Query Cursor " << cursor_id << " is closed";
+
+ cursor_id = -1;
+ *end_of_sequence = true;
+ return tensorflow::Status::OK();
+ } else {
+ tensorflow::Status status = EstablishConnection();
+ if (!status.ok()) return status;
+
+ if (remainder == -1 || remainder == 0) {
+ tensorflow::Status status =
+ remainder == -1 ? ScanQuery() : LoadNextPage();
+ if (!status.ok()) return status;
+ }
+
+ uint8_t* initial_ptr = ptr;
+ std::vector<int32_t> types;
+ std::vector<tensorflow::Tensor> tensors;
+
+ status = parser.Parse(ptr, tensors, types); // Parse key
+ if (!status.ok()) return status;
+
+ status = parser.Parse(ptr, tensors, types); // Parse val
+ if (!status.ok()) return status;
+
+ remainder -= (ptr - initial_ptr);
+
+ out_tensors->resize(tensors.size());
+ for (int32_t i = 0; i < tensors.size(); i++)
+ (*out_tensors)[permutation[i]] = std::move(tensors[i]);
+
+ *end_of_sequence = false;
+ return tensorflow::Status::OK();
+ }
+
+ *end_of_sequence = true;
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::SaveInternal(
+ tensorflow::IteratorStateWriter* writer) {
+ return tensorflow::errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'SaveInternal'");
+}
+
+tensorflow::Status IgniteDatasetIterator::RestoreInternal(
+ tensorflow::IteratorContext* ctx, tensorflow::IteratorStateReader* reader) {
+ return tensorflow::errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'RestoreInternal')");
+}
+
+tensorflow::Status IgniteDatasetIterator::Handshake() {
+ int32_t msg_len = 8;
+
+ if (username.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + username.length();
+
+ if (password.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + password.length();
+
+ CHECK_STATUS(client->WriteInt(msg_len));
+ CHECK_STATUS(client->WriteByte(1));
+ CHECK_STATUS(client->WriteShort(protocol_major_version));
+ CHECK_STATUS(client->WriteShort(protocol_minor_version));
+ CHECK_STATUS(client->WriteShort(protocol_patch_version));
+ CHECK_STATUS(client->WriteByte(2));
+ if (username.empty()) {
+ CHECK_STATUS(client->WriteByte(null_val));
+ } else {
+ CHECK_STATUS(client->WriteByte(string_val));
+ CHECK_STATUS(client->WriteInt(username.length()));
+ CHECK_STATUS(
+ client->WriteData((uint8_t*)username.c_str(), username.length()));
+ }
+
+ if (password.empty()) {
+ CHECK_STATUS(client->WriteByte(null_val));
+ } else {
+ CHECK_STATUS(client->WriteByte(string_val));
+ CHECK_STATUS(client->WriteInt(password.length()));
+ CHECK_STATUS(
+ client->WriteData((uint8_t*)password.c_str(), password.length()));
+ }
+
+ int32_t handshake_res_len;
+ CHECK_STATUS(client->ReadInt(handshake_res_len));
+ uint8_t handshake_res;
+ CHECK_STATUS(client->ReadByte(handshake_res));
+
+ LOG(INFO) << "Handshake length " << handshake_res_len << ", res "
+ << (int16_t)handshake_res;
+
+ if (handshake_res != 1) {
+ int16_t serv_ver_major;
+ CHECK_STATUS(client->ReadShort(serv_ver_major));
+ int16_t serv_ver_minor;
+ CHECK_STATUS(client->ReadShort(serv_ver_minor));
+ int16_t serv_ver_patch;
+ CHECK_STATUS(client->ReadShort(serv_ver_patch));
+ uint8_t header;
+ CHECK_STATUS(client->ReadByte(header));
+
+ if (header == string_val) {
+ int32_t length;
+ CHECK_STATUS(client->ReadInt(length));
+ uint8_t* err_msg_c = new uint8_t[length];
+ CHECK_STATUS(client->ReadData(err_msg_c, length));
+ std::string err_msg((char*)err_msg_c, length);
+ delete[] err_msg_c;
+
+ return tensorflow::errors::Internal(
+ "Handshake Error [result=", handshake_res, ", version=",
+ serv_ver_major, ".", serv_ver_minor, ".", serv_ver_patch,
+ ", message='", err_msg, "']");
+ } else if (header == null_val) {
+ return tensorflow::errors::Internal(
+ "Handshake Error [result=", handshake_res, ", version=",
+ serv_ver_major, ".", serv_ver_minor, ".", serv_ver_patch, "]");
+ } else {
+ return tensorflow::errors::Internal(
+ "Handshake Error [result=", handshake_res, ", version=",
+ serv_ver_major, ".", serv_ver_minor, ".", serv_ver_patch, "]");
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::ScanQuery() {
+ CHECK_STATUS(client->WriteInt(25)); // Message length
+ CHECK_STATUS(client->WriteShort(scan_query_opcode)); // Operation code
+ CHECK_STATUS(client->WriteLong(0)); // Request ID
+ CHECK_STATUS(client->WriteInt(JavaHashCode(cache_name))); // Cache name
+ CHECK_STATUS(client->WriteByte(0)); // Flags
+ CHECK_STATUS(client->WriteByte(null_val)); // Filter object
+ CHECK_STATUS(client->WriteInt(page_size)); // Cursor page size
+ CHECK_STATUS(client->WriteInt(part)); // Partition to query
+ CHECK_STATUS(client->WriteByte(local)); // Local flag
+
+ int64_t wait_start = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ int32_t res_len;
+ CHECK_STATUS(client->ReadInt(res_len));
+
+ int64_t wait_stop = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) << " ms";
+
+ if (res_len < 12)
+ return tensorflow::errors::Internal("Scan Query Response is corrupted");
+
+ int64_t req_id;
+ CHECK_STATUS(client->ReadLong(req_id));
+
+ int32_t status;
+ CHECK_STATUS(client->ReadInt(status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ CHECK_STATUS(client->ReadByte(err_msg_header));
+
+ if (err_msg_header == string_val) {
+ int32_t err_msg_length;
+ CHECK_STATUS(client->ReadInt(err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ CHECK_STATUS(client->ReadData(err_msg_c, err_msg_length));
+ std::string err_msg((char*)err_msg_c, err_msg_length);
+ delete[] err_msg_c;
+
+ return tensorflow::errors::Internal("Scan Query Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return tensorflow::errors::Internal("Scan Query Error [status=", status,
+ "]");
+ }
+
+ CHECK_STATUS(client->ReadLong(cursor_id));
+
+ LOG(INFO) << "Query Cursor " << cursor_id << " is opened";
+
+ int32_t row_cnt;
+ CHECK_STATUS(client->ReadInt(row_cnt));
+
+ remainder = res_len - 25;
+ page = std::unique_ptr<uint8_t>(new uint8_t[remainder]);
+ ptr = page.get();
+
+ int64_t start = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ CHECK_STATUS(client->ReadData(ptr, remainder));
+
+ int64_t stop = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+ ;
+
+ double size_in_mb = 1.0 * remainder / 1024 / 1024;
+ double time_in_s = 1.0 * (stop - start) / 1000;
+ LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
+ << " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
+
+ uint8_t last_page_b;
+ CHECK_STATUS(client->ReadByte(last_page_b));
+
+ last_page = !last_page_b;
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::LoadNextPage() {
+ CHECK_STATUS(client->WriteInt(18)); // Message length
+ CHECK_STATUS(client->WriteShort(load_next_page_opcode)); // Operation code
+ CHECK_STATUS(client->WriteLong(0)); // Request ID
+ CHECK_STATUS(client->WriteLong(cursor_id)); // Cursor ID
+
+ int64_t wait_start = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ int32_t res_len;
+ CHECK_STATUS(client->ReadInt(res_len));
+
+ int64_t wait_stop = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) << " ms";
+
+ if (res_len < 12)
+ return tensorflow::errors::Internal("Load Next Page Response is corrupted");
+
+ int64_t req_id;
+ CHECK_STATUS(client->ReadLong(req_id));
+
+ int32_t status;
+ CHECK_STATUS(client->ReadInt(status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ CHECK_STATUS(client->ReadByte(err_msg_header));
+
+ if (err_msg_header == string_val) {
+ int32_t err_msg_length;
+ CHECK_STATUS(client->ReadInt(err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ CHECK_STATUS(client->ReadData(err_msg_c, err_msg_length));
+ std::string err_msg((char*)err_msg_c, err_msg_length);
+ delete[] err_msg_c;
+
+ return tensorflow::errors::Internal("Load Next Page Error [status=",
+ status, ", message=", err_msg, "]");
+ }
+ return tensorflow::errors::Internal("Load Next Page Error [status=", status,
+ "]");
+ }
+
+ int32_t row_cnt;
+ CHECK_STATUS(client->ReadInt(row_cnt));
+
+ remainder = res_len - 17;
+ page = std::unique_ptr<uint8_t>(new uint8_t[remainder]);
+ ptr = page.get();
+
+ int64_t start = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ CHECK_STATUS(client->ReadData(ptr, remainder));
+
+ int64_t stop = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+ ;
+
+ double size_in_mb = 1.0 * remainder / 1024 / 1024;
+ double time_in_s = 1.0 * (stop - start) / 1000;
+ LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
+ << " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
+
+ uint8_t last_page_b;
+ CHECK_STATUS(client->ReadByte(last_page_b));
+
+ last_page = !last_page_b;
+
+ return tensorflow::Status::OK();
+}
+
+int32_t IgniteDatasetIterator::JavaHashCode(std::string str) {
+ int32_t h = 0;
+ for (char& c : str) {
+ h = 31 * h + c;
+ }
+ return h;
+}
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h
new file mode 100644
index 0000000000..d1df4527f9
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_binary_object_parser.h"
+#include "ignite_dataset.h"
+
+#ifndef IGNITE_CLIENT_H
+#define IGNITE_CLIENT_H
+#include "ignite_client.h"
+#endif
+
+namespace ignite {
+
+class IgniteDatasetIterator
+ : public tensorflow::DatasetIterator<IgniteDataset> {
+ public:
+ IgniteDatasetIterator(const Params& params, std::string host,
+ tensorflow::int32 port, std::string cache_name,
+ bool local, tensorflow::int32 part,
+ tensorflow::int32 page_size, std::string username,
+ std::string password, std::string certfile,
+ std::string keyfile, std::string cert_password,
+ std::vector<tensorflow::int32> schema,
+ std::vector<tensorflow::int32> permutation);
+ ~IgniteDatasetIterator();
+ tensorflow::Status GetNextInternal(
+ tensorflow::IteratorContext* ctx,
+ std::vector<tensorflow::Tensor>* out_tensors,
+ bool* end_of_sequence) override;
+
+ protected:
+ tensorflow::Status SaveInternal(
+ tensorflow::IteratorStateWriter* writer) override;
+ tensorflow::Status RestoreInternal(
+ tensorflow::IteratorContext* ctx,
+ tensorflow::IteratorStateReader* reader) override;
+
+ private:
+ std::unique_ptr<Client> client;
+ BinaryObjectParser parser;
+
+ const std::string cache_name;
+ const bool local;
+ const tensorflow::int32 part;
+ const tensorflow::int32 page_size;
+ const std::string username;
+ const std::string password;
+ const std::vector<tensorflow::int32> schema;
+ const std::vector<tensorflow::int32> permutation;
+
+ int32_t remainder;
+ int64_t cursor_id;
+ bool last_page;
+
+ std::unique_ptr<uint8_t> page;
+ uint8_t* ptr;
+
+ tensorflow::Status EstablishConnection();
+ tensorflow::Status CloseConnection();
+ tensorflow::Status Handshake();
+ tensorflow::Status ScanQuery();
+ tensorflow::Status LoadNextPage();
+ int32_t JavaHashCode(std::string str);
+};
+
+constexpr uint8_t null_val = 101;
+constexpr uint8_t string_val = 9;
+constexpr uint8_t protocol_major_version = 1;
+constexpr uint8_t protocol_minor_version = 1;
+constexpr uint8_t protocol_patch_version = 0;
+constexpr int16_t scan_query_opcode = 2000;
+constexpr int16_t load_next_page_opcode = 2001;
+constexpr int16_t close_connection_opcode = 0;
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
new file mode 100644
index 0000000000..543b5e4afc
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_dataset.h"
+#include <stdlib.h>
+#include "tensorflow/core/framework/dataset.h"
+
+namespace tensorflow {
+
+class IgniteDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ std::string cache_name = "";
+ std::string host = "";
+ int32 port = -1;
+ bool local = false;
+ int32 part = -1;
+ int32 page_size = -1;
+ std::string username = "";
+ std::string password = "";
+ std::string certfile = "";
+ std::string keyfile = "";
+ std::string cert_password = "";
+
+ const char* env_cache_name = std::getenv("IGNITE_DATASET_CACHE_NAME");
+ const char* env_host = std::getenv("IGNITE_DATASET_HOST");
+ const char* env_port = std::getenv("IGNITE_DATASET_PORT");
+ const char* env_local = std::getenv("IGNITE_DATASET_LOCAL");
+ const char* env_part = std::getenv("IGNITE_DATASET_PART");
+ const char* env_page_size = std::getenv("IGNITE_DATASET_PAGE_SIZE");
+ const char* env_username = std::getenv("IGNITE_DATASET_USERNAME");
+ const char* env_password = std::getenv("IGNITE_DATASET_PASSWORD");
+ const char* env_certfile = std::getenv("IGNITE_DATASET_CERTFILE");
+ const char* env_keyfile = std::getenv("IGNITE_DATASET_KEYFILE");
+ const char* env_cert_password = std::getenv("IGNITE_DATASET_CERT_PASSWORD");
+
+ if (env_cache_name)
+ cache_name = std::string(env_cache_name);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "cache_name",
+ &cache_name));
+
+ if (env_host)
+ host = std::string(env_host);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "host", &host));
+
+ if (env_port)
+ port = atoi(env_port);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "port", &port));
+
+ if (env_local)
+ local = true;
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "local", &local));
+
+ if (env_part)
+ part = atoi(env_part);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "part", &part));
+
+ if (env_page_size)
+ page_size = atoi(env_page_size);
+ else
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int32>(ctx, "page_size", &page_size));
+
+ if (env_username)
+ username = std::string(env_username);
+ else
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<std::string>(ctx, "username", &username));
+
+ if (env_password)
+ password = std::string(env_password);
+ else
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<std::string>(ctx, "password", &password));
+
+ if (env_certfile)
+ certfile = std::string(env_certfile);
+ else
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<std::string>(ctx, "certfile", &certfile));
+
+ if (env_keyfile)
+ keyfile = std::string(env_keyfile);
+ else
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<std::string>(ctx, "keyfile", &keyfile));
+
+ if (env_cert_password)
+ cert_password = std::string(env_cert_password);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "cert_password",
+ &cert_password));
+
+ const Tensor* schema_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("schema", &schema_tensor));
+ OP_REQUIRES(ctx, schema_tensor->dims() == 1,
+ errors::InvalidArgument("`schema` must be a vector."));
+
+ std::vector<int32> schema;
+ schema.reserve(schema_tensor->NumElements());
+ for (int i = 0; i < schema_tensor->NumElements(); i++) {
+ schema.push_back(schema_tensor->flat<int32>()(i));
+ }
+
+ const Tensor* permutation_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("permutation", &permutation_tensor));
+ OP_REQUIRES(ctx, schema_tensor->dims() == 1,
+ errors::InvalidArgument("`permutation` must be a vector."));
+
+ std::vector<int32> permutation;
+ permutation.reserve(permutation_tensor->NumElements());
+ for (int i = 0; i < permutation_tensor->NumElements(); i++) {
+ permutation.push_back(permutation_tensor->flat<int32>()(i));
+ }
+
+ *output = new ignite::IgniteDataset(
+ ctx, cache_name, host, port, local, part, page_size, username, password,
+ certfile, keyfile, cert_password, std::move(schema),
+ std::move(permutation));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("IgniteDataset").Device(DEVICE_CPU),
+ IgniteDatasetOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h
new file mode 100644
index 0000000000..5491af68d6
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef IGNITE_CLIENT_H
+#define IGNITE_CLIENT_H
+#include "ignite_client.h"
+#endif
+
+#include <string>
+
+namespace ignite {
+
+class PlainClient : public Client {
+ public:
+ PlainClient(std::string host, int port);
+ ~PlainClient();
+
+ virtual tensorflow::Status Connect();
+ virtual tensorflow::Status Disconnect();
+ virtual bool IsConnected();
+ virtual int GetSocketDescriptor();
+ virtual tensorflow::Status ReadData(uint8_t* buf, int32_t length);
+ virtual tensorflow::Status WriteData(uint8_t* buf, int32_t length);
+
+ private:
+ std::string host;
+ int port;
+ int sock;
+};
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc
new file mode 100644
index 0000000000..dbfa4f8786
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc
@@ -0,0 +1,132 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_plain_client.h"
+
+#include <arpa/inet.h>
+#include <netdb.h>
+#include <sys/socket.h>
+
+#include <arpa/inet.h>
+#include <sys/socket.h>
+#include <unistd.h>
+#include <map>
+
+#include <iostream>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace ignite {
+
+PlainClient::PlainClient(std::string host, int port)
+ : host(host), port(port), sock(-1) {}
+
+PlainClient::~PlainClient() {
+ if (IsConnected()) {
+ tensorflow::Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+}
+
+tensorflow::Status PlainClient::Connect() {
+ if (sock == -1) {
+ sock = socket(AF_INET, SOCK_STREAM, 0);
+ if (sock == -1)
+ return tensorflow::errors::Internal("Failed to create socket");
+ }
+
+ sockaddr_in server;
+
+ server.sin_addr.s_addr = inet_addr(host.c_str());
+ if (server.sin_addr.s_addr == -1) {
+ hostent* he;
+ in_addr** addr_list;
+
+ if ((he = gethostbyname(host.c_str())) == NULL)
+ return tensorflow::errors::Internal("Failed to resolve hostname \"", host,
+ "\"");
+
+ addr_list = (in_addr**)he->h_addr_list;
+ if (addr_list[0] != NULL) server.sin_addr = *addr_list[0];
+ }
+
+ server.sin_family = AF_INET;
+ server.sin_port = htons(port);
+
+ if (connect(sock, (sockaddr*)&server, sizeof(server)) < 0)
+ return tensorflow::errors::Internal("Failed to connect to \"", host, ":",
+ port, "\"");
+
+ LOG(INFO) << "Connection to \"" << host << ":" << port << "\" established";
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status PlainClient::Disconnect() {
+ int close_res = close(sock);
+ sock = -1;
+
+ LOG(INFO) << "Connection to \"" << host << ":" << port << "\" is closed";
+
+ return close_res == 0 ? tensorflow::Status::OK()
+ : tensorflow::errors::Internal(
+ "Failed to correctly close connection");
+}
+
+bool PlainClient::IsConnected() { return sock != -1; }
+
+int PlainClient::GetSocketDescriptor() { return sock; }
+
+tensorflow::Status PlainClient::ReadData(uint8_t* buf, int32_t length) {
+ int recieved = 0;
+
+ while (recieved < length) {
+ int res = recv(sock, buf, length - recieved, 0);
+
+ if (res < 0)
+ return tensorflow::errors::Internal(
+ "Error occured while reading from socket: ", res, ", ",
+ std::string(strerror(errno)));
+
+ if (res == 0)
+ return tensorflow::errors::Internal("Server closed connection");
+
+ recieved += res;
+ buf += res;
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status PlainClient::WriteData(uint8_t* buf, int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = send(sock, buf, length - sent, 0);
+
+ if (res < 0)
+ return tensorflow::errors::Internal(
+ "Error occured while writing into socket: ", res, ", ",
+ std::string(strerror(errno)));
+
+ sent += res;
+ buf += res;
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc
new file mode 100644
index 0000000000..f78c9b3627
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc
@@ -0,0 +1,143 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_plain_client.h"
+
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+#pragma comment(lib, "Ws2_32.lib")
+#pragma comment(lib, "Mswsock.lib")
+#pragma comment(lib, "AdvApi32.lib")
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace ignite {
+
+PlainClient::PlainClient(std::string host, int port)
+ : host(host), port(port), sock(INVALID_SOCKET) {}
+
+PlainClient::~PlainClient() {
+ if (IsConnected()) {
+ tensorflow::Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+}
+
+tensorflow::Status PlainClient::Connect() {
+ WSADATA wsaData;
+ addrinfo *result = NULL, *ptr = NULL, hints;
+
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
+ if (res != 0)
+ return tensorflow::errors::Internal("WSAStartup failed with error: ", res);
+
+ ZeroMemory(&hints, sizeof(hints));
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = IPPROTO_TCP;
+
+ res =
+ getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &result);
+ if (res != 0)
+ return tensorflow::errors::Internal("Getaddrinfo failed with error: ", res);
+
+ for (ptr = result; ptr != NULL; ptr = ptr->ai_next) {
+ sock = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol);
+ if (sock == INVALID_SOCKET) {
+ WSACleanup();
+ return tensorflow::errors::Internal("Socket failed with error: ",
+ WSAGetLastError());
+ }
+
+ res = connect(sock, ptr->ai_addr, (int)ptr->ai_addrlen);
+ if (res == SOCKET_ERROR) {
+ closesocket(sock);
+ sock = INVALID_SOCKET;
+ continue;
+ }
+
+ break;
+ }
+
+ freeaddrinfo(result);
+
+ if (sock == INVALID_SOCKET) {
+ WSACleanup();
+ return tensorflow::errors::Internal("Unable to connect to server");
+ }
+
+ LOG(INFO) << "Connection to \"" << host << ":" << port << "\" established";
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status PlainClient::Disconnect() {
+ int res = shutdown(sock, SD_SEND);
+ closesocket(sock);
+ WSACleanup();
+
+ if (res == SOCKET_ERROR)
+ return tensorflow::errors::Internal("Shutdown failed with error: ",
+ WSAGetLastError());
+ else
+ return tensorflow::Status::OK();
+}
+
+bool PlainClient::IsConnected() { return sock != INVALID_SOCKET; }
+
+int PlainClient::GetSocketDescriptor() { return sock; }
+
+tensorflow::Status PlainClient::ReadData(uint8_t *buf, int32_t length) {
+ int recieved = 0;
+
+ while (recieved < length) {
+ int res = recv(sock, buf, length - recieved, 0);
+
+ if (res < 0)
+ return tensorflow::errors::Internal(
+ "Error occured while reading from socket: ", res);
+
+ if (res == 0)
+ return tensorflow::errors::Internal("Server closed connection");
+
+ recieved += res;
+ buf += res;
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status PlainClient::WriteData(uint8_t *buf, int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = send(sock, buf, length - sent, 0);
+
+ if (res < 0)
+ return tensorflow::errors::Internal(
+ "Error occured while writing into socket: ", res);
+
+ sent += res;
+ buf += res;
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
new file mode 100644
index 0000000000..a1101b91f3
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_ssl_wrapper.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+namespace ignite {
+
+static int PasswordCb(char *buf, int size, int rwflag, void *password) {
+ strncpy(buf, (char *)(password), size);
+ buf[size - 1] = '\0';
+ return (strlen(buf));
+}
+
+SslWrapper::SslWrapper(std::shared_ptr<Client> client, std::string certfile,
+ std::string keyfile, std::string cert_password)
+ : client(client),
+ certfile(certfile),
+ keyfile(keyfile),
+ cert_password(cert_password),
+ ctx(NULL) {}
+
+SslWrapper::~SslWrapper() {
+ if (IsConnected()) {
+ tensorflow::Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+
+ if (ctx != NULL) {
+ SSL_CTX_free(ctx);
+ ctx = NULL;
+ }
+}
+
+tensorflow::Status SslWrapper::InitSslContext() {
+ OpenSSL_add_all_algorithms();
+ SSL_load_error_strings();
+
+ ctx = SSL_CTX_new(SSLv23_method());
+ if (ctx == NULL)
+ return tensorflow::errors::Internal("Couldn't create SSL context");
+
+ SSL_CTX_set_default_passwd_cb(ctx, PasswordCb);
+ SSL_CTX_set_default_passwd_cb_userdata(ctx, (void *)cert_password.c_str());
+
+ if (SSL_CTX_use_certificate_chain_file(ctx, certfile.c_str()) != 1)
+ return tensorflow::errors::Internal(
+ "Couldn't load cetificate chain (file '", certfile, "')");
+
+ std::string private_key_file = keyfile.empty() ? certfile : keyfile;
+ if (SSL_CTX_use_PrivateKey_file(ctx, private_key_file.c_str(),
+ SSL_FILETYPE_PEM) != 1)
+ return tensorflow::errors::Internal("Couldn't load private key (file '",
+ private_key_file, "')");
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status SslWrapper::Connect() {
+ tensorflow::Status status;
+
+ if (ctx == NULL) {
+ status = InitSslContext();
+ if (!status.ok()) return status;
+ }
+
+ ssl = SSL_new(ctx);
+ if (ssl == NULL)
+ return tensorflow::errors::Internal("Failed to establish SSL connection");
+
+ status = client->Connect();
+ if (!status.ok()) return status;
+
+ SSL_set_fd(ssl, client->GetSocketDescriptor());
+ if (SSL_connect(ssl) != 1)
+ return tensorflow::errors::Internal("Failed to establish SSL connection");
+
+ LOG(INFO) << "SSL connection established";
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status SslWrapper::Disconnect() {
+ SSL_free(ssl);
+
+ LOG(INFO) << "SSL connection closed";
+
+ return client->Disconnect();
+}
+
+bool SslWrapper::IsConnected() { return client->IsConnected(); }
+
+int SslWrapper::GetSocketDescriptor() { return client->GetSocketDescriptor(); }
+
+tensorflow::Status SslWrapper::ReadData(uint8_t *buf, int32_t length) {
+ int recieved = 0;
+
+ while (recieved < length) {
+ int res = SSL_read(ssl, buf, length - recieved);
+
+ if (res < 0)
+ return tensorflow::errors::Internal(
+ "Error occured while reading from SSL socket: ", res);
+
+ if (res == 0)
+ return tensorflow::errors::Internal("Server closed SSL connection");
+
+ recieved += res;
+ buf += res;
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status SslWrapper::WriteData(uint8_t *buf, int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = SSL_write(ssl, buf, length - sent);
+
+ if (res < 0)
+ return tensorflow::errors::Internal(
+ "Error occured while writing into socket: ", res);
+
+ sent += res;
+ buf += res;
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h
new file mode 100644
index 0000000000..e0c2a242dc
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef IGNITE_CLIENT_H
+#define IGNITE_CLIENT_H
+#include "ignite_client.h"
+#endif
+
+#include <openssl/ssl.h>
+#include <string>
+
+namespace ignite {
+
+class SslWrapper : public Client {
+ public:
+ SslWrapper(std::shared_ptr<Client> client, std::string certfile,
+ std::string keyfile, std::string cert_password);
+ ~SslWrapper();
+
+ virtual tensorflow::Status Connect();
+ virtual tensorflow::Status Disconnect();
+ virtual bool IsConnected();
+ virtual int GetSocketDescriptor();
+ virtual tensorflow::Status ReadData(uint8_t* buf, int32_t length);
+ virtual tensorflow::Status WriteData(uint8_t* buf, int32_t length);
+
+ private:
+ std::shared_ptr<Client> client;
+ std::string certfile;
+ std::string keyfile;
+ std::string cert_password;
+ SSL_CTX* ctx;
+ SSL* ssl;
+ tensorflow::Status InitSslContext();
+};
+
+} // namespace ignite
diff --git a/tensorflow/contrib/ignite/ops/dataset_ops.cc b/tensorflow/contrib/ignite/ops/dataset_ops.cc
new file mode 100644
index 0000000000..17494d1cfd
--- /dev/null
+++ b/tensorflow/contrib/ignite/ops/dataset_ops.cc
@@ -0,0 +1,64 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("IgniteDataset")
+ .Input("cache_name: string")
+ .Input("host: string")
+ .Input("port: int32")
+ .Input("local: bool")
+ .Input("part: int32")
+ .Input("page_size: int32")
+ .Input("username: string")
+ .Input("password: string")
+ .Input("certfile: string")
+ .Input("keyfile: string")
+ .Input("cert_password: string")
+ .Input("schema: int32")
+ .Input("permutation: int32")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Apache Ignite is a memory-centric distributed database, caching, and processing
+platform for transactional, analytical, and streaming workloads, delivering
+in-memory speeds at petabyte scale. This contrib package contains an
+integration between Apache Ignite and TensorFlow. The integration is based on
+tf.data from TensorFlow side and Binary Client Protocol from Apache Ignite side.
+It allows to use Apache Ignite as a datasource for neural network training,
+inference and all other computations supported by TensorFlow. Ignite Dataset
+is based on Apache Ignite Binary Client Protocol.
+
+cache_name: Ignite Cache Name.
+host: Ignite Thin Client Host.
+port: Ignite Thin Client Port.
+local: Local flag that defines that data should be fetched from local host only.
+part: Partition data should be fetched from.
+page_size: Page size for Ignite Thin Client.
+username: Username to authenticate via Ignite Thin Client.
+password: Password to authenticate via Ignite Thin Client.
+certfile: SSL certificate to establish SSL connection.
+keyfile: Private key file to establish SSL connection.
+cert_password: SSL certificate password to establish SSL connection.
+schema: Internal structure that defines schema of cache objects.
+permutation: Internal structure that defines permutation of cache objects.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py
new file mode 100644
index 0000000000..6fa073957a
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py
@@ -0,0 +1,763 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Ignite Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import socket
+import struct
+import ssl
+import abc
+
+from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.ignite.python.ops import gen_dataset_ops
+from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+class Readable():
+ """Readable abstract class that exposes methods to do reading-related
+ operations.
+ """
+
+ @abc.abstractmethod
+ def __init__(self):
+ pass
+
+ def read_byte(self):
+ """Reads and returnes byte."""
+ return self.__read("b", 1)
+
+ def read_short(self):
+ """Reads and returns short (2 bytes, little-endian)."""
+ return self.__read("h", 2)
+
+ def read_int(self):
+ """Reads and returns int (4 bytes, little-endian)."""
+ return self.__read("i", 4)
+
+ def read_long(self):
+ """Reads and returns long (8 bytes, little-endian)."""
+ return self.__read("q", 8)
+
+ def skip(self, length):
+ """Skips the specified number of bytes."""
+ self.read_data(length)
+
+ @abc.abstractmethod
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ return None
+
+ def __read(self, data_type, length):
+ """Reads, unpacks and returns specified type (little-endian)."""
+ buffer = self.read_data(length)
+ return struct.unpack("<" + data_type, buffer)[0]
+
+class DataBuffer(Readable):
+ """DataBuffer class that exposes methods to read data from a byte buffer."""
+
+ def __init__(self, buffer):
+ """Constructs a new instance of DataBuffer based on the specified byte
+ buffer.
+
+ Args:
+ buffer: Buffer to be read.
+ """
+ Readable.__init__(self)
+ self.buffer = buffer
+ self.ptr = 0
+
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ data_buffer = self.buffer[self.ptr:][:length]
+ self.ptr += length
+ return data_buffer
+
+class TcpClient(Readable):
+ """TcpClient class that exposes methods to read data from a socket."""
+
+ def __init__(self, host, port, certfile=None, keyfile=None, password=None):
+ """Constructs a new instance of TcpClient based on the specified host
+ and port.
+
+ Args:
+ host: Host to be connected.
+ port: Port to be connected.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate’s
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key
+ will be taken from certfile as well).
+ password: Password to be used if the private key is encrypted and a
+ password is necessary.
+ """
+ Readable.__init__(self)
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+ if certfile is not None:
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.load_cert_chain(certfile, keyfile, password)
+ self.sock = context.wrap_socket(self.sock)
+ else:
+ if keyfile is not None:
+ raise Exception("SSL is disabled, keyfile must not be specified \
+ (to enable SSL specify certfile)")
+ if password is not None:
+ raise Exception("SSL is disabled, password must not be specified \
+ (to enable SSL specify certfile)")
+
+ self.host = host
+ self.port = port
+
+ def __enter__(self):
+ """Connects to host and port specified in the constructor."""
+ self.sock.connect((self.host, self.port))
+ return self
+
+ def __exit__(self, t, v, traceback):
+ """Disconnects the socket."""
+ self.sock.close()
+
+ def write_byte(self, v):
+ """Writes the specified byte."""
+ self.__write(v, "b")
+
+ def write_short(self, v):
+ """Writes the specified short (2 bytes, little-endian)."""
+ self.__write(v, "h")
+
+ def write_int(self, v):
+ """Writes the specified short (4 bytes, little-endian)."""
+ self.__write(v, "i")
+
+ def write_long(self, v):
+ """Writes the specified int (8 bytes, little-endian)."""
+ self.__write(v, "q")
+
+ def write_string(self, v):
+ """Writes the specified string."""
+ self.sock.sendall(v.encode("UTF-8"))
+
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ data_buffer = None
+ rem = length
+ while rem > 0:
+ buf = self.sock.recv(rem)
+ rem = rem - len(buf)
+ if data_buffer is None:
+ data_buffer = buf
+ else:
+ data_buffer += buf
+ return data_buffer
+
+ def __write(self, value, data_type):
+ """Packs and writes data using the specified type (little-endian)."""
+ data_buffer = struct.pack("<" + data_type, value)
+ self.sock.sendall(data_buffer)
+
+class BinaryType():
+ """BinaryType class that encapsulated type id, type name and fields."""
+
+ def __init__(self, type_id, type_name, fields):
+ """Constructs a new instance of BinaryType."""
+ self.type_id = type_id
+ self.type_name = type_name
+ self.fields = fields
+
+class BinaryField():
+ """BinaryField class that encapsulated field name, type id and field id."""
+
+ def __init__(self, field_name, type_id, field_id):
+ """Constructs a new instance of BinaryField."""
+ self.field_name = field_name
+ self.type_id = type_id
+ self.field_id = field_id
+
+# Binary types defined in Apache Ignite Thin client and supported by
+# TensorFlow on Apache Ignite, see
+# https://apacheignite.readme.io/v2.6/docs/binary-client-protocol.
+types = {
+ 1: (dtypes.uint8, False),
+ 2: (dtypes.int16, False),
+ 3: (dtypes.int32, False),
+ 4: (dtypes.int64, False),
+ 5: (dtypes.float32, False),
+ 6: (dtypes.float64, False),
+ 7: (dtypes.uint16, False),
+ 8: (dtypes.bool, False),
+ 9: (dtypes.string, False),
+ 12: (dtypes.uint8, True),
+ 13: (dtypes.int16, True),
+ 14: (dtypes.int32, True),
+ 15: (dtypes.int64, True),
+ 16: (dtypes.float32, True),
+ 17: (dtypes.float64, True),
+ 18: (dtypes.uint16, True),
+ 19: (dtypes.bool, True),
+ 20: (dtypes.string, True)
+}
+
+class TypeTreeNode():
+ """TypeTreeNode class exposes methods to format object tree structure
+ data.
+ """
+ def __init__(self, name, type_id, fields=None, permutation=None):
+ """Constructs a new instance of TypeTreeNode.
+
+ Args:
+ name: Name of the object tree node.
+ type_id: Type id of the object tree node.
+ fields: List of fields (children of the object tree node).
+ permutation: Permutation that should be applied to order object children.
+ """
+ self.name = name
+ self.type_id = type_id
+ self.fields = fields
+ self.permutation = permutation
+
+ def to_output_classes(self):
+ """Formats the tree object the way required in 'output_classes' property of
+ dataset.
+ """
+ if self.fields is None:
+ return ops.Tensor
+ output_classes = {}
+ for field in self.fields:
+ output_classes[field.name] = field.to_output_classes()
+ return output_classes
+
+ def to_output_shapes(self):
+ """Formats the tree object the way required in 'output_shapes' property of
+ dataset.
+ """
+ if self.fields is None:
+ object_type = types[self.type_id]
+ if object_type is not None:
+ is_array = object_type[1]
+ if is_array:
+ return tensor_shape.TensorShape([None])
+ return tensor_shape.TensorShape([])
+ raise Exception("Unsupported type [type_id=%d]" % self.type_id)
+ output_shapes = {}
+ for field in self.fields:
+ output_shapes[field.name] = field.to_output_shapes()
+ return output_shapes
+
+ def to_output_types(self):
+ """Formats the tree object the way required in 'output_types' property of
+ dataset.
+ """
+ if self.fields is None:
+ object_type = types[self.type_id]
+ if object_type is not None:
+ return object_type[0]
+ raise Exception("Unsupported type [type_id=%d]" % self.type_id)
+ else:
+ output_types = {}
+ for field in self.fields:
+ output_types[field.name] = field.to_output_types()
+ return output_types
+
+ def to_flat(self):
+ """Returns a list of leaf node types."""
+ return self.to_flat_rec([])
+
+ def to_permutation(self):
+ """Returns a permutation that should be applied to order object leafs."""
+ correct_order_dict = {}
+ self.traversal_rec(correct_order_dict, 0)
+ object_order = []
+ self.traversal_permutation_rec(object_order)
+ return [correct_order_dict[o] for o in object_order]
+
+ def to_flat_rec(self, flat):
+ """Formats a list of leaf node types."""
+ flat.append(self.type_id)
+ if self.fields is not None:
+ for field in self.fields:
+ field.to_flat_rec(flat)
+ return flat
+
+ def traversal_permutation_rec(self, permutation):
+ """Collects nodes in accordance with permutation."""
+ if self.fields is None:
+ permutation.append(self)
+ else:
+ for idx in self.permutation:
+ field = self.fields[idx]
+ field.traversal_permutation_rec(permutation)
+
+ def traversal_rec(self, d, i):
+ """Collects nodes in pre-order traversal."""
+ if self.fields is None:
+ d[self] = i
+ i += 1
+ else:
+ for field in self.fields:
+ i = field.traversal_rec(d, i)
+ return i
+
+class IgniteClient(TcpClient):
+ """IgniteClient class exposes methods to work with Apache Ignite using Thin
+ client. This client works with assumption that all object in the cache
+ have the same structure (homogeneous objects) and the cache contains at
+ least one object.
+ """
+ def __init__(self, host, port, username=None, password=None, certfile=None,\
+ keyfile=None, cert_password=None):
+ """Constructs a new instance of IgniteClient.
+
+ Args:
+ host: Apache Ignite Thin client host to be connected.
+ port: Apache Ignite Thin client port to be connected.
+ username: Apache Ignite Thin Client authentication username.
+ password: Apache Ignite Thin Client authentication password.
+ certfile: File in PEM format containing the certificate as well as
+ any number of CA certificates needed to establish the certificate’s
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key
+ will be taken from certfile as well).
+ cert_password: Password to be used if the private key is encrypted and a
+ password is necessary.
+ """
+ TcpClient.__init__(self, host, port, certfile, keyfile, cert_password)
+ self.username = username
+ self.password = password
+
+ def handshake(self):
+ """Makes a handshake required to be made after connect before any other
+ calls.
+ """
+ msg_len = 8
+
+ if self.username is None:
+ msg_len += 1
+ else:
+ msg_len += 5 + len(self.username)
+
+ if self.password is None:
+ msg_len += 1
+ else:
+ msg_len += 5 + len(self.password)
+
+ self.write_int(msg_len) # Message length
+ self.write_byte(1) # Handshake operation
+ self.write_short(1) # Version (1.1.0)
+ self.write_short(1)
+ self.write_short(0)
+ self.write_byte(2) # Thin client
+
+ if self.username is None: # Username
+ self.write_byte(101)
+ else:
+ self.write_byte(9)
+ self.write_int(len(self.username))
+ self.write_string(self.username)
+
+ if self.password is None: # Password
+ self.write_byte(101)
+ else:
+ self.write_byte(9)
+ self.write_int(len(self.password))
+ self.write_string(self.password)
+
+ self.read_int() # Result length
+ res = self.read_byte()
+
+ if res != 1:
+ serv_ver_major = self.read_short()
+ serv_ver_minor = self.read_short()
+ serv_ver_patch = self.read_short()
+ err_msg = self.__parse_string()
+ if err_msg is None:
+ raise Exception("Handshake Error [result=%d, version=%d.%d.%d]" \
+ % (res, serv_ver_major, serv_ver_minor, serv_ver_patch))
+ else:
+ raise Exception("Handshake Error [result=%d, version=%d.%d.%d, \
+ message='%s']" % (
+ res,
+ serv_ver_major,
+ serv_ver_minor,
+ serv_ver_patch,
+ err_msg
+ ))
+
+ def get_cache_type(self, cache_name):
+ """Collects type information about objects stored in the specified
+ cache.
+ """
+ cache_name_hash = self.__java_hash_code(cache_name)
+ self.write_int(25) # Message length
+ self.write_short(2000) # Operation code
+ self.write_long(0) # Request ID
+ self.write_int(cache_name_hash) # Cache name
+ self.write_byte(0) # Flags
+ self.write_byte(101) # Filter (NULL)
+ self.write_int(1) # Cursor page size
+ self.write_int(-1) # Partition to query
+ self.write_byte(0) # Local flag
+
+ result_length = self.read_int()
+ self.read_long() # Request id
+ status = self.read_int()
+
+ if status != 0:
+ err_msg = self.__parse_string()
+ if err_msg is None:
+ raise Exception("Scan Query Error [status=%s]" % status)
+ else:
+ raise Exception("Scan Query Error [status=%s, message='%s']" \
+ % (status, err_msg))
+
+ self.read_long() # Cursor id
+ row_count = self.read_int()
+
+ if row_count == 0:
+ raise Exception("Scan Query returned empty result, so it's \
+ impossible to derive the cache type")
+
+ payload = DataBuffer(self.read_data(result_length - 25))
+
+ self.read_byte() # Next page
+
+ res = TypeTreeNode("root", 0, [
+ self.__collect_types("key", payload),
+ self.__collect_types("val", payload)
+ ], [0, 1])
+
+ return res
+
+ def __java_hash_code(self, s):
+ """Computes hash code of the specified string using Java code."""
+ h = 0
+ for c in s:
+ h = (31 * h + ord(c)) & 0xFFFFFFFF
+ return ((h + 0x80000000) & 0xFFFFFFFF) - 0x80000000
+
+ def __collect_types(self, field_name, data):
+ """Extracts type information from the specified object."""
+ type_id = data.read_byte()
+
+ # Byte scalar.
+ if type_id == 1:
+ data.skip(1)
+ return TypeTreeNode(field_name, type_id)
+
+ # Short scalar.
+ if type_id == 2:
+ data.skip(2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Integer scalar.
+ if type_id == 3:
+ data.skip(4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Long scalar.
+ if type_id == 4:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Float scalar.
+ if type_id == 5:
+ data.skip(4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Double scalar.
+ if type_id == 6:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Char scalar.
+ if type_id == 7:
+ data.skip(2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Bool scalar.
+ if type_id == 8:
+ data.skip(1)
+ return TypeTreeNode(field_name, type_id)
+
+ # String scalar.
+ if type_id == 9:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # UUID scalar.
+ if type_id == 10:
+ data.skip(16)
+ return TypeTreeNode(field_name, type_id)
+
+ # Date scalar.
+ if type_id == 11:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Byte array.
+ if type_id == 12:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # Short array.
+ if type_id == 13:
+ length = data.read_int()
+ data.skip(length * 2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Integer array.
+ if type_id == 14:
+ length = data.read_int()
+ data.skip(length * 4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Long array.
+ if type_id == 15:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Float array.
+ if type_id == 16:
+ length = data.read_int()
+ data.skip(length * 4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Double array.
+ if type_id == 17:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Char array.
+ if type_id == 18:
+ length = data.read_int()
+ data.skip(length * 2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Bool array.
+ if type_id == 19:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # String array.
+ if type_id == 20:
+ length = data.read_int()
+ for _ in range(length):
+ header = data.read_byte()
+ if header == 9:
+ str_length = data.read_int()
+ data.skip(str_length)
+ elif header == 101:
+ pass
+ else:
+ raise Exception("Unknown binary type when expected string \
+ [type_id=%d]" % header)
+ return TypeTreeNode(field_name, type_id)
+
+ # UUID array.
+ if type_id == 21:
+ length = data.read_int()
+ data.skip(length * 16) # TODO: support NULL values.
+ return TypeTreeNode(field_name, type_id)
+
+ # Date array.
+ if type_id == 22:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Wrapped Binary Object.
+ if type_id == 27:
+ length = data.read_int()
+ inner_data = data.read_data(length)
+ data.read_int() # Offset
+ return self.__collect_types(field_name, DataBuffer(inner_data))
+
+ # Complex Object.
+ if type_id == 103:
+ data.read_byte() # Object version
+ data.read_short() # Object flags
+ obj_type_id = data.read_int()
+ data.read_int() # Object hash code
+ obj_length = data.read_int()
+ data.read_int() # Object schema id
+ obj_schema_offset = data.read_int()
+
+ obj_type = self.__get_type(obj_type_id)
+ children = []
+
+ for obj_field in obj_type.fields:
+ child = self.__collect_types(obj_field.field_name, data)
+ children.append(child)
+
+ children_sorted = sorted(children, key=lambda child: child.name)
+ permutation = [children_sorted.index(child) for child in children]
+ children = children_sorted
+
+ data.skip(obj_length - obj_schema_offset)
+
+ return TypeTreeNode(field_name, type_id, children, permutation)
+
+ raise Exception("Unknown binary type [type_id=%d]" % type_id)
+
+ def __get_type(self, type_id):
+ """Queries Apache Ignite information about type by type id."""
+ self.write_int(14) # Message length
+ self.write_short(3002) # Operation code
+ self.write_long(0) # Request ID
+ self.write_int(type_id) # Type ID
+
+ self.read_int() # Result length
+ self.read_long() # Request id
+ status = self.read_int()
+
+ if status != 0:
+ err_msg = self.__parse_string()
+ if err_msg is None:
+ raise Exception("Get Binary Type Error [status=%d, message='%s']" \
+ % (status, err_msg))
+ else:
+ raise Exception("Get Binary Type Error [status=%d]" % status)
+
+ binary_type_exists = self.read_byte()
+
+ if binary_type_exists == 0:
+ raise Exception("Binary type not found [type_id=%d] " % type_id)
+
+ binary_type_id = self.read_int()
+ binary_type_name = self.__parse_string()
+ self.__parse_string() # Affinity field name
+
+ fields = []
+ for _ in range(self.read_int()):
+ field_name = self.__parse_string()
+ field_type_id = self.read_int()
+ field_id = self.read_int()
+
+ field = BinaryField(field_name, field_type_id, field_id)
+ fields.append(field)
+
+ is_enum = self.read_byte()
+ if is_enum == 1:
+ raise Exception("Enum fields are not supported yet")
+
+ schema_cnt = self.read_int()
+ for _ in range(schema_cnt):
+ self.read_int() # Schema id
+ field_cnt = self.read_int()
+ self.skip(field_cnt * 4)
+
+ return BinaryType(binary_type_id, binary_type_name, fields)
+
+ def __parse_string(self):
+ """Parses string."""
+ header = self.read_byte()
+ if header == 9:
+ length = self.read_int()
+ return self.read_data(length).decode("utf-8")
+ if header == 101:
+ return None
+ raise Exception("Unknown binary type when expected string [type_id=%d]" \
+ % header)
+
+class IgniteDataset(Dataset):
+ """Apache Ignite is a memory-centric distributed database, caching, and
+ processing platform for transactional, analytical, and streaming workloads,
+ delivering in-memory speeds at petabyte scale. This contrib package
+ contains an integration between Apache Ignite and TensorFlow. The
+ integration is based on tf.data from TensorFlow side and Binary Client
+ Protocol from Apache Ignite side. It allows to use Apache Ignite as a
+ datasource for neural network training, inference and all other
+ computations supported by TensorFlow. Ignite Dataset is based on Apache
+ Ignite Binary Client Protocol.
+ """
+
+ def __init__(self, cache_name, host="localhost", port=10800, local=False,\
+ part=-1, page_size=100, username=None, password=None, certfile=None,\
+ keyfile=None, cert_password=None):
+ """Create a IgniteDataset.
+
+ Args:
+ cache_name: Cache name to be used as datasource.
+ host: Apache Ignite Thin Client host to be connected.
+ port: Apache Ignite Thin Client port to be connected.
+ local: Local flag that defines to query only local data.
+ part: Number of partitions to be queried.
+ page_size: Apache Ignite Thin Client page size.
+ username: Apache Ignite Thin Client authentication username.
+ password: Apache Ignite Thin Client authentication password.
+ certfile: File in PEM format containing the certificate as well as
+ any number of CA certificates needed to establish the certificate’s
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key
+ will be taken from certfile as well).
+ cert_password: Password to be used if the private key is encrypted and a
+ password is necessary.
+ """
+ super(IgniteDataset, self).__init__()
+
+ with IgniteClient(host, port, username, password, certfile, keyfile,\
+ cert_password) as client:
+ client.handshake()
+ self.cache_type = client.get_cache_type(cache_name)
+
+ self.cache_name = ops.convert_to_tensor(cache_name, dtype=dtypes.string,\
+ name="cache_name")
+ self.host = ops.convert_to_tensor(host, dtype=dtypes.string, name="host")
+ self.port = ops.convert_to_tensor(port, dtype=dtypes.int32, name="port")
+ self.local = ops.convert_to_tensor(local, dtype=dtypes.bool, name="local")
+ self.part = ops.convert_to_tensor(part, dtype=dtypes.int32, name="part")
+ self.page_size = ops.convert_to_tensor(page_size, dtype=dtypes.int32,\
+ name="page_size")
+ self.username = ops.convert_to_tensor("" if username is None else username,\
+ dtype=dtypes.string, name="username")
+ self.password = ops.convert_to_tensor("" if password is None else password,\
+ dtype=dtypes.string, name="password")
+ self.certfile = ops.convert_to_tensor("" if certfile is None else certfile,\
+ dtype=dtypes.string, name="certfile")
+ self.keyfile = ops.convert_to_tensor("" if keyfile is None else keyfile,\
+ dtype=dtypes.string, name="keyfile")
+ self.cert_password = ops.convert_to_tensor("" if cert_password is None\
+ else cert_password, dtype=dtypes.string, name="cert_password")
+ self.schema = ops.convert_to_tensor(self.cache_type.to_flat(),\
+ dtype=dtypes.int32, name="schema")
+ self.permutation = ops.convert_to_tensor(self.cache_type.to_permutation(),\
+ dtype=dtypes.int32, name="permutation")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.ignite_dataset(self.cache_name, self.host,\
+ self.port, self.local, self.part, self.page_size, self.username,\
+ self.password, self.certfile, self.keyfile, self.cert_password,\
+ self.schema, self.permutation)
+
+ @property
+ def output_classes(self):
+ return self.cache_type.to_output_classes()
+
+ @property
+ def output_shapes(self):
+ return self.cache_type.to_output_shapes()
+
+ @property
+ def output_types(self):
+ return self.cache_type.to_output_types()
diff --git a/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py
new file mode 100644
index 0000000000..8115bda85b
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py
@@ -0,0 +1,25 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Python helper for loading Ignite ops and kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_dataset_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
diff --git a/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh
new file mode 100755
index 0000000000..f4607ce8ad
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-plain.xml &
+sleep 5 # Wait Apache Ignite to be started
+
+./apache-ignite-fabric/bin/sqlline.sh \
+-u "jdbc:ignite:thin://127.0.0.1/" \
+--run=/data/sql/init.sql
+
+tail -f nohup.out
diff --git a/tensorflow/contrib/ignite/python/tests/bin/start-ssl-auth.sh b/tensorflow/contrib/ignite/python/tests/bin/start-ssl-auth.sh
new file mode 100755
index 0000000000..dde1162816
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/bin/start-ssl-auth.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-ssl-auth.xml &
+sleep 5 # Wait Apache Ignite to be started
+
+./apache-ignite-fabric/bin/sqlline.sh -u "jdbc:ignite:thin://127.0.0.1/?\
+sslMode=require&\
+sslClientCertificateKeyStoreUrl=/data/keystore/client.jks&\
+sslClientCertificateKeyStorePassword=123456&\
+sslTrustAll=true&\
+username=ignite&\
+password=ignite" --run=/data/sql/init.sql
+
+tail -f nohup.out
diff --git a/tensorflow/contrib/ignite/python/tests/bin/start-ssl.sh b/tensorflow/contrib/ignite/python/tests/bin/start-ssl.sh
new file mode 100755
index 0000000000..58b40b2738
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/bin/start-ssl.sh
@@ -0,0 +1,26 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-ssl.xml &
+sleep 5 # Wait Apache Ignite to be started
+
+./apache-ignite-fabric/bin/sqlline.sh -u "jdbc:ignite:thin://127.0.0.1/?\
+sslMode=require&\
+sslClientCertificateKeyStoreUrl=/data/keystore/client.jks&\
+sslClientCertificateKeyStorePassword=123456&\
+sslTrustAll=true" --run=/data/sql/init.sql --verbose=true
+
+tail -f nohup.out
diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml
new file mode 100644
index 0000000000..d900174a8a
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml
@@ -0,0 +1,39 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<beans xmlns="http://www.springframework.org/schema/beans"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xmlns:util="http://www.springframework.org/schema/util"
+ xsi:schemaLocation="http://www.springframework.org/schema/beans
+ http://www.springframework.org/schema/beans/spring-beans.xsd
+ http://www.springframework.org/schema/util
+ http://www.springframework.org/schema/util/spring-util.xsd">
+
+ <bean class="org.apache.ignite.configuration.IgniteConfiguration">
+ <property name="discoverySpi">
+ <bean class="org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi">
+ <property name="ipFinder">
+ <bean class="org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder">
+ <property name="addresses">
+ <list>
+ <value>127.0.0.1</value>
+ </list>
+ </property>
+ </bean>
+ </property>
+ </bean>
+ </property>
+ </bean>
+
+</beans>
diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-ssl-auth.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-ssl-auth.xml
new file mode 100644
index 0000000000..8e001b28ab
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-ssl-auth.xml
@@ -0,0 +1,59 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+
+<beans xmlns="http://www.springframework.org/schema/beans"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xmlns:util="http://www.springframework.org/schema/util"
+ xsi:schemaLocation="http://www.springframework.org/schema/beans
+ http://www.springframework.org/schema/beans/spring-beans.xsd
+ http://www.springframework.org/schema/util
+ http://www.springframework.org/schema/util/spring-util.xsd">
+
+ <bean id="client-connector-configuration"
+ class="org.apache.ignite.configuration.ClientConnectorConfiguration">
+ <property name="sslClientAuth" value="true" />
+ <property name="sslEnabled" value="true" />
+ <property name="useIgniteSslContextFactory" value="true" />
+ </bean>
+
+ <bean id="ssl-context-factory"
+ class="org.apache.ignite.ssl.SslContextFactory">
+ <property name="keyStoreFilePath" value="/data/keystore/server.jks"/>
+ <property name="keyStorePassword" value="123456"/>
+ <property name="trustStoreFilePath" value="/data/keystore/trust.jks"/>
+ <property name="trustStorePassword" value="123456"/>
+ </bean>
+
+ <bean id="ignite-configuration"
+ class="org.apache.ignite.configuration.IgniteConfiguration">
+ <property name="clientConnectorConfiguration"
+ ref="client-connector-configuration" />
+ <property name="sslContextFactory" ref="ssl-context-factory" />
+ <property name="discoverySpi">
+ <bean class="org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi">
+ <property name="ipFinder">
+ <bean class="org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder">
+ <property name="addresses">
+ <list>
+ <value>127.0.0.1</value>
+ </list>
+ </property>
+ </bean>
+ </property>
+ </bean>
+ </property>
+ </bean>
+
+</beans>
diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-ssl.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-ssl.xml
new file mode 100644
index 0000000000..42d480c114
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-ssl.xml
@@ -0,0 +1,59 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+
+<beans xmlns="http://www.springframework.org/schema/beans"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xmlns:util="http://www.springframework.org/schema/util"
+ xsi:schemaLocation="http://www.springframework.org/schema/beans
+ http://www.springframework.org/schema/beans/spring-beans.xsd
+ http://www.springframework.org/schema/util
+ http://www.springframework.org/schema/util/spring-util.xsd">
+
+ <bean id="client-connector-configuration"
+ class="org.apache.ignite.configuration.ClientConnectorConfiguration">
+ <property name="sslClientAuth" value="false" />
+ <property name="sslEnabled" value="true" />
+ <property name="useIgniteSslContextFactory" value="true" />
+ </bean>
+
+ <bean id="ssl-context-factory"
+ class="org.apache.ignite.ssl.SslContextFactory">
+ <property name="keyStoreFilePath" value="/data/keystore/server.jks"/>
+ <property name="keyStorePassword" value="123456"/>
+ <property name="trustStoreFilePath" value="/data/keystore/trust.jks"/>
+ <property name="trustStorePassword" value="123456"/>
+ </bean>
+
+ <bean id="ignite-configuration"
+ class="org.apache.ignite.configuration.IgniteConfiguration">
+ <property name="clientConnectorConfiguration"
+ ref="client-connector-configuration" />
+ <property name="sslContextFactory" ref="ssl-context-factory" />
+ <property name="discoverySpi">
+ <bean class="org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi">
+ <property name="ipFinder">
+ <bean class="org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder">
+ <property name="addresses">
+ <list>
+ <value>127.0.0.1</value>
+ </list>
+ </property>
+ </bean>
+ </property>
+ </bean>
+ </property>
+ </bean>
+
+</beans>
diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py
new file mode 100644
index 0000000000..933e62b804
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py
@@ -0,0 +1,77 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for IgniteDataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tensorflow as tf
+from tensorflow.contrib.ignite import IgniteDataset
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+class IgniteDatasetTest(test.TestCase):
+ """The Apache Ignite servers have to setup before the test and tear down
+ after the test manually. The docker engine has to be installed.
+
+ To setup Apache Ignite servers:
+ $ bash start_ignite.sh
+
+ To tear down Apache Ignite servers:
+ $ bash stop_ignite.sh
+ """
+
+ def test_ignite_dataset_with_plain_client(self):
+ ds = IgniteDataset(cache_name="SQL_PUBLIC_TEST_CACHE", port=42300)
+ self.__check_dataset(ds)
+
+ def test_ignite_dataset_with_ssl_client(self):
+ ds = IgniteDataset(cache_name="SQL_PUBLIC_TEST_CACHE", port=42301,\
+ certfile=os.path.dirname(os.path.realpath(__file__)) +\
+ "/keystore/client.pem", cert_password="123456")
+ self.__check_dataset(ds)
+
+ def test_ignite_dataset_with_ssl_client_and_auth(self):
+ ds = IgniteDataset(cache_name="SQL_PUBLIC_TEST_CACHE", port=42302,\
+ certfile=os.path.dirname(os.path.realpath(__file__)) +\
+ "/keystore/client.pem", cert_password="123456",\
+ username="ignite", password="ignite")
+ self.__check_dataset(ds)
+
+ def __check_dataset(self, dataset):
+ """Checks that dataset provids correct data.
+ """
+ self.assertEquals(tf.int64, dataset.output_types['key'])
+ self.assertEquals(tf.string, dataset.output_types['val']['NAME'])
+ self.assertEquals(tf.int64, dataset.output_types['val']['VAL'])
+
+ it = dataset.make_one_shot_iterator()
+ ne = it.get_next()
+
+ with tf.Session() as sess:
+ rows = [sess.run(ne), sess.run(ne), sess.run(ne)]
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(ne)
+
+ self.assertEquals({'key': 1, 'val': {'NAME': b'TEST1', 'VAL': 42}},\
+ rows[0])
+ self.assertEquals({'key': 2, 'val': {'NAME': b'TEST2', 'VAL': 43}},\
+ rows[1])
+ self.assertEquals({'key': 3, 'val': {'NAME': b'TEST3', 'VAL': 44}},\
+ rows[2])
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/ignite/python/tests/keystore/client.jks b/tensorflow/contrib/ignite/python/tests/keystore/client.jks
new file mode 100644
index 0000000000..1875c71b60
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/keystore/client.jks
Binary files differ
diff --git a/tensorflow/contrib/ignite/python/tests/keystore/client.pem b/tensorflow/contrib/ignite/python/tests/keystore/client.pem
new file mode 100644
index 0000000000..a71a87e0bb
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/keystore/client.pem
@@ -0,0 +1,69 @@
+Bag Attributes
+ friendlyName: client
+ localKeyID: 54 69 6D 65 20 31 33 33 39 32 33 39 38 35 39 34 34 36
+Key Attributes: <No Attributes>
+-----BEGIN RSA PRIVATE KEY-----
+Proc-Type: 4,ENCRYPTED
+DEK-Info: DES-EDE3-CBC,CE61EDD98349D0C7
+
+Kzl16sj8R7YUXPCEZCqCrY4LSAjiKCRFNOagEehvN9Jpswcz4JbatoFmvVvOCgBF
+7kkeCaALhfM5a+46uynZ1sOOFUOn8fUFgguN3lLInWfm6vTuXDPslg0/tRNI0YqW
+ujfxyzrm1/k4RX0oLzRE1jZr69VZsBmZndkz9nkz3anWKLE7X/VIFV6U/N6YNPch
+BG1Fxpt/HtM9p3B5wNDSjCVaeNP1ROKe3APLRY6k+SppTuntHV5q9Ni82r1l3ahU
+zf2QvocSy9MLh+bGusJGHyJJAGuwPHm6ytPwbXGHn5xe4HPIno28j9kN7EL1ZoUs
+q0PhipAkFrGIM4zg6nAwVdzY5iGySDQ3fWpz2MkrKMDRftBwA3o/M321NBUW9/2X
+l+XmjXcJd0dEOslGxveb6UXLL2YvYszjQXRR4dCV/40bMJL3umRhVSay0NteoXfY
+82rQchm2NHKOiDfB4RpD8JJtVQeDSMXc9TH5y2Ua7FZND60JXtFpdnfCVfVZuBJm
+yBafyIsXR7EQzLG4z28Dvp4fs42A3JkF+e9Aq6Y6MmYA1wsvIKKT9HKEifqKmbgG
+4E9WOZn5IWi4ZJ44VAwN/uBGrLm//3OjByeB9y8vszNbyY8dQ8x5XqnF/IzIvgqc
+uKA8xuLAkTFmgRGQ/lmMDR+iMhet5dCtg9Orb9tYVL55JAb/OfsCX0LTJ3Y2RmIx
+CaFpkUP7KKYD+69ajnFCxvfGnGxyBkf+JeuDYIZVFklVT9SUtL9RJh26jUdvHt2A
+LQerBl8UCkVbPxsxYjdawvxuBNTD6tSRykM8zwtWcvIubp+gxE7png==
+-----END RSA PRIVATE KEY-----
+Bag Attributes
+ friendlyName: 1.2.840.113549.1.9.1=#1613636c69656e7440677269646761696e2e636f6d,CN=client,OU=Dev,O=GridGain,ST=SPb,C=RU
+ localKeyID: 54 69 6D 65 20 31 33 33 39 32 33 39 38 35 39 34 34 36
+subject=/C=RU/ST=SPb/O=GridGain/OU=Dev/CN=client/emailAddress=client@gridgain.com
+issuer=/C=RU/ST=SPb/L=SPb/O=GridGain/OU=Dev/CN=ca/emailAddress=ca@gridgain.com
+-----BEGIN CERTIFICATE-----
+MIIC2TCCAkKgAwIBAgIBJDANBgkqhkiG9w0BAQUFADB3MQswCQYDVQQGEwJSVTEM
+MAoGA1UECBMDU1BiMQwwCgYDVQQHEwNTUGIxETAPBgNVBAoTCEdyaWRHYWluMQww
+CgYDVQQLEwNEZXYxCzAJBgNVBAMTAmNhMR4wHAYJKoZIhvcNAQkBFg9jYUBncmlk
+Z2Fpbi5jb20wHhcNMTIwNjA5MTEwNDE3WhcNMzIwNjA5MTEwNDE3WjBxMQswCQYD
+VQQGEwJSVTEMMAoGA1UECBMDU1BiMREwDwYDVQQKEwhHcmlkR2FpbjEMMAoGA1UE
+CxMDRGV2MQ8wDQYDVQQDEwZjbGllbnQxIjAgBgkqhkiG9w0BCQEWE2NsaWVudEBn
+cmlkZ2Fpbi5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBANIHHcYiA+CP
+EBPKNZJ6mtvN4d9Yj43B5/hzs/TK3e4XImLsMhXaElYtrXQX/SDK7Zv5zdj6AkKH
+QkJ9BT8Jw7wvOQx/v4Qxrl+gTgcf6gjk6DvzqMlZUwH+ohbALj2TWsy9y+0uHKal
+EVrHpbYeB9TGpD+3NHwO/CG4SySk/Y4nAgMBAAGjezB5MAkGA1UdEwQCMAAwLAYJ
+YIZIAYb4QgENBB8WHU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRlMB0GA1Ud
+DgQWBBRD/TKyBQyoVxqEupLzUB8hDrSF6DAfBgNVHSMEGDAWgBS1+Ah4ZG58tImL
+KqLVX+xBKbeFUTANBgkqhkiG9w0BAQUFAAOBgQCL2vhjwcJkA1OJGuXsuO2/87Zu
+HMa7gc4pm+Iol1B1gD2ksQEAU2dz/adD3369H7gZdHuk3RYPeYmD5Ppp9eECDsXc
+gNWrNYaqcSTYWRAUe1/St7vB9HzPdOm/eADfQaMnal6fmjfpzTgg65A/2w4GCsqt
+RL98pvdAft8v5WSx7A==
+-----END CERTIFICATE-----
+Bag Attributes
+ friendlyName: 1.2.840.113549.1.9.1=#160f636140677269646761696e2e636f6d,CN=ca,OU=Dev,O=GridGain,L=SPb,ST=SPb,C=RU
+subject=/C=RU/ST=SPb/L=SPb/O=GridGain/OU=Dev/CN=ca/emailAddress=ca@gridgain.com
+issuer=/C=RU/ST=SPb/L=SPb/O=GridGain/OU=Dev/CN=ca/emailAddress=ca@gridgain.com
+-----BEGIN CERTIFICATE-----
+MIIDSTCCArKgAwIBAgIJAKmuj925215OMA0GCSqGSIb3DQEBBQUAMHcxCzAJBgNV
+BAYTAlJVMQwwCgYDVQQIEwNTUGIxDDAKBgNVBAcTA1NQYjERMA8GA1UEChMIR3Jp
+ZEdhaW4xDDAKBgNVBAsTA0RldjELMAkGA1UEAxMCY2ExHjAcBgkqhkiG9w0BCQEW
+D2NhQGdyaWRnYWluLmNvbTAeFw0xMjA2MDkwNjU1MTJaFw0zMjA2MDQwNjU1MTJa
+MHcxCzAJBgNVBAYTAlJVMQwwCgYDVQQIEwNTUGIxDDAKBgNVBAcTA1NQYjERMA8G
+A1UEChMIR3JpZEdhaW4xDDAKBgNVBAsTA0RldjELMAkGA1UEAxMCY2ExHjAcBgkq
+hkiG9w0BCQEWD2NhQGdyaWRnYWluLmNvbTCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
+gYkCgYEAtd16DCObyM63NKF/cvRcE+8cr1dc3c7mSnTEQ61WfqPJ2QqsQAB6e+5+
+q9Np1SaJyqFTTag6483ibrU+DkGPGgEXndRHtQHQPbStWsf47DBBW2bMi6+bkPox
+Cp6BhYO1DQUG5tP9CQ/g32mLQLB7LH0KtS1JcKpAClCjjWZC8b8CAwEAAaOB3DCB
+2TAdBgNVHQ4EFgQUtfgIeGRufLSJiyqi1V/sQSm3hVEwgakGA1UdIwSBoTCBnoAU
+tfgIeGRufLSJiyqi1V/sQSm3hVGhe6R5MHcxCzAJBgNVBAYTAlJVMQwwCgYDVQQI
+EwNTUGIxDDAKBgNVBAcTA1NQYjERMA8GA1UEChMIR3JpZEdhaW4xDDAKBgNVBAsT
+A0RldjELMAkGA1UEAxMCY2ExHjAcBgkqhkiG9w0BCQEWD2NhQGdyaWRnYWluLmNv
+bYIJAKmuj925215OMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEAhrzd
+qusVLHO3wtyu0o+EAFyoDv5avCBTFsQLeDDPMyfDcEO6wfxhTanfH8C7gZc0rRnv
+2nbkVbfortHIOfU2wch5gClju0cXSTIXSKOAWPIMp3HLxC/l+KpFo3epFz0rsMVB
+M1ymOOdRDdAcTxcTTGY7WJXquEM3ZbT5Gh4RLDk=
+-----END CERTIFICATE-----
diff --git a/tensorflow/contrib/ignite/python/tests/keystore/server.jks b/tensorflow/contrib/ignite/python/tests/keystore/server.jks
new file mode 100644
index 0000000000..006ececc31
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/keystore/server.jks
Binary files differ
diff --git a/tensorflow/contrib/ignite/python/tests/keystore/trust.jks b/tensorflow/contrib/ignite/python/tests/keystore/trust.jks
new file mode 100644
index 0000000000..a00f1251af
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/keystore/trust.jks
Binary files differ
diff --git a/tensorflow/contrib/ignite/python/tests/sql/init.sql b/tensorflow/contrib/ignite/python/tests/sql/init.sql
new file mode 100644
index 0000000000..5a192aef17
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/sql/init.sql
@@ -0,0 +1,20 @@
+-- Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+--
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ==============================================================================
+
+CREATE TABLE TEST_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR, VAL LONG);
+
+INSERT INTO TEST_CACHE VALUES (1, 'TEST1', 42);
+INSERT INTO TEST_CACHE VALUES (2, 'TEST2', 43);
+INSERT INTO TEST_CACHE VALUES (3, 'TEST3', 44);
diff --git a/tensorflow/contrib/ignite/python/tests/start_ignite.sh b/tensorflow/contrib/ignite/python/tests/start_ignite.sh
new file mode 100755
index 0000000000..fbcf656afd
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/start_ignite.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+IGNITE_VERSION=2.6.0
+SCRIPT_PATH="$( cd "$(dirname "$0")" ; pwd -P )"
+
+# Start Apache Ignite with plain client listener.
+docker run -itd --name ignite-plain -p 42300:10800 \
+-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-plain.sh
+
+# Start Apache Ignite with SSL client listener.
+docker run -itd --name ignite-ssl -p 42301:10800 \
+-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-ssl.sh
+
+# Start Apache Ignite with SSL client listener with auth.
+docker run -itd --name ignite-ssl-auth -p 42302:10800 \
+-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-ssl-auth.sh
diff --git a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh
new file mode 100755
index 0000000000..8f03dbd1ed
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+docker rm -f ignite-plain
+docker rm -f ignite-ssl
+docker rm -f ignite-ssl-auth