aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/state_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/state_ops.cc')
-rw-r--r--tensorflow/core/ops/state_ops.cc290
1 files changed, 290 insertions, 0 deletions
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
new file mode 100644
index 0000000000..da9fd4ad08
--- /dev/null
+++ b/tensorflow/core/ops/state_ops.cc
@@ -0,0 +1,290 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Variable")
+ .Output("ref: Ref(dtype)")
+ .Attr("shape: shape")
+ .Attr("dtype: type")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+Holds state in the form of a tensor that persists across steps.
+
+Outputs a ref to the tensor state so it may be read or modified.
+TODO(zhifengc/mrry): Adds a pointer to a more detail document
+about sharing states in tensorflow.
+
+ref: A reference to the variable tensor.
+shape: The shape of the variable tensor.
+dtype: The type of elements in the variable tensor.
+container: If non-empty, this variable is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this variable is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("TemporaryVariable")
+ .Output("ref: Ref(dtype)")
+ .Attr("shape: shape")
+ .Attr("dtype: type")
+ .Attr("var_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+Returns a tensor that may be mutated, but only persists within a single step.
+
+This is an experimental op for internal use only and it is possible to use this
+op in unsafe ways. DO NOT USE unless you fully understand the risks.
+
+It is the caller's responsibility to ensure that 'ref' is eventually passed to a
+matching 'DestroyTemporaryVariable' op after all other uses have completed.
+
+Outputs a ref to the tensor state so it may be read or modified.
+
+ E.g.
+ var = state_ops._temporary_variable([1, 2], types.float_)
+ var_name = var.op.name
+ var = state_ops.assign(var, [[4.0, 5.0]])
+ var = state_ops.assign_add(var, [[6.0, 7.0]])
+ final = state_ops._destroy_temporary_variable(var, var_name=var_name)
+
+ref: A reference to the variable tensor.
+shape: The shape of the variable tensor.
+dtype: The type of elements in the variable tensor.
+var_name: Overrides the name used for the temporary variable resource. Default
+value is the name of the 'TemporaryVariable' op (which is guaranteed unique).
+)doc");
+
+REGISTER_OP("DestroyTemporaryVariable")
+ .Input("ref: Ref(T)")
+ .Output("value: T")
+ .Attr("T: type")
+ .Attr("var_name: string")
+ .Doc(R"doc(
+Destroys the temporary variable and returns its final value.
+
+Sets output to the value of the Tensor pointed to by 'ref', then destroys
+the temporary variable called 'var_name'.
+All other uses of 'ref' *must* have executed before this op.
+This is typically achieved by chaining the ref through each assign op, or by
+using control dependencies.
+
+Outputs the final value of the tensor pointed to by 'ref'.
+
+ref: A reference to the temporary variable tensor.
+var_name: Name of the temporary variable, usually the name of the matching
+'TemporaryVariable' op.
+)doc");
+
+REGISTER_OP("Assign")
+ .Input("ref: Ref(T)")
+ .Input("value: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: type")
+ .Attr("validate_shape: bool = true")
+ .Attr("use_locking: bool = true")
+ .SetAllowsUninitializedInput()
+ .Doc(R"doc(
+Update 'ref' by assigning 'value' to it.
+
+This operation outputs "ref" after the assignment is done.
+This makes it easier to chain operations that need to use the reset value.
+
+ref: Should be from a `Variable` node. May be uninitialized.
+value: The value to be assigned to the variable.
+validate_shape: If true, the operation will validate that the shape
+ of 'value' matches the shape of the Tensor being assigned to. If false,
+ 'ref' will take on the shape of 'value'.
+use_locking: If True, the assignment will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+output_ref:= Same as "ref". Returned as a convenience for operations that want
+ to use the new value after the variable has been reset.
+)doc");
+
+REGISTER_OP("AssignAdd")
+ .Input("ref: Ref(T)")
+ .Input("value: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update 'ref' by adding 'value' to it.
+
+This operation outputs "ref" after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+ref: Should be from a `Variable` node.
+value: The value to be added to the variable.
+use_locking: If True, the addition will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+output_ref:= Same as "ref". Returned as a convenience for operations that want
+ to use the new value after the variable has been updated.
+)doc");
+
+REGISTER_OP("AssignSub")
+ .Input("ref: Ref(T)")
+ .Input("value: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update 'ref' by subtracting 'value' from it.
+
+This operation outputs "ref" after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+ref: Should be from a `Variable` node.
+value: The value to be subtracted to the variable.
+use_locking: If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+output_ref:= Same as "ref". Returned as a convenience for operations that want
+ to use the new value after the variable has been updated.
+)doc");
+
+REGISTER_OP("ScatterUpdate")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = true")
+ .Doc(R"doc(
+Applies sparse updates to a variable reference.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+If `indices` contains duplicate entries, lexicographically later entries
+override earlier entries.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/ScatterUpdate.png" alt>
+</div>
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to store in `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the assignment will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ScatterAdd")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Adds sparse updates to a variable reference.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] += updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] += updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions add.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/ScatterAdd.png" alt>
+</div>
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to add to `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the addition will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ScatterSub")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Subtracts sparse updates to a variable reference.
+
+ # Scalar indices
+ ref[indices, ...] -= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] -= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their (negated) contributions add.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/ScatterSub.png" alt>
+</div>
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to subtract from `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("CountUpTo")
+ .Input("ref: Ref(T)")
+ .Output("output: T")
+ .Attr("limit: int")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Increments 'ref' until it reaches 'limit'.
+
+This operation outputs "ref" after the update is done. This makes it
+easier to chain operations that need to use the updated value.
+
+ref: Should be from a scalar `Variable` node.
+limit: If incrementing ref would bring it above limit, instead generates an
+ 'OutOfRange' error.
+output: A copy of the input before increment. If nothing else modifies the
+ input, the values produced will all be distinct.
+)doc");
+
+} // namespace tensorflow