aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar Siu Kei, Muk <muksiukei@gmail.com>2018-04-16 10:23:20 +0800
committerGravatar Jonathan Hseu <vomjom@vomjom.net>2018-04-15 19:23:20 -0700
commitc6fdeaca7dd32c6bec3ff2df14889c3f2c129f14 (patch)
treee36a2f95f1919b241466f1ea4f3021859e952466 /tensorflow/contrib/learn
parentba1c53a5f2bb106e16ec7503dbd4d0db9ecc9799 (diff)
adding ps_strategy to run_config to enable different placement strate… (#15640)
* adding ps_strategy to run_config to enable different placement strategy in estimator * 1. Moved estimator._device_fn to RunConfig as @property 2. Made RunConfig.device_fn to return custom device function if one is specified, otherwise the result from `tf.train.replica_device_setter` call is used 3. Added some basic unit tests, may need further tests. * 1. Removing ps_strategy. 2. Modified estimator to take overriden device_fn from if set. 3. Removed ps_strategy related unit tests. * Adding manual initialization of _device_fn in legacy RunConfig class * Updated estimator golden API through 1. bazel build //tensorflow/tools/api/tests:api_compatibility_test 2. bazel-bin/tensorflow/tools/api/tests/api_compatibility_test --update_goldens True * fixing code styles
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/run_config.py1
1 files changed, 1 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
index 8c85c431be..14ee2ba609 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
@@ -299,6 +299,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
# so instead of breaking compatibility with that assumption, we
# just manually initialize this field:
self._train_distribute = None
+ self._device_fn = None
gpu_options = config_pb2.GPUOptions(
per_process_gpu_memory_fraction=gpu_memory_fraction)