aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/BUILD3
-rw-r--r--tensorflow/contrib/distribute/__init__.py6
2 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD
index 1126f76f58..d3628d480d 100644
--- a/tensorflow/contrib/distribute/BUILD
+++ b/tensorflow/contrib/distribute/BUILD
@@ -25,10 +25,13 @@ py_library(
srcs = ["__init__.py"],
visibility = ["//tensorflow:internal"],
deps = [
+ "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy",
"//tensorflow/contrib/distribute/python:cross_tower_ops",
"//tensorflow/contrib/distribute/python:mirrored_strategy",
"//tensorflow/contrib/distribute/python:monitor",
+ "//tensorflow/contrib/distribute/python:multi_worker_strategy",
"//tensorflow/contrib/distribute/python:one_device_strategy",
+ "//tensorflow/contrib/distribute/python:parameter_server_strategy",
"//tensorflow/contrib/distribute/python:step_fn",
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training",
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 2e2c3be853..9123ca749b 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -19,10 +19,13 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy
from tensorflow.contrib.distribute.python.cross_tower_ops import *
from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
+from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy
from tensorflow.contrib.distribute.python.monitor import Monitor
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
+from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.training.distribute import *
@@ -32,11 +35,14 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'AllReduceCrossTowerOps',
+ 'CollectiveAllReduceStrategy',
'CrossTowerOps',
'DistributionStrategy',
'MirroredStrategy',
+ 'MultiWorkerMirroredStrategy',
'Monitor',
'OneDeviceStrategy',
+ 'ParameterServerStrategy',
'ReductionToOneDeviceCrossTowerOps',
'Step',
'StandardInputStep',