aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/initializable_lookup_table.h
blob: a14d4967a59f53668a4a4c7135e79ed046666edb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
#define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_

#include "tensorflow/core/framework/lookup_interface.h"
#include "tensorflow/core/platform/macros.h"

namespace tensorflow {
namespace lookup {

// Base class for lookup tables that require initialization.
class InitializableLookupTable : public LookupInterface {
 public:
  class InitTableIterator;

  // Performs batch lookups, for every element in the key tensor, Find returns
  // the corresponding value into the values tensor.
  // If an element is not present in the table, the given default value is used.
  //
  // For tables that require initialization, `Find` is available once the table
  // is marked as initialized.
  //
  // Returns the following statuses:
  // - OK: when the find finishes successfully.
  // - FailedPrecondition: if the table is not initialized.
  // - InvalidArgument: if any of the preconditions on the lookup key or value
  //   fails.
  // - In addition, other implementations may provide another non-OK status
  //   specific to their failure modes.
  Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
              const Tensor& default_value) final;

  // Returns errors::Unimplemented.
  Status Insert(OpKernelContext* ctx, const Tensor& keys,
                const Tensor& values) final {
    return errors::Unimplemented(
        "Insert not supported by InitializableLookupTable implementations");
  }

  // Returns errors::Unimplemented.
  Status Remove(OpKernelContext* ctx, const Tensor& keys) final {
    return errors::Unimplemented(
        "Remove not supported by InitializableLookupTable implementations");
  }

  Status ExportValues(OpKernelContext* context) override {
    return errors::Unimplemented(
        "ExportValues not supported by InitializableLookupTable "
        "implementations");
  }

  Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
                      const Tensor& values) final;

  TensorShape key_shape() const final { return TensorShape(); }

  TensorShape value_shape() const final { return TensorShape(); }

  // Returns whether the table was initialized and is ready to serve lookups.
  bool is_initialized() const { return is_initialized_; }

  // Initializes the table from the given init table iterator.
  //
  // Atomically, this operation prepares the table, populates it with the given
  // iterator, and mark the table as initialized.
  //
  // Returns the following statuses:
  // - OK: when the initialization was successful.
  // - InvalidArgument: if any of the preconditions on the lookup key or value
  //   fails.
  // - FailedPrecondition: if the table is already initialized and
  //   fail_if_initialized is set to true.
  // - In addition, other implementations may provide another non-OK status
  //   specific to their failure modes.
  Status Initialize(InitTableIterator& iter);

  // Basic iterator to initialize lookup tables.
  // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that
  // the consumer may insert key-value pairs in batches.
  //
  // Then the iterator is exhausted, valid returns false and status returns
  // Status::OutOfRange.
  //
  // This class is Thread-unsafe.
  class InitTableIterator {
   public:
    InitTableIterator() {}

    virtual ~InitTableIterator() {}

    // Prepares the next batch of key and value tensors.
    virtual void Next() = 0;

    // Returns true if keys and values point to valid tensors.
    virtual bool Valid() const = 0;

    // Returns a tensor that contains the current batch of 'key' values.
    virtual const Tensor& keys() const = 0;

    // Returns a tensor that contains the current batch of 'value' values.
    virtual const Tensor& values() const = 0;

    // Returns an error if one has occurred, otherwise returns Status::OK.
    virtual Status status() const = 0;

    // Returns the total number of elements that the iterator will produce.
    // It might return -1 in case of error.
    virtual int64 total_size() const = 0;

   private:
    TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator);
  };

  InitializableLookupTable* GetInitializableLookupTable() override {
    return this;
  }

 protected:
  // Prepares and allocates the underlying data structure to store the given
  // number of expected elements.
  virtual Status DoPrepare(size_t expected_num_elements) = 0;

  // Same as DoPrepare() but derived implementations might choose to skip
  // calling get_expected_num_elements if size is not needed for DoPrepare.
  virtual Status DoLazyPrepare(
      std::function<int64(void)> get_expected_num_elements) {
    int64 expected_num_elements = get_expected_num_elements();
    if (expected_num_elements < 0) {
      return errors::FailedPrecondition("Got negative expected_num_elements.");
    }
    return DoPrepare(expected_num_elements);
  }

  // Populates the table in batches given keys and values as tensors into the
  // underlying data structure.
  virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0;

  // Performs the batch find operation on the underlying data structure.
  virtual Status DoFind(const Tensor& keys, Tensor* values,
                        const Tensor& default_value) = 0;

  mutex mu_;
  bool is_initialized_ = false;
};

// Iterator to initialize tables given 'keys' and 'values' tensors.
//
// The two tensors are returned in the first iteration. It doesn't loop
// over each element of the tensor since insertions in the lookup table can
// process batches.
class KeyValueTensorIterator
    : public InitializableLookupTable::InitTableIterator {
 public:
  // keys and values are not owned by the iterator.
  explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
      : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
    TensorShape key_shape = keys_->shape();
    if (!key_shape.IsSameSize(values_->shape())) {
      valid_ = false;
      status_ = errors::InvalidArgument(
          "keys and values should have the same dimension.",
          key_shape.DebugString(), " vs ", values_->shape().DebugString());
    }
    if (key_shape.num_elements() == 0) {
      valid_ = false;
      status_ =
          errors::InvalidArgument("keys and values cannot be empty tensors.");
    }
  }

  bool Valid() const override { return valid_; }

  void Next() override {
    valid_ = false;
    status_ = errors::OutOfRange("No more data.");
  }

  const Tensor& keys() const override { return *keys_; }

  const Tensor& values() const override { return *values_; }

  Status status() const override { return status_; }

  int64 total_size() const override {
    return keys_ == nullptr ? -1 : keys_->NumElements();
  }

 private:
  TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);

  const Tensor* keys_;    // Doesn't own it.
  const Tensor* values_;  // Doesn't own it.
  bool valid_;            // true if the iterator points to an existing range.
  Status status_;
};

}  // namespace lookup
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_