aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/lookup_interface.cc
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2016-11-16 10:12:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-16 10:26:41 -0800
commit634790044c81521c438799b558d33b8440fa9e23 (patch)
treea0319de4c600738e4e917a3f9f2a194b08214635 /tensorflow/core/framework/lookup_interface.cc
parenteabf41b7cc8ab515b556bc91b4f282d1d671c1a7 (diff)
Fixing a bug in the MutableDenseHashTable implementation where the difference in shapes between the Insert and Import functions was causing issues with a vector key and scalar value input.
Fixed by splitting the LookupInterface CheckKeysAndValueTensors method into one for Insert and the other for Import. Change: 139346138
Diffstat (limited to 'tensorflow/core/framework/lookup_interface.cc')
-rw-r--r--tensorflow/core/framework/lookup_interface.cc37
1 files changed, 22 insertions, 15 deletions
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc
index 3322e3e8df..bf3204ea6e 100644
--- a/tensorflow/core/framework/lookup_interface.cc
+++ b/tensorflow/core/framework/lookup_interface.cc
@@ -21,17 +21,14 @@ limitations under the License.
namespace tensorflow {
namespace lookup {
-namespace {
-Status CheckKeyShape(const TensorShape& table_key_shape,
- const TensorShape& key_shape) {
- if (!TensorShapeUtils::EndsWith(key_shape, table_key_shape)) {
- return errors::InvalidArgument("Input key shape ", key_shape.DebugString(),
+Status LookupInterface::CheckKeyShape(const TensorShape& shape) {
+ if (!TensorShapeUtils::EndsWith(shape, key_shape())) {
+ return errors::InvalidArgument("Input key shape ", shape.DebugString(),
" must end with the table's key shape ",
- table_key_shape.DebugString());
+ key_shape().DebugString());
}
return Status::OK();
}
-} // namespace
Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys,
const Tensor& values) {
@@ -46,28 +43,38 @@ Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys,
return Status::OK();
}
-Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key,
- const Tensor& value) {
- TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, value));
- TF_RETURN_IF_ERROR(CheckKeyShape(key_shape(), key.shape()));
+Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys,
+ const Tensor& values) {
+ TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values));
+ TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape()));
- TensorShape expected_value_shape = key.shape();
+ TensorShape expected_value_shape = keys.shape();
for (int i = 0; i < key_shape().dims(); ++i) {
expected_value_shape.RemoveDim(expected_value_shape.dims() - 1);
}
expected_value_shape.AppendShape(value_shape());
- if (value.shape() != expected_value_shape) {
+ if (values.shape() != expected_value_shape) {
return errors::InvalidArgument(
"Expected shape ", expected_value_shape.DebugString(),
- " for value, got ", value.shape().DebugString());
+ " for value, got ", values.shape().DebugString());
}
return Status::OK();
}
+Status LookupInterface::CheckKeyAndValueTensorsForInsert(const Tensor& keys,
+ const Tensor& values) {
+ return CheckKeyAndValueTensorsHelper(keys, values);
+}
+
+Status LookupInterface::CheckKeyAndValueTensorsForImport(const Tensor& keys,
+ const Tensor& values) {
+ return CheckKeyAndValueTensorsHelper(keys, values);
+}
+
Status LookupInterface::CheckFindArguments(const Tensor& key,
const Tensor& default_value) {
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
- TF_RETURN_IF_ERROR(CheckKeyShape(key_shape(), key.shape()));
+ TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
if (default_value.shape() != value_shape()) {
return errors::InvalidArgument(
"Expected shape ", value_shape().DebugString(),