diff options
Diffstat (limited to 'tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py')
-rw-r--r-- | tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py index 348623d8f8..470e300ccb 100644 --- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py +++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py @@ -21,11 +21,17 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op -from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample +from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader # pylint: enable=unused-import _periodic_resample_op = loader.load_op_library( resource_loader.get_path_to_datafile('_periodic_resample_op.so')) + +@ops.RegisterGradient("PeriodicResample") +def _periodic_resample_grad_cc(op, grad): + return periodic_resample_op_grad( + grad, op.inputs[0].shape, op.get_attr('shape')) |