diff options
author | 2018-09-26 14:36:59 -0700 | |
---|---|---|
committer | 2018-09-26 14:41:28 -0700 | |
commit | a1801ecdbb75b4583d757204611afd9af28b4a49 (patch) | |
tree | 97f8e453de9d819a33bd4d54bc87278814c56fea /tensorflow/contrib/tpu/BUILD | |
parent | 2116c6649cfe339ce8a3859eb425806db8ae32b9 (diff) |
Add experimental asynchronous checkpoint hook.
This triggers checkpoints in a separate thread while allowing training to
continue. This can effectively parallelize checkpointing and training for
workloads like TPUEstimator, where the weights are only updated after a number
of device iterations.
PiperOrigin-RevId: 214670991
Diffstat (limited to 'tensorflow/contrib/tpu/BUILD')
-rw-r--r-- | tensorflow/contrib/tpu/BUILD | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 4e0b61227e..8355c92a4d 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -36,6 +36,27 @@ cc_library( ) py_library( + name = "async_checkpoint", + srcs = ["python/tpu/async_checkpoint.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:summary_ops_v2", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_library( name = "tpu_estimator", srcs = [ "python/tpu/error_handling.py", @@ -46,6 +67,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":async_checkpoint", ":tpu_lib", "//tensorflow/compiler/xla/experimental/xla_sharding", "//tensorflow/compiler/xla/python_api:xla_shape", |