diff options
author | 2016-06-27 15:01:58 -0800 | |
---|---|---|
committer | 2016-06-27 16:16:13 -0700 | |
commit | 380e801f9251995122ae09ac11d83bd6a30edefa (patch) | |
tree | 5754257fdffffe214971f4dd3eb429c0bee382c9 /tensorflow/python | |
parent | 7a62f1e0be05291cf046f57482f364e0033f6e05 (diff) |
Add a simple way to set the `default_session_config` on a `tf.train.Server`.
This makes it easier to set properties such as the
`gpu_options.per_process_gpu_memory_fraction`, which have to be set on
the server, rather than individual serssions.
Fixes #3057.
Change: 126009942
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/training/server_lib.py | 36 | ||||
-rw-r--r-- | tensorflow/python/training/server_lib_test.py | 24 |
2 files changed, 54 insertions, 6 deletions
diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index 9a3877c456..605e970fe6 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -26,7 +26,8 @@ from tensorflow.python.framework import errors from tensorflow.python.util import compat -def _make_server_def(server_or_cluster_def, job_name, task_index, protocol): +def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, + config): """Creates a `tf.train.ServerDef` protocol buffer. Args: @@ -43,6 +44,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol): protocol: (Optional.) Specifies the protocol to be used by the server. Acceptable values include `"grpc"`. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. + config: (Options.) A `tf.ConfigProto` that specifies default configuration + options for all sessions that run on this server. Returns: A `tf.train.ServerDef`. @@ -60,6 +63,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol): server_def.task_index = task_index if protocol is not None: server_def.protocol = protocol + if config is not None: + server_def.default_session_config.MergeFrom(config) else: try: cluster_spec = ClusterSpec(server_or_cluster_def) @@ -82,6 +87,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol): server_def = tensorflow_server_pb2.ServerDef( cluster=cluster_spec.as_cluster_def(), job_name=job_name, task_index=task_index, protocol=protocol) + if config is not None: + server_def.default_session_config.MergeFrom(config) return server_def @@ -98,6 +105,7 @@ class Server(object): @@__init__ @@create_local_server @@target + @@server_def @@start @@join @@ -108,6 +116,7 @@ class Server(object): job_name=None, task_index=None, protocol=None, + config=None, start=True): """Creates a new server with the given definition. @@ -128,6 +137,8 @@ class Server(object): protocol: (Optional.) Specifies the protocol to be used by the server. Acceptable values include `"grpc"`. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. + config: (Options.) A `tf.ConfigProto` that specifies default + configuration options for all sessions that run on this server. start: (Optional.) Boolean, indicating whether to start the server after creating it. Defaults to `True`. @@ -135,11 +146,11 @@ class Server(object): tf.errors.OpError: Or one of its subclasses if an error occurs while creating the TensorFlow server. """ - server_def = _make_server_def(server_or_cluster_def, - job_name, task_index, protocol) + self._server_def = _make_server_def(server_or_cluster_def, + job_name, task_index, protocol, config) with errors.raise_exception_on_not_ok_status() as status: self._server = pywrap_tensorflow.PyServer_New( - server_def.SerializeToString(), status) + self._server_def.SerializeToString(), status) if start: self.start() @@ -166,6 +177,16 @@ class Server(object): pywrap_tensorflow.PyServer_Join(self._server, status) @property + def server_def(self): + """Returns the `tf.train.ServerDef` for this server. + + Returns: + A `tf.train.ServerDef` prototocol buffer that describes the configuration + of this server. + """ + return self._server_def + + @property def target(self): """Returns the target for a `tf.Session` to connect to this server. @@ -185,7 +206,7 @@ class Server(object): return self._server.target() @staticmethod - def create_local_server(start=True): + def create_local_server(config=None, start=True): """Creates a new single-process cluster running on the local host. This method is a convenience wrapper for creating a @@ -194,6 +215,8 @@ class Server(object): `"local"`. Args: + config: (Options.) A `tf.ConfigProto` that specifies default + configuration options for all sessions that run on this server. start: (Optional.) Boolean, indicating whether to start the server after creating it. Defaults to `True`. @@ -202,7 +225,8 @@ class Server(object): """ # Specifying port 0 means that the OS will choose a free port for the # server. - return Server({"local": ["localhost:0"]}, protocol="grpc", start=start) + return Server({"local": ["localhost:0"]}, protocol="grpc", config=config, + start=start) class ClusterSpec(object): diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py index 94fa6c295f..40d399cced 100644 --- a/tensorflow/python/training/server_lib_test.py +++ b/tensorflow/python/training/server_lib_test.py @@ -250,6 +250,30 @@ class GrpcServerTest(tf.test.TestCase): sess.close() blocking_thread.join() + def testSetConfiguration(self): + config = tf.ConfigProto( + gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.1)) + + # Configure a server using the default local server options. + server = tf.train.Server.create_local_server(config=config, start=False) + self.assertEqual( + 0.1, + server.server_def.default_session_config + .gpu_options.per_process_gpu_memory_fraction) + + # Configure a server using an explicit ServerDefd with an + # overridden config. + cluster_def = tf.train.ClusterSpec( + {"localhost": ["localhost:0"]}).as_cluster_def() + server_def = tf.train.ServerDef( + cluster=cluster_def, job_name="localhost", task_index=0, + protocol="grpc") + server = tf.train.Server(server_def, config=config, start=False) + self.assertEqual( + 0.1, + server.server_def.default_session_config + .gpu_options.per_process_gpu_memory_fraction) + def testInvalidHostname(self): with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "port"): _ = tf.train.Server({"local": ["localhost"]}, |