aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/lookup_interface.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/lookup_interface.cc')
-rw-r--r--tensorflow/core/framework/lookup_interface.cc45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc
new file mode 100644
index 0000000000..c660b84aa0
--- /dev/null
+++ b/tensorflow/core/framework/lookup_interface.cc
@@ -0,0 +1,45 @@
+#include "tensorflow/core/framework/lookup_interface.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace lookup {
+
+Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key,
+ const Tensor& value) {
+ if (key.dtype() != key_dtype()) {
+ return errors::InvalidArgument("Key must be type ", key_dtype(),
+ " but got ", key.dtype());
+ }
+ if (value.dtype() != value_dtype()) {
+ return errors::InvalidArgument("Value must be type ", value_dtype(),
+ " but got ", value.dtype());
+ }
+ if (key.NumElements() != value.NumElements()) {
+ return errors::InvalidArgument("Number of elements of key(",
+ key.NumElements(), ") and value(",
+ value.NumElements(), ") are different.");
+ }
+ if (!key.shape().IsSameSize(value.shape())) {
+ return errors::InvalidArgument("key and value have different shapes.");
+ }
+ return Status::OK();
+}
+
+Status LookupInterface::CheckFindArguments(const Tensor& key,
+ const Tensor& value,
+ const Tensor& default_value) {
+ TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value));
+
+ if (default_value.dtype() != value_dtype()) {
+ return errors::InvalidArgument("Default value must be type ", value_dtype(),
+ " but got ", default_value.dtype());
+ }
+ if (!TensorShapeUtils::IsScalar(default_value.shape())) {
+ return errors::InvalidArgument("Default values must be scalar.");
+ }
+ return Status::OK();
+}
+
+} // namespace lookup
+} // namespace tensorflow