aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/resource_variable_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/resource_variable_ops.cc')
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 115a8eb251..ebcfb673d1 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -232,12 +232,12 @@ class AssignVariableOp : public OpKernel {
return Status::OK();
}));
core::ScopedUnref s(variable);
+ mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
- mutex_lock ml(*variable->mu());
variable->is_initialized = true;
*variable->tensor() = value;
}
@@ -268,11 +268,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
return Status::OK();
}));
core::ScopedUnref s(variable);
- OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
- errors::InvalidArgument(
- "Trying to assign variable with wrong dtype. Expected ",
- DataTypeString(variable->tensor()->dtype()), " got ",
- DataTypeString(DT_VARIANT)));
// For purposes of forwarding DT_VARIANT, we want the least
// restrictive attr; we already know the input is on host.
@@ -293,6 +288,11 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
attr);
mutex_lock ml(*variable->mu());
+ OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
+ errors::InvalidArgument(
+ "Trying to assign variable with wrong dtype. Expected ",
+ DataTypeString(variable->tensor()->dtype()), " got ",
+ DataTypeString(DT_VARIANT)));
variable->is_initialized = true;
*variable->tensor() = Tensor(DT_VARIANT, value.shape());