diff options
author | 2018-05-08 16:43:54 -0700 | |
---|---|---|
committer | 2018-05-08 17:09:23 -0700 | |
commit | 14d5f219f33b1ab8e0a67b84d97204d046adb91f (patch) | |
tree | b887f04458ef204e522d2b3d81d15128104b397c /tensorflow/compiler | |
parent | 79b773a4395caf7f0b17ce9ac84a1f34dd277bb9 (diff) |
Make eager functions runable on TPU
PiperOrigin-RevId: 195897321
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 24 | ||||
-rw-r--r-- | tensorflow/compiler/jit/create_xla_launch_op.cc | 207 | ||||
-rw-r--r-- | tensorflow/compiler/jit/create_xla_launch_op.h | 35 | ||||
-rw-r--r-- | tensorflow/compiler/jit/create_xla_launch_op_test.cc | 145 | ||||
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_launch_op.cc | 90 | ||||
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_launch_op.h | 51 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_compile_on_demand_op.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.cc | 18 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.h | 15 | ||||
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tests/eager_test.py | 112 |
11 files changed, 602 insertions, 102 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 07136d6a74..a6b3ce394c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -261,6 +261,7 @@ cc_library( name = "create_xla_launch_op", srcs = [ "create_xla_launch_op.cc", + "create_xla_launch_op.h", ], deps = [ ":common", @@ -270,6 +271,29 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "create_xla_launch_op_test", + srcs = [ + "create_xla_launch_op.h", + "create_xla_launch_op_test.cc", + ], + deps = [ + ":create_xla_launch_op", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 18d901323f..f35e916eb9 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -25,78 +27,189 @@ limitations under the License. namespace tensorflow { namespace { -// Givens a NodeDef 'ndef' and the function library runtime 'flr', if -// 'ndef' is a call to a compilable function defined in 'flr', returns OK -// and fills in 'kernel' with a XlaLaunchOp kernel which computes the -// node. Otherwise, returns a non-OK. +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. // -// This routine is here so that FunctionLibraryRuntime can jit a -// specific function call as requested. -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::unique_ptr<OpKernel>* kernel) { - bool xla_compile = false; - if (!flr->GetFunctionLibraryDefinition() - ->GetAttr(ndef, kXlaCompileAttr, &xla_compile) - .ok() || - !xla_compile) { - // Not marked as _XlaCompile=true. - return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op()); +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(const std::vector<int>* values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_->size() && + (*values_)[current_index_] <= value) { + if ((*values_)[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; } - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - if (!IsCompilable(flr, ndef)) { - // ndef is calling a function that XLA can't compile. - return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString()); + + private: + int current_index_; + const std::vector<int>* values_; +}; + +Status CompilationRequested(const FunctionLibraryRuntime& flr, + const NodeDef& node_def) { + bool xla_compile = false; + // Check if op is marked _XlaCompile=true. + Status status = flr.GetFunctionLibraryDefinition()->GetAttr( + node_def, kXlaCompileAttr, &xla_compile); + if (!status.ok() || !xla_compile) { + if (VLOG_IS_ON(3)) { + if (!status.ok()) { + VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " + << node_def.op() << ". status=" << status.ToString(); + } else { + VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; + } + } + return Status(error::INVALID_ARGUMENT, ""); } + return Status::OK(); +} + +// Given a FunctionLibraryRuntime and a NodeDef calling a function in the +// runtime, returns this function's body in `fbody` as well as the indices +// of its constant and resource arguments. +// `fbody` is owned by `flr`. +// `constant_arg_indices` and `resource_arg_indices` should be empty vector. +// They are sorted in ascending order on this function's return. +Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + const FunctionBody** fbody, + std::vector<int>* constant_arg_indices, + std::vector<int>* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; - // If ndef is not instantiable, e.g., the function does not exist, + // If node_def is not instantiable, e.g., the function does not exist, // simply bail out. TF_RETURN_IF_ERROR( - flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); - const FunctionBody* fbody = flr->GetFunctionBody(handle); - CHECK(fbody); // Can't be nullptr since we just instantiated it. - std::vector<bool> const_args(fbody->arg_types.size()); + flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); + *fbody = flr->GetFunctionBody(handle); + CHECK(*fbody); // Can't be nullptr since we just instantiated it. + const DataTypeVector& arg_types = (*fbody)->arg_types; + std::vector<bool> const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { - // There is a const arg. Bail out. - return errors::InvalidArgument("Const arg: ", i, " in ", - DebugString(fbody->fdef)); + constant_arg_indices->push_back(i); + } + } + + // There can be hundreds of resource variables. Reserve the space for them. + // We don't reserve for constants above as they are usually few. + resource_arg_indices->reserve(arg_types.size()); + for (int i = 0; i < arg_types.size(); ++i) { + if (arg_types[i] == DT_RESOURCE) { + resource_arg_indices->push_back(i); } } - NodeDef launch_def; - launch_def.set_name(ndef.name()); - launch_def.set_op("_XlaLaunch"); - launch_def.set_device(flr->device()->name()); - AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def); - AddNodeAttr("Nresources", 0, &launch_def); - AddNodeAttr("Targs", fbody->arg_types, &launch_def); - AddNodeAttr("Tresults", fbody->ret_types, &launch_def); - NameAttrList func; - func.set_name(ndef.op()); - *(func.mutable_attr()) = ndef.attr(); - AddNodeAttr("function", func, &launch_def); - - // TODO(b/32387911): Handles the host memory types across function - // calls properly. For now, we assume all inputs and outputs are on - // the device memory. + return Status::OK(); +} + +} // namespace + +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr<OpKernel>* kernel) { + TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def)); + + VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString(); + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + if (!IsCompilable(flr, node_def)) { + // node_def is calling a function that XLA can't compile. + return errors::InvalidArgument("Not compilable: ", + node_def.ShortDebugString()); + } + + // Get function body, constant args, and resource args. + const FunctionBody* fbody = nullptr; + std::vector<int> constant_arg_indices; + std::vector<int> resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); + + // Set input and output memory types. MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(&constant_arg_indices); + SinglePassSearch resources_search(&resource_arg_indices); + for (int i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = HOST_MEMORY; + } + } + // One might wonder, about the case where a compile-time constant argument + // (which must be in host memory) is also used as an input into an op, + // e.g. Add, that expects its inputs in device memory. Here is how it + // works now. + // First, what do we mean by "op expects an input in XYZ memory"? + // There are two types of "ops" here: the tf2xla kernel and the HLO + // computation it builds. The tf2xla kernel needs to retrieve the actual + // numeric value of the compile-time constant tensors, so it really expects + // them to be on in host memory. However, for other inputs, it refers to them + // using xla::ComputationDataHandle, which is just a symbolic handle that + // xla::ComputationBuilder assigns. How does this handle gets assigned for + // constant arguments? Even constant arguments get an _Arg node in the graph + // instatiated for Function compilation. The tf2xla kernel for constant _Arg + // nodes takes the constant value, converts it to XlaLiteral, and feeds it + // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This + // constant XlaLiteral is included in the HLO graph, and subsequently, in + // the actual executable, which is copied to the device before being + // executed. Thus, when this executable runs, the constant is available in + // device memory. + + // XlaLaunch kernel keeps all outputs (including constants, which it copies), + // in device memory MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + // Create the kernel. + NameAttrList function; + function.set_name(node_def.op()); + *(function.mutable_attr()) = node_def.attr(); + Device* dev = flr->device(); Status s; OpKernelConstruction construction( DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &launch_def, + dev->GetAllocator(AllocatorAttributes()), &node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - kernel->reset(new XlaLocalLaunchOp(&construction)); + + *kernel = absl::make_unique<XlaLocalLaunchBase>( + &construction, constant_arg_indices, resource_arg_indices, function); return s; } +namespace { + bool RegisterLaunchOpCreator() { RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp); return true; diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h new file mode 100644 index 0000000000..98a22e3515 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FunctionLibraryRuntime; +class OpKernel; + +// Given a NodeDef 'node_def' and the function library runtime 'flr', if +// 'node_def' is a call to a compilable function defined in 'flr', returns OK +// and fills in 'kernel' with a XlaLaunchOp kernel which computes the +// node. Otherwise, returns a non-OK. +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr<OpKernel>* kernel); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc new file mode 100644 index 0000000000..bcd5e75c7e --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/create_xla_launch_op.h" + +#include "absl/memory/memory.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +NodeDef ToNodeDef(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +// Create a FunctionDef that takes one resource and one regular param +FunctionDef XTimesY() { + return FunctionDefHelper::Define( + // Name + "XTimesY", + // Args + {"x: float", "y: resource"}, + // Return values + {"z: float"}, + // Attr def + {}, + // Nodes + { + {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}}, + {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}}, + }); +} + +class CreateXlaLaunchOpTest : public ::testing::Test { + protected: + void Init(const std::vector<FunctionDef>& flib) { + SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices_)); + + FunctionDefLibrary proto; + for (const auto& fdef : flib) { + *(proto.add_function()) = fdef; + } + lib_def_ = absl::make_unique<FunctionLibraryDefinition>( + OpRegistry::Global(), proto); + OptimizerOptions opts; + device_mgr_ = absl::make_unique<DeviceMgr>(devices_); + pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + } + + FunctionLibraryRuntime* flr_; + std::vector<Device*> devices_; + std::unique_ptr<DeviceMgr> device_mgr_; + std::unique_ptr<FunctionLibraryDefinition> lib_def_; + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + + std::unique_ptr<OpKernel> kernel_; +}; + +AttrValue BoolAttr(bool b) { + AttrValue v; + v.set_b(b); + return v; +} + +TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); + Init({fdef}); + + Status status = CreateXlaLaunchOp( + flr_, ToNodeDef(R"pb( + name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' + )pb"), &kernel_); + ASSERT_TRUE(status.ok()) << status.ToString(); + + EXPECT_EQ("XTimesY", kernel_->name()); + EXPECT_EQ("XTimesY", kernel_->type_string()); + + EXPECT_EQ(2, kernel_->num_inputs()); + EXPECT_EQ(DT_FLOAT, kernel_->input_type(0)); + EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]); + EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]); + + EXPECT_EQ(1, kernel_->num_outputs()); + EXPECT_EQ(DT_FLOAT, kernel_->output_type(0)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) { + FunctionDef fdef = XTimesY(); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 049d170fa4..86a9fd3b8e 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -39,15 +39,15 @@ limitations under the License. namespace tensorflow { -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : OpKernel(ctx), device_type_(ctx->device_type()) { - const NameAttrList* func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); - function_ = *func; - DataTypeVector constant_types; - OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); - num_constant_args_ = constant_types.size(); - OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_)); +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector<int>& constants, + const std::vector<int>& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + device_type_(ctx->device_type()), + function_(function) { if (device_type_ == DeviceType(DEVICE_CPU)) { platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { @@ -57,8 +57,8 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) } } -Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { +Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { const XlaDevice::Metadata* metadata; Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { @@ -90,8 +90,8 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } -void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOp::Compute " +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. @@ -124,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { } std::map<int, OptionalTensor> variables = - SnapshotResourceVariables(ctx, num_resource_args_); + SnapshotResourceVariables(ctx, resources_); xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client()); @@ -161,7 +161,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map<int, Tensor> constant_args; - for (int i = 0; i < num_constant_args_; ++i) { + for (int i : constants_) { constant_args.insert({i, ctx->input(i)}); } OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, @@ -170,8 +170,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context( - num_resource_args_, client, xla_allocator, allocate_xla_tensors); + XlaComputationLaunchContext launch_context(client, xla_allocator, + allocate_xla_tensors); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. @@ -194,6 +194,62 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Done"; } +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector<int> ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector<int> constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} + +std::vector<int> ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Tconstants", &constant_types)); + + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector<int> resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; +} + +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index 8f8e646f0f..8dfc4b382d 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -26,6 +26,41 @@ limitations under the License. namespace tensorflow { +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector<int>& constants, + const std::vector<int>& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Builds a XlaCompilationCache class suitable for the current device. + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache); + + // Indexes of compile-time constant inputs + std::vector<int> constants_; + // Indexes of resource inputs + std::vector<int> resources_; + + DeviceType device_type_; + NameAttrList function_; + se::Platform::Id platform_id_; +}; + // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // which will be compiled and executed using XLA. The XlaLocalLaunchOp is // responsible for handling interactions with the TensorFlow executor. @@ -35,26 +70,12 @@ namespace tensorflow { // XlaLocalLaunchOp uses xla::LocalClient::Compile() and // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device // memory. -class XlaLocalLaunchOp : public OpKernel { +class XlaLocalLaunchOp : public XlaLocalLaunchBase { public: explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); ~XlaLocalLaunchOp() override; - void Compute(OpKernelContext* ctx) override; - private: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** compiler); - - DeviceType device_type_; - NameAttrList function_; - int num_constant_args_; - // Number of resource variable arguments. - int num_resource_args_; - - se::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 60458f6f33..6b83cf67ff 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -48,13 +48,12 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable) { std::map<int, OptionalTensor> variables = GetVariables(ctx); - int64 num_resource_args = variables.size(); xla::LocalClient* client = metadata.client(); // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( - num_resource_args, client, client->backend().memory_allocator(), true); + client, client->backend().memory_allocator(), true); launch_context.PopulateInputs(ctx, result, variables); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 33e53612b9..0223f97a03 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer; using xla::ShapedBuffer; } // anonymous namespace -std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables) { +std::map<int, OptionalTensor> SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector<int>& variables) { std::map<int, OptionalTensor> snapshot; - int first_variable = ctx->num_inputs() - num_variables; - for (int i = 0; i < num_variables; ++i) { + for (int i : variables) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, first_variable + i); - OptionalTensor& tensor = snapshot[first_variable + i]; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); @@ -112,10 +111,9 @@ ScopedShapedBuffer ExtractSubShapedBuffer( using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( - int64 num_resource_args, xla::LocalClient* client, - xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors) - : num_resource_args_(num_resource_args), - client_(client), + xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors) + : client_(client), xla_allocator_(xla_allocator), allocate_xla_tensors_(allocate_xla_tensors) {} diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 38291b0bd4..a2431253f8 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -31,15 +31,17 @@ limitations under the License. namespace tensorflow { class XlaAllocator; -// Takes a snapshot of the values of resource variable arguments, which are -// the last `num_variables` arguments. We snapshot tensors that back +// Takes a snapshot of the values of resource variable arguments, whose +// indices are specified in `variables` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is // important that the shapes used for compilation match the true shapes of the // buffers. // -// Returns a map of TensorFlow argument index to resource variable. -std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables); +// Returns a map of TensorFlow argument index to resource variable. If a +// resource variable is not initialized, the corresponding OptionalTensor +// will have its `present` field set to false. +std::map<int, OptionalTensor> SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector<int>& variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -72,7 +74,7 @@ class XlaComputationLaunchContext { // Create a new launch context. 'allocate_xla_tensors' is true if allocated // output tensors and variables are always XlaTensors. If false they are // assumed to be "normal" device pointers. - XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client, + XlaComputationLaunchContext(xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors); @@ -92,7 +94,6 @@ class XlaComputationLaunchContext { const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; } private: - int64 num_resource_args_; xla::LocalClient* client_; xla::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index aaea83ae9c..9791792f29 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -327,7 +327,11 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", ], ) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index bdd0185dfe..5ab1585f8c 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -24,10 +24,16 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.layers import convolutional +from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import googletest @@ -43,7 +49,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen0(self): with self.test_scope(): - empty = constant_op.constant([], dtype=dtypes.int32) + empty = constant_op.constant([], dtype=dtypes.float32) result = array_ops.unstack(empty, 0) self.assertTrue(isinstance(result, list)) self.assertEqual(0, len(result)) @@ -51,7 +57,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen1(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) result = array_ops.split(value, 1, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(1, len(result)) @@ -60,7 +66,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen3(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) result = array_ops.split(value, 3, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(3, len(result)) @@ -131,7 +137,105 @@ class EagerTest(XLATestCase): self.assertEqual(2., grads[0][0].numpy()) -if __name__ == "__main__": +class EagerFunctionTest(XLATestCase): + + def testBasic(self): + with self.test_scope(): + matmul = function.defun(math_ops.matmul, compiled=True) + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq = matmul(t, t, transpose_a=True) + self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) + + def testConv(self): + if 'GPU' in self.device: + # TODO(b/32333178) + self.skipTest('Current implementation of RandomStandardNormal kernel ' + 'is very slow on GPU, and has been blacklisted.') + with self.test_scope(): + data_format = 'channels_last' + conv = convolutional.Conv2D( + filters=1, kernel_size=2, padding='VALID', + data_format=data_format, activation=nn_ops.relu, + kernel_initializer=init_ops.ones_initializer(), + bias_initializer=init_ops.zeros_initializer()) + pool = pooling.MaxPooling2D(2, 2, data_format=data_format) + + def model(x): + x = conv(x) + return pool(x) + model = function.defun(model, compiled=True) + + x = array_ops.ones([1, 4, 4, 1]) + y = model(x) + self.assertAllEqual(y.numpy(), [[[[4.]]]]) + + def testReadVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun(compiled=True) + def f(): + return v.read_value() + + var = f() + self.assertEqual(1.0, var.numpy()) + + def testUpdateVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + def f(v): + v.assign_add(1.0) + return v + + f = function.defun(f, compiled=True) + + var = f(v) + self.assertEqual(2.0, var.numpy()) + + def testAllArgumentKinds(self): + """Test a complex function that takes different argument kinds. + + tf2xla machinery that translates, compiles, and runs defuns + classifies arguments into: compile-time constants, regular tensors, + and resources. This test creates a function with a mix of all these + kinds. Moreover, the order of function arguments is intentionally mixed up. + + This also tests the case when the same argument is a compile-time constant + as well as used in an operation that normally expects its inputs to be + in device memory - addition in this case. + """ + with self.test_scope(): + def foo(c1, r1, v1, c2, v2, r2): + # c1 and c2 are compile-time constants + # r1 and r2 are regular tensors + # v1 and v2 are resource variables + a = c1 + r1 + b = math_ops.cast(c2, dtypes.float32) + v2 + c = array_ops.slice(v1, c1, c2) + d = r2 * v2 + return a, b, c, d + + foo = function.defun(foo, compiled=True) + + c1 = [0, 0] + c2 = array_ops.ones([2], dtype=dtypes.int32) + + r1 = array_ops.ones([2]) + r2 = [[2., 2.], [3., 3.]] + + v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]]) + v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]]) + + a, b, c, d = foo(c1, r1, v1, c2, v2, r2) + + self.assertAllEqual([1, 1], a.numpy()) + self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy()) + self.assertAllEqual([[1.]], c.numpy()) + self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) + + +if __name__ == '__main__': ops.enable_eager_execution( config=config_pb2.ConfigProto(log_device_placement=True)) googletest.main() |