aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-22 13:56:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 14:05:05 -0700
commit21cdc5b13e3eeb873f92648f229ad29b3b7b1129 (patch)
tree4ab3ee521572f7da90929ee1b8897461b5fd9d3f /tensorflow/contrib/eager
parent7acfb875a0217777287a299ea8013e16fca59d4e (diff)
Making the side effect of result() function of tfe.Metrics.Mean optional.
PiperOrigin-RevId: 209824328
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py22
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py22
2 files changed, 42 insertions, 2 deletions
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index 6efafccd6b..930e62b680 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -336,9 +336,27 @@ class Mean(Metric):
return values
return values, weights
- def result(self):
+ def result(self, write_summary=True):
+ """Returns the result of the Metric.
+
+ Args:
+ write_summary: bool indicating whether to feed the result to the summary
+ before returning.
+ Returns:
+ aggregated metric as float.
+ Raises:
+ ValueError: if the optional argument is not bool
+ """
+ # Convert the boolean to tensor for tf.cond, if it is not.
+ if not isinstance(write_summary, ops.Tensor):
+ write_summary = ops.convert_to_tensor(write_summary)
t = self.numer / self.denom
- summary_ops.scalar(name=self.name, tensor=t)
+ def write_summary_f():
+ summary_ops.scalar(name=self.name, tensor=t)
+ return t
+ control_flow_ops.cond(write_summary,
+ write_summary_f,
+ lambda: t)
return t
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index 20d938d492..aa99616810 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -46,6 +46,18 @@ class MetricsTest(test.TestCase):
self.assertEqual(dtypes.float64, m.dtype)
self.assertEqual(dtypes.float64, m.result().dtype)
+ def testSummaryArg(self):
+ m = metrics.Mean()
+ m([1, 10, 100])
+ m(1000)
+ m([10000.0, 100000.0])
+ self.assertEqual(111111.0/6, m.result(write_summary=True).numpy())
+ self.assertEqual(111111.0/6, m.result(write_summary=False).numpy())
+ with self.assertRaises(ValueError):
+ m.result(write_summary=5)
+ with self.assertRaises(ValueError):
+ m.result(write_summary=[True])
+
def testVariableCollections(self):
with context.graph_mode(), ops.Graph().as_default():
m = metrics.Mean()
@@ -93,6 +105,16 @@ class MetricsTest(test.TestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].simple_value, 37.0)
+ # Get result without saving the summary.
+ logdir = tempfile.mkdtemp()
+ with summary_ops.create_file_writer(
+ logdir, max_queue=0,
+ name="t0").as_default(), summary_ops.always_record_summaries():
+ m.result(write_summary=False) # As a side-effect will write summaries.
+ # events_from_logdir(_) asserts the directory exists.
+ events = summary_test_util.events_from_logdir(logdir)
+ self.assertEqual(len(events), 1)
+
def testWeightedMean(self):
m = metrics.Mean()
m([1, 100, 100000], weights=[1, 0.2, 0.3])