diff options
Diffstat (limited to 'tensorflow/core/kernels/lookup_util.cc')
-rw-r--r-- | tensorflow/core/kernels/lookup_util.cc | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc new file mode 100644 index 0000000000..634c11e4a5 --- /dev/null +++ b/tensorflow/core/kernels/lookup_util.cc @@ -0,0 +1,72 @@ +#include "tensorflow/core/kernels/lookup_util.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { +namespace lookup { +namespace { + +Status GetTableHandle(const string& input_name, OpKernelContext* ctx, + string* container, string* table_handle) { + { + mutex* mu; + TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); + mutex_lock l(*mu); + Tensor tensor; + TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); + if (tensor.NumElements() != 2) { + return errors::InvalidArgument( + "Lookup table handle must be scalar, but had shape: ", + tensor.shape().DebugString()); + } + auto h = tensor.flat<string>(); + *container = h(0); + *table_handle = h(1); + } + return Status::OK(); +} + +} // namespace + +Status GetLookupTable(const string& input_name, OpKernelContext* ctx, + LookupInterface** table) { + string container; + string table_handle; + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + return ctx->resource_manager()->Lookup(container, table_handle, table); +} + +Status GetInitializableLookupTable(const string& input_name, + OpKernelContext* ctx, + InitializableLookupTable** table) { + string container; + string table_handle; + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + LookupInterface* lookup_table; + TF_RETURN_IF_ERROR( + ctx->resource_manager()->Lookup(container, table_handle, &lookup_table)); + *table = dynamic_cast<InitializableLookupTable*>(lookup_table); + if (*table == nullptr) { + lookup_table->Unref(); + return errors::InvalidArgument("Table ", container, " ", table_handle, + " is not initializable"); + } + return Status::OK(); +} + +Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, + DataType value_dtype, const string& table_name) { + if (table.key_dtype() != key_dtype || table.value_dtype() != value_dtype) { + return errors::InvalidArgument( + "Conflicting key/value dtypes ", key_dtype, "->", value_dtype, " with ", + table.key_dtype(), "-", table.value_dtype(), " for table ", table_name); + } + return Status::OK(); +} + +} // namespace lookup +} // namespace tensorflow |