diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-08-13 13:03:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-13 13:07:41 -0700 |
commit | 35535e6313c7c35a851466efd67be0ec1df14c9e (patch) | |
tree | 199000b261fdc9b3331650580e7883e768a36c07 | |
parent | 959f075558b33674c201367aef4bfc9c2dc116c4 (diff) |
[tf.contrib.lookup] Clean up shape inference for lookup ops.
More of the shape inference can be done in C++-land, which may help grappler
do its thing. Also fix a bug where keys.dim_size(0) was being requested even
when keys.dims() == 0
[this should probably lead to DCHECK failure, but doesn't seem to].
PiperOrigin-RevId: 208529368
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 33 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 11 | ||||
-rw-r--r-- | tensorflow/core/kernels/lookup_table_op.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/ops/lookup_ops.cc | 139 |
4 files changed, 149 insertions, 45 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 4942d94176..8c0bfefb30 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_lookup_ops from tensorflow.python.ops import lookup_ops # pylint: disable=unused-import @@ -395,17 +394,12 @@ class MutableHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype.base_dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_lookup_table_find" % self._name, (self._table_ref, keys, self._default_value)) as name: + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") with ops.colocate_with(self._table_ref): values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, keys, self._default_value, name=name) - - values.set_shape(keys.get_shape().concatenate(self._value_shape)) return values def insert(self, keys, values, name=None): @@ -451,9 +445,6 @@ class MutableHashTable(LookupInterface): with ops.colocate_with(self._table_ref): exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( self._table_ref, self._key_dtype, self._value_dtype, name=name) - - exported_values.set_shape(exported_keys.get_shape().concatenate( - self._value_shape)) return exported_keys, exported_values class _Saveable(BaseSaverBuilder.SaveableObject): @@ -537,14 +528,15 @@ class MutableDenseHashTable(LookupInterface): ValueError: If checkpoint is True and no name was specified. """ self._default_value = ops.convert_to_tensor( - default_value, dtype=value_dtype) + default_value, dtype=value_dtype, name="default_value") self._value_shape = self._default_value.get_shape() # The table must be shared if checkpointing is requested for multi-worker # training to work correctly. Use the node name if no shared_name has been # explicitly specified. use_node_name_sharing = checkpoint and shared_name is None - empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype) + empty_key = ops.convert_to_tensor( + empty_key, dtype=key_dtype, name="empty_key") self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( empty_key=empty_key, shared_name=shared_name, @@ -591,20 +583,13 @@ class MutableDenseHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype.base_dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_lookup_table_find" % self._name, [self._table_ref, keys]) as name: + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") with ops.colocate_with(self._table_ref): values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, keys, self._default_value, name=name) - if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0: - values.set_shape( - tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate( - self._value_shape)) return values def insert(self, keys, values, name=None): @@ -624,11 +609,11 @@ class MutableDenseHashTable(LookupInterface): TypeError: when `keys` or `values` doesn't match the table data types. """ - # pylint: disable=protected-access - lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype) - # pylint: enable=protected-access with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") + values = ops.convert_to_tensor( + values, dtype=self._value_dtype, name="values") with ops.colocate_with(self._table_ref): op = gen_lookup_ops.lookup_table_insert_v2( self._table_ref, keys, values, name=name) @@ -650,8 +635,6 @@ class MutableDenseHashTable(LookupInterface): exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( self._table_ref, self._key_dtype, self._value_dtype, name=name) - exported_values.set_shape(exported_keys.get_shape().concatenate( - self._value_shape)) return exported_keys, exported_values class _Saveable(BaseSaverBuilder.SaveableObject): diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 8d510ede58..6fb5244fc6 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -434,8 +434,10 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list()) - self.assertAllEqual([None, 2], exported_values.get_shape().as_list()) + self.assertAllEqual([None], exported_keys.get_shape().as_list(), + msg="Saw shape %s" % exported_keys.shape) + self.assertAllEqual([None, 2], exported_values.get_shape().as_list(), + msg="Saw shape %s" % exported_values.shape) # exported data is in the order of the internal map, i.e. undefined sorted_keys = np.sort(exported_keys.eval()) sorted_values = np.sort(exported_values.eval()) @@ -669,7 +671,7 @@ class MutableHashTableOpTest(test.TestCase): # lookup with keys of the wrong type input_string = constant_op.constant([1, 2, 3], dtypes.int64) - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): table.lookup(input_string).eval() # default value of the wrong type @@ -853,7 +855,8 @@ class MutableDenseHashTableOpTest(test.TestCase): input_string = constant_op.constant([11, 12, 15], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([3, 4], output.get_shape()) + self.assertAllEqual( + [3, 4], output.shape, msg="Saw shape: %s" % output.shape) result = output.eval() self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]], diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 07e754a6ef..cbe8560267 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -341,7 +341,7 @@ class MutableDenseHashTable final : public LookupInterface { Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, const Tensor& default_value) override LOCKS_EXCLUDED(mu_) { - const int64 num_elements = key.dim_size(0); + const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64 key_size = key_shape_.num_elements(); const int64 value_size = value_shape_.num_elements(); if (key.NumElements() != num_elements * key_size) { @@ -403,8 +403,9 @@ class MutableDenseHashTable final : public LookupInterface { Status Insert(OpKernelContext* ctx, const Tensor& key, const Tensor& value) override LOCKS_EXCLUDED(mu_) { - if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) { - TensorShape expected_shape({key.dim_size(0)}); + const int64 batch_size = (key.dims() == 0) ? 1 : key.dim_size(0); + if (key.NumElements() != batch_size * key_shape_.num_elements()) { + TensorShape expected_shape({batch_size}); expected_shape.AppendShape(key_shape_); return errors::InvalidArgument("Expected key shape ", expected_shape.DebugString(), " got ", @@ -415,7 +416,7 @@ class MutableDenseHashTable final : public LookupInterface { // rather than updates. That means we may grow the table even though we // don't need to. As long as the number of keys inserted in one call is // small compared to the size of the map, the impact of this is minimal. - const int64 pending_num_entries = num_entries_ + key.dim_size(0); + const int64 pending_num_entries = num_entries_ + batch_size; if (pending_num_entries > num_buckets_ * max_load_factor_) { int64 new_num_buckets = num_buckets_; do { @@ -500,7 +501,7 @@ class MutableDenseHashTable final : public LookupInterface { private: Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value, bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - const int64 num_elements = key.dim_size(0); + const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64 value_size = value_shape_.num_elements(); const int64 key_size = key_shape_.num_elements(); const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index 2059741da9..7c71406c6b 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -23,6 +23,7 @@ namespace tensorflow { using shape_inference::DimensionHandle; using shape_inference::InferenceContext; +using shape_inference::ShapeAndType; using shape_inference::ShapeHandle; // -------------------------------------------------------------------------- @@ -86,6 +87,74 @@ REGISTER_OP("LookupTableFind") return Status::OK(); }); +Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys, + const string& key_dtype_attr, + const string& value_dtype_attr, + bool is_lookup, + ShapeAndType* output_shape_and_type) { + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data == nullptr || handle_data->size() != 2) { + output_shape_and_type->shape = c->UnknownShape(); + output_shape_and_type->dtype = DT_INVALID; + } else { + const ShapeAndType& key_shape_and_type = (*handle_data)[0]; + const ShapeAndType& value_shape_and_type = (*handle_data)[1]; + DataType key_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype)); + if (key_shape_and_type.dtype != key_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(key_shape_and_type.dtype), " got ", + DataTypeString(key_dtype)); + } + DataType value_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype)); + if (value_shape_and_type.dtype != value_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(value_shape_and_type.dtype), " got ", + DataTypeString(value_dtype)); + } + output_shape_and_type->dtype = value_shape_and_type.dtype; + + if (is_lookup) { + if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) { + int keys_rank = c->Rank(keys); + int key_suffix_rank = c->Rank(key_shape_and_type.shape); + if (keys_rank < key_suffix_rank) { + return errors::InvalidArgument( + "Expected keys to have suffix ", + c->DebugString(key_shape_and_type.shape), + " but saw shape: ", c->DebugString(keys)); + } + for (int d = 0; d < key_suffix_rank; d++) { + // Ensure the suffix of keys match what's in the Table. + DimensionHandle dim = c->Dim(key_shape_and_type.shape, d); + TF_RETURN_IF_ERROR( + c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys)); + } + std::vector<DimensionHandle> keys_prefix_vec; + keys_prefix_vec.reserve(keys_rank - key_suffix_rank); + for (int d = 0; d < keys_rank - key_suffix_rank; ++d) { + keys_prefix_vec.push_back(c->Dim(keys, d)); + } + ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec); + TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix, + value_shape_and_type.shape, + &output_shape_and_type->shape)); + } else { + output_shape_and_type->shape = c->UnknownShape(); + } + } else { + TF_RETURN_IF_ERROR(c->Concatenate(keys, value_shape_and_type.shape, + &output_shape_and_type->shape)); + } + } + return Status::OK(); +} + REGISTER_OP("LookupTableFindV2") .Input("table_handle: resource") .Input("keys: Tin") @@ -98,9 +167,18 @@ REGISTER_OP("LookupTableFindV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); // Default value must be scalar or vector. - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); - c->set_output(0, c->UnknownShape()); + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys)); + + ShapeAndType value_shape_and_type; + TF_RETURN_IF_ERROR(ValidateTableResourceHandle( + c, + /*keys=*/c->input(1), + /*key_dtype_attr=*/"Tin", + /*value_dtype_attr=*/"Tout", + /*is_lookup=*/true, &value_shape_and_type)); + c->set_output(0, value_shape_and_type.shape); + return Status::OK(); }); WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2"); @@ -177,12 +255,16 @@ REGISTER_OP("LookupTableExportV2") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - ShapeHandle values = c->UnknownShape(); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); - ShapeHandle keys = c->Vector(c->Dim(values, 0)); + ShapeHandle keys = c->UnknownShapeOfRank(1); + ShapeAndType value_shape_and_type; + TF_RETURN_IF_ERROR(ValidateTableResourceHandle( + c, + /*keys=*/keys, + /*key_dtype_attr=*/"Tkeys", + /*value_dtype_attr=*/"Tvalues", + /*is_lookup=*/false, &value_shape_and_type)); c->set_output(0, keys); - c->set_output(1, values); + c->set_output(1, value_shape_and_type.shape); return Status::OK(); }); @@ -216,6 +298,26 @@ REGISTER_OP("LookupTableImportV2") return Status::OK(); }); +Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key, + const ShapeHandle& value) { + c->set_output(0, c->Scalar()); + + ShapeHandle key_s; + TF_RETURN_IF_ERROR(c->WithRankAtMost(key, 1, &key_s)); + + DataType key_t; + TF_RETURN_IF_ERROR(c->GetAttr("key_dtype", &key_t)); + + DataType value_t; + TF_RETURN_IF_ERROR(c->GetAttr("value_dtype", &value_t)); + + // ShapeAndType vector for {key, value}. + c->set_output_handle_shapes_and_types( + 0, std::vector<ShapeAndType>{{key_s, key_t}, {value, value_t}}); + + return Status::OK(); +} + REGISTER_OP("HashTable") .Output("table_handle: Ref(string)") .Attr("container: string = ''") @@ -254,7 +356,10 @@ REGISTER_OP("MutableHashTableV2") .Attr("key_dtype: type") .Attr("value_dtype: type") .SetIsStateful() - .SetShapeFn(ScalarOutput); + .SetShapeFn([](InferenceContext* c) { + return MutableHashTableShape(c, /*key=*/c->Scalar(), + /*value=*/c->Scalar()); + }); REGISTER_OP("MutableHashTableOfTensors") .Output("table_handle: Ref(string)") @@ -276,7 +381,13 @@ REGISTER_OP("MutableHashTableOfTensorsV2") .Attr("value_dtype: type") .Attr("value_shape: shape = {}") .SetIsStateful() - .SetShapeFn(ScalarOutput); + .SetShapeFn([](InferenceContext* c) { + PartialTensorShape value_p; + TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p)); + ShapeHandle value_s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s)); + return MutableHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s); + }); REGISTER_OP("MutableDenseHashTable") .Input("empty_key: key_dtype") @@ -304,7 +415,13 @@ REGISTER_OP("MutableDenseHashTableV2") .Attr("initial_num_buckets: int = 131072") // 2^17 .Attr("max_load_factor: float = 0.8") .SetIsStateful() - .SetShapeFn(ScalarOutput); + .SetShapeFn([](InferenceContext* c) { + PartialTensorShape value_p; + TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p)); + ShapeHandle value_s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s)); + return MutableHashTableShape(c, /*key=*/c->input(0), /*value=*/value_s); + }); REGISTER_OP("InitializeTable") .Input("table_handle: Ref(string)") |