diff options
Diffstat (limited to 'tensorflow/core/framework/lookup_interface.cc')
-rw-r--r-- | tensorflow/core/framework/lookup_interface.cc | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc new file mode 100644 index 0000000000..c660b84aa0 --- /dev/null +++ b/tensorflow/core/framework/lookup_interface.cc @@ -0,0 +1,45 @@ +#include "tensorflow/core/framework/lookup_interface.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace lookup { + +Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key, + const Tensor& value) { + if (key.dtype() != key_dtype()) { + return errors::InvalidArgument("Key must be type ", key_dtype(), + " but got ", key.dtype()); + } + if (value.dtype() != value_dtype()) { + return errors::InvalidArgument("Value must be type ", value_dtype(), + " but got ", value.dtype()); + } + if (key.NumElements() != value.NumElements()) { + return errors::InvalidArgument("Number of elements of key(", + key.NumElements(), ") and value(", + value.NumElements(), ") are different."); + } + if (!key.shape().IsSameSize(value.shape())) { + return errors::InvalidArgument("key and value have different shapes."); + } + return Status::OK(); +} + +Status LookupInterface::CheckFindArguments(const Tensor& key, + const Tensor& value, + const Tensor& default_value) { + TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value)); + + if (default_value.dtype() != value_dtype()) { + return errors::InvalidArgument("Default value must be type ", value_dtype(), + " but got ", default_value.dtype()); + } + if (!TensorShapeUtils::IsScalar(default_value.shape())) { + return errors::InvalidArgument("Default values must be scalar."); + } + return Status::OK(); +} + +} // namespace lookup +} // namespace tensorflow |