diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-05-01 11:15:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-01 12:29:29 -0700 |
commit | 9528658bb51d3a4e14ec1014e8fdd8e0076805e8 (patch) | |
tree | 83cd29bd789c9bf48dd951198a62cabe73170813 /tensorflow | |
parent | 03327190420dd5b1c34a5ffdd0000aff40980ed5 (diff) |
LookupTable ops go to V2, use resource handles.
There's some duplicate documentation; yutaka will deprecate the old ops
and remove the duplicate docs, update the shape tests to point to the new ops,
and modify lookup_ops.py to use the _v2 ops.
Change: 154743350
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/framework/resource_mgr.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/framework/resource_mgr.h | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/lookup_table_init_op.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.cc | 51 | ||||
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.h | 72 | ||||
-rw-r--r-- | tensorflow/core/kernels/lookup_util.cc | 50 | ||||
-rw-r--r-- | tensorflow/core/ops/data_flow_ops.cc | 312 | ||||
-rw-r--r-- | tensorflow/python/ops/hidden_ops.txt | 11 |
8 files changed, 470 insertions, 49 deletions
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 7f9fe084ba..ab7dd0c547 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -246,6 +246,14 @@ ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) { return ctx->input(input).flat<ResourceHandle>()(0); } +Status HandleFromInput(OpKernelContext* ctx, StringPiece input, + ResourceHandle* handle) { + const Tensor* tensor; + TF_RETURN_IF_ERROR(ctx->input(input, &tensor)); + *handle = tensor->flat<ResourceHandle>()(0); + return Status::OK(); +} + Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); return ctx->resource_manager()->Delete(p); diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index fe6e09378f..26a5766569 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -211,6 +211,8 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, // Returns a resource handle from a numbered op input. ResourceHandle HandleFromInput(OpKernelContext* ctx, int input); +Status HandleFromInput(OpKernelContext* ctx, StringPiece input, + ResourceHandle* handle); // Create a resource pointed by a given resource handle. template <typename T> diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc index bde1d0360a..ada6fe8d95 100644 --- a/tensorflow/core/kernels/lookup_table_init_op.cc +++ b/tensorflow/core/kernels/lookup_table_init_op.cc @@ -367,7 +367,9 @@ class InitializeTableOp : public OpKernel { GetInitializableLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), table->value_dtype()}; DataTypeVector expected_outputs = {}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); @@ -408,6 +410,8 @@ class InitializeTableOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU), InitializeTableOp); +REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU), + InitializeTableOp); // Kernel to initialize a lookup table from a text file. // @@ -433,7 +437,9 @@ class InitializeTableFromTextFileOp : public OpKernel { GetInitializableLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, DT_STRING}; + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, DT_STRING}; DataTypeVector expected_outputs = {}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); @@ -472,5 +478,8 @@ class InitializeTableFromTextFileOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU), InitializeTableFromTextFileOp); +REGISTER_KERNEL_BUILDER( + Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU), + InitializeTableFromTextFileOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 0a065e37d7..11ce2a71dc 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -624,7 +624,10 @@ class LookupTableFindOp : public OpKernel { OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + // Input 0 could be a STRING_REF or a RESOURCE + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), table->value_dtype()}; DataTypeVector expected_outputs = {table->value_dtype()}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); @@ -647,6 +650,8 @@ class LookupTableFindOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU), LookupTableFindOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU), + LookupTableFindOp); // Table insert op. class LookupTableInsertOp : public OpKernel { @@ -658,7 +663,9 @@ class LookupTableInsertOp : public OpKernel { OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), table->value_dtype()}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); @@ -680,6 +687,8 @@ class LookupTableInsertOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU), LookupTableInsertOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU), + LookupTableInsertOp); // Op that returns the size of the given table. class LookupTableSizeOp : public OpKernel { @@ -699,6 +708,8 @@ class LookupTableSizeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU), LookupTableSizeOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2").Device(DEVICE_CPU), + LookupTableSizeOp); // Op that outputs tensors of all keys and all values. class LookupTableExportOp : public OpKernel { @@ -716,6 +727,8 @@ class LookupTableExportOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU), LookupTableExportOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2").Device(DEVICE_CPU), + LookupTableExportOp); // Clear the table and insert data. class LookupTableImportOp : public OpKernel { @@ -727,7 +740,9 @@ class LookupTableImportOp : public OpKernel { OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); core::ScopedUnref unref_me(table); - DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(), + DataType expected_input_0 = + (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; + DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), table->value_dtype()}; OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); @@ -749,6 +764,8 @@ class LookupTableImportOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU), LookupTableImportOp); +REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU), + LookupTableImportOp); // Register the HashTable op with the currently supported key and value types. #define REGISTER_KERNEL(key_dtype, value_dtype) \ @@ -758,6 +775,13 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU), .TypeConstraint<key_dtype>("key_dtype") \ .TypeConstraint<value_dtype>("value_dtype"), \ LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ + value_dtype>) \ + REGISTER_KERNEL_BUILDER( \ + Name("HashTableV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ value_dtype>) REGISTER_KERNEL(string, double); @@ -779,6 +803,13 @@ REGISTER_KERNEL(string, bool); .TypeConstraint<key_dtype>("key_dtype") \ .TypeConstraint<value_dtype>("value_dtype"), \ LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ + key_dtype, value_dtype>) \ + REGISTER_KERNEL_BUILDER( \ + Name("MutableHashTableV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ key_dtype, value_dtype>) REGISTER_KERNEL(string, float); @@ -797,6 +828,13 @@ REGISTER_KERNEL(int64, float); .TypeConstraint<key_dtype>("key_dtype") \ .TypeConstraint<value_dtype>("value_dtype"), \ LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ + key_dtype, value_dtype>) \ + REGISTER_KERNEL_BUILDER( \ + Name("MutableHashTableOfTensorsV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ key_dtype, value_dtype>) REGISTER_KERNEL(string, float); @@ -814,6 +852,13 @@ REGISTER_KERNEL(string, bool); .TypeConstraint<key_dtype>("key_dtype") \ .TypeConstraint<value_dtype>("value_dtype"), \ LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ + key_dtype, value_dtype>) \ + REGISTER_KERNEL_BUILDER( \ + Name("MutableDenseHashTableV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<key_dtype>("key_dtype") \ + .TypeConstraint<value_dtype>("value_dtype"), \ + LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ key_dtype, value_dtype>) REGISTER_KERNEL(int64, int64); diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index ae253b4dc9..4cd25a3cc6 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -51,40 +51,52 @@ class LookupTableOp : public OpKernel { // ctx is not owned by this function. void Compute(OpKernelContext* ctx) override { mutex_lock l(mu_); + if (!table_handle_set_) { OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), use_node_name_sharing_)); - auto creator = [ctx, this](lookup::LookupInterface** ret) { - lookup::LookupInterface* container = new Container(ctx, this); - if (!ctx->status().ok()) { - container->Unref(); - return ctx->status(); - } - if (ctx->track_allocations()) { - ctx->record_device_persistent_memory_allocation( - container->MemoryUsed()); - } - *ret = container; - return Status::OK(); - }; - - lookup::LookupInterface* table = nullptr; - OP_REQUIRES_OK( - ctx, cinfo_.resource_manager() - ->template LookupOrCreate<lookup::LookupInterface>( - cinfo_.container(), cinfo_.name(), &table, creator)); - core::ScopedUnref unref_me(table); - - OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes( - *table, DataTypeToEnum<key_dtype>::v(), - DataTypeToEnum<value_dtype>::v(), cinfo_.name())); - - auto h = table_handle_.AccessTensor(ctx)->template flat<string>(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); - table_handle_set_ = true; } - ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx)); + + auto creator = [ctx, this](lookup::LookupInterface** ret) { + lookup::LookupInterface* container = new Container(ctx, this); + if (!ctx->status().ok()) { + container->Unref(); + return ctx->status(); + } + if (ctx->track_allocations()) { + ctx->record_device_persistent_memory_allocation( + container->MemoryUsed()); + } + *ret = container; + return Status::OK(); + }; + + lookup::LookupInterface* table = nullptr; + OP_REQUIRES_OK(ctx, + cinfo_.resource_manager() + ->template LookupOrCreate<lookup::LookupInterface>( + cinfo_.container(), cinfo_.name(), &table, creator)); + core::ScopedUnref unref_me(table); + + OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes( + *table, DataTypeToEnum<key_dtype>::v(), + DataTypeToEnum<value_dtype>::v(), cinfo_.name())); + + if (ctx->expected_output_dtype(0) == DT_RESOURCE) { + Tensor* handle; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); + handle->scalar<ResourceHandle>()() = + MakeResourceHandle<lookup::LookupInterface>(ctx, cinfo_.container(), + cinfo_.name()); + } else { + if (!table_handle_set_) { + auto h = table_handle_.AccessTensor(ctx)->template flat<string>(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + } + ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx)); + } + table_handle_set_ = true; } ~LookupTableOp() override { diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index f87ce0e6b2..d0f269be23 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -49,26 +49,48 @@ Status GetLookupTable(const string& input_name, OpKernelContext* ctx, LookupInterface** table) { string container; string table_handle; - TF_RETURN_IF_ERROR( - GetTableHandle(input_name, ctx, &container, &table_handle)); - return ctx->resource_manager()->Lookup(container, table_handle, table); + DataType handle_dtype; + TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); + if (handle_dtype == DT_RESOURCE) { + ResourceHandle handle; + TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); + return LookupResource(ctx, handle, table); + } else { + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + return ctx->resource_manager()->Lookup(container, table_handle, table); + } } Status GetInitializableLookupTable(const string& input_name, OpKernelContext* ctx, InitializableLookupTable** table) { - string container; - string table_handle; - TF_RETURN_IF_ERROR( - GetTableHandle(input_name, ctx, &container, &table_handle)); LookupInterface* lookup_table; - TF_RETURN_IF_ERROR( - ctx->resource_manager()->Lookup(container, table_handle, &lookup_table)); - *table = lookup_table->GetInitializableLookupTable(); - if (*table == nullptr) { - lookup_table->Unref(); - return errors::InvalidArgument("Table ", container, " ", table_handle, - " is not initializable"); + DataType handle_dtype; + TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); + if (handle_dtype == DT_RESOURCE) { + ResourceHandle handle; + TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle)); + TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table)); + *table = lookup_table->GetInitializableLookupTable(); + if (*table == nullptr) { + lookup_table->Unref(); + return errors::InvalidArgument("Table ", handle.container(), " ", + handle.name(), " is not initializable"); + } + } else { + string container; + string table_handle; + TF_RETURN_IF_ERROR( + GetTableHandle(input_name, ctx, &container, &table_handle)); + TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle, + &lookup_table)); + *table = lookup_table->GetInitializableLookupTable(); + if (*table == nullptr) { + lookup_table->Unref(); + return errors::InvalidArgument("Table ", container, " ", table_handle, + " is not initializable"); + } } return Status::OK(); } diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index b34dd4ae90..f82e9d1eb7 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -210,10 +210,29 @@ Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { return Status::OK(); } +Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + for (int i = 1; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); + } + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + Status TwoElementOutput(InferenceContext* c) { c->set_output(0, c->Vector(2)); return Status::OK(); } + +Status ScalarOutput(InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); +} } // namespace REGISTER_OP("RandomShuffleQueue") @@ -1881,6 +1900,38 @@ values: Same shape as `keys`. Values found in the table, or `default_values` for missing keys. )doc"); +REGISTER_OP("LookupTableFindV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("default_value: Tout") + .Output("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // Default value must be scalar or vector. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }) + .Doc(R"doc( +Looks up keys in a table, outputs the corresponding values. + +The tensor `keys` must of the same type as the keys of the table. +The output `values` is of the type of the table values. + +The scalar `default_value` is the value output for keys not present in the +table. It must also be of the same type as the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Same shape as `keys`. Values found in the table, or `default_values` + for missing keys. +)doc"); + REGISTER_OP("LookupTableInsert") .Input("table_handle: Ref(string)") .Input("keys: Tin") @@ -1893,6 +1944,30 @@ REGISTER_OP("LookupTableInsert") DimensionHandle unused_dim; TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Updates the table to associates keys with values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableInsertV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + // TODO: Validate keys and values shape. return Status::OK(); }) @@ -1918,6 +1993,17 @@ table_handle: Handle to the table. size: Scalar that contains number of elements in the table. )doc"); +REGISTER_OP("LookupTableSizeV2") + .Input("table_handle: resource") + .Output("size: int64") + .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs) + .Doc(R"doc( +Computes the number of elements in the given table. + +table_handle: Handle to the table. +size: Scalar that contains number of elements in the table. +)doc"); + REGISTER_OP("LookupTableExport") .Input("table_handle: Ref(string)") .Output("keys: Tkeys") @@ -1945,6 +2031,31 @@ keys: Vector of all keys present in the table. values: Tensor of all values in the table. Indexed in parallel with `keys`. )doc"); +REGISTER_OP("LookupTableExportV2") + .Input("table_handle: resource") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle values = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); + ShapeHandle keys = c->Vector(c->Dim(values, 0)); + c->set_output(0, keys); + c->set_output(1, values); + return Status::OK(); + }) + .Doc(R"doc( +Outputs all keys and values in the table. + +table_handle: Handle to the table. +keys: Vector of all keys present in the table. +values: Tensor of all values in the table. Indexed in parallel with `keys`. +)doc"); + REGISTER_OP("LookupTableImport") .Input("table_handle: Ref(string)") .Input("keys: Tin") @@ -1957,6 +2068,30 @@ REGISTER_OP("LookupTableImport") DimensionHandle unused_dim; TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Replaces the contents of the table with the specified keys and values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableImportV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + // TODO: Validate keys and values shape. return Status::OK(); }) @@ -1998,6 +2133,33 @@ key_dtype: Type of the table keys. value_dtype: Type of the table values. )doc"); +REGISTER_OP("HashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates a non-initialized hash table. + +This op creates a hash table, specifying the type of its keys and values. +Before using the table you will have to initialize it. After initialization the +table will be immutable. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + REGISTER_OP("MutableHashTable") .Output("table_handle: Ref(string)") .Attr("container: string = ''") @@ -2025,6 +2187,33 @@ key_dtype: Type of the table keys. value_dtype: Type of the table values. )doc"); +REGISTER_OP("MutableHashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + REGISTER_OP("MutableHashTableOfTensors") .Output("table_handle: Ref(string)") .Attr("container: string = ''") @@ -2051,6 +2240,32 @@ key_dtype: Type of the table keys. value_dtype: Type of the table values. )doc"); +REGISTER_OP("MutableHashTableOfTensorsV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + REGISTER_OP("MutableDenseHashTable") .Input("empty_key: key_dtype") .Output("table_handle: Ref(string)") @@ -2088,6 +2303,43 @@ max_load_factor: The maximum ratio between number of entries and number of buckets before growing the table. Must be between 0 and 1. )doc"); +REGISTER_OP("MutableDenseHashTableV2") + .Input("empty_key: key_dtype") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .Attr("initial_num_buckets: int = 131072") // 2^17 + .Attr("max_load_factor: float = 0.8") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table that uses tensors as the backing store. It uses +"open addressing" with quadratic reprobing to resolve collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +empty_key: The key used to represent empty key buckets internally. Must not + be used in insert or lookup operations. +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +value_shape: The shape of each value. +initial_num_buckets: The initial number of hash table buckets. Must be a power + to 2. +max_load_factor: The maximum ratio between number of entries and number of + buckets before growing the table. Must be between 0 and 1. +)doc"); + REGISTER_OP("InitializeTable") .Input("table_handle: Ref(string)") .Input("keys: Tkey") @@ -2113,6 +2365,29 @@ keys: Keys of type Tkey. values: Values of type Tval. )doc"); +REGISTER_OP("InitializeTableV2") + .Input("table_handle: resource") + .Input("keys: Tkey") + .Input("values: Tval") + .Attr("Tkey: type") + .Attr("Tval: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); + return Status::OK(); + }) + .Doc(R"doc( +Table initializer that takes two tensors for keys and values respectively. + +table_handle: Handle to a table which will be initialized. +keys: Keys of type Tkey. +values: Values of type Tval. +)doc"); + REGISTER_OP("InitializeTableFromTextFile") .Input("table_handle: Ref(string)") .Input("filename: string") @@ -2152,6 +2427,43 @@ vocab_size: Number of elements of the file, use -1 if unknown. delimiter: Delimiter to separate fields in a line. )doc"); +REGISTER_OP("InitializeTableFromTextFileV2") + .Input("table_handle: resource") + .Input("filename: string") + .Attr("key_index: int >= -2") + .Attr("value_index: int >= -2") + .Attr("vocab_size: int >= -1 = -1") + .Attr("delimiter: string = '\t'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return Status::OK(); + }) + .Doc(R"doc( +Initializes a table from a text file. + +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + +table_handle: Handle to a table which will be initialized. +filename: Filename of a vocabulary text file. +key_index: Column index in a line to get the table `key` values from. +value_index: Column index that represents information of a line to get the table + `value` values from. +vocab_size: Number of elements of the file, use -1 if unknown. +delimiter: Delimiter to separate fields in a line. +)doc"); + REGISTER_OP("GetSessionHandle") .Input("value: T") .Output("handle: string") diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 9022e1453d..ec02ee3e03 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -63,16 +63,27 @@ GetSessionHandle GetSessionHandleV2 GetSessionTensor HashTable +HashTableV2 InitializeTable +InitializeTableV2 InitializeTableFromTextFile +InitializeTableFromTextFileV2 LookupTableExport +LookupTableExportV2 LookupTableFind +LookupTableFindV2 LookupTableImport +LookupTableImportV2 LookupTableInsert +LookupTableInsertV2 LookupTableSize +LookupTableSizeV2 MutableDenseHashTable +MutableDenseHashTableV2 MutableHashTable +MutableHashTableV2 MutableHashTableOfTensors +MutableHashTableOfTensorsV2 Mutex MutexAcquire MutexRelease |