aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned/linear_testing_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/canned/linear_testing_utils.py')
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py141
1 files changed, 139 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index 0e6436b421..c3934c7a80 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -29,6 +29,7 @@ import six
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.client import session as tf_session
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator.canned import linear
@@ -260,6 +261,8 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 9.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -285,6 +288,8 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 18.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -315,6 +320,8 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 27.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -345,7 +352,9 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is
# [2., 4., 5.] * [1.0, 2.0] + [7.0, 8.0] = [39, 50] + [7.0, 8.0]
@@ -382,7 +391,9 @@ class BaseLinearRegressorEvaluationTest(object):
eval_metrics = est.evaluate(input_fn=input_fn, steps=1)
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is [(20. * 10.0 + 4 * 2.0 + 5.0), (40. * 10.0 + 8 * 2.0 + 5.0)] =
# [213.0, 421.0], while label is [213., 421.]. Loss = 0.
@@ -484,6 +495,69 @@ class BaseLinearRegressorPredictTest(object):
# x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2
self.assertAllClose([[80.2]], predicted_scores)
+ def testSparseCombiner(self):
+ w_a = 2.0
+ w_b = 3.0
+ w_c = 5.0
+ bias = 5.0
+ with ops.Graph().as_default():
+ variables_lib.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME)
+ variables_lib.Variable([bias], name=BIAS_NAME)
+ variables_lib.Variable(1, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors({
+ 'language': sparse_tensor.SparseTensor(
+ values=['a', 'c', 'b', 'c'],
+ indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ })
+
+ feature_columns = (
+ feature_column_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
+
+ # Check prediction for each sparse_combiner.
+ # With sparse_combiner = 'sum', we have
+ # logits_1 = w_a + w_c + bias
+ # = 2.0 + 5.0 + 5.0 = 12.0
+ # logits_2 = w_b + w_c + bias
+ # = 3.0 + 5.0 + 5.0 = 13.0
+ linear_regressor = self._linear_regressor_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir)
+ predictions = linear_regressor.predict(input_fn=_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ self.assertAllClose([[12.0], [13.0]], predicted_scores)
+
+ # With sparse_combiner = 'mean', we have
+ # logits_1 = 1/2 * (w_a + w_c) + bias
+ # = 1/2 * (2.0 + 5.0) + 5.0 = 8.5
+ # logits_2 = 1/2 * (w_b + w_c) + bias
+ # = 1/2 * (3.0 + 5.0) + 5.0 = 9.0
+ linear_regressor = self._linear_regressor_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir,
+ sparse_combiner='mean')
+ predictions = linear_regressor.predict(input_fn=_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ self.assertAllClose([[8.5], [9.0]], predicted_scores)
+
+ # With sparse_combiner = 'sqrtn', we have
+ # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias
+ # = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974
+ # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias
+ # = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685
+ linear_regressor = self._linear_regressor_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir,
+ sparse_combiner='sqrtn')
+ predictions = linear_regressor.predict(input_fn=_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ self.assertAllClose([[9.94974], [10.65685]], predicted_scores)
+
class BaseLinearRegressorIntegrationTest(object):
@@ -1636,6 +1710,69 @@ class BaseLinearClassifierPredictTest(object):
for i in range(n_classes)],
label_output_fn=lambda x: ('class_vocab_%s' % x).encode())
+ def testSparseCombiner(self):
+ w_a = 2.0
+ w_b = 3.0
+ w_c = 5.0
+ bias = 5.0
+ with ops.Graph().as_default():
+ variables_lib.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME)
+ variables_lib.Variable([bias], name=BIAS_NAME)
+ variables_lib.Variable(1, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors({
+ 'language': sparse_tensor.SparseTensor(
+ values=['a', 'c', 'b', 'c'],
+ indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ })
+
+ feature_columns = (
+ feature_column_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
+
+ # Check prediction for each sparse_combiner.
+ # With sparse_combiner = 'sum', we have
+ # logits_1 = w_a + w_c + bias
+ # = 2.0 + 5.0 + 5.0 = 12.0
+ # logits_2 = w_b + w_c + bias
+ # = 3.0 + 5.0 + 5.0 = 13.0
+ linear_classifier = self._linear_classifier_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir)
+ predictions = linear_classifier.predict(input_fn=_input_fn)
+ predicted_scores = list([x['logits'] for x in predictions])
+ self.assertAllClose([[12.0], [13.0]], predicted_scores)
+
+ # With sparse_combiner = 'mean', we have
+ # logits_1 = 1/2 * (w_a + w_c) + bias
+ # = 1/2 * (2.0 + 5.0) + 5.0 = 8.5
+ # logits_2 = 1/2 * (w_b + w_c) + bias
+ # = 1/2 * (3.0 + 5.0) + 5.0 = 9.0
+ linear_classifier = self._linear_classifier_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir,
+ sparse_combiner='mean')
+ predictions = linear_classifier.predict(input_fn=_input_fn)
+ predicted_scores = list([x['logits'] for x in predictions])
+ self.assertAllClose([[8.5], [9.0]], predicted_scores)
+
+ # With sparse_combiner = 'sqrtn', we have
+ # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias
+ # = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974
+ # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias
+ # = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685
+ linear_classifier = self._linear_classifier_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir,
+ sparse_combiner='sqrtn')
+ predictions = linear_classifier.predict(input_fn=_input_fn)
+ predicted_scores = list([x['logits'] for x in predictions])
+ self.assertAllClose([[9.94974], [10.65685]], predicted_scores)
+
class BaseLinearClassifierIntegrationTest(object):