diff options
Diffstat (limited to 'tensorflow/cc/ops/functional_grad.cc')
-rw-r--r-- | tensorflow/cc/ops/functional_grad.cc | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/cc/ops/functional_grad.cc b/tensorflow/cc/ops/functional_grad.cc new file mode 100644 index 0000000000..28b8b4a0e5 --- /dev/null +++ b/tensorflow/cc/ops/functional_grad.cc @@ -0,0 +1,42 @@ +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +Status MapAccumulateGrad(const AttrSlice& attrs, FunctionDef* ret) { + const NameAttrList* func; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "f", &func)); + DataType T; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); + int k; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "K", &k)); + // The gradient function of f. + // f : (K*T, T, T) -> T + // g : (K*T, T, T, T) -> (K*T, T, T) + auto grad = FDH::FunctionRef("SymbolicGradient", + {{"f", *func}, + {"Tin", std::vector<DataType>(k + 3, T)}, + {"Tout", std::vector<DataType>(k + 2, T)}}); + *ret = FDH::Define( + // Arg defs + {"theta: K*T", "x: T", "u: T", "dy: T"}, + // Ret val defs + {"dtheta: K*T", "dx: T", "du: T"}, + // Attr defs + {{"T: {float, double}"}}, + // nodes. + {{{"y"}, + "MapAccumulate", + {"theta", "x", "u"}, + {{"f", *func}, {"T", "$T"}, {"K", k}}}, + {{"dtheta", "dx", "du"}, + "MapAccumulateGrad", + {"theta", "x", "u", "y", "dy"}, + {{"g", grad}, {"T", "$T"}, {"K", k}}}}); + return Status::OK(); +} +REGISTER_OP_GRADIENT("MapAccumulate", MapAccumulateGrad); + +} // end namespace tensorflow |