aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/lookup_table_op.h
blob: a5bccbddd28d1f8a1ed079bc307d927818a7eba3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
/* Copyright 2015 Google Inc. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
#define TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_

#include "tensorflow/core/framework/lookup_interface.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/kernels/lookup_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"

namespace tensorflow {

// Lookup table op that supports different table implementations specified by
// the 'Container' template. Container must be derived from LookupInterface. The
// key and value are of the templated type "key_dtype" and "value_dtype"
// respectively.
template <class Container, class key_dtype, class value_dtype>
class LookupTableOp : public OpKernel {
 public:
  // ctx is not owned by this class.
  explicit LookupTableOp(OpKernelConstruction* ctx)
      : OpKernel(ctx), table_handle_set_(false) {
    OP_REQUIRES_OK(ctx, ctx->allocate_persistent(tensorflow::DT_STRING,
                                                 tensorflow::TensorShape({2}),
                                                 &table_handle_, nullptr));
  }

  // ctx is not owned by this function.
  void Compute(OpKernelContext* ctx) override {
    mutex_lock l(mu_);
    if (!table_handle_set_) {
      OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
      auto creator = [this](lookup::LookupInterface** ret) {
        *ret = new Container();
        return Status::OK();
      };

      lookup::LookupInterface* table = nullptr;
      OP_REQUIRES_OK(
          ctx, cinfo_.resource_manager()
                   ->template LookupOrCreate<lookup::LookupInterface>(
                       cinfo_.container(), cinfo_.name(), &table, creator));
      core::ScopedUnref unref_me(table);

      OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
                              *table, DataTypeToEnum<key_dtype>::v(),
                              DataTypeToEnum<value_dtype>::v(), cinfo_.name()));

      auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
      h(0) = cinfo_.container();
      h(1) = cinfo_.name();
      table_handle_set_ = true;
    }
    ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx));
  }

  ~LookupTableOp() override {
    // If the table object was not shared, delete it.
    if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
      TF_CHECK_OK(
          cinfo_.resource_manager()->template Delete<lookup::LookupInterface>(
              cinfo_.container(), cinfo_.name()));
    }
  }

 private:
  mutex mu_;
  PersistentTensor table_handle_ GUARDED_BY(mu_);
  bool table_handle_set_ GUARDED_BY(mu_);
  ContainerInfo cinfo_;

  TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp);
};

}  // namespace tensorflow

#endif  // TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_