diff options
Diffstat (limited to 'tensorflow/core/kernels/resource_variable_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/resource_variable_ops.cc | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 115a8eb251..ebcfb673d1 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -232,12 +232,12 @@ class AssignVariableOp : public OpKernel { return Status::OK(); })); core::ScopedUnref s(variable); + mutex_lock ml(*variable->mu()); OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, errors::InvalidArgument( "Trying to assign variable with wrong dtype. Expected ", DataTypeString(variable->tensor()->dtype()), " got ", DataTypeString(dtype_))); - mutex_lock ml(*variable->mu()); variable->is_initialized = true; *variable->tensor() = value; } @@ -268,11 +268,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel { return Status::OK(); })); core::ScopedUnref s(variable); - OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT, - errors::InvalidArgument( - "Trying to assign variable with wrong dtype. Expected ", - DataTypeString(variable->tensor()->dtype()), " got ", - DataTypeString(DT_VARIANT))); // For purposes of forwarding DT_VARIANT, we want the least // restrictive attr; we already know the input is on host. @@ -293,6 +288,11 @@ class AssignVariableOp<Device, Variant> : public OpKernel { attr); mutex_lock ml(*variable->mu()); + OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT, + errors::InvalidArgument( + "Trying to assign variable with wrong dtype. Expected ", + DataTypeString(variable->tensor()->dtype()), " got ", + DataTypeString(DT_VARIANT))); variable->is_initialized = true; *variable->tensor() = Tensor(DT_VARIANT, value.shape()); |