aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2017-03-06 16:54:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-06 17:09:49 -0800
commitc25a6623cdb6780d4f84d1123dc1165a13446fb9 (patch)
treefacdda5a35b3743f891bf61d1f805753cabd3ee4
parent020501c695119f76bba0dd2bb47abfeaa939d669 (diff)
Record memory deallocation for destroying temp variable op.
Change: 149363494
-rw-r--r--tensorflow/core/framework/op_kernel.h3
-rw-r--r--tensorflow/core/kernels/variable_ops.h9
2 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 8e7608ec4e..5b13c8be76 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1010,6 +1010,7 @@ class OpKernelContext {
TensorValue release_output(int index);
bool track_allocations() const { return params_->track_allocations; }
+ bool allocate_on_host(AllocatorAttributes alloc_attr) const;
// Records temporary memory sizes.
void record_host_temp_memory_size(int64 size) {
@@ -1064,8 +1065,6 @@ class OpKernelContext {
Tensor* out_tensor, AllocatorAttributes allocator_attr,
const AllocationAttributes& allocation_attr);
- bool allocate_on_host(AllocatorAttributes alloc_attr) const;
-
// This is called by PersistentTensor::AccessTensor whenever the
// wrapped tensor is retrieved, to ensure the runtime knows that the
// Tensor is being accessed within an Op. This is necessary for
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
index 2839c3d8cf..642bff055f 100644
--- a/tensorflow/core/kernels/variable_ops.h
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_VARIABLE_OPS_H_
#define TENSORFLOW_KERNELS_VARIABLE_OPS_H_
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -154,6 +155,14 @@ class DestroyTemporaryVariableOp : public OpKernel {
OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>(
context->step_container()->name(), var_name_));
+ if (context->track_allocations()) {
+ if (context->allocate_on_host(AllocatorAttributes())) {
+ context->record_host_persistent_memory_allocation(-tmpvar.TotalBytes());
+ } else {
+ context->record_device_persistent_memory_allocation(
+ -tmpvar.TotalBytes());
+ }
+ }
}
private: