aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_compilation_cache.h
blob: 75c7758f730f9f2f8251c02e7fac1a01f8cc9c2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_

#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"

namespace tensorflow {

// Struct that represents a possibly-absent Tensor.
struct OptionalTensor {
  string name;           // A descriptive name
  bool present = false;  // Is the tensor present?
  Tensor value;          // If present, what is the Tensor's value?
};

// The XlaCompilationCache class caches the results of the XlaCompiler class,
// which converts a Tensorflow graph into a compiled XLA compilation.
//
// Since XLA computations must have static shapes, the cache generates a new
// XLA computation for each new set of input shapes.
//
// Currently no cache eviction policy is implemented and the cache grows without
// bound.
class XlaCompilationCache : public ResourceBase {
 public:
  XlaCompilationCache(xla::LocalClient* client, DeviceType device_type);
  ~XlaCompilationCache() override;

  // Compiles a function into a XlaCompiler::CompilationResult that can be used
  // to execute an XLA Computation. Compilation results are cached.
  // `function` is the name of a Tensorflow function to compile.
  // `constant_args` is a map of tensorflow argument number to its constant
  //  value.
  // `variable_args` is a snapshot of the current values of the
  // resource variable arguments to `function`; uninitialized variables are
  // represented by an absent OptionalTensor.
  // The result of compilation is written to `*compilation_result`, which must
  // be non-null. If `executable` is non-null, also builds an
  // xla::LocalExecutable and sets `executable` to point to it. The resulting
  // executable pointer may be null if the computation has no non-constant
  // outputs.
  Status Compile(const XlaCompiler::Options& options,
                 const NameAttrList& function,
                 const std::map<int, Tensor>& constant_args,
                 const std::map<int, OptionalTensor>& variable_args,
                 OpKernelContext* ctx,
                 const XlaCompiler::CompileOptions& compile_options,
                 const XlaCompiler::CompilationResult** out_compilation_result,
                 xla::LocalExecutable** out_executable);

  // As above, but calls XlaCompiler::CompileSingleOp instead of
  // XlaCompiler::CompileFunction.
  Status CompileSingleOp(
      const XlaCompiler::Options& options,
      const std::map<int, Tensor>& constant_args,
      const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
      const XlaCompiler::CompileOptions& compile_options,
      const XlaCompiler::CompilationResult** out_compilation_result,
      xla::LocalExecutable** out_executable);

  xla::LocalClient* client() const { return client_; }
  const DeviceType& device_type() const { return device_type_; }

  string DebugString() override;

 private:
  // Common implementation of Compile and CompileSingleOp.
  Status CompileImpl(
      const XlaCompiler::Options& options, const NameAttrList& function,
      const std::map<int, Tensor>& constant_args,
      const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
      const XlaCompiler::CompileOptions& compile_options,
      bool compile_single_op,
      const XlaCompiler::CompilationResult** out_compilation_result,
      xla::LocalExecutable** out_executable);

  // Takes `result` which has been compiled from a Tensorflow subgraph to a
  // XLA computation already, and generates an XLA LocalExecutable `executable`.
  Status BuildExecutable(const XlaCompiler::Options& options,
                         const XlaCompiler::CompilationResult& result,
                         std::unique_ptr<xla::LocalExecutable>* executable);

  xla::LocalClient* const client_;
  const DeviceType device_type_;

  // Describes the types, shapes and any compile-time constant arguments
  // to a kernel. Key that uniquely identifies a compilation output.
  struct Signature {
    string name;

    std::vector<std::pair<DataType, TensorShape>> arg_types;

    // List of Tensor values for compile-time constant arguments to the
    // compilation, ordered by argument number. Tensors must be in host memory.
    std::vector<Tensor> arg_values;

    bool operator==(const Signature& other) const;

    struct Hash {
      uint64 operator()(const Signature& signature) const;
    };
  };
  static string SignatureDebugString(const Signature& sig);

  // Builds the signature for a compilation.
  Status BuildSignature(const NameAttrList& function,
                        const std::map<int, Tensor>& constant_args,
                        const std::map<int, OptionalTensor>& variable_args,
                        OpKernelContext* ctx, Signature* signature);

  // The value associated with a cache entry.
  struct Entry {
    mutex mu;

    // Have we tried compiling this entry?
    bool compiled = false;

    // Did compilation succeed?
    Status compilation_status GUARDED_BY(mu);

    // Output of the XlaCompiler.
    XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);

    // The XLA executable compiled from <computation>. May be null if no
    // executable has been built.
    std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
  };

  mutex compile_cache_mu_;
  absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
      GUARDED_BY(compile_cache_mu_);

  struct CompileStats {
    // Number of times the cluster has been (re-)compiled.
    int64 compile_count = 0;

    // Cumulative time spent compiling the cluster.
    int64 cumulative_compile_time_us = 0;
  };
  mutex compile_stats_mu_;

  // Maps cluster names to compilation statistics for said cluster.
  absl::flat_hash_map<string, CompileStats> compile_stats_
      GUARDED_BY(compile_stats_mu_);

  TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
};

}  // namespace tensorflow

#endif  // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_