diff options
author | 2016-07-18 16:58:41 -0800 | |
---|---|---|
committer | 2016-07-18 18:03:33 -0700 | |
commit | 075d1ab68fcaf5ddd90c3d30f6a046b78c3c04f3 (patch) | |
tree | ff341eceaa08df93ffafa1e6f13e6e2cda2dba48 /tensorflow/core/ops/training_ops.cc | |
parent | 6696de2b02d9f2c6d16fd39ef388be87309525bb (diff) |
Add C++ shape inference functions for the Sparse ops in training_ops.cc.
Change: 127782110
Diffstat (limited to 'tensorflow/core/ops/training_ops.cc')
-rw-r--r-- | tensorflow/core/ops/training_ops.cc | 140 |
1 files changed, 105 insertions, 35 deletions
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index d80cbf6c35..f779215bc1 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -23,6 +23,30 @@ typedef shape_inference::InferenceContext InferenceContext; typedef shape_inference::Shape Shape; static constexpr auto kUnknownDim = InferenceContext::kUnknownDim; +// Handle the gradient and, if <sparse>, indices inputs. +// <s> is an input+output parameter, containing the current known input shape to +// the gradient. +static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse, + int grad_idx, const Shape** s) { + const Shape* grad = c->input(grad_idx); + if (!sparse) { + TF_RETURN_IF_ERROR(c->Merge(*s, grad, s)); + return Status::OK(); + } + // Indices is a vector where indices.dim[0].rank == grad[0].rank. + const Shape* indices; + TF_RETURN_IF_ERROR(c->WithRank(c->input(grad_idx + 1), 1, &indices)); + const Dimension* unused; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused)); + + // Trailing part of grad matches *s. + const Shape* grad_subshape; + TF_RETURN_IF_ERROR(c->Subshape(grad, 1, &grad_subshape)); + TF_RETURN_IF_ERROR(c->Merge(*s, grad_subshape, s)); + + return Status::OK(); +} + static Status ApplyGradientDescentShapeFn(InferenceContext* c) { const Shape* unused; const Shape* s = c->input(0); // var @@ -51,13 +75,15 @@ use_locking: If `True`, the subtraction will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); -static Status ApplyProxiimalGradientDescentShapeFn(InferenceContext* c) { +static Status ApplyProximalGradientDescentShapeFn(InferenceContext* c, + bool sparse) { const Shape* unused; const Shape* s = c->input(0); // var TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // l1 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l2 - TF_RETURN_IF_ERROR(c->Merge(s, c->input(4), &s)); // delta + TF_RETURN_IF_ERROR( + HandleGradAndIndicesInputs(c, sparse, 4 /* grad_idx */, &s)); c->set_output(0, s); return Status::OK(); } @@ -71,7 +97,9 @@ REGISTER_OP("ApplyProximalGradientDescent") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") - .SetShapeFn(OpShapeInferenceFn(ApplyProxiimalGradientDescentShapeFn)) + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyProximalGradientDescentShapeFn(c, false /* sparse */); + })) .Doc(R"doc( Update '*var' as FOBOS algorithm with fixed learning rate. prox_v = var - alpha * delta @@ -98,6 +126,9 @@ REGISTER_OP("SparseApplyProximalGradientDescent") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyProximalGradientDescentShapeFn(c, true /* sparse */); + })) .Doc(R"doc( Sparse update '*var' as FOBOS algorithm with fixed learning rate. @@ -115,7 +146,7 @@ out: Same as "var". use_locking: If True, the subtraction will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); -static Status ApplyAdadeltaShapeFn(InferenceContext* c) { +static Status ApplyAdadeltaShapeFn(InferenceContext* c, bool sparse) { const Shape* unused; const Shape* s = c->input(0); // var TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum @@ -123,7 +154,8 @@ static Status ApplyAdadeltaShapeFn(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // epsilon - TF_RETURN_IF_ERROR(c->Merge(s, c->input(6), &s)); // grad + TF_RETURN_IF_ERROR( + HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s)); c->set_output(0, s); return Status::OK(); } @@ -139,7 +171,9 @@ REGISTER_OP("ApplyAdadelta") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") - .SetShapeFn(OpShapeInferenceFn(ApplyAdadeltaShapeFn)) + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyAdadeltaShapeFn(c, false /* sparse */); + })) .Doc(R"doc( Update '*var' according to the adadelta scheme. @@ -173,6 +207,9 @@ REGISTER_OP("SparseApplyAdadelta") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyAdadeltaShapeFn(c, true /* sparse */); + })) .Doc(R"doc( var: Should be from a Variable(). accum: Should be from a Variable(). @@ -186,12 +223,14 @@ out: Same as "var". use_locking: If True, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); -static Status ApplyAdagradShapeFn(InferenceContext* c) { + +static Status ApplyAdagradShapeFn(InferenceContext* c, bool sparse) { const Shape* unused; const Shape* s = c->input(0); // var TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr - TF_RETURN_IF_ERROR(c->Merge(s, c->input(3), &s)); // grad + TF_RETURN_IF_ERROR( + HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s)); c->set_output(0, s); return Status::OK(); } @@ -204,8 +243,9 @@ REGISTER_OP("ApplyAdagrad") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") - .SetShapeFn(OpShapeInferenceFn(ApplyAdagradShapeFn)) - + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyAdagradShapeFn(c, false /* sparse */); + })) .Doc(R"doc( Update '*var' according to the adagrad scheme. @@ -221,14 +261,15 @@ use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); -static Status ApplyProximalAdagradShapeFn(InferenceContext* c) { +static Status ApplyProximalAdagradShapeFn(InferenceContext* c, bool sparse) { const Shape* unused; const Shape* s = c->input(0); // var TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l1 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // l2 - TF_RETURN_IF_ERROR(c->Merge(s, c->input(5), &s)); // grad + TF_RETURN_IF_ERROR( + HandleGradAndIndicesInputs(c, sparse, 5 /* grad_idx */, &s)); c->set_output(0, s); return Status::OK(); } @@ -243,8 +284,9 @@ REGISTER_OP("ApplyProximalAdagrad") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") - .SetShapeFn(OpShapeInferenceFn(ApplyProximalAdagradShapeFn)) - + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyProximalAdagradShapeFn(c, false /* sparse */); + })) .Doc(R"doc( Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. accum += grad * grad @@ -272,6 +314,9 @@ REGISTER_OP("SparseApplyAdagrad") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyAdagradShapeFn(c, true /* sparse */); + })) .Doc(R"doc( Update relevant entries in '*var' and '*accum' according to the adagrad scheme. @@ -302,6 +347,9 @@ REGISTER_OP("SparseApplyProximalAdagrad") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyProximalAdagradShapeFn(c, true /* sparse */); + })) .Doc(R"doc( Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. @@ -323,16 +371,18 @@ use_locking: If True, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); -static Status ApplyFtrlShapeFn(InferenceContext* c) { +static Status ApplyFtrlShapeFn(InferenceContext* c, bool sparse) { const Shape* unused; - const Shape* s = c->input(0); // var - TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum - TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // linear - TF_RETURN_IF_ERROR(c->Merge(s, c->input(3), &s)); // grad - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // l1 - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // l2 - TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // lr_power + const Shape* s = c->input(0); // var + TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum + TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // linear + TF_RETURN_IF_ERROR( + HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s)); + int idx = sparse ? 5 : 4; + TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr + TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1 + TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2 + TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr_power c->set_output(0, s); return Status::OK(); } @@ -349,8 +399,9 @@ REGISTER_OP("ApplyFtrl") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") - .SetShapeFn(OpShapeInferenceFn(ApplyFtrlShapeFn)) - + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyFtrlShapeFn(c, false /* sparse */); + })) .Doc(R"doc( Update '*var' according to the Ftrl-proximal scheme. @@ -388,6 +439,9 @@ REGISTER_OP("SparseApplyFtrl") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyFtrlShapeFn(c, true /* sparse */); + })) .Doc(R"doc( Update relevant entries in '*var' according to the Ftrl-proximal scheme. @@ -413,13 +467,15 @@ use_locking: If `True`, updating of the var and accum tensors will be protected contention. )doc"); -static Status ApplyMomentumShapeFn(InferenceContext* c) { +static Status ApplyMomentumShapeFn(InferenceContext* c, bool sparse) { const Shape* unused; const Shape* s = c->input(0); // var TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // accum TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr - TF_RETURN_IF_ERROR(c->Merge(s, c->input(3), &s)); // grad - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // momentum + TF_RETURN_IF_ERROR( + HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s)); + int idx = sparse ? 5 : 4; + TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // momentum c->set_output(0, s); return Status::OK(); } @@ -433,7 +489,9 @@ REGISTER_OP("ApplyMomentum") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") - .SetShapeFn(OpShapeInferenceFn(ApplyMomentumShapeFn)) + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyMomentumShapeFn(c, false /* sparse */); + })) .Doc(R"doc( Update '*var' according to the momentum scheme. @@ -462,6 +520,9 @@ REGISTER_OP("SparseApplyMomentum") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyMomentumShapeFn(c, true /* sparse */); + })) .Doc(R"doc( Update relevant entries in '*var' and '*accum' according to the momentum scheme. @@ -482,7 +543,7 @@ use_locking: If `True`, updating of the var and accum tensors will be protected contention. )doc"); -static Status ApplyAdamShapeFn(InferenceContext* c) { +static Status ApplyAdamShapeFn(InferenceContext* c, bool sparse) { const Shape* unused; const Shape* s = c->input(0); // var TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // m @@ -493,7 +554,8 @@ static Status ApplyAdamShapeFn(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta1 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta2 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // epsilon - TF_RETURN_IF_ERROR(c->Merge(s, c->input(9), &s)); // grad + TF_RETURN_IF_ERROR( + HandleGradAndIndicesInputs(c, sparse, 9 /* grad_idx */, &s)); c->set_output(0, s); return Status::OK(); } @@ -512,7 +574,9 @@ REGISTER_OP("ApplyAdam") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") - .SetShapeFn(OpShapeInferenceFn(ApplyAdamShapeFn)) + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyAdamShapeFn(c, false /* sparse */); + })) .Doc(R"doc( Update '*var' according to the Adam algorithm. @@ -537,7 +601,7 @@ use_locking: If `True`, updating of the var, m, and v tensors will be protected contention. )doc"); -static Status ApplyRMSPropShapeFn(InferenceContext* c) { +static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) { const Shape* unused; const Shape* s = c->input(0); // var TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // ms @@ -546,7 +610,8 @@ static Status ApplyRMSPropShapeFn(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // momentum TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // epsilon - TF_RETURN_IF_ERROR(c->Merge(s, c->input(7), &s)); // grad + TF_RETURN_IF_ERROR( + HandleGradAndIndicesInputs(c, sparse, 7 /* grad_idx */, &s)); c->set_output(0, s); return Status::OK(); } @@ -563,7 +628,9 @@ REGISTER_OP("ApplyRMSProp") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") - .SetShapeFn(OpShapeInferenceFn(ApplyRMSPropShapeFn)) + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyRMSPropShapeFn(c, false /* sparse */); + })) .Doc(R"doc( Update '*var' according to the RMSProp algorithm. Note that in dense implement of this algorithm, ms and mom will @@ -604,6 +671,9 @@ REGISTER_OP("SparseApplyRMSProp") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return ApplyRMSPropShapeFn(c, true /* sparse */); + })) .Doc(R"doc( Update '*var' according to the RMSProp algorithm. Note that in dense implement of this algorithm, ms and mom will |