aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/BUILD
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-09-26 14:36:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 14:41:28 -0700
commita1801ecdbb75b4583d757204611afd9af28b4a49 (patch)
tree97f8e453de9d819a33bd4d54bc87278814c56fea /tensorflow/contrib/tpu/BUILD
parent2116c6649cfe339ce8a3859eb425806db8ae32b9 (diff)
Add experimental asynchronous checkpoint hook.
This triggers checkpoints in a separate thread while allowing training to continue. This can effectively parallelize checkpointing and training for workloads like TPUEstimator, where the weights are only updated after a number of device iterations. PiperOrigin-RevId: 214670991
Diffstat (limited to 'tensorflow/contrib/tpu/BUILD')
-rw-r--r--tensorflow/contrib/tpu/BUILD22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 4e0b61227e..8355c92a4d 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -36,6 +36,27 @@ cc_library(
)
py_library(
+ name = "async_checkpoint",
+ srcs = ["python/tpu/async_checkpoint.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:summary_ops_v2",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_library(
name = "tpu_estimator",
srcs = [
"python/tpu/error_handling.py",
@@ -46,6 +67,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":async_checkpoint",
":tpu_lib",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",