aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py')
-rw-r--r--tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
index 30e1281845..b66e7cd537 100644
--- a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
+++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
@@ -18,28 +18,33 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
-class StatSummarizerTest(tf.test.TestCase):
+class StatSummarizerTest(test.TestCase):
def testStatSummarizer(self):
- with tf.Graph().as_default() as graph:
- matrix1 = tf.constant([[3., 3.]])
- matrix2 = tf.constant([[2.], [2.]])
- product = tf.matmul(matrix1, matrix2)
+ with ops.Graph().as_default() as graph:
+ matrix1 = constant_op.constant([[3., 3.]])
+ matrix2 = constant_op.constant([[2.], [2.]])
+ product = math_ops.matmul(matrix1, matrix2)
graph_def = graph.as_graph_def()
- ss = tf.contrib.stat_summarizer.NewStatSummarizer(
- graph_def.SerializeToString())
+ ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
with self.test_session() as sess:
- sess.run(tf.global_variables_initializer())
+ sess.run(variables.global_variables_initializer())
for _ in range(20):
- run_metadata = tf.RunMetadata()
- run_options = tf.RunOptions(
- trace_level=tf.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
sess.run(product, options=run_options, run_metadata=run_metadata)
ss.ProcessStepStatsStr(run_metadata.step_stats.SerializeToString())
@@ -60,7 +65,8 @@ class StatSummarizerTest(tf.test.TestCase):
# Test that a CDF summed to 100%
self.assertRegexpMatches(output_string, r"100\.")
- tf.contrib.stat_summarizer.DeleteStatSummarizer(ss)
+ pywrap_tensorflow.DeleteStatSummarizer(ss)
+
if __name__ == "__main__":
- tf.test.main()
+ test.main()