aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-05-08 16:43:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-08 17:09:23 -0700
commit14d5f219f33b1ab8e0a67b84d97204d046adb91f (patch)
treeb887f04458ef204e522d2b3d81d15128104b397c /tensorflow
parent79b773a4395caf7f0b17ce9ac84a1f34dd277bb9 (diff)
Make eager functions runable on TPU
PiperOrigin-RevId: 195897321
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/jit/BUILD24
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc207
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.h35
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op_test.cc145
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc90
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h51
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc18
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h15
-rw-r--r--tensorflow/compiler/tests/BUILD4
-rw-r--r--tensorflow/compiler/tests/eager_test.py112
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py55
-rw-r--r--tensorflow/python/eager/function.py127
13 files changed, 722 insertions, 164 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 07136d6a74..a6b3ce394c 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -261,6 +261,7 @@ cc_library(
name = "create_xla_launch_op",
srcs = [
"create_xla_launch_op.cc",
+ "create_xla_launch_op.h",
],
deps = [
":common",
@@ -270,6 +271,29 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "create_xla_launch_op_test",
+ srcs = [
+ "create_xla_launch_op.h",
+ "create_xla_launch_op_test.cc",
+ ],
+ deps = [
+ ":create_xla_launch_op",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 18d901323f..f35e916eb9 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -25,78 +27,189 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
-// 'ndef' is a call to a compilable function defined in 'flr', returns OK
-// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
-// node. Otherwise, returns a non-OK.
+// Utility which searches for values in a sorted list by scanning over it once.
+// No matter how many times ScanForValue is called, the list is scanned at most
+// once. However, if a call to ScanForValue skips over a value, that value is
+// not revisited in future calls to ScanForValue, so callers must take
+// care to order their calls.
//
-// This routine is here so that FunctionLibraryRuntime can jit a
-// specific function call as requested.
-Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
- std::unique_ptr<OpKernel>* kernel) {
- bool xla_compile = false;
- if (!flr->GetFunctionLibraryDefinition()
- ->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
- .ok() ||
- !xla_compile) {
- // Not marked as _XlaCompile=true.
- return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
+// Useful for merging multiple sorted lists in O(n) time.
+class SinglePassSearch {
+ public:
+ // Creates a SinglePassSearch object that can be used to search in `values`.
+ // Does not take ownership of `values`. `values` must outlive this.
+ // `values` must be sorted.
+ explicit SinglePassSearch(const std::vector<int>* values)
+ : current_index_(0), values_(values) {}
+
+ // Scans forward in the vector looking for "value", updating the internal
+ // position in to the vector.
+ // Returns true iff the vector contains the given value at or after current
+ // position.
+ // Not thread-safe.
+ bool ScanForValue(int value) {
+ while (current_index_ < values_->size() &&
+ (*values_)[current_index_] <= value) {
+ if ((*values_)[current_index_] == value) {
+ current_index_++;
+ return true;
+ }
+ current_index_++;
+ }
+ return false;
}
- // Make sure that kernels have been registered on the JIT device.
- XlaOpRegistry::RegisterCompilationKernels();
- if (!IsCompilable(flr, ndef)) {
- // ndef is calling a function that XLA can't compile.
- return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
+
+ private:
+ int current_index_;
+ const std::vector<int>* values_;
+};
+
+Status CompilationRequested(const FunctionLibraryRuntime& flr,
+ const NodeDef& node_def) {
+ bool xla_compile = false;
+ // Check if op is marked _XlaCompile=true.
+ Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
+ node_def, kXlaCompileAttr, &xla_compile);
+ if (!status.ok() || !xla_compile) {
+ if (VLOG_IS_ON(3)) {
+ if (!status.ok()) {
+ VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
+ << node_def.op() << ". status=" << status.ToString();
+ } else {
+ VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
+ }
+ }
+ return Status(error::INVALID_ARGUMENT, "");
}
+ return Status::OK();
+}
+
+// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
+// runtime, returns this function's body in `fbody` as well as the indices
+// of its constant and resource arguments.
+// `fbody` is owned by `flr`.
+// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
+// They are sorted in ascending order on this function's return.
+Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
+ const NodeDef& node_def,
+ const FunctionBody** fbody,
+ std::vector<int>* constant_arg_indices,
+ std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
- // If ndef is not instantiable, e.g., the function does not exist,
+ // If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(
- flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
- const FunctionBody* fbody = flr->GetFunctionBody(handle);
- CHECK(fbody); // Can't be nullptr since we just instantiated it.
- std::vector<bool> const_args(fbody->arg_types.size());
+ flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
+ *fbody = flr->GetFunctionBody(handle);
+ CHECK(*fbody); // Can't be nullptr since we just instantiated it.
+ const DataTypeVector& arg_types = (*fbody)->arg_types;
+ std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
- // There is a const arg. Bail out.
- return errors::InvalidArgument("Const arg: ", i, " in ",
- DebugString(fbody->fdef));
+ constant_arg_indices->push_back(i);
+ }
+ }
+
+ // There can be hundreds of resource variables. Reserve the space for them.
+ // We don't reserve for constants above as they are usually few.
+ resource_arg_indices->reserve(arg_types.size());
+ for (int i = 0; i < arg_types.size(); ++i) {
+ if (arg_types[i] == DT_RESOURCE) {
+ resource_arg_indices->push_back(i);
}
}
- NodeDef launch_def;
- launch_def.set_name(ndef.name());
- launch_def.set_op("_XlaLaunch");
- launch_def.set_device(flr->device()->name());
- AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
- AddNodeAttr("Nresources", 0, &launch_def);
- AddNodeAttr("Targs", fbody->arg_types, &launch_def);
- AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
- NameAttrList func;
- func.set_name(ndef.op());
- *(func.mutable_attr()) = ndef.attr();
- AddNodeAttr("function", func, &launch_def);
-
- // TODO(b/32387911): Handles the host memory types across function
- // calls properly. For now, we assume all inputs and outputs are on
- // the device memory.
+ return Status::OK();
+}
+
+} // namespace
+
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+ std::unique_ptr<OpKernel>* kernel) {
+ TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
+
+ VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
+
+ // Make sure that kernels have been registered on the JIT device.
+ XlaOpRegistry::RegisterCompilationKernels();
+ if (!IsCompilable(flr, node_def)) {
+ // node_def is calling a function that XLA can't compile.
+ return errors::InvalidArgument("Not compilable: ",
+ node_def.ShortDebugString());
+ }
+
+ // Get function body, constant args, and resource args.
+ const FunctionBody* fbody = nullptr;
+ std::vector<int> constant_arg_indices;
+ std::vector<int> resource_arg_indices;
+ TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
+ flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
+
+ // Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
+ // These indices are used only for optimization purposes. They allow us
+ // to loop over constant_arg_indices and resource_arg_indices only once
+ // while iterating over all the function arguments checking if it is a
+ // resource or a constant.
+ // The reason we optimized this code is because functions can have a lot of
+ // captured arguments. For example, the backward pass of ResNet50 takes in all
+ // 214 variables and a similar number of activations.
+ SinglePassSearch constants_search(&constant_arg_indices);
+ SinglePassSearch resources_search(&resource_arg_indices);
+ for (int i = 0; i < fbody->arg_types.size(); ++i) {
+ if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
+ // Compile-time constants and resource handles are expected to be in
+ // host memory.
+ input_memory_types[i] = HOST_MEMORY;
+ }
+ }
+ // One might wonder, about the case where a compile-time constant argument
+ // (which must be in host memory) is also used as an input into an op,
+ // e.g. Add, that expects its inputs in device memory. Here is how it
+ // works now.
+ // First, what do we mean by "op expects an input in XYZ memory"?
+ // There are two types of "ops" here: the tf2xla kernel and the HLO
+ // computation it builds. The tf2xla kernel needs to retrieve the actual
+ // numeric value of the compile-time constant tensors, so it really expects
+ // them to be on in host memory. However, for other inputs, it refers to them
+ // using xla::ComputationDataHandle, which is just a symbolic handle that
+ // xla::ComputationBuilder assigns. How does this handle gets assigned for
+ // constant arguments? Even constant arguments get an _Arg node in the graph
+ // instatiated for Function compilation. The tf2xla kernel for constant _Arg
+ // nodes takes the constant value, converts it to XlaLiteral, and feeds it
+ // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
+ // constant XlaLiteral is included in the HLO graph, and subsequently, in
+ // the actual executable, which is copied to the device before being
+ // executed. Thus, when this executable runs, the constant is available in
+ // device memory.
+
+ // XlaLaunch kernel keeps all outputs (including constants, which it copies),
+ // in device memory
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
+ // Create the kernel.
+ NameAttrList function;
+ function.set_name(node_def.op());
+ *(function.mutable_attr()) = node_def.attr();
+
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
- dev->GetAllocator(AllocatorAttributes()), &launch_def,
+ dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
- kernel->reset(new XlaLocalLaunchOp(&construction));
+
+ *kernel = absl::make_unique<XlaLocalLaunchBase>(
+ &construction, constant_arg_indices, resource_arg_indices, function);
return s;
}
+namespace {
+
bool RegisterLaunchOpCreator() {
RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
return true;
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h
new file mode 100644
index 0000000000..98a22e3515
--- /dev/null
+++ b/tensorflow/compiler/jit/create_xla_launch_op.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class FunctionLibraryRuntime;
+class OpKernel;
+
+// Given a NodeDef 'node_def' and the function library runtime 'flr', if
+// 'node_def' is a call to a compilable function defined in 'flr', returns OK
+// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
+// node. Otherwise, returns a non-OK.
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+ std::unique_ptr<OpKernel>* kernel);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
new file mode 100644
index 0000000000..bcd5e75c7e
--- /dev/null
+++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+NodeDef ToNodeDef(const string& text) {
+ NodeDef node_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
+ return node_def;
+}
+
+// Create a FunctionDef that takes one resource and one regular param
+FunctionDef XTimesY() {
+ return FunctionDefHelper::Define(
+ // Name
+ "XTimesY",
+ // Args
+ {"x: float", "y: resource"},
+ // Return values
+ {"z: float"},
+ // Attr def
+ {},
+ // Nodes
+ {
+ {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
+ {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
+ });
+}
+
+class CreateXlaLaunchOpTest : public ::testing::Test {
+ protected:
+ void Init(const std::vector<FunctionDef>& flib) {
+ SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", 1});
+ TF_CHECK_OK(DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices_));
+
+ FunctionDefLibrary proto;
+ for (const auto& fdef : flib) {
+ *(proto.add_function()) = fdef;
+ }
+ lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
+ OpRegistry::Global(), proto);
+ OptimizerOptions opts;
+ device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
+ pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
+ device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
+ opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
+ flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
+ }
+
+ FunctionLibraryRuntime* flr_;
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<FunctionLibraryDefinition> lib_def_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+
+ std::unique_ptr<OpKernel> kernel_;
+};
+
+AttrValue BoolAttr(bool b) {
+ AttrValue v;
+ v.set_b(b);
+ return v;
+}
+
+TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) {
+ FunctionDef fdef = XTimesY();
+ (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(
+ flr_, ToNodeDef(R"pb(
+ name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
+ )pb"), &kernel_);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ EXPECT_EQ("XTimesY", kernel_->name());
+ EXPECT_EQ("XTimesY", kernel_->type_string());
+
+ EXPECT_EQ(2, kernel_->num_inputs());
+ EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
+ EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
+ EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
+ EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
+
+ EXPECT_EQ(1, kernel_->num_outputs());
+ EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
+ EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) {
+ FunctionDef fdef = XTimesY();
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+ name: 'XTimesY'
+ op: 'XTimesY'
+ input: 'a'
+ input: 'b'
+ )proto"), &kernel_);
+ EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) {
+ FunctionDef fdef = XTimesY();
+ (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+ name: 'XTimesY'
+ op: 'XTimesY'
+ input: 'a'
+ input: 'b'
+ )proto"), &kernel_);
+ EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 049d170fa4..86a9fd3b8e 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -39,15 +39,15 @@ limitations under the License.
namespace tensorflow {
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), device_type_(ctx->device_type()) {
- const NameAttrList* func;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
- function_ = *func;
- DataTypeVector constant_types;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
- num_constant_args_ = constant_types.size();
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function)
+ : OpKernel(ctx),
+ constants_(constants),
+ resources_(resources),
+ device_type_(ctx->device_type()),
+ function_(function) {
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
@@ -57,8 +57,8 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
}
}
-Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
+Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache) {
const XlaDevice::Metadata* metadata;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
if (s.ok()) {
@@ -90,8 +90,8 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
return Status::OK();
}
-void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOp::Compute "
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOpBase::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
@@ -124,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
}
std::map<int, OptionalTensor> variables =
- SnapshotResourceVariables(ctx, num_resource_args_);
+ SnapshotResourceVariables(ctx, resources_);
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
@@ -161,7 +161,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
std::map<int, Tensor> constant_args;
- for (int i = 0; i < num_constant_args_; ++i) {
+ for (int i : constants_) {
constant_args.insert({i, ctx->input(i)});
}
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
@@ -170,8 +170,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
- XlaComputationLaunchContext launch_context(
- num_resource_args_, client, xla_allocator, allocate_xla_tensors);
+ XlaComputationLaunchContext launch_context(client, xla_allocator,
+ allocate_xla_tensors);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
@@ -194,6 +194,62 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Done";
}
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return RET; \
+ } \
+ } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+ std::vector<int> constants(constant_types.size());
+ std::iota(constants.begin(), constants.end(), 0);
+ return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+
+ DataTypeVector arg_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Targs", &arg_types));
+
+ int num_resources;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Nresources", &num_resources));
+
+ std::vector<int> resources(num_resources);
+ std::iota(resources.begin(), resources.end(),
+ constant_types.size() + arg_types.size());
+ return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+ return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+} // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+ FunctionAttr(ctx)) {}
+
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
}
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
index 8f8e646f0f..8dfc4b382d 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h
@@ -26,6 +26,41 @@ limitations under the License.
namespace tensorflow {
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+ XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function);
+ XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+ XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+ ~XlaLocalLaunchBase() override = default;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ // Builds a XlaCompilationCache class suitable for the current device.
+ Status BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache);
+
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ DeviceType device_type_;
+ NameAttrList function_;
+ se::Platform::Id platform_id_;
+};
+
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
@@ -35,26 +70,12 @@ namespace tensorflow {
// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
// memory.
-class XlaLocalLaunchOp : public OpKernel {
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
public:
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
~XlaLocalLaunchOp() override;
- void Compute(OpKernelContext* ctx) override;
-
private:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** compiler);
-
- DeviceType device_type_;
- NameAttrList function_;
- int num_constant_args_;
- // Number of resource variable arguments.
- int num_resource_args_;
-
- se::Platform::Id platform_id_;
-
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 60458f6f33..6b83cf67ff 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -48,13 +48,12 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable) {
std::map<int, OptionalTensor> variables = GetVariables(ctx);
- int64 num_resource_args = variables.size();
xla::LocalClient* client = metadata.client();
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
- num_resource_args, client, client->backend().memory_allocator(), true);
+ client, client->backend().memory_allocator(), true);
launch_context.PopulateInputs(ctx, result, variables);
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 33e53612b9..0223f97a03 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer;
using xla::ShapedBuffer;
} // anonymous namespace
-std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables) {
+std::map<int, OptionalTensor> SnapshotResourceVariables(
+ OpKernelContext* ctx, const std::vector<int>& variables) {
std::map<int, OptionalTensor> snapshot;
- int first_variable = ctx->num_inputs() - num_variables;
- for (int i = 0; i < num_variables; ++i) {
+ for (int i : variables) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
- OptionalTensor& tensor = snapshot[first_variable + i];
+ ResourceHandle handle = HandleFromInput(ctx, i);
+ OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
@@ -112,10 +111,9 @@ ScopedShapedBuffer ExtractSubShapedBuffer(
using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
- int64 num_resource_args, xla::LocalClient* client,
- xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors)
- : num_resource_args_(num_resource_args),
- client_(client),
+ xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
+ bool allocate_xla_tensors)
+ : client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors) {}
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 38291b0bd4..a2431253f8 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -31,15 +31,17 @@ limitations under the License.
namespace tensorflow {
class XlaAllocator;
-// Takes a snapshot of the values of resource variable arguments, which are
-// the last `num_variables` arguments. We snapshot tensors that back
+// Takes a snapshot of the values of resource variable arguments, whose
+// indices are specified in `variables` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
-// Returns a map of TensorFlow argument index to resource variable.
-std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables);
+// Returns a map of TensorFlow argument index to resource variable. If a
+// resource variable is not initialized, the corresponding OptionalTensor
+// will have its `present` field set to false.
+std::map<int, OptionalTensor> SnapshotResourceVariables(
+ OpKernelContext* ctx, const std::vector<int>& variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
@@ -72,7 +74,7 @@ class XlaComputationLaunchContext {
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
// output tensors and variables are always XlaTensors. If false they are
// assumed to be "normal" device pointers.
- XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
+ XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors);
@@ -92,7 +94,6 @@ class XlaComputationLaunchContext {
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
private:
- int64 num_resource_args_;
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index aaea83ae9c..9791792f29 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -327,7 +327,11 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
"//tensorflow/python:platform_test",
+ "//tensorflow/python/eager:function",
],
)
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index bdd0185dfe..5ab1585f8c 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -24,10 +24,16 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.layers import convolutional
+from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import googletest
@@ -43,7 +49,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen0(self):
with self.test_scope():
- empty = constant_op.constant([], dtype=dtypes.int32)
+ empty = constant_op.constant([], dtype=dtypes.float32)
result = array_ops.unstack(empty, 0)
self.assertTrue(isinstance(result, list))
self.assertEqual(0, len(result))
@@ -51,7 +57,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen1(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 1, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(1, len(result))
@@ -60,7 +66,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen3(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 3, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(3, len(result))
@@ -131,7 +137,105 @@ class EagerTest(XLATestCase):
self.assertEqual(2., grads[0][0].numpy())
-if __name__ == "__main__":
+class EagerFunctionTest(XLATestCase):
+
+ def testBasic(self):
+ with self.test_scope():
+ matmul = function.defun(math_ops.matmul, compiled=True)
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ sq = matmul(t, t, transpose_a=True)
+ self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
+
+ def testConv(self):
+ if 'GPU' in self.device:
+ # TODO(b/32333178)
+ self.skipTest('Current implementation of RandomStandardNormal kernel '
+ 'is very slow on GPU, and has been blacklisted.')
+ with self.test_scope():
+ data_format = 'channels_last'
+ conv = convolutional.Conv2D(
+ filters=1, kernel_size=2, padding='VALID',
+ data_format=data_format, activation=nn_ops.relu,
+ kernel_initializer=init_ops.ones_initializer(),
+ bias_initializer=init_ops.zeros_initializer())
+ pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
+
+ def model(x):
+ x = conv(x)
+ return pool(x)
+ model = function.defun(model, compiled=True)
+
+ x = array_ops.ones([1, 4, 4, 1])
+ y = model(x)
+ self.assertAllEqual(y.numpy(), [[[[4.]]]])
+
+ def testReadVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun(compiled=True)
+ def f():
+ return v.read_value()
+
+ var = f()
+ self.assertEqual(1.0, var.numpy())
+
+ def testUpdateVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ def f(v):
+ v.assign_add(1.0)
+ return v
+
+ f = function.defun(f, compiled=True)
+
+ var = f(v)
+ self.assertEqual(2.0, var.numpy())
+
+ def testAllArgumentKinds(self):
+ """Test a complex function that takes different argument kinds.
+
+ tf2xla machinery that translates, compiles, and runs defuns
+ classifies arguments into: compile-time constants, regular tensors,
+ and resources. This test creates a function with a mix of all these
+ kinds. Moreover, the order of function arguments is intentionally mixed up.
+
+ This also tests the case when the same argument is a compile-time constant
+ as well as used in an operation that normally expects its inputs to be
+ in device memory - addition in this case.
+ """
+ with self.test_scope():
+ def foo(c1, r1, v1, c2, v2, r2):
+ # c1 and c2 are compile-time constants
+ # r1 and r2 are regular tensors
+ # v1 and v2 are resource variables
+ a = c1 + r1
+ b = math_ops.cast(c2, dtypes.float32) + v2
+ c = array_ops.slice(v1, c1, c2)
+ d = r2 * v2
+ return a, b, c, d
+
+ foo = function.defun(foo, compiled=True)
+
+ c1 = [0, 0]
+ c2 = array_ops.ones([2], dtype=dtypes.int32)
+
+ r1 = array_ops.ones([2])
+ r2 = [[2., 2.], [3., 3.]]
+
+ v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
+ v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
+
+ a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
+
+ self.assertAllEqual([1, 1], a.numpy())
+ self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
+ self.assertAllEqual([[1.]], c.numpy())
+ self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
+
+
+if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))
googletest.main()
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index 8517a3bf7b..b8f352d5f5 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -36,9 +36,7 @@ def device_and_data_format():
'channels_last')
-def random_batch(batch_size, device_and_format=None):
- _, data_format = device_and_format or device_and_data_format()
-
+def random_batch(batch_size, data_format):
shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3)
shape = (batch_size,) + shape
@@ -70,7 +68,7 @@ class ResNet50Test(tf.test.TestCase):
if defun:
model.call = tfe.defun(model.call)
with tf.device(device), tfe.execution_mode(execution_mode):
- images, _ = random_batch(2)
+ images, _ = random_batch(2, data_format)
output = model(images, training=False)
tfe.async_wait()
self.assertEqual((2, 1000), output.shape)
@@ -91,7 +89,7 @@ class ResNet50Test(tf.test.TestCase):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False)
with tf.device(device):
- images, _ = random_batch(2)
+ images, _ = random_batch(2, data_format)
output = model(images, training=False)
output_shape = ((2, 2048, 1, 1)
if data_format == 'channels_first' else (2, 1, 1, 2048))
@@ -101,7 +99,7 @@ class ResNet50Test(tf.test.TestCase):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
with tf.device(device):
- images, _ = random_batch(2)
+ images, _ = random_batch(2, data_format)
output = model(images, training=False)
self.assertEqual((2, 2048), output.shape)
@@ -115,7 +113,7 @@ class ResNet50Test(tf.test.TestCase):
name='t0').as_default(), tf.contrib.summary.always_record_summaries():
with tf.device(device), tfe.execution_mode(execution_mode):
optimizer = tf.train.GradientDescentOptimizer(0.1)
- images, labels = random_batch(2)
+ images, labels = random_batch(2, data_format)
train_one_step(model, images, labels, optimizer)
self.assertEqual(320, len(model.variables))
tfe.async_wait()
@@ -134,7 +132,7 @@ class ResNet50Test(tf.test.TestCase):
model = resnet50.ResNet50(data_format)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
- images, labels = random_batch(2)
+ images, labels = random_batch(2, data_format)
gc.disable()
# Warm up. Note that this first run does create significant amounts of
# garbage to be collected. The hope is that this is a build-only effect,
@@ -202,18 +200,18 @@ class ResNet50Benchmarks(tf.test.Benchmark):
# which forces a sync. This is a roundabout way, yes.
tf.constant(1.).cpu()
- def _benchmark_eager_apply(self, label, defun=False, execution_mode=None,
- device_and_format=None):
+ def _benchmark_eager_apply(self, label, device_and_format, defun=False,
+ execution_mode=None, compiled=False):
with tfe.execution_mode(execution_mode):
- device, data_format = device_and_format or device_and_data_format()
+ device, data_format = device_and_format
model = resnet50.ResNet50(data_format)
if defun:
- model.call = tfe.defun(model.call)
+ model.call = tfe.defun(model.call, compiled=compiled)
batch_size = 64
num_burn = 5
num_iters = 30
with tf.device(device):
- images, _ = random_batch(batch_size, device_and_format)
+ images, _ = random_batch(batch_size, data_format)
for _ in xrange(num_burn):
model(images, training=False).cpu()
if execution_mode:
@@ -227,30 +225,34 @@ class ResNet50Benchmarks(tf.test.Benchmark):
self._report(label, start, num_iters, device, batch_size, data_format)
def benchmark_eager_apply_sync(self):
- self._benchmark_eager_apply('eager_apply', defun=False)
+ self._benchmark_eager_apply('eager_apply', device_and_data_format(),
+ defun=False)
def benchmark_eager_apply_async(self):
self._benchmark_eager_apply(
- 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC)
+ 'eager_apply_async', device_and_data_format(), defun=False,
+ execution_mode=tfe.ASYNC)
def benchmark_eager_apply_with_defun(self):
- self._benchmark_eager_apply('eager_apply_with_defun', defun=True)
+ self._benchmark_eager_apply('eager_apply_with_defun',
+ device_and_data_format(), defun=True)
def _benchmark_eager_train(self,
label,
make_iterator,
+ device_and_format,
defun=False,
execution_mode=None,
- device_and_format=None):
+ compiled=False):
with tfe.execution_mode(execution_mode):
- device, data_format = device_and_format or device_and_data_format()
+ device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
- (images, labels) = random_batch(batch_size, device_and_format)
+ (images, labels) = random_batch(batch_size, data_format)
num_burn = 3
num_iters = 10
model = resnet50.ResNet50(data_format)
if defun:
- model.call = tfe.defun(model.call)
+ model.call = tfe.defun(model.call, compiled=compiled)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
@@ -273,18 +275,21 @@ class ResNet50Benchmarks(tf.test.Benchmark):
self._report(label, start, num_iters, device, batch_size, data_format)
def benchmark_eager_train_sync(self):
- self._benchmark_eager_train('eager_train', MockIterator, defun=False)
+ self._benchmark_eager_train('eager_train', MockIterator,
+ device_and_data_format(), defun=False)
def benchmark_eager_train_async(self):
self._benchmark_eager_train(
'eager_train_async',
MockIterator,
+ device_and_data_format(),
defun=False,
execution_mode=tfe.ASYNC)
def benchmark_eager_train_with_defun(self):
self._benchmark_eager_train(
- 'eager_train_with_defun', MockIterator, defun=True)
+ 'eager_train_with_defun', MockIterator,
+ device_and_data_format(), defun=True)
def benchmark_eager_train_datasets(self):
@@ -294,7 +299,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return tfe.Iterator(ds)
self._benchmark_eager_train(
- 'eager_train_dataset', make_iterator, defun=False)
+ 'eager_train_dataset', make_iterator,
+ device_and_data_format(), defun=False)
def benchmark_eager_train_datasets_with_defun(self):
@@ -304,7 +310,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return tfe.Iterator(ds)
self._benchmark_eager_train(
- 'eager_train_dataset_with_defun', make_iterator, defun=True)
+ 'eager_train_dataset_with_defun', make_iterator,
+ device_and_data_format(), defun=True)
if __name__ == '__main__':
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 89257bb20a..b478b6b0db 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -23,6 +23,7 @@ import collections
import numpy as np
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -227,7 +228,7 @@ def _inference_name(n):
class _EagerDefinedFunction(object):
"""Function object with the interface of tf _DefinedFunction."""
- def __init__(self, name, graph, operations, inputs, outputs):
+ def __init__(self, name, graph, operations, inputs, outputs, attrs):
"""Initializes an eager defined function.
Args:
@@ -237,6 +238,7 @@ class _EagerDefinedFunction(object):
which will be in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
+ attrs: dict mapping names of attributes to their AttrValue values
"""
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
@@ -248,6 +250,14 @@ class _EagerDefinedFunction(object):
[],
None,
compat.as_str(""))
+
+ for name, attr_value in attrs.items():
+ serialized = attr_value.SerializeToString()
+ # TODO(iga): this creates and deletes a new TF_Status for every attr.
+ # It might be worth creating a convenient way to re-use status.
+ pywrap_tensorflow.TF_FunctionSetAttrValueProto(
+ fn, compat.as_str(name), serialized)
+
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_:
@@ -289,25 +299,6 @@ def _flatten(sequence):
class GraphModeFunction(object):
"""Callable object representing a graph-mode function.
-
- Args:
- name: str the name of the created function
- input_placeholders: list of placeholder values (tensors) to feed when
- calling the wrapped function.
- extra_inputs: Tensor inputs this function definition closed over which
- are passed as arguments. Need to track so gradients are supported
- correctly.
- graph: the Graph from which the operations will be pulled. Used as
- a context when computing gradients.
- operations: the subset of Operations in the graph used in the function
- definition.
- outputs: a flat list of the Tensors in the graph used as outputs to the
- function
- func_outputs: a possibly nested python object which will be returned by
- this function. The Tensors in this structure will be replaced by their
- corresponding values in outputs.
- output_shapes: List of shapes of all tensors in outputs
- variables: (optional) List of variables to watch during function execution.
"""
def __init__(self,
@@ -319,9 +310,36 @@ class GraphModeFunction(object):
outputs,
func_outputs,
output_shapes,
- variables=None):
+ variables=None,
+ attrs=None):
+ """Initialize a GraphModeFunction.
+
+ Args:
+ name: str the name of the created function
+ input_placeholders: list of placeholder values (tensors) to feed when
+ calling the wrapped function.
+ extra_inputs: Tensor inputs this function definition closed over which
+ are passed as arguments. Need to track so gradients are supported
+ correctly.
+ graph: the Graph from which the operations will be pulled. Used as
+ a context when computing gradients.
+ operations: the subset of Operations in the graph used in the function
+ definition.
+ outputs: a flat list of the Tensors in the graph used as outputs to the
+ function
+ func_outputs: a possibly nested python object which will be returned by
+ this function. The Tensors in this structure will be replaced by their
+ corresponding values in outputs.
+ output_shapes: List of shapes of all tensors in outputs
+ variables: (optional) List of variables to watch during function
+ execution.
+ attrs: (optional) dict mapping names of attributes to their AttrValue
+ values. Attributes in `attrs` will be included in this function's
+ definition.
+ """
+ self._attrs = attrs or {}
defined_function = _EagerDefinedFunction(
- name, graph, operations, input_placeholders, outputs)
+ name, graph, operations, input_placeholders, outputs, self._attrs)
if len(input_placeholders) != len(defined_function.signature.input_arg):
raise ValueError("Internal error: invalid lengths. %s %s" % (
len(input_placeholders), len(defined_function.signature.input_arg)))
@@ -374,7 +392,7 @@ class GraphModeFunction(object):
forward_name = _forward_name(self._func_name)
self._forward_fdef = _EagerDefinedFunction(
forward_name, self._graph, self._ops, self._input_placeholders,
- filtered_outputs + captures)
+ filtered_outputs + captures, self._attrs)
all_inputs = self._out_grad_placeholders + captures
# Excluding input ops from the body as we do not intend to execute these
# operations when the function is executed.
@@ -388,7 +406,7 @@ class GraphModeFunction(object):
bname = _backward_name(self._func_name)
self._backward_function = GraphModeFunction(
bname, all_inputs, [], self._graph, function_def_ops,
- backward_outputs, in_gradients, output_shapes)
+ backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
def _backprop_call(self, args):
"""Calls the wrapped function and records the result on a tape."""
@@ -562,7 +580,7 @@ def _get_defun_inputs(args):
return nest.pack_sequence_as(args, ret)
-def _defun_internal(name, func, args, kwds):
+def _defun_internal(name, func, compiled, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
with context.graph_mode():
@@ -627,9 +645,14 @@ def _defun_internal(name, func, args, kwds):
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register(f._c_func.func) # pylint: disable=protected-access
+
+ attrs = {}
+ if compiled:
+ attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True)
+
return GraphModeFunction(
fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
- func_outputs, output_shapes, variables)
+ func_outputs, output_shapes, variables, attrs)
# Defun uses this instead of Tensor as a cache key. Using dtype because
@@ -671,7 +694,7 @@ def _register(fn):
# TODO(apassos): better error messages for non-hashable arguments.
-def named_defun(func, name):
+def named_defun(func, name, compiled=False):
"""Defines a function with a given name.
See the documentation for `defun` for more information on the semantics of the
@@ -680,6 +703,7 @@ def named_defun(func, name):
Args:
func: the function to be wrapped.
name: the name given to it.
+ compiled: if true, the framework will attempt to compile func with XLA.
Returns:
the wrapped function.
@@ -696,13 +720,13 @@ def named_defun(func, name):
if cache_key not in arguments_to_functions:
arguments_to_functions[cache_key] = _defun_internal(
- name, func, args, kwds)
+ name, func, compiled, args, kwds)
return arguments_to_functions[cache_key](*args)
return decorated
-def defun(func):
+def defun(func=None, compiled=False):
"""Decorator to compile func into graph_mode.
`defun` converts a function that constructs a TensorFlow graph into a function
@@ -745,18 +769,45 @@ def defun(func):
```
Args:
- func: function to be compiled.
+ func: function to be compiled. If `func` is None, returns a
+ decorator that can be invoked with a single argument - `func`. The
+ end result is equivalent to providing all the arguments up front.
+ In other words, defun(compiled=True)(func) is equivalent to
+ defun(func, compiled=True). The former allows the following use case:
+ @tfe.defun(compiled=True)
+ def foo(...):
+ ...
+ compiled: If True, an attempt to compile `func` with XLA will be made.
+ If it fails, function will be run normally. Experimental.
+ Currently, supported only for execution on TPUs.
Returns:
- A callable that will execute the compiled function (and return zero
- or more `tf.Tensor` objects).
+ If `func` is not None, returns callable that will execute the compiled
+ function (and return zero or more `tf.Tensor` objects).
+ If `func` is None, returns a decorator that, when invoked with a single
+ `func` argument, returns a callable equivalent to the case above.
"""
# TODO(apassos): deal with captured global state. Deal with control flow.
- try:
- name = func.__name__
- except AttributeError:
- name = "function"
- return tf_decorator.make_decorator(func, named_defun(func, name))
+ def decorated(function):
+ try:
+ name = function.__name__
+ except AttributeError:
+ name = "function"
+ return tf_decorator.make_decorator(
+ function, named_defun(function, name, compiled=compiled))
+
+ # This code path is for the `foo = tfe.defun(foo, ...)` use case
+ if func is not None:
+ return decorated(func)
+
+ # This code path is for the
+ #
+ # @tfe.defun(...)
+ # def foo(...):
+ # ...
+ #
+ # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
+ return decorated
def make_defun_op(func, *args, **kwds):
@@ -808,7 +859,7 @@ def make_defun_op(func, *args, **kwds):
name = func.__name__
if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
raise ValueError("Tensor keyword arguments are not supported.")
- return _defun_internal(name, func, args, kwds)
+ return _defun_internal(name, func, False, args, kwds)
class AutomaticControlDependencies(object):