aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/lookup_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/lookup_ops.cc')
-rw-r--r--tensorflow/core/ops/lookup_ops.cc139
1 files changed, 128 insertions, 11 deletions
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
index 2059741da9..7c71406c6b 100644
--- a/tensorflow/core/ops/lookup_ops.cc
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -23,6 +23,7 @@ namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
+using shape_inference::ShapeAndType;
using shape_inference::ShapeHandle;
// --------------------------------------------------------------------------
@@ -86,6 +87,74 @@ REGISTER_OP("LookupTableFind")
return Status::OK();
});
+Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys,
+ const string& key_dtype_attr,
+ const string& value_dtype_attr,
+ bool is_lookup,
+ ShapeAndType* output_shape_and_type) {
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ if (handle_data == nullptr || handle_data->size() != 2) {
+ output_shape_and_type->shape = c->UnknownShape();
+ output_shape_and_type->dtype = DT_INVALID;
+ } else {
+ const ShapeAndType& key_shape_and_type = (*handle_data)[0];
+ const ShapeAndType& value_shape_and_type = (*handle_data)[1];
+ DataType key_dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype));
+ if (key_shape_and_type.dtype != key_dtype) {
+ return errors::InvalidArgument(
+ "Trying to read value with wrong dtype. "
+ "Expected ",
+ DataTypeString(key_shape_and_type.dtype), " got ",
+ DataTypeString(key_dtype));
+ }
+ DataType value_dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype));
+ if (value_shape_and_type.dtype != value_dtype) {
+ return errors::InvalidArgument(
+ "Trying to read value with wrong dtype. "
+ "Expected ",
+ DataTypeString(value_shape_and_type.dtype), " got ",
+ DataTypeString(value_dtype));
+ }
+ output_shape_and_type->dtype = value_shape_and_type.dtype;
+
+ if (is_lookup) {
+ if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) {
+ int keys_rank = c->Rank(keys);
+ int key_suffix_rank = c->Rank(key_shape_and_type.shape);
+ if (keys_rank < key_suffix_rank) {
+ return errors::InvalidArgument(
+ "Expected keys to have suffix ",
+ c->DebugString(key_shape_and_type.shape),
+ " but saw shape: ", c->DebugString(keys));
+ }
+ for (int d = 0; d < key_suffix_rank; d++) {
+ // Ensure the suffix of keys match what's in the Table.
+ DimensionHandle dim = c->Dim(key_shape_and_type.shape, d);
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys));
+ }
+ std::vector<DimensionHandle> keys_prefix_vec;
+ keys_prefix_vec.reserve(keys_rank - key_suffix_rank);
+ for (int d = 0; d < keys_rank - key_suffix_rank; ++d) {
+ keys_prefix_vec.push_back(c->Dim(keys, d));
+ }
+ ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec);
+ TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix,
+ value_shape_and_type.shape,
+ &output_shape_and_type->shape));
+ } else {
+ output_shape_and_type->shape = c->UnknownShape();
+ }
+ } else {
+ TF_RETURN_IF_ERROR(c->Concatenate(keys, value_shape_and_type.shape,
+ &output_shape_and_type->shape));
+ }
+ }
+ return Status::OK();
+}
+
REGISTER_OP("LookupTableFindV2")
.Input("table_handle: resource")
.Input("keys: Tin")
@@ -98,9 +167,18 @@ REGISTER_OP("LookupTableFindV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
// Default value must be scalar or vector.
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
- c->set_output(0, c->UnknownShape());
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));
+
+ ShapeAndType value_shape_and_type;
+ TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
+ c,
+ /*keys=*/c->input(1),
+ /*key_dtype_attr=*/"Tin",
+ /*value_dtype_attr=*/"Tout",
+ /*is_lookup=*/true, &value_shape_and_type));
+ c->set_output(0, value_shape_and_type.shape);
+
return Status::OK();
});
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2");
@@ -177,12 +255,16 @@ REGISTER_OP("LookupTableExportV2")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- ShapeHandle values = c->UnknownShape();
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
- ShapeHandle keys = c->Vector(c->Dim(values, 0));
+ ShapeHandle keys = c->UnknownShapeOfRank(1);
+ ShapeAndType value_shape_and_type;
+ TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
+ c,
+ /*keys=*/keys,
+ /*key_dtype_attr=*/"Tkeys",
+ /*value_dtype_attr=*/"Tvalues",
+ /*is_lookup=*/false, &value_shape_and_type));
c->set_output(0, keys);
- c->set_output(1, values);
+ c->set_output(1, value_shape_and_type.shape);
return Status::OK();
});
@@ -216,6 +298,26 @@ REGISTER_OP("LookupTableImportV2")
return Status::OK();
});
+Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key,
+ const ShapeHandle& value) {
+ c->set_output(0, c->Scalar());
+
+ ShapeHandle key_s;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(key, 1, &key_s));
+
+ DataType key_t;
+ TF_RETURN_IF_ERROR(c->GetAttr("key_dtype", &key_t));
+
+ DataType value_t;
+ TF_RETURN_IF_ERROR(c->GetAttr("value_dtype", &value_t));
+
+ // ShapeAndType vector for {key, value}.
+ c->set_output_handle_shapes_and_types(
+ 0, std::vector<ShapeAndType>{{key_s, key_t}, {value, value_t}});
+
+ return Status::OK();
+}
+
REGISTER_OP("HashTable")
.Output("table_handle: Ref(string)")
.Attr("container: string = ''")
@@ -254,7 +356,10 @@ REGISTER_OP("MutableHashTableV2")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
.SetIsStateful()
- .SetShapeFn(ScalarOutput);
+ .SetShapeFn([](InferenceContext* c) {
+ return MutableHashTableShape(c, /*key=*/c->Scalar(),
+ /*value=*/c->Scalar());
+ });
REGISTER_OP("MutableHashTableOfTensors")
.Output("table_handle: Ref(string)")
@@ -276,7 +381,13 @@ REGISTER_OP("MutableHashTableOfTensorsV2")
.Attr("value_dtype: type")
.Attr("value_shape: shape = {}")
.SetIsStateful()
- .SetShapeFn(ScalarOutput);
+ .SetShapeFn([](InferenceContext* c) {
+ PartialTensorShape value_p;
+ TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p));
+ ShapeHandle value_s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s));
+ return MutableHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s);
+ });
REGISTER_OP("MutableDenseHashTable")
.Input("empty_key: key_dtype")
@@ -304,7 +415,13 @@ REGISTER_OP("MutableDenseHashTableV2")
.Attr("initial_num_buckets: int = 131072") // 2^17
.Attr("max_load_factor: float = 0.8")
.SetIsStateful()
- .SetShapeFn(ScalarOutput);
+ .SetShapeFn([](InferenceContext* c) {
+ PartialTensorShape value_p;
+ TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p));
+ ShapeHandle value_s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s));
+ return MutableHashTableShape(c, /*key=*/c->input(0), /*value=*/value_s);
+ });
REGISTER_OP("InitializeTable")
.Input("table_handle: Ref(string)")