diff options
Diffstat (limited to 'tensorflow/core/kernels/lookup_util.cc')
-rw-r--r-- | tensorflow/core/kernels/lookup_util.cc | 50 |
1 files changed, 36 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index f87ce0e6b2..d0f269be23 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -49,26 +49,48 @@ Status GetLookupTable(const string& input_name, OpKernelContext* ctx, LookupInterface** table) { string container; string table_handle; - TF_RETURN_IF_ERROR( - GetTableHandle(input_name, ctx, &container, &table_handle)); - return ctx->resource_manager()->Lookup(container, table_handle, table); + DataType handle_dtype; + TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); + if (handle_dtype == DT_RESOURCE) { + ResourceHandle handle; + TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); + return LookupResource(ctx, handle, table); + } else { + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + return ctx->resource_manager()->Lookup(container, table_handle, table); + } } Status GetInitializableLookupTable(const string& input_name, OpKernelContext* ctx, InitializableLookupTable** table) { - string container; - string table_handle; - TF_RETURN_IF_ERROR( - GetTableHandle(input_name, ctx, &container, &table_handle)); LookupInterface* lookup_table; - TF_RETURN_IF_ERROR( - ctx->resource_manager()->Lookup(container, table_handle, &lookup_table)); - *table = lookup_table->GetInitializableLookupTable(); - if (*table == nullptr) { - lookup_table->Unref(); - return errors::InvalidArgument("Table ", container, " ", table_handle, - " is not initializable"); + DataType handle_dtype; + TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); + if (handle_dtype == DT_RESOURCE) { + ResourceHandle handle; + TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); + TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table)); + *table = lookup_table->GetInitializableLookupTable(); + if (*table == nullptr) { + lookup_table->Unref(); + return errors::InvalidArgument("Table ", handle.container(), " ", + handle.name(), " is not initializable"); + } + } else { + string container; + string table_handle; + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle, + &lookup_table)); + *table = lookup_table->GetInitializableLookupTable(); + if (*table == nullptr) { + lookup_table->Unref(); + return errors::InvalidArgument("Table ", container, " ", table_handle, + " is not initializable"); + } } return Status::OK(); } |