aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/training_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-18 16:58:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-18 18:03:33 -0700
commit075d1ab68fcaf5ddd90c3d30f6a046b78c3c04f3 (patch)
treeff341eceaa08df93ffafa1e6f13e6e2cda2dba48 /tensorflow/core/ops/training_ops.cc
parent6696de2b02d9f2c6d16fd39ef388be87309525bb (diff)
Add C++ shape inference functions for the Sparse ops in training_ops.cc.
Change: 127782110
Diffstat (limited to 'tensorflow/core/ops/training_ops.cc')
-rw-r--r--tensorflow/core/ops/training_ops.cc140
1 files changed, 105 insertions, 35 deletions
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index d80cbf6c35..f779215bc1 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -23,6 +23,30 @@ typedef shape_inference::InferenceContext InferenceContext;
typedef shape_inference::Shape Shape;
static constexpr auto kUnknownDim = InferenceContext::kUnknownDim;
+// Handle the gradient and, if <sparse>, indices inputs.
+// <s> is an input+output parameter, containing the current known input shape to
+// the gradient.
+static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse,
+ int grad_idx, const Shape** s) {
+ const Shape* grad = c->input(grad_idx);
+ if (!sparse) {
+ TF_RETURN_IF_ERROR(c->Merge(*s, grad, s));
+ return Status::OK();
+ }
+ // Indices is a vector where indices.dim[0].rank == grad[0].rank.
+ const Shape* indices;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(grad_idx + 1), 1, &indices));
+ const Dimension* unused;
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused));
+
+ // Trailing part of grad matches *s.
+ const Shape* grad_subshape;
+ TF_RETURN_IF_ERROR(c->Subshape(grad, 1, &grad_subshape));
+ TF_RETURN_IF_ERROR(c->Merge(*s, grad_subshape, s));
+
+ return Status::OK();
+}
+
static Status ApplyGradientDescentShapeFn(InferenceContext* c) {
const Shape* unused;
const Shape* s = c->input(0); // var
@@ -51,13 +75,15 @@ use_locking: If `True`, the subtraction will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
-static Status ApplyProxiimalGradientDescentShapeFn(InferenceContext* c) {
+static Status ApplyProximalGradientDescentShapeFn(InferenceContext* c,
+ bool sparse) {
const Shape* unused;
const Shape* s = c->input(0); // var
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // l1
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l2
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(4), &s)); // delta
+ TF_RETURN_IF_ERROR(
+ HandleGradAndIndicesInputs(c, sparse, 4 /* grad_idx */, &s));
c->set_output(0, s);
return Status::OK();
}
@@ -71,7 +97,9 @@ REGISTER_OP("ApplyProximalGradientDescent")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
- .SetShapeFn(OpShapeInferenceFn(ApplyProxiimalGradientDescentShapeFn))
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyProximalGradientDescentShapeFn(c, false /* sparse */);
+ }))
.Doc(R"doc(
Update '*var' as FOBOS algorithm with fixed learning rate.
prox_v = var - alpha * delta
@@ -98,6 +126,9 @@ REGISTER_OP("SparseApplyProximalGradientDescent")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyProximalGradientDescentShapeFn(c, true /* sparse */);
+ }))
.Doc(R"doc(
Sparse update '*var' as FOBOS algorithm with fixed learning rate.
@@ -115,7 +146,7 @@ out: Same as "var".
use_locking: If True, the subtraction will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
-static Status ApplyAdadeltaShapeFn(InferenceContext* c) {
+static Status ApplyAdadeltaShapeFn(InferenceContext* c, bool sparse) {
const Shape* unused;
const Shape* s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
@@ -123,7 +154,8 @@ static Status ApplyAdadeltaShapeFn(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // epsilon
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(6), &s)); // grad
+ TF_RETURN_IF_ERROR(
+ HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
c->set_output(0, s);
return Status::OK();
}
@@ -139,7 +171,9 @@ REGISTER_OP("ApplyAdadelta")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
- .SetShapeFn(OpShapeInferenceFn(ApplyAdadeltaShapeFn))
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyAdadeltaShapeFn(c, false /* sparse */);
+ }))
.Doc(R"doc(
Update '*var' according to the adadelta scheme.
@@ -173,6 +207,9 @@ REGISTER_OP("SparseApplyAdadelta")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyAdadeltaShapeFn(c, true /* sparse */);
+ }))
.Doc(R"doc(
var: Should be from a Variable().
accum: Should be from a Variable().
@@ -186,12 +223,14 @@ out: Same as "var".
use_locking: If True, updating of the var and accum tensors will be protected by
a lock; otherwise the behavior is undefined, but may exhibit less contention.
)doc");
-static Status ApplyAdagradShapeFn(InferenceContext* c) {
+
+static Status ApplyAdagradShapeFn(InferenceContext* c, bool sparse) {
const Shape* unused;
const Shape* s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(3), &s)); // grad
+ TF_RETURN_IF_ERROR(
+ HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
c->set_output(0, s);
return Status::OK();
}
@@ -204,8 +243,9 @@ REGISTER_OP("ApplyAdagrad")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
- .SetShapeFn(OpShapeInferenceFn(ApplyAdagradShapeFn))
-
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyAdagradShapeFn(c, false /* sparse */);
+ }))
.Doc(R"doc(
Update '*var' according to the adagrad scheme.
@@ -221,14 +261,15 @@ use_locking: If `True`, updating of the var and accum tensors will be protected
by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
)doc");
-static Status ApplyProximalAdagradShapeFn(InferenceContext* c) {
+static Status ApplyProximalAdagradShapeFn(InferenceContext* c, bool sparse) {
const Shape* unused;
const Shape* s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l1
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // l2
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(5), &s)); // grad
+ TF_RETURN_IF_ERROR(
+ HandleGradAndIndicesInputs(c, sparse, 5 /* grad_idx */, &s));
c->set_output(0, s);
return Status::OK();
}
@@ -243,8 +284,9 @@ REGISTER_OP("ApplyProximalAdagrad")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
- .SetShapeFn(OpShapeInferenceFn(ApplyProximalAdagradShapeFn))
-
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyProximalAdagradShapeFn(c, false /* sparse */);
+ }))
.Doc(R"doc(
Update '*var' and '*accum' according to FOBOS with Adagrad learning rate.
accum += grad * grad
@@ -272,6 +314,9 @@ REGISTER_OP("SparseApplyAdagrad")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyAdagradShapeFn(c, true /* sparse */);
+ }))
.Doc(R"doc(
Update relevant entries in '*var' and '*accum' according to the adagrad scheme.
@@ -302,6 +347,9 @@ REGISTER_OP("SparseApplyProximalAdagrad")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyProximalAdagradShapeFn(c, true /* sparse */);
+ }))
.Doc(R"doc(
Sparse update entries in '*var' and '*accum' according to FOBOS algorithm.
@@ -323,16 +371,18 @@ use_locking: If True, updating of the var and accum tensors will be protected by
a lock; otherwise the behavior is undefined, but may exhibit less contention.
)doc");
-static Status ApplyFtrlShapeFn(InferenceContext* c) {
+static Status ApplyFtrlShapeFn(InferenceContext* c, bool sparse) {
const Shape* unused;
- const Shape* s = c->input(0); // var
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // linear
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(3), &s)); // grad
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr
- TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // l1
- TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // l2
- TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // lr_power
+ const Shape* s = c->input(0); // var
+ TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
+ TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // linear
+ TF_RETURN_IF_ERROR(
+ HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
+ int idx = sparse ? 5 : 4;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr_power
c->set_output(0, s);
return Status::OK();
}
@@ -349,8 +399,9 @@ REGISTER_OP("ApplyFtrl")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
- .SetShapeFn(OpShapeInferenceFn(ApplyFtrlShapeFn))
-
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyFtrlShapeFn(c, false /* sparse */);
+ }))
.Doc(R"doc(
Update '*var' according to the Ftrl-proximal scheme.
@@ -388,6 +439,9 @@ REGISTER_OP("SparseApplyFtrl")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyFtrlShapeFn(c, true /* sparse */);
+ }))
.Doc(R"doc(
Update relevant entries in '*var' according to the Ftrl-proximal scheme.
@@ -413,13 +467,15 @@ use_locking: If `True`, updating of the var and accum tensors will be protected
contention.
)doc");
-static Status ApplyMomentumShapeFn(InferenceContext* c) {
+static Status ApplyMomentumShapeFn(InferenceContext* c, bool sparse) {
const Shape* unused;
const Shape* s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(3), &s)); // grad
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // momentum
+ TF_RETURN_IF_ERROR(
+ HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
+ int idx = sparse ? 5 : 4;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // momentum
c->set_output(0, s);
return Status::OK();
}
@@ -433,7 +489,9 @@ REGISTER_OP("ApplyMomentum")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
- .SetShapeFn(OpShapeInferenceFn(ApplyMomentumShapeFn))
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyMomentumShapeFn(c, false /* sparse */);
+ }))
.Doc(R"doc(
Update '*var' according to the momentum scheme.
@@ -462,6 +520,9 @@ REGISTER_OP("SparseApplyMomentum")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyMomentumShapeFn(c, true /* sparse */);
+ }))
.Doc(R"doc(
Update relevant entries in '*var' and '*accum' according to the momentum scheme.
@@ -482,7 +543,7 @@ use_locking: If `True`, updating of the var and accum tensors will be protected
contention.
)doc");
-static Status ApplyAdamShapeFn(InferenceContext* c) {
+static Status ApplyAdamShapeFn(InferenceContext* c, bool sparse) {
const Shape* unused;
const Shape* s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // m
@@ -493,7 +554,8 @@ static Status ApplyAdamShapeFn(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta1
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta2
TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // epsilon
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(9), &s)); // grad
+ TF_RETURN_IF_ERROR(
+ HandleGradAndIndicesInputs(c, sparse, 9 /* grad_idx */, &s));
c->set_output(0, s);
return Status::OK();
}
@@ -512,7 +574,9 @@ REGISTER_OP("ApplyAdam")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
- .SetShapeFn(OpShapeInferenceFn(ApplyAdamShapeFn))
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyAdamShapeFn(c, false /* sparse */);
+ }))
.Doc(R"doc(
Update '*var' according to the Adam algorithm.
@@ -537,7 +601,7 @@ use_locking: If `True`, updating of the var, m, and v tensors will be protected
contention.
)doc");
-static Status ApplyRMSPropShapeFn(InferenceContext* c) {
+static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) {
const Shape* unused;
const Shape* s = c->input(0); // var
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // ms
@@ -546,7 +610,8 @@ static Status ApplyRMSPropShapeFn(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // momentum
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // epsilon
- TF_RETURN_IF_ERROR(c->Merge(s, c->input(7), &s)); // grad
+ TF_RETURN_IF_ERROR(
+ HandleGradAndIndicesInputs(c, sparse, 7 /* grad_idx */, &s));
c->set_output(0, s);
return Status::OK();
}
@@ -563,7 +628,9 @@ REGISTER_OP("ApplyRMSProp")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
- .SetShapeFn(OpShapeInferenceFn(ApplyRMSPropShapeFn))
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyRMSPropShapeFn(c, false /* sparse */);
+ }))
.Doc(R"doc(
Update '*var' according to the RMSProp algorithm.
Note that in dense implement of this algorithm, ms and mom will
@@ -604,6 +671,9 @@ REGISTER_OP("SparseApplyRMSProp")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return ApplyRMSPropShapeFn(c, true /* sparse */);
+ }))
.Doc(R"doc(
Update '*var' according to the RMSProp algorithm.
Note that in dense implement of this algorithm, ms and mom will