aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/lookup_interface.h
blob: d4036d201988f90e0da77343416766a5ed3b7c54 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_

#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"

namespace tensorflow {
namespace lookup {

// Lookup interface for batch lookups used by table lookup ops.
class LookupInterface : public ResourceBase {
 public:
  // Performs batch lookups, for every element in the key tensor, Find returns
  // the corresponding value into the values tensor.
  // If an element is not present in the table, the given default value is used.

  // For tables that require initialization, Find is available once the table
  // is marked as initialized.

  // Returns the following statuses:
  // - OK: when the find finishes successfully.
  // - FailedPrecondition: if the table is not initialized.
  // - InvalidArgument: if any of the preconditions on the lookup key or value
  //   fails.
  // - In addition, other implementations may provide another non-OK status
  //   specific to their failure modes.
  virtual Status Find(const Tensor& keys, Tensor* values,
                      const Tensor& default_value) = 0;

  // Returns the number of elements in the table.
  virtual size_t size() const = 0;

  // Returns the data type of the key.
  virtual DataType key_dtype() const = 0;

  // Returns the data type of the value.
  virtual DataType value_dtype() const = 0;

  string DebugString() override { return "A lookup table"; }

 protected:
  virtual ~LookupInterface() = default;

  // Check format of the key and value tensors.
  // Returns OK if all the following requirements are satisfied, otherwise it
  // returns InvalidArgument:
  // - DataType of the tensor key equals to the table key_dtype
  // - DataType of the test value equals to the table value_dtype
  // - key and value have the same size and shape
  Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values);

  // Check the arguments of a find operation. Returns OK if all the following
  // requirements are satisfied, otherwise it returns InvalidArgument:
  // - All requirements of CheckKeyAndValueTensors
  // - default_value type equals to the table value_dtype
  // - default_value is scalar
  Status CheckFindArguments(const Tensor& keys, const Tensor& values,
                            const Tensor& default_value);
};

}  // namespace lookup
}  // namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_