aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/lookup_interface.h
blob: d33945fd1b0c44264855ed518714eb35faf4b29f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
#define TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_

#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {

class OpKernelContext;

namespace lookup {

// Forward declaration so we can define GetInitializableLookupTable() in
// LookupInterface.
class InitializableLookupTable;

// Lookup interface for batch lookups used by table lookup ops.
class LookupInterface : public ResourceBase {
 public:
  // Performs batch lookups, for every element in the key tensor, Find returns
  // the corresponding value into the values tensor.
  // If an element is not present in the table, the given default value is used.

  // For tables that require initialization, Find is available once the table
  // is marked as initialized.

  // Returns the following statuses:
  // - OK: when the find finishes successfully.
  // - FailedPrecondition: if the table is not initialized.
  // - InvalidArgument: if any of the preconditions on the lookup key or value
  //   fails.
  // - In addition, other implementations may provide another non-OK status
  //   specific to their failure modes.
  virtual Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
                      const Tensor& default_value) = 0;

  // Inserts elements into the table. Each element of the key tensor is
  // associated with the corresponding element in the value tensor.
  // This method is only implemented in mutable tables that can be updated over
  // the execution of the graph. It returns Status::NotImplemented for read-only
  // tables that are initialized once before they can be looked up.

  // Returns the following statuses:
  // - OK: when the insert finishes successfully.
  // - InvalidArgument: if any of the preconditions on the lookup key or value
  //   fails.
  // - Unimplemented: if the table does not support insertions.
  virtual Status Insert(OpKernelContext* ctx, const Tensor& keys,
                        const Tensor& values) = 0;

  // Removes elements from the table.
  // This method is only implemented in mutable tables that can be updated over
  // the execution of the graph. It returns Status::NotImplemented for read-only
  // tables that are initialized once before they can be looked up.

  // Returns the following statuses:
  // - OK: when the remove finishes successfully.
  // - InvalidArgument: if any of the preconditions on the lookup key fails.
  // - Unimplemented: if the table does not support removals.
  virtual Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0;

  // Returns the number of elements in the table.
  virtual size_t size() const = 0;

  // Exports the values of the table to two tensors named keys and values.
  // Note that the shape of the tensors is completely up to the implementation
  // of the table and can be different than the tensors used for the Insert
  // function above.
  virtual Status ExportValues(OpKernelContext* ctx) = 0;

  // Imports previously exported keys and values.
  // As mentioned above, the shape of the keys and values tensors are determined
  // by the ExportValues function above and can be different than for the
  // Insert function.
  virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
                              const Tensor& values) = 0;

  // Returns the data type of the key.
  virtual DataType key_dtype() const = 0;

  // Returns the data type of the value.
  virtual DataType value_dtype() const = 0;

  // Returns the shape of a key in the table.
  virtual TensorShape key_shape() const = 0;

  // Returns the shape of a value in the table.
  virtual TensorShape value_shape() const = 0;

  // Check format of the key and value tensors for the Insert function.
  // Returns OK if all the following requirements are satisfied, otherwise it
  // returns InvalidArgument:
  // - DataType of the tensor keys equals to the table key_dtype
  // - DataType of the tensor values equals to the table value_dtype
  // - the values tensor has the required shape given keys and the tables's
  //   value shape.
  virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys,
                                                  const Tensor& values);

  // Similar to the function above but instead checks eligibility for the Import
  // function.
  virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
                                                  const Tensor& values);

  // Check format of the key tensor for the Remove function.
  // Returns OK if all the following requirements are satisfied, otherwise it
  // returns InvalidArgument:
  // - DataType of the tensor keys equals to the table key_dtype
  virtual Status CheckKeyTensorForRemove(const Tensor& keys);

  // Check the arguments of a find operation. Returns OK if all the following
  // requirements are satisfied, otherwise it returns InvalidArgument:
  // - DataType of the tensor keys equals to the table key_dtype
  // - DataType of the tensor default_value equals to the table value_dtype
  // - the default_value tensor shape matches the table's value shape.
  Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);

  string DebugString() override {
    return strings::StrCat("A lookup table of size: ", size());
  }

  // Returns an InitializableLookupTable, a subclass of LookupInterface, if the
  // current object is an InitializableLookupTable. Otherwise, returns nullptr.
  virtual InitializableLookupTable* GetInitializableLookupTable() {
    return nullptr;
  }

 protected:
  virtual ~LookupInterface() = default;

  // Makes sure that the key and value tensor DataType's match the table
  // key_dtype and value_dtype.
  Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values);

  // Makes sure that the provided shape is consistent with the table keys shape.
  Status CheckKeyShape(const TensorShape& shape);

 private:
  Status CheckKeyAndValueTensorsHelper(const Tensor& keys,
                                       const Tensor& values);
};

}  // namespace lookup
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_