aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-08 13:46:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 13:51:07 -0700
commitb052c51374f558c25a29c70918d79205dfec808b (patch)
treefd2eb631b27c5facc6c27a17cfa3c6feb539136d /tensorflow/python
parent76ab96c8a5b2d77dfc191c94ff54fd5e52c561f2 (diff)
Add tf.BenchmarkConfig that returns a session config appropriate for benchmarking. At the moment, it returns a default config with only Grappler dependency optimizer disabled. Many benchmarks wrap the subgraph they want to time in control_flow_ops.group() to avoid including the overhead of copying the output back to the Python client in the measurement. In the graph, this only adds a control dependency between the subgraph output and the fetch node, which in turn (often) causes the dependency optimizer to turn all nodes in the graph into no-ops.
PiperOrigin-RevId: 216242463
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/benchmark_test.py2
-rw-r--r--tensorflow/python/kernel_tests/cholesky_op_test.py7
-rw-r--r--tensorflow/python/kernel_tests/determinant_op_test.py9
-rw-r--r--tensorflow/python/kernel_tests/matrix_band_part_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_exponential_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_inverse_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_logarithm_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_solve_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py5
-rw-r--r--tensorflow/python/ops/image_ops_test.py62
-rw-r--r--tensorflow/python/platform/benchmark.py14
13 files changed, 76 insertions, 54 deletions
diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py
index 78b6e38d94..5777a5d097 100644
--- a/tensorflow/python/kernel_tests/benchmark_test.py
+++ b/tensorflow/python/kernel_tests/benchmark_test.py
@@ -64,7 +64,7 @@ class TestReportingBenchmark(test.Benchmark):
"other_key": "string"})
def benchmark_times_an_op(self):
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
a = constant_op.constant(0.0)
a_plus_a = a + a
return self.run_op_benchmark(
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index 782e6b5068..2ebf74a4d7 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@@ -327,7 +328,7 @@ class CholeskyBenchmark(test.Benchmark):
def benchmarkCholeskyOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = variables.Variable(self._GenerateMatrix(shape))
l = linalg_ops.cholesky(matrix)
@@ -341,7 +342,7 @@ class CholeskyBenchmark(test.Benchmark):
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/device:GPU:0"):
matrix = variables.Variable(self._GenerateMatrix(shape))
l = linalg_ops.cholesky(matrix)
@@ -359,7 +360,7 @@ class CholeskyBenchmark(test.Benchmark):
for shape in self.shapes:
matrix = self._GenerateMatrix(shape)
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device(device):
l = variables.Variable(np.linalg.cholesky(matrix))
grad_matrix = variables.Variable(
diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py
index a52b2c0dc3..fb114f9f24 100644
--- a/tensorflow/python/kernel_tests/determinant_op_test.py
+++ b/tensorflow/python/kernel_tests/determinant_op_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -185,8 +186,8 @@ class MatrixDeterminantBenchmark(test.Benchmark):
def benchmarkMatrixDeterminantOp(self):
for shape in self.shapes:
- with ops.Graph().as_default(), session.Session() as sess, ops.device(
- "/cpu:0"):
+ with ops.Graph().as_default(), session.Session(
+ config=benchmark.benchmark_config()) as sess, ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
d = linalg_ops.matrix_determinant(matrix)
variables.global_variables_initializer().run()
@@ -198,8 +199,8 @@ class MatrixDeterminantBenchmark(test.Benchmark):
name="matrix_determinant_cpu_{shape}".format(shape=shape))
if test.is_gpu_available(True):
- with ops.Graph().as_default(), session.Session() as sess, ops.device(
- "/gpu:0"):
+ with ops.Graph().as_default(), session.Session(
+ config=benchmark.benchmark_config()) as sess, ops.device("/gpu:0"):
matrix = self._GenerateMatrix(shape)
d = linalg_ops.matrix_determinant(matrix)
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
index 68d626de2c..a0ef3a607e 100644
--- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test as test_lib
@@ -109,7 +110,7 @@ class MatrixBandPartBenchmark(test_lib.Benchmark):
for shape_ in self.shapes:
for limits in (-1, -1), (-1, 0), (0, -1), (2, 2):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = variables.Variable(array_ops.ones(shape_))
band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
@@ -123,7 +124,7 @@ class MatrixBandPartBenchmark(test_lib.Benchmark):
if test_lib.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix = variables.Variable(array_ops.ones(shape_))
band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
index 0386e91276..9630c052b8 100644
--- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg_impl
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -181,7 +182,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
def benchmarkMatrixExponentialOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
expm = linalg_impl.matrix_exponential(matrix)
@@ -195,7 +196,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix = self._GenerateMatrix(shape)
expm = linalg_impl.matrix_exponential(matrix)
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index 720ba806e9..8bda04b53d 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -179,7 +180,7 @@ class MatrixInverseBenchmark(test.Benchmark):
for adjoint in False, True:
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
inv = linalg_ops.matrix_inverse(matrix, adjoint=adjoint)
@@ -193,7 +194,7 @@ class MatrixInverseBenchmark(test.Benchmark):
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix = self._GenerateMatrix(shape)
inv = linalg_ops.matrix_inverse(matrix, adjoint=adjoint)
diff --git a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
index 723a15fbd1..3205e211d9 100644
--- a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg_impl
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -159,7 +160,7 @@ class MatrixLogarithmBenchmark(test.Benchmark):
def benchmarkMatrixLogarithmOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
logm = gen_linalg_ops.matrix_logarithm(matrix)
diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
index de495968a7..225a10e117 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test as test_lib
@@ -313,7 +314,7 @@ class MatrixSolveLsBenchmark(test_lib.Benchmark):
for num_rhs in 1, 2, matrix_shape[-1]:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
@@ -328,7 +329,7 @@ class MatrixSolveLsBenchmark(test_lib.Benchmark):
if run_gpu_test and (len(matrix_shape) < 3 or matrix_shape[0] < 513):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
index b8f2736b7b..264df2565c 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -167,7 +168,7 @@ class MatrixSolveBenchmark(test.Benchmark):
for num_rhs in 1, 2, matrix_shape[-1]:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix, rhs = self._GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
@@ -185,7 +186,7 @@ class MatrixSolveBenchmark(test.Benchmark):
if run_gpu_test:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix, rhs = self._GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
diff --git a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
index 31e84341ae..fdfe1001b8 100644
--- a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
# pylint: disable=protected-access
@@ -192,7 +193,7 @@ class BenchmarkSparseTensorsMapVsSerialization(test.Benchmark):
sorted(zip(indices_batch, indices_value)), dtype=np.int64)
values = ["feature_value_for_embedding_lookup"] * num_elements
shape = np.asarray([batch_size, num_elements], dtype=np.int64)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
with ops.device("/cpu:0"):
indices = variables.Variable(indices)
values = variables.Variable(values)
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index 29fb002ef4..04ac589432 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -160,7 +161,7 @@ class WhereBenchmark(test.Benchmark):
x = random_ops.random_uniform((m, n), dtype=dtypes.float32) <= p
v = resource_variable_ops.ResourceVariable(x)
op = array_ops.where(v)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
v.initializer.run()
r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
gb_processed_input = m * n / 1.0e9
@@ -186,7 +187,7 @@ class WhereBenchmark(test.Benchmark):
y = resource_variable_ops.ResourceVariable(y_gen)
c = resource_variable_ops.ResourceVariable(c_gen)
op = array_ops.where(c, x, y)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
x.initializer.run()
y.initializer.run()
c.initializer.run()
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 35fdee4fad..ff86df6346 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -602,20 +602,19 @@ class AdjustHueBenchmark(test.Benchmark):
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
- with session.Session("", graph=ops.Graph(), config=config) as sess:
- with ops.device(device):
- inputs = variables.Variable(
- random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
- trainable=False,
- dtype=dtypes.float32)
- delta = constant_op.constant(0.1, dtype=dtypes.float32)
- outputs = image_ops.adjust_hue(inputs, delta)
- run_op = control_flow_ops.group(outputs)
- sess.run(variables.global_variables_initializer())
- for i in xrange(warmup_rounds + benchmark_rounds):
- if i == warmup_rounds:
- start = time.time()
- sess.run(run_op)
+ with self.benchmark_session(config=config, device=device) as sess:
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ delta = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = image_ops.adjust_hue(inputs, delta)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for i in xrange(warmup_rounds + benchmark_rounds):
+ if i == warmup_rounds:
+ start = time.time()
+ sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
@@ -646,21 +645,20 @@ class AdjustSaturationBenchmark(test.Benchmark):
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
- with session.Session("", graph=ops.Graph(), config=config) as sess:
- with ops.device(device):
- inputs = variables.Variable(
- random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
- trainable=False,
- dtype=dtypes.float32)
- delta = constant_op.constant(0.1, dtype=dtypes.float32)
- outputs = image_ops.adjust_saturation(inputs, delta)
- run_op = control_flow_ops.group(outputs)
- sess.run(variables.global_variables_initializer())
- for _ in xrange(warmup_rounds):
- sess.run(run_op)
- start = time.time()
- for _ in xrange(benchmark_rounds):
- sess.run(run_op)
+ with self.benchmark_session(config=config, device=device) as sess:
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ delta = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = image_ops.adjust_saturation(inputs, delta)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for _ in xrange(warmup_rounds):
+ sess.run(run_op)
+ start = time.time()
+ for _ in xrange(benchmark_rounds):
+ sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
@@ -699,7 +697,7 @@ class ResizeBilinearBenchmark(test.Benchmark):
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
- with session.Session() as sess:
+ with self.benchmark_session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess,
@@ -747,7 +745,7 @@ class ResizeBicubicBenchmark(test.Benchmark):
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
- with session.Session() as sess:
+ with self.benchmark_session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess,
@@ -804,7 +802,7 @@ class ResizeAreaBenchmark(test.Benchmark):
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
- with session.Session() as sess:
+ with self.benchmark_session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess,
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index fa17b17d10..4f7abb311a 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -27,6 +27,7 @@ import time
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.util import test_log_pb2
from tensorflow.python.client import timeline
from tensorflow.python.platform import app
@@ -182,6 +183,19 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
throughput=throughput, extras=extras)
+@tf_export("test.benchmark_config")
+def benchmark_config():
+ """Returns a tf.ConfigProto for disabling the dependency optimizer.
+
+ Returns:
+ A TensorFlow ConfigProto object.
+ """
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.dependency_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ return config
+
+
@tf_export("test.Benchmark")
class TensorFlowBenchmark(Benchmark):
"""Abstract class that provides helpers for TensorFlow benchmarks."""