aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc')
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc447
1 files changed, 447 insertions, 0 deletions
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
new file mode 100644
index 0000000000..03cc3c1291
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
@@ -0,0 +1,447 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_dataset_iterator.h"
+
+#include "ignite_plain_client.h"
+#include "ignite_ssl_wrapper.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include <time.h>
+#include <chrono>
+
+namespace ignite {
+
+#define CHECK_STATUS(status) \
+ if (!status.ok()) return status;
+
+IgniteDatasetIterator::IgniteDatasetIterator(
+ const Params& params, std::string host, tensorflow::int32 port,
+ std::string cache_name, bool local, tensorflow::int32 part,
+ tensorflow::int32 page_size, std::string username, std::string password,
+ std::string certfile, std::string keyfile, std::string cert_password,
+ std::vector<tensorflow::int32> schema,
+ std::vector<tensorflow::int32> permutation)
+ : tensorflow::DatasetIterator<IgniteDataset>(params),
+ cache_name(cache_name),
+ local(local),
+ part(part),
+ page_size(page_size),
+ username(username),
+ password(password),
+ schema(schema),
+ permutation(permutation),
+ remainder(-1),
+ cursor_id(-1),
+ last_page(false) {
+ Client* p_client = new PlainClient(host, port);
+
+ if (certfile.empty())
+ client = std::unique_ptr<Client>(p_client);
+ else
+ client = std::unique_ptr<Client>(new SslWrapper(
+ std::unique_ptr<Client>(p_client), certfile, keyfile, cert_password));
+
+ LOG(INFO) << "Ignite Dataset Iterator created";
+}
+
+IgniteDatasetIterator::~IgniteDatasetIterator() {
+ tensorflow::Status status = CloseConnection();
+ if (!status.ok()) LOG(ERROR) << status.ToString();
+
+ LOG(INFO) << "Ignite Dataset Iterator destroyed";
+}
+
+tensorflow::Status IgniteDatasetIterator::EstablishConnection() {
+ if (!client->IsConnected()) {
+ tensorflow::Status status = client->Connect();
+ if (!status.ok()) return status;
+
+ status = Handshake();
+ if (!status.ok()) {
+ tensorflow::Status disconnect_status = client->Disconnect();
+ if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString();
+
+ return status;
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::CloseConnection() {
+ if (cursor_id != -1 && !last_page) {
+ tensorflow::Status conn_status = EstablishConnection();
+ if (!conn_status.ok()) return conn_status;
+
+ CHECK_STATUS(client->WriteInt(18)); // Message length
+ CHECK_STATUS(
+ client->WriteShort(close_connection_opcode)); // Operation code
+ CHECK_STATUS(client->WriteLong(0)); // Request ID
+ CHECK_STATUS(client->WriteLong(cursor_id)); // Resource ID
+
+ int32_t res_len;
+ CHECK_STATUS(client->ReadInt(res_len));
+ if (res_len < 12)
+ return tensorflow::errors::Internal(
+ "Close Resource Response is corrupted");
+
+ int64_t req_id;
+ CHECK_STATUS(client->ReadLong(req_id));
+ int32_t status;
+ CHECK_STATUS(client->ReadInt(status));
+ if (status != 0) {
+ uint8_t err_msg_header;
+ CHECK_STATUS(client->ReadByte(err_msg_header));
+ if (err_msg_header == string_val) {
+ int32_t err_msg_length;
+ CHECK_STATUS(client->ReadInt(err_msg_length));
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ CHECK_STATUS(client->ReadData(err_msg_c, err_msg_length));
+ std::string err_msg((char*)err_msg_c, err_msg_length);
+ delete[] err_msg_c;
+
+ return tensorflow::errors::Internal("Close Resource Error [status=",
+ status, ", message=", err_msg, "]");
+ }
+ return tensorflow::errors::Internal("Close Resource Error [status=",
+ status, "]");
+ }
+
+ LOG(INFO) << "Query Cursor " << cursor_id << " is closed";
+
+ cursor_id = -1;
+
+ return client->Disconnect();
+ } else {
+ LOG(INFO) << "Query Cursor " << cursor_id << " is already closed";
+ }
+
+ return client->IsConnected() ? client->Disconnect()
+ : tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::GetNextInternal(
+ tensorflow::IteratorContext* ctx,
+ std::vector<tensorflow::Tensor>* out_tensors, bool* end_of_sequence) {
+ if (remainder == 0 && last_page) {
+ LOG(INFO) << "Query Cursor " << cursor_id << " is closed";
+
+ cursor_id = -1;
+ *end_of_sequence = true;
+ return tensorflow::Status::OK();
+ } else {
+ tensorflow::Status status = EstablishConnection();
+ if (!status.ok()) return status;
+
+ if (remainder == -1 || remainder == 0) {
+ tensorflow::Status status =
+ remainder == -1 ? ScanQuery() : LoadNextPage();
+ if (!status.ok()) return status;
+ }
+
+ uint8_t* initial_ptr = ptr;
+ std::vector<int32_t> types;
+ std::vector<tensorflow::Tensor> tensors;
+
+ status = parser.Parse(ptr, tensors, types); // Parse key
+ if (!status.ok()) return status;
+
+ status = parser.Parse(ptr, tensors, types); // Parse val
+ if (!status.ok()) return status;
+
+ remainder -= (ptr - initial_ptr);
+
+ out_tensors->resize(tensors.size());
+ for (int32_t i = 0; i < tensors.size(); i++)
+ (*out_tensors)[permutation[i]] = std::move(tensors[i]);
+
+ *end_of_sequence = false;
+ return tensorflow::Status::OK();
+ }
+
+ *end_of_sequence = true;
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::SaveInternal(
+ tensorflow::IteratorStateWriter* writer) {
+ return tensorflow::errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'SaveInternal'");
+}
+
+tensorflow::Status IgniteDatasetIterator::RestoreInternal(
+ tensorflow::IteratorContext* ctx, tensorflow::IteratorStateReader* reader) {
+ return tensorflow::errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'RestoreInternal')");
+}
+
+tensorflow::Status IgniteDatasetIterator::Handshake() {
+ int32_t msg_len = 8;
+
+ if (username.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + username.length();
+
+ if (password.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + password.length();
+
+ CHECK_STATUS(client->WriteInt(msg_len));
+ CHECK_STATUS(client->WriteByte(1));
+ CHECK_STATUS(client->WriteShort(protocol_major_version));
+ CHECK_STATUS(client->WriteShort(protocol_minor_version));
+ CHECK_STATUS(client->WriteShort(protocol_patch_version));
+ CHECK_STATUS(client->WriteByte(2));
+ if (username.empty()) {
+ CHECK_STATUS(client->WriteByte(null_val));
+ } else {
+ CHECK_STATUS(client->WriteByte(string_val));
+ CHECK_STATUS(client->WriteInt(username.length()));
+ CHECK_STATUS(
+ client->WriteData((uint8_t*)username.c_str(), username.length()));
+ }
+
+ if (password.empty()) {
+ CHECK_STATUS(client->WriteByte(null_val));
+ } else {
+ CHECK_STATUS(client->WriteByte(string_val));
+ CHECK_STATUS(client->WriteInt(password.length()));
+ CHECK_STATUS(
+ client->WriteData((uint8_t*)password.c_str(), password.length()));
+ }
+
+ int32_t handshake_res_len;
+ CHECK_STATUS(client->ReadInt(handshake_res_len));
+ uint8_t handshake_res;
+ CHECK_STATUS(client->ReadByte(handshake_res));
+
+ LOG(INFO) << "Handshake length " << handshake_res_len << ", res "
+ << (int16_t)handshake_res;
+
+ if (handshake_res != 1) {
+ int16_t serv_ver_major;
+ CHECK_STATUS(client->ReadShort(serv_ver_major));
+ int16_t serv_ver_minor;
+ CHECK_STATUS(client->ReadShort(serv_ver_minor));
+ int16_t serv_ver_patch;
+ CHECK_STATUS(client->ReadShort(serv_ver_patch));
+ uint8_t header;
+ CHECK_STATUS(client->ReadByte(header));
+
+ if (header == string_val) {
+ int32_t length;
+ CHECK_STATUS(client->ReadInt(length));
+ uint8_t* err_msg_c = new uint8_t[length];
+ CHECK_STATUS(client->ReadData(err_msg_c, length));
+ std::string err_msg((char*)err_msg_c, length);
+ delete[] err_msg_c;
+
+ return tensorflow::errors::Internal(
+ "Handshake Error [result=", handshake_res, ", version=",
+ serv_ver_major, ".", serv_ver_minor, ".", serv_ver_patch,
+ ", message='", err_msg, "']");
+ } else if (header == null_val) {
+ return tensorflow::errors::Internal(
+ "Handshake Error [result=", handshake_res, ", version=",
+ serv_ver_major, ".", serv_ver_minor, ".", serv_ver_patch, "]");
+ } else {
+ return tensorflow::errors::Internal(
+ "Handshake Error [result=", handshake_res, ", version=",
+ serv_ver_major, ".", serv_ver_minor, ".", serv_ver_patch, "]");
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::ScanQuery() {
+ CHECK_STATUS(client->WriteInt(25)); // Message length
+ CHECK_STATUS(client->WriteShort(scan_query_opcode)); // Operation code
+ CHECK_STATUS(client->WriteLong(0)); // Request ID
+ CHECK_STATUS(client->WriteInt(JavaHashCode(cache_name))); // Cache name
+ CHECK_STATUS(client->WriteByte(0)); // Flags
+ CHECK_STATUS(client->WriteByte(null_val)); // Filter object
+ CHECK_STATUS(client->WriteInt(page_size)); // Cursor page size
+ CHECK_STATUS(client->WriteInt(part)); // Partition to query
+ CHECK_STATUS(client->WriteByte(local)); // Local flag
+
+ int64_t wait_start = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ int32_t res_len;
+ CHECK_STATUS(client->ReadInt(res_len));
+
+ int64_t wait_stop = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) << " ms";
+
+ if (res_len < 12)
+ return tensorflow::errors::Internal("Scan Query Response is corrupted");
+
+ int64_t req_id;
+ CHECK_STATUS(client->ReadLong(req_id));
+
+ int32_t status;
+ CHECK_STATUS(client->ReadInt(status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ CHECK_STATUS(client->ReadByte(err_msg_header));
+
+ if (err_msg_header == string_val) {
+ int32_t err_msg_length;
+ CHECK_STATUS(client->ReadInt(err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ CHECK_STATUS(client->ReadData(err_msg_c, err_msg_length));
+ std::string err_msg((char*)err_msg_c, err_msg_length);
+ delete[] err_msg_c;
+
+ return tensorflow::errors::Internal("Scan Query Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return tensorflow::errors::Internal("Scan Query Error [status=", status,
+ "]");
+ }
+
+ CHECK_STATUS(client->ReadLong(cursor_id));
+
+ LOG(INFO) << "Query Cursor " << cursor_id << " is opened";
+
+ int32_t row_cnt;
+ CHECK_STATUS(client->ReadInt(row_cnt));
+
+ remainder = res_len - 25;
+ page = std::unique_ptr<uint8_t>(new uint8_t[remainder]);
+ ptr = page.get();
+
+ int64_t start = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ CHECK_STATUS(client->ReadData(ptr, remainder));
+
+ int64_t stop = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+ ;
+
+ double size_in_mb = 1.0 * remainder / 1024 / 1024;
+ double time_in_s = 1.0 * (stop - start) / 1000;
+ LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
+ << " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
+
+ uint8_t last_page_b;
+ CHECK_STATUS(client->ReadByte(last_page_b));
+
+ last_page = !last_page_b;
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status IgniteDatasetIterator::LoadNextPage() {
+ CHECK_STATUS(client->WriteInt(18)); // Message length
+ CHECK_STATUS(client->WriteShort(load_next_page_opcode)); // Operation code
+ CHECK_STATUS(client->WriteLong(0)); // Request ID
+ CHECK_STATUS(client->WriteLong(cursor_id)); // Cursor ID
+
+ int64_t wait_start = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ int32_t res_len;
+ CHECK_STATUS(client->ReadInt(res_len));
+
+ int64_t wait_stop = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) << " ms";
+
+ if (res_len < 12)
+ return tensorflow::errors::Internal("Load Next Page Response is corrupted");
+
+ int64_t req_id;
+ CHECK_STATUS(client->ReadLong(req_id));
+
+ int32_t status;
+ CHECK_STATUS(client->ReadInt(status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ CHECK_STATUS(client->ReadByte(err_msg_header));
+
+ if (err_msg_header == string_val) {
+ int32_t err_msg_length;
+ CHECK_STATUS(client->ReadInt(err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ CHECK_STATUS(client->ReadData(err_msg_c, err_msg_length));
+ std::string err_msg((char*)err_msg_c, err_msg_length);
+ delete[] err_msg_c;
+
+ return tensorflow::errors::Internal("Load Next Page Error [status=",
+ status, ", message=", err_msg, "]");
+ }
+ return tensorflow::errors::Internal("Load Next Page Error [status=", status,
+ "]");
+ }
+
+ int32_t row_cnt;
+ CHECK_STATUS(client->ReadInt(row_cnt));
+
+ remainder = res_len - 17;
+ page = std::unique_ptr<uint8_t>(new uint8_t[remainder]);
+ ptr = page.get();
+
+ int64_t start = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+
+ CHECK_STATUS(client->ReadData(ptr, remainder));
+
+ int64_t stop = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+ ;
+
+ double size_in_mb = 1.0 * remainder / 1024 / 1024;
+ double time_in_s = 1.0 * (stop - start) / 1000;
+ LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
+ << " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
+
+ uint8_t last_page_b;
+ CHECK_STATUS(client->ReadByte(last_page_b));
+
+ last_page = !last_page_b;
+
+ return tensorflow::Status::OK();
+}
+
+int32_t IgniteDatasetIterator::JavaHashCode(std::string str) {
+ int32_t h = 0;
+ for (char& c : str) {
+ h = 31 * h + c;
+ }
+ return h;
+}
+
+} // namespace ignite