aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-14 09:14:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-14 09:26:35 -0800
commitab336a6a7d56a712364897f8bd4b30b7b7a4b186 (patch)
treef6b43cad8d4da9804432ae4a935859d21343156f
parente4c4f4b9045d71571fdc3e088625f58821ba49cb (diff)
Refactor head implementations to reduce code duplication and eliminate deep inheritance.
Change: 147480076
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py402
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py62
2 files changed, 301 insertions, 163 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 2f005a70cb..1d0a3c11e3 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -361,7 +361,7 @@ class _SingleHead(_Head):
if problem_type is None:
raise ValueError("Invalid problem_type %s." % problem_type)
if logits_dimension is None or logits_dimension < 1:
- raise ValueError("Invalid logits_dimension %s." % problem_type)
+ raise ValueError("Invalid logits_dimension %s." % logits_dimension)
self._problem_type = problem_type
self._logits_dimension = logits_dimension
self._label_name = label_name
@@ -474,6 +474,66 @@ def _logits(logits_input, logits, logits_dimension):
return logits
+def _create_model_fn_ops(features,
+ mode,
+ transform_labels_fn,
+ loss_fn,
+ logits_to_predictions_fn,
+ metrics_fn,
+ create_output_alternatives_fn,
+ default_variable_scope_name,
+ labels=None,
+ train_op_fn=None,
+ logits=None,
+ logits_input=None,
+ logits_dimension=None,
+ head_name=None,
+ weight_column_name=None,
+ enable_centered_bias=False):
+ """Returns a `ModelFnOps` object."""
+ _check_mode_valid(mode)
+
+ with variable_scope.variable_scope(
+ None,
+ default_name=head_name or default_variable_scope_name,
+ values=(tuple(six.itervalues(features)) +
+ (labels, logits, logits_input))):
+ if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
+ labels = transform_labels_fn(labels)
+ else:
+ labels = None
+
+ logits = _logits(logits_input, logits, logits_dimension)
+ centered_bias = None
+ if enable_centered_bias:
+ centered_bias = _centered_bias(logits_dimension, head_name)
+ logits = nn.bias_add(logits, centered_bias)
+
+ predictions = logits_to_predictions_fn(logits)
+ loss = None
+ train_op = None
+ eval_metric_ops = None
+ if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
+ weight_tensor = _weight_tensor(features, weight_column_name)
+ loss, weighted_average_loss = _loss(
+ loss_fn(logits, labels), weight_tensor)
+ logging_ops.scalar_summary(
+ _summary_key(head_name, mkey.LOSS), weighted_average_loss)
+
+ if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None):
+ train_op = _train_op(loss, labels, train_op_fn, centered_bias,
+ logits_dimension, loss_fn)
+ eval_metric_ops = metrics_fn(
+ weighted_average_loss, predictions, labels, weight_tensor)
+ return model_fn.ModelFnOps(
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metric_ops=eval_metric_ops,
+ output_alternatives=create_output_alternatives_fn(predictions))
+
+
class _RegressionHead(_SingleHead):
"""_Head for regression with a generalized linear model."""
@@ -525,44 +585,29 @@ class _RegressionHead(_SingleHead):
logits_input=None,
scope=None):
"""See `_Head`."""
- _check_mode_valid(mode)
-
- with variable_scope.variable_scope(
- scope,
- self.head_name or "regression_head",
- values=(tuple(six.itervalues(features)) +
- (labels, logits, logits_input))):
- logits = _logits(logits_input, logits, self.logits_dimension)
- centered_bias = None
- if self._enable_centered_bias:
- centered_bias = _centered_bias(self.logits_dimension, self.head_name)
- logits = nn.bias_add(logits, centered_bias)
-
- predictions = self._logits_to_predictions(logits)
- loss = None
- train_op = None
- eval_metric_ops = None
- if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
- labels_tensor = _to_labels_tensor(labels, self._label_name)
- loss, weighted_average_loss = _loss(
- self._loss_fn(logits, labels_tensor),
- _weight_tensor(features, self.weight_column_name))
- logging_ops.scalar_summary(
- _summary_key(self.head_name, mkey.LOSS), weighted_average_loss)
-
- if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None):
- train_op = _train_op(loss, labels_tensor, train_op_fn, centered_bias,
- self.logits_dimension, self._loss_fn)
- with ops.name_scope("default_metrics", values=[weighted_average_loss]):
- eval_metric_ops = {_summary_key(self.head_name, mkey.LOSS):
- metrics_lib.streaming_mean(weighted_average_loss)}
- return model_fn.ModelFnOps(
+ return _create_model_fn_ops(
+ features=features,
mode=mode,
- predictions=predictions,
- loss=loss,
- train_op=train_op,
- eval_metric_ops=eval_metric_ops,
- output_alternatives=self._create_output_alternatives(predictions))
+ transform_labels_fn=self._transform_labels,
+ loss_fn=self._loss_fn,
+ logits_to_predictions_fn=self._logits_to_predictions,
+ metrics_fn=self._metrics,
+ create_output_alternatives_fn=self._create_output_alternatives,
+ default_variable_scope_name="regression_head",
+ labels=labels,
+ train_op_fn=train_op_fn,
+ logits=logits,
+ logits_input=logits_input,
+ logits_dimension=self.logits_dimension,
+ head_name=self.head_name,
+ weight_column_name=self.weight_column_name,
+ enable_centered_bias=self._enable_centered_bias)
+
+ def _transform_labels(self, labels):
+ """Applies transformations to labels tensor."""
+ labels_tensor = _to_labels_tensor(labels, self._label_name)
+ _check_no_sparse_tensor(labels_tensor)
+ return labels_tensor
def _logits_to_predictions(self, logits):
"""Returns a dict of predictions.
@@ -579,6 +624,14 @@ class _RegressionHead(_SingleHead):
logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key)
return {key: self._link_fn(logits)}
+ def _metrics(self, eval_loss, predictions, labels, weights):
+ """Returns a dict of metrics keyed by name."""
+ del predictions, labels, weights # Unused by this head.
+ with ops.name_scope("metrics", values=[eval_loss]):
+ return {
+ _summary_key(self.head_name, mkey.LOSS):
+ metrics_lib.streaming_mean(eval_loss)}
+
def _log_loss_with_two_classes(logits, labels):
with ops.name_scope(None, "log_loss_with_two_classes",
@@ -646,45 +699,29 @@ class _BinaryLogisticHead(_SingleHead):
logits_input=None,
scope=None):
"""See `_Head`."""
- _check_mode_valid(mode)
-
- with variable_scope.variable_scope(
- scope,
- self.head_name or "binary_logistic_head",
- values=(tuple(six.itervalues(features)) +
- (labels, logits, logits_input))):
- logits = _logits(logits_input, logits, self.logits_dimension)
- centered_bias = None
- if self._enable_centered_bias:
- centered_bias = _centered_bias(1, self.head_name)
- logits = nn.bias_add(logits, centered_bias)
-
- predictions = self._logits_to_predictions(logits)
- loss = None
- train_op = None
- eval_metric_ops = None
- if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
- weight_tensor = _weight_tensor(features, self.weight_column_name)
- labels_tensor = _to_labels_tensor(labels, self._label_name)
- loss, weighted_average_loss = _loss(
- self._loss_fn(logits, labels_tensor), weight_tensor)
- logging_ops.scalar_summary(
- _summary_key(self.head_name, mkey.LOSS), weighted_average_loss)
-
- if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None):
- train_op = _train_op(loss, labels_tensor, train_op_fn, centered_bias,
- self.logits_dimension, self._loss_fn)
- eval_metric_ops = self._default_metrics(weighted_average_loss,
- predictions, labels_tensor,
- weight_tensor)
-
- return model_fn.ModelFnOps(
+ return _create_model_fn_ops(
+ features=features,
mode=mode,
- predictions=predictions,
- loss=loss,
- train_op=train_op,
- eval_metric_ops=eval_metric_ops,
- output_alternatives=self._create_output_alternatives(predictions))
+ transform_labels_fn=self._transform_labels,
+ loss_fn=self._loss_fn,
+ logits_to_predictions_fn=self._logits_to_predictions,
+ metrics_fn=self._metrics,
+ create_output_alternatives_fn=self._create_output_alternatives,
+ default_variable_scope_name="binary_logistic_head",
+ labels=labels,
+ train_op_fn=train_op_fn,
+ logits=logits,
+ logits_input=logits_input,
+ logits_dimension=self.logits_dimension,
+ head_name=self.head_name,
+ weight_column_name=self.weight_column_name,
+ enable_centered_bias=self._enable_centered_bias)
+
+ def _transform_labels(self, labels):
+ """Applies transformations to labels tensor."""
+ labels_tensor = _to_labels_tensor(labels, self._label_name)
+ _check_no_sparse_tensor(labels_tensor)
+ return labels_tensor
def _logits_to_predictions(self, logits):
"""Returns a dict of predictions.
@@ -714,9 +751,9 @@ class _BinaryLogisticHead(_SingleHead):
name=prediction_key.PredictionKey.CLASSES)
}
- def _default_metrics(self, eval_loss, predictions, labels, weights):
+ def _metrics(self, eval_loss, predictions, labels, weights):
"""Returns a dict of metrics keyed by name."""
- with ops.name_scope("default_metrics", values=(
+ with ops.name_scope("metrics", values=(
[eval_loss, labels, weights] + list(six.itervalues(predictions)))):
classes = predictions[prediction_key.PredictionKey.CLASSES]
logistic = predictions[prediction_key.PredictionKey.LOGISTIC]
@@ -837,45 +874,29 @@ class _MultiClassHead(_SingleHead):
logits_input=None,
scope=None):
"""See `_Head`."""
- _check_mode_valid(mode)
-
- with variable_scope.variable_scope(
- scope,
- self.head_name or "multi_class_head",
- values=(tuple(six.itervalues(features)) +
- (labels, logits, logits_input))):
- logits = _logits(logits_input, logits, self.logits_dimension)
- centered_bias = None
- if self._enable_centered_bias:
- centered_bias = _centered_bias(self.logits_dimension, self.head_name)
- logits = nn.bias_add(logits, centered_bias)
-
- predictions = self._logits_to_predictions(logits)
- loss = None
- train_op = None
- eval_metric_ops = None
- if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
- weight_tensor = _weight_tensor(features, self.weight_column_name)
- labels_tensor = _to_labels_tensor(labels, self._label_name,
- self._logits_dimension)
- loss, weighted_average_loss = _loss(
- self._loss_fn(logits, labels_tensor), weight_tensor)
- logging_ops.scalar_summary(
- _summary_key(self.head_name, mkey.LOSS), weighted_average_loss)
-
- if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None):
- train_op = _train_op(loss, labels_tensor, train_op_fn, centered_bias,
- self.logits_dimension, self._loss_fn)
- eval_metric_ops = self._default_metrics(weighted_average_loss,
- predictions, labels_tensor,
- weight_tensor)
- return model_fn.ModelFnOps(
+ return _create_model_fn_ops(
+ features=features,
mode=mode,
- predictions=predictions,
- loss=loss,
- train_op=train_op,
- eval_metric_ops=eval_metric_ops,
- output_alternatives=self._create_output_alternatives(predictions))
+ transform_labels_fn=self._transform_labels,
+ loss_fn=self._loss_fn,
+ logits_to_predictions_fn=self._logits_to_predictions,
+ metrics_fn=self._metrics,
+ create_output_alternatives_fn=self._create_output_alternatives,
+ default_variable_scope_name="multi_class_head",
+ labels=labels,
+ train_op_fn=train_op_fn,
+ logits=logits,
+ logits_input=logits_input,
+ logits_dimension=self.logits_dimension,
+ head_name=self.head_name,
+ weight_column_name=self.weight_column_name,
+ enable_centered_bias=self._enable_centered_bias)
+
+ def _transform_labels(self, labels):
+ """Applies transformations to labels tensor."""
+ labels_tensor = _to_labels_tensor(labels, self._label_name)
+ _check_no_sparse_tensor(labels_tensor)
+ return labels_tensor
def _logits_to_predictions(self, logits):
"""Returns a dict of predictions.
@@ -898,9 +919,9 @@ class _MultiClassHead(_SingleHead):
logits, 1, name=prediction_key.PredictionKey.CLASSES)
}
- def _default_metrics(self, eval_loss, predictions, labels, weights):
+ def _metrics(self, eval_loss, predictions, labels, weights):
"""Returns a dict of metrics keyed by name."""
- with ops.name_scope("default_metrics", values=(
+ with ops.name_scope("metrics", values=(
[eval_loss, labels, weights] + list(six.itervalues(predictions)))):
classes = predictions[prediction_key.PredictionKey.CLASSES]
probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
@@ -937,31 +958,44 @@ class _MultiClassHead(_SingleHead):
return metrics
-def _to_labels_tensor(labels, label_name, num_classes=None):
- """Returns label as a tensor, converting from sparse to dense if needed.
+def _to_labels_tensor(labels, label_name):
+ """Returns label as a tensor.
Args:
- labels: Label tensor or a dict containig labels.
+ labels: Label `Tensor` or `SparseTensor` or a dict containig labels.
label_name: Label name if labels is a dict.
+
+ Returns:
+ Label `Tensor` or `SparseTensor`.
+ """
+ labels = labels[label_name] if isinstance(labels, dict) else labels
+ return framework_lib.convert_to_tensor_or_sparse_tensor(labels)
+
+
+def _check_no_sparse_tensor(x):
+ """Raises ValueError if the given tensor is `SparseTensor`."""
+ if isinstance(x, sparse_tensor.SparseTensor):
+ raise ValueError("SparseTensor is not supported.")
+
+
+def _sparse_labels_to_indicator(labels, num_classes):
+ """If labels is `SparseTensor`, converts it to indicator `Tensor`.
+
+ Args:
+ labels: Label `Tensor` or `SparseTensor`.
num_classes: Number of classes.
Returns:
- Dense label tensor. If label is sparse, it will be converted to dense.
+ Dense label `Tensor`.
Raises:
- ValueError: if label is sparse and num_classes is not provided or <2
+ ValueError: If labels is `SparseTensot` and `num_classes` < 2.
"""
- labels = labels[label_name] if isinstance(labels, dict) else labels
- labels = framework_lib.convert_to_tensor_or_sparse_tensor(labels)
if isinstance(labels, sparse_tensor.SparseTensor):
- if num_classes is None:
- raise ValueError("Must set num_classes when passing labels as a "
- "SparseTensor. Sparse labels are currently supported "
- "for MultiLabelHead only.")
if num_classes < 2:
raise ValueError("Must set num_classes >= 2 when passing labels as a "
"SparseTensor.")
- labels = math_ops.to_int64(
+ return math_ops.to_int64(
sparse_ops.sparse_to_indicator(labels, num_classes))
return labels
@@ -972,7 +1006,7 @@ def _assert_labels_rank(labels):
("labels shape should be either [batch_size, 1] or [batch_size]",))
-class _BinarySvmHead(_BinaryLogisticHead):
+class _BinarySvmHead(_SingleHead):
"""_Head for binary classification using SVMs."""
def __init__(self, label_name, weight_column_name, enable_centered_bias,
@@ -985,12 +1019,47 @@ class _BinarySvmHead(_BinaryLogisticHead):
return losses_lib.hinge_loss(logits, labels, scope=name)
super(_BinarySvmHead, self).__init__(
+ problem_type=constants.ProblemType.LOGISTIC_REGRESSION,
+ logits_dimension=1,
label_name=label_name,
weight_column_name=weight_column_name,
- enable_centered_bias=enable_centered_bias,
- head_name=head_name,
- loss_fn=_loss_fn,
- thresholds=thresholds)
+ head_name=head_name)
+ self._thresholds = thresholds if thresholds else (.5,)
+ self._loss_fn = _loss_fn
+ self._enable_centered_bias = enable_centered_bias
+
+ def create_model_fn_ops(self,
+ features,
+ mode,
+ labels=None,
+ train_op_fn=None,
+ logits=None,
+ logits_input=None,
+ scope=None):
+ """See `_Head`."""
+ return _create_model_fn_ops(
+ features=features,
+ mode=mode,
+ transform_labels_fn=self._transform_labels,
+ loss_fn=self._loss_fn,
+ logits_to_predictions_fn=self._logits_to_predictions,
+ metrics_fn=self._metrics,
+ create_output_alternatives_fn=self._create_output_alternatives,
+ default_variable_scope_name="binary_svm_head",
+ labels=labels,
+ train_op_fn=train_op_fn,
+ logits=logits,
+ logits_input=logits_input,
+ logits_dimension=self.logits_dimension,
+ head_name=self.head_name,
+ weight_column_name=self.weight_column_name,
+ enable_centered_bias=self._enable_centered_bias)
+
+ def _transform_labels(self, labels):
+ """Applies transformations to labels tensor."""
+ labels_tensor = _to_labels_tensor(labels, self._label_name)
+ _check_no_sparse_tensor(labels_tensor)
+ return labels_tensor
def _logits_to_predictions(self, logits):
"""See `_MultiClassHead`."""
@@ -1005,9 +1074,9 @@ class _BinarySvmHead(_BinaryLogisticHead):
name=prediction_key.PredictionKey.CLASSES)
}
- def _default_metrics(self, eval_loss, predictions, labels, weights):
+ def _metrics(self, eval_loss, predictions, labels, weights):
"""See `_MultiClassHead`."""
- with ops.name_scope("default_metrics", values=(
+ with ops.name_scope("metrics", values=(
[eval_loss, labels, weights] + list(six.itervalues(predictions)))):
metrics = {_summary_key(self.head_name, mkey.LOSS):
metrics_lib.streaming_mean(eval_loss)}
@@ -1022,7 +1091,7 @@ class _BinarySvmHead(_BinaryLogisticHead):
return metrics
-class _MultiLabelHead(_MultiClassHead):
+class _MultiLabelHead(_SingleHead):
"""_Head for multlabel classification."""
# TODO(zakaria): add signature and metric for multilabel.
@@ -1036,14 +1105,54 @@ class _MultiLabelHead(_MultiClassHead):
metric_class_ids=None):
super(_MultiLabelHead, self).__init__(
- n_classes=n_classes,
+ problem_type=constants.ProblemType.CLASSIFICATION,
+ logits_dimension=n_classes,
label_name=label_name,
weight_column_name=weight_column_name,
- enable_centered_bias=enable_centered_bias,
- head_name=head_name,
- loss_fn=_sigmoid_cross_entropy_loss,
- thresholds=thresholds,
- metric_class_ids=metric_class_ids)
+ head_name=head_name)
+
+ self._thresholds = thresholds if thresholds else (.5,)
+ self._loss_fn = _sigmoid_cross_entropy_loss
+ self._enable_centered_bias = enable_centered_bias
+ self._metric_class_ids = tuple([] if metric_class_ids is None else
+ metric_class_ids)
+ for class_id in self._metric_class_ids:
+ if (class_id < 0) or (class_id >= n_classes):
+ raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes))
+
+ def create_model_fn_ops(self,
+ features,
+ mode,
+ labels=None,
+ train_op_fn=None,
+ logits=None,
+ logits_input=None,
+ scope=None):
+ """See `_Head`."""
+ return _create_model_fn_ops(
+ features=features,
+ mode=mode,
+ transform_labels_fn=self._transform_labels,
+ loss_fn=self._loss_fn,
+ logits_to_predictions_fn=self._logits_to_predictions,
+ metrics_fn=self._metrics,
+ create_output_alternatives_fn=self._create_output_alternatives,
+ default_variable_scope_name="multi_label_head",
+ labels=labels,
+ train_op_fn=train_op_fn,
+ logits=logits,
+ logits_input=logits_input,
+ logits_dimension=self.logits_dimension,
+ head_name=self.head_name,
+ weight_column_name=self.weight_column_name,
+ enable_centered_bias=self._enable_centered_bias)
+
+ def _transform_labels(self, labels):
+ """Applies transformations to labels tensor."""
+ labels_tensor = _to_labels_tensor(labels, self._label_name)
+ labels_tensor = _sparse_labels_to_indicator(labels_tensor,
+ self._logits_dimension)
+ return labels_tensor
def _logits_to_predictions(self, logits):
"""See `_MultiClassHead`."""
@@ -1060,9 +1169,9 @@ class _MultiLabelHead(_MultiClassHead):
name=prediction_key.PredictionKey.CLASSES)
}
- def _default_metrics(self, eval_loss, predictions, labels, weights):
+ def _metrics(self, eval_loss, predictions, labels, weights):
"""Returns a dict of metrics keyed by name."""
- with ops.name_scope("default_metrics", values=(
+ with ops.name_scope("metrics", values=(
[eval_loss, labels, weights] + list(six.itervalues(predictions)))):
classes = predictions[prediction_key.PredictionKey.CLASSES]
probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
@@ -1554,4 +1663,3 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
predictions, labels=labels, thresholds=(threshold,),
weights=_float_weights_or_none(weights))
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
-
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 715275e3eb..da2e509a4a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -277,7 +277,7 @@ class RegressionHeadTest(test.TestCase):
values=(0., 1., 1.),
dense_shape=(3, 1))
with self.assertRaisesRegexp(ValueError,
- "Must set num_classes when passing"):
+ "SparseTensor is not supported"):
head.create_model_fn_ops(
{},
labels=labels,
@@ -336,6 +336,36 @@ class MultiLabelHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
+ def testMultiLabelTwoClasses(self):
+ n_classes = 2
+ labels = ((0, 1),)
+ logits = ((1., 0.),)
+ head = head_lib._multi_label_head(
+ n_classes=n_classes, metric_class_ids=range(n_classes))
+ with ops.Graph().as_default(), session.Session():
+ model_fn_ops = head.create_model_fn_ops(
+ {}, model_fn.ModeKeys.TRAIN, labels=labels,
+ train_op_fn=_noop_train_op, logits=logits)
+ self._assert_output_alternatives(model_fn_ops)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ expected_loss = 1.00320443
+ _assert_metrics(self, expected_loss, {
+ "accuracy": 0.,
+ "auc": 0.,
+ "loss": expected_loss,
+ "auc/class0": 1.,
+ "auc/class1": 0.,
+ "labels/actual_label_mean/class0": labels[0][0],
+ "labels/actual_label_mean/class1": labels[0][1],
+ "labels/logits_mean/class0": logits[0][0],
+ "labels/logits_mean/class1": logits[0][1],
+ "labels/prediction_mean/class0": logits[0][0],
+ "labels/prediction_mean/class1": logits[0][1],
+ "labels/probability_mean/class0": _sigmoid(logits[0][0]),
+ "labels/probability_mean/class1": _sigmoid(logits[0][1]),
+ }, model_fn_ops)
+
def testMultiLabelWithInvalidLogits(self):
head = head_lib._multi_label_head(n_classes=len(self._labels[0]) + 1)
with ops.Graph().as_default(), session.Session():
@@ -353,8 +383,8 @@ class MultiLabelHeadTest(test.TestCase):
{}, model_fn.ModeKeys.TRAIN, self._labels, _noop_train_op,
logits_input=((0., 0.),))
self._assert_output_alternatives(model_fn_ops)
- w = ("multi_class_head/logits/weights:0",
- "multi_class_head/logits/biases:0")
+ w = ("multi_label_head/logits/weights:0",
+ "multi_label_head/logits/biases:0")
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
@@ -459,16 +489,16 @@ class MultiLabelHeadTest(test.TestCase):
_assert_variables(
self,
expected_global=(
- "multi_class_head/centered_bias_weight:0",
- ("multi_class_head/multi_class_head/centered_bias_weight/"
+ "multi_label_head/centered_bias_weight:0",
+ ("multi_label_head/multi_label_head/centered_bias_weight/"
"Adagrad:0"),),
- expected_trainable=("multi_class_head/centered_bias_weight:0",))
+ expected_trainable=("multi_label_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self, (
"loss",
- "multi_class_head/centered_bias/bias_0",
- "multi_class_head/centered_bias/bias_1",
- "multi_class_head/centered_bias/bias_2"
+ "multi_label_head/centered_bias/bias_0",
+ "multi_label_head/centered_bias/bias_1",
+ "multi_label_head/centered_bias/bias_2"
))
expected_loss = .89985204
_assert_metrics(self, expected_loss,
@@ -662,7 +692,7 @@ class BinaryClassificationHeadTest(test.TestCase):
values=(0, 1, 1),
dense_shape=(3, 1))
with self.assertRaisesRegexp(ValueError,
- "Must set num_classes when passing"):
+ "SparseTensor is not supported"):
head.create_model_fn_ops(
{},
model_fn.ModeKeys.TRAIN,
@@ -959,8 +989,8 @@ class BinarySvmHeadTest(test.TestCase):
_noop_train_op,
logits_input=((0., 0.), (0., 0.)))
self._assert_output_alternatives(model_fn_ops)
- w = ("binary_logistic_head/logits/weights:0",
- "binary_logistic_head/logits/biases:0")
+ w = ("binary_svm_head/logits/weights:0",
+ "binary_svm_head/logits/biases:0")
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
@@ -1055,14 +1085,14 @@ class BinarySvmHeadTest(test.TestCase):
_assert_variables(
self,
expected_global=(
- "binary_logistic_head/centered_bias_weight:0",
- ("binary_logistic_head/binary_logistic_head/centered_bias_weight/"
+ "binary_svm_head/centered_bias_weight:0",
+ ("binary_svm_head/binary_svm_head/centered_bias_weight/"
"Adagrad:0"),
),
- expected_trainable=("binary_logistic_head/centered_bias_weight:0",))
+ expected_trainable=("binary_svm_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
- self, ["loss", "binary_logistic_head/centered_bias/bias_0"])
+ self, ["loss", "binary_svm_head/centered_bias/bias_0"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,