aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 11:00:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 11:08:35 -0700
commit55ea7f89ee6aa45c5a7623ac9ba671044467e807 (patch)
tree4a2bcc407b8536540af4e7bc31328a85b6d07baf /tensorflow/contrib/estimator
parentb25ef3877da28b7ec31d0bd69a7a6268f5e8a4b4 (diff)
Supports TPUEstimatorSpec in multi_head for TRAIN and PREDICT modes.
PiperOrigin-RevId: 215590676
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head.py67
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py75
2 files changed, 111 insertions, 31 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py
index ce75899214..6e793c8302 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py
@@ -233,6 +233,22 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None):
"""See `_Head`."""
+ return self._create_estimator_spec(
+ features=features, mode=mode, logits=logits, labels=labels,
+ optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=False)
+
+ def _create_tpu_estimator_spec(
+ self, features, mode, logits, labels=None, optimizer=None,
+ train_op_fn=None):
+ """See `_Head`."""
+ return self._create_estimator_spec(
+ features=features, mode=mode, logits=logits, labels=labels,
+ optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=True)
+
+ def _create_estimator_spec(
+ self, features, mode, logits, labels=None, optimizer=None,
+ train_op_fn=None, use_tpu=False):
+ """Returns `EstimatorSpec` or `TPUEstimatorSpec`."""
if isinstance(logits, dict):
logits_dict = logits
else:
@@ -255,14 +271,15 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
spec = self._merge_train(
all_estimator_spec=all_estimator_spec,
optimizer=optimizer,
- train_op_fn=train_op_fn)
+ train_op_fn=train_op_fn,
+ use_tpu=use_tpu)
with ops.name_scope(''):
summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss)
return spec
if mode == model_fn.ModeKeys.PREDICT:
- return self._merge_predict(all_estimator_spec)
+ return self._merge_predict(all_estimator_spec, use_tpu=use_tpu)
if mode == model_fn.ModeKeys.EVAL:
- return self._merge_eval(all_estimator_spec)
+ return self._merge_eval(all_estimator_spec, use_tpu=use_tpu)
raise ValueError('mode={} unrecognized'.format(mode))
def _split_logits(self, logits):
@@ -284,28 +301,28 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
begin_idx += head.logits_dimension
return logits_dict
- def _merge_train(self, all_estimator_spec, optimizer, train_op_fn):
- """Merges list of `EstimatorSpec` for training.
+ def _merge_train(
+ self, all_estimator_spec, optimizer, train_op_fn, use_tpu=False):
+ """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for training.
Args:
- all_estimator_spec: list of `EstimatorSpec` for the individual heads.
+ all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the
+ individual heads.
optimizer: `Optimizer` instance to create train op. See
`create_estimator_spec` documentation for more details.
train_op_fn: Function to create train op. Used if `optimizer` is `None`.
+ use_tpu: If `True`, returns `TPUEstimatorSpec`.
Returns:
- `EstimatorSpec` that merges all heads for TRAIN.
+ `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for TRAIN.
Raises:
ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
mode.
"""
losses = []
- metrics = {}
for spec in all_estimator_spec:
losses.append(spec.loss)
- # Metric keys already contain head.name.
- metrics.update(spec.eval_metric_ops or {})
loss = _merge_losses(losses, self._head_weights)
if optimizer is not None:
if train_op_fn is not None:
@@ -317,20 +334,23 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
else:
raise ValueError('train_op_fn and optimizer cannot both be None.')
- return model_fn.EstimatorSpec(
+ spec_type = (
+ model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access
+ return spec_type(
mode=model_fn.ModeKeys.TRAIN,
loss=loss,
- train_op=train_op,
- eval_metric_ops=metrics)
+ train_op=train_op)
- def _merge_predict(self, all_estimator_spec):
- """Merges list of `EstimatorSpec` for prediction.
+ def _merge_predict(self, all_estimator_spec, use_tpu=False):
+ """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for prediction.
Args:
- all_estimator_spec: list of `EstimatorSpec` for the individual heads.
+ all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the
+ individual heads.
+ use_tpu: If `True`, returns `TPUEstimatorSpec`.
Returns:
- `EstimatorSpec` that merges all heads for PREDICT.
+ `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for PREDICT.
"""
predictions = {}
export_outputs = {
@@ -357,20 +377,29 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
export_outputs[head_lib._PREDICT_SERVING_KEY] = ( # pylint:disable=protected-access
export_output_lib.PredictOutput(merged_predict_outputs))
- return model_fn.EstimatorSpec(
+ spec_type = (
+ model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access
+ return spec_type(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs=export_outputs)
- def _merge_eval(self, all_estimator_spec):
+ def _merge_eval(self, all_estimator_spec, use_tpu=False):
"""Merges list of `EstimatorSpec` for eval.
Args:
all_estimator_spec: list of `EstimatorSpec` for the individual heads.
+ use_tpu: If `True`, will raise `NotImplementedError`, because TPU is not
+ yet supported for eval.
Returns:
`EstimatorSpec` that merges all heads for EVAL.
+ Raises:
+ NotImplementedError: If `use_tpu` is `True`.
"""
+ if use_tpu:
+ raise NotImplementedError(
+ 'TPU evaluation is not implemented for multi_head.')
predictions = {}
metrics = {}
losses = []
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index 2b4d5f5261..a602f87b4a 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -106,7 +106,7 @@ class MultiHeadTest(test.TestCase):
multi_head = multi_head_lib.multi_head([head1, head2])
self.assertEqual('head1_head2', multi_head.name)
- def test_predict_two_heads_logits_dict(self):
+ def _test_predict_two_heads_logits_dict(self, use_tpu):
"""Tests predict with logits as dict."""
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
head2 = head_lib.multi_label_head(n_classes=3, name='head2')
@@ -121,10 +121,16 @@ class MultiHeadTest(test.TestCase):
'head2': _sigmoid(logits['head2']),
}
- spec = multi_head.create_estimator_spec(
- features={'x': np.array(((42,),), dtype=np.int32)},
- mode=model_fn.ModeKeys.PREDICT,
- logits=logits)
+ if use_tpu:
+ spec = multi_head._create_tpu_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits).as_estimator_spec()
+ else:
+ spec = multi_head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits)
self.assertItemsEqual(
(_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
@@ -175,6 +181,12 @@ class MultiHeadTest(test.TestCase):
sess.run(
spec.export_outputs['head2/predict'].outputs['probabilities']))
+ def test_predict_two_heads_logits_dict(self):
+ self._test_predict_two_heads_logits_dict(use_tpu=False)
+
+ def test_predict_two_heads_logits_dict_tpu(self):
+ self._test_predict_two_heads_logits_dict(use_tpu=True)
+
def test_predict_two_heads_logits_tensor(self):
"""Tests predict with logits as Tensor."""
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
@@ -350,6 +362,31 @@ class MultiHeadTest(test.TestCase):
rtol=tol,
atol=tol)
+ def test_eval_tpu(self):
+ head1 = head_lib.multi_label_head(n_classes=2, name='head1')
+ head2 = head_lib.multi_label_head(n_classes=3, name='head2')
+ multi_head = multi_head_lib.multi_head(
+ [head1, head2], head_weights=[1., 2.])
+
+ logits = {
+ 'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
+ 'head2': np.array([[20., -20., 20.], [-30., 20., -20.]],
+ dtype=np.float32),
+ }
+ labels = {
+ 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
+ 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
+ }
+
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ r'TPU evaluation is not implemented for multi_head\.'):
+ multi_head._create_tpu_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)
+
def test_train_create_loss_one_head(self):
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
multi_head = multi_head_lib.multi_head([head1])
@@ -587,7 +624,7 @@ class MultiHeadTest(test.TestCase):
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
- def test_train_two_heads_with_weights(self):
+ def _test_train_two_heads_with_weights(self, use_tpu):
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
head2 = head_lib.multi_label_head(n_classes=3, name='head2')
multi_head = multi_head_lib.multi_head(
@@ -619,12 +656,20 @@ class MultiHeadTest(test.TestCase):
[constant_op.constant(expected_train_result),
string_ops.as_string(loss, precision=3)])
- spec = multi_head.create_estimator_spec(
- features={'x': np.array(((42,),), dtype=np.int32)},
- mode=model_fn.ModeKeys.TRAIN,
- logits=logits,
- labels=labels,
- train_op_fn=_train_op_fn)
+ if use_tpu:
+ spec = multi_head._create_tpu_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn).as_estimator_spec()
+ else:
+ spec = multi_head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
self.assertIsNotNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops)
@@ -649,6 +694,12 @@ class MultiHeadTest(test.TestCase):
metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2,
}, summary_str, tol)
+ def test_train_two_heads_with_weights(self):
+ self._test_train_two_heads_with_weights(use_tpu=False)
+
+ def test_train_two_heads_with_weights_tpu(self):
+ self._test_train_two_heads_with_weights(use_tpu=True)
+
if __name__ == '__main__':
test.main()