aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py')
-rw-r--r--tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
index 348623d8f8..470e300ccb 100644
--- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
+++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
@@ -21,11 +21,17 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op
-from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample
+from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad
from tensorflow.contrib.util import loader
+from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
# pylint: enable=unused-import
_periodic_resample_op = loader.load_op_library(
resource_loader.get_path_to_datafile('_periodic_resample_op.so'))
+
+@ops.RegisterGradient("PeriodicResample")
+def _periodic_resample_grad_cc(op, grad):
+ return periodic_resample_op_grad(
+ grad, op.inputs[0].shape, op.get_attr('shape'))