diff options
Diffstat (limited to 'tensorflow/core/ops/lookup_ops.cc')
-rw-r--r-- | tensorflow/core/ops/lookup_ops.cc | 139 |
1 files changed, 128 insertions, 11 deletions
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index 2059741da9..7c71406c6b 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -23,6 +23,7 @@ namespace tensorflow { using shape_inference::DimensionHandle; using shape_inference::InferenceContext; +using shape_inference::ShapeAndType; using shape_inference::ShapeHandle; // -------------------------------------------------------------------------- @@ -86,6 +87,74 @@ REGISTER_OP("LookupTableFind") return Status::OK(); }); +Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys, + const string& key_dtype_attr, + const string& value_dtype_attr, + bool is_lookup, + ShapeAndType* output_shape_and_type) { + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data == nullptr || handle_data->size() != 2) { + output_shape_and_type->shape = c->UnknownShape(); + output_shape_and_type->dtype = DT_INVALID; + } else { + const ShapeAndType& key_shape_and_type = (*handle_data)[0]; + const ShapeAndType& value_shape_and_type = (*handle_data)[1]; + DataType key_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype)); + if (key_shape_and_type.dtype != key_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(key_shape_and_type.dtype), " got ", + DataTypeString(key_dtype)); + } + DataType value_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype)); + if (value_shape_and_type.dtype != value_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(value_shape_and_type.dtype), " got ", + DataTypeString(value_dtype)); + } + output_shape_and_type->dtype = value_shape_and_type.dtype; + + if (is_lookup) { + if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) { + int keys_rank = c->Rank(keys); + int key_suffix_rank = c->Rank(key_shape_and_type.shape); + if (keys_rank < key_suffix_rank) { + return errors::InvalidArgument( + "Expected keys to have suffix ", + c->DebugString(key_shape_and_type.shape), + " but saw shape: ", c->DebugString(keys)); + } + for (int d = 0; d < key_suffix_rank; d++) { + // Ensure the suffix of keys match what's in the Table. + DimensionHandle dim = c->Dim(key_shape_and_type.shape, d); + TF_RETURN_IF_ERROR( + c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys)); + } + std::vector<DimensionHandle> keys_prefix_vec; + keys_prefix_vec.reserve(keys_rank - key_suffix_rank); + for (int d = 0; d < keys_rank - key_suffix_rank; ++d) { + keys_prefix_vec.push_back(c->Dim(keys, d)); + } + ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec); + TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix, + value_shape_and_type.shape, + &output_shape_and_type->shape)); + } else { + output_shape_and_type->shape = c->UnknownShape(); + } + } else { + TF_RETURN_IF_ERROR(c->Concatenate(keys, value_shape_and_type.shape, + &output_shape_and_type->shape)); + } + } + return Status::OK(); +} + REGISTER_OP("LookupTableFindV2") .Input("table_handle: resource") .Input("keys: Tin") @@ -98,9 +167,18 @@ REGISTER_OP("LookupTableFindV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); // Default value must be scalar or vector. - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); - c->set_output(0, c->UnknownShape()); + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys)); + + ShapeAndType value_shape_and_type; + TF_RETURN_IF_ERROR(ValidateTableResourceHandle( + c, + /*keys=*/c->input(1), + /*key_dtype_attr=*/"Tin", + /*value_dtype_attr=*/"Tout", + /*is_lookup=*/true, &value_shape_and_type)); + c->set_output(0, value_shape_and_type.shape); + return Status::OK(); }); WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2"); @@ -177,12 +255,16 @@ REGISTER_OP("LookupTableExportV2") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - ShapeHandle values = c->UnknownShape(); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); - ShapeHandle keys = c->Vector(c->Dim(values, 0)); + ShapeHandle keys = c->UnknownShapeOfRank(1); + ShapeAndType value_shape_and_type; + TF_RETURN_IF_ERROR(ValidateTableResourceHandle( + c, + /*keys=*/keys, + /*key_dtype_attr=*/"Tkeys", + /*value_dtype_attr=*/"Tvalues", + /*is_lookup=*/false, &value_shape_and_type)); c->set_output(0, keys); - c->set_output(1, values); + c->set_output(1, value_shape_and_type.shape); return Status::OK(); }); @@ -216,6 +298,26 @@ REGISTER_OP("LookupTableImportV2") return Status::OK(); }); +Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key, + const ShapeHandle& value) { + c->set_output(0, c->Scalar()); + + ShapeHandle key_s; + TF_RETURN_IF_ERROR(c->WithRankAtMost(key, 1, &key_s)); + + DataType key_t; + TF_RETURN_IF_ERROR(c->GetAttr("key_dtype", &key_t)); + + DataType value_t; + TF_RETURN_IF_ERROR(c->GetAttr("value_dtype", &value_t)); + + // ShapeAndType vector for {key, value}. + c->set_output_handle_shapes_and_types( + 0, std::vector<ShapeAndType>{{key_s, key_t}, {value, value_t}}); + + return Status::OK(); +} + REGISTER_OP("HashTable") .Output("table_handle: Ref(string)") .Attr("container: string = ''") @@ -254,7 +356,10 @@ REGISTER_OP("MutableHashTableV2") .Attr("key_dtype: type") .Attr("value_dtype: type") .SetIsStateful() - .SetShapeFn(ScalarOutput); + .SetShapeFn([](InferenceContext* c) { + return MutableHashTableShape(c, /*key=*/c->Scalar(), + /*value=*/c->Scalar()); + }); REGISTER_OP("MutableHashTableOfTensors") .Output("table_handle: Ref(string)") @@ -276,7 +381,13 @@ REGISTER_OP("MutableHashTableOfTensorsV2") .Attr("value_dtype: type") .Attr("value_shape: shape = {}") .SetIsStateful() - .SetShapeFn(ScalarOutput); + .SetShapeFn([](InferenceContext* c) { + PartialTensorShape value_p; + TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p)); + ShapeHandle value_s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s)); + return MutableHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s); + }); REGISTER_OP("MutableDenseHashTable") .Input("empty_key: key_dtype") @@ -304,7 +415,13 @@ REGISTER_OP("MutableDenseHashTableV2") .Attr("initial_num_buckets: int = 131072") // 2^17 .Attr("max_load_factor: float = 0.8") .SetIsStateful() - .SetShapeFn(ScalarOutput); + .SetShapeFn([](InferenceContext* c) { + PartialTensorShape value_p; + TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p)); + ShapeHandle value_s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s)); + return MutableHashTableShape(c, /*key=*/c->input(0), /*value=*/value_s); + }); REGISTER_OP("InitializeTable") .Input("table_handle: Ref(string)") |