aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-07-26 13:26:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 13:30:43 -0700
commitca69ddc34b37258534d8327ec55a26b2add6a632 (patch)
tree373ccf4c116a750c0d92fc4ffeb8d6dcbae15306
parent63563579653c1f0829d460eef5f05963111e08f0 (diff)
ResourceVariables shouldn't need twice the memory when initializing.
This is safe because all ops which write to resource variables check whether there are other outstanding references to the buffer and copy if that's the case. So we can safely reuse the buffer of initializer tensors even in weird cases such as initializing from a constant (which should never be mutated) or using the same tensor to initialize multiple variables. PiperOrigin-RevId: 206211065
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc68
1 files changed, 18 insertions, 50 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index c5292e1ae1..cab9eb729d 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -213,64 +213,32 @@ class AssignVariableOp : public OpKernel {
"Variable and value dtypes don't match; respectively, ",
dtype_, " and ", context->input(1).dtype()));
Var* variable = nullptr;
- OP_REQUIRES_OK(
- context,
- LookupOrCreateResource<Var>(
- context, HandleFromInput(context, 0), &variable,
- [this, context](Var** ptr) {
- *ptr = new Var(dtype_);
- PersistentTensor unused;
- Tensor* tmp;
- AllocatorAttributes attr;
- if (!relax_constraints_) {
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- }
- TF_RETURN_IF_ERROR(context->allocate_persistent(
- dtype_, context->input(1).shape(), &unused, &tmp, attr));
- *(*ptr)->tensor() = *tmp;
- return Status::OK();
- }));
+ const Tensor& value = context->input(1);
+ // Note: every resource-variable-manipulating op assumes copy-on-write
+ // semantics, and creates a copy of the variable's Tensor if its refcount is
+ // bigger than 1 when we try to modify it. This means we never need to copy
+ // the original tensor for AssignVariableOp; even if there are other live
+ // users of it we know none can modify it so this is always safe (even in
+ // esoteric cases where the same tensor is used to initialize multiple
+ // variables or the tensor is a constant this is safe, as future writes will
+ // trigger copies).
+ OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
+ context, HandleFromInput(context, 0), &variable,
+ [this, &value](Var** ptr) {
+ *ptr = new Var(dtype_);
+ *(*ptr)->tensor() = value;
+ (*ptr)->is_initialized = true;
+ return Status::OK();
+ }));
core::ScopedUnref s(variable);
-
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
-
- const Tensor& value = context->input(1);
- AllocatorAttributes attr;
- if (!relax_constraints_) {
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- }
-
- // Copying is unnecessary if we are the last user of the value
- // tensor, we can just adopt the input tensor's buffer instead.
- std::unique_ptr<Tensor> input_alias = context->forward_input(
- 1, OpKernelContext::Params::kNoReservation /*output_index*/, dtype_,
- value.shape(), DEVICE_MEMORY, attr);
mutex_lock ml(*variable->mu());
variable->is_initialized = true;
- if (input_alias) {
- *variable->tensor() = *input_alias;
- return;
- }
-
- // Need to copy, but maybe we can re-use variable's buffer?
- if (!variable->tensor()->RefCountIsOne() ||
- !variable->tensor()->shape().IsSameSize(value.shape())) {
- // Copy to new buffer
- PersistentTensor unused;
- Tensor* tmp;
- OP_REQUIRES_OK(context, context->allocate_persistent(
- dtype_, value.shape(), &unused, &tmp, attr));
- *variable->tensor() = *tmp;
- }
- functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
- copy_functor(context->eigen_device<Device>(), variable->tensor()->flat<T>(),
- value.flat<T>());
+ *variable->tensor() = value;
}
private: