aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar Younghee Kwon <youngheek@google.com>2018-06-15 11:31:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-15 11:39:29 -0700
commita7fcc5da93988b6cbb1f64fcee1e7862d1f788ab (patch)
treec1a1b981119e1d5c24038bfb642161e9cb296213 /tensorflow/contrib/timeseries
parent8ba25e36b948555f6b5df079b968b2a1382b5328 (diff)
contrib.timeseries: sets the predictions dict in EstimatorSpec for evaluation op.
PiperOrigin-RevId: 200747192
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py13
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py45
2 files changed, 51 insertions, 7 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index a28a5872b8..f236329fdb 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -132,7 +132,8 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
loss=model_outputs.loss,
mode=mode,
eval_metric_ops=metrics,
- predictions={})
+ # needed for custom metrics.
+ predictions=model_outputs.predictions)
def _predict_ops(self, features):
"""Add ops for prediction to the graph."""
@@ -210,12 +211,12 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
def create_estimator_spec(self, features, mode, labels=None):
"""Performs basic error checking and returns an EstimatorSpec."""
with ops.name_scope(self._name, "head"):
- if labels:
+ if labels is not None and labels != {}: # for better error messages.
raise ValueError(
- "The model received a `labels` dictionary, which is "
- "not supported. Pass '{}' and '{}' as "
- "features.".format(feature_keys.TrainEvalFeatures.TIMES,
- feature_keys.TrainEvalFeatures.VALUES))
+ "The model received a `labels`, which is not supported. "
+ "Pass '{}' and '{}' as features.".format(
+ feature_keys.TrainEvalFeatures.TIMES,
+ feature_keys.TrainEvalFeatures.VALUES))
del labels
features = {
name: self._convert_feature_to_tensor(name=name, value=value)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index c606db76a6..ed8f29c321 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy
import six
+from tensorflow.contrib.estimator.python.estimator import extenders
from tensorflow.contrib.timeseries.examples import lstm as lstm_example
from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
@@ -35,6 +36,7 @@ from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import variables
@@ -53,9 +55,12 @@ class HeadTest(test.TestCase):
model_fn = _stub_model_fn()
for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL,
estimator_lib.ModeKeys.PREDICT]:
- with self.assertRaisesRegexp(ValueError, "labels"):
+ with self.assertRaisesRegexp(ValueError, "received a `labels`"):
model_fn(features={}, labels={"a": "b"}, mode=mode)
+ with self.assertRaisesRegexp(ValueError, "received a `labels`"):
+ model_fn(features={}, labels=array_ops.zeros([]), mode=mode)
+
def test_unknown_mode(self):
model_fn = _stub_model_fn()
with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"):
@@ -128,6 +133,44 @@ class EvaluationMetricsTests(test.TestCase):
coordinator.request_stop()
coordinator.join()
+ def test_custom_metrics(self):
+ """Tests that the custom metrics can be applied to the estimator."""
+ model_dir = self.get_temp_dir()
+ estimator = ts_estimators.TimeSeriesRegressor(
+ model=lstm_example._LSTMModel(num_features=1, num_units=4),
+ optimizer=adam.AdamOptimizer(0.001),
+ config=estimator_lib.RunConfig(tf_random_seed=4),
+ model_dir=model_dir)
+
+ def input_fn():
+ return {
+ feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3], [7, 8, 9]],
+ feature_keys.TrainEvalFeatures.VALUES:
+ numpy.array([[[0.], [1.], [0.]], [[2.], [3.], [2.]]])
+ }
+
+ def metrics_fn(predictions, features):
+ # checking that the inputs are properly passed.
+ predict = predictions["mean"]
+ target = features[feature_keys.TrainEvalFeatures.VALUES][:, -1, 0]
+ return {
+ "plain_boring_metric386":
+ (math_ops.reduce_mean(math_ops.abs(predict - target)),
+ control_flow_ops.no_op()),
+ "fun_metric101": (math_ops.reduce_sum(predict + target),
+ control_flow_ops.no_op()),
+ }
+
+ # Evaluation without training is enough for testing custom metrics.
+ estimator = extenders.add_metrics(estimator, metrics_fn)
+ evaluation = estimator.evaluate(input_fn, steps=1)
+ self.assertIn("plain_boring_metric386", evaluation)
+ self.assertIn("fun_metric101", evaluation)
+ # The values are deterministic because of fixed tf_random_seed.
+ # However if they become flaky, remove such exacts comparisons.
+ self.assertAllClose(evaluation["plain_boring_metric386"], 1.130380)
+ self.assertAllClose(evaluation["fun_metric101"], 10.435442)
+
class _StubModel(object):
num_features = 3