diff options
author | 2017-02-14 09:14:26 -0800 | |
---|---|---|
committer | 2017-02-14 09:26:35 -0800 | |
commit | ab336a6a7d56a712364897f8bd4b30b7b7a4b186 (patch) | |
tree | f6b43cad8d4da9804432ae4a935859d21343156f | |
parent | e4c4f4b9045d71571fdc3e088625f58821ba49cb (diff) |
Refactor head implementations to reduce code duplication and eliminate deep inheritance.
Change: 147480076
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 402 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head_test.py | 62 |
2 files changed, 301 insertions, 163 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 2f005a70cb..1d0a3c11e3 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -361,7 +361,7 @@ class _SingleHead(_Head): if problem_type is None: raise ValueError("Invalid problem_type %s." % problem_type) if logits_dimension is None or logits_dimension < 1: - raise ValueError("Invalid logits_dimension %s." % problem_type) + raise ValueError("Invalid logits_dimension %s." % logits_dimension) self._problem_type = problem_type self._logits_dimension = logits_dimension self._label_name = label_name @@ -474,6 +474,66 @@ def _logits(logits_input, logits, logits_dimension): return logits +def _create_model_fn_ops(features, + mode, + transform_labels_fn, + loss_fn, + logits_to_predictions_fn, + metrics_fn, + create_output_alternatives_fn, + default_variable_scope_name, + labels=None, + train_op_fn=None, + logits=None, + logits_input=None, + logits_dimension=None, + head_name=None, + weight_column_name=None, + enable_centered_bias=False): + """Returns a `ModelFnOps` object.""" + _check_mode_valid(mode) + + with variable_scope.variable_scope( + None, + default_name=head_name or default_variable_scope_name, + values=(tuple(six.itervalues(features)) + + (labels, logits, logits_input))): + if (mode != model_fn.ModeKeys.INFER) and (labels is not None): + labels = transform_labels_fn(labels) + else: + labels = None + + logits = _logits(logits_input, logits, logits_dimension) + centered_bias = None + if enable_centered_bias: + centered_bias = _centered_bias(logits_dimension, head_name) + logits = nn.bias_add(logits, centered_bias) + + predictions = logits_to_predictions_fn(logits) + loss = None + train_op = None + eval_metric_ops = None + if (mode != model_fn.ModeKeys.INFER) and (labels is not None): + weight_tensor = _weight_tensor(features, weight_column_name) + loss, weighted_average_loss = _loss( + loss_fn(logits, labels), weight_tensor) + logging_ops.scalar_summary( + _summary_key(head_name, mkey.LOSS), weighted_average_loss) + + if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None): + train_op = _train_op(loss, labels, train_op_fn, centered_bias, + logits_dimension, loss_fn) + eval_metric_ops = metrics_fn( + weighted_average_loss, predictions, labels, weight_tensor) + return model_fn.ModelFnOps( + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + eval_metric_ops=eval_metric_ops, + output_alternatives=create_output_alternatives_fn(predictions)) + + class _RegressionHead(_SingleHead): """_Head for regression with a generalized linear model.""" @@ -525,44 +585,29 @@ class _RegressionHead(_SingleHead): logits_input=None, scope=None): """See `_Head`.""" - _check_mode_valid(mode) - - with variable_scope.variable_scope( - scope, - self.head_name or "regression_head", - values=(tuple(six.itervalues(features)) + - (labels, logits, logits_input))): - logits = _logits(logits_input, logits, self.logits_dimension) - centered_bias = None - if self._enable_centered_bias: - centered_bias = _centered_bias(self.logits_dimension, self.head_name) - logits = nn.bias_add(logits, centered_bias) - - predictions = self._logits_to_predictions(logits) - loss = None - train_op = None - eval_metric_ops = None - if (mode != model_fn.ModeKeys.INFER) and (labels is not None): - labels_tensor = _to_labels_tensor(labels, self._label_name) - loss, weighted_average_loss = _loss( - self._loss_fn(logits, labels_tensor), - _weight_tensor(features, self.weight_column_name)) - logging_ops.scalar_summary( - _summary_key(self.head_name, mkey.LOSS), weighted_average_loss) - - if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None): - train_op = _train_op(loss, labels_tensor, train_op_fn, centered_bias, - self.logits_dimension, self._loss_fn) - with ops.name_scope("default_metrics", values=[weighted_average_loss]): - eval_metric_ops = {_summary_key(self.head_name, mkey.LOSS): - metrics_lib.streaming_mean(weighted_average_loss)} - return model_fn.ModelFnOps( + return _create_model_fn_ops( + features=features, mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metric_ops=eval_metric_ops, - output_alternatives=self._create_output_alternatives(predictions)) + transform_labels_fn=self._transform_labels, + loss_fn=self._loss_fn, + logits_to_predictions_fn=self._logits_to_predictions, + metrics_fn=self._metrics, + create_output_alternatives_fn=self._create_output_alternatives, + default_variable_scope_name="regression_head", + labels=labels, + train_op_fn=train_op_fn, + logits=logits, + logits_input=logits_input, + logits_dimension=self.logits_dimension, + head_name=self.head_name, + weight_column_name=self.weight_column_name, + enable_centered_bias=self._enable_centered_bias) + + def _transform_labels(self, labels): + """Applies transformations to labels tensor.""" + labels_tensor = _to_labels_tensor(labels, self._label_name) + _check_no_sparse_tensor(labels_tensor) + return labels_tensor def _logits_to_predictions(self, logits): """Returns a dict of predictions. @@ -579,6 +624,14 @@ class _RegressionHead(_SingleHead): logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key) return {key: self._link_fn(logits)} + def _metrics(self, eval_loss, predictions, labels, weights): + """Returns a dict of metrics keyed by name.""" + del predictions, labels, weights # Unused by this head. + with ops.name_scope("metrics", values=[eval_loss]): + return { + _summary_key(self.head_name, mkey.LOSS): + metrics_lib.streaming_mean(eval_loss)} + def _log_loss_with_two_classes(logits, labels): with ops.name_scope(None, "log_loss_with_two_classes", @@ -646,45 +699,29 @@ class _BinaryLogisticHead(_SingleHead): logits_input=None, scope=None): """See `_Head`.""" - _check_mode_valid(mode) - - with variable_scope.variable_scope( - scope, - self.head_name or "binary_logistic_head", - values=(tuple(six.itervalues(features)) + - (labels, logits, logits_input))): - logits = _logits(logits_input, logits, self.logits_dimension) - centered_bias = None - if self._enable_centered_bias: - centered_bias = _centered_bias(1, self.head_name) - logits = nn.bias_add(logits, centered_bias) - - predictions = self._logits_to_predictions(logits) - loss = None - train_op = None - eval_metric_ops = None - if (mode != model_fn.ModeKeys.INFER) and (labels is not None): - weight_tensor = _weight_tensor(features, self.weight_column_name) - labels_tensor = _to_labels_tensor(labels, self._label_name) - loss, weighted_average_loss = _loss( - self._loss_fn(logits, labels_tensor), weight_tensor) - logging_ops.scalar_summary( - _summary_key(self.head_name, mkey.LOSS), weighted_average_loss) - - if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None): - train_op = _train_op(loss, labels_tensor, train_op_fn, centered_bias, - self.logits_dimension, self._loss_fn) - eval_metric_ops = self._default_metrics(weighted_average_loss, - predictions, labels_tensor, - weight_tensor) - - return model_fn.ModelFnOps( + return _create_model_fn_ops( + features=features, mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metric_ops=eval_metric_ops, - output_alternatives=self._create_output_alternatives(predictions)) + transform_labels_fn=self._transform_labels, + loss_fn=self._loss_fn, + logits_to_predictions_fn=self._logits_to_predictions, + metrics_fn=self._metrics, + create_output_alternatives_fn=self._create_output_alternatives, + default_variable_scope_name="binary_logistic_head", + labels=labels, + train_op_fn=train_op_fn, + logits=logits, + logits_input=logits_input, + logits_dimension=self.logits_dimension, + head_name=self.head_name, + weight_column_name=self.weight_column_name, + enable_centered_bias=self._enable_centered_bias) + + def _transform_labels(self, labels): + """Applies transformations to labels tensor.""" + labels_tensor = _to_labels_tensor(labels, self._label_name) + _check_no_sparse_tensor(labels_tensor) + return labels_tensor def _logits_to_predictions(self, logits): """Returns a dict of predictions. @@ -714,9 +751,9 @@ class _BinaryLogisticHead(_SingleHead): name=prediction_key.PredictionKey.CLASSES) } - def _default_metrics(self, eval_loss, predictions, labels, weights): + def _metrics(self, eval_loss, predictions, labels, weights): """Returns a dict of metrics keyed by name.""" - with ops.name_scope("default_metrics", values=( + with ops.name_scope("metrics", values=( [eval_loss, labels, weights] + list(six.itervalues(predictions)))): classes = predictions[prediction_key.PredictionKey.CLASSES] logistic = predictions[prediction_key.PredictionKey.LOGISTIC] @@ -837,45 +874,29 @@ class _MultiClassHead(_SingleHead): logits_input=None, scope=None): """See `_Head`.""" - _check_mode_valid(mode) - - with variable_scope.variable_scope( - scope, - self.head_name or "multi_class_head", - values=(tuple(six.itervalues(features)) + - (labels, logits, logits_input))): - logits = _logits(logits_input, logits, self.logits_dimension) - centered_bias = None - if self._enable_centered_bias: - centered_bias = _centered_bias(self.logits_dimension, self.head_name) - logits = nn.bias_add(logits, centered_bias) - - predictions = self._logits_to_predictions(logits) - loss = None - train_op = None - eval_metric_ops = None - if (mode != model_fn.ModeKeys.INFER) and (labels is not None): - weight_tensor = _weight_tensor(features, self.weight_column_name) - labels_tensor = _to_labels_tensor(labels, self._label_name, - self._logits_dimension) - loss, weighted_average_loss = _loss( - self._loss_fn(logits, labels_tensor), weight_tensor) - logging_ops.scalar_summary( - _summary_key(self.head_name, mkey.LOSS), weighted_average_loss) - - if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None): - train_op = _train_op(loss, labels_tensor, train_op_fn, centered_bias, - self.logits_dimension, self._loss_fn) - eval_metric_ops = self._default_metrics(weighted_average_loss, - predictions, labels_tensor, - weight_tensor) - return model_fn.ModelFnOps( + return _create_model_fn_ops( + features=features, mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metric_ops=eval_metric_ops, - output_alternatives=self._create_output_alternatives(predictions)) + transform_labels_fn=self._transform_labels, + loss_fn=self._loss_fn, + logits_to_predictions_fn=self._logits_to_predictions, + metrics_fn=self._metrics, + create_output_alternatives_fn=self._create_output_alternatives, + default_variable_scope_name="multi_class_head", + labels=labels, + train_op_fn=train_op_fn, + logits=logits, + logits_input=logits_input, + logits_dimension=self.logits_dimension, + head_name=self.head_name, + weight_column_name=self.weight_column_name, + enable_centered_bias=self._enable_centered_bias) + + def _transform_labels(self, labels): + """Applies transformations to labels tensor.""" + labels_tensor = _to_labels_tensor(labels, self._label_name) + _check_no_sparse_tensor(labels_tensor) + return labels_tensor def _logits_to_predictions(self, logits): """Returns a dict of predictions. @@ -898,9 +919,9 @@ class _MultiClassHead(_SingleHead): logits, 1, name=prediction_key.PredictionKey.CLASSES) } - def _default_metrics(self, eval_loss, predictions, labels, weights): + def _metrics(self, eval_loss, predictions, labels, weights): """Returns a dict of metrics keyed by name.""" - with ops.name_scope("default_metrics", values=( + with ops.name_scope("metrics", values=( [eval_loss, labels, weights] + list(six.itervalues(predictions)))): classes = predictions[prediction_key.PredictionKey.CLASSES] probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES] @@ -937,31 +958,44 @@ class _MultiClassHead(_SingleHead): return metrics -def _to_labels_tensor(labels, label_name, num_classes=None): - """Returns label as a tensor, converting from sparse to dense if needed. +def _to_labels_tensor(labels, label_name): + """Returns label as a tensor. Args: - labels: Label tensor or a dict containig labels. + labels: Label `Tensor` or `SparseTensor` or a dict containig labels. label_name: Label name if labels is a dict. + + Returns: + Label `Tensor` or `SparseTensor`. + """ + labels = labels[label_name] if isinstance(labels, dict) else labels + return framework_lib.convert_to_tensor_or_sparse_tensor(labels) + + +def _check_no_sparse_tensor(x): + """Raises ValueError if the given tensor is `SparseTensor`.""" + if isinstance(x, sparse_tensor.SparseTensor): + raise ValueError("SparseTensor is not supported.") + + +def _sparse_labels_to_indicator(labels, num_classes): + """If labels is `SparseTensor`, converts it to indicator `Tensor`. + + Args: + labels: Label `Tensor` or `SparseTensor`. num_classes: Number of classes. Returns: - Dense label tensor. If label is sparse, it will be converted to dense. + Dense label `Tensor`. Raises: - ValueError: if label is sparse and num_classes is not provided or <2 + ValueError: If labels is `SparseTensot` and `num_classes` < 2. """ - labels = labels[label_name] if isinstance(labels, dict) else labels - labels = framework_lib.convert_to_tensor_or_sparse_tensor(labels) if isinstance(labels, sparse_tensor.SparseTensor): - if num_classes is None: - raise ValueError("Must set num_classes when passing labels as a " - "SparseTensor. Sparse labels are currently supported " - "for MultiLabelHead only.") if num_classes < 2: raise ValueError("Must set num_classes >= 2 when passing labels as a " "SparseTensor.") - labels = math_ops.to_int64( + return math_ops.to_int64( sparse_ops.sparse_to_indicator(labels, num_classes)) return labels @@ -972,7 +1006,7 @@ def _assert_labels_rank(labels): ("labels shape should be either [batch_size, 1] or [batch_size]",)) -class _BinarySvmHead(_BinaryLogisticHead): +class _BinarySvmHead(_SingleHead): """_Head for binary classification using SVMs.""" def __init__(self, label_name, weight_column_name, enable_centered_bias, @@ -985,12 +1019,47 @@ class _BinarySvmHead(_BinaryLogisticHead): return losses_lib.hinge_loss(logits, labels, scope=name) super(_BinarySvmHead, self).__init__( + problem_type=constants.ProblemType.LOGISTIC_REGRESSION, + logits_dimension=1, label_name=label_name, weight_column_name=weight_column_name, - enable_centered_bias=enable_centered_bias, - head_name=head_name, - loss_fn=_loss_fn, - thresholds=thresholds) + head_name=head_name) + self._thresholds = thresholds if thresholds else (.5,) + self._loss_fn = _loss_fn + self._enable_centered_bias = enable_centered_bias + + def create_model_fn_ops(self, + features, + mode, + labels=None, + train_op_fn=None, + logits=None, + logits_input=None, + scope=None): + """See `_Head`.""" + return _create_model_fn_ops( + features=features, + mode=mode, + transform_labels_fn=self._transform_labels, + loss_fn=self._loss_fn, + logits_to_predictions_fn=self._logits_to_predictions, + metrics_fn=self._metrics, + create_output_alternatives_fn=self._create_output_alternatives, + default_variable_scope_name="binary_svm_head", + labels=labels, + train_op_fn=train_op_fn, + logits=logits, + logits_input=logits_input, + logits_dimension=self.logits_dimension, + head_name=self.head_name, + weight_column_name=self.weight_column_name, + enable_centered_bias=self._enable_centered_bias) + + def _transform_labels(self, labels): + """Applies transformations to labels tensor.""" + labels_tensor = _to_labels_tensor(labels, self._label_name) + _check_no_sparse_tensor(labels_tensor) + return labels_tensor def _logits_to_predictions(self, logits): """See `_MultiClassHead`.""" @@ -1005,9 +1074,9 @@ class _BinarySvmHead(_BinaryLogisticHead): name=prediction_key.PredictionKey.CLASSES) } - def _default_metrics(self, eval_loss, predictions, labels, weights): + def _metrics(self, eval_loss, predictions, labels, weights): """See `_MultiClassHead`.""" - with ops.name_scope("default_metrics", values=( + with ops.name_scope("metrics", values=( [eval_loss, labels, weights] + list(six.itervalues(predictions)))): metrics = {_summary_key(self.head_name, mkey.LOSS): metrics_lib.streaming_mean(eval_loss)} @@ -1022,7 +1091,7 @@ class _BinarySvmHead(_BinaryLogisticHead): return metrics -class _MultiLabelHead(_MultiClassHead): +class _MultiLabelHead(_SingleHead): """_Head for multlabel classification.""" # TODO(zakaria): add signature and metric for multilabel. @@ -1036,14 +1105,54 @@ class _MultiLabelHead(_MultiClassHead): metric_class_ids=None): super(_MultiLabelHead, self).__init__( - n_classes=n_classes, + problem_type=constants.ProblemType.CLASSIFICATION, + logits_dimension=n_classes, label_name=label_name, weight_column_name=weight_column_name, - enable_centered_bias=enable_centered_bias, - head_name=head_name, - loss_fn=_sigmoid_cross_entropy_loss, - thresholds=thresholds, - metric_class_ids=metric_class_ids) + head_name=head_name) + + self._thresholds = thresholds if thresholds else (.5,) + self._loss_fn = _sigmoid_cross_entropy_loss + self._enable_centered_bias = enable_centered_bias + self._metric_class_ids = tuple([] if metric_class_ids is None else + metric_class_ids) + for class_id in self._metric_class_ids: + if (class_id < 0) or (class_id >= n_classes): + raise ValueError("Class ID %s not in [0, %s)." % (class_id, n_classes)) + + def create_model_fn_ops(self, + features, + mode, + labels=None, + train_op_fn=None, + logits=None, + logits_input=None, + scope=None): + """See `_Head`.""" + return _create_model_fn_ops( + features=features, + mode=mode, + transform_labels_fn=self._transform_labels, + loss_fn=self._loss_fn, + logits_to_predictions_fn=self._logits_to_predictions, + metrics_fn=self._metrics, + create_output_alternatives_fn=self._create_output_alternatives, + default_variable_scope_name="multi_label_head", + labels=labels, + train_op_fn=train_op_fn, + logits=logits, + logits_input=logits_input, + logits_dimension=self.logits_dimension, + head_name=self.head_name, + weight_column_name=self.weight_column_name, + enable_centered_bias=self._enable_centered_bias) + + def _transform_labels(self, labels): + """Applies transformations to labels tensor.""" + labels_tensor = _to_labels_tensor(labels, self._label_name) + labels_tensor = _sparse_labels_to_indicator(labels_tensor, + self._logits_dimension) + return labels_tensor def _logits_to_predictions(self, logits): """See `_MultiClassHead`.""" @@ -1060,9 +1169,9 @@ class _MultiLabelHead(_MultiClassHead): name=prediction_key.PredictionKey.CLASSES) } - def _default_metrics(self, eval_loss, predictions, labels, weights): + def _metrics(self, eval_loss, predictions, labels, weights): """Returns a dict of metrics keyed by name.""" - with ops.name_scope("default_metrics", values=( + with ops.name_scope("metrics", values=( [eval_loss, labels, weights] + list(six.itervalues(predictions)))): classes = predictions[prediction_key.PredictionKey.CLASSES] probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES] @@ -1554,4 +1663,3 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold): predictions, labels=labels, thresholds=(threshold,), weights=_float_weights_or_none(weights)) return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) - diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 715275e3eb..da2e509a4a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -277,7 +277,7 @@ class RegressionHeadTest(test.TestCase): values=(0., 1., 1.), dense_shape=(3, 1)) with self.assertRaisesRegexp(ValueError, - "Must set num_classes when passing"): + "SparseTensor is not supported"): head.create_model_fn_ops( {}, labels=labels, @@ -336,6 +336,36 @@ class MultiLabelHeadTest(test.TestCase): _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) + def testMultiLabelTwoClasses(self): + n_classes = 2 + labels = ((0, 1),) + logits = ((1., 0.),) + head = head_lib._multi_label_head( + n_classes=n_classes, metric_class_ids=range(n_classes)) + with ops.Graph().as_default(), session.Session(): + model_fn_ops = head.create_model_fn_ops( + {}, model_fn.ModeKeys.TRAIN, labels=labels, + train_op_fn=_noop_train_op, logits=logits) + self._assert_output_alternatives(model_fn_ops) + _assert_no_variables(self) + _assert_summary_tags(self, ["loss"]) + expected_loss = 1.00320443 + _assert_metrics(self, expected_loss, { + "accuracy": 0., + "auc": 0., + "loss": expected_loss, + "auc/class0": 1., + "auc/class1": 0., + "labels/actual_label_mean/class0": labels[0][0], + "labels/actual_label_mean/class1": labels[0][1], + "labels/logits_mean/class0": logits[0][0], + "labels/logits_mean/class1": logits[0][1], + "labels/prediction_mean/class0": logits[0][0], + "labels/prediction_mean/class1": logits[0][1], + "labels/probability_mean/class0": _sigmoid(logits[0][0]), + "labels/probability_mean/class1": _sigmoid(logits[0][1]), + }, model_fn_ops) + def testMultiLabelWithInvalidLogits(self): head = head_lib._multi_label_head(n_classes=len(self._labels[0]) + 1) with ops.Graph().as_default(), session.Session(): @@ -353,8 +383,8 @@ class MultiLabelHeadTest(test.TestCase): {}, model_fn.ModeKeys.TRAIN, self._labels, _noop_train_op, logits_input=((0., 0.),)) self._assert_output_alternatives(model_fn_ops) - w = ("multi_class_head/logits/weights:0", - "multi_class_head/logits/biases:0") + w = ("multi_label_head/logits/weights:0", + "multi_label_head/logits/biases:0") _assert_variables( self, expected_global=w, expected_model=w, expected_trainable=w) variables.global_variables_initializer().run() @@ -459,16 +489,16 @@ class MultiLabelHeadTest(test.TestCase): _assert_variables( self, expected_global=( - "multi_class_head/centered_bias_weight:0", - ("multi_class_head/multi_class_head/centered_bias_weight/" + "multi_label_head/centered_bias_weight:0", + ("multi_label_head/multi_label_head/centered_bias_weight/" "Adagrad:0"),), - expected_trainable=("multi_class_head/centered_bias_weight:0",)) + expected_trainable=("multi_label_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() _assert_summary_tags(self, ( "loss", - "multi_class_head/centered_bias/bias_0", - "multi_class_head/centered_bias/bias_1", - "multi_class_head/centered_bias/bias_2" + "multi_label_head/centered_bias/bias_0", + "multi_label_head/centered_bias/bias_1", + "multi_label_head/centered_bias/bias_2" )) expected_loss = .89985204 _assert_metrics(self, expected_loss, @@ -662,7 +692,7 @@ class BinaryClassificationHeadTest(test.TestCase): values=(0, 1, 1), dense_shape=(3, 1)) with self.assertRaisesRegexp(ValueError, - "Must set num_classes when passing"): + "SparseTensor is not supported"): head.create_model_fn_ops( {}, model_fn.ModeKeys.TRAIN, @@ -959,8 +989,8 @@ class BinarySvmHeadTest(test.TestCase): _noop_train_op, logits_input=((0., 0.), (0., 0.))) self._assert_output_alternatives(model_fn_ops) - w = ("binary_logistic_head/logits/weights:0", - "binary_logistic_head/logits/biases:0") + w = ("binary_svm_head/logits/weights:0", + "binary_svm_head/logits/biases:0") _assert_variables( self, expected_global=w, expected_model=w, expected_trainable=w) variables.global_variables_initializer().run() @@ -1055,14 +1085,14 @@ class BinarySvmHeadTest(test.TestCase): _assert_variables( self, expected_global=( - "binary_logistic_head/centered_bias_weight:0", - ("binary_logistic_head/binary_logistic_head/centered_bias_weight/" + "binary_svm_head/centered_bias_weight:0", + ("binary_svm_head/binary_svm_head/centered_bias_weight/" "Adagrad:0"), ), - expected_trainable=("binary_logistic_head/centered_bias_weight:0",)) + expected_trainable=("binary_svm_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() _assert_summary_tags( - self, ["loss", "binary_logistic_head/centered_bias/bias_0"]) + self, ["loss", "binary_svm_head/centered_bias/bias_0"]) expected_loss = np.average(self._expected_losses) _assert_metrics(self, expected_loss, { "accuracy": 1., |