diff options
Diffstat (limited to 'tensorflow/core/kernels/lookup_table_init_op.cc')
-rw-r--r-- | tensorflow/core/kernels/lookup_table_init_op.cc | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc index bde1d0360a..ada6fe8d95 100644 --- a/tensorflow/core/kernels/lookup_table_init_op.cc +++ b/tensorflow/core/kernels/lookup_table_init_op.cc @@ -367,7 +367,9 @@ class InitializeTableOp : public OpKernel { GetInitializableLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), table->value_dtype()}; DataTypeVector expected_outputs = {}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); @@ -408,6 +410,8 @@ class InitializeTableOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU), InitializeTableOp); +REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU), + InitializeTableOp); // Kernel to initialize a lookup table from a text file. // @@ -433,7 +437,9 @@ class InitializeTableFromTextFileOp : public OpKernel { GetInitializableLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, DT_STRING}; + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, DT_STRING}; DataTypeVector expected_outputs = {}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); @@ -472,5 +478,8 @@ class InitializeTableFromTextFileOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU), InitializeTableFromTextFileOp); +REGISTER_KERNEL_BUILDER( + Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU), + InitializeTableFromTextFileOp); } // namespace tensorflow |