aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/while_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/while_v2.py')
-rw-r--r--tensorflow/python/ops/while_v2.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 8e88a84d60..0419656143 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
@@ -580,7 +581,7 @@ def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
- function._copy_handle_data(src_t, tgt_t)
+ custom_gradient.copy_handle_data(src_t, tgt_t)
# TODO(srbs): Move to common utils for cond_v2 and while_v2.