aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-12-07 20:29:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 20:32:19 -0800
commit233aff82ad661b4792577690044b6a24132a2470 (patch)
tree8b645adcb39bbd91077f37e357d79c7f4aad3a77
parent09030980ea42dd1f7c0058c15c27fc74c7c505ec (diff)
Create a new Var-like object, LegacyVar, which allows access to its mutex.
Future changes will change how locking happens on the resource-specific Var object. Also hide any access to LegacyVar in the implementation file; and move other ops into the .cc file where they belong. PiperOrigin-RevId: 178334244
-rw-r--r--tensorflow/core/kernels/variable_ops.cc211
-rw-r--r--tensorflow/core/kernels/variable_ops.h158
2 files changed, 200 insertions, 169 deletions
diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc
index 36b8ff09d7..1b7079dcba 100644
--- a/tensorflow/core/kernels/variable_ops.cc
+++ b/tensorflow/core/kernels/variable_ops.cc
@@ -23,6 +23,177 @@ limitations under the License.
namespace tensorflow {
+// Resource stored by variables in the resource manager
+// (legacy, ref-style version).
+class LegacyVar : public ResourceBase {
+ public:
+ explicit LegacyVar(DataType dtype) : tensor_(dtype) {}
+ // Not copyable or movable.
+ LegacyVar(const LegacyVar&) = delete;
+ LegacyVar& operator=(const LegacyVar&) = delete;
+
+ mutex* mu() { return &mu_; }
+ Tensor* tensor() { return &tensor_; }
+
+ string DebugString() override {
+ return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
+ tensor_.shape().DebugString());
+ }
+
+ private:
+ mutex mu_;
+ Tensor tensor_;
+
+ ~LegacyVar() override {}
+};
+
+VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
+ dtype_ = RemoveRefType(context->output_type(0));
+}
+
+void VariableOp::Compute(OpKernelContext* ctx) {
+ mutex_lock l(init_mu_);
+ if (!initialized_) {
+ OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
+ true /* use name() */));
+ initialized_ = true;
+ }
+ auto creator = [this](LegacyVar** var) {
+ *var = new LegacyVar(dtype_);
+ (*var)->tensor()->set_shape(shape_);
+ return Status::OK();
+ };
+ LegacyVar* var;
+ OP_REQUIRES_OK(ctx, cinfo_.resource_manager()->LookupOrCreate<LegacyVar>(
+ cinfo_.container(), cinfo_.name(), &var, creator));
+ // Output a reference to our tensor, so it may be updated.
+ //
+ // As long as the resource manager hasn't been cleared the ref we return
+ // here is valid because it owns a ref on var.
+ ctx->set_output_ref(0, var->mu(), var->tensor());
+ if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ if (ctx->allocate_on_host(attr)) {
+ ctx->record_host_persistent_memory_allocation(
+ var->tensor()->AllocatedBytes());
+ } else {
+ ctx->record_device_persistent_memory_allocation(
+ var->tensor()->AllocatedBytes());
+ }
+ }
+ var->Unref();
+}
+
+class TemporaryVariableOp : public OpKernel {
+ public:
+ explicit TemporaryVariableOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
+ // Variable name defaults to op name if not specified explicitly.
+ if (var_name_.empty()) var_name_ = name();
+ }
+
+ void Compute(OpKernelContext* context) override {
+ Status s;
+ ResourceMgr* rm = context->resource_manager();
+ OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
+ auto* tmp_var = new TmpVar;
+ OP_REQUIRES(context, tmp_var,
+ errors::ResourceExhausted("Could not allocate TmpVar."));
+ tmp_var->name = var_name_;
+ s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
+ if (!s.ok()) tmp_var->Unref();
+ OP_REQUIRES_OK(context, s);
+ OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
+ var_name_, tmp_var));
+ context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
+ if (context->track_allocations()) {
+ AllocatorAttributes attr;
+ if (context->allocate_on_host(attr)) {
+ context->record_host_persistent_memory_allocation(
+ tmp_var->val.AllocatedBytes());
+ } else {
+ context->record_device_persistent_memory_allocation(
+ tmp_var->val.AllocatedBytes());
+ }
+ }
+ }
+
+ private:
+ // Refcounted temporary variable resource.
+ friend class DestroyTemporaryVariableOp;
+ struct TmpVar : public ResourceBase {
+ mutex mu;
+ Tensor val;
+ string name;
+ string DebugString() override { return name; }
+ ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
+ };
+
+ TensorShape shape_;
+ DataType dtype_;
+ string var_name_;
+};
+
+class DestroyTemporaryVariableOp : public OpKernel {
+ public:
+ explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES(context, IsRefType(context->input_type(0)),
+ errors::InvalidArgument("lhs input needs to be a ref type"));
+ OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
+ OP_REQUIRES(context, !var_name_.empty(),
+ errors::InvalidArgument("Missing var_name attribute"));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
+ // their execution before this DestroyTemporaryVariable op executes.
+ // This is typically achieved using control dependencies.
+ CHECK(IsRefType(context->input_dtype(0)));
+ Tensor tmpvar = context->mutable_input(0, false);
+ context->set_output(0, tmpvar);
+ ResourceMgr* rm = context->resource_manager();
+ OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
+ OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>(
+ context->step_container()->name(), var_name_));
+ if (context->track_allocations()) {
+ if (context->allocate_on_host(AllocatorAttributes())) {
+ context->record_host_persistent_memory_allocation(
+ -static_cast<int64>(tmpvar.AllocatedBytes()));
+ } else {
+ context->record_device_persistent_memory_allocation(
+ -static_cast<int64>(tmpvar.AllocatedBytes()));
+ }
+ }
+ }
+
+ private:
+ string var_name_;
+};
+
+class IsVariableInitializedOp : public OpKernel {
+ public:
+ explicit IsVariableInitializedOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Get a mutable input tensor of the Ref input.
+ const Tensor& input_tensor = context->mutable_input(0, false);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+ auto output_tensor = output->tensor<bool, 0>();
+ bool result = input_tensor.IsInitialized();
+ output_tensor() = result;
+ }
+};
+
REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
REGISTER_KERNEL_BUILDER(Name("VariableV2").Device(DEVICE_CPU), VariableOp);
REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
@@ -33,30 +204,30 @@ REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU),
IsVariableInitializedOp);
#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Variable").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \
- VariableOp); \
- REGISTER_KERNEL_BUILDER( \
- Name("VariableV2").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"),\
- VariableOp); \
- REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("dtype"), \
- TemporaryVariableOp); \
- REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T"), \
- DestroyTemporaryVariableOp); \
- REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("dtype") \
- .HostMemory("is_initialized"), \
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Variable").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \
+ VariableOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("VariableV2").Device(DEVICE_SYCL).TypeConstraint<type>("dtype"), \
+ VariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("dtype"), \
+ TemporaryVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T"), \
+ DestroyTemporaryVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("dtype") \
+ .HostMemory("is_initialized"), \
IsVariableInitializedOp);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
// Only register 'Variable' on GPU for the subset of types also supported by
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
index 820b90d041..83134bad37 100644
--- a/tensorflow/core/kernels/variable_ops.h
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -27,10 +27,16 @@ limitations under the License.
namespace tensorflow {
-// Resource stored by variables in the resource manager.
+// Resource stored by variables in the resource manager
+// (new, resource-style version).
class Var : public ResourceBase {
public:
explicit Var(DataType dtype) : tensor_(dtype) {}
+ // Not copyable or movable.
+ Var(const Var&) = delete;
+ Var& operator=(const Var&) = delete;
+
+ // TODO(ebrevdo): Use LockSet instead of exposing mu.
mutex* mu() { return &mu_; }
Tensor* tensor() { return &tensor_; }
@@ -44,52 +50,12 @@ class Var : public ResourceBase {
Tensor tensor_;
~Var() override {}
- TF_DISALLOW_COPY_AND_ASSIGN(Var);
};
class VariableOp : public OpKernel {
public:
- explicit VariableOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
- dtype_ = RemoveRefType(context->output_type(0));
- }
-
- void Compute(OpKernelContext* ctx) override {
- mutex_lock l(init_mu_);
- if (!initialized_) {
- OP_REQUIRES_OK(
- ctx,
- cinfo_.Init(ctx->resource_manager(), def(), true /* use name() */));
- initialized_ = true;
- }
- auto creator = [this](Var** var) {
- *var = new Var(dtype_);
- (*var)->tensor()->set_shape(shape_);
- return Status::OK();
- };
- Var* var;
- OP_REQUIRES_OK(ctx,
- cinfo_.resource_manager()->LookupOrCreate<Var>(
- cinfo_.container(), cinfo_.name(), &var, creator));
- // Output a reference to our tensor, so it may be updated.
- //
- // As long as the resource manager hasn't been cleared the ref we return
- // here is valid because it owns a ref on var.
- ctx->set_output_ref(0, var->mu(), var->tensor());
- if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
- AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- if (ctx->allocate_on_host(attr)) {
- ctx->record_host_persistent_memory_allocation(
- var->tensor()->AllocatedBytes());
- } else {
- ctx->record_device_persistent_memory_allocation(
- var->tensor()->AllocatedBytes());
- }
- }
- var->Unref();
- }
+ explicit VariableOp(OpKernelConstruction* context);
+ void Compute(OpKernelContext* ctx) override;
private:
DataType dtype_;
@@ -102,112 +68,6 @@ class VariableOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(VariableOp);
};
-class TemporaryVariableOp : public OpKernel {
- public:
- explicit TemporaryVariableOp(OpKernelConstruction* context)
- : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
- OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
- OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
- // Variable name defaults to op name if not specified explicitly.
- if (var_name_ == "") var_name_ = name();
- }
-
- void Compute(OpKernelContext* context) override {
- Status s;
- ResourceMgr* rm = context->resource_manager();
- OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
- auto* tmp_var = new TmpVar;
- OP_REQUIRES(context, tmp_var,
- errors::ResourceExhausted("Could not allocate TmpVar."));
- tmp_var->name = var_name_;
- s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
- if (!s.ok()) tmp_var->Unref();
- OP_REQUIRES_OK(context, s);
- OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
- var_name_, tmp_var));
- context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
- if (context->track_allocations()) {
- AllocatorAttributes attr;
- if (context->allocate_on_host(attr)) {
- context->record_host_persistent_memory_allocation(
- tmp_var->val.AllocatedBytes());
- } else {
- context->record_device_persistent_memory_allocation(
- tmp_var->val.AllocatedBytes());
- }
- }
- }
-
- private:
- // Refcounted temporary variable resource.
- friend class DestroyTemporaryVariableOp;
- struct TmpVar : public ResourceBase {
- mutex mu;
- Tensor val;
- string name;
- string DebugString() override { return name; }
- ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
- };
-
- TensorShape shape_;
- DataType dtype_;
- string var_name_;
-};
-
-class DestroyTemporaryVariableOp : public OpKernel {
- public:
- explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
- : OpKernel(context) {
- OP_REQUIRES(context, IsRefType(context->input_type(0)),
- errors::InvalidArgument("lhs input needs to be a ref type"));
- OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
- OP_REQUIRES(context, var_name_ != "",
- errors::InvalidArgument("Missing var_name attribute"));
- }
-
- void Compute(OpKernelContext* context) override {
- // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
- // their execution before this DestroyTemporaryVariable op executes.
- // This is typically achieved using control dependencies.
- CHECK(IsRefType(context->input_dtype(0)));
- Tensor tmpvar = context->mutable_input(0, false);
- context->set_output(0, tmpvar);
- ResourceMgr* rm = context->resource_manager();
- OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
- OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>(
- context->step_container()->name(), var_name_));
- if (context->track_allocations()) {
- if (context->allocate_on_host(AllocatorAttributes())) {
- context->record_host_persistent_memory_allocation(
- -static_cast<int64>(tmpvar.AllocatedBytes()));
- } else {
- context->record_device_persistent_memory_allocation(
- -static_cast<int64>(tmpvar.AllocatedBytes()));
- }
- }
- }
-
- private:
- string var_name_;
-};
-
-class IsVariableInitializedOp : public OpKernel {
- public:
- IsVariableInitializedOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- // Get a mutable input tensor of the Ref input.
- const Tensor& input_tensor = context->mutable_input(0, false);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, TensorShape({}), &output));
- auto output_tensor = output->tensor<bool, 0>();
- bool result = input_tensor.IsInitialized();
- output_tensor() = result;
- }
-};
-
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_VARIABLE_OPS_H_