aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index c1f90c3963..0f9f7cd91b 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -654,13 +654,16 @@ def split_compile_and_replicate(computation,
# variables.
# Partitioned variables is not supported (b/112311320).
def custom_getter(getter, name, *args, **kwargs):
+ """Variables on TPU have a few restrictions."""
partitioner = kwargs["partitioner"]
- if partitioner is None:
- return getter(name, *args, **kwargs)
- else:
- raise ValueError(
+ if partitioner is not None:
+ kwargs["partitioner"] = None
+ logging.warning(
"Partitioned variables are not supported on TPU. Got "
- "`partitioner` that is {}.".format(partitioner))
+ "`partitioner` that is {} for variable {}. "
+ "Setting `partitioner` to `None`."
+ .format(partitioner, name))
+ return getter(name, *args, **kwargs)
vscope = variable_scope.get_variable_scope()