aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 20:45:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 20:48:59 -0700
commit941e757a2364bb2e7cf41b8d980d7639849c6c5d (patch)
tree96306b0e0c91e7e38f77eaea8a9d5e2bc7686e58 /tensorflow/contrib/tpu
parent40ffbcc12519fa11e1dfb84f2f54a4f5d9b1b1c8 (diff)
Fix custom getter handling in tpu.rewrite() and friends.
It used to save the existing custom getter then overwrites the custom getter. That means the previous custom getter will never be called inside "computation". It now create a new custom getter that calls the previous custom getter. PiperOrigin-RevId: 214715720
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 712b02ff0d..883e08bf47 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -661,6 +661,10 @@ def split_compile_and_replicate(computation,
# be less confusing to clients if they knowingly choose to use resource
# variables.
# Partitioned variables is not supported (b/112311320).
+ vscope = variable_scope.get_variable_scope()
+ saved_use_resource = vscope.use_resource
+ saved_custom_getter = vscope.custom_getter
+
def custom_getter(getter, name, *args, **kwargs):
"""Variables on TPU have a few restrictions."""
partitioner = kwargs["partitioner"]
@@ -671,12 +675,10 @@ def split_compile_and_replicate(computation,
"`partitioner` that is {} for variable {}. "
"Setting `partitioner` to `None`."
.format(partitioner, name))
- return getter(name, *args, **kwargs)
-
- vscope = variable_scope.get_variable_scope()
-
- saved_use_resource = vscope.use_resource
- saved_custom_getter = vscope.custom_getter
+ if saved_custom_getter is None:
+ return getter(name, *args, **kwargs)
+ else:
+ return saved_custom_getter(getter, name, *args, **kwargs)
vscope.set_use_resource(True)
vscope.set_custom_getter(custom_getter)