diff options
Diffstat (limited to 'tensorflow/core/kernels/variable_ops.h')
-rw-r--r-- | tensorflow/core/kernels/variable_ops.h | 146 |
1 files changed, 146 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h new file mode 100644 index 0000000000..77d2da0ad4 --- /dev/null +++ b/tensorflow/core/kernels/variable_ops.h @@ -0,0 +1,146 @@ +#ifndef TENSORFLOW_KERNELS_VARIABLE_OPS_H_ +#define TENSORFLOW_KERNELS_VARIABLE_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class VariableOp : public OpKernel { + public: + explicit VariableOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); + dtype_ = RemoveRefType(context->output_type(0)); + } + + ~VariableOp() override { + if (var_) var_->Unref(); + } + + void Compute(OpKernelContext* ctx) override { + mutex_lock l(init_mu_); + if (var_ == nullptr) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), + true /* use name() */)); + auto creator = [this](Var** var) { + *var = new Var(dtype_); + (*var)->tensor()->set_shape(shape_); + return Status::OK(); + }; + OP_REQUIRES_OK(ctx, + cinfo_.resource_manager()->LookupOrCreate<Var>( + cinfo_.container(), cinfo_.name(), &var_, creator)); + } + // Output a reference to our tensor, so it may be updated. + // + // As long as *this is alive, the ref we return here is valid + // because *this owns a ref on var_. + ctx->set_output_ref(0, var_->mu(), var_->tensor()); + } + + private: + class Var : public ResourceBase { + public: + explicit Var(DataType dtype) : tensor_(dtype) {} + mutex* mu() { return &mu_; } + Tensor* tensor() { return &tensor_; } + + string DebugString() override { + return strings::StrCat(DataTypeString(tensor_.dtype()), "/", + tensor_.shape().ShortDebugString()); + } + + private: + mutex mu_; + Tensor tensor_; + + ~Var() override {} + TF_DISALLOW_COPY_AND_ASSIGN(Var); + }; + + DataType dtype_; + TensorShape shape_; + + mutex init_mu_; + ContainerInfo cinfo_ GUARDED_BY(init_mu_); + Var* var_ GUARDED_BY(init_mu_) = nullptr; + + TF_DISALLOW_COPY_AND_ASSIGN(VariableOp); +}; + +class TemporaryVariableOp : public OpKernel { + public: + explicit TemporaryVariableOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); + OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); + // Variable name defaults to op name if not specified explicitly. + if (var_name_ == "") var_name_ = name(); + } + + void Compute(OpKernelContext* context) override { + Status s; + ResourceMgr* rm = context->step_resource_manager(); + OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); + auto* tmp_var = new TmpVar; + OP_REQUIRES(context, tmp_var, + errors::ResourceExhausted("Could not allocate TmpVar.")); + tmp_var->name = var_name_; + s = context->allocate_temp(dtype_, shape_, &tmp_var->val); + if (!s.ok()) tmp_var->Unref(); + OP_REQUIRES_OK(context, s); + OP_REQUIRES_OK(context, rm->Create("tmp_var", var_name_, tmp_var)); + context->set_output_ref(0, &tmp_var->mu, &tmp_var->val); + } + + private: + // Refcounted temporary variable resource. + friend class DestroyTemporaryVariableOp; + struct TmpVar : public ResourceBase { + mutex mu; + Tensor val; + string name; + string DebugString() override { return name; } + ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; } + }; + + TensorShape shape_; + DataType dtype_; + string var_name_; +}; + +class DestroyTemporaryVariableOp : public OpKernel { + public: + explicit DestroyTemporaryVariableOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES(context, IsRefType(context->input_type(0)), + errors::InvalidArgument("lhs input needs to be a ref type")) + OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_)); + OP_REQUIRES(context, var_name_ != "", + errors::InvalidArgument("Missing var_name attribute")); + } + + void Compute(OpKernelContext* context) override { + // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed + // their execution before this DestroyTemporaryVariable op executes. + // This is typically achieved using control dependencies. + CHECK(IsRefType(context->input_dtype(0))); + Tensor tmpvar = context->mutable_input(0, false); + context->set_output(0, tmpvar); + ResourceMgr* rm = context->step_resource_manager(); + OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); + OP_REQUIRES_OK( + context, rm->Delete<TemporaryVariableOp::TmpVar>("tmp_var", var_name_)); + } + + private: + string var_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_VARIABLE_OPS_H_ |