aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 20:39:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 20:42:37 -0700
commit0b176e9e45d391b2e6da5199fc6c5e8000a772a4 (patch)
tree87f8181ec1c81bc74ec68527785f45676a0799e0
parente6830cdb06efe6f4cea2e4f30aa98f66ee1b305a (diff)
Give a warning about partitioned variable on TPU and set it to None, instead of erring out.
PiperOrigin-RevId: 212385555
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index c1f90c3963..0f9f7cd91b 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -654,13 +654,16 @@ def split_compile_and_replicate(computation,
# variables.
# Partitioned variables is not supported (b/112311320).
def custom_getter(getter, name, *args, **kwargs):
+ """Variables on TPU have a few restrictions."""
partitioner = kwargs["partitioner"]
- if partitioner is None:
- return getter(name, *args, **kwargs)
- else:
- raise ValueError(
+ if partitioner is not None:
+ kwargs["partitioner"] = None
+ logging.warning(
"Partitioned variables are not supported on TPU. Got "
- "`partitioner` that is {}.".format(partitioner))
+ "`partitioner` that is {} for variable {}. "
+ "Setting `partitioner` to `None`."
+ .format(partitioner, name))
+ return getter(name, *args, **kwargs)
vscope = variable_scope.get_variable_scope()