diff options
Diffstat (limited to 'tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py')
-rw-r--r-- | tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py index 30e1281845..b66e7cd537 100644 --- a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py +++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py @@ -18,28 +18,33 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import tensorflow as tf +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test -class StatSummarizerTest(tf.test.TestCase): +class StatSummarizerTest(test.TestCase): def testStatSummarizer(self): - with tf.Graph().as_default() as graph: - matrix1 = tf.constant([[3., 3.]]) - matrix2 = tf.constant([[2.], [2.]]) - product = tf.matmul(matrix1, matrix2) + with ops.Graph().as_default() as graph: + matrix1 = constant_op.constant([[3., 3.]]) + matrix2 = constant_op.constant([[2.], [2.]]) + product = math_ops.matmul(matrix1, matrix2) graph_def = graph.as_graph_def() - ss = tf.contrib.stat_summarizer.NewStatSummarizer( - graph_def.SerializeToString()) + ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString()) with self.test_session() as sess: - sess.run(tf.global_variables_initializer()) + sess.run(variables.global_variables_initializer()) for _ in range(20): - run_metadata = tf.RunMetadata() - run_options = tf.RunOptions( - trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) sess.run(product, options=run_options, run_metadata=run_metadata) ss.ProcessStepStatsStr(run_metadata.step_stats.SerializeToString()) @@ -60,7 +65,8 @@ class StatSummarizerTest(tf.test.TestCase): # Test that a CDF summed to 100% self.assertRegexpMatches(output_string, r"100\.") - tf.contrib.stat_summarizer.DeleteStatSummarizer(ss) + pywrap_tensorflow.DeleteStatSummarizer(ss) + if __name__ == "__main__": - tf.test.main() + test.main() |