diff options
Diffstat (limited to 'tensorflow/core/ops/nn_grad.cc')
-rw-r--r-- | tensorflow/core/ops/nn_grad.cc | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/core/ops/nn_grad.cc b/tensorflow/core/ops/nn_grad.cc index e3b876b240..05ad635f58 100644 --- a/tensorflow/core/ops/nn_grad.cc +++ b/tensorflow/core/ops/nn_grad.cc @@ -181,4 +181,35 @@ Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("MaxPool", MaxPoolGrad); +Status MaxPoolGradGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"input: T", "grad: T"}, + // Ret val defs + {"output: T"}, + // Attr defs + {"T: {float, half} = DT_FLOAT", + "ksize: list(int) >= 4", + "strides: list(int) >= 4", + GetPaddingAttrString()}, + // Nodes + { + // Invoke MaxPool again to recompute the outputs (removed by CSE?). + {{"maxpool"}, "MaxPool", {"input"}, + /*Attrs=*/{{"T", "$T"}, + {"ksize", "$ksize"}, + {"strides", "$strides"}, + {"padding", "$padding"}}}, + {{"output"}, "MaxPoolGradGrad", {"input", "maxpool", "grad"}, + /*Attrs=*/{{"T", "$T"}, + {"ksize", "$ksize"}, + {"strides", "$strides"}, + {"padding", "$padding"}}} + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("MaxPoolGrad", MaxPoolGradGrad); + } // end namespace tensorflow |