aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-05-02 18:56:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-02 20:10:03 -0700
commit3af03be757b63ea6fbd28cc351d5d2323c526354 (patch)
tree7e29728255f6a4124ded1bde08afa2ac6d01d5a2 /tensorflow/tools/dist_test
parenta8d720c2b7e4260bee7020822168bfba852274ac (diff)
tfdbg: internal-only changes
Change: 154914490
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r--tensorflow/tools/dist_test/server/BUILD2
-rw-r--r--[-rwxr-xr-x]tensorflow/tools/dist_test/server/grpc_tensorflow_server.py12
2 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/tools/dist_test/server/BUILD b/tensorflow/tools/dist_test/server/BUILD
index 9d008ec9ce..865af8dd7b 100644
--- a/tensorflow/tools/dist_test/server/BUILD
+++ b/tensorflow/tools/dist_test/server/BUILD
@@ -9,7 +9,7 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
-py_library(
+py_binary(
name = "grpc_tensorflow_server",
srcs = [
"grpc_tensorflow_server.py",
diff --git a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
index 2d774577b6..bd6700a0b1 100755..100644
--- a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
+++ b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
@@ -36,6 +36,7 @@ from __future__ import print_function
import argparse
import sys
+from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.platform import app
from tensorflow.python.training import server_lib
@@ -103,8 +104,11 @@ def main(unused_args):
raise ValueError("Invalid task_id: %d" % FLAGS.task_id)
server_def.task_index = FLAGS.task_id
+ config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
+ per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction))
+
# Create GRPC Server instance
- server = server_lib.Server(server_def)
+ server = server_lib.Server(server_def, config=config)
# join() is blocking, unlike start()
server.join()
@@ -138,6 +142,11 @@ if __name__ == "__main__":
help="Task index, e.g., 0"
)
parser.add_argument(
+ "--gpu_memory_fraction",
+ type=float,
+ default=1.0,
+ help="Fraction of GPU memory allocated",)
+ parser.add_argument(
"--verbose",
type="bool",
nargs="?",
@@ -145,5 +154,6 @@ if __name__ == "__main__":
default=False,
help="Verbose mode"
)
+
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)