diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2017-03-06 16:54:54 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-06 17:09:49 -0800 |
commit | c25a6623cdb6780d4f84d1123dc1165a13446fb9 (patch) | |
tree | facdda5a35b3743f891bf61d1f805753cabd3ee4 | |
parent | 020501c695119f76bba0dd2bb47abfeaa939d669 (diff) |
Record memory deallocation for destroying temp variable op.
Change: 149363494
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 3 | ||||
-rw-r--r-- | tensorflow/core/kernels/variable_ops.h | 9 |
2 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 8e7608ec4e..5b13c8be76 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -1010,6 +1010,7 @@ class OpKernelContext { TensorValue release_output(int index); bool track_allocations() const { return params_->track_allocations; } + bool allocate_on_host(AllocatorAttributes alloc_attr) const; // Records temporary memory sizes. void record_host_temp_memory_size(int64 size) { @@ -1064,8 +1065,6 @@ class OpKernelContext { Tensor* out_tensor, AllocatorAttributes allocator_attr, const AllocationAttributes& allocation_attr); - bool allocate_on_host(AllocatorAttributes alloc_attr) const; - // This is called by PersistentTensor::AccessTensor whenever the // wrapped tensor is retrieved, to ensure the runtime knows that the // Tensor is being accessed within an Op. This is necessary for diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h index 2839c3d8cf..642bff055f 100644 --- a/tensorflow/core/kernels/variable_ops.h +++ b/tensorflow/core/kernels/variable_ops.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_KERNELS_VARIABLE_OPS_H_ #define TENSORFLOW_KERNELS_VARIABLE_OPS_H_ +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -154,6 +155,14 @@ class DestroyTemporaryVariableOp : public OpKernel { OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>( context->step_container()->name(), var_name_)); + if (context->track_allocations()) { + if (context->allocate_on_host(AllocatorAttributes())) { + context->record_host_persistent_memory_allocation(-tmpvar.TotalBytes()); + } else { + context->record_device_persistent_memory_allocation( + -tmpvar.TotalBytes()); + } + } } private: |