diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-26 20:45:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 20:48:59 -0700 |
commit | 941e757a2364bb2e7cf41b8d980d7639849c6c5d (patch) | |
tree | 96306b0e0c91e7e38f77eaea8a9d5e2bc7686e58 /tensorflow/contrib/tpu | |
parent | 40ffbcc12519fa11e1dfb84f2f54a4f5d9b1b1c8 (diff) |
Fix custom getter handling in tpu.rewrite() and friends.
It used to save the existing custom getter then overwrites the custom getter. That means the previous custom getter will never be called inside "computation". It now create a new custom getter that calls the previous custom getter.
PiperOrigin-RevId: 214715720
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 712b02ff0d..883e08bf47 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -661,6 +661,10 @@ def split_compile_and_replicate(computation, # be less confusing to clients if they knowingly choose to use resource # variables. # Partitioned variables is not supported (b/112311320). + vscope = variable_scope.get_variable_scope() + saved_use_resource = vscope.use_resource + saved_custom_getter = vscope.custom_getter + def custom_getter(getter, name, *args, **kwargs): """Variables on TPU have a few restrictions.""" partitioner = kwargs["partitioner"] @@ -671,12 +675,10 @@ def split_compile_and_replicate(computation, "`partitioner` that is {} for variable {}. " "Setting `partitioner` to `None`." .format(partitioner, name)) - return getter(name, *args, **kwargs) - - vscope = variable_scope.get_variable_scope() - - saved_use_resource = vscope.use_resource - saved_custom_getter = vscope.custom_getter + if saved_custom_getter is None: + return getter(name, *args, **kwargs) + else: + return saved_custom_getter(getter, name, *args, **kwargs) vscope.set_use_resource(True) vscope.set_custom_getter(custom_getter) |