diff options
author | Igor Ganichev <iga@google.com> | 2018-05-08 16:43:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-08 17:09:23 -0700 |
commit | 14d5f219f33b1ab8e0a67b84d97204d046adb91f (patch) | |
tree | b887f04458ef204e522d2b3d81d15128104b397c /tensorflow/compiler/jit/xla_launch_util.cc | |
parent | 79b773a4395caf7f0b17ce9ac84a1f34dd277bb9 (diff) |
Make eager functions runable on TPU
PiperOrigin-RevId: 195897321
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.cc | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 33e53612b9..0223f97a03 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer; using xla::ShapedBuffer; } // anonymous namespace -std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables) { +std::map<int, OptionalTensor> SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector<int>& variables) { std::map<int, OptionalTensor> snapshot; - int first_variable = ctx->num_inputs() - num_variables; - for (int i = 0; i < num_variables; ++i) { + for (int i : variables) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, first_variable + i); - OptionalTensor& tensor = snapshot[first_variable + i]; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); @@ -112,10 +111,9 @@ ScopedShapedBuffer ExtractSubShapedBuffer( using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( - int64 num_resource_args, xla::LocalClient* client, - xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors) - : num_resource_args_(num_resource_args), - client_(client), + xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors) + : client_(client), xla_allocator_(xla_allocator), allocate_xla_tensors_(allocate_xla_tensors) {} |