diff options
-rw-r--r-- | tensorflow/contrib/distribute/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/__init__.py | 6 |
2 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index 1126f76f58..d3628d480d 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -25,10 +25,13 @@ py_library( srcs = ["__init__.py"], visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", + "//tensorflow/contrib/distribute/python:multi_worker_strategy", "//tensorflow/contrib/distribute/python:one_device_strategy", + "//tensorflow/contrib/distribute/python:parameter_server_strategy", "//tensorflow/contrib/distribute/python:step_fn", "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 2e2c3be853..9123ca749b 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -19,10 +19,13 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy +from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.training.distribute import * @@ -32,11 +35,14 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AllReduceCrossTowerOps', + 'CollectiveAllReduceStrategy', 'CrossTowerOps', 'DistributionStrategy', 'MirroredStrategy', + 'MultiWorkerMirroredStrategy', 'Monitor', 'OneDeviceStrategy', + 'ParameterServerStrategy', 'ReductionToOneDeviceCrossTowerOps', 'Step', 'StandardInputStep', |