diff options
author | 2018-10-03 11:00:21 -0700 | |
---|---|---|
committer | 2018-10-03 11:08:35 -0700 | |
commit | 55ea7f89ee6aa45c5a7623ac9ba671044467e807 (patch) | |
tree | 4a2bcc407b8536540af4e7bc31328a85b6d07baf /tensorflow/contrib | |
parent | b25ef3877da28b7ec31d0bd69a7a6268f5e8a4b4 (diff) |
Supports TPUEstimatorSpec in multi_head for TRAIN and PREDICT modes.
PiperOrigin-RevId: 215590676
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/multi_head.py | 67 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/multi_head_test.py | 75 |
2 files changed, 111 insertions, 31 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index ce75899214..6e793c8302 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -233,6 +233,22 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None): """See `_Head`.""" + return self._create_estimator_spec( + features=features, mode=mode, logits=logits, labels=labels, + optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=False) + + def _create_tpu_estimator_spec( + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None): + """See `_Head`.""" + return self._create_estimator_spec( + features=features, mode=mode, logits=logits, labels=labels, + optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=True) + + def _create_estimator_spec( + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, use_tpu=False): + """Returns `EstimatorSpec` or `TPUEstimatorSpec`.""" if isinstance(logits, dict): logits_dict = logits else: @@ -255,14 +271,15 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access spec = self._merge_train( all_estimator_spec=all_estimator_spec, optimizer=optimizer, - train_op_fn=train_op_fn) + train_op_fn=train_op_fn, + use_tpu=use_tpu) with ops.name_scope(''): summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) return spec if mode == model_fn.ModeKeys.PREDICT: - return self._merge_predict(all_estimator_spec) + return self._merge_predict(all_estimator_spec, use_tpu=use_tpu) if mode == model_fn.ModeKeys.EVAL: - return self._merge_eval(all_estimator_spec) + return self._merge_eval(all_estimator_spec, use_tpu=use_tpu) raise ValueError('mode={} unrecognized'.format(mode)) def _split_logits(self, logits): @@ -284,28 +301,28 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access begin_idx += head.logits_dimension return logits_dict - def _merge_train(self, all_estimator_spec, optimizer, train_op_fn): - """Merges list of `EstimatorSpec` for training. + def _merge_train( + self, all_estimator_spec, optimizer, train_op_fn, use_tpu=False): + """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for training. Args: - all_estimator_spec: list of `EstimatorSpec` for the individual heads. + all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the + individual heads. optimizer: `Optimizer` instance to create train op. See `create_estimator_spec` documentation for more details. train_op_fn: Function to create train op. Used if `optimizer` is `None`. + use_tpu: If `True`, returns `TPUEstimatorSpec`. Returns: - `EstimatorSpec` that merges all heads for TRAIN. + `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for TRAIN. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode. """ losses = [] - metrics = {} for spec in all_estimator_spec: losses.append(spec.loss) - # Metric keys already contain head.name. - metrics.update(spec.eval_metric_ops or {}) loss = _merge_losses(losses, self._head_weights) if optimizer is not None: if train_op_fn is not None: @@ -317,20 +334,23 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access else: raise ValueError('train_op_fn and optimizer cannot both be None.') - return model_fn.EstimatorSpec( + spec_type = ( + model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access + return spec_type( mode=model_fn.ModeKeys.TRAIN, loss=loss, - train_op=train_op, - eval_metric_ops=metrics) + train_op=train_op) - def _merge_predict(self, all_estimator_spec): - """Merges list of `EstimatorSpec` for prediction. + def _merge_predict(self, all_estimator_spec, use_tpu=False): + """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for prediction. Args: - all_estimator_spec: list of `EstimatorSpec` for the individual heads. + all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the + individual heads. + use_tpu: If `True`, returns `TPUEstimatorSpec`. Returns: - `EstimatorSpec` that merges all heads for PREDICT. + `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for PREDICT. """ predictions = {} export_outputs = { @@ -357,20 +377,29 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access export_outputs[head_lib._PREDICT_SERVING_KEY] = ( # pylint:disable=protected-access export_output_lib.PredictOutput(merged_predict_outputs)) - return model_fn.EstimatorSpec( + spec_type = ( + model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access + return spec_type( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs=export_outputs) - def _merge_eval(self, all_estimator_spec): + def _merge_eval(self, all_estimator_spec, use_tpu=False): """Merges list of `EstimatorSpec` for eval. Args: all_estimator_spec: list of `EstimatorSpec` for the individual heads. + use_tpu: If `True`, will raise `NotImplementedError`, because TPU is not + yet supported for eval. Returns: `EstimatorSpec` that merges all heads for EVAL. + Raises: + NotImplementedError: If `use_tpu` is `True`. """ + if use_tpu: + raise NotImplementedError( + 'TPU evaluation is not implemented for multi_head.') predictions = {} metrics = {} losses = [] diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 2b4d5f5261..a602f87b4a 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -106,7 +106,7 @@ class MultiHeadTest(test.TestCase): multi_head = multi_head_lib.multi_head([head1, head2]) self.assertEqual('head1_head2', multi_head.name) - def test_predict_two_heads_logits_dict(self): + def _test_predict_two_heads_logits_dict(self, use_tpu): """Tests predict with logits as dict.""" head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -121,10 +121,16 @@ class MultiHeadTest(test.TestCase): 'head2': _sigmoid(logits['head2']), } - spec = multi_head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.PREDICT, - logits=logits) + if use_tpu: + spec = multi_head._create_tpu_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits).as_estimator_spec() + else: + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) self.assertItemsEqual( (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification', @@ -175,6 +181,12 @@ class MultiHeadTest(test.TestCase): sess.run( spec.export_outputs['head2/predict'].outputs['probabilities'])) + def test_predict_two_heads_logits_dict(self): + self._test_predict_two_heads_logits_dict(use_tpu=False) + + def test_predict_two_heads_logits_dict_tpu(self): + self._test_predict_two_heads_logits_dict(use_tpu=True) + def test_predict_two_heads_logits_tensor(self): """Tests predict with logits as Tensor.""" head1 = head_lib.multi_label_head(n_classes=2, name='head1') @@ -350,6 +362,31 @@ class MultiHeadTest(test.TestCase): rtol=tol, atol=tol) + def test_eval_tpu(self): + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2') + multi_head = multi_head_lib.multi_head( + [head1, head2], head_weights=[1., 2.]) + + logits = { + 'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + 'head2': np.array([[20., -20., 20.], [-30., 20., -20.]], + dtype=np.float32), + } + labels = { + 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), + 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), + } + + with self.assertRaisesRegexp( + NotImplementedError, + r'TPU evaluation is not implemented for multi_head\.'): + multi_head._create_tpu_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels) + def test_train_create_loss_one_head(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') multi_head = multi_head_lib.multi_head([head1]) @@ -587,7 +624,7 @@ class MultiHeadTest(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - def test_train_two_heads_with_weights(self): + def _test_train_two_heads_with_weights(self, use_tpu): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') multi_head = multi_head_lib.multi_head( @@ -619,12 +656,20 @@ class MultiHeadTest(test.TestCase): [constant_op.constant(expected_train_result), string_ops.as_string(loss, precision=3)]) - spec = multi_head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.TRAIN, - logits=logits, - labels=labels, - train_op_fn=_train_op_fn) + if use_tpu: + spec = multi_head._create_tpu_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn).as_estimator_spec() + else: + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) self.assertIsNotNone(spec.loss) self.assertEqual({}, spec.eval_metric_ops) @@ -649,6 +694,12 @@ class MultiHeadTest(test.TestCase): metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, }, summary_str, tol) + def test_train_two_heads_with_weights(self): + self._test_train_two_heads_with_weights(use_tpu=False) + + def test_train_two_heads_with_weights_tpu(self): + self._test_train_two_heads_with_weights(use_tpu=True) + if __name__ == '__main__': test.main() |