1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
|
#include "tensorflow/core/framework/lookup_interface.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace lookup {
Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key,
const Tensor& value) {
if (key.dtype() != key_dtype()) {
return errors::InvalidArgument("Key must be type ", key_dtype(),
" but got ", key.dtype());
}
if (value.dtype() != value_dtype()) {
return errors::InvalidArgument("Value must be type ", value_dtype(),
" but got ", value.dtype());
}
if (key.NumElements() != value.NumElements()) {
return errors::InvalidArgument("Number of elements of key(",
key.NumElements(), ") and value(",
value.NumElements(), ") are different.");
}
if (!key.shape().IsSameSize(value.shape())) {
return errors::InvalidArgument("key and value have different shapes.");
}
return Status::OK();
}
Status LookupInterface::CheckFindArguments(const Tensor& key,
const Tensor& value,
const Tensor& default_value) {
TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value));
if (default_value.dtype() != value_dtype()) {
return errors::InvalidArgument("Default value must be type ", value_dtype(),
" but got ", default_value.dtype());
}
if (!TensorShapeUtils::IsScalar(default_value.shape())) {
return errors::InvalidArgument("Default values must be scalar.");
}
return Status::OK();
}
} // namespace lookup
} // namespace tensorflow
|