diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_compilation_cache.h')
-rw-r--r-- | tensorflow/compiler/jit/xla_compilation_cache.h | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h new file mode 100644 index 0000000000..44d76db0fd --- /dev/null +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// The XlaCompilationCache class caches the results of the XlaCompiler class, +// which converts a Tensorflow graph into a compiled XLA compilation. +// +// Since XLA computations must have static shapes, the cache generates a new +// XLA computation for each new set of input shapes. +// +// Currently no cache eviction policy is implemented and the cache grows without +// bound. +class XlaCompilationCache : public ResourceBase { + public: + explicit XlaCompilationCache(const XlaCompiler::Options& options); + ~XlaCompilationCache() override; + + // Compiles a function into a XlaCompiler::CompilationResult that can be used + // to execute an XLA Computation. `compilation_result` must be non-null. + // If `executable` is non-null, also builds an xla::LocalExecutable and sets + // `executable to point to it. The resulting executable pointer may be null if + // the computation has no non-constant outputs. + // Compilation results are cached. + Status Compile(const NameAttrList& function, int num_constant_args, + OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable); + + xla::Client* client() const { return compiler_.client(); } + + string DebugString() override; + + private: + XlaCompiler compiler_; + std::unique_ptr<FunctionLibraryRuntime> function_library_runtime_; + + // Describes the types, shapes and any compile-time constant arguments + // to a kernel. + struct Signature { + string name; + + std::vector<std::pair<DataType, TensorShape>> arg_types; + + // List of (argument #, value) pairs for arguments whose values are + // part of the JIT signature, and that are therefore constants in any given + // JIT compilation. Tensors must be in host memory. + std::vector<std::pair<int, Tensor>> arg_values; + + bool operator==(const Signature& other) const; + + struct Hash { + uint64 operator()(const Signature& signature) const; + }; + }; + static string SignatureDebugString(const Signature& sig); + + // The value associated with a cache entry. + struct Entry { + mutex mu; + + // Have we tried compiling this entry? + bool compiled = false; + + // Did compilation succeed? + Status compilation_status GUARDED_BY(mu); + + // Output of the XlaCompiler. + XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu); + + // The XLA executable compiled from <computation>. May be null if no + // executable has been built. + std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu); + }; + + mutex mu_; + std::unordered_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ + GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ |