diff options
Diffstat (limited to 'tensorflow/cc/ops/math_grad.cc')
-rw-r--r-- | tensorflow/cc/ops/math_grad.cc | 566 |
1 files changed, 566 insertions, 0 deletions
diff --git a/tensorflow/cc/ops/math_grad.cc b/tensorflow/cc/ops/math_grad.cc new file mode 100644 index 0000000000..4e8baa0d10 --- /dev/null +++ b/tensorflow/cc/ops/math_grad.cc @@ -0,0 +1,566 @@ +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +// Cwise binary ops +Status GradForUnaryCwise(FunctionDef* g, std::vector<FDH::Node> nodes) { + for (auto& n : nodes) { + if (n.attr.empty()) { + n.attr = {{"T", "$T"}}; + } + } + *g = FDH::Define( + // Arg defs + {"x: T", "dy: T"}, + // Ret val defs + {"dx: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + nodes); + return Status::OK(); +} + +Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"sign"}, "Sign", {"x"}}, + {{"dx"}, "Mul", {"dy", "sign"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Abs", AbsGrad); + +Status NegGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"dx"}, "Neg", {"dy"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Neg", NegGrad); + +Status InvGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Inv", {"x"}}, + {{"y2"}, "Square", {"y"}}, + {{"y2_neg"}, "Neg", {"y2"}}, + {{"dx"}, "Mul", {"dy", "y2_neg"}} + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Inv", InvGrad); + +Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + FDH::Const("c", 2LL), + {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"x2"}, "Mul", {"x", "two"}}, // x * 2 + {{"dx"}, "Mul", {"dy", "x2"}}, // dy * (x * 2) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Square", SquareGrad); + +Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Sqrt", {"x"}}, + {{"y_inv"}, "Inv", {"y"}}, + FDH::Const("const", 0.5f), + {{"half"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"a"}, "Mul", {"half", "y_inv"}}, // .5 * 1/y + {{"dx"}, "Mul", {"dy", "a"}}, // dy * (.5 * 1/y) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sqrt", SqrtGrad); + +Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"x_inv"}, "Inv", {"x"}}, + {{"y"}, "Rsqrt", {"x"}}, + FDH::Const("const", -.5f), + {{"neghalf"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"a"}, "Mul", {"neghalf", "x_inv"}}, // -0.5 * 1/x + {{"b"}, "Mul", {"a", "y"}}, // -0.5 * 1/x * y + {{"dx"}, "Mul", {"dy", "b"}}, // dy * (1/y * .5) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Rsqrt", RsqrtGrad); + +Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Exp", {"x"}}, + {{"dx"}, "Mul", {"dy", "y"}}, // dy * y + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Exp", ExpGrad); + +Status LogGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"x_inv"}, "Inv", {"x"}}, + {{"dx"}, "Mul", {"dy", "x_inv"}}, // dy * 1/x + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Log", LogGrad); + +Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Tanh", {"x"}}, + {{"y2"}, "Square", {"y"}}, + FDH::Const("const", 1.0f), + {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"a"}, "Sub", {"one", "y2"}}, + {{"dx"}, "Mul", {"dy", "a"}}, // dy * (1 - y*y) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Tanh", TanhGrad); + +Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Sigmoid", {"x"}}, + FDH::Const("const", 1.0f), + {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"a"}, "Sub", {"one", "y"}}, + {{"b"}, "Mul", {"y", "a"}}, // y * (1 - y) + {{"dx"}, "Mul", {"dy", "b"}}, // dy * y * (1 - y) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sigmoid", SigmoidGrad); + +Status SignGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"s"}, "Shape", {"x"}}, + FDH::Const("zero", 0.f), + {{"val"}, "Cast", {"zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"dx"}, "Fill", {"s", "val"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sign", SignGrad); + +Status SinGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"cos"}, "Cos", {"x"}}, + {{"dx"}, "Mul", {"dy", "cos"}}, // dy * cos(x) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sin", SinGrad); + +Status CosGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"sin"}, "Sin", {"x"}}, + {{"neg"}, "Neg", {"sin"}}, + {{"dx"}, "Mul", {"dy", "neg"}}, // dy * (-sin(x)) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Cos", CosGrad); + +Status RealGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + FDH::Const("zero", 0.f), + {{"dx"}, "Complex", {"dy", "zero"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Real", RealGrad); + +Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + FDH::Const("zero", 0.f), + {{"dx"}, "Complex", {"zero", "dy"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Imag", ImagGrad); + +Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"dx"}, "Conj", {"dy"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Conj", ConjGrad); + +// Cwise binary ops +// +// TODO(zhifengc): This can be arrange as a function in the standard +// library. +Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) { + // clang-format off + std::vector<FDH::Node> nodes = { + {{"sx"}, "Shape", {"x"}}, + {{"sy"}, "Shape", {"y"}}, + }; + nodes.insert(nodes.end(), body.begin(), body.end()); + std::vector<FDH::Node> reshapes = { + {{"sum_gx"}, "Sum", {"gx", "rx"}}, + {{"dx"}, "Reshape", {"sum_gx", "sx"}}, + {{"sum_gy"}, "Sum", {"gy", "ry"}}, + {{"dy"}, "Reshape", {"sum_gy", "sy"}}, + }; + nodes.insert(nodes.end(), reshapes.begin(), reshapes.end()); + + // clang-format on + for (auto& n : nodes) { + if (n.attr.empty()) { + n.attr = {{"T", "$T"}}; + } + } + // "BroadcastGradientArgs" doesn't need any attrs. + nodes.push_back({{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}}); + *g = FDH::Define( + // Arg defs + {"x: T", "y: T", "dz: T"}, + // Ret val defs + {"dx: T", "dy: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + nodes); + return Status::OK(); +} + +Status AddGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Identity", {"dz"}}, + {{"gy"}, "Identity", {"dz"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Add", AddGrad); + +Status SubGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Identity", {"dz"}}, + {{"gy"}, "Neg", {"dz"}}, // -dz + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sub", SubGrad); + +Status MulGrad(const AttrSlice& attrs, FunctionDef* g) { + DataType T; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); + if (T == DT_COMPLEX64) { + return GradForBinaryCwise( + g, { + {{"cy"}, "Conj", {"y"}}, + {{"gx"}, "Mul", {"dz", "cy"}}, // dz * Conj(y) + {{"cx"}, "Conj", {"x"}}, + {{"gy"}, "Mul", {"cx", "dz"}}, // Conj(x) * dz + }); + } else { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Mul", {"dz", "y"}}, // dz * y + {{"gy"}, "Mul", {"x", "dz"}}, // x * dz + }); + // clang-format on + } +} +REGISTER_OP_GRADIENT("Mul", MulGrad); + +Status DivGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Div", {"dz", "y"}}, + {{"nx"}, "Neg", {"x"}}, + {{"y2"}, "Square", {"y"}}, + {{"nx_y2"}, "Div", {"nx", "y2"}}, + {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Div", DivGrad); + +Status PowGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"z"}, "Pow", {"x", "y"}}, + // dz * y * Pow(x, y - 1) + FDH::Const("const", 1.0f), + {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"t0"}, "Sub", {"y", "one"}}, + {{"t1"}, "Pow", {"x", "t0"}}, + {{"t2"}, "Mul", {"dz", "y"}}, + {{"gx"}, "Mul", {"t1", "t2"}}, + // dz * z * Log(x) + {{"t3"}, "Log", {"x"}}, + {{"t4"}, "Mul", {"dz", "z"}}, + {{"gy"}, "Mul", {"t3", "t4"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Pow", PowGrad); + +Status MaximumMinimumGradHelper(const string& comparator, + const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"c"}, comparator, {"x", "y"}}, + {{"mask"}, "Cast", {"c"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}}, + {{"gx"}, "Mul", {"dz", "mask"}}, + {{"gy"}, "Sub", {"dz", "gx"}}, + }); + // clang-format on +} + +Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) { + return MaximumMinimumGradHelper("GreaterEqual", attrs, g); +} +REGISTER_OP_GRADIENT("Maximum", MaximumGrad); + +Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) { + return MaximumMinimumGradHelper("LessEqual", attrs, g); +} +REGISTER_OP_GRADIENT("Minimum", MinimumGrad); + +Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Real", {"dz"}}, + {{"gy"}, "Imag", {"dz"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Complex", ComplexGrad); + +// Cwise ternary ops. +Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + {"c:bool", "x:T", "y:T", "dz:T"}, + {"dc:bool", "dx:T", "dy:T"}, + {{"T: {float, double}"}}, + { + {{"dc"}, "ZerosLike", {"c"}, {{"T", DT_BOOL}}}, + {{"zeros"}, "ZerosLike", {"x"}, {{"T", "$T"}}}, + {{"dx"}, "Select", {"c", "dz", "zeros"}, {{"T", "$T"}}}, + {{"dy"}, "Select", {"c", "zeros", "dz"}, {{"T", "$T"}}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Select", SelectGrad); + +// N-ry ops +// REGISTER_OP_GRADIENT("AddN", AddNGrad); + +// Reduction ops +// +// TODO(zhifengc): This helper is pretty ugly. Do something better. +// TODO(zhifengc): This can be arrange as a function in the standard library. +Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) { + // Shape manipulation nodes. + + // clang-format off + std::vector<FDH::Node> nodes = { + {{"x_shape"}, "Shape", {"x"}}, + {{"x_rank"}, "Rank", {"x"}}, + {{"i_shape"}, "Shape", {"i"}, {{"T", DT_INT32}}}, + FDH::Const("zero", 0), + FDH::Const("one", 1), + // stitch_idx0 = Range(0, x_rank, 1) + {{"stitch_idx1"}, "Identity", {"i"}, {{"T", DT_INT32}}}, + {{"stitch_idx"}, "_ListToArray", {"stitch_idx0", "stitch_idx1"}, + {{"Tin", DataTypeSlice{DT_INT32, DT_INT32}}, + {"T", DT_INT32}, {"N", 2}}}, + {{"stitch_val0"}, "Identity", {"x_shape"}, {{"T", DT_INT32}}}, + {{"stitch_val1"}, "Fill", {"i_shape", "one"}, {{"T", DT_INT32}}}, + {{"stitch_val"}, "_ListToArray", {"stitch_val0", "stitch_val1"}, + {{"Tin", DataTypeSlice{DT_INT32, DT_INT32}}, + {"T", DT_INT32}, {"N", 2}}}, + {{"y_shape"}, "DynamicStitch", {"stitch_idx", "stitch_val"}, + {{"N", 2}, {"T", DT_INT32}}}, + {{"tile_scaling"}, "Div", {"x_shape", "y_shape"}, {{"T", DT_INT32}}}, + {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}} + }; + // clang-format on + nodes.insert(nodes.end(), body.begin(), body.end()); + for (auto& n : nodes) { + if (n.attr.empty()) { + n.attr = {{"T", "$T"}}; + } + } + // "Range" doesn't need any attr. + nodes.push_back({{"stitch_idx0"}, "Range", {"zero", "x_rank", "one"}, {}}); + *g = FDH::Define( + // Arg defs + {"x:T", "i:int32", "dy:T"}, + // Ret val defs + {"dx:T", "di:int32"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + nodes); + return Status::OK(); +} + +Status SumGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForReductionOp(g, { + {{"dy_reshaped"}, "Reshape", {"dy", "y_shape"}}, + {{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Sum", SumGrad); + +Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForReductionOp(g, { + {{"factor"}, "Prod", {"tile_scaling", "zero"}, {{"T", DT_INT32}}}, + {{"factor_T"}, "Cast", {"factor"}, {{"SrcT", DT_INT32}, {"DstT", "$T"}}}, + {{"dy_scaled"}, "Div", {"dy", "factor_T"}}, + {{"dy_reshaped"}, "Reshape", {"dy_scaled", "y_shape"}}, + {{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Mean", MeanGrad); + +// REGISTER_OP_GRADIENT("Prod", ProdGrad); +// REGISTER_OP_GRADIENT("SegmentSum", SegmentSumGrad); +// REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad); +// REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad); +// REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad); +// REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad); +// REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad); +// REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad); + +Status MinMaxGradHelper(const string& op, const AttrSlice& attrs, + FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"x:T", "i:int32", "dy:T"}, + // Ret val defs + {"dx:T", "di:int32"}, + // Attr defs + {{"T: {float, double}"}}, + { + // keep_dims because we need to do x == y, which requries x + // and y are broadcastable. + {{"y"}, op, {"x", "i"}, {{"T", "$T"}, {"keep_dims", true}}}, + {{"mask"}, "Equal", {"x", "y"}, {{"T", "$T"}}}, + {{"mask_cast"}, "Cast", {"mask"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}}, + {{"mask_sum"}, "Sum", {"mask_cast", "i"}, {{"T", "$T"}}}, + {{"norm_dy"}, "Div", {"dy", "mask_sum"}, {{"T", "$T"}}}, + {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}}, + {{"norm_dy_reshaped"}, "Reshape", {"norm_dy", "sy"}, {{"T", "$T"}}}, + {{"dx"}, "Mul", {"mask_cast", "norm_dy_reshaped"}, {{"T", "$T"}}}, + {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}} + }); + // clang-format on + return Status::OK(); +} + +Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) { + return MinMaxGradHelper("Max", attrs, g); +} +REGISTER_OP_GRADIENT("Max", MaxGrad); + +Status MinGrad(const AttrSlice& attrs, FunctionDef* g) { + return MinMaxGradHelper("Min", attrs, g); +} +REGISTER_OP_GRADIENT("Min", MinGrad); + +static Status MatMulGradHelper(FunctionDef* g, const string& x0, bool tx0, + const string& x1, bool tx1, const string& y0, + bool ty0, const string& y1, bool ty1) { + *g = FDH::Define( + // Arg defs + {"x: T", "y: T", "dz: T"}, + // Ret val defs + {"dx: T", "dy: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + { + {{"dx"}, + "MatMul", + {x0, x1}, + {{"T", "$T"}, {"transpose_a", tx0}, {"transpose_b", tx1}}}, + {{"dy"}, + "MatMul", + {y0, y1}, + {{"T", "$T"}, {"transpose_a", ty0}, {"transpose_b", ty1}}}, + }); + return Status::OK(); +} + +Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { + DataType T; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); + if (T == DT_COMPLEX64) { + return errors::Unimplemented( + "MatMul gradient for complex is not supported yet."); + } + bool ta; + bool tb; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "transpose_a", &ta)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "transpose_b", &tb)); + if (!ta && !tb) { + return MatMulGradHelper(g, "dz", false, "y", true, "x", true, "dz", false); + } + if (!ta && tb) { + return MatMulGradHelper(g, "dz", false, "y", false, "dz", true, "x", false); + } + if (ta && !tb) { + return MatMulGradHelper(g, "y", false, "dz", true, "x", false, "dz", false); + } + CHECK(ta && tb); + return MatMulGradHelper(g, "y", true, "dz", true, "dz", true, "x", true); +} +REGISTER_OP_GRADIENT("MatMul", MatMulGrad); + +// REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad); +// REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad); + +// Comparison ops. +REGISTER_OP_NO_GRADIENT("Less"); +REGISTER_OP_NO_GRADIENT("LessEqual"); +REGISTER_OP_NO_GRADIENT("Greater"); +REGISTER_OP_NO_GRADIENT("GreaterEqual"); +REGISTER_OP_NO_GRADIENT("Equal"); +REGISTER_OP_NO_GRADIENT("NotEqual"); + +// Logical ops. +REGISTER_OP_NO_GRADIENT("LogicalAnd"); +REGISTER_OP_NO_GRADIENT("LogicalOr"); +REGISTER_OP_NO_GRADIENT("LogicalNot"); + +// Sequence generation ops. +REGISTER_OP_NO_GRADIENT("Range"); +REGISTER_OP_NO_GRADIENT("LinSpace"); + +} // end namespace tensorflow |