aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/lookup_interface.cc
blob: c660b84aa0f6198342f1f394d16d842702d33e88 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include "tensorflow/core/framework/lookup_interface.h"

#include "tensorflow/core/lib/core/errors.h"

namespace tensorflow {
namespace lookup {

Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key,
                                                const Tensor& value) {
  if (key.dtype() != key_dtype()) {
    return errors::InvalidArgument("Key must be type ", key_dtype(),
                                   " but got ", key.dtype());
  }
  if (value.dtype() != value_dtype()) {
    return errors::InvalidArgument("Value must be type ", value_dtype(),
                                   " but got ", value.dtype());
  }
  if (key.NumElements() != value.NumElements()) {
    return errors::InvalidArgument("Number of elements of key(",
                                   key.NumElements(), ") and value(",
                                   value.NumElements(), ") are different.");
  }
  if (!key.shape().IsSameSize(value.shape())) {
    return errors::InvalidArgument("key and value have different shapes.");
  }
  return Status::OK();
}

Status LookupInterface::CheckFindArguments(const Tensor& key,
                                           const Tensor& value,
                                           const Tensor& default_value) {
  TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value));

  if (default_value.dtype() != value_dtype()) {
    return errors::InvalidArgument("Default value must be type ", value_dtype(),
                                   " but got ", default_value.dtype());
  }
  if (!TensorShapeUtils::IsScalar(default_value.shape())) {
    return errors::InvalidArgument("Default values must be scalar.");
  }
  return Status::OK();
}

}  // namespace lookup
}  // namespace tensorflow