aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/ignite
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 12:25:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 12:25:39 -0700
commit61a872068ece1355945ef2d88659e99de2fe7591 (patch)
tree0aa42811dd0a8741f0a418b2ff97e512c40ebcb8 /tensorflow/contrib/ignite
parentc4b3ce081b8abfae5560814ec445f0169cb4c368 (diff)
parent90c68770467701a23d23a85c5d769f6f4fa39f0f (diff)
Merge pull request #22210 from dmitrievanthony:apache-ignite-dataset
PiperOrigin-RevId: 215258743
Diffstat (limited to 'tensorflow/contrib/ignite')
-rw-r--r--tensorflow/contrib/ignite/BUILD139
-rw-r--r--tensorflow/contrib/ignite/README.md167
-rw-r--r--tensorflow/contrib/ignite/__init__.py42
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc334
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h81
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h126
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_client.h84
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset.cc81
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset.h63
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc422
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h99
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc198
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client.h43
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc123
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc142
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc151
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h51
-rw-r--r--tensorflow/contrib/ignite/ops/dataset_ops.cc56
-rw-r--r--tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py772
-rw-r--r--tensorflow/contrib/ignite/python/ops/ignite_op_loader.py24
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/bin/start-plain.sh24
-rw-r--r--tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml39
-rw-r--r--tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py118
-rw-r--r--tensorflow/contrib/ignite/python/tests/sql/init.sql20
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/start_ignite.sh22
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/stop_ignite.sh19
26 files changed, 3440 insertions, 0 deletions
diff --git a/tensorflow/contrib/ignite/BUILD b/tensorflow/contrib/ignite/BUILD
new file mode 100644
index 0000000000..9393b702d1
--- /dev/null
+++ b/tensorflow/contrib/ignite/BUILD
@@ -0,0 +1,139 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "if_not_windows",
+ "if_windows",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_py_test",
+)
+
+py_library(
+ name = "ignite",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = ["ops/dataset_ops.cc"],
+ deps = [":dataset_kernels"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["dataset_ops"],
+)
+
+cc_library(
+ name = "dataset_kernels",
+ srcs = [
+ "kernels/ignite_dataset_ops.cc",
+ "kernels/ignite_client.h",
+ "kernels/ignite_byte_swapper.h",
+ "kernels/ignite_plain_client.h",
+ "kernels/ignite_ssl_wrapper.h",
+ "kernels/ignite_ssl_wrapper.cc",
+ "kernels/ignite_binary_object_parser.h",
+ "kernels/ignite_binary_object_parser.cc",
+ "kernels/ignite_dataset.h",
+ "kernels/ignite_dataset.cc",
+ "kernels/ignite_dataset_iterator.h",
+ "kernels/ignite_dataset_iterator.cc",
+ ] + if_not_windows([
+ "kernels/ignite_plain_client_unix.cc",
+ ]) + if_windows([
+ "kernels/ignite_plain_client_windows.cc",
+ ]),
+ copts = if_windows([
+ "-DWIN32_LEAN_AND_MEAN",
+ ]),
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@boringssl//:ssl",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+py_library(
+ name = "dataset_ops",
+ srcs = [
+ "python/ops/ignite_dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ignite_op_loader",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_dataset_ops",
+ out = "python/ops/gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/ignite:dataset_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "dataset_ops_kernels",
+ deps = [
+ ":dataset_kernels",
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "ignite_op_loader",
+ srcs = ["python/ops/ignite_op_loader.py"],
+ dso = ["//tensorflow/contrib/ignite:_dataset_ops.so"],
+ kernels = [
+ ":dataset_ops_kernels",
+ "//tensorflow/contrib/ignite:dataset_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_dataset_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+# The Apache Ignite servers have to setup before the test and tear down
+# after the test manually. The docker engine has to be installed.
+#
+# To setup Apache Ignite servers:
+# $ bash ./python/tests/start_ignite.sh
+#
+# To tear down Apache Ignite servers:
+# $ bash ./python/tests/stop_ignite.sh
+tf_py_test(
+ name = "ignite_dataset_test",
+ srcs = ["python/tests/ignite_dataset_test.py"],
+ additional_deps = [
+ ":ignite",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "no_windows",
+ "notap",
+ ],
+)
diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md
new file mode 100644
index 0000000000..55c89d2799
--- /dev/null
+++ b/tensorflow/contrib/ignite/README.md
@@ -0,0 +1,167 @@
+# Ignite Dataset
+
+- [Overview](#overview)
+- [Features](#features)
+ * [Distributed In-Memory Datasource](#distributed-in-memory-datasource)
+ * [Structured Objects](#structured-objects)
+ * [Distributed Training](#distributed-training)
+ * [SSL Connection](#ssl-connection)
+ * [Windows Support](#windows-support)
+- [Try it out](#try-it-out)
+- [Limitations](#limitations)
+
+## Overview
+
+[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed database, caching, and processing platform for
+transactional, analytical, and streaming workloads, delivering in-memory speeds at petabyte scale. This contrib package contains an integration between Apache Ignite and TensorFlow. The integration is based on [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow side and [Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) from Apache Ignite side. It allows to use Apache Ignite as a data source for neural network training, inference and all other computations supported by TensorFlow.
+
+## Features
+
+Ignite Dataset provides features that that you can use in a wide range of cases. The most important and interesting features are described below.
+
+### Distributed In-Memory Datasource
+[Apache Ignite](https://ignite.apache.org/) is a distributed in-memory database, caching, and processing platform that provides fast data access. It allows you to avoid limitations of hard drive and store and operate with as much data as you need in distributed cluster. You can utilize
+these benefits of Apache Ignite by using Ignite Dataset. Moreover, Ignite Dataset can be used for the following use-cases:
+- If you have a **gigabyte** of data you can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations. At the same time, you can store your data in Apache Ignite on the same machine and use it as a datasource for TensorFlow and thus avoid these limitations.
+- If you have a **terabyte** of data you probably still can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations again. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow and thus avoid these limitations.
+- If you have a **petabyte** of data you can't keep it on a single machine. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow.
+
+Note that Apache Ignite is not just a step of ETL pipeline between a database or a data warehouse and TensorFlow. Apache Ignite is a high-grade database itself. By choosing Apache Ignite and TensorFlow you are getting everything you need to work with operational or historical data and, at the same time, an ability to use this data for neural network training and inference.
+
+```bash
+$ apache-ignite-fabric/bin/ignite.sh
+$ apache-ignite-fabric/bin/sqlline.sh -u "jdbc:ignite:thin://localhost:10800/"
+
+jdbc:ignite:thin://localhost/> CREATE TABLE KITTEN_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR);
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (1, 'WARM KITTY');
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (2, 'SOFT KITTY');
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL OF FUR');
+```
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE")
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> for _ in range(3):
+>>> print(sess.run(next_obj))
+
+{'key': 1, 'val': {'NAME': b'WARM KITTY'}}
+{'key': 2, 'val': {'NAME': b'SOFT KITTY'}}
+{'key': 3, 'val': {'NAME': b'LITTLE BALL OF FUR'}}
+```
+
+### Structured Objects
+[Apache Ignite](https://ignite.apache.org/) allows to store any type of objects. These objects can have any hierarchy. Ignite Dataset provides an ability to work with such objects.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES")
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> print(sess.run(next_obj))
+
+{
+ 'key': 'kitten.png',
+ 'val': {
+ 'metadata': {
+ 'file_name': b'kitten.png',
+ 'label': b'little ball of fur',
+ width: 800,
+ height: 600
+ },
+ 'pixels': [0, 0, 0, 0, ..., 0]
+ }
+}
+```
+ Neural network training and other computations require transformations that can be done as part of [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) pipeline if you use Ignite Dataset.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels'])
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> print(sess.run(next_obj))
+
+[0, 0, 0, 0, ..., 0]
+```
+
+### Distributed Training
+
+TensorFlow is a machine learning framework that [natively supports](https://www.tensorflow.org/deploy/distributed) distributed neural network training, inference and other computations. The main idea behind the distributed neural network training is the ability to calculate gradients of loss functions (squares of the errors) on every partition of data (in terms of horizontal partitioning) and then sum them to get loss function gradient of the whole dataset.
+
+<a href="https://www.codecogs.com/eqnedit.php?latex=\nabla[\sum_1^n(y&space;-&space;\hat{y})^2]&space;=&space;\nabla[\sum_1^{n_1}(y&space;-&space;\hat{y})^2]&space;&plus;&space;\nabla[\sum_{n_1}^{n_2}(y&space;-&space;\hat{y})^2]&space;&plus;&space;...&space;&plus;&space;\nabla[\sum_{n_{k-1}}^n(y&space;-&space;\hat{y})^2]" target="_blank"><img src="https://latex.codecogs.com/gif.latex?\nabla[\sum_1^n(y&space;-&space;\hat{y})^2]&space;=&space;\nabla[\sum_1^{n_1}(y&space;-&space;\hat{y})^2]&space;&plus;&space;\nabla[\sum_{n_1}^{n_2}(y&space;-&space;\hat{y})^2]&space;&plus;&space;...&space;&plus;&space;\nabla[\sum_{n_{k-1}}^n(y&space;-&space;\hat{y})^2]" title="\nabla[\sum_1^n(y - \hat{y})^2] = \nabla[\sum_1^{n_1}(y - \hat{y})^2] + \nabla[\sum_{n_1}^{n_2}(y - \hat{y})^2] + ... + \nabla[\sum_{n_{k-1}}^n(y - \hat{y})^2]" /></a>
+
+Using this ability we can calculate gradients on the nodes the data is stored on, reduce them and then finally update model parameters. It allows to avoid data transfers between nodes and thus to avoid network bottlenecks.
+
+Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL), we can specify the number of partitions the data will be partitioned on. For example, if an Apache Ignite cluster consists of 10 machines and we create cache with 10 partitions, then every machine will maintain approximately one data partition.
+
+Ignite Dataset allows using these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that can be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach, we can assign a specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset("IMAGES")
+>>>
+>>> # Compute gradients locally on every worker node.
+>>> gradients = []
+>>> for i in range(5):
+>>> with tf.device("/job:WORKER/task:%d" % i):
+>>> device_iterator = dataset.make_one_shot_iterator()
+>>> device_next_obj = device_iterator.get_next()
+>>> gradient = compute_gradient(device_next_obj)
+>>> gradients.append(gradient)
+>>>
+>>> # Aggregate them on master node.
+>>> result_gradient = tf.reduce_sum(gradients)
+>>>
+>>> with tf.Session("grpc://localhost:10000") as sess:
+>>> print(sess.run(result_gradient))
+```
+
+High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well.
+
+### SSL Connection
+
+Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", cert_password="password", username="ignite", password="ignite")
+>>> ...
+```
+
+### Windows Support
+
+Ignite Dataset is fully compatible with Windows. You can use it as part of TensorFlow on your Windows workstation as well as on Linux/MacOS systems.
+
+## Try it out
+
+The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded [MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine:
+
+```
+docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist
+```
+
+After that you will be able to work with it following way:
+
+![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist")
+
+## Limitations
+
+Presently, Ignite Dataset works with assumption that all objects in the cache have the same structure (homogeneous objects) and the cache contains at least one object. Another limitation concerns structured objects, Ignite Dataset does not support UUID, Maps and Object arrays that might be parts of an object structure.
diff --git a/tensorflow/contrib/ignite/__init__.py b/tensorflow/contrib/ignite/__init__.py
new file mode 100644
index 0000000000..f42947696f
--- /dev/null
+++ b/tensorflow/contrib/ignite/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""IgniteDataset that allows to get data from Apache Ignite.
+
+Apache Ignite is a memory-centric distributed database, caching, and
+processing platform for transactional, analytical, and streaming workloads,
+delivering in-memory speeds at petabyte scale. This contrib package
+contains an integration between Apache Ignite and TensorFlow. The
+integration is based on tf.data from TensorFlow side and Binary Client
+Protocol from Apache Ignite side. It allows to use Apache Ignite as a
+datasource for neural network training, inference and all other
+computations supported by TensorFlow. Ignite Dataset is based on Apache
+Ignite Binary Client Protocol:
+https://apacheignite.readme.io/v2.6/docs/binary-client-protocol.
+
+@@IgniteDataset
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.ignite.python.ops.ignite_dataset_ops import IgniteDataset
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "IgniteDataset",
+]
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc
new file mode 100644
index 0000000000..2c8a7d44b0
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc
@@ -0,0 +1,334 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+BinaryObjectParser::BinaryObjectParser() : byte_swapper_(ByteSwapper(false)) {}
+
+Status BinaryObjectParser::Parse(uint8_t** ptr,
+ std::vector<Tensor>* out_tensors,
+ std::vector<int32_t>* types) const {
+ uint8_t object_type_id = ParseByte(ptr);
+
+ // Skip non-leaf nodes.
+ if (object_type_id != WRAPPED_OBJ && object_type_id != COMPLEX_OBJ)
+ types->push_back(object_type_id);
+
+ switch (object_type_id) {
+ case BYTE: {
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT8, TensorShape({}));
+ out_tensors->back().scalar<uint8>()() = ParseByte(ptr);
+ break;
+ }
+ case SHORT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT16, TensorShape({}));
+ out_tensors->back().scalar<int16>()() = ParseShort(ptr);
+ break;
+ }
+ case USHORT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT16, TensorShape({}));
+ out_tensors->back().scalar<uint16>()() = ParseUnsignedShort(ptr);
+ break;
+ }
+ case INT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT32, TensorShape({}));
+ out_tensors->back().scalar<int32>()() = ParseInt(ptr);
+ break;
+ }
+ case LONG: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({}));
+ out_tensors->back().scalar<int64>()() = ParseLong(ptr);
+ break;
+ }
+ case FLOAT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, TensorShape({}));
+ out_tensors->back().scalar<float>()() = ParseFloat(ptr);
+ break;
+ }
+ case DOUBLE: {
+ out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, TensorShape({}));
+ out_tensors->back().scalar<double>()() = ParseDouble(ptr);
+ break;
+ }
+ case BOOL: {
+ out_tensors->emplace_back(cpu_allocator(), DT_BOOL, TensorShape({}));
+ out_tensors->back().scalar<bool>()() = ParseBool(ptr);
+ break;
+ }
+ case STRING: {
+ out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({}));
+ out_tensors->back().scalar<string>()() = ParseString(ptr);
+ break;
+ }
+ case DATE: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({}));
+ out_tensors->back().scalar<int64>()() = ParseLong(ptr);
+ break;
+ }
+ case BYTE_ARR: {
+ int32_t length = ParseInt(ptr);
+ uint8_t* arr = ParseByteArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT8,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<uint8>().data());
+ break;
+ }
+ case SHORT_ARR: {
+ int32_t length = ParseInt(ptr);
+ int16_t* arr = ParseShortArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT16,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int16>().data());
+ break;
+ }
+ case USHORT_ARR: {
+ int32_t length = ParseInt(ptr);
+ uint16_t* arr = ParseUnsignedShortArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT16,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<uint16>().data());
+ break;
+ }
+ case INT_ARR: {
+ int32_t length = ParseInt(ptr);
+ int32_t* arr = ParseIntArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT32,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int32>().data());
+ break;
+ }
+ case LONG_ARR: {
+ int32_t length = ParseInt(ptr);
+ int64_t* arr = ParseLongArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int64>().data());
+ break;
+ }
+ case FLOAT_ARR: {
+ int32_t length = ParseInt(ptr);
+ float* arr = ParseFloatArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_FLOAT,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<float>().data());
+ break;
+ }
+ case DOUBLE_ARR: {
+ int32_t length = ParseInt(ptr);
+ double* arr = ParseDoubleArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<double>().data());
+ break;
+ }
+ case BOOL_ARR: {
+ int32_t length = ParseInt(ptr);
+ bool* arr = ParseBoolArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_BOOL,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<bool>().data());
+ break;
+ }
+ case STRING_ARR: {
+ int32_t length = ParseInt(ptr);
+ out_tensors->emplace_back(cpu_allocator(), DT_STRING,
+ TensorShape({length}));
+ for (int32_t i = 0; i < length; i++)
+ out_tensors->back().vec<string>()(i) = ParseString(ptr);
+ break;
+ }
+ case DATE_ARR: {
+ int32_t length = ParseInt(ptr);
+ int64_t* arr = ParseLongArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int64>().data());
+ break;
+ }
+ case WRAPPED_OBJ: {
+ int32_t byte_arr_size = ParseInt(ptr);
+ TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types));
+ int32_t offset = ParseInt(ptr);
+
+ break;
+ }
+ case COMPLEX_OBJ: {
+ uint8_t version = ParseByte(ptr);
+ int16_t flags = ParseShort(ptr);
+ int32_t type_id = ParseInt(ptr);
+ int32_t hash_code = ParseInt(ptr);
+ int32_t length = ParseInt(ptr);
+ int32_t schema_id = ParseInt(ptr);
+ int32_t schema_offset = ParseInt(ptr);
+
+ // 24 is size of header just read.
+ uint8_t* end = *ptr + schema_offset - 24;
+ int32_t i = 0;
+ while (*ptr < end) {
+ i++;
+ TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types));
+ }
+
+ *ptr += (length - schema_offset);
+
+ break;
+ }
+ default: {
+ return errors::Unknown("Unknowd binary type (type id ",
+ (int)object_type_id, ")");
+ }
+ }
+
+ return Status::OK();
+}
+
+uint8_t BinaryObjectParser::ParseByte(uint8_t** ptr) const {
+ uint8_t res = **ptr;
+ *ptr += 1;
+
+ return res;
+}
+
+int16_t BinaryObjectParser::ParseShort(uint8_t** ptr) const {
+ int16_t* res = *reinterpret_cast<int16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt16(res);
+ *ptr += 2;
+
+ return *res;
+}
+
+uint16_t BinaryObjectParser::ParseUnsignedShort(uint8_t** ptr) const {
+ uint16_t* res = *reinterpret_cast<uint16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredUnsignedInt16(res);
+ *ptr += 2;
+
+ return *res;
+}
+
+int32_t BinaryObjectParser::ParseInt(uint8_t** ptr) const {
+ int32_t* res = *reinterpret_cast<int32_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt32(res);
+ *ptr += 4;
+
+ return *res;
+}
+
+int64_t BinaryObjectParser::ParseLong(uint8_t** ptr) const {
+ int64_t* res = *reinterpret_cast<int64_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt64(res);
+ *ptr += 8;
+
+ return *res;
+}
+
+float BinaryObjectParser::ParseFloat(uint8_t** ptr) const {
+ float* res = *reinterpret_cast<float**>(ptr);
+ byte_swapper_.SwapIfRequiredFloat(res);
+ *ptr += 4;
+
+ return *res;
+}
+
+double BinaryObjectParser::ParseDouble(uint8_t** ptr) const {
+ double* res = *reinterpret_cast<double**>(ptr);
+ byte_swapper_.SwapIfRequiredDouble(res);
+ *ptr += 8;
+
+ return *res;
+}
+
+bool BinaryObjectParser::ParseBool(uint8_t** ptr) const {
+ bool res = **reinterpret_cast<bool**>(ptr);
+ *ptr += 1;
+
+ return res;
+}
+
+string BinaryObjectParser::ParseString(uint8_t** ptr) const {
+ int32_t length = ParseInt(ptr);
+ string res(*reinterpret_cast<char**>(ptr), length);
+ *ptr += length;
+
+ return res;
+}
+
+uint8_t* BinaryObjectParser::ParseByteArr(uint8_t** ptr, int length) const {
+ uint8_t* res = *reinterpret_cast<uint8_t**>(ptr);
+ *ptr += length;
+
+ return res;
+}
+
+int16_t* BinaryObjectParser::ParseShortArr(uint8_t** ptr, int length) const {
+ int16_t* res = *reinterpret_cast<int16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt16Arr(res, length);
+ *ptr += length * 2;
+
+ return res;
+}
+
+uint16_t* BinaryObjectParser::ParseUnsignedShortArr(uint8_t** ptr,
+ int length) const {
+ uint16_t* res = *reinterpret_cast<uint16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredUnsignedInt16Arr(res, length);
+ *ptr += length * 2;
+
+ return res;
+}
+
+int32_t* BinaryObjectParser::ParseIntArr(uint8_t** ptr, int length) const {
+ int32_t* res = *reinterpret_cast<int32_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt32Arr(res, length);
+ *ptr += length * 4;
+
+ return res;
+}
+
+int64_t* BinaryObjectParser::ParseLongArr(uint8_t** ptr, int length) const {
+ int64_t* res = *reinterpret_cast<int64_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt64Arr(res, length);
+ *ptr += length * 8;
+
+ return res;
+}
+
+float* BinaryObjectParser::ParseFloatArr(uint8_t** ptr, int length) const {
+ float* res = *reinterpret_cast<float**>(ptr);
+ byte_swapper_.SwapIfRequiredFloatArr(res, length);
+ *ptr += length * 4;
+
+ return res;
+}
+
+double* BinaryObjectParser::ParseDoubleArr(uint8_t** ptr, int length) const {
+ double* res = *reinterpret_cast<double**>(ptr);
+ byte_swapper_.SwapIfRequiredDoubleArr(res, length);
+ *ptr += length * 8;
+
+ return res;
+}
+
+bool* BinaryObjectParser::ParseBoolArr(uint8_t** ptr, int length) const {
+ bool* res = *reinterpret_cast<bool**>(ptr);
+ *ptr += length;
+
+ return res;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h
new file mode 100644
index 0000000000..eb1f856643
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_
+
+#include <vector>
+#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class BinaryObjectParser {
+ public:
+ BinaryObjectParser();
+ Status Parse(uint8_t** ptr, std::vector<Tensor>* out_tensors,
+ std::vector<int32_t>* types) const;
+
+ private:
+ uint8_t ParseByte(uint8_t** ptr) const;
+ int16_t ParseShort(uint8_t** ptr) const;
+ uint16_t ParseUnsignedShort(uint8_t** ptr) const;
+ int32_t ParseInt(uint8_t** ptr) const;
+ int64_t ParseLong(uint8_t** ptr) const;
+ float ParseFloat(uint8_t** ptr) const;
+ double ParseDouble(uint8_t** ptr) const;
+ bool ParseBool(uint8_t** ptr) const;
+ string ParseString(uint8_t** ptr) const;
+ uint8_t* ParseByteArr(uint8_t** ptr, int length) const;
+ int16_t* ParseShortArr(uint8_t** ptr, int length) const;
+ uint16_t* ParseUnsignedShortArr(uint8_t** ptr, int length) const;
+ int32_t* ParseIntArr(uint8_t** ptr, int length) const;
+ int64_t* ParseLongArr(uint8_t** ptr, int length) const;
+ float* ParseFloatArr(uint8_t** ptr, int length) const;
+ double* ParseDoubleArr(uint8_t** ptr, int length) const;
+ bool* ParseBoolArr(uint8_t** ptr, int length) const;
+
+ const ByteSwapper byte_swapper_;
+};
+
+enum ObjectType {
+ BYTE = 1,
+ SHORT = 2,
+ INT = 3,
+ LONG = 4,
+ FLOAT = 5,
+ DOUBLE = 6,
+ USHORT = 7,
+ BOOL = 8,
+ STRING = 9,
+ DATE = 11,
+ BYTE_ARR = 12,
+ SHORT_ARR = 13,
+ INT_ARR = 14,
+ LONG_ARR = 15,
+ FLOAT_ARR = 16,
+ DOUBLE_ARR = 17,
+ USHORT_ARR = 18,
+ BOOL_ARR = 19,
+ STRING_ARR = 20,
+ DATE_ARR = 22,
+ WRAPPED_OBJ = 27,
+ COMPLEX_OBJ = 103
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h
new file mode 100644
index 0000000000..46df3e39dc
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_
+
+#include <stdint.h>
+#include "tensorflow/core/platform/byte_order.h"
+
+namespace tensorflow {
+
+class ByteSwapper {
+ public:
+ ByteSwapper(bool big_endian) { swap_ = big_endian == port::kLittleEndian; }
+
+ inline void SwapIfRequiredInt16(int16_t *x) const {
+ if (swap_) {
+ Swap16(x);
+ }
+ }
+
+ inline void SwapIfRequiredUnsignedInt16(uint16_t *x) const {
+ if (swap_) {
+ Swap16(reinterpret_cast<int16_t *>(x));
+ }
+ }
+
+ inline void SwapIfRequiredInt32(int32_t *x) const {
+ if (swap_) {
+ Swap32(x);
+ }
+ }
+
+ inline void SwapIfRequiredFloat(float *x) const {
+ if (swap_) {
+ Swap32(reinterpret_cast<int32_t *>(x));
+ }
+ }
+
+ inline void SwapIfRequiredInt64(int64_t *x) const {
+ if (swap_) {
+ Swap64(x);
+ }
+ }
+
+ inline void SwapIfRequiredDouble(double *x) const {
+ if (swap_) {
+ Swap64(reinterpret_cast<int64_t *>(x));
+ }
+ }
+
+ inline void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++) Swap16(&x[i]);
+ }
+ }
+
+ inline void SwapIfRequiredUnsignedInt16Arr(uint16_t *x,
+ int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++)
+ Swap16(reinterpret_cast<int16_t *>(&x[i]));
+ }
+ }
+
+ inline void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++) Swap32(&x[i]);
+ }
+ }
+
+ inline void SwapIfRequiredFloatArr(float *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++)
+ Swap32(reinterpret_cast<int32_t *>(&x[i]));
+ }
+ }
+
+ inline void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++) Swap64(&x[i]);
+ }
+ }
+
+ inline void SwapIfRequiredDoubleArr(double *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++)
+ Swap64(reinterpret_cast<int64_t *>(&x[i]));
+ }
+ }
+
+ private:
+ inline void Swap16(int16_t *x) const {
+ *x = ((*x & 0xFF) << 8) | ((*x >> 8) & 0xFF);
+ }
+
+ inline void Swap32(int32_t *x) const {
+ *x = ((*x & 0xFF) << 24) | (((*x >> 8) & 0xFF) << 16) |
+ (((*x >> 16) & 0xFF) << 8) | ((*x >> 24) & 0xFF);
+ }
+
+ inline void Swap64(int64_t *x) const {
+ *x = ((*x & 0xFF) << 56) | (((*x >> 8) & 0xFF) << 48) |
+ (((*x >> 16) & 0xFF) << 40) | (((*x >> 24) & 0xFF) << 32) |
+ (((*x >> 32) & 0xFF) << 24) | (((*x >> 40) & 0xFF) << 16) |
+ (((*x >> 48) & 0xFF) << 8) | ((*x >> 56) & 0xFF);
+ }
+
+ bool swap_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.h b/tensorflow/contrib/ignite/kernels/ignite_client.h
new file mode 100644
index 0000000000..459b50b48f
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_client.h
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class Client {
+ public:
+ Client(bool big_endian) : byte_swapper_(ByteSwapper(big_endian)) {}
+ virtual Status Connect() = 0;
+ virtual Status Disconnect() = 0;
+ virtual bool IsConnected() = 0;
+ virtual int GetSocketDescriptor() = 0;
+ virtual Status ReadData(uint8_t *buf, const int32_t length) = 0;
+ virtual Status WriteData(const uint8_t *buf, const int32_t length) = 0;
+
+ inline Status ReadByte(uint8_t *data) { return ReadData(data, 1); }
+
+ inline Status ReadShort(int16_t *data) {
+ TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 2));
+ byte_swapper_.SwapIfRequiredInt16(data);
+
+ return Status::OK();
+ }
+
+ inline Status ReadInt(int32_t *data) {
+ TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 4));
+ byte_swapper_.SwapIfRequiredInt32(data);
+
+ return Status::OK();
+ }
+
+ inline Status ReadLong(int64_t *data) {
+ TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 8));
+ byte_swapper_.SwapIfRequiredInt64(data);
+
+ return Status::OK();
+ }
+
+ inline Status WriteByte(const uint8_t data) { return WriteData(&data, 1); }
+
+ inline Status WriteShort(const int16_t data) {
+ int16_t tmp = data;
+ byte_swapper_.SwapIfRequiredInt16(&tmp);
+ return WriteData((uint8_t *)&tmp, 2);
+ }
+
+ inline Status WriteInt(const int32_t data) {
+ int32_t tmp = data;
+ byte_swapper_.SwapIfRequiredInt32(&tmp);
+ return WriteData((uint8_t *)&tmp, 4);
+ }
+
+ inline Status WriteLong(const int64_t data) {
+ int64_t tmp = data;
+ byte_swapper_.SwapIfRequiredInt64(&tmp);
+ return WriteData((uint8_t *)&tmp, 8);
+ }
+
+ private:
+ const ByteSwapper byte_swapper_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc
new file mode 100644
index 0000000000..c4a7d3c513
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+IgniteDataset::IgniteDataset(OpKernelContext* ctx, string cache_name,
+ string host, int32 port, bool local, int32 part,
+ int32 page_size, string username, string password,
+ string certfile, string keyfile,
+ string cert_password, std::vector<int32> schema,
+ std::vector<int32> permutation,
+ DataTypeVector dtypes,
+ std::vector<PartialTensorShape> shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ cache_name_(std::move(cache_name)),
+ host_(std::move(host)),
+ port_(port),
+ local_(local),
+ part_(part),
+ page_size_(page_size),
+ username_(std::move(username)),
+ password_(std::move(password)),
+ certfile_(std::move(certfile)),
+ keyfile_(std::move(keyfile)),
+ cert_password_(std::move(cert_password)),
+ schema_(std::move(schema)),
+ permutation_(std::move(permutation)),
+ dtypes_(dtypes),
+ shapes_(shapes) {
+ LOG(INFO) << "Ignite Dataset created [cache_name='" << cache_name_
+ << "', host='" << host_ << "', port=" << port_
+ << ", local=" << local_ << ", part=" << part_
+ << ", page_size=" << page_size_ << ", username='" << username_
+ << "', certfile='" << certfile_ << "', keyfile='"
+ << keyfile_ + "']";
+}
+
+IgniteDataset::~IgniteDataset() { LOG(INFO) << "Ignite Dataset destroyed"; }
+
+std::unique_ptr<IteratorBase> IgniteDataset::MakeIteratorInternal(
+ const string& prefix) const {
+ return std::unique_ptr<IteratorBase>(new IgniteDatasetIterator(
+ {this, strings::StrCat(prefix, "::Ignite")}, std::move(this->host_),
+ this->port_, std::move(this->cache_name_), this->local_, this->part_,
+ this->page_size_, std::move(this->username_), std::move(this->password_),
+ std::move(this->certfile_), std::move(this->keyfile_),
+ std::move(this->cert_password_), std::move(this->schema_),
+ std::move(this->permutation_)));
+}
+
+const DataTypeVector& IgniteDataset::output_dtypes() const { return dtypes_; }
+
+const std::vector<PartialTensorShape>& IgniteDataset::output_shapes() const {
+ return shapes_;
+}
+
+string IgniteDataset::DebugString() const { return "IgniteDatasetOp::Dataset"; }
+
+Status IgniteDataset::AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const {
+ return errors::Unimplemented(
+ "IgniteDataset does not support 'AsGraphDefInternal'");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.h b/tensorflow/contrib/ignite/kernels/ignite_dataset.h
new file mode 100644
index 0000000000..66bfdf2e2a
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.h
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_
+
+#include "tensorflow/core/framework/dataset.h"
+
+namespace tensorflow {
+
+class IgniteDataset : public DatasetBase {
+ public:
+ IgniteDataset(OpKernelContext* ctx, string cache_name, string host,
+ int32 port, bool local, int32 part, int32 page_size,
+ string username, string password, string certfile,
+ string keyfile, string cert_password, std::vector<int32> schema,
+ std::vector<int32> permutation, DataTypeVector dtypes,
+ std::vector<PartialTensorShape> shapes);
+ ~IgniteDataset();
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override;
+ const DataTypeVector& output_dtypes() const override;
+ const std::vector<PartialTensorShape>& output_shapes() const override;
+ string DebugString() const override;
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override;
+
+ private:
+ const string cache_name_;
+ const string host_;
+ const int32 port_;
+ const bool local_;
+ const int32 part_;
+ const int32 page_size_;
+ const string username_;
+ const string password_;
+ const string certfile_;
+ const string keyfile_;
+ const string cert_password_;
+ const std::vector<int32> schema_;
+ const std::vector<int32> permutation_;
+ const DataTypeVector dtypes_;
+ const std::vector<PartialTensorShape> shapes_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
new file mode 100644
index 0000000000..5da9127aa6
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
@@ -0,0 +1,422 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h"
+
+#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+IgniteDatasetIterator::IgniteDatasetIterator(
+ const Params& params, string host, int32 port, string cache_name,
+ bool local, int32 part, int32 page_size, string username, string password,
+ string certfile, string keyfile, string cert_password,
+ std::vector<int32> schema, std::vector<int32> permutation)
+ : DatasetIterator<IgniteDataset>(params),
+ cache_name_(std::move(cache_name)),
+ local_(local),
+ part_(part),
+ page_size_(page_size),
+ username_(std::move(username)),
+ password_(std::move(password)),
+ schema_(std::move(schema)),
+ permutation_(std::move(permutation)),
+ remainder_(-1),
+ cursor_id_(-1),
+ last_page_(false),
+ valid_state_(true) {
+ Client* p_client = new PlainClient(std::move(host), port, false);
+
+ if (certfile.empty())
+ client_ = std::unique_ptr<Client>(p_client);
+ else
+ client_ = std::unique_ptr<Client>(
+ new SslWrapper(std::unique_ptr<Client>(p_client), std::move(certfile),
+ std::move(keyfile), std::move(cert_password), false));
+
+ LOG(INFO) << "Ignite Dataset Iterator created";
+}
+
+IgniteDatasetIterator::~IgniteDatasetIterator() {
+ Status status = CloseConnection();
+ if (!status.ok()) LOG(ERROR) << status.ToString();
+
+ LOG(INFO) << "Ignite Dataset Iterator destroyed";
+}
+
+Status IgniteDatasetIterator::GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ mutex_lock l(mutex_);
+
+ if (valid_state_) {
+ Status status =
+ GetNextInternalWithValidState(ctx, out_tensors, end_of_sequence);
+
+ if (!status.ok()) valid_state_ = false;
+
+ return status;
+ }
+
+ return errors::Unknown("Iterator is invalid");
+}
+
+Status IgniteDatasetIterator::SaveInternal(IteratorStateWriter* writer) {
+ return errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'SaveInternal'");
+}
+
+Status IgniteDatasetIterator::RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) {
+ return errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'RestoreInternal')");
+}
+
+Status IgniteDatasetIterator::GetNextInternalWithValidState(
+ IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ if (remainder_ == 0 && last_page_) {
+ cursor_id_ = -1;
+ *end_of_sequence = true;
+
+ return Status::OK();
+ } else {
+ TF_RETURN_IF_ERROR(EstablishConnection());
+
+ if (remainder_ == -1) {
+ TF_RETURN_IF_ERROR(ScanQuery());
+ } else if (remainder_ == 0) {
+ TF_RETURN_IF_ERROR(LoadNextPage());
+ }
+
+ uint8_t* initial_ptr = ptr_;
+ std::vector<Tensor> tensors;
+ std::vector<int32_t> types;
+
+ TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse key
+ TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse val
+
+ remainder_ -= (ptr_ - initial_ptr);
+
+ TF_RETURN_IF_ERROR(CheckTypes(types));
+
+ for (size_t i = 0; i < tensors.size(); i++)
+ out_tensors->push_back(tensors[permutation_[i]]);
+
+ *end_of_sequence = false;
+
+ return Status::OK();
+ }
+
+ *end_of_sequence = true;
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::EstablishConnection() {
+ if (!client_->IsConnected()) {
+ TF_RETURN_IF_ERROR(client_->Connect());
+
+ Status status = Handshake();
+ if (!status.ok()) {
+ Status disconnect_status = client_->Disconnect();
+ if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString();
+
+ return status;
+ }
+ }
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::CloseConnection() {
+ if (cursor_id_ != -1 && !last_page_) {
+ TF_RETURN_IF_ERROR(EstablishConnection());
+
+ TF_RETURN_IF_ERROR(client_->WriteInt(kCloseConnectionReqLength));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kCloseConnectionOpcode));
+ TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
+ TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Resource ID
+
+ int32_t res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
+ if (res_len < kMinResLength)
+ return errors::Unknown("Close Resource Response is corrupted");
+
+ int64_t req_id;
+ TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
+ int32_t status;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&status));
+ if (status != 0) {
+ uint8_t err_msg_header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
+ if (err_msg_header == kStringVal) {
+ int32_t err_msg_length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
+
+ return errors::Unknown("Close Resource Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return errors::Unknown("Close Resource Error [status=", status, "]");
+ }
+
+ cursor_id_ = -1;
+
+ return client_->Disconnect();
+ } else {
+ LOG(INFO) << "Query Cursor " << cursor_id_ << " is already closed";
+ }
+
+ return client_->IsConnected() ? client_->Disconnect() : Status::OK();
+}
+
+Status IgniteDatasetIterator::Handshake() {
+ int32_t msg_len = kHandshakeReqDefaultLength;
+
+ if (username_.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + username_.length(); // 1 byte header, 4 bytes length.
+
+ if (password_.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + password_.length(); // 1 byte header, 4 bytes length.
+
+ TF_RETURN_IF_ERROR(client_->WriteInt(msg_len));
+ TF_RETURN_IF_ERROR(client_->WriteByte(1));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMajorVersion));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMinorVersion));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolPatchVersion));
+ TF_RETURN_IF_ERROR(client_->WriteByte(2));
+ if (username_.empty()) {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
+ } else {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
+ TF_RETURN_IF_ERROR(client_->WriteInt(username_.length()));
+ TF_RETURN_IF_ERROR(
+ client_->WriteData(reinterpret_cast<const uint8_t*>(username_.c_str()),
+ username_.length()));
+ }
+
+ if (password_.empty()) {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
+ } else {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
+ TF_RETURN_IF_ERROR(client_->WriteInt(password_.length()));
+ TF_RETURN_IF_ERROR(
+ client_->WriteData(reinterpret_cast<const uint8_t*>(password_.c_str()),
+ password_.length()));
+ }
+
+ int32_t handshake_res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&handshake_res_len));
+ uint8_t handshake_res;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&handshake_res));
+
+ if (handshake_res != 1) {
+ int16_t serv_ver_major;
+ TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_major));
+ int16_t serv_ver_minor;
+ TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_minor));
+ int16_t serv_ver_patch;
+ TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_patch));
+ uint8_t header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&header));
+
+ if (header == kStringVal) {
+ int32_t length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&length));
+
+ uint8_t* err_msg_c = new uint8_t[length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), length);
+
+ return errors::Unknown("Handshake Error [result=", handshake_res,
+ ", version=", serv_ver_major, ".", serv_ver_minor,
+ ".", serv_ver_patch, ", message='", err_msg, "']");
+ } else if (header == kNullVal) {
+ return errors::Unknown("Handshake Error [result=", handshake_res,
+ ", version=", serv_ver_major, ".", serv_ver_minor,
+ ".", serv_ver_patch, "]");
+ } else {
+ return errors::Unknown("Handshake Error [result=", handshake_res,
+ ", version=", serv_ver_major, ".", serv_ver_minor,
+ ".", serv_ver_patch, "]");
+ }
+ }
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::ScanQuery() {
+ TF_RETURN_IF_ERROR(client_->WriteInt(kScanQueryReqLength));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kScanQueryOpcode));
+ TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
+ TF_RETURN_IF_ERROR(
+ client_->WriteInt(JavaHashCode(cache_name_))); // Cache name
+ TF_RETURN_IF_ERROR(client_->WriteByte(0)); // Flags
+ TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); // Filter object
+ TF_RETURN_IF_ERROR(client_->WriteInt(page_size_)); // Cursor page size
+ TF_RETURN_IF_ERROR(client_->WriteInt(part_)); // part_ition to query
+ TF_RETURN_IF_ERROR(client_->WriteByte(local_)); // local_ flag
+
+ uint64 wait_start = Env::Default()->NowMicros();
+ int32_t res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
+ int64_t wait_stop = Env::Default()->NowMicros();
+
+ LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) / 1000 << " ms";
+
+ if (res_len < kMinResLength)
+ return errors::Unknown("Scan Query Response is corrupted");
+
+ int64_t req_id;
+ TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
+
+ int32_t status;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
+
+ if (err_msg_header == kStringVal) {
+ int32_t err_msg_length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
+
+ return errors::Unknown("Scan Query Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return errors::Unknown("Scan Query Error [status=", status, "]");
+ }
+
+ TF_RETURN_IF_ERROR(client_->ReadLong(&cursor_id_));
+
+ int32_t row_cnt;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
+
+ int32_t page_size = res_len - kScanQueryResHeaderLength;
+
+ return ReceivePage(page_size);
+}
+
+Status IgniteDatasetIterator::LoadNextPage() {
+ TF_RETURN_IF_ERROR(client_->WriteInt(kLoadNextPageReqLength));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kLoadNextPageOpcode));
+ TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
+ TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Cursor ID
+
+ uint64 wait_start = Env::Default()->NowMicros();
+ int32_t res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
+ uint64 wait_stop = Env::Default()->NowMicros();
+
+ LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) / 1000
+ << " ms";
+
+ if (res_len < kMinResLength)
+ return errors::Unknown("Load Next Page Response is corrupted");
+
+ int64_t req_id;
+ TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
+
+ int32_t status;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
+
+ if (err_msg_header == kStringVal) {
+ int32_t err_msg_length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
+
+ return errors::Unknown("Load Next Page Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return errors::Unknown("Load Next Page Error [status=", status, "]");
+ }
+
+ int32_t row_cnt;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
+
+ int32_t page_size = res_len - kLoadNextPageResHeaderLength;
+
+ return ReceivePage(page_size);
+}
+
+Status IgniteDatasetIterator::ReceivePage(int32_t page_size) {
+ remainder_ = page_size;
+ page_ = std::unique_ptr<uint8_t>(new uint8_t[remainder_]);
+ ptr_ = page_.get();
+
+ uint64 start = Env::Default()->NowMicros();
+ TF_RETURN_IF_ERROR(client_->ReadData(ptr_, remainder_));
+ uint64 stop = Env::Default()->NowMicros();
+
+ double size_in_mb = 1.0 * remainder_ / 1024 / 1024;
+ double time_in_s = 1.0 * (stop - start) / 1000 / 1000;
+ LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
+ << " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
+
+ uint8_t last_page_b;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&last_page_b));
+
+ last_page_ = !last_page_b;
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::CheckTypes(const std::vector<int32_t>& types) {
+ if (schema_.size() != types.size())
+ return errors::Unknown("Object has unexpected schema");
+
+ for (size_t i = 0; i < schema_.size(); i++) {
+ if (schema_[i] != types[permutation_[i]])
+ return errors::Unknown("Object has unexpected schema");
+ }
+
+ return Status::OK();
+}
+
+int32_t IgniteDatasetIterator::JavaHashCode(string str) const {
+ int32_t h = 0;
+ for (char& c : str) {
+ h = 31 * h + c;
+ }
+ return h;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h
new file mode 100644
index 0000000000..c499e2c9cc
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h
@@ -0,0 +1,99 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_client.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+class IgniteDatasetIterator : public DatasetIterator<IgniteDataset> {
+ public:
+ IgniteDatasetIterator(const Params& params, string host, int32 port,
+ string cache_name, bool local, int32 part,
+ int32 page_size, string username, string password,
+ string certfile, string keyfile, string cert_password,
+ std::vector<int32> schema,
+ std::vector<int32> permutation);
+ ~IgniteDatasetIterator();
+ Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override;
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override;
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override;
+
+ private:
+ Status GetNextInternalWithValidState(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence);
+
+ Status EstablishConnection();
+ Status CloseConnection();
+ Status Handshake();
+ Status ScanQuery();
+ Status LoadNextPage();
+ Status ReceivePage(int32_t page_size);
+ Status CheckTypes(const std::vector<int32_t>& types);
+ int32_t JavaHashCode(string str) const;
+
+ std::unique_ptr<Client> client_;
+ BinaryObjectParser parser_;
+
+ const string cache_name_;
+ const bool local_;
+ const int32 part_;
+ const int32 page_size_;
+ const string username_;
+ const string password_;
+ const std::vector<int32> schema_;
+ const std::vector<int32> permutation_;
+
+ int32_t remainder_;
+ int64_t cursor_id_;
+ bool last_page_;
+
+ bool valid_state_;
+
+ mutex mutex_;
+
+ std::unique_ptr<uint8_t> page_;
+ uint8_t* ptr_;
+};
+
+constexpr uint8_t kNullVal = 101;
+constexpr uint8_t kStringVal = 9;
+constexpr uint8_t kProtocolMajorVersion = 1;
+constexpr uint8_t kProtocolMinorVersion = 1;
+constexpr uint8_t kProtocolPatchVersion = 0;
+constexpr int16_t kScanQueryOpcode = 2000;
+constexpr int16_t kLoadNextPageOpcode = 2001;
+constexpr int16_t kCloseConnectionOpcode = 0;
+constexpr int32_t kScanQueryReqLength = 25;
+constexpr int32_t kScanQueryResHeaderLength = 25;
+constexpr int32_t kLoadNextPageReqLength = 18;
+constexpr int32_t kLoadNextPageResHeaderLength = 17;
+constexpr int32_t kCloseConnectionReqLength = 18;
+constexpr int32_t kHandshakeReqDefaultLength = 8;
+constexpr int32_t kMinResLength = 12;
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
new file mode 100644
index 0000000000..f75b1c5ff5
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
@@ -0,0 +1,198 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdlib.h>
+
+#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+
+namespace tensorflow {
+namespace {
+
+Status SchemaToTypes(const std::vector<int32>& schema, DataTypeVector* dtypes) {
+ for (auto e : schema) {
+ if (e == BYTE || e == BYTE_ARR) {
+ dtypes->push_back(DT_UINT8);
+ } else if (e == SHORT || e == SHORT_ARR) {
+ dtypes->push_back(DT_INT16);
+ } else if (e == INT || e == INT_ARR) {
+ dtypes->push_back(DT_INT32);
+ } else if (e == LONG || e == LONG_ARR) {
+ dtypes->push_back(DT_INT64);
+ } else if (e == FLOAT || e == FLOAT_ARR) {
+ dtypes->push_back(DT_FLOAT);
+ } else if (e == DOUBLE || e == DOUBLE_ARR) {
+ dtypes->push_back(DT_DOUBLE);
+ } else if (e == USHORT || e == USHORT_ARR) {
+ dtypes->push_back(DT_UINT8);
+ } else if (e == BOOL || e == BOOL_ARR) {
+ dtypes->push_back(DT_BOOL);
+ } else if (e == STRING || e == STRING_ARR) {
+ dtypes->push_back(DT_STRING);
+ } else {
+ return errors::Unknown("Unexpected type in schema [type_id=", e, "]");
+ }
+ }
+
+ return Status::OK();
+}
+
+Status SchemaToShapes(const std::vector<int32>& schema,
+ std::vector<PartialTensorShape>* shapes) {
+ for (auto e : schema) {
+ if (e >= 1 && e < 10) {
+ shapes->push_back(PartialTensorShape({}));
+ } else if (e >= 12 && e < 21) {
+ shapes->push_back(PartialTensorShape({-1}));
+ } else {
+ return errors::Unknown("Unexpected type in schema [type_id=", e, "]");
+ }
+ }
+
+ return Status::OK();
+}
+
+class IgniteDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string cache_name = "";
+ string host = "";
+ int32 port = -1;
+ bool local = false;
+ int32 part = -1;
+ int32 page_size = -1;
+ string username = "";
+ string password = "";
+ string certfile = "";
+ string keyfile = "";
+ string cert_password = "";
+
+ const char* env_cache_name = std::getenv("IGNITE_DATASET_CACHE_NAME");
+ const char* env_host = std::getenv("IGNITE_DATASET_HOST");
+ const char* env_port = std::getenv("IGNITE_DATASET_PORT");
+ const char* env_local = std::getenv("IGNITE_DATASET_LOCAL");
+ const char* env_part = std::getenv("IGNITE_DATASET_PART");
+ const char* env_page_size = std::getenv("IGNITE_DATASET_PAGE_SIZE");
+ const char* env_username = std::getenv("IGNITE_DATASET_USERNAME");
+ const char* env_password = std::getenv("IGNITE_DATASET_PASSWORD");
+ const char* env_certfile = std::getenv("IGNITE_DATASET_CERTFILE");
+ const char* env_keyfile = std::getenv("IGNITE_DATASET_KEYFILE");
+ const char* env_cert_password = std::getenv("IGNITE_DATASET_CERT_PASSWORD");
+
+ if (env_cache_name) {
+ cache_name = string(env_cache_name);
+ } else {
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<string>(ctx, "cache_name", &cache_name));
+ }
+
+ if (env_host) {
+ host = string(env_host);
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "host", &host));
+ }
+
+ if (env_port) {
+ OP_REQUIRES(ctx, strings::safe_strto32(env_port, &port),
+ errors::InvalidArgument("IGNITE_DATASET_PORT environment "
+ "variable is not a valid integer: ",
+ env_port));
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "port", &port));
+ }
+
+ if (env_local) {
+ local = true;
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "local", &local));
+ }
+
+ if (env_part) {
+ OP_REQUIRES(ctx, strings::safe_strto32(env_part, &part),
+ errors::InvalidArgument("IGNITE_DATASET_PART environment "
+ "variable is not a valid integer: ",
+ env_part));
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "part", &part));
+ }
+
+ if (env_page_size) {
+ OP_REQUIRES(ctx, strings::safe_strto32(env_page_size, &page_size),
+ errors::InvalidArgument("IGNITE_DATASET_PAGE_SIZE "
+ "environment variable is not a valid "
+ "integer: ",
+ env_page_size));
+ } else {
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int32>(ctx, "page_size", &page_size));
+ }
+
+ if (env_username) username = string(env_username);
+
+ if (env_password) password = string(env_password);
+
+ if (env_certfile) certfile = string(env_certfile);
+
+ if (env_keyfile) keyfile = string(env_keyfile);
+
+ if (env_cert_password) cert_password = string(env_cert_password);
+
+ const Tensor* schema_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("schema", &schema_tensor));
+ OP_REQUIRES(ctx, schema_tensor->dims() == 1,
+ errors::InvalidArgument("`schema` must be a vector."));
+
+ std::vector<int32> schema;
+ schema.reserve(schema_tensor->NumElements());
+ for (int i = 0; i < schema_tensor->NumElements(); i++) {
+ schema.push_back(schema_tensor->flat<int32>()(i));
+ }
+
+ const Tensor* permutation_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("permutation", &permutation_tensor));
+ OP_REQUIRES(ctx, permutation_tensor->dims() == 1,
+ errors::InvalidArgument("`permutation` must be a vector."));
+
+ std::vector<int32> permutation;
+ permutation.resize(permutation_tensor->NumElements());
+ for (int i = 0; i < permutation_tensor->NumElements(); i++) {
+ // Inversed permutation.
+ permutation[permutation_tensor->flat<int32>()(i)] = i;
+ }
+
+ DataTypeVector dtypes;
+ std::vector<PartialTensorShape> shapes;
+
+ OP_REQUIRES_OK(ctx, SchemaToTypes(schema, &dtypes));
+ OP_REQUIRES_OK(ctx, SchemaToShapes(schema, &shapes));
+
+ *output = new IgniteDataset(
+ ctx, std::move(cache_name), std::move(host), port, local, part,
+ page_size, std::move(username), std::move(password),
+ std::move(certfile), std::move(keyfile), std::move(cert_password),
+ std::move(schema), std::move(permutation), std::move(dtypes),
+ std::move(shapes));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("IgniteDataset").Device(DEVICE_CPU),
+ IgniteDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h
new file mode 100644
index 0000000000..75424c19ee
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_client.h"
+
+namespace tensorflow {
+
+class PlainClient : public Client {
+ public:
+ PlainClient(string host, int port, bool big_endian);
+ ~PlainClient();
+
+ Status Connect() override;
+ Status Disconnect() override;
+ bool IsConnected() override;
+ int GetSocketDescriptor() override;
+ Status ReadData(uint8_t* buf, const int32_t length) override;
+ Status WriteData(const uint8_t* buf, const int32_t length) override;
+
+ private:
+ const string host_;
+ const int port_;
+ int sock_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc
new file mode 100644
index 0000000000..cf672942c6
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc
@@ -0,0 +1,123 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h"
+
+#include <arpa/inet.h>
+#include <netdb.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <map>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+PlainClient::PlainClient(string host, int port, bool big_endian)
+ : Client(big_endian), host_(std::move(host)), port_(port), sock_(-1) {}
+
+PlainClient::~PlainClient() {
+ if (IsConnected()) {
+ Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+}
+
+Status PlainClient::Connect() {
+ if (sock_ == -1) {
+ sock_ = socket(AF_INET, SOCK_STREAM, 0);
+ if (sock_ == -1) return errors::Internal("Failed to create socket");
+ }
+
+ sockaddr_in server;
+
+ server.sin_addr.s_addr = inet_addr(host_.c_str());
+ if (server.sin_addr.s_addr == -1) {
+ hostent* he;
+ in_addr** addr_list;
+
+ if ((he = gethostbyname(host_.c_str())) == NULL)
+ return errors::Internal("Failed to resolve hostname \"", host_, "\"");
+
+ addr_list = (in_addr**)he->h_addr_list;
+ if (addr_list[0] != NULL) server.sin_addr = *addr_list[0];
+ }
+
+ server.sin_family = AF_INET;
+ server.sin_port = htons(port_);
+
+ if (connect(sock_, (sockaddr*)&server, sizeof(server)) < 0)
+ return errors::Internal("Failed to connect to \"", host_, ":", port_, "\"");
+
+ LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established";
+
+ return Status::OK();
+}
+
+Status PlainClient::Disconnect() {
+ int close_res = close(sock_);
+ sock_ = -1;
+
+ LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" is closed";
+
+ return close_res == 0
+ ? Status::OK()
+ : errors::Internal("Failed to correctly close connection");
+}
+
+bool PlainClient::IsConnected() { return sock_ != -1; }
+
+int PlainClient::GetSocketDescriptor() { return sock_; }
+
+Status PlainClient::ReadData(uint8_t* buf, const int32_t length) {
+ int received = 0;
+
+ while (received < length) {
+ int res = recv(sock_, buf, length - received, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while reading from socket: ", res,
+ ", ", string(strerror(errno)));
+
+ if (res == 0) return errors::Internal("Server closed connection");
+
+ received += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+Status PlainClient::WriteData(const uint8_t* buf, const int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = send(sock_, buf, length - sent, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while writing into socket: ", res,
+ ", ", string(strerror(errno)));
+
+ sent += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc
new file mode 100644
index 0000000000..dad5aace5f
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h"
+
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+#pragma comment(lib, "Ws2_32.lib")
+#pragma comment(lib, "Mswsock.lib")
+#pragma comment(lib, "AdvApi32.lib")
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+PlainClient::PlainClient(string host, int port, bool big_endian)
+ : Client(big_endian),
+ host_(std::move(host)),
+ port_(port),
+ sock_(INVALID_SOCKET) {}
+
+PlainClient::~PlainClient() {
+ if (IsConnected()) {
+ Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+}
+
+Status PlainClient::Connect() {
+ WSADATA wsaData;
+ addrinfo *result = NULL, *ptr = NULL, hints;
+
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
+ if (res != 0) return errors::Internal("WSAStartup failed with error: ", res);
+
+ ZeroMemory(&hints, sizeof(hints));
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = IPPROTO_TCP;
+
+ res = getaddrinfo(host_.c_str(), std::to_string(port_).c_str(), &hints,
+ &result);
+ if (res != 0) return errors::Internal("Getaddrinfo failed with error: ", res);
+
+ auto clean = gtl::MakeCleanup([result] { freeaddrinfo(result); });
+
+ for (ptr = result; ptr != NULL; ptr = ptr->ai_next) {
+ sock_ = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol);
+ if (sock_ == INVALID_SOCKET) {
+ WSACleanup();
+ return errors::Internal("Socket failed with error: ", WSAGetLastError());
+ }
+
+ res = connect(sock_, ptr->ai_addr, (int)ptr->ai_addrlen);
+ if (res == SOCKET_ERROR) {
+ closesocket(sock_);
+ sock_ = INVALID_SOCKET;
+ continue;
+ }
+
+ break;
+ }
+
+ if (sock_ == INVALID_SOCKET) {
+ WSACleanup();
+ return errors::Internal("Unable to connect to server");
+ }
+
+ LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established";
+
+ return Status::OK();
+}
+
+Status PlainClient::Disconnect() {
+ int res = shutdown(sock_, SD_SEND);
+ closesocket(sock_);
+ WSACleanup();
+
+ if (res == SOCKET_ERROR)
+ return errors::Internal("Shutdown failed with error: ", WSAGetLastError());
+ else
+ return Status::OK();
+}
+
+bool PlainClient::IsConnected() { return sock_ != INVALID_SOCKET; }
+
+int PlainClient::GetSocketDescriptor() { return sock_; }
+
+Status PlainClient::ReadData(uint8_t *buf, const int32_t length) {
+ int received = 0;
+
+ while (received < length) {
+ int res = recv(sock_, (char *)buf, length - received, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while reading from socket: ",
+ res);
+
+ if (res == 0) return errors::Internal("Server closed connection");
+
+ received += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+Status PlainClient::WriteData(const uint8_t *buf, const int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = send(sock_, (char *)buf, length - sent, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while writing into socket: ",
+ res);
+
+ sent += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
new file mode 100644
index 0000000000..ceb479b084
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h"
+
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+static int PasswordCb(char *buf, int size, int rwflag, void *password) {
+ strncpy(buf, (char *)(password), size);
+ buf[size - 1] = '\0';
+ return (strlen(buf));
+}
+
+SslWrapper::SslWrapper(std::shared_ptr<Client> client, string certfile,
+ string keyfile, string cert_password, bool big_endian)
+ : Client(big_endian),
+ client_(client),
+ certfile_(std::move(certfile)),
+ keyfile_(std::move(keyfile)),
+ cert_password_(std::move(cert_password)),
+ ctx_(nullptr),
+ ssl_(nullptr) {}
+
+SslWrapper::~SslWrapper() {
+ if (IsConnected()) {
+ Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+
+ if (ctx_ != nullptr) {
+ SSL_CTX_free(ctx_);
+ ctx_ = nullptr;
+ }
+
+ if (ssl_ != nullptr) {
+ SSL_free(ssl_);
+ ssl_ = nullptr;
+ }
+}
+
+Status SslWrapper::InitSslContext() {
+ OpenSSL_add_all_algorithms();
+ SSL_load_error_strings();
+
+ ctx_ = SSL_CTX_new(SSLv23_method());
+ if (ctx_ == NULL) return errors::Internal("Couldn't create SSL context");
+
+ SSL_CTX_set_default_passwd_cb(ctx_, PasswordCb);
+ SSL_CTX_set_default_passwd_cb_userdata(ctx_, (void *)cert_password_.c_str());
+
+ if (SSL_CTX_use_certificate_chain_file(ctx_, certfile_.c_str()) != 1)
+ return errors::Internal("Couldn't load cetificate chain (file '", certfile_,
+ "')");
+
+ string private_key_file = keyfile_.empty() ? certfile_ : keyfile_;
+ if (SSL_CTX_use_PrivateKey_file(ctx_, private_key_file.c_str(),
+ SSL_FILETYPE_PEM) != 1)
+ return errors::Internal("Couldn't load private key (file '",
+ private_key_file, "')");
+
+ return Status::OK();
+}
+
+Status SslWrapper::Connect() {
+ if (ctx_ == NULL) {
+ TF_RETURN_IF_ERROR(InitSslContext());
+ }
+
+ ssl_ = SSL_new(ctx_);
+ if (ssl_ == NULL)
+ return errors::Internal("Failed to establish SSL connection");
+
+ TF_RETURN_IF_ERROR(client_->Connect());
+
+ SSL_set_fd(ssl_, client_->GetSocketDescriptor());
+ if (SSL_connect(ssl_) != 1)
+ return errors::Internal("Failed to establish SSL connection");
+
+ LOG(INFO) << "SSL connection established";
+
+ return Status::OK();
+}
+
+Status SslWrapper::Disconnect() {
+ SSL_free(ssl_);
+ ssl_ = nullptr;
+
+ LOG(INFO) << "SSL connection closed";
+
+ return client_->Disconnect();
+}
+
+bool SslWrapper::IsConnected() { return client_->IsConnected(); }
+
+int SslWrapper::GetSocketDescriptor() { return client_->GetSocketDescriptor(); }
+
+Status SslWrapper::ReadData(uint8_t *buf, const int32_t length) {
+ int received = 0;
+
+ while (received < length) {
+ int res = SSL_read(ssl_, buf, length - received);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while reading from SSL socket: ",
+ res);
+
+ if (res == 0) return errors::Internal("Server closed SSL connection");
+
+ received += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+Status SslWrapper::WriteData(const uint8_t *buf, const int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = SSL_write(ssl_, buf, length - sent);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while writing into socket: ",
+ res);
+
+ sent += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h
new file mode 100644
index 0000000000..0406644bba
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_client.h"
+
+#include <openssl/ssl.h>
+
+namespace tensorflow {
+
+class SslWrapper : public Client {
+ public:
+ SslWrapper(std::shared_ptr<Client> client, string certfile, string keyfile,
+ string cert_password, bool big_endian);
+ ~SslWrapper();
+
+ Status Connect() override;
+ Status Disconnect() override;
+ bool IsConnected() override;
+ int GetSocketDescriptor() override;
+ Status ReadData(uint8_t* buf, const int32_t length) override;
+ Status WriteData(const uint8_t* buf, const int32_t length) override;
+
+ private:
+ Status InitSslContext();
+
+ std::shared_ptr<Client> client_;
+ string certfile_;
+ string keyfile_;
+ string cert_password_;
+ SSL_CTX* ctx_;
+ SSL* ssl_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_
diff --git a/tensorflow/contrib/ignite/ops/dataset_ops.cc b/tensorflow/contrib/ignite/ops/dataset_ops.cc
new file mode 100644
index 0000000000..3d6fbe00e6
--- /dev/null
+++ b/tensorflow/contrib/ignite/ops/dataset_ops.cc
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("IgniteDataset")
+ .Input("cache_name: string")
+ .Input("host: string")
+ .Input("port: int32")
+ .Input("local: bool")
+ .Input("part: int32")
+ .Input("page_size: int32")
+ .Input("schema: int32")
+ .Input("permutation: int32")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+IgniteDataset that allows to get data from Apache Ignite.
+
+Apache Ignite is a memory-centric distributed database, caching, and processing
+platform for transactional, analytical, and streaming workloads, delivering
+in-memory speeds at petabyte scale. This contrib package contains an
+integration between Apache Ignite and TensorFlow. The integration is based on
+tf.data from TensorFlow side and Binary Client Protocol from Apache Ignite side.
+It allows to use Apache Ignite as a datasource for neural network training,
+inference and all other computations supported by TensorFlow. Ignite Dataset
+is based on Apache Ignite Binary Client Protocol.
+
+cache_name: Ignite Cache Name.
+host: Ignite Thin Client Host.
+port: Ignite Thin Client Port.
+local: Local flag that defines that data should be fetched from local host only.
+part: Partition data should be fetched from.
+page_size: Page size for Ignite Thin Client.
+schema: Internal structure that defines schema of cache objects.
+permutation: Internal structure that defines permutation of cache objects.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py
new file mode 100644
index 0000000000..cfe59b6b23
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py
@@ -0,0 +1,772 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ignite Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import socket
+import ssl
+import struct
+
+from tensorflow.contrib.ignite.python.ops import gen_dataset_ops
+from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class Readable(object):
+ """Readable abstract class that exposes methods to do reading-related
+
+ operations.
+ """
+
+ @abc.abstractmethod
+ def __init__(self):
+ pass
+
+ def read_byte(self):
+ """Reads and returnes byte."""
+ return self._read("b", 1)
+
+ def read_short(self):
+ """Reads and returns short (2 bytes, little-endian)."""
+ return self._read("h", 2)
+
+ def read_int(self):
+ """Reads and returns int (4 bytes, little-endian)."""
+ return self._read("i", 4)
+
+ def read_long(self):
+ """Reads and returns long (8 bytes, little-endian)."""
+ return self._read("q", 8)
+
+ def skip(self, length):
+ """Skips the specified number of bytes."""
+ self.read_data(length)
+
+ @abc.abstractmethod
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ return None
+
+ def _read(self, data_type, length):
+ """Reads, unpacks and returns specified type (little-endian)."""
+ data_buffer = self.read_data(length)
+ return struct.unpack("<" + data_type, data_buffer)[0]
+
+
+class DataBuffer(Readable):
+ """DataBuffer class that exposes methods to read data from a byte buffer."""
+
+ def __init__(self, data_buffer):
+ """Constructs a new instance based on the specified byte buffer.
+
+ Args:
+ data_buffer: Buffer to be read.
+ """
+ Readable.__init__(self)
+ self.buffer = data_buffer
+ self.ptr = 0
+
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ data_buffer = self.buffer[self.ptr:][:length]
+ self.ptr += length
+ return data_buffer
+
+
+class TcpClient(Readable):
+ """TcpClient class that exposes methods to read data from a socket."""
+
+ def __init__(self, host, port, certfile=None, keyfile=None, password=None):
+ """Constructs a new instance based on the specified host and port.
+
+ Args:
+ host: Host to be connected.
+ port: Port to be connected.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate's
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key will
+ be taken from certfile as well).
+ password: Password to be used if the private key is encrypted and a
+ password is necessary.
+
+ Raises:
+ ValueError: If the wrong combination of arguments is provided.
+ """
+ Readable.__init__(self)
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+ if certfile is not None:
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.load_cert_chain(certfile, keyfile, password)
+ self.sock = context.wrap_socket(self.sock)
+ else:
+ if keyfile is not None:
+ raise ValueError("SSL is disabled, keyfile must not be specified "
+ "(to enable SSL specify certfile)")
+ if password is not None:
+ raise ValueError("SSL is disabled, password must not be specified "
+ "(to enable SSL specify certfile)")
+
+ self.host = host
+ self.port = port
+
+ def __enter__(self):
+ """Connects to host and port specified in the constructor."""
+ self.sock.connect((self.host, self.port))
+ return self
+
+ def __exit__(self, t, v, traceback):
+ """Disconnects the socket."""
+ self.sock.close()
+
+ def write_byte(self, v):
+ """Writes the specified byte."""
+ self._write(v, "b")
+
+ def write_short(self, v):
+ """Writes the specified short (2 bytes, little-endian)."""
+ self._write(v, "h")
+
+ def write_int(self, v):
+ """Writes the specified short (4 bytes, little-endian)."""
+ self._write(v, "i")
+
+ def write_long(self, v):
+ """Writes the specified int (8 bytes, little-endian)."""
+ self._write(v, "q")
+
+ def write_string(self, v):
+ """Writes the specified string."""
+ self.sock.sendall(v.encode("UTF-8"))
+
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ data_buffer = None
+ rem = length
+ while rem > 0:
+ buf = self.sock.recv(rem)
+ rem = rem - len(buf)
+ if data_buffer is None:
+ data_buffer = buf
+ else:
+ data_buffer += buf
+ return data_buffer
+
+ def _write(self, value, data_type):
+ """Packs and writes data using the specified type (little-endian)."""
+ data_buffer = struct.pack("<" + data_type, value)
+ self.sock.sendall(data_buffer)
+
+
+class BinaryType(object):
+ """BinaryType class that encapsulated type id, type name and fields."""
+
+ def __init__(self, type_id, type_name, fields):
+ """Constructs a new instance of BinaryType."""
+ self.type_id = type_id
+ self.type_name = type_name
+ self.fields = fields
+
+
+class BinaryField(object):
+ """BinaryField class that encapsulated field name, type id and field id."""
+
+ def __init__(self, field_name, type_id, field_id):
+ """Constructs a new instance of BinaryField."""
+ self.field_name = field_name
+ self.type_id = type_id
+ self.field_id = field_id
+
+
+# Binary types defined in Apache Ignite Thin client and supported by
+# TensorFlow on Apache Ignite, see
+# https://apacheignite.readme.io/v2.6/docs/binary-client-protocol.
+# True means that type is a vector, False means type is scalar.
+types = {
+ 1: (dtypes.uint8, False),
+ 2: (dtypes.int16, False),
+ 3: (dtypes.int32, False),
+ 4: (dtypes.int64, False),
+ 5: (dtypes.float32, False),
+ 6: (dtypes.float64, False),
+ 7: (dtypes.uint16, False),
+ 8: (dtypes.bool, False),
+ 9: (dtypes.string, False),
+ 12: (dtypes.uint8, True),
+ 13: (dtypes.int16, True),
+ 14: (dtypes.int32, True),
+ 15: (dtypes.int64, True),
+ 16: (dtypes.float32, True),
+ 17: (dtypes.float64, True),
+ 18: (dtypes.uint16, True),
+ 19: (dtypes.bool, True),
+ 20: (dtypes.string, True)
+}
+
+
+class TypeTreeNode(object):
+ """TypeTreeNode class exposes methods to format object tree structure
+
+ data.
+ """
+
+ def __init__(self, name, type_id, fields=None, permutation=None):
+ """Constructs a new instance of TypeTreeNode.
+
+ Args:
+ name: Name of the object tree node.
+ type_id: Type id of the object tree node.
+ fields: List of fields (children of the object tree node).
+ permutation: Permutation that should be applied to order object children.
+ """
+ self.name = name
+ self.type_id = type_id
+ self.fields = fields
+ self.permutation = permutation
+
+ def to_output_classes(self):
+ """Formats the tree object as required by `Dataset.output_classes`."""
+ if self.fields is None:
+ return ops.Tensor
+ output_classes = {}
+ for field in self.fields:
+ output_classes[field.name] = field.to_output_classes()
+ return output_classes
+
+ def to_output_shapes(self):
+ """Formats the tree object as required by `Dataset.output_shapes`."""
+ if self.fields is None:
+ if self.type_id in types:
+ object_type = types[self.type_id]
+ is_array = object_type[1]
+ if is_array:
+ return tensor_shape.TensorShape([None])
+ return tensor_shape.TensorShape([])
+ raise ValueError("Unsupported type [type_id=%d]" % self.type_id)
+ output_shapes = {}
+ for field in self.fields:
+ output_shapes[field.name] = field.to_output_shapes()
+ return output_shapes
+
+ def to_output_types(self):
+ """Formats the tree object as required by `Dataset.output_types`."""
+ if self.fields is None:
+ if self.type_id in types:
+ object_type = types[self.type_id]
+ return object_type[0]
+ raise ValueError("Unsupported type [type_id=%d]" % self.type_id)
+ else:
+ output_types = {}
+ for field in self.fields:
+ output_types[field.name] = field.to_output_types()
+ return output_types
+
+ def to_flat(self):
+ """Returns a list of node types."""
+ return self.to_flat_rec([])
+
+ def to_permutation(self):
+ """Returns a permutation that should be applied to order object leaves."""
+ correct_order_dict = {}
+ self.traversal_rec(correct_order_dict, 0)
+ object_order = []
+ self.traversal_permutation_rec(object_order)
+ return [correct_order_dict[o] for o in object_order]
+
+ def to_flat_rec(self, flat):
+ """Formats a list of leaf node types in pre-order."""
+ if self.fields is None:
+ flat.append(self.type_id)
+ else:
+ for field in self.fields:
+ field.to_flat_rec(flat)
+ return flat
+
+ def traversal_permutation_rec(self, permutation):
+ """Collects nodes in accordance with permutation."""
+ if self.fields is None:
+ permutation.append(self)
+ else:
+ for idx in self.permutation:
+ field = self.fields[idx]
+ field.traversal_permutation_rec(permutation)
+
+ def traversal_rec(self, d, i):
+ """Collects nodes in pre-order traversal."""
+ if self.fields is None:
+ d[self] = i
+ i += 1
+ else:
+ for field in self.fields:
+ i = field.traversal_rec(d, i)
+ return i
+
+
+class IgniteClient(TcpClient):
+ """IgniteClient enables working with Apache Ignite using a thin client.
+
+ This client works with assumption that all object in the cache
+ have the same structure (homogeneous objects) and the cache contains at
+ least one object.
+ """
+
+ def __init__(self,
+ host,
+ port,
+ username=None,
+ password=None,
+ certfile=None,
+ keyfile=None,
+ cert_password=None):
+ """Constructs a new instance of IgniteClient.
+
+ Args:
+ host: Apache Ignite Thin client host to be connected.
+ port: Apache Ignite Thin client port to be connected.
+ username: Apache Ignite Thin Client authentication username.
+ password: Apache Ignite Thin Client authentication password.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate's
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key will
+ be taken from certfile as well).
+ cert_password: Password to be used if the private key is encrypted and a
+ password is necessary.
+ """
+ TcpClient.__init__(self, host, port, certfile, keyfile, cert_password)
+ self.username = username
+ self.password = password
+
+ def handshake(self):
+ """Makes a handshake after connect and before any other calls."""
+ msg_len = 8
+
+ if self.username is None:
+ msg_len += 1
+ else:
+ msg_len += 5 + len(self.username)
+
+ if self.password is None:
+ msg_len += 1
+ else:
+ msg_len += 5 + len(self.password)
+
+ self.write_int(msg_len) # Message length
+ self.write_byte(1) # Handshake operation
+ self.write_short(1) # Version (1.1.0)
+ self.write_short(1)
+ self.write_short(0)
+ self.write_byte(2) # Thin client
+
+ if self.username is None: # Username
+ self.write_byte(101)
+ else:
+ self.write_byte(9)
+ self.write_int(len(self.username))
+ self.write_string(self.username)
+
+ if self.password is None: # Password
+ self.write_byte(101)
+ else:
+ self.write_byte(9)
+ self.write_int(len(self.password))
+ self.write_string(self.password)
+
+ self.read_int() # Result length
+ res = self.read_byte()
+
+ if res != 1:
+ serv_ver_major = self.read_short()
+ serv_ver_minor = self.read_short()
+ serv_ver_patch = self.read_short()
+ err_msg = self._parse_string()
+ if err_msg is None:
+ raise RuntimeError(
+ "Handshake Error [result=%d, version=%d.%d.%d]" %
+ (res, serv_ver_major, serv_ver_minor, serv_ver_patch))
+ else:
+ raise RuntimeError(
+ "Handshake Error [result=%d, version=%d.%d.%d, message='%s']" %
+ (res, serv_ver_major, serv_ver_minor, serv_ver_patch, err_msg))
+
+ def get_cache_type(self, cache_name):
+ """Collects type information about objects stored in the specified cache."""
+ cache_name_hash = self._java_hash_code(cache_name)
+ self.write_int(25) # Message length
+ self.write_short(2000) # Operation code
+ self.write_long(0) # Request ID
+ self.write_int(cache_name_hash) # Cache name
+ self.write_byte(0) # Flags
+ self.write_byte(101) # Filter (NULL)
+ self.write_int(1) # Cursor page size
+ self.write_int(-1) # Partition to query
+ self.write_byte(0) # Local flag
+
+ result_length = self.read_int()
+ self.read_long() # Request id
+ status = self.read_int()
+
+ if status != 0:
+ err_msg = self._parse_string()
+ if err_msg is None:
+ raise RuntimeError("Scan Query Error [status=%s]" % status)
+ else:
+ raise RuntimeError(
+ "Scan Query Error [status=%s, message='%s']" % (status, err_msg))
+
+ self.read_long() # Cursor id
+ row_count = self.read_int()
+
+ if row_count == 0:
+ raise RuntimeError("Scan Query returned empty result, so it's "
+ "impossible to derive the cache type")
+
+ payload = DataBuffer(self.read_data(result_length - 25))
+
+ self.read_byte() # Next page
+
+ res = TypeTreeNode("root", 0, [
+ self._collect_types("key", payload),
+ self._collect_types("val", payload)
+ ], [0, 1])
+
+ return res
+
+ def _java_hash_code(self, s):
+ """Computes hash code of the specified string using Java code."""
+ h = 0
+ for c in s:
+ h = (31 * h + ord(c)) & 0xFFFFFFFF
+ return ((h + 0x80000000) & 0xFFFFFFFF) - 0x80000000
+
+ def _collect_types(self, field_name, data):
+ """Extracts type information from the specified object."""
+ type_id = data.read_byte()
+
+ # Byte scalar.
+ if type_id == 1:
+ data.skip(1)
+ return TypeTreeNode(field_name, type_id)
+
+ # Short scalar.
+ if type_id == 2:
+ data.skip(2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Integer scalar.
+ if type_id == 3:
+ data.skip(4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Long scalar.
+ if type_id == 4:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Float scalar.
+ if type_id == 5:
+ data.skip(4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Double scalar.
+ if type_id == 6:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Char scalar.
+ if type_id == 7:
+ data.skip(2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Bool scalar.
+ if type_id == 8:
+ data.skip(1)
+ return TypeTreeNode(field_name, type_id)
+
+ # String scalar.
+ if type_id == 9:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # UUID scalar.
+ if type_id == 10:
+ data.skip(16)
+ return TypeTreeNode(field_name, type_id)
+
+ # Date scalar.
+ if type_id == 11:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Byte array.
+ if type_id == 12:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # Short array.
+ if type_id == 13:
+ length = data.read_int()
+ data.skip(length * 2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Integer array.
+ if type_id == 14:
+ length = data.read_int()
+ data.skip(length * 4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Long array.
+ if type_id == 15:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Float array.
+ if type_id == 16:
+ length = data.read_int()
+ data.skip(length * 4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Double array.
+ if type_id == 17:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Char array.
+ if type_id == 18:
+ length = data.read_int()
+ data.skip(length * 2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Bool array.
+ if type_id == 19:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # String array.
+ if type_id == 20:
+ length = data.read_int()
+ for _ in range(length):
+ header = data.read_byte()
+ if header == 9:
+ str_length = data.read_int()
+ data.skip(str_length)
+ elif header == 101:
+ pass
+ else:
+ raise RuntimeError(
+ "Unknown binary type when expected string [type_id=%d]" % header)
+ return TypeTreeNode(field_name, type_id)
+
+ # UUID array.
+ if type_id == 21:
+ length = data.read_int()
+ data.skip(length * 16) # TODO(dmitrievanthony): support NULL values.
+ return TypeTreeNode(field_name, type_id)
+
+ # Date array.
+ if type_id == 22:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Wrapped Binary Object.
+ if type_id == 27:
+ length = data.read_int()
+ inner_data = data.read_data(length)
+ data.read_int() # Offset
+ return self._collect_types(field_name, DataBuffer(inner_data))
+
+ # Complex Object.
+ if type_id == 103:
+ data.read_byte() # Object version
+ data.read_short() # Object flags
+ obj_type_id = data.read_int()
+ data.read_int() # Object hash code
+ obj_length = data.read_int()
+ data.read_int() # Object schema id
+ obj_schema_offset = data.read_int()
+
+ obj_type = self._get_type(obj_type_id)
+ children = []
+
+ for obj_field in obj_type.fields:
+ child = self._collect_types(obj_field.field_name, data)
+ children.append(child)
+
+ children_sorted = sorted(children, key=lambda child: child.name)
+ permutation = [children_sorted.index(child) for child in children]
+ children = children_sorted
+
+ data.skip(obj_length - obj_schema_offset)
+
+ return TypeTreeNode(field_name, type_id, children, permutation)
+
+ raise RuntimeError("Unknown binary type [type_id=%d]" % type_id)
+
+ def _get_type(self, type_id):
+ """Queries Apache Ignite information about type by type id."""
+ self.write_int(14) # Message length
+ self.write_short(3002) # Operation code
+ self.write_long(0) # Request ID
+ self.write_int(type_id) # Type ID
+
+ self.read_int() # Result length
+ self.read_long() # Request id
+ status = self.read_int()
+
+ if status != 0:
+ err_msg = self._parse_string()
+ if err_msg is None:
+ raise RuntimeError("Get Binary Type Error [status=%d, message='%s']" %
+ (status, err_msg))
+ else:
+ raise RuntimeError("Get Binary Type Error [status=%d]" % status)
+
+ binary_type_exists = self.read_byte()
+
+ if binary_type_exists == 0:
+ raise RuntimeError("Binary type not found [type_id=%d] " % type_id)
+
+ binary_type_id = self.read_int()
+ binary_type_name = self._parse_string()
+ self._parse_string() # Affinity field name
+
+ fields = []
+ for _ in range(self.read_int()):
+ field_name = self._parse_string()
+ field_type_id = self.read_int()
+ field_id = self.read_int()
+
+ field = BinaryField(field_name, field_type_id, field_id)
+ fields.append(field)
+
+ is_enum = self.read_byte()
+ if is_enum == 1:
+ raise RuntimeError("Enum fields are not supported yet")
+
+ schema_cnt = self.read_int()
+ for _ in range(schema_cnt):
+ self.read_int() # Schema id
+ field_cnt = self.read_int()
+ self.skip(field_cnt * 4)
+
+ return BinaryType(binary_type_id, binary_type_name, fields)
+
+ def _parse_string(self):
+ """Parses string."""
+ header = self.read_byte()
+ if header == 9:
+ length = self.read_int()
+ return self.read_data(length).decode("utf-8")
+ if header == 101:
+ return None
+ raise RuntimeError(
+ "Unknown binary type when expected string [type_id=%d]" % header)
+
+
+class IgniteDataset(dataset_ops.Dataset):
+ """Apache Ignite is a memory-centric distributed database, caching, and
+
+ processing platform for transactional, analytical, and streaming workloads,
+ delivering in-memory speeds at petabyte scale. This contrib package
+ contains an integration between Apache Ignite and TensorFlow. The
+ integration is based on tf.data from TensorFlow side and Binary Client
+ Protocol from Apache Ignite side. It allows to use Apache Ignite as a
+ datasource for neural network training, inference and all other
+ computations supported by TensorFlow. Ignite Dataset is based on Apache
+ Ignite Binary Client Protocol.
+ """
+
+ def __init__(self,
+ cache_name,
+ host="localhost",
+ port=10800,
+ local=False,
+ part=-1,
+ page_size=100,
+ username=None,
+ password=None,
+ certfile=None,
+ keyfile=None,
+ cert_password=None):
+ """Create a IgniteDataset.
+
+ Args:
+ cache_name: Cache name to be used as datasource.
+ host: Apache Ignite Thin Client host to be connected.
+ port: Apache Ignite Thin Client port to be connected.
+ local: Local flag that defines to query only local data.
+ part: Number of partitions to be queried.
+ page_size: Apache Ignite Thin Client page size.
+ username: Apache Ignite Thin Client authentication username.
+ password: Apache Ignite Thin Client authentication password.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate's
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key will
+ be taken from certfile as well).
+ cert_password: Password to be used if the private key is encrypted and a
+ password is necessary.
+ """
+ super(IgniteDataset, self).__init__()
+
+ with IgniteClient(host, port, username, password, certfile, keyfile,
+ cert_password) as client:
+ client.handshake()
+ self.cache_type = client.get_cache_type(cache_name)
+
+ self.cache_name = ops.convert_to_tensor(
+ cache_name, dtype=dtypes.string, name="cache_name")
+ self.host = ops.convert_to_tensor(host, dtype=dtypes.string, name="host")
+ self.port = ops.convert_to_tensor(port, dtype=dtypes.int32, name="port")
+ self.local = ops.convert_to_tensor(local, dtype=dtypes.bool, name="local")
+ self.part = ops.convert_to_tensor(part, dtype=dtypes.int32, name="part")
+ self.page_size = ops.convert_to_tensor(
+ page_size, dtype=dtypes.int32, name="page_size")
+ self.schema = ops.convert_to_tensor(
+ self.cache_type.to_flat(), dtype=dtypes.int32, name="schema")
+ self.permutation = ops.convert_to_tensor(
+ self.cache_type.to_permutation(),
+ dtype=dtypes.int32,
+ name="permutation")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port,
+ self.local, self.part, self.page_size,
+ self.schema, self.permutation)
+
+ @property
+ def output_classes(self):
+ return self.cache_type.to_output_classes()
+
+ @property
+ def output_shapes(self):
+ return self.cache_type.to_output_shapes()
+
+ @property
+ def output_types(self):
+ return self.cache_type.to_output_types()
diff --git a/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py
new file mode 100644
index 0000000000..c9af7386cf
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python helper for loading Ignite ops and kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_dataset_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
diff --git a/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh
new file mode 100755
index 0000000000..f4607ce8ad
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-plain.xml &
+sleep 5 # Wait Apache Ignite to be started
+
+./apache-ignite-fabric/bin/sqlline.sh \
+-u "jdbc:ignite:thin://127.0.0.1/" \
+--run=/data/sql/init.sql
+
+tail -f nohup.out
diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml
new file mode 100644
index 0000000000..d900174a8a
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml
@@ -0,0 +1,39 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<beans xmlns="http://www.springframework.org/schema/beans"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xmlns:util="http://www.springframework.org/schema/util"
+ xsi:schemaLocation="http://www.springframework.org/schema/beans
+ http://www.springframework.org/schema/beans/spring-beans.xsd
+ http://www.springframework.org/schema/util
+ http://www.springframework.org/schema/util/spring-util.xsd">
+
+ <bean class="org.apache.ignite.configuration.IgniteConfiguration">
+ <property name="discoverySpi">
+ <bean class="org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi">
+ <property name="ipFinder">
+ <bean class="org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder">
+ <property name="addresses">
+ <list>
+ <value>127.0.0.1</value>
+ </list>
+ </property>
+ </bean>
+ </property>
+ </bean>
+ </property>
+ </bean>
+
+</beans>
diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py
new file mode 100644
index 0000000000..1856a4fba8
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py
@@ -0,0 +1,118 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for IgniteDataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.ignite import IgniteDataset
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class IgniteDatasetTest(test.TestCase):
+ """The Apache Ignite servers have to setup before the test and tear down
+
+ after the test manually. The docker engine has to be installed.
+
+ To setup Apache Ignite servers:
+ $ bash start_ignite.sh
+
+ To tear down Apache Ignite servers:
+ $ bash stop_ignite.sh
+ """
+
+ def test_ignite_dataset_with_plain_client(self):
+ """Test Ignite Dataset with plain client.
+
+ """
+ self._clear_env()
+ ds = IgniteDataset(cache_name="SQL_PUBLIC_TEST_CACHE", port=42300)
+ self._check_dataset(ds)
+
+ def test_ignite_dataset_with_ssl_client(self):
+ """Test Ignite Dataset with ssl client.
+
+ """
+ self._clear_env()
+ os.environ["IGNITE_DATASET_CERTFILE"] = os.path.dirname(
+ os.path.realpath(__file__)) + "/keystore/client.pem"
+ os.environ["IGNITE_DATASET_CERT_PASSWORD"] = "123456"
+
+ ds = IgniteDataset(
+ cache_name="SQL_PUBLIC_TEST_CACHE",
+ port=42301,
+ certfile=os.environ["IGNITE_DATASET_CERTFILE"],
+ cert_password=os.environ["IGNITE_DATASET_CERT_PASSWORD"])
+ self._check_dataset(ds)
+
+ def test_ignite_dataset_with_ssl_client_and_auth(self):
+ """Test Ignite Dataset with ssl client and authentication.
+
+ """
+ self._clear_env()
+ os.environ["IGNITE_DATASET_USERNAME"] = "ignite"
+ os.environ["IGNITE_DATASET_PASSWORD"] = "ignite"
+ os.environ["IGNITE_DATASET_CERTFILE"] = os.path.dirname(
+ os.path.realpath(__file__)) + "/keystore/client.pem"
+ os.environ["IGNITE_DATASET_CERT_PASSWORD"] = "123456"
+
+ ds = IgniteDataset(
+ cache_name="SQL_PUBLIC_TEST_CACHE",
+ port=42302,
+ certfile=os.environ["IGNITE_DATASET_CERTFILE"],
+ cert_password=os.environ["IGNITE_DATASET_CERT_PASSWORD"],
+ username=os.environ["IGNITE_DATASET_USERNAME"],
+ password=os.environ["IGNITE_DATASET_PASSWORD"])
+ self._check_dataset(ds)
+
+ def _clear_env(self):
+ """Clears environment variables used by Ignite Dataset.
+
+ """
+ if "IGNITE_DATASET_USERNAME" in os.environ:
+ del os.environ["IGNITE_DATASET_USERNAME"]
+ if "IGNITE_DATASET_PASSWORD" in os.environ:
+ del os.environ["IGNITE_DATASET_PASSWORD"]
+ if "IGNITE_DATASET_CERTFILE" in os.environ:
+ del os.environ["IGNITE_DATASET_CERTFILE"]
+ if "IGNITE_DATASET_CERT_PASSWORD" in os.environ:
+ del os.environ["IGNITE_DATASET_CERT_PASSWORD"]
+
+ def _check_dataset(self, dataset):
+ """Checks that dataset provides correct data."""
+ self.assertEqual(dtypes.int64, dataset.output_types["key"])
+ self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"])
+ self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"])
+
+ it = dataset.make_one_shot_iterator()
+ ne = it.get_next()
+
+ with session.Session() as sess:
+ rows = [sess.run(ne), sess.run(ne), sess.run(ne)]
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(ne)
+
+ self.assertEqual({"key": 1, "val": {"NAME": b"TEST1", "VAL": 42}}, rows[0])
+ self.assertEqual({"key": 2, "val": {"NAME": b"TEST2", "VAL": 43}}, rows[1])
+ self.assertEqual({"key": 3, "val": {"NAME": b"TEST3", "VAL": 44}}, rows[2])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/ignite/python/tests/sql/init.sql b/tensorflow/contrib/ignite/python/tests/sql/init.sql
new file mode 100644
index 0000000000..5a192aef17
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/sql/init.sql
@@ -0,0 +1,20 @@
+-- Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+--
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ==============================================================================
+
+CREATE TABLE TEST_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR, VAL LONG);
+
+INSERT INTO TEST_CACHE VALUES (1, 'TEST1', 42);
+INSERT INTO TEST_CACHE VALUES (2, 'TEST2', 43);
+INSERT INTO TEST_CACHE VALUES (3, 'TEST3', 44);
diff --git a/tensorflow/contrib/ignite/python/tests/start_ignite.sh b/tensorflow/contrib/ignite/python/tests/start_ignite.sh
new file mode 100755
index 0000000000..a67bd44f2f
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/start_ignite.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+IGNITE_VERSION=2.6.0
+SCRIPT_PATH="$( cd "$(dirname "$0")" ; pwd -P )"
+
+# Start Apache Ignite with plain client listener.
+docker run -itd --name ignite-plain -p 42300:10800 \
+-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-plain.sh
diff --git a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh
new file mode 100755
index 0000000000..8f03dbd1ed
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+docker rm -f ignite-plain
+docker rm -f ignite-ssl
+docker rm -f ignite-ssl-auth