aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-05-11 08:30:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 11:13:19 -0700
commit851982c0e3c543306f3942f92bc4bdbc0d50d1f7 (patch)
tree5121012e0cd4ef5cc2541c2dadc47c0ff729b6e4
parent1e390f5f7992f3ae5f9ecfc91ebe3d711efe0b7d (diff)
[TF:XLA] Add function compilation cache.
PiperOrigin-RevId: 155751485
-rw-r--r--tensorflow/compiler/aot/compile.cc2
-rw-r--r--tensorflow/compiler/jit/kernels/xla_device_launch_op.cc4
-rw-r--r--tensorflow/compiler/jit/kernels/xla_local_launch_op.cc2
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc34
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h13
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc8
-rw-r--r--tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc2
-rw-r--r--tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc2
9 files changed, 53 insertions, 16 deletions
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 4c572fd390..162e719ade 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -301,7 +301,7 @@ Status ConvertGraphToXla(xla::CompileOnlyClient* client,
"tfcompile", std::move(graph),
xla_args, &result));
*has_context_arg = result.requires_runtime_context;
- *computation = std::move(result.computation);
+ *computation = std::move(*result.computation);
int num_const_results = 0;
for (int i = 0; i < result.outputs.size(); ++i) {
diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc
index 497c813d99..cb6c14901b 100644
--- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc
@@ -121,7 +121,7 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
// Runs the computation, if any. There might not be a computation if all
// outputs were compile-time constants.
std::vector<std::unique_ptr<xla::GlobalData>> outputs;
- if (!kernel->computation.IsNull()) {
+ if (!kernel->computation->IsNull()) {
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
// Builds the inputs to the computation.
@@ -152,7 +152,7 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
Env* env = Env::Default();
auto start_time = env->NowMicros();
VLOG(1) << "Executing XLA Computation...";
- auto result = cache->client()->Execute(kernel->computation, arg_ptrs,
+ auto result = cache->client()->Execute(*kernel->computation, arg_ptrs,
&execution_options, &profile);
auto elapsed = env->NowMicros() - start_time;
OP_REQUIRES(ctx, result.ok(), result.status());
diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
index 0a2af4050e..032ded54e6 100644
--- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
@@ -225,7 +225,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
std::unique_ptr<xla::ShapedBuffer> output;
bool output_is_tuple;
- if (!kernel->computation.IsNull()) {
+ if (!kernel->computation->IsNull()) {
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
arg_buffers.reserve(kernel->xla_input_shapes.size() + 1);
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 8238f7d919..82af304169 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -273,7 +273,7 @@ Status XlaCompilationCache::Compile(
*compilation_result = &entry->compilation_result;
if (entry->compilation_status.ok() && executable) {
if (entry->executable == nullptr &&
- !entry->compilation_result.computation.IsNull()) {
+ !entry->compilation_result.computation->IsNull()) {
XlaCompiler compiler(options);
entry->compilation_status = compiler.BuildExecutable(
entry->compilation_result, &entry->executable);
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 5ed4393bc5..d246e7f9ac 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -57,6 +57,18 @@ Status CheckSignature(const DataTypeVector& types,
} // namespace
+bool XlaCompiler::Argument::operator==(
+ const XlaCompiler::Argument& other) const {
+ if (std::tie(kind, type, shape, name) !=
+ std::tie(other.kind, other.type, other.shape, other.name)) {
+ return false;
+ }
+ if (constant_value.shape() != other.constant_value.shape()) {
+ return false;
+ }
+ return constant_value.tensor_data() == other.constant_value.tensor_data();
+}
+
XlaCompiler::XlaCompiler(XlaCompiler::Options options)
: options_(std::move(options)),
initialization_status_(Status::OK()),
@@ -85,6 +97,11 @@ int64 XlaCompiler::NextStepId() {
return next_step_id_++;
}
+uint64 XlaCompiler::SignatureHash::operator()(
+ const std::pair<string, std::vector<Argument>>& signature) const {
+ return std::hash<string>()(signature.first);
+}
+
Status XlaCompiler::CompileFunction(
const XlaCompiler::CompileOptions& options, const NameAttrList& function,
const std::vector<XlaCompiler::Argument>& args,
@@ -92,6 +109,12 @@ Status XlaCompiler::CompileFunction(
const string function_id = Canonicalize(function.name(), function.attr());
VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
+ auto it = cache_.find({function_id, args});
+ if (it != cache_.end()) {
+ *result = it->second;
+ return Status::OK();
+ }
+
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(
flib_runtime_->Instantiate(function.name(), function.attr(), &handle));
@@ -129,6 +152,7 @@ Status XlaCompiler::CompileFunction(
CompileGraph(options, function_id, std::move(graph), args, result));
VLOG(1) << "====================================================";
+ cache_[{function_id, args}] = *result;
return Status::OK();
}
@@ -155,7 +179,7 @@ Status XlaCompiler::BuildExecutable(
build_options.set_has_hybrid_result(
options_.local_executable_has_hybrid_result);
- auto compile_result = local_client->Compile(result.computation,
+ auto compile_result = local_client->Compile(*result.computation,
argument_layouts, build_options);
if (!compile_result.ok()) {
return compile_result.status();
@@ -403,10 +427,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
flib_runtime_.get(), NextStepId()));
int num_nonconst_outputs;
+ result->computation = std::make_shared<xla::Computation>();
TF_RETURN_IF_ERROR(BuildComputation(
context->retvals(), context->variables(), context->has_side_effects(),
options.return_updated_values_for_all_variables, &builder,
- &result->computation, &num_nonconst_outputs, &result->variable_updates));
+ result->computation.get(), &num_nonconst_outputs,
+ &result->variable_updates));
result->requires_runtime_context = context->has_context_parameter();
@@ -427,13 +453,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
}
}
- if (result->computation.IsNull()) {
+ if (result->computation->IsNull()) {
return Status::OK();
}
// Compute the output shapes, if there is a computation with non-constant
// outputs.
- auto computation_shape = client()->GetComputationShape(result->computation);
+ auto computation_shape = client()->GetComputationShape(*result->computation);
if (!computation_shape.ok()) {
return computation_shape.status();
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 1a089155ce..15f723ad78 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -113,6 +113,8 @@ class XlaCompiler {
// The name of this argument, used for debugging.
string name;
+
+ bool operator==(const Argument& other) const;
};
struct OutputDescription {
@@ -173,7 +175,7 @@ class XlaCompiler {
// The XLA computation built from the tensorflow subgraph. May be null
// if the output consists solely of compile-time constants.
- xla::Computation computation;
+ std::shared_ptr<xla::Computation> computation;
};
struct Options {
@@ -290,6 +292,15 @@ class XlaCompiler {
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
+ struct SignatureHash {
+ uint64 operator()(
+ const std::pair<string, std::vector<Argument>>& signature) const;
+ };
+
+ std::unordered_map<std::pair<string, std::vector<Argument>>,
+ CompilationResult, SignatureHash>
+ cache_;
+
std::unordered_map<string, xla::ChannelHandle> channels_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 11c811613e..58d74057d1 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -131,7 +131,7 @@ TEST_F(XlaCompilerTest, EmptyReturnValues) {
/*args=*/{}, &result));
// No computation should be generated.
- EXPECT_EQ(0, result.computation.handle().handle());
+ EXPECT_EQ(0, result.computation->handle().handle());
}
// Tests compilation and execution of a graph that adds two tensors.
@@ -173,7 +173,7 @@ TEST_F(XlaCompilerTest, Simple) {
std::unique_ptr<xla::GlobalData> actual =
client_
- ->Execute(result.computation, {param0_data.get(), param1_data.get()})
+ ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
@@ -230,7 +230,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
- client_->Execute(result.computation, {param0_data.get()})
+ client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
@@ -265,7 +265,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
- client_->Execute(result.computation, {param0_data.get()})
+ client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc
index 87f7cba5f4..302aa6457a 100644
--- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc
+++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc
@@ -228,7 +228,7 @@ ConvertTfGraphToXlaSessionModule(const std::vector<XlaCompiler::Argument>& args,
TF_CHECK_OK(compiler->CompileGraph(XlaCompiler::CompileOptions(), GRAPH_NAME,
std::move(graph), args, &result));
- return result.computation.Snapshot();
+ return result.computation->Snapshot();
}
xla::StatusOr<std::unordered_map<int64, XlaNode>>
diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
index 7162627562..beb4c8009b 100644
--- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
+++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
@@ -79,7 +79,7 @@ static void DumpHloGraphForDebug(const std::vector<XlaCompiler::Argument>& args,
std::move(graph), args, &result));
// Convert to hlo
- xla::Computation& computation = result.computation;
+ xla::Computation& computation = *result.computation;
xla::Service* service(
static_cast<xla::Service*>(xla::ClientLibrary::GetXlaService(