aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-06-27 15:01:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-27 16:16:13 -0700
commit380e801f9251995122ae09ac11d83bd6a30edefa (patch)
tree5754257fdffffe214971f4dd3eb429c0bee382c9 /tensorflow/python
parent7a62f1e0be05291cf046f57482f364e0033f6e05 (diff)
Add a simple way to set the `default_session_config` on a `tf.train.Server`.
This makes it easier to set properties such as the `gpu_options.per_process_gpu_memory_fraction`, which have to be set on the server, rather than individual serssions. Fixes #3057. Change: 126009942
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/training/server_lib.py36
-rw-r--r--tensorflow/python/training/server_lib_test.py24
2 files changed, 54 insertions, 6 deletions
diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py
index 9a3877c456..605e970fe6 100644
--- a/tensorflow/python/training/server_lib.py
+++ b/tensorflow/python/training/server_lib.py
@@ -26,7 +26,8 @@ from tensorflow.python.framework import errors
from tensorflow.python.util import compat
-def _make_server_def(server_or_cluster_def, job_name, task_index, protocol):
+def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
+ config):
"""Creates a `tf.train.ServerDef` protocol buffer.
Args:
@@ -43,6 +44,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol):
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc"`. Defaults to the value in
`server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
+ config: (Options.) A `tf.ConfigProto` that specifies default configuration
+ options for all sessions that run on this server.
Returns:
A `tf.train.ServerDef`.
@@ -60,6 +63,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol):
server_def.task_index = task_index
if protocol is not None:
server_def.protocol = protocol
+ if config is not None:
+ server_def.default_session_config.MergeFrom(config)
else:
try:
cluster_spec = ClusterSpec(server_or_cluster_def)
@@ -82,6 +87,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol):
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_spec.as_cluster_def(),
job_name=job_name, task_index=task_index, protocol=protocol)
+ if config is not None:
+ server_def.default_session_config.MergeFrom(config)
return server_def
@@ -98,6 +105,7 @@ class Server(object):
@@__init__
@@create_local_server
@@target
+ @@server_def
@@start
@@join
@@ -108,6 +116,7 @@ class Server(object):
job_name=None,
task_index=None,
protocol=None,
+ config=None,
start=True):
"""Creates a new server with the given definition.
@@ -128,6 +137,8 @@ class Server(object):
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc"`. Defaults to the value in
`server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
+ config: (Options.) A `tf.ConfigProto` that specifies default
+ configuration options for all sessions that run on this server.
start: (Optional.) Boolean, indicating whether to start the server
after creating it. Defaults to `True`.
@@ -135,11 +146,11 @@ class Server(object):
tf.errors.OpError: Or one of its subclasses if an error occurs while
creating the TensorFlow server.
"""
- server_def = _make_server_def(server_or_cluster_def,
- job_name, task_index, protocol)
+ self._server_def = _make_server_def(server_or_cluster_def,
+ job_name, task_index, protocol, config)
with errors.raise_exception_on_not_ok_status() as status:
self._server = pywrap_tensorflow.PyServer_New(
- server_def.SerializeToString(), status)
+ self._server_def.SerializeToString(), status)
if start:
self.start()
@@ -166,6 +177,16 @@ class Server(object):
pywrap_tensorflow.PyServer_Join(self._server, status)
@property
+ def server_def(self):
+ """Returns the `tf.train.ServerDef` for this server.
+
+ Returns:
+ A `tf.train.ServerDef` prototocol buffer that describes the configuration
+ of this server.
+ """
+ return self._server_def
+
+ @property
def target(self):
"""Returns the target for a `tf.Session` to connect to this server.
@@ -185,7 +206,7 @@ class Server(object):
return self._server.target()
@staticmethod
- def create_local_server(start=True):
+ def create_local_server(config=None, start=True):
"""Creates a new single-process cluster running on the local host.
This method is a convenience wrapper for creating a
@@ -194,6 +215,8 @@ class Server(object):
`"local"`.
Args:
+ config: (Options.) A `tf.ConfigProto` that specifies default
+ configuration options for all sessions that run on this server.
start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`.
@@ -202,7 +225,8 @@ class Server(object):
"""
# Specifying port 0 means that the OS will choose a free port for the
# server.
- return Server({"local": ["localhost:0"]}, protocol="grpc", start=start)
+ return Server({"local": ["localhost:0"]}, protocol="grpc", config=config,
+ start=start)
class ClusterSpec(object):
diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py
index 94fa6c295f..40d399cced 100644
--- a/tensorflow/python/training/server_lib_test.py
+++ b/tensorflow/python/training/server_lib_test.py
@@ -250,6 +250,30 @@ class GrpcServerTest(tf.test.TestCase):
sess.close()
blocking_thread.join()
+ def testSetConfiguration(self):
+ config = tf.ConfigProto(
+ gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.1))
+
+ # Configure a server using the default local server options.
+ server = tf.train.Server.create_local_server(config=config, start=False)
+ self.assertEqual(
+ 0.1,
+ server.server_def.default_session_config
+ .gpu_options.per_process_gpu_memory_fraction)
+
+ # Configure a server using an explicit ServerDefd with an
+ # overridden config.
+ cluster_def = tf.train.ClusterSpec(
+ {"localhost": ["localhost:0"]}).as_cluster_def()
+ server_def = tf.train.ServerDef(
+ cluster=cluster_def, job_name="localhost", task_index=0,
+ protocol="grpc")
+ server = tf.train.Server(server_def, config=config, start=False)
+ self.assertEqual(
+ 0.1,
+ server.server_def.default_session_config
+ .gpu_options.per_process_gpu_memory_fraction)
+
def testInvalidHostname(self):
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "port"):
_ = tf.train.Server({"local": ["localhost"]},