blob: d4036d201988f90e0da77343416766a5ed3b7c54 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"
namespace tensorflow {
namespace lookup {
// Lookup interface for batch lookups used by table lookup ops.
class LookupInterface : public ResourceBase {
public:
// Performs batch lookups, for every element in the key tensor, Find returns
// the corresponding value into the values tensor.
// If an element is not present in the table, the given default value is used.
// For tables that require initialization, Find is available once the table
// is marked as initialized.
// Returns the following statuses:
// - OK: when the find finishes successfully.
// - FailedPrecondition: if the table is not initialized.
// - InvalidArgument: if any of the preconditions on the lookup key or value
// fails.
// - In addition, other implementations may provide another non-OK status
// specific to their failure modes.
virtual Status Find(const Tensor& keys, Tensor* values,
const Tensor& default_value) = 0;
// Returns the number of elements in the table.
virtual size_t size() const = 0;
// Returns the data type of the key.
virtual DataType key_dtype() const = 0;
// Returns the data type of the value.
virtual DataType value_dtype() const = 0;
string DebugString() override { return "A lookup table"; }
protected:
virtual ~LookupInterface() = default;
// Check format of the key and value tensors.
// Returns OK if all the following requirements are satisfied, otherwise it
// returns InvalidArgument:
// - DataType of the tensor key equals to the table key_dtype
// - DataType of the test value equals to the table value_dtype
// - key and value have the same size and shape
Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values);
// Check the arguments of a find operation. Returns OK if all the following
// requirements are satisfied, otherwise it returns InvalidArgument:
// - All requirements of CheckKeyAndValueTensors
// - default_value type equals to the table value_dtype
// - default_value is scalar
Status CheckFindArguments(const Tensor& keys, const Tensor& values,
const Tensor& default_value);
};
} // namespace lookup
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
|