diff options
Diffstat (limited to 'tensorflow/core/kernels/lookup_table_init_op.cc')
-rw-r--r-- | tensorflow/core/kernels/lookup_table_init_op.cc | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc new file mode 100644 index 0000000000..9781bcfa59 --- /dev/null +++ b/tensorflow/core/kernels/lookup_table_init_op.cc @@ -0,0 +1,116 @@ +#define EIGEN_USE_THREADS + +#include <string> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/initializable_lookup_table.h" +#include "tensorflow/core/kernels/lookup_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace lookup { + +// Iterator to initialize tables given 'keys' and 'values' tensors. +// +// The two tensors are returned in the first iteration. It doesn't loop +// over each element of the tensor since insertions in the lookup table can +// process batches. +class KeyValueTensorIterator + : public InitializableLookupTable::InitTableIterator { + public: + // keys and values are not owned by the iterator. + explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values) + : keys_(keys), values_(values), valid_(true), status_(Status::OK()) { + TensorShape key_shape = keys_->shape(); + if (!key_shape.IsSameSize(values_->shape())) { + valid_ = false; + status_ = errors::InvalidArgument( + "keys and values should have the same dimension.", + key_shape.DebugString(), " vs ", values_->shape().DebugString()); + } + if (key_shape.num_elements() == 0) { + valid_ = false; + status_ = + errors::InvalidArgument("keys and values cannot be empty tensors."); + } + } + + bool Valid() const override { return valid_; } + + void Next() override { + valid_ = false; + status_ = errors::OutOfRange("No more data."); + } + + const Tensor& keys() const override { return *keys_; } + + const Tensor& values() const override { return *values_; } + + Status status() const override { return status_; } + + int64 total_size() const { + return keys_ == nullptr ? -1 : keys_->NumElements(); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator); + + const Tensor* keys_; // Doesn't own it. + const Tensor* values_; // Doesn't own it. + bool valid_; // true if the iterator points to an existing range. + Status status_; +}; + +} // namespace lookup + +// Kernel to initialize a look table given a key and value tensors. +// After this operation, the table becomes read-only. +class InitializeTableOp : public OpKernel { + public: + explicit InitializeTableOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + mutex_lock l(mu_); + lookup::InitializableLookupTable* table; + OP_REQUIRES_OK(ctx, + GetInitializableLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + table->value_dtype()}; + DataTypeVector expected_outputs = {}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); + + const Tensor& keys = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(keys.shape()), + errors::InvalidArgument("Keys must be a vector, but received ", + keys.shape().DebugString())); + + const Tensor& values = ctx->input(2); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(values.shape()), + errors::InvalidArgument("Values must be a vector, but received ", + values.shape().DebugString())); + + OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(), + errors::InvalidArgument( + "Keys and values must have the same size ", + keys.NumElements(), " vs ", values.NumElements())); + + lookup::KeyValueTensorIterator iter(&keys, &values); + OP_REQUIRES_OK(ctx, table->Initialize(iter)); + } + + private: + mutex mu_; +}; + +REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU), + InitializeTableOp); + +} // namespace tensorflow |