aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_launch_util.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-05-08 16:43:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-08 17:09:23 -0700
commit14d5f219f33b1ab8e0a67b84d97204d046adb91f (patch)
treeb887f04458ef204e522d2b3d81d15128104b397c /tensorflow/compiler/jit/xla_launch_util.cc
parent79b773a4395caf7f0b17ce9ac84a1f34dd277bb9 (diff)
Make eager functions runable on TPU
PiperOrigin-RevId: 195897321
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc18
1 files changed, 8 insertions, 10 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 33e53612b9..0223f97a03 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer;
using xla::ShapedBuffer;
} // anonymous namespace
-std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables) {
+std::map<int, OptionalTensor> SnapshotResourceVariables(
+ OpKernelContext* ctx, const std::vector<int>& variables) {
std::map<int, OptionalTensor> snapshot;
- int first_variable = ctx->num_inputs() - num_variables;
- for (int i = 0; i < num_variables; ++i) {
+ for (int i : variables) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
- OptionalTensor& tensor = snapshot[first_variable + i];
+ ResourceHandle handle = HandleFromInput(ctx, i);
+ OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
@@ -112,10 +111,9 @@ ScopedShapedBuffer ExtractSubShapedBuffer(
using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
- int64 num_resource_args, xla::LocalClient* client,
- xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors)
- : num_resource_args_(num_resource_args),
- client_(client),
+ xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
+ bool allocate_xla_tensors)
+ : client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors) {}