aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops/nn_grad.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/ops/nn_grad.cc')
-rw-r--r--tensorflow/cc/ops/nn_grad.cc55
1 files changed, 55 insertions, 0 deletions
diff --git a/tensorflow/cc/ops/nn_grad.cc b/tensorflow/cc/ops/nn_grad.cc
new file mode 100644
index 0000000000..89b037e3c8
--- /dev/null
+++ b/tensorflow/cc/ops/nn_grad.cc
@@ -0,0 +1,55 @@
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+typedef FunctionDefHelper FDH;
+
+Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "dy: T"},
+ // Ret val defs
+ {"dx: T"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // Nodes
+ {
+ {{"dx"}, "ReluGrad", {"dy", "x"}, {{"T", "$T"}}}
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("Relu", ReluGrad);
+
+Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"features: T", "labels: T", "dcost_dloss: T", "donotcare: T"},
+ // Ret val defs
+ {"dcost_dfeatures: T", "dcost_dlabels: T"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // Nodes
+ {
+ // _, dloss_dfeatures = CrossEntropy(features, labels)
+ {{"donotcare_loss", "dloss_dfeatures"}, "CrossEntropy",
+ {"features", "labels"}, {{"T", "$T"}}},
+ // dcost_dloss is of shape [batch_size].
+ // dcost_dloss_mat is of shape [batch_size, 1].
+ FDH::Const("neg1", -1),
+ {{"dcost_dloss_mat"}, "ExpandDims", {"dcost_dloss", "neg1"},
+ {{"T", "$T"}}},
+ // chain rule: dcost/dfeatures = dcost/dloss * dloss/dfeatures
+ {{"dcost_dfeatures"}, "Mul", {"dcost_dloss_mat", "dloss_dfeatures"},
+ {{"T", "$T"}}},
+ {{"dcost_dlabels"}, "ZerosLike", {"labels"}, {{"T", "$T"}}},
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad);
+
+} // end namespace tensorflow