aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-11 14:48:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 14:51:46 -0700
commit4e3e8ca367eb9203ba2df07a3826be0005c18157 (patch)
tree400404c2523f5695d61c2d8e3963c67df184fcb9
parentdce5a71d588079c86f74033a06e600fb7710ab9b (diff)
Handle 1d weights.
PiperOrigin-RevId: 155799189
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py31
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py145
2 files changed, 161 insertions, 15 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 52b4213463..e4ef6996d8 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -163,10 +163,10 @@ class Head(object):
ModeFnOps.loss to compute and apply gradients.
logits: logits `Tensor` to be used by the head.
logits_input: `Tensor` from which to build logits, often needed when you
- don't want to compute the logits. Typically this is the activation of the
- last hidden layer in a DNN. Some heads (like the ones responsible for
- candidate sampling) intrinsically avoid computing full logits and only
- accepts logits_input.
+ don't want to compute the logits. Typically this is the activation of
+ the last hidden layer in a DNN. Some heads (like the ones responsible
+ for candidate sampling) intrinsically avoid computing full logits and
+ only accepts logits_input.
scope: Optional scope for `variable_scope`.
Returns:
@@ -1646,12 +1646,27 @@ class _MultiHead(Head):
def _weight_tensor(features, weight_column_name):
- """Returns weights as 1d `Tensor`."""
+ """Returns weights as `Tensor` of rank 0, or at least 2."""
if not weight_column_name:
return None
- with ops.name_scope(None, "weight_tensor",
- tuple(six.itervalues(features))):
- return math_ops.to_float(features[weight_column_name])
+ if weight_column_name not in features:
+ raise ValueError("Weights {} missing from features.".format(
+ weight_column_name))
+ with ops.name_scope(None, "weight_tensor", tuple(six.itervalues(features))):
+ weight_tensor = math_ops.to_float(features[weight_column_name])
+ shape = weight_tensor.get_shape()
+ rank = shape.ndims
+ # We don't bother with expanding dims of non-staticly shaped tensors or
+ # scalars, and >1d is already in a good format.
+ if rank == 1:
+ logging.warning(
+ "Weights {} has shape {}, expanding to make it 2d.",
+ weight_column_name, shape)
+ return (
+ sparse_ops.sparse_reshape(weight_tensor, (-1, 1))
+ if isinstance(weight_tensor, sparse_tensor.SparseTensor) else
+ array_ops.reshape(weight_tensor, (-1, 1)))
+ return weight_tensor
# TODO(zakaria): This function is needed for backward compatibility and should
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index f7934fc188..012b919d63 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -225,20 +225,56 @@ class RegressionHeadTest(test.TestCase):
_assert_summary_tags(self, ["loss"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
- def testRegressionWithWeights(self):
+ def testRegressionWithScalarWeights(self):
+ head = head_lib.regression_head(weight_column_name="label_weight")
+ with ops.Graph().as_default(), session.Session():
+ weights = 2.
+ labels = ((0.,), (1.,), (1.,))
+ model_fn_ops = head.create_model_fn_ops(
+ features={"label_weight": weights},
+ labels=labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ train_op_fn=head_lib.no_op_train_fn,
+ logits=((1.,), (1.,), (3.,)))
+ self._assert_output_alternatives(model_fn_ops)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ _assert_metrics(self, (weights * 5.) / len(labels), {
+ "loss": (weights * 5.) / (weights * len(labels))
+ }, model_fn_ops)
+
+ def testRegressionWith1DWeights(self):
+ head = head_lib.regression_head(weight_column_name="label_weight")
+ with ops.Graph().as_default(), session.Session():
+ weights = (2., 5., 0.)
+ labels = ((0.,), (1.,), (1.,))
+ model_fn_ops = head.create_model_fn_ops(
+ features={"label_weight": weights},
+ labels=labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ train_op_fn=head_lib.no_op_train_fn,
+ logits=((1.,), (1.,), (3.,)))
+ self._assert_output_alternatives(model_fn_ops)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},
+ model_fn_ops)
+
+ def testRegressionWith2DWeights(self):
head = head_lib.regression_head(weight_column_name="label_weight")
with ops.Graph().as_default(), session.Session():
weights = ((2.,), (5.,), (0.,))
+ labels = ((0.,), (1.,), (1.,))
model_fn_ops = head.create_model_fn_ops(
features={"label_weight": weights},
- labels=((0.,), (1.,), (1.,)),
+ labels=labels,
mode=model_fn.ModeKeys.TRAIN,
train_op_fn=head_lib.no_op_train_fn,
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
- _assert_metrics(self, 2. / len(weights), {"loss": 2. / np.sum(weights)},
+ _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},
model_fn_ops)
def testRegressionWithCenteredBias(self):
@@ -525,7 +561,7 @@ class MultiLabelHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
- def testMultiLabelWithWeight(self):
+ def testMultiLabelWithScalarWeight(self):
n_classes = 3
head = head_lib.multi_label_head(
n_classes=n_classes,
@@ -544,7 +580,23 @@ class MultiLabelHeadTest(test.TestCase):
_assert_metrics(self, .089985214,
self._expected_eval_metrics(.89985214), model_fn_ops)
- def testMultiLabelWithMultiDimensionalWeight(self):
+ def testMultiLabelWith1DWeight(self):
+ n_classes = 3
+ head = head_lib.multi_label_head(
+ n_classes=n_classes,
+ weight_column_name="label_weight",
+ metric_class_ids=range(n_classes))
+ with ops.Graph().as_default(), session.Session():
+ with self.assertRaisesRegexp(
+ ValueError, "weights can not be broadcast to values"):
+ head.create_model_fn_ops(
+ features={"label_weight": (.1, .1, .1)},
+ labels=self._labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ train_op_fn=head_lib.no_op_train_fn,
+ logits=self._logits)
+
+ def testMultiLabelWith2DWeight(self):
n_classes = 3
head = head_lib.multi_label_head(
n_classes=n_classes,
@@ -843,7 +895,42 @@ class BinaryClassificationHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
- def testBinaryClassificationWithWeights(self):
+ def testBinaryClassificationWith1DWeights(self):
+ n_classes = 2
+ head = head_lib.multi_class_head(
+ n_classes=n_classes, weight_column_name="label_weight")
+ with ops.Graph().as_default(), session.Session():
+ weights = (1., 0.)
+ # logloss: z:label, x:logit
+ # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ model_fn_ops = head.create_model_fn_ops(
+ features={"label_weight": weights},
+ labels=self._labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ train_op_fn=head_lib.no_op_train_fn,
+ logits=self._logits)
+ self._assert_output_alternatives(model_fn_ops)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ expected_total_loss = .31326166
+ _assert_metrics(
+ self,
+ expected_total_loss / len(weights),
+ {
+ "accuracy": 1. / 1,
+ "accuracy/baseline_label_mean": 1. / 1,
+ "accuracy/threshold_0.500000_mean": 1. / 1,
+ "auc": 0. / 1,
+ "labels/actual_label_mean": 1. / 1,
+ "labels/prediction_mean": .731059, # softmax
+ # eval loss is weighted loss divided by sum of weights.
+ "loss": expected_total_loss,
+ "precision/positive_threshold_0.500000_mean": 1. / 1,
+ "recall/positive_threshold_0.500000_mean": 1. / 1,
+ },
+ model_fn_ops)
+
+ def testBinaryClassificationWith2DWeights(self):
n_classes = 2
head = head_lib.multi_class_head(
n_classes=n_classes, weight_column_name="label_weight")
@@ -1154,6 +1241,30 @@ class MultiClassHeadTest(test.TestCase):
_assert_metrics(self, expected_loss * weight,
self._expected_eval_metrics(expected_loss), model_fn_ops)
+ def testMultiClassWith1DWeight(self):
+ n_classes = 3
+ head = head_lib.multi_class_head(
+ n_classes=n_classes,
+ weight_column_name="label_weight",
+ metric_class_ids=range(n_classes))
+ with ops.Graph().as_default(), session.Session():
+ weight = .1
+ weights = (weight,)
+ # logloss: z:label, x:logit
+ # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ model_fn_ops = head.create_model_fn_ops(
+ features={"label_weight": weights},
+ labels=self._labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ train_op_fn=head_lib.no_op_train_fn,
+ logits=self._logits)
+ self._assert_output_alternatives(model_fn_ops)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ expected_loss = 1.5514447
+ _assert_metrics(self, expected_loss * weight,
+ self._expected_eval_metrics(expected_loss), model_fn_ops)
+
def testMultiClassWith2DWeight(self):
n_classes = 3
head = head_lib.multi_class_head(
@@ -1457,7 +1568,27 @@ class BinarySvmHeadTest(test.TestCase):
"loss": expected_loss,
}, model_fn_ops)
- def testBinarySVMWithWeights(self):
+ def testBinarySVMWith1DWeights(self):
+ head = head_lib.binary_svm_head(weight_column_name="weights")
+ with ops.Graph().as_default(), session.Session():
+ weights = (7., 11.)
+ model_fn_ops = head.create_model_fn_ops(
+ # We have to add an extra dim here for weights broadcasting to work.
+ features={"weights": weights},
+ mode=model_fn.ModeKeys.TRAIN,
+ labels=self._labels,
+ train_op_fn=head_lib.no_op_train_fn,
+ logits=self._predictions)
+ self._assert_output_alternatives(model_fn_ops)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ expected_weighted_losses = np.multiply(weights, self._expected_losses)
+ _assert_metrics(self, np.mean(expected_weighted_losses), {
+ "accuracy": 1.,
+ "loss": np.sum(expected_weighted_losses) / np.sum(weights),
+ }, model_fn_ops)
+
+ def testBinarySVMWith2DWeights(self):
head = head_lib.binary_svm_head(weight_column_name="weights")
with ops.Graph().as_default(), session.Session():
weights = (7., 11.)