diff options
author | 2018-09-10 20:39:11 -0700 | |
---|---|---|
committer | 2018-09-10 20:42:37 -0700 | |
commit | 0b176e9e45d391b2e6da5199fc6c5e8000a772a4 (patch) | |
tree | 87f8181ec1c81bc74ec68527785f45676a0799e0 | |
parent | e6830cdb06efe6f4cea2e4f30aa98f66ee1b305a (diff) |
Give a warning about partitioned variable on TPU and set it to None, instead of erring out.
PiperOrigin-RevId: 212385555
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index c1f90c3963..0f9f7cd91b 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -654,13 +654,16 @@ def split_compile_and_replicate(computation, # variables. # Partitioned variables is not supported (b/112311320). def custom_getter(getter, name, *args, **kwargs): + """Variables on TPU have a few restrictions.""" partitioner = kwargs["partitioner"] - if partitioner is None: - return getter(name, *args, **kwargs) - else: - raise ValueError( + if partitioner is not None: + kwargs["partitioner"] = None + logging.warning( "Partitioned variables are not supported on TPU. Got " - "`partitioner` that is {}.".format(partitioner)) + "`partitioner` that is {} for variable {}. " + "Setting `partitioner` to `None`." + .format(partitioner, name)) + return getter(name, *args, **kwargs) vscope = variable_scope.get_variable_scope() |