aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned/dnn_testing_utils.py
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-10-09 08:16:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 08:21:39 -0700
commitcadcacc6224bcbb8a05bf3b70d625d9024a9c0f3 (patch)
treefe73a2d1ed500dbd1e5b0f6f20229e534f813d90 /tensorflow/python/estimator/canned/dnn_testing_utils.py
parenta0ed9452d5c7f897e26788d8dca5164cb6fba023 (diff)
Allowing for mixture of V1 and V2 feature columns usage in canned estimators. This is required for TF hub use cases where users might send in new feature columns to old model code. Implemented this support by making V2 feature columns support the V1 API. This is needed temporarily and would definitely be removed by TF 2.0, possibly earlier depending on what guarantees are provided by TF hub.
The only case we don't allow here is mixing in V2 shared embedding columns with V1 Feature columns. V2 Shared FC's depend on a SharedEmbeddingState manager that would have to be passed in to the various API's and there wasn't really a very clean way to make that work. Mixing V2 feature columns with V1 shared embedding columns is fine though and along with all other combinations PiperOrigin-RevId: 216359041
Diffstat (limited to 'tensorflow/python/estimator/canned/dnn_testing_utils.py')
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index cd66d0a3bd..71d7e54783 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -34,6 +34,7 @@ from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -479,6 +480,60 @@ class BaseDNNModelFnTest(object):
else:
self.fail('Invalid mode: {}'.format(mode))
+ def test_multi_feature_column_mix_multi_dim_logits(self):
+ """Tests multiple feature columns and multi-dimensional logits.
+
+ All numbers are the same as test_multi_dim_input_multi_dim_logits. The only
+ difference is that the input consists of two 1D feature columns, instead of
+ one 2D feature column.
+ """
+ base_global_step = 100
+ create_checkpoint((
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
+ ), base_global_step, self._model_dir)
+ hidden_units = (2, 2)
+ logits_dimension = 3
+ inputs = ([[10.]], [[8.]])
+ expected_logits = [[-0.48, 0.48, 0.39]]
+
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ head = mock_head(
+ self,
+ hidden_units=hidden_units,
+ logits_dimension=logits_dimension,
+ expected_logits=expected_logits)
+ estimator_spec = self._dnn_model_fn(
+ features={
+ 'age': constant_op.constant(inputs[0]),
+ 'height': constant_op.constant(inputs[1])
+ },
+ labels=constant_op.constant([[1]]),
+ mode=mode,
+ head=head,
+ hidden_units=hidden_units,
+ feature_columns=[
+ feature_column.numeric_column('age'),
+ feature_column_v2.numeric_column('height')
+ ],
+ optimizer=mock_optimizer(self, hidden_units))
+ with monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=self._model_dir) as sess:
+ if mode == model_fn.ModeKeys.TRAIN:
+ sess.run(estimator_spec.train_op)
+ elif mode == model_fn.ModeKeys.EVAL:
+ sess.run(estimator_spec.loss)
+ elif mode == model_fn.ModeKeys.PREDICT:
+ sess.run(estimator_spec.predictions)
+ else:
+ self.fail('Invalid mode: {}'.format(mode))
+
def test_features_tensor_raises_value_error(self):
"""Tests that passing a Tensor for features raises a ValueError."""
hidden_units = (2, 2)
@@ -806,6 +861,60 @@ class BaseDNNLogitFnTest(object):
checkpoint_dir=self._model_dir) as sess:
self.assertAllClose(expected_logits, sess.run(logits))
+ def test_multi_feature_column_mix_multi_dim_logits(self):
+ """Tests multiple feature columns and multi-dimensional logits.
+
+ All numbers are the same as test_multi_dim_input_multi_dim_logits. The only
+ difference is that the input consists of two 1D feature columns, instead of
+ one 2D feature column.
+ """
+ base_global_step = 100
+ create_checkpoint((
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
+ ), base_global_step, self._model_dir)
+
+ hidden_units = (2, 2)
+ logits_dimension = 3
+ inputs = ([[10.]], [[8.]])
+ expected_logits = [[-0.48, 0.48, 0.39]]
+
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ with ops.Graph().as_default():
+ # Global step needed for MonitoredSession, which is in turn used to
+ # explicitly set variable weights through a checkpoint.
+ training_util.create_global_step()
+ # Use a variable scope here with 'dnn', emulating the dnn model_fn, so
+ # the checkpoint naming is shared.
+ with variable_scope.variable_scope('dnn'):
+ input_layer_partitioner = (
+ partitioned_variables.min_max_variable_partitioner(
+ max_partitions=0, min_slice_size=64 << 20))
+ logit_fn = self._dnn_logit_fn_builder(
+ units=logits_dimension,
+ hidden_units=hidden_units,
+ feature_columns=[
+ feature_column.numeric_column('age'),
+ feature_column_v2.numeric_column('height')
+ ],
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=input_layer_partitioner,
+ batch_norm=False)
+ logits = logit_fn(
+ features={
+ 'age': constant_op.constant(inputs[0]),
+ 'height': constant_op.constant(inputs[1])
+ },
+ mode=mode)
+ with monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=self._model_dir) as sess:
+ self.assertAllClose(expected_logits, sess.run(logits))
+
class BaseDNNWarmStartingTest(object):