aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-05-08 16:43:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-08 17:09:23 -0700
commit14d5f219f33b1ab8e0a67b84d97204d046adb91f (patch)
treeb887f04458ef204e522d2b3d81d15128104b397c /tensorflow/compiler/jit
parent79b773a4395caf7f0b17ce9ac84a1f34dd277bb9 (diff)
Make eager functions runable on TPU
PiperOrigin-RevId: 195897321
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/BUILD24
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc207
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.h35
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op_test.cc145
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc90
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h51
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc18
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h15
9 files changed, 490 insertions, 98 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 07136d6a74..a6b3ce394c 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -261,6 +261,7 @@ cc_library(
name = "create_xla_launch_op",
srcs = [
"create_xla_launch_op.cc",
+ "create_xla_launch_op.h",
],
deps = [
":common",
@@ -270,6 +271,29 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "create_xla_launch_op_test",
+ srcs = [
+ "create_xla_launch_op.h",
+ "create_xla_launch_op_test.cc",
+ ],
+ deps = [
+ ":create_xla_launch_op",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 18d901323f..f35e916eb9 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -25,78 +27,189 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
-// 'ndef' is a call to a compilable function defined in 'flr', returns OK
-// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
-// node. Otherwise, returns a non-OK.
+// Utility which searches for values in a sorted list by scanning over it once.
+// No matter how many times ScanForValue is called, the list is scanned at most
+// once. However, if a call to ScanForValue skips over a value, that value is
+// not revisited in future calls to ScanForValue, so callers must take
+// care to order their calls.
//
-// This routine is here so that FunctionLibraryRuntime can jit a
-// specific function call as requested.
-Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
- std::unique_ptr<OpKernel>* kernel) {
- bool xla_compile = false;
- if (!flr->GetFunctionLibraryDefinition()
- ->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
- .ok() ||
- !xla_compile) {
- // Not marked as _XlaCompile=true.
- return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
+// Useful for merging multiple sorted lists in O(n) time.
+class SinglePassSearch {
+ public:
+ // Creates a SinglePassSearch object that can be used to search in `values`.
+ // Does not take ownership of `values`. `values` must outlive this.
+ // `values` must be sorted.
+ explicit SinglePassSearch(const std::vector<int>* values)
+ : current_index_(0), values_(values) {}
+
+ // Scans forward in the vector looking for "value", updating the internal
+ // position in to the vector.
+ // Returns true iff the vector contains the given value at or after current
+ // position.
+ // Not thread-safe.
+ bool ScanForValue(int value) {
+ while (current_index_ < values_->size() &&
+ (*values_)[current_index_] <= value) {
+ if ((*values_)[current_index_] == value) {
+ current_index_++;
+ return true;
+ }
+ current_index_++;
+ }
+ return false;
}
- // Make sure that kernels have been registered on the JIT device.
- XlaOpRegistry::RegisterCompilationKernels();
- if (!IsCompilable(flr, ndef)) {
- // ndef is calling a function that XLA can't compile.
- return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
+
+ private:
+ int current_index_;
+ const std::vector<int>* values_;
+};
+
+Status CompilationRequested(const FunctionLibraryRuntime& flr,
+ const NodeDef& node_def) {
+ bool xla_compile = false;
+ // Check if op is marked _XlaCompile=true.
+ Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
+ node_def, kXlaCompileAttr, &xla_compile);
+ if (!status.ok() || !xla_compile) {
+ if (VLOG_IS_ON(3)) {
+ if (!status.ok()) {
+ VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
+ << node_def.op() << ". status=" << status.ToString();
+ } else {
+ VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
+ }
+ }
+ return Status(error::INVALID_ARGUMENT, "");
}
+ return Status::OK();
+}
+
+// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
+// runtime, returns this function's body in `fbody` as well as the indices
+// of its constant and resource arguments.
+// `fbody` is owned by `flr`.
+// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
+// They are sorted in ascending order on this function's return.
+Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
+ const NodeDef& node_def,
+ const FunctionBody** fbody,
+ std::vector<int>* constant_arg_indices,
+ std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
- // If ndef is not instantiable, e.g., the function does not exist,
+ // If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(
- flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
- const FunctionBody* fbody = flr->GetFunctionBody(handle);
- CHECK(fbody); // Can't be nullptr since we just instantiated it.
- std::vector<bool> const_args(fbody->arg_types.size());
+ flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
+ *fbody = flr->GetFunctionBody(handle);
+ CHECK(*fbody); // Can't be nullptr since we just instantiated it.
+ const DataTypeVector& arg_types = (*fbody)->arg_types;
+ std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
- // There is a const arg. Bail out.
- return errors::InvalidArgument("Const arg: ", i, " in ",
- DebugString(fbody->fdef));
+ constant_arg_indices->push_back(i);
+ }
+ }
+
+ // There can be hundreds of resource variables. Reserve the space for them.
+ // We don't reserve for constants above as they are usually few.
+ resource_arg_indices->reserve(arg_types.size());
+ for (int i = 0; i < arg_types.size(); ++i) {
+ if (arg_types[i] == DT_RESOURCE) {
+ resource_arg_indices->push_back(i);
}
}
- NodeDef launch_def;
- launch_def.set_name(ndef.name());
- launch_def.set_op("_XlaLaunch");
- launch_def.set_device(flr->device()->name());
- AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
- AddNodeAttr("Nresources", 0, &launch_def);
- AddNodeAttr("Targs", fbody->arg_types, &launch_def);
- AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
- NameAttrList func;
- func.set_name(ndef.op());
- *(func.mutable_attr()) = ndef.attr();
- AddNodeAttr("function", func, &launch_def);
-
- // TODO(b/32387911): Handles the host memory types across function
- // calls properly. For now, we assume all inputs and outputs are on
- // the device memory.
+ return Status::OK();
+}
+
+} // namespace
+
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+ std::unique_ptr<OpKernel>* kernel) {
+ TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
+
+ VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
+
+ // Make sure that kernels have been registered on the JIT device.
+ XlaOpRegistry::RegisterCompilationKernels();
+ if (!IsCompilable(flr, node_def)) {
+ // node_def is calling a function that XLA can't compile.
+ return errors::InvalidArgument("Not compilable: ",
+ node_def.ShortDebugString());
+ }
+
+ // Get function body, constant args, and resource args.
+ const FunctionBody* fbody = nullptr;
+ std::vector<int> constant_arg_indices;
+ std::vector<int> resource_arg_indices;
+ TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
+ flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
+
+ // Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
+ // These indices are used only for optimization purposes. They allow us
+ // to loop over constant_arg_indices and resource_arg_indices only once
+ // while iterating over all the function arguments checking if it is a
+ // resource or a constant.
+ // The reason we optimized this code is because functions can have a lot of
+ // captured arguments. For example, the backward pass of ResNet50 takes in all
+ // 214 variables and a similar number of activations.
+ SinglePassSearch constants_search(&constant_arg_indices);
+ SinglePassSearch resources_search(&resource_arg_indices);
+ for (int i = 0; i < fbody->arg_types.size(); ++i) {
+ if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
+ // Compile-time constants and resource handles are expected to be in
+ // host memory.
+ input_memory_types[i] = HOST_MEMORY;
+ }
+ }
+ // One might wonder, about the case where a compile-time constant argument
+ // (which must be in host memory) is also used as an input into an op,
+ // e.g. Add, that expects its inputs in device memory. Here is how it
+ // works now.
+ // First, what do we mean by "op expects an input in XYZ memory"?
+ // There are two types of "ops" here: the tf2xla kernel and the HLO
+ // computation it builds. The tf2xla kernel needs to retrieve the actual
+ // numeric value of the compile-time constant tensors, so it really expects
+ // them to be on in host memory. However, for other inputs, it refers to them
+ // using xla::ComputationDataHandle, which is just a symbolic handle that
+ // xla::ComputationBuilder assigns. How does this handle gets assigned for
+ // constant arguments? Even constant arguments get an _Arg node in the graph
+ // instatiated for Function compilation. The tf2xla kernel for constant _Arg
+ // nodes takes the constant value, converts it to XlaLiteral, and feeds it
+ // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
+ // constant XlaLiteral is included in the HLO graph, and subsequently, in
+ // the actual executable, which is copied to the device before being
+ // executed. Thus, when this executable runs, the constant is available in
+ // device memory.
+
+ // XlaLaunch kernel keeps all outputs (including constants, which it copies),
+ // in device memory
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
+ // Create the kernel.
+ NameAttrList function;
+ function.set_name(node_def.op());
+ *(function.mutable_attr()) = node_def.attr();
+
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
- dev->GetAllocator(AllocatorAttributes()), &launch_def,
+ dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
- kernel->reset(new XlaLocalLaunchOp(&construction));
+
+ *kernel = absl::make_unique<XlaLocalLaunchBase>(
+ &construction, constant_arg_indices, resource_arg_indices, function);
return s;
}
+namespace {
+
bool RegisterLaunchOpCreator() {
RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
return true;
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h
new file mode 100644
index 0000000000..98a22e3515
--- /dev/null
+++ b/tensorflow/compiler/jit/create_xla_launch_op.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class FunctionLibraryRuntime;
+class OpKernel;
+
+// Given a NodeDef 'node_def' and the function library runtime 'flr', if
+// 'node_def' is a call to a compilable function defined in 'flr', returns OK
+// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
+// node. Otherwise, returns a non-OK.
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+ std::unique_ptr<OpKernel>* kernel);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
new file mode 100644
index 0000000000..bcd5e75c7e
--- /dev/null
+++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+NodeDef ToNodeDef(const string& text) {
+ NodeDef node_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
+ return node_def;
+}
+
+// Create a FunctionDef that takes one resource and one regular param
+FunctionDef XTimesY() {
+ return FunctionDefHelper::Define(
+ // Name
+ "XTimesY",
+ // Args
+ {"x: float", "y: resource"},
+ // Return values
+ {"z: float"},
+ // Attr def
+ {},
+ // Nodes
+ {
+ {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
+ {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
+ });
+}
+
+class CreateXlaLaunchOpTest : public ::testing::Test {
+ protected:
+ void Init(const std::vector<FunctionDef>& flib) {
+ SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", 1});
+ TF_CHECK_OK(DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices_));
+
+ FunctionDefLibrary proto;
+ for (const auto& fdef : flib) {
+ *(proto.add_function()) = fdef;
+ }
+ lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
+ OpRegistry::Global(), proto);
+ OptimizerOptions opts;
+ device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
+ pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
+ device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
+ opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
+ flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
+ }
+
+ FunctionLibraryRuntime* flr_;
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<FunctionLibraryDefinition> lib_def_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+
+ std::unique_ptr<OpKernel> kernel_;
+};
+
+AttrValue BoolAttr(bool b) {
+ AttrValue v;
+ v.set_b(b);
+ return v;
+}
+
+TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) {
+ FunctionDef fdef = XTimesY();
+ (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(
+ flr_, ToNodeDef(R"pb(
+ name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
+ )pb"), &kernel_);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ EXPECT_EQ("XTimesY", kernel_->name());
+ EXPECT_EQ("XTimesY", kernel_->type_string());
+
+ EXPECT_EQ(2, kernel_->num_inputs());
+ EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
+ EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
+ EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
+ EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
+
+ EXPECT_EQ(1, kernel_->num_outputs());
+ EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
+ EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) {
+ FunctionDef fdef = XTimesY();
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+ name: 'XTimesY'
+ op: 'XTimesY'
+ input: 'a'
+ input: 'b'
+ )proto"), &kernel_);
+ EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) {
+ FunctionDef fdef = XTimesY();
+ (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+ name: 'XTimesY'
+ op: 'XTimesY'
+ input: 'a'
+ input: 'b'
+ )proto"), &kernel_);
+ EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 049d170fa4..86a9fd3b8e 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -39,15 +39,15 @@ limitations under the License.
namespace tensorflow {
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), device_type_(ctx->device_type()) {
- const NameAttrList* func;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
- function_ = *func;
- DataTypeVector constant_types;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
- num_constant_args_ = constant_types.size();
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function)
+ : OpKernel(ctx),
+ constants_(constants),
+ resources_(resources),
+ device_type_(ctx->device_type()),
+ function_(function) {
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
@@ -57,8 +57,8 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
}
}
-Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
+Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache) {
const XlaDevice::Metadata* metadata;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
if (s.ok()) {
@@ -90,8 +90,8 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
return Status::OK();
}
-void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOp::Compute "
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOpBase::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
@@ -124,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
}
std::map<int, OptionalTensor> variables =
- SnapshotResourceVariables(ctx, num_resource_args_);
+ SnapshotResourceVariables(ctx, resources_);
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
@@ -161,7 +161,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
std::map<int, Tensor> constant_args;
- for (int i = 0; i < num_constant_args_; ++i) {
+ for (int i : constants_) {
constant_args.insert({i, ctx->input(i)});
}
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
@@ -170,8 +170,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
- XlaComputationLaunchContext launch_context(
- num_resource_args_, client, xla_allocator, allocate_xla_tensors);
+ XlaComputationLaunchContext launch_context(client, xla_allocator,
+ allocate_xla_tensors);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
@@ -194,6 +194,62 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Done";
}
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return RET; \
+ } \
+ } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+ std::vector<int> constants(constant_types.size());
+ std::iota(constants.begin(), constants.end(), 0);
+ return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+
+ DataTypeVector arg_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Targs", &arg_types));
+
+ int num_resources;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Nresources", &num_resources));
+
+ std::vector<int> resources(num_resources);
+ std::iota(resources.begin(), resources.end(),
+ constant_types.size() + arg_types.size());
+ return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+ return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+} // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+ FunctionAttr(ctx)) {}
+
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
}
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
index 8f8e646f0f..8dfc4b382d 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h
@@ -26,6 +26,41 @@ limitations under the License.
namespace tensorflow {
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+ XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function);
+ XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+ XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+ ~XlaLocalLaunchBase() override = default;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ // Builds a XlaCompilationCache class suitable for the current device.
+ Status BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache);
+
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ DeviceType device_type_;
+ NameAttrList function_;
+ se::Platform::Id platform_id_;
+};
+
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
@@ -35,26 +70,12 @@ namespace tensorflow {
// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
// memory.
-class XlaLocalLaunchOp : public OpKernel {
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
public:
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
~XlaLocalLaunchOp() override;
- void Compute(OpKernelContext* ctx) override;
-
private:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** compiler);
-
- DeviceType device_type_;
- NameAttrList function_;
- int num_constant_args_;
- // Number of resource variable arguments.
- int num_resource_args_;
-
- se::Platform::Id platform_id_;
-
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 60458f6f33..6b83cf67ff 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -48,13 +48,12 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable) {
std::map<int, OptionalTensor> variables = GetVariables(ctx);
- int64 num_resource_args = variables.size();
xla::LocalClient* client = metadata.client();
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
- num_resource_args, client, client->backend().memory_allocator(), true);
+ client, client->backend().memory_allocator(), true);
launch_context.PopulateInputs(ctx, result, variables);
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 33e53612b9..0223f97a03 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer;
using xla::ShapedBuffer;
} // anonymous namespace
-std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables) {
+std::map<int, OptionalTensor> SnapshotResourceVariables(
+ OpKernelContext* ctx, const std::vector<int>& variables) {
std::map<int, OptionalTensor> snapshot;
- int first_variable = ctx->num_inputs() - num_variables;
- for (int i = 0; i < num_variables; ++i) {
+ for (int i : variables) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
- OptionalTensor& tensor = snapshot[first_variable + i];
+ ResourceHandle handle = HandleFromInput(ctx, i);
+ OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
@@ -112,10 +111,9 @@ ScopedShapedBuffer ExtractSubShapedBuffer(
using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
- int64 num_resource_args, xla::LocalClient* client,
- xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors)
- : num_resource_args_(num_resource_args),
- client_(client),
+ xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
+ bool allocate_xla_tensors)
+ : client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors) {}
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 38291b0bd4..a2431253f8 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -31,15 +31,17 @@ limitations under the License.
namespace tensorflow {
class XlaAllocator;
-// Takes a snapshot of the values of resource variable arguments, which are
-// the last `num_variables` arguments. We snapshot tensors that back
+// Takes a snapshot of the values of resource variable arguments, whose
+// indices are specified in `variables` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
-// Returns a map of TensorFlow argument index to resource variable.
-std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables);
+// Returns a map of TensorFlow argument index to resource variable. If a
+// resource variable is not initialized, the corresponding OptionalTensor
+// will have its `present` field set to false.
+std::map<int, OptionalTensor> SnapshotResourceVariables(
+ OpKernelContext* ctx, const std::vector<int>& variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
@@ -72,7 +74,7 @@ class XlaComputationLaunchContext {
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
// output tensors and variables are always XlaTensors. If false they are
// assumed to be "normal" device pointers.
- XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
+ XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors);
@@ -92,7 +94,6 @@ class XlaComputationLaunchContext {
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
private:
- int64 num_resource_args_;
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;