diff options
author | 2015-11-06 16:27:58 -0800 | |
---|---|---|
committer | 2015-11-06 16:27:58 -0800 | |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/cc/ops/nn_grad.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/cc/ops/nn_grad.cc')
-rw-r--r-- | tensorflow/cc/ops/nn_grad.cc | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/tensorflow/cc/ops/nn_grad.cc b/tensorflow/cc/ops/nn_grad.cc new file mode 100644 index 0000000000..89b037e3c8 --- /dev/null +++ b/tensorflow/cc/ops/nn_grad.cc @@ -0,0 +1,55 @@ +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"x: T", "dy: T"}, + // Ret val defs + {"dx: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + { + {{"dx"}, "ReluGrad", {"dy", "x"}, {{"T", "$T"}}} + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Relu", ReluGrad); + +Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"features: T", "labels: T", "dcost_dloss: T", "donotcare: T"}, + // Ret val defs + {"dcost_dfeatures: T", "dcost_dlabels: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + { + // _, dloss_dfeatures = CrossEntropy(features, labels) + {{"donotcare_loss", "dloss_dfeatures"}, "CrossEntropy", + {"features", "labels"}, {{"T", "$T"}}}, + // dcost_dloss is of shape [batch_size]. + // dcost_dloss_mat is of shape [batch_size, 1]. + FDH::Const("neg1", -1), + {{"dcost_dloss_mat"}, "ExpandDims", {"dcost_dloss", "neg1"}, + {{"T", "$T"}}}, + // chain rule: dcost/dfeatures = dcost/dloss * dloss/dfeatures + {{"dcost_dfeatures"}, "Mul", {"dcost_dloss_mat", "dloss_dfeatures"}, + {{"T", "$T"}}}, + {{"dcost_dlabels"}, "ZerosLike", {"labels"}, {{"T", "$T"}}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad); + +} // end namespace tensorflow |