diff options
author | Alexandre Passos <apassos@google.com> | 2018-07-26 13:26:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-26 13:30:43 -0700 |
commit | ca69ddc34b37258534d8327ec55a26b2add6a632 (patch) | |
tree | 373ccf4c116a750c0d92fc4ffeb8d6dcbae15306 | |
parent | 63563579653c1f0829d460eef5f05963111e08f0 (diff) |
ResourceVariables shouldn't need twice the memory when initializing.
This is safe because all ops which write to resource variables check whether
there are other outstanding references to the buffer and copy if that's the
case. So we can safely reuse the buffer of initializer tensors even in weird
cases such as initializing from a constant (which should never be mutated)
or using the same tensor to initialize multiple variables.
PiperOrigin-RevId: 206211065
-rw-r--r-- | tensorflow/core/kernels/resource_variable_ops.cc | 68 |
1 files changed, 18 insertions, 50 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index c5292e1ae1..cab9eb729d 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -213,64 +213,32 @@ class AssignVariableOp : public OpKernel { "Variable and value dtypes don't match; respectively, ", dtype_, " and ", context->input(1).dtype())); Var* variable = nullptr; - OP_REQUIRES_OK( - context, - LookupOrCreateResource<Var>( - context, HandleFromInput(context, 0), &variable, - [this, context](Var** ptr) { - *ptr = new Var(dtype_); - PersistentTensor unused; - Tensor* tmp; - AllocatorAttributes attr; - if (!relax_constraints_) { - attr.set_gpu_compatible(true); - attr.set_nic_compatible(true); - } - TF_RETURN_IF_ERROR(context->allocate_persistent( - dtype_, context->input(1).shape(), &unused, &tmp, attr)); - *(*ptr)->tensor() = *tmp; - return Status::OK(); - })); + const Tensor& value = context->input(1); + // Note: every resource-variable-manipulating op assumes copy-on-write + // semantics, and creates a copy of the variable's Tensor if its refcount is + // bigger than 1 when we try to modify it. This means we never need to copy + // the original tensor for AssignVariableOp; even if there are other live + // users of it we know none can modify it so this is always safe (even in + // esoteric cases where the same tensor is used to initialize multiple + // variables or the tensor is a constant this is safe, as future writes will + // trigger copies). + OP_REQUIRES_OK(context, LookupOrCreateResource<Var>( + context, HandleFromInput(context, 0), &variable, + [this, &value](Var** ptr) { + *ptr = new Var(dtype_); + *(*ptr)->tensor() = value; + (*ptr)->is_initialized = true; + return Status::OK(); + })); core::ScopedUnref s(variable); - OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, errors::InvalidArgument( "Trying to assign variable with wrong dtype. Expected ", DataTypeString(variable->tensor()->dtype()), " got ", DataTypeString(dtype_))); - - const Tensor& value = context->input(1); - AllocatorAttributes attr; - if (!relax_constraints_) { - attr.set_gpu_compatible(true); - attr.set_nic_compatible(true); - } - - // Copying is unnecessary if we are the last user of the value - // tensor, we can just adopt the input tensor's buffer instead. - std::unique_ptr<Tensor> input_alias = context->forward_input( - 1, OpKernelContext::Params::kNoReservation /*output_index*/, dtype_, - value.shape(), DEVICE_MEMORY, attr); mutex_lock ml(*variable->mu()); variable->is_initialized = true; - if (input_alias) { - *variable->tensor() = *input_alias; - return; - } - - // Need to copy, but maybe we can re-use variable's buffer? - if (!variable->tensor()->RefCountIsOne() || - !variable->tensor()->shape().IsSameSize(value.shape())) { - // Copy to new buffer - PersistentTensor unused; - Tensor* tmp; - OP_REQUIRES_OK(context, context->allocate_persistent( - dtype_, value.shape(), &unused, &tmp, attr)); - *variable->tensor() = *tmp; - } - functor::DenseUpdate<Device, T, ASSIGN> copy_functor; - copy_functor(context->eigen_device<Device>(), variable->tensor()->flat<T>(), - value.flat<T>()); + *variable->tensor() = value; } private: |