diff options
author | 2016-06-26 08:14:31 -0800 | |
---|---|---|
committer | 2016-06-26 09:18:06 -0700 | |
commit | acd8c859f7c2d077464422cd033efeb5cce4b986 (patch) | |
tree | d6ccfcd3ed7d932b7ab9876ff0f327aed7b46a19 /tensorflow/core/framework/lookup_interface.cc | |
parent | 0868ce67f4174b2b857641f473a00da81a1f511a (diff) |
Add a variant of mutable hash table that supports tensors as values.
Add support for exporting the contents of a table.
Change: 125901929
Diffstat (limited to 'tensorflow/core/framework/lookup_interface.cc')
-rw-r--r-- | tensorflow/core/framework/lookup_interface.cc | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc index aafa9e084a..0d20766673 100644 --- a/tensorflow/core/framework/lookup_interface.cc +++ b/tensorflow/core/framework/lookup_interface.cc @@ -30,28 +30,30 @@ Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key, return errors::InvalidArgument("Value must be type ", value_dtype(), " but got ", value.dtype()); } - if (key.NumElements() != value.NumElements()) { - return errors::InvalidArgument("Number of elements of key(", - key.NumElements(), ") and value(", - value.NumElements(), ") are different."); - } - if (!key.shape().IsSameSize(value.shape())) { - return errors::InvalidArgument("key and value have different shapes."); + TensorShape expected_value_shape = key.shape(); + expected_value_shape.AppendShape(value_shape()); + if (value.shape() != expected_value_shape) { + return errors::InvalidArgument( + "Expected shape ", expected_value_shape.DebugString(), + " for value, got ", value.shape().DebugString()); } return Status::OK(); } Status LookupInterface::CheckFindArguments(const Tensor& key, - const Tensor& value, const Tensor& default_value) { - TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value)); - + if (key.dtype() != key_dtype()) { + return errors::InvalidArgument("Key must be type ", key_dtype(), + " but got ", key.dtype()); + } if (default_value.dtype() != value_dtype()) { return errors::InvalidArgument("Default value must be type ", value_dtype(), " but got ", default_value.dtype()); } - if (!TensorShapeUtils::IsScalar(default_value.shape())) { - return errors::InvalidArgument("Default values must be scalar."); + if (default_value.shape() != value_shape()) { + return errors::InvalidArgument( + "Expected shape ", value_shape().DebugString(), + " for default value, got ", default_value.shape().DebugString()); } return Status::OK(); } |