aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/lookup_table_init_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/lookup_table_init_op.cc')
-rw-r--r--tensorflow/core/kernels/lookup_table_init_op.cc116
1 files changed, 116 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc
new file mode 100644
index 0000000000..9781bcfa59
--- /dev/null
+++ b/tensorflow/core/kernels/lookup_table_init_op.cc
@@ -0,0 +1,116 @@
+#define EIGEN_USE_THREADS
+
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/initializable_lookup_table.h"
+#include "tensorflow/core/kernels/lookup_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace lookup {
+
+// Iterator to initialize tables given 'keys' and 'values' tensors.
+//
+// The two tensors are returned in the first iteration. It doesn't loop
+// over each element of the tensor since insertions in the lookup table can
+// process batches.
+class KeyValueTensorIterator
+ : public InitializableLookupTable::InitTableIterator {
+ public:
+ // keys and values are not owned by the iterator.
+ explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
+ : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
+ TensorShape key_shape = keys_->shape();
+ if (!key_shape.IsSameSize(values_->shape())) {
+ valid_ = false;
+ status_ = errors::InvalidArgument(
+ "keys and values should have the same dimension.",
+ key_shape.DebugString(), " vs ", values_->shape().DebugString());
+ }
+ if (key_shape.num_elements() == 0) {
+ valid_ = false;
+ status_ =
+ errors::InvalidArgument("keys and values cannot be empty tensors.");
+ }
+ }
+
+ bool Valid() const override { return valid_; }
+
+ void Next() override {
+ valid_ = false;
+ status_ = errors::OutOfRange("No more data.");
+ }
+
+ const Tensor& keys() const override { return *keys_; }
+
+ const Tensor& values() const override { return *values_; }
+
+ Status status() const override { return status_; }
+
+ int64 total_size() const {
+ return keys_ == nullptr ? -1 : keys_->NumElements();
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
+
+ const Tensor* keys_; // Doesn't own it.
+ const Tensor* values_; // Doesn't own it.
+ bool valid_; // true if the iterator points to an existing range.
+ Status status_;
+};
+
+} // namespace lookup
+
+// Kernel to initialize a look table given a key and value tensors.
+// After this operation, the table becomes read-only.
+class InitializeTableOp : public OpKernel {
+ public:
+ explicit InitializeTableOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ mutex_lock l(mu_);
+ lookup::InitializableLookupTable* table;
+ OP_REQUIRES_OK(ctx,
+ GetInitializableLookupTable("table_handle", ctx, &table));
+ core::ScopedUnref unref_me(table);
+
+ DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
+ table->value_dtype()};
+ DataTypeVector expected_outputs = {};
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
+
+ const Tensor& keys = ctx->input(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(keys.shape()),
+ errors::InvalidArgument("Keys must be a vector, but received ",
+ keys.shape().DebugString()));
+
+ const Tensor& values = ctx->input(2);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(values.shape()),
+ errors::InvalidArgument("Values must be a vector, but received ",
+ values.shape().DebugString()));
+
+ OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(),
+ errors::InvalidArgument(
+ "Keys and values must have the same size ",
+ keys.NumElements(), " vs ", values.NumElements()));
+
+ lookup::KeyValueTensorIterator iter(&keys, &values);
+ OP_REQUIRES_OK(ctx, table->Initialize(iter));
+ }
+
+ private:
+ mutex mu_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU),
+ InitializeTableOp);
+
+} // namespace tensorflow