aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc')
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc149
1 files changed, 149 insertions, 0 deletions
diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
new file mode 100644
index 0000000000..a1101b91f3
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_ssl_wrapper.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+namespace ignite {
+
+static int PasswordCb(char *buf, int size, int rwflag, void *password) {
+ strncpy(buf, (char *)(password), size);
+ buf[size - 1] = '\0';
+ return (strlen(buf));
+}
+
+SslWrapper::SslWrapper(std::shared_ptr<Client> client, std::string certfile,
+ std::string keyfile, std::string cert_password)
+ : client(client),
+ certfile(certfile),
+ keyfile(keyfile),
+ cert_password(cert_password),
+ ctx(NULL) {}
+
+SslWrapper::~SslWrapper() {
+ if (IsConnected()) {
+ tensorflow::Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+
+ if (ctx != NULL) {
+ SSL_CTX_free(ctx);
+ ctx = NULL;
+ }
+}
+
+tensorflow::Status SslWrapper::InitSslContext() {
+ OpenSSL_add_all_algorithms();
+ SSL_load_error_strings();
+
+ ctx = SSL_CTX_new(SSLv23_method());
+ if (ctx == NULL)
+ return tensorflow::errors::Internal("Couldn't create SSL context");
+
+ SSL_CTX_set_default_passwd_cb(ctx, PasswordCb);
+ SSL_CTX_set_default_passwd_cb_userdata(ctx, (void *)cert_password.c_str());
+
+ if (SSL_CTX_use_certificate_chain_file(ctx, certfile.c_str()) != 1)
+ return tensorflow::errors::Internal(
+ "Couldn't load cetificate chain (file '", certfile, "')");
+
+ std::string private_key_file = keyfile.empty() ? certfile : keyfile;
+ if (SSL_CTX_use_PrivateKey_file(ctx, private_key_file.c_str(),
+ SSL_FILETYPE_PEM) != 1)
+ return tensorflow::errors::Internal("Couldn't load private key (file '",
+ private_key_file, "')");
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status SslWrapper::Connect() {
+ tensorflow::Status status;
+
+ if (ctx == NULL) {
+ status = InitSslContext();
+ if (!status.ok()) return status;
+ }
+
+ ssl = SSL_new(ctx);
+ if (ssl == NULL)
+ return tensorflow::errors::Internal("Failed to establish SSL connection");
+
+ status = client->Connect();
+ if (!status.ok()) return status;
+
+ SSL_set_fd(ssl, client->GetSocketDescriptor());
+ if (SSL_connect(ssl) != 1)
+ return tensorflow::errors::Internal("Failed to establish SSL connection");
+
+ LOG(INFO) << "SSL connection established";
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status SslWrapper::Disconnect() {
+ SSL_free(ssl);
+
+ LOG(INFO) << "SSL connection closed";
+
+ return client->Disconnect();
+}
+
+bool SslWrapper::IsConnected() { return client->IsConnected(); }
+
+int SslWrapper::GetSocketDescriptor() { return client->GetSocketDescriptor(); }
+
+tensorflow::Status SslWrapper::ReadData(uint8_t *buf, int32_t length) {
+ int recieved = 0;
+
+ while (recieved < length) {
+ int res = SSL_read(ssl, buf, length - recieved);
+
+ if (res < 0)
+ return tensorflow::errors::Internal(
+ "Error occured while reading from SSL socket: ", res);
+
+ if (res == 0)
+ return tensorflow::errors::Internal("Server closed SSL connection");
+
+ recieved += res;
+ buf += res;
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status SslWrapper::WriteData(uint8_t *buf, int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = SSL_write(ssl, buf, length - sent);
+
+ if (res < 0)
+ return tensorflow::errors::Internal(
+ "Error occured while writing into socket: ", res);
+
+ sent += res;
+ buf += res;
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace ignite