aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-05-01 11:15:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-01 12:29:29 -0700
commit9528658bb51d3a4e14ec1014e8fdd8e0076805e8 (patch)
tree83cd29bd789c9bf48dd951198a62cabe73170813 /tensorflow
parent03327190420dd5b1c34a5ffdd0000aff40980ed5 (diff)
LookupTable ops go to V2, use resource handles.
There's some duplicate documentation; yutaka will deprecate the old ops and remove the duplicate docs, update the shape tests to point to the new ops, and modify lookup_ops.py to use the _v2 ops. Change: 154743350
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/framework/resource_mgr.cc8
-rw-r--r--tensorflow/core/framework/resource_mgr.h2
-rw-r--r--tensorflow/core/kernels/lookup_table_init_op.cc13
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc51
-rw-r--r--tensorflow/core/kernels/lookup_table_op.h72
-rw-r--r--tensorflow/core/kernels/lookup_util.cc50
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc312
-rw-r--r--tensorflow/python/ops/hidden_ops.txt11
8 files changed, 470 insertions, 49 deletions
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 7f9fe084ba..ab7dd0c547 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -246,6 +246,14 @@ ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) {
return ctx->input(input).flat<ResourceHandle>()(0);
}
+Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
+ ResourceHandle* handle) {
+ const Tensor* tensor;
+ TF_RETURN_IF_ERROR(ctx->input(input, &tensor));
+ *handle = tensor->flat<ResourceHandle>()(0);
+ return Status::OK();
+}
+
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
return ctx->resource_manager()->Delete(p);
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index fe6e09378f..26a5766569 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -211,6 +211,8 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
// Returns a resource handle from a numbered op input.
ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
+Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
+ ResourceHandle* handle);
// Create a resource pointed by a given resource handle.
template <typename T>
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc
index bde1d0360a..ada6fe8d95 100644
--- a/tensorflow/core/kernels/lookup_table_init_op.cc
+++ b/tensorflow/core/kernels/lookup_table_init_op.cc
@@ -367,7 +367,9 @@ class InitializeTableOp : public OpKernel {
GetInitializableLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
- DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
table->value_dtype()};
DataTypeVector expected_outputs = {};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
@@ -408,6 +410,8 @@ class InitializeTableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU),
InitializeTableOp);
+REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU),
+ InitializeTableOp);
// Kernel to initialize a lookup table from a text file.
//
@@ -433,7 +437,9 @@ class InitializeTableFromTextFileOp : public OpKernel {
GetInitializableLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
- DataTypeVector expected_inputs = {DT_STRING_REF, DT_STRING};
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, DT_STRING};
DataTypeVector expected_outputs = {};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
@@ -472,5 +478,8 @@ class InitializeTableFromTextFileOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU),
InitializeTableFromTextFileOp);
+REGISTER_KERNEL_BUILDER(
+ Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
+ InitializeTableFromTextFileOp);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 0a065e37d7..11ce2a71dc 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -624,7 +624,10 @@ class LookupTableFindOp : public OpKernel {
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
- DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
+ // Input 0 could be a STRING_REF or a RESOURCE
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
table->value_dtype()};
DataTypeVector expected_outputs = {table->value_dtype()};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
@@ -647,6 +650,8 @@ class LookupTableFindOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU),
LookupTableFindOp);
+REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU),
+ LookupTableFindOp);
// Table insert op.
class LookupTableInsertOp : public OpKernel {
@@ -658,7 +663,9 @@ class LookupTableInsertOp : public OpKernel {
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
- DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
table->value_dtype()};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
@@ -680,6 +687,8 @@ class LookupTableInsertOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU),
LookupTableInsertOp);
+REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU),
+ LookupTableInsertOp);
// Op that returns the size of the given table.
class LookupTableSizeOp : public OpKernel {
@@ -699,6 +708,8 @@ class LookupTableSizeOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU),
LookupTableSizeOp);
+REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2").Device(DEVICE_CPU),
+ LookupTableSizeOp);
// Op that outputs tensors of all keys and all values.
class LookupTableExportOp : public OpKernel {
@@ -716,6 +727,8 @@ class LookupTableExportOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU),
LookupTableExportOp);
+REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2").Device(DEVICE_CPU),
+ LookupTableExportOp);
// Clear the table and insert data.
class LookupTableImportOp : public OpKernel {
@@ -727,7 +740,9 @@ class LookupTableImportOp : public OpKernel {
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
- DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
+ DataType expected_input_0 =
+ (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
+ DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
table->value_dtype()};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
@@ -749,6 +764,8 @@ class LookupTableImportOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU),
LookupTableImportOp);
+REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
+ LookupTableImportOp);
// Register the HashTable op with the currently supported key and value types.
#define REGISTER_KERNEL(key_dtype, value_dtype) \
@@ -758,6 +775,13 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU),
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
+ value_dtype>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("HashTableV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<key_dtype>("key_dtype") \
+ .TypeConstraint<value_dtype>("value_dtype"), \
+ LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
value_dtype>)
REGISTER_KERNEL(string, double);
@@ -779,6 +803,13 @@ REGISTER_KERNEL(string, bool);
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
+ key_dtype, value_dtype>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MutableHashTableV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<key_dtype>("key_dtype") \
+ .TypeConstraint<value_dtype>("value_dtype"), \
+ LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
REGISTER_KERNEL(string, float);
@@ -797,6 +828,13 @@ REGISTER_KERNEL(int64, float);
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
+ key_dtype, value_dtype>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MutableHashTableOfTensorsV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<key_dtype>("key_dtype") \
+ .TypeConstraint<value_dtype>("value_dtype"), \
+ LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
REGISTER_KERNEL(string, float);
@@ -814,6 +852,13 @@ REGISTER_KERNEL(string, bool);
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
+ key_dtype, value_dtype>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MutableDenseHashTableV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<key_dtype>("key_dtype") \
+ .TypeConstraint<value_dtype>("value_dtype"), \
+ LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
REGISTER_KERNEL(int64, int64);
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
index ae253b4dc9..4cd25a3cc6 100644
--- a/tensorflow/core/kernels/lookup_table_op.h
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -51,40 +51,52 @@ class LookupTableOp : public OpKernel {
// ctx is not owned by this function.
void Compute(OpKernelContext* ctx) override {
mutex_lock l(mu_);
+
if (!table_handle_set_) {
OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
use_node_name_sharing_));
- auto creator = [ctx, this](lookup::LookupInterface** ret) {
- lookup::LookupInterface* container = new Container(ctx, this);
- if (!ctx->status().ok()) {
- container->Unref();
- return ctx->status();
- }
- if (ctx->track_allocations()) {
- ctx->record_device_persistent_memory_allocation(
- container->MemoryUsed());
- }
- *ret = container;
- return Status::OK();
- };
-
- lookup::LookupInterface* table = nullptr;
- OP_REQUIRES_OK(
- ctx, cinfo_.resource_manager()
- ->template LookupOrCreate<lookup::LookupInterface>(
- cinfo_.container(), cinfo_.name(), &table, creator));
- core::ScopedUnref unref_me(table);
-
- OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
- *table, DataTypeToEnum<key_dtype>::v(),
- DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
-
- auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
- h(0) = cinfo_.container();
- h(1) = cinfo_.name();
- table_handle_set_ = true;
}
- ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx));
+
+ auto creator = [ctx, this](lookup::LookupInterface** ret) {
+ lookup::LookupInterface* container = new Container(ctx, this);
+ if (!ctx->status().ok()) {
+ container->Unref();
+ return ctx->status();
+ }
+ if (ctx->track_allocations()) {
+ ctx->record_device_persistent_memory_allocation(
+ container->MemoryUsed());
+ }
+ *ret = container;
+ return Status::OK();
+ };
+
+ lookup::LookupInterface* table = nullptr;
+ OP_REQUIRES_OK(ctx,
+ cinfo_.resource_manager()
+ ->template LookupOrCreate<lookup::LookupInterface>(
+ cinfo_.container(), cinfo_.name(), &table, creator));
+ core::ScopedUnref unref_me(table);
+
+ OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
+ *table, DataTypeToEnum<key_dtype>::v(),
+ DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
+
+ if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
+ Tensor* handle;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
+ handle->scalar<ResourceHandle>()() =
+ MakeResourceHandle<lookup::LookupInterface>(ctx, cinfo_.container(),
+ cinfo_.name());
+ } else {
+ if (!table_handle_set_) {
+ auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ }
+ ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx));
+ }
+ table_handle_set_ = true;
}
~LookupTableOp() override {
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc
index f87ce0e6b2..d0f269be23 100644
--- a/tensorflow/core/kernels/lookup_util.cc
+++ b/tensorflow/core/kernels/lookup_util.cc
@@ -49,26 +49,48 @@ Status GetLookupTable(const string& input_name, OpKernelContext* ctx,
LookupInterface** table) {
string container;
string table_handle;
- TF_RETURN_IF_ERROR(
- GetTableHandle(input_name, ctx, &container, &table_handle));
- return ctx->resource_manager()->Lookup(container, table_handle, table);
+ DataType handle_dtype;
+ TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
+ if (handle_dtype == DT_RESOURCE) {
+ ResourceHandle handle;
+ TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
+ return LookupResource(ctx, handle, table);
+ } else {
+ TF_RETURN_IF_ERROR(
+ GetTableHandle(input_name, ctx, &container, &table_handle));
+ return ctx->resource_manager()->Lookup(container, table_handle, table);
+ }
}
Status GetInitializableLookupTable(const string& input_name,
OpKernelContext* ctx,
InitializableLookupTable** table) {
- string container;
- string table_handle;
- TF_RETURN_IF_ERROR(
- GetTableHandle(input_name, ctx, &container, &table_handle));
LookupInterface* lookup_table;
- TF_RETURN_IF_ERROR(
- ctx->resource_manager()->Lookup(container, table_handle, &lookup_table));
- *table = lookup_table->GetInitializableLookupTable();
- if (*table == nullptr) {
- lookup_table->Unref();
- return errors::InvalidArgument("Table ", container, " ", table_handle,
- " is not initializable");
+ DataType handle_dtype;
+ TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
+ if (handle_dtype == DT_RESOURCE) {
+ ResourceHandle handle;
+ TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
+ TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table));
+ *table = lookup_table->GetInitializableLookupTable();
+ if (*table == nullptr) {
+ lookup_table->Unref();
+ return errors::InvalidArgument("Table ", handle.container(), " ",
+ handle.name(), " is not initializable");
+ }
+ } else {
+ string container;
+ string table_handle;
+ TF_RETURN_IF_ERROR(
+ GetTableHandle(input_name, ctx, &container, &table_handle));
+ TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle,
+ &lookup_table));
+ *table = lookup_table->GetInitializableLookupTable();
+ if (*table == nullptr) {
+ lookup_table->Unref();
+ return errors::InvalidArgument("Table ", container, " ", table_handle,
+ " is not initializable");
+ }
}
return Status::OK();
}
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index b34dd4ae90..f82e9d1eb7 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -210,10 +210,29 @@ Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
return Status::OK();
}
+Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
+ ShapeHandle handle;
+ DimensionHandle unused_handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+ for (int i = 1; i < c->num_inputs(); ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
+ }
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ return Status::OK();
+}
+
Status TwoElementOutput(InferenceContext* c) {
c->set_output(0, c->Vector(2));
return Status::OK();
}
+
+Status ScalarOutput(InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+}
} // namespace
REGISTER_OP("RandomShuffleQueue")
@@ -1881,6 +1900,38 @@ values: Same shape as `keys`. Values found in the table, or `default_values`
for missing keys.
)doc");
+REGISTER_OP("LookupTableFindV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Input("default_value: Tout")
+ .Output("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ // Default value must be scalar or vector.
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Looks up keys in a table, outputs the corresponding values.
+
+The tensor `keys` must of the same type as the keys of the table.
+The output `values` is of the type of the table values.
+
+The scalar `default_value` is the value output for keys not present in the
+table. It must also be of the same type as the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Same shape as `keys`. Values found in the table, or `default_values`
+ for missing keys.
+)doc");
+
REGISTER_OP("LookupTableInsert")
.Input("table_handle: Ref(string)")
.Input("keys: Tin")
@@ -1893,6 +1944,30 @@ REGISTER_OP("LookupTableInsert")
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+ // TODO(ebrevdo): Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Updates the table to associates keys with values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("LookupTableInsertV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
// TODO: Validate keys and values shape.
return Status::OK();
})
@@ -1918,6 +1993,17 @@ table_handle: Handle to the table.
size: Scalar that contains number of elements in the table.
)doc");
+REGISTER_OP("LookupTableSizeV2")
+ .Input("table_handle: resource")
+ .Output("size: int64")
+ .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs)
+ .Doc(R"doc(
+Computes the number of elements in the given table.
+
+table_handle: Handle to the table.
+size: Scalar that contains number of elements in the table.
+)doc");
+
REGISTER_OP("LookupTableExport")
.Input("table_handle: Ref(string)")
.Output("keys: Tkeys")
@@ -1945,6 +2031,31 @@ keys: Vector of all keys present in the table.
values: Tensor of all values in the table. Indexed in parallel with `keys`.
)doc");
+REGISTER_OP("LookupTableExportV2")
+ .Input("table_handle: resource")
+ .Output("keys: Tkeys")
+ .Output("values: Tvalues")
+ .Attr("Tkeys: type")
+ .Attr("Tvalues: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ ShapeHandle values = c->UnknownShape();
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
+ ShapeHandle keys = c->Vector(c->Dim(values, 0));
+ c->set_output(0, keys);
+ c->set_output(1, values);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs all keys and values in the table.
+
+table_handle: Handle to the table.
+keys: Vector of all keys present in the table.
+values: Tensor of all values in the table. Indexed in parallel with `keys`.
+)doc");
+
REGISTER_OP("LookupTableImport")
.Input("table_handle: Ref(string)")
.Input("keys: Tin")
@@ -1957,6 +2068,30 @@ REGISTER_OP("LookupTableImport")
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+ // TODO(ebrevdo): Validate keys and values shape.
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Replaces the contents of the table with the specified keys and values.
+
+The tensor `keys` must be of the same type as the keys of the table.
+The tensor `values` must be of the type of the table values.
+
+table_handle: Handle to the table.
+keys: Any shape. Keys to look up.
+values: Values to associate with keys.
+)doc");
+
+REGISTER_OP("LookupTableImportV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tin")
+ .Input("values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
// TODO: Validate keys and values shape.
return Status::OK();
})
@@ -1998,6 +2133,33 @@ key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
+REGISTER_OP("HashTableV2")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates a non-initialized hash table.
+
+This op creates a hash table, specifying the type of its keys and values.
+Before using the table you will have to initialize it. After initialization the
+table will be immutable.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
REGISTER_OP("MutableHashTable")
.Output("table_handle: Ref(string)")
.Attr("container: string = ''")
@@ -2025,6 +2187,33 @@ key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
+REGISTER_OP("MutableHashTableV2")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+use_node_name_sharing: If true and shared_name is empty, the table is shared
+ using the node name.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
REGISTER_OP("MutableHashTableOfTensors")
.Output("table_handle: Ref(string)")
.Attr("container: string = ''")
@@ -2051,6 +2240,32 @@ key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
+REGISTER_OP("MutableHashTableOfTensorsV2")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates an empty hash table.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a vector. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+)doc");
+
REGISTER_OP("MutableDenseHashTable")
.Input("empty_key: key_dtype")
.Output("table_handle: Ref(string)")
@@ -2088,6 +2303,43 @@ max_load_factor: The maximum ratio between number of entries and number of
buckets before growing the table. Must be between 0 and 1.
)doc");
+REGISTER_OP("MutableDenseHashTableV2")
+ .Input("empty_key: key_dtype")
+ .Output("table_handle: resource")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("use_node_name_sharing: bool = false")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Attr("value_shape: shape = {}")
+ .Attr("initial_num_buckets: int = 131072") // 2^17
+ .Attr("max_load_factor: float = 0.8")
+ .SetIsStateful()
+ .SetShapeFn(ScalarOutput)
+ .Doc(R"doc(
+Creates an empty hash table that uses tensors as the backing store. It uses
+"open addressing" with quadratic reprobing to resolve collisions.
+
+This op creates a mutable hash table, specifying the type of its keys and
+values. Each value must be a scalar. Data can be inserted into the table using
+the insert operations. It does not support the initialization operation.
+
+empty_key: The key used to represent empty key buckets internally. Must not
+ be used in insert or lookup operations.
+table_handle: Handle to a table.
+container: If non-empty, this table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this table is shared under the given name across
+ multiple sessions.
+key_dtype: Type of the table keys.
+value_dtype: Type of the table values.
+value_shape: The shape of each value.
+initial_num_buckets: The initial number of hash table buckets. Must be a power
+ to 2.
+max_load_factor: The maximum ratio between number of entries and number of
+ buckets before growing the table. Must be between 0 and 1.
+)doc");
+
REGISTER_OP("InitializeTable")
.Input("table_handle: Ref(string)")
.Input("keys: Tkey")
@@ -2113,6 +2365,29 @@ keys: Keys of type Tkey.
values: Values of type Tval.
)doc");
+REGISTER_OP("InitializeTableV2")
+ .Input("table_handle: resource")
+ .Input("keys: Tkey")
+ .Input("values: Tval")
+ .Attr("Tkey: type")
+ .Attr("Tval: type")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
+ TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Table initializer that takes two tensors for keys and values respectively.
+
+table_handle: Handle to a table which will be initialized.
+keys: Keys of type Tkey.
+values: Values of type Tval.
+)doc");
+
REGISTER_OP("InitializeTableFromTextFile")
.Input("table_handle: Ref(string)")
.Input("filename: string")
@@ -2152,6 +2427,43 @@ vocab_size: Number of elements of the file, use -1 if unknown.
delimiter: Delimiter to separate fields in a line.
)doc");
+REGISTER_OP("InitializeTableFromTextFileV2")
+ .Input("table_handle: resource")
+ .Input("filename: string")
+ .Attr("key_index: int >= -2")
+ .Attr("value_index: int >= -2")
+ .Attr("vocab_size: int >= -1 = -1")
+ .Attr("delimiter: string = '\t'")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
+
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Initializes a table from a text file.
+
+It inserts one key-value pair into the table for each line of the file.
+The key and value is extracted from the whole line content, elements from the
+split line based on `delimiter` or the line number (starting from zero).
+Where to extract the key and value from a line is specified by `key_index` and
+`value_index`.
+
+- A value of -1 means use the line number(starting from zero), expects `int64`.
+- A value of -2 means use the whole line content, expects `string`.
+- A value >= 0 means use the index (starting at zero) of the split line based
+ on `delimiter`.
+
+table_handle: Handle to a table which will be initialized.
+filename: Filename of a vocabulary text file.
+key_index: Column index in a line to get the table `key` values from.
+value_index: Column index that represents information of a line to get the table
+ `value` values from.
+vocab_size: Number of elements of the file, use -1 if unknown.
+delimiter: Delimiter to separate fields in a line.
+)doc");
+
REGISTER_OP("GetSessionHandle")
.Input("value: T")
.Output("handle: string")
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 9022e1453d..ec02ee3e03 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -63,16 +63,27 @@ GetSessionHandle
GetSessionHandleV2
GetSessionTensor
HashTable
+HashTableV2
InitializeTable
+InitializeTableV2
InitializeTableFromTextFile
+InitializeTableFromTextFileV2
LookupTableExport
+LookupTableExportV2
LookupTableFind
+LookupTableFindV2
LookupTableImport
+LookupTableImportV2
LookupTableInsert
+LookupTableInsertV2
LookupTableSize
+LookupTableSizeV2
MutableDenseHashTable
+MutableDenseHashTableV2
MutableHashTable
+MutableHashTableV2
MutableHashTableOfTensors
+MutableHashTableOfTensorsV2
Mutex
MutexAcquire
MutexRelease