aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/nn_grad.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/nn_grad.cc')
-rw-r--r--tensorflow/core/ops/nn_grad.cc31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/core/ops/nn_grad.cc b/tensorflow/core/ops/nn_grad.cc
index e3b876b240..05ad635f58 100644
--- a/tensorflow/core/ops/nn_grad.cc
+++ b/tensorflow/core/ops/nn_grad.cc
@@ -181,4 +181,35 @@ Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("MaxPool", MaxPoolGrad);
+Status MaxPoolGradGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"input: T", "grad: T"},
+ // Ret val defs
+ {"output: T"},
+ // Attr defs
+ {"T: {float, half} = DT_FLOAT",
+ "ksize: list(int) >= 4",
+ "strides: list(int) >= 4",
+ GetPaddingAttrString()},
+ // Nodes
+ {
+ // Invoke MaxPool again to recompute the outputs (removed by CSE?).
+ {{"maxpool"}, "MaxPool", {"input"},
+ /*Attrs=*/{{"T", "$T"},
+ {"ksize", "$ksize"},
+ {"strides", "$strides"},
+ {"padding", "$padding"}}},
+ {{"output"}, "MaxPoolGradGrad", {"input", "maxpool", "grad"},
+ /*Attrs=*/{{"T", "$T"},
+ {"ksize", "$ksize"},
+ {"strides", "$strides"},
+ {"padding", "$padding"}}}
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("MaxPoolGrad", MaxPoolGradGrad);
+
} // end namespace tensorflow