aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/resource_variable_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/resource_variable_ops.cc')
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc92
1 files changed, 63 insertions, 29 deletions
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 0d8cf78cc2..3d0a6c2157 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -167,27 +167,75 @@ REGISTER_OP("ResourceGather")
return Status::OK();
});
+namespace {
+
+Status ResourceScatterUpdateShape(InferenceContext* c) {
+ ShapeAndType handle_shape_and_type;
+ TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type));
+ ShapeHandle var_shape = handle_shape_and_type.shape;
+ ShapeHandle indices_shape = c->input(1);
+
+ ShapeHandle unused_updates_shape;
+ ShapeHandle concat;
+ ShapeHandle var_subshape;
+ TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
+ TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
+ TF_RETURN_IF_ERROR(
+ InferenceContext::Rank(c->input(2)) == 0
+ ? Status::OK()
+ : c->Merge(c->input(2), concat, &unused_updates_shape));
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("ResourceScatterAdd")
.Input("resource: resource")
.Input("indices: Tindices")
.Input("updates: dtype")
.Attr("dtype: numbertype")
.Attr("Tindices: {int32, int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeAndType handle_shape_and_type;
- TF_RETURN_IF_ERROR(
- ValidateVariableResourceHandle(c, &handle_shape_and_type));
- ShapeHandle var_shape = handle_shape_and_type.shape;
- ShapeHandle indices_shape = c->input(1);
+ .SetShapeFn(ResourceScatterUpdateShape);
- ShapeHandle unused_updates_shape;
- ShapeHandle concat;
- ShapeHandle var_subshape;
- TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
- TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
- TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
- return Status::OK();
- });
+REGISTER_OP("ResourceScatterSub")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMul")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterDiv")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMin")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMax")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
REGISTER_OP("ResourceScatterUpdate")
.Input("resource: resource")
@@ -195,21 +243,7 @@ REGISTER_OP("ResourceScatterUpdate")
.Input("updates: dtype")
.Attr("dtype: type")
.Attr("Tindices: {int32, int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeAndType handle_shape_and_type;
- TF_RETURN_IF_ERROR(
- ValidateVariableResourceHandle(c, &handle_shape_and_type));
- ShapeHandle var_shape = handle_shape_and_type.shape;
- ShapeHandle indices_shape = c->input(1);
-
- ShapeHandle unused_updates_shape;
- ShapeHandle concat;
- ShapeHandle var_subshape;
- TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
- TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
- TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
- return Status::OK();
- });
+ .SetShapeFn(ResourceScatterUpdateShape);
REGISTER_OP("MutexV2")
.Attr("container: string = ''")