diff options
author | 2016-11-16 10:12:45 -0800 | |
---|---|---|
committer | 2016-11-16 10:26:41 -0800 | |
commit | 634790044c81521c438799b558d33b8440fa9e23 (patch) | |
tree | a0319de4c600738e4e917a3f9f2a194b08214635 /tensorflow/core/framework/lookup_interface.cc | |
parent | eabf41b7cc8ab515b556bc91b4f282d1d671c1a7 (diff) |
Fixing a bug in the MutableDenseHashTable implementation where the difference in shapes between the Insert and Import functions was causing issues with a vector key and scalar value input.
Fixed by splitting the LookupInterface CheckKeysAndValueTensors method into one for Insert and the other for Import.
Change: 139346138
Diffstat (limited to 'tensorflow/core/framework/lookup_interface.cc')
-rw-r--r-- | tensorflow/core/framework/lookup_interface.cc | 37 |
1 files changed, 22 insertions, 15 deletions
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc index 3322e3e8df..bf3204ea6e 100644 --- a/tensorflow/core/framework/lookup_interface.cc +++ b/tensorflow/core/framework/lookup_interface.cc @@ -21,17 +21,14 @@ limitations under the License. namespace tensorflow { namespace lookup { -namespace { -Status CheckKeyShape(const TensorShape& table_key_shape, - const TensorShape& key_shape) { - if (!TensorShapeUtils::EndsWith(key_shape, table_key_shape)) { - return errors::InvalidArgument("Input key shape ", key_shape.DebugString(), +Status LookupInterface::CheckKeyShape(const TensorShape& shape) { + if (!TensorShapeUtils::EndsWith(shape, key_shape())) { + return errors::InvalidArgument("Input key shape ", shape.DebugString(), " must end with the table's key shape ", - table_key_shape.DebugString()); + key_shape().DebugString()); } return Status::OK(); } -} // namespace Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values) { @@ -46,28 +43,38 @@ Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, return Status::OK(); } -Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key, - const Tensor& value) { - TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, value)); - TF_RETURN_IF_ERROR(CheckKeyShape(key_shape(), key.shape())); +Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys, + const Tensor& values) { + TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values)); + TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape())); - TensorShape expected_value_shape = key.shape(); + TensorShape expected_value_shape = keys.shape(); for (int i = 0; i < key_shape().dims(); ++i) { expected_value_shape.RemoveDim(expected_value_shape.dims() - 1); } expected_value_shape.AppendShape(value_shape()); - if (value.shape() != expected_value_shape) { + if (values.shape() != expected_value_shape) { return errors::InvalidArgument( "Expected shape ", expected_value_shape.DebugString(), - " for value, got ", value.shape().DebugString()); + " for value, got ", values.shape().DebugString()); } return Status::OK(); } +Status LookupInterface::CheckKeyAndValueTensorsForInsert(const Tensor& keys, + const Tensor& values) { + return CheckKeyAndValueTensorsHelper(keys, values); +} + +Status LookupInterface::CheckKeyAndValueTensorsForImport(const Tensor& keys, + const Tensor& values) { + return CheckKeyAndValueTensorsHelper(keys, values); +} + Status LookupInterface::CheckFindArguments(const Tensor& key, const Tensor& default_value) { TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value)); - TF_RETURN_IF_ERROR(CheckKeyShape(key_shape(), key.shape())); + TF_RETURN_IF_ERROR(CheckKeyShape(key.shape())); if (default_value.shape() != value_shape()) { return errors::InvalidArgument( "Expected shape ", value_shape().DebugString(), |