aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc')
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc145
1 files changed, 145 insertions, 0 deletions
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
new file mode 100644
index 0000000000..543b5e4afc
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ignite_dataset.h"
+#include <stdlib.h>
+#include "tensorflow/core/framework/dataset.h"
+
+namespace tensorflow {
+
+class IgniteDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ std::string cache_name = "";
+ std::string host = "";
+ int32 port = -1;
+ bool local = false;
+ int32 part = -1;
+ int32 page_size = -1;
+ std::string username = "";
+ std::string password = "";
+ std::string certfile = "";
+ std::string keyfile = "";
+ std::string cert_password = "";
+
+ const char* env_cache_name = std::getenv("IGNITE_DATASET_CACHE_NAME");
+ const char* env_host = std::getenv("IGNITE_DATASET_HOST");
+ const char* env_port = std::getenv("IGNITE_DATASET_PORT");
+ const char* env_local = std::getenv("IGNITE_DATASET_LOCAL");
+ const char* env_part = std::getenv("IGNITE_DATASET_PART");
+ const char* env_page_size = std::getenv("IGNITE_DATASET_PAGE_SIZE");
+ const char* env_username = std::getenv("IGNITE_DATASET_USERNAME");
+ const char* env_password = std::getenv("IGNITE_DATASET_PASSWORD");
+ const char* env_certfile = std::getenv("IGNITE_DATASET_CERTFILE");
+ const char* env_keyfile = std::getenv("IGNITE_DATASET_KEYFILE");
+ const char* env_cert_password = std::getenv("IGNITE_DATASET_CERT_PASSWORD");
+
+ if (env_cache_name)
+ cache_name = std::string(env_cache_name);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "cache_name",
+ &cache_name));
+
+ if (env_host)
+ host = std::string(env_host);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "host", &host));
+
+ if (env_port)
+ port = atoi(env_port);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "port", &port));
+
+ if (env_local)
+ local = true;
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "local", &local));
+
+ if (env_part)
+ part = atoi(env_part);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "part", &part));
+
+ if (env_page_size)
+ page_size = atoi(env_page_size);
+ else
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int32>(ctx, "page_size", &page_size));
+
+ if (env_username)
+ username = std::string(env_username);
+ else
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<std::string>(ctx, "username", &username));
+
+ if (env_password)
+ password = std::string(env_password);
+ else
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<std::string>(ctx, "password", &password));
+
+ if (env_certfile)
+ certfile = std::string(env_certfile);
+ else
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<std::string>(ctx, "certfile", &certfile));
+
+ if (env_keyfile)
+ keyfile = std::string(env_keyfile);
+ else
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<std::string>(ctx, "keyfile", &keyfile));
+
+ if (env_cert_password)
+ cert_password = std::string(env_cert_password);
+ else
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "cert_password",
+ &cert_password));
+
+ const Tensor* schema_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("schema", &schema_tensor));
+ OP_REQUIRES(ctx, schema_tensor->dims() == 1,
+ errors::InvalidArgument("`schema` must be a vector."));
+
+ std::vector<int32> schema;
+ schema.reserve(schema_tensor->NumElements());
+ for (int i = 0; i < schema_tensor->NumElements(); i++) {
+ schema.push_back(schema_tensor->flat<int32>()(i));
+ }
+
+ const Tensor* permutation_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("permutation", &permutation_tensor));
+ OP_REQUIRES(ctx, schema_tensor->dims() == 1,
+ errors::InvalidArgument("`permutation` must be a vector."));
+
+ std::vector<int32> permutation;
+ permutation.reserve(permutation_tensor->NumElements());
+ for (int i = 0; i < permutation_tensor->NumElements(); i++) {
+ permutation.push_back(permutation_tensor->flat<int32>()(i));
+ }
+
+ *output = new ignite::IgniteDataset(
+ ctx, cache_name, host, port, local, part, page_size, username, password,
+ certfile, keyfile, cert_password, std::move(schema),
+ std::move(permutation));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("IgniteDataset").Device(DEVICE_CPU),
+ IgniteDatasetOp);
+
+} // namespace tensorflow