aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Jing Li <jingli@google.com>2018-10-05 09:54:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 09:59:41 -0700
commit8b7c789e7401fe56b4f648a04f675a3cb69119e5 (patch)
tree774d7ef4eaf5169620ae26325cc819bf47d072f2 /tensorflow/contrib/tpu
parent5a43e01ef0f8cb86d836a4d1c08a246630e26f8c (diff)
- Don't set tpu optimizer parameter variable during weight initialization if the optimizer isn't set, e.g. loading weights and then predict.
- Add load_weights for `KerasTpuModel`. PiperOrigin-RevId: 215920993
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index a3a7fd8bb0..af183b3232 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -1998,6 +1998,9 @@ class KerasTPUModel(models.Model):
logging.info('Setting weights on TPU model.')
cloned_model.set_weights(weights)
+ if self._tpu_model.optimizer is None:
+ # tpu_model may not be compiled, e.g., loading weights and then predict.
+ return
for k, v in six.iteritems(cpu_optimizer_config):
opt_var = getattr(self._tpu_model.optimizer, k)
if isinstance(opt_var, variables.Variable):
@@ -2052,6 +2055,10 @@ class KerasTPUModel(models.Model):
self._cpu_model.set_weights(weights)
self._tpu_weights_initialized = False
+ def load_weights(self, filepath, by_name=False):
+ self._cpu_model.load_weights(filepath, by_name)
+ self._tpu_weights_initialized = False
+
# pylint: disable=bad-continuation
def _validate_shapes(model):