diff options
Diffstat (limited to 'tensorflow/core/ops/state_ops.cc')
-rw-r--r-- | tensorflow/core/ops/state_ops.cc | 290 |
1 files changed, 290 insertions, 0 deletions
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc new file mode 100644 index 0000000000..da9fd4ad08 --- /dev/null +++ b/tensorflow/core/ops/state_ops.cc @@ -0,0 +1,290 @@ +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("Variable") + .Output("ref: Ref(dtype)") + .Attr("shape: shape") + .Attr("dtype: type") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +Holds state in the form of a tensor that persists across steps. + +Outputs a ref to the tensor state so it may be read or modified. +TODO(zhifengc/mrry): Adds a pointer to a more detail document +about sharing states in tensorflow. + +ref: A reference to the variable tensor. +shape: The shape of the variable tensor. +dtype: The type of elements in the variable tensor. +container: If non-empty, this variable is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this variable is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + +REGISTER_OP("TemporaryVariable") + .Output("ref: Ref(dtype)") + .Attr("shape: shape") + .Attr("dtype: type") + .Attr("var_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +Returns a tensor that may be mutated, but only persists within a single step. + +This is an experimental op for internal use only and it is possible to use this +op in unsafe ways. DO NOT USE unless you fully understand the risks. + +It is the caller's responsibility to ensure that 'ref' is eventually passed to a +matching 'DestroyTemporaryVariable' op after all other uses have completed. + +Outputs a ref to the tensor state so it may be read or modified. + + E.g. + var = state_ops._temporary_variable([1, 2], types.float_) + var_name = var.op.name + var = state_ops.assign(var, [[4.0, 5.0]]) + var = state_ops.assign_add(var, [[6.0, 7.0]]) + final = state_ops._destroy_temporary_variable(var, var_name=var_name) + +ref: A reference to the variable tensor. +shape: The shape of the variable tensor. +dtype: The type of elements in the variable tensor. +var_name: Overrides the name used for the temporary variable resource. Default +value is the name of the 'TemporaryVariable' op (which is guaranteed unique). +)doc"); + +REGISTER_OP("DestroyTemporaryVariable") + .Input("ref: Ref(T)") + .Output("value: T") + .Attr("T: type") + .Attr("var_name: string") + .Doc(R"doc( +Destroys the temporary variable and returns its final value. + +Sets output to the value of the Tensor pointed to by 'ref', then destroys +the temporary variable called 'var_name'. +All other uses of 'ref' *must* have executed before this op. +This is typically achieved by chaining the ref through each assign op, or by +using control dependencies. + +Outputs the final value of the tensor pointed to by 'ref'. + +ref: A reference to the temporary variable tensor. +var_name: Name of the temporary variable, usually the name of the matching +'TemporaryVariable' op. +)doc"); + +REGISTER_OP("Assign") + .Input("ref: Ref(T)") + .Input("value: T") + .Output("output_ref: Ref(T)") + .Attr("T: type") + .Attr("validate_shape: bool = true") + .Attr("use_locking: bool = true") + .SetAllowsUninitializedInput() + .Doc(R"doc( +Update 'ref' by assigning 'value' to it. + +This operation outputs "ref" after the assignment is done. +This makes it easier to chain operations that need to use the reset value. + +ref: Should be from a `Variable` node. May be uninitialized. +value: The value to be assigned to the variable. +validate_shape: If true, the operation will validate that the shape + of 'value' matches the shape of the Tensor being assigned to. If false, + 'ref' will take on the shape of 'value'. +use_locking: If True, the assignment will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +output_ref:= Same as "ref". Returned as a convenience for operations that want + to use the new value after the variable has been reset. +)doc"); + +REGISTER_OP("AssignAdd") + .Input("ref: Ref(T)") + .Input("value: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update 'ref' by adding 'value' to it. + +This operation outputs "ref" after the update is done. +This makes it easier to chain operations that need to use the reset value. + +ref: Should be from a `Variable` node. +value: The value to be added to the variable. +use_locking: If True, the addition will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +output_ref:= Same as "ref". Returned as a convenience for operations that want + to use the new value after the variable has been updated. +)doc"); + +REGISTER_OP("AssignSub") + .Input("ref: Ref(T)") + .Input("value: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update 'ref' by subtracting 'value' from it. + +This operation outputs "ref" after the update is done. +This makes it easier to chain operations that need to use the reset value. + +ref: Should be from a `Variable` node. +value: The value to be subtracted to the variable. +use_locking: If True, the subtraction will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +output_ref:= Same as "ref". Returned as a convenience for operations that want + to use the new value after the variable has been updated. +)doc"); + +REGISTER_OP("ScatterUpdate") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = true") + .Doc(R"doc( +Applies sparse updates to a variable reference. + +This operation computes + + # Scalar indices + ref[indices, ...] = updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] = updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +If `indices` contains duplicate entries, lexicographically later entries +override earlier entries. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/ScatterUpdate.png" alt> +</div> + +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of updated values to store in `ref`. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the assignment will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ScatterAdd") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Adds sparse updates to a variable reference. + +This operation computes + + # Scalar indices + ref[indices, ...] += updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] += updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions add. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/ScatterAdd.png" alt> +</div> + +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of updated values to add to `ref`. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the addition will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ScatterSub") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Subtracts sparse updates to a variable reference. + + # Scalar indices + ref[indices, ...] -= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] -= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their (negated) contributions add. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="../images/ScatterSub.png" alt> +</div> + +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of updated values to subtract from `ref`. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the subtraction will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("CountUpTo") + .Input("ref: Ref(T)") + .Output("output: T") + .Attr("limit: int") + .Attr("T: {int32, int64}") + .Doc(R"doc( +Increments 'ref' until it reaches 'limit'. + +This operation outputs "ref" after the update is done. This makes it +easier to chain operations that need to use the updated value. + +ref: Should be from a scalar `Variable` node. +limit: If incrementing ref would bring it above limit, instead generates an + 'OutOfRange' error. +output: A copy of the input before increment. If nothing else modifies the + input, the values produced will all be distinct. +)doc"); + +} // namespace tensorflow |