aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/lookup_table_init_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/lookup_table_init_op.cc')
-rw-r--r--tensorflow/core/kernels/lookup_table_init_op.cc13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc
index bde1d0360a..ada6fe8d95 100644
--- a/tensorflow/core/kernels/lookup_table_init_op.cc
+++ b/tensorflow/core/kernels/lookup_table_init_op.cc
@@ -367,7 +367,9 @@ class InitializeTableOp : public OpKernel {
GetInitializableLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
- DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
table->value_dtype()};
DataTypeVector expected_outputs = {};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
@@ -408,6 +410,8 @@ class InitializeTableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU),
InitializeTableOp);
+REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU),
+ InitializeTableOp);
// Kernel to initialize a lookup table from a text file.
//
@@ -433,7 +437,9 @@ class InitializeTableFromTextFileOp : public OpKernel {
GetInitializableLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
- DataTypeVector expected_inputs = {DT_STRING_REF, DT_STRING};
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, DT_STRING};
DataTypeVector expected_outputs = {};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
@@ -472,5 +478,8 @@ class InitializeTableFromTextFileOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU),
InitializeTableFromTextFileOp);
+REGISTER_KERNEL_BUILDER(
+ Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
+ InitializeTableFromTextFileOp);
} // namespace tensorflow