diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-22 13:56:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 14:05:05 -0700 |
commit | 21cdc5b13e3eeb873f92648f229ad29b3b7b1129 (patch) | |
tree | 4ab3ee521572f7da90929ee1b8897461b5fd9d3f /tensorflow/contrib/eager | |
parent | 7acfb875a0217777287a299ea8013e16fca59d4e (diff) |
Making the side effect of result() function of tfe.Metrics.Mean optional.
PiperOrigin-RevId: 209824328
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/metrics_impl.py | 22 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/metrics_test.py | 22 |
2 files changed, 42 insertions, 2 deletions
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 6efafccd6b..930e62b680 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -336,9 +336,27 @@ class Mean(Metric): return values return values, weights - def result(self): + def result(self, write_summary=True): + """Returns the result of the Metric. + + Args: + write_summary: bool indicating whether to feed the result to the summary + before returning. + Returns: + aggregated metric as float. + Raises: + ValueError: if the optional argument is not bool + """ + # Convert the boolean to tensor for tf.cond, if it is not. + if not isinstance(write_summary, ops.Tensor): + write_summary = ops.convert_to_tensor(write_summary) t = self.numer / self.denom - summary_ops.scalar(name=self.name, tensor=t) + def write_summary_f(): + summary_ops.scalar(name=self.name, tensor=t) + return t + control_flow_ops.cond(write_summary, + write_summary_f, + lambda: t) return t diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 20d938d492..aa99616810 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -46,6 +46,18 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testSummaryArg(self): + m = metrics.Mean() + m([1, 10, 100]) + m(1000) + m([10000.0, 100000.0]) + self.assertEqual(111111.0/6, m.result(write_summary=True).numpy()) + self.assertEqual(111111.0/6, m.result(write_summary=False).numpy()) + with self.assertRaises(ValueError): + m.result(write_summary=5) + with self.assertRaises(ValueError): + m.result(write_summary=[True]) + def testVariableCollections(self): with context.graph_mode(), ops.Graph().as_default(): m = metrics.Mean() @@ -93,6 +105,16 @@ class MetricsTest(test.TestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 37.0) + # Get result without saving the summary. + logdir = tempfile.mkdtemp() + with summary_ops.create_file_writer( + logdir, max_queue=0, + name="t0").as_default(), summary_ops.always_record_summaries(): + m.result(write_summary=False) # As a side-effect will write summaries. + # events_from_logdir(_) asserts the directory exists. + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 1) + def testWeightedMean(self): m = metrics.Mean() m([1, 100, 100000], weights=[1, 0.2, 0.3]) |