aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/lookup_interface.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-26 08:14:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-26 09:18:06 -0700
commitacd8c859f7c2d077464422cd033efeb5cce4b986 (patch)
treed6ccfcd3ed7d932b7ab9876ff0f327aed7b46a19 /tensorflow/core/framework/lookup_interface.cc
parent0868ce67f4174b2b857641f473a00da81a1f511a (diff)
Add a variant of mutable hash table that supports tensors as values.
Add support for exporting the contents of a table. Change: 125901929
Diffstat (limited to 'tensorflow/core/framework/lookup_interface.cc')
-rw-r--r--tensorflow/core/framework/lookup_interface.cc26
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc
index aafa9e084a..0d20766673 100644
--- a/tensorflow/core/framework/lookup_interface.cc
+++ b/tensorflow/core/framework/lookup_interface.cc
@@ -30,28 +30,30 @@ Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key,
return errors::InvalidArgument("Value must be type ", value_dtype(),
" but got ", value.dtype());
}
- if (key.NumElements() != value.NumElements()) {
- return errors::InvalidArgument("Number of elements of key(",
- key.NumElements(), ") and value(",
- value.NumElements(), ") are different.");
- }
- if (!key.shape().IsSameSize(value.shape())) {
- return errors::InvalidArgument("key and value have different shapes.");
+ TensorShape expected_value_shape = key.shape();
+ expected_value_shape.AppendShape(value_shape());
+ if (value.shape() != expected_value_shape) {
+ return errors::InvalidArgument(
+ "Expected shape ", expected_value_shape.DebugString(),
+ " for value, got ", value.shape().DebugString());
}
return Status::OK();
}
Status LookupInterface::CheckFindArguments(const Tensor& key,
- const Tensor& value,
const Tensor& default_value) {
- TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value));
-
+ if (key.dtype() != key_dtype()) {
+ return errors::InvalidArgument("Key must be type ", key_dtype(),
+ " but got ", key.dtype());
+ }
if (default_value.dtype() != value_dtype()) {
return errors::InvalidArgument("Default value must be type ", value_dtype(),
" but got ", default_value.dtype());
}
- if (!TensorShapeUtils::IsScalar(default_value.shape())) {
- return errors::InvalidArgument("Default values must be scalar.");
+ if (default_value.shape() != value_shape()) {
+ return errors::InvalidArgument(
+ "Expected shape ", value_shape().DebugString(),
+ " for default value, got ", default_value.shape().DebugString());
}
return Status::OK();
}