aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-23 16:09:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-23 16:12:54 -0700
commit6d8c3af9ec938f36eb5df84cb77ece091fa97355 (patch)
tree8bf5368d21330233319fa3708b45eb916e0ea20f
parent15705148626d19a58d0e244e9849bce746f653d7 (diff)
Tests _dnn_model_fn using mock head, modifies other tests to use consistent numbers, and drops some unnecessary tests.
PiperOrigin-RevId: 156925729
-rw-r--r--tensorflow/python/estimator/BUILD5
-rw-r--r--tensorflow/python/estimator/canned/dnn_test.py1028
2 files changed, 376 insertions, 657 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index b7f83afdb1..7584cede10 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -114,17 +114,20 @@ py_test(
deps = [
":dnn",
":export_export",
+ ":head",
":metric_keys",
+ ":model_fn",
":numpy_io",
":prediction_keys",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:nn",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
diff --git a/tensorflow/python/estimator/canned/dnn_test.py b/tensorflow/python/estimator/canned/dnn_test.py
index d680adbb44..44eba6e6e6 100644
--- a/tensorflow/python/estimator/canned/dnn_test.py
+++ b/tensorflow/python/estimator/canned/dnn_test.py
@@ -27,23 +27,27 @@ import six
from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session as tf_session
+from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary as summary_lib
from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import monitored_session
from tensorflow.python.training import optimizer
from tensorflow.python.training import saver
from tensorflow.python.training import session_run_hook
@@ -84,9 +88,6 @@ def _create_checkpoint(weights_and_biases, global_step, model_dir):
# Create non-model variables.
global_step_var = training_util.create_global_step()
- # TODO(ptucker): We shouldn't have this in the checkpoint for constant LRs.
- # Learning rate.
- variables_lib.Variable(.5, name=_LEARNING_RATE_NAME, dtype=dtypes.float32)
# Initialize vars and save checkpoint.
with tf_session.Session() as sess:
@@ -95,7 +96,55 @@ def _create_checkpoint(weights_and_biases, global_step, model_dir):
saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
-class DNNRegressorEvaluateTest(test.TestCase):
+def _mock_head(
+ testcase, hidden_units, logits_dimension, expected_logits):
+ """Returns a mock head that validates logits values and variable names."""
+ hidden_weights_names = [
+ (_HIDDEN_WEIGHTS_NAME_PATTERN + '/part_0:0') % i
+ for i in range(len(hidden_units))]
+ hidden_biases_names = [
+ (_HIDDEN_BIASES_NAME_PATTERN + '/part_0:0') % i
+ for i in range(len(hidden_units))]
+ expected_var_names = (
+ hidden_weights_names + hidden_biases_names +
+ [_LOGITS_WEIGHTS_NAME + '/part_0:0', _LOGITS_BIASES_NAME + '/part_0:0'])
+
+ def _create_estimator_spec(features, mode, logits, labels, train_op_fn):
+ del features, labels # Not used.
+ trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ testcase.assertItemsEqual(
+ expected_var_names,
+ [var.name for var in trainable_vars])
+ loss = constant_op.constant(1.)
+ assert_logits = _assert_close(
+ expected_logits, logits, message='Failed for mode={}. '.format(mode))
+ with ops.control_dependencies([assert_logits]):
+ if mode == model_fn.ModeKeys.TRAIN:
+ return model_fn.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ train_op=train_op_fn(loss))
+ elif mode == model_fn.ModeKeys.EVAL:
+ return model_fn.EstimatorSpec(
+ mode=mode,
+ loss=array_ops.identity(loss))
+ elif mode == model_fn.ModeKeys.PREDICT:
+ return model_fn.EstimatorSpec(
+ mode=mode,
+ predictions={'logits': array_ops.identity(logits)})
+ else:
+ testcase.fail('Invalid mode: {}'.format(mode))
+
+ mock_head = test.mock.NonCallableMagicMock(spec=head_lib._Head)
+ mock_head.logits_dimension = logits_dimension
+ mock_head.create_estimator_spec = test.mock.MagicMock(
+ wraps=_create_estimator_spec)
+
+ return mock_head
+
+
+class DNNModelFnTest(test.TestCase):
+ """Tests that _dnn_model_fn passes expected logits to mock head."""
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -104,388 +153,293 @@ class DNNRegressorEvaluateTest(test.TestCase):
if self._model_dir:
shutil.rmtree(self._model_dir)
- def test_simple(self):
- # Create checkpoint: num_inputs=1, hidden_units=(2, 2), num_outputs=1.
- global_step = 100
+ def _test_logits(
+ self, mode, hidden_units, logits_dimension, inputs, expected_logits):
+ """Tests that the expected logits are passed to mock head."""
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ head = _mock_head(
+ self,
+ hidden_units=hidden_units,
+ logits_dimension=logits_dimension,
+ expected_logits=expected_logits)
+ estimator_spec = dnn._dnn_model_fn(
+ features={'age': constant_op.constant(inputs)},
+ labels=constant_op.constant([[1]]),
+ mode=mode,
+ head=head,
+ hidden_units=hidden_units,
+ feature_columns=[
+ feature_column.numeric_column('age',
+ shape=np.array(inputs).shape[1:])],
+ optimizer=_mock_optimizer(self, hidden_units))
+ with monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=self._model_dir) as sess:
+ if mode == model_fn.ModeKeys.TRAIN:
+ sess.run(estimator_spec.train_op)
+ elif mode == model_fn.ModeKeys.EVAL:
+ sess.run(estimator_spec.loss)
+ elif mode == model_fn.ModeKeys.PREDICT:
+ sess.run(estimator_spec.predictions)
+ else:
+ self.fail('Invalid mode: {}'.format(mode))
+
+ def test_one_dim_logits(self):
+ """Tests one-dimensional logits.
+
+ input_layer = [[10]]
+ hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]
+ hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]
+ = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]
+ logits = [[-1*2.38 +1*0 +0.3]] = [[-2.08]]
+ """
+ base_global_step = 100
_create_checkpoint((
- (((1., 2.),), (3., 4.)),
- (((5., 6.), (7., 8.),), (9., 10.)),
- (((11.,), (12.,),), (13.,))
- ), global_step, self._model_dir)
-
- # Create DNNRegressor and evaluate.
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('age'),),
- model_dir=self._model_dir)
- def _input_fn():
- return {'age': ((1,),)}, ((10.,),)
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # prediction = 1778
- # loss = (10-1778)^2 = 3125824
- expected_loss = 3125824
- self.assertAllClose({
- metric_keys.MetricKeys.LOSS: expected_loss,
- metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
- ops.GraphKeys.GLOBAL_STEP: global_step
- }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))
+ ([[.6, .5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1.], [1.]], [.3]),
+ ), base_global_step, self._model_dir)
- def test_weighted(self):
- # Create checkpoint: num_inputs=1, hidden_units=(2, 2), num_outputs=1.
- global_step = 100
+ for mode in [model_fn.ModeKeys.TRAIN,
+ model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT]:
+ self._test_logits(
+ mode,
+ hidden_units=(2, 2),
+ logits_dimension=1,
+ inputs=[[10.]],
+ expected_logits=[[-2.08]])
+
+ def test_multi_dim_logits(self):
+ """Tests multi-dimensional logits.
+
+ input_layer = [[10]]
+ hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]
+ hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]
+ = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]
+ logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38]]
+ = [[-2.08, 2.08, 1.19]]
+ """
+ base_global_step = 100
_create_checkpoint((
- (((1., 2.),), (3., 4.)),
- (((5., 6.), (7., 8.),), (9., 10.)),
- (((11.,), (12.,),), (13.,))
- ), global_step, self._model_dir)
-
- # Create DNNRegressor and evaluate.
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('age'),),
- model_dir=self._model_dir,
- weight_feature_key='label_weight')
- def _input_fn():
- return {'age': ((1,),), 'label_weight': ((1.5,),)}, ((10.,),)
- self.assertAllClose({
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # prediction = 1778
- # loss = 1.5*((10-1778)^2) = 4688736
- metric_keys.MetricKeys.LOSS: 4688736,
- # average_loss = loss / 1.5 = 3125824
- metric_keys.MetricKeys.LOSS_MEAN: 3125824,
- ops.GraphKeys.GLOBAL_STEP: global_step
- }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))
+ ([[.6, .5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
+ ), base_global_step, self._model_dir)
- def test_multi_example(self):
- # Create initial checkpoint, 1 input, 2x2 hidden dims, 1 outputs.
- global_step = 100
+ for mode in [model_fn.ModeKeys.TRAIN,
+ model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT]:
+ self._test_logits(
+ mode,
+ hidden_units=(2, 2),
+ logits_dimension=3,
+ inputs=[[10.]],
+ expected_logits=[[-2.08, 2.08, 1.19]])
+
+ def test_multi_example_multi_dim_logits(self):
+ """Tests multiple examples and multi-dimensional logits.
+
+ input_layer = [[10], [5]]
+ hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)],
+ [relu(0.6*5 +0.1), relu(0.5*5 -0.1)]]
+ = [[6.1, 4.9], [3.1, 2.4]]
+ hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)],
+ [relu(1*3.1 -0.8*2.4 +0.2), relu(0.8*3.1 -1*2.4 -0.1)]]
+ = [[2.38, 0], [1.38, 0]]
+ logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38],
+ [-1*1.38 +0.3, 1*1.38 -0.3, 0.5*1.38]]
+ = [[-2.08, 2.08, 1.19], [-1.08, 1.08, 0.69]]
+ """
+ base_global_step = 100
_create_checkpoint((
- (((1., 2.),), (3., 4.)),
- (((5., 6.), (7., 8.),), (9., 10.)),
- (((11.,), (12.,),), (13.,))
- ), global_step, self._model_dir)
-
- # Create DNNRegressor and evaluate.
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('age'),),
- model_dir=self._model_dir)
- input_fn = numpy_io.numpy_input_fn(
- x={'age': np.array(((1,), (2,), (3,)))},
- y=np.array(((10,), (9,), (8,))),
- batch_size=3,
- shuffle=False)
- self.assertAllClose({
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = 1778, 2251, 2724
- # loss = ((10-1778)^2 + (9-2251)^2 + (8-2724)^2) = 15529044
- metric_keys.MetricKeys.LOSS: 15529044.,
- # average_loss = loss / 3 = 5176348
- metric_keys.MetricKeys.LOSS_MEAN: 5176348.,
- ops.GraphKeys.GLOBAL_STEP: global_step
- }, dnn_regressor.evaluate(input_fn=input_fn, steps=1))
+ ([[.6, .5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
+ ), base_global_step, self._model_dir)
- def test_multi_batch(self):
- # Create checkpoint: num_inputs=1, hidden_units=(2, 2), num_outputs=1.
- global_step = 100
+ for mode in [model_fn.ModeKeys.TRAIN,
+ model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT]:
+ self._test_logits(
+ mode,
+ hidden_units=(2, 2),
+ logits_dimension=3,
+ inputs=[[10.], [5.]],
+ expected_logits=[[-2.08, 2.08, 1.19], [-1.08, 1.08, .69]])
+
+ def test_multi_dim_input_one_dim_logits(self):
+ """Tests multi-dimensional inputs and one-dimensional logits.
+
+ input_layer = [[10, 8]]
+ hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]
+ = [[1.3, 0.9]]
+ hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]
+ = [[0.78, relu(-0.06)]] = [[0.78, 0]]
+ logits = [[-1*0.78 +1*0 +0.3]] = [[-0.48]]
+ """
+ base_global_step = 100
_create_checkpoint((
- (((1., 2.),), (3., 4.)),
- (((5., 6.), (7., 8.),), (9., 10.)),
- (((11.,), (12.,),), (13.,))
- ), global_step, self._model_dir)
-
- # Create DNNRegressor and evaluate.
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('age'),),
- model_dir=self._model_dir)
- input_fn = numpy_io.numpy_input_fn(
- x={'age': np.array(((1,), (2,), (3,)))},
- y=np.array(((10,), (9,), (8,))),
- batch_size=1,
- shuffle=False)
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = 1778, 2251, 2724
- # loss = ((10-1778)^2 + (9-2251)^2 + (8-2724)^2) / 3 = 5176348
- expected_loss = 5176348.
- self.assertAllClose({
- metric_keys.MetricKeys.LOSS: expected_loss,
- metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
- ops.GraphKeys.GLOBAL_STEP: global_step
- }, dnn_regressor.evaluate(input_fn=input_fn, steps=3))
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1.], [1.]], [.3]),
+ ), base_global_step, self._model_dir)
- def test_weighted_multi_example(self):
- # Create checkpoint: num_inputs=4, hidden_units=(2, 2), num_outputs=3.
- global_step = 100
+ for mode in [model_fn.ModeKeys.TRAIN,
+ model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT]:
+ self._test_logits(
+ mode,
+ hidden_units=(2, 2),
+ logits_dimension=1,
+ inputs=[[10., 8.]],
+ expected_logits=[[-0.48]])
+
+ def test_multi_dim_input_multi_dim_logits(self):
+ """Tests multi-dimensional inputs and multi-dimensional logits.
+
+ input_layer = [[10, 8]]
+ hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]
+ = [[1.3, 0.9]]
+ hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]
+ = [[0.78, relu(-0.06)]] = [[0.78, 0]]
+ logits = [[-1*0.78 + 0.3, 1*0.78 -0.3, 0.5*0.78]] = [[-0.48, 0.48, 0.39]]
+ """
+ base_global_step = 100
_create_checkpoint((
- (((1., 2.), (3., 4.), (5., 6.), (7., 8.),), (9., 8.)),
- (((7., 6.), (5., 4.),), (3., 2.)),
- (((1., 2., 3.), (4., 5., 6.),), (7., 8., 9.)),
- ), global_step, self._model_dir)
-
- # Create batched input.
- input_fn = numpy_io.numpy_input_fn(
- x={
- # Dimensions are (batch_size, feature_column.dimension).
- 'x': np.array((
- (15., 0., 1.5, 135.2),
- (45., 45000., 1.8, 158.8),
- (21., 33000., 1.7, 207.1),
- (60., 10000., 1.6, 90.2)
- )),
- # TODO(ptucker): Add test for different weight shapes when we fix
- # head._compute_weighted_loss (currently it requires weights to be
- # same shape as labels & logits).
- 'label_weight': np.array((
- (1., 1., 0.),
- (.5, 1., .1),
- (.5, 0., .9),
- (0., 0., 0.),
- ))
- },
- # Label shapes is (batch_size, num_outputs).
- y=np.array((
- (5., 2., 2.),
- (-2., 1., -4.),
- (-1., -1., -1.),
- (-4., 3., 9.),
- )),
- batch_size=4,
- shuffle=False)
-
- # Create DNNRegressor and evaluate.
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=(2, 2),
- feature_columns=(
- # Dimension is number of inputs.
- feature_column.numeric_column(
- 'x', dtype=dtypes.int32, shape=(4,)),
- ),
- model_dir=self._model_dir,
- label_dimension=3,
- weight_feature_key='label_weight')
- self.assertAllClose({
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = [
- # [ 54033.5 76909.6 99785.7]
- # [8030393.8 11433082.4 14835771.0]
- # [5923209.2 8433014.8 10942820.4]
- # [1810021.6 2576969.6 3343917.6]
- # ]
- # loss = sum(label_weights*(labels-predictions)^2) = 3.10290850204e+14
- metric_keys.MetricKeys.LOSS: 3.10290850204e+14,
- # average_loss = loss / sum(label_weights) = 3.10290850204e+14 / 5.
- # = 6.205817e+13
- metric_keys.MetricKeys.LOSS_MEAN: 6.205817e+13,
- ops.GraphKeys.GLOBAL_STEP: global_step
- }, dnn_regressor.evaluate(input_fn=input_fn, steps=1))
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
+ ), base_global_step, self._model_dir)
- def test_weighted_multi_example_multi_column(self):
- # Create checkpoint: num_inputs=4, hidden_units=(2, 2), num_outputs=3.
- global_step = 100
+ for mode in [model_fn.ModeKeys.TRAIN,
+ model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT]:
+ self._test_logits(
+ mode,
+ hidden_units=(2, 2),
+ logits_dimension=3,
+ inputs=[[10., 8.]],
+ expected_logits=[[-0.48, 0.48, 0.39]])
+
+ def test_multi_feature_column_multi_dim_logits(self):
+ """Tests multiple feature columns and multi-dimensional logits.
+
+ All numbers are the same as test_multi_dim_input_multi_dim_logits. The only
+ difference is that the input consists of two 1D feature columns, instead of
+ one 2D feature column.
+ """
+ base_global_step = 100
_create_checkpoint((
- (((1., 2.), (3., 4.), (5., 6.), (7., 8.),), (9., 8.)),
- (((7., 6.), (5., 4.),), (3., 2.)),
- (((1., 2., 3.), (4., 5., 6.),), (7., 8., 9.)),
- ), global_step, self._model_dir)
-
- # Create batched input.
- input_fn = numpy_io.numpy_input_fn(
- x={
- # Dimensions are (batch_size, feature_column.dimension).
- 'x': np.array((
- (15., 0.),
- (45., 45000.),
- (21., 33000.),
- (60., 10000.)
- )),
- 'y': np.array((
- (1.5, 135.2),
- (1.8, 158.8),
- (1.7, 207.1),
- (1.6, 90.2)
- )),
- # TODO(ptucker): Add test for different weight shapes when we fix
- # head._compute_weighted_loss (currently it requires weights to be
- # same shape as labels & logits).
- 'label_weight': np.array((
- (1., 1., 0.),
- (.5, 1., .1),
- (.5, 0., .9),
- (0., 0., 0.),
- ))
- },
- # Label shapes is (batch_size, num_outputs).
- y=np.array((
- (5., 2., 2.),
- (-2., 1., -4.),
- (-1., -1., -1.),
- (-4., 3., 9.),
- )),
- batch_size=4,
- shuffle=False)
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
+ ), base_global_step, self._model_dir)
+ hidden_units = (2, 2)
+ logits_dimension = 3
+ inputs = ([[10.]], [[8.]])
+ expected_logits = [[-0.48, 0.48, 0.39]]
+
+ for mode in [model_fn.ModeKeys.TRAIN,
+ model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT]:
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ head = _mock_head(
+ self,
+ hidden_units=hidden_units,
+ logits_dimension=logits_dimension,
+ expected_logits=expected_logits)
+ estimator_spec = dnn._dnn_model_fn(
+ features={'age': constant_op.constant(inputs[0]),
+ 'height': constant_op.constant(inputs[1])},
+ labels=constant_op.constant([[1]]),
+ mode=mode,
+ head=head,
+ hidden_units=hidden_units,
+ feature_columns=[
+ feature_column.numeric_column('age'),
+ feature_column.numeric_column('height')],
+ optimizer=_mock_optimizer(self, hidden_units))
+ with monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=self._model_dir) as sess:
+ if mode == model_fn.ModeKeys.TRAIN:
+ sess.run(estimator_spec.train_op)
+ elif mode == model_fn.ModeKeys.EVAL:
+ sess.run(estimator_spec.loss)
+ elif mode == model_fn.ModeKeys.PREDICT:
+ sess.run(estimator_spec.predictions)
+ else:
+ self.fail('Invalid mode: {}'.format(mode))
- # Create DNNRegressor and evaluate.
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=(2, 2),
- feature_columns=(
- # Dimensions add up to 4 (number of inputs).
- feature_column.numeric_column(
- 'x', dtype=dtypes.int32, shape=(2,)),
- feature_column.numeric_column(
- 'y', dtype=dtypes.float32, shape=(2,)),
- ),
- model_dir=self._model_dir,
- label_dimension=3,
- weight_feature_key='label_weight')
- self.assertAllClose({
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = [
- # [ 54033.5 76909.6 99785.7]
- # [8030393.8 11433082.4 14835771.0]
- # [5923209.2 8433014.8 10942820.4]
- # [1810021.6 2576969.6 3343917.6]
- # ]
- # loss = sum(label_weights*(labels-predictions)^2) = 3.10290850204e+14
- metric_keys.MetricKeys.LOSS: 3.10290850204e+14,
- # average_loss = loss / sum(label_weights) = 3.10290850204e+14 / 5.
- # = 6.205817e+13
- metric_keys.MetricKeys.LOSS_MEAN: 6.205817e+13,
- ops.GraphKeys.GLOBAL_STEP: global_step
- }, dnn_regressor.evaluate(input_fn=input_fn, steps=1))
- def test_weighted_multi_batch(self):
- # Create checkpoint: num_inputs=4, hidden_units=(2, 2), num_outputs=3.
- global_step = 100
- _create_checkpoint((
- (((1., 2.), (3., 4.), (5., 6.), (7., 8.),), (9., 8.)),
- (((7., 6.), (5., 4.),), (3., 2.)),
- (((1., 2., 3.), (4., 5., 6.),), (7., 8., 9.)),
- ), global_step, self._model_dir)
+class DNNRegressorEvaluateTest(test.TestCase):
- # Create batched input.
- input_fn = numpy_io.numpy_input_fn(
- x={
- # Dimensions are (batch_size, feature_column.dimension).
- 'x': np.array((
- (15., 0., 1.5, 135.2),
- (45., 45000., 1.8, 158.8),
- (21., 33000., 1.7, 207.1),
- (60., 10000., 1.6, 90.2)
- )),
- # TODO(ptucker): Add test for different weight shapes when we fix
- # head._compute_weighted_loss (currently it requires weights to be
- # same shape as labels & logits).
- 'label_weights': np.array((
- (1., 1., 0.),
- (.5, 1., .1),
- (.5, 0., .9),
- (0., 0., 0.),
- ))
- },
- # Label shapes is (batch_size, num_outputs).
- y=np.array((
- (5., 2., 2.),
- (-2., 1., -4.),
- (-1., -1., -1.),
- (-4., 3., 9.),
- )),
- batch_size=1,
- shuffle=False)
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
- # Create DNNRegressor and evaluate.
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=(2, 2),
- feature_columns=(
- # Dimension is number of inputs.
- feature_column.numeric_column(
- 'x', dtype=dtypes.int32, shape=(4,)),
- ),
- model_dir=self._model_dir,
- label_dimension=3,
- weight_feature_key='label_weights')
- self.assertAllClose({
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = [
- # [ 54033.5 76909.6 99785.7]
- # [8030393.8 11433082.4 14835771.0]
- # [5923209.2 8433014.8 10942820.4]
- # [1810021.6 2576969.6 3343917.6]
- # ]
- # losses = label_weights*(labels-predictions)^2 = [
- # [ 2.91907881e+09 5.91477894e+09 0]
- # [ 3.22436284e+13 1.30715350e+14 2.20100220e+13]
- # [ 1.75422095e+13 0 1.07770806e+14]
- # [ 0 0 0]
- # ]
- # total_loss = sum(losses) = 3.10290850204e+14
- # loss = total_loss / 4 = 7.7572712551e+13
- metric_keys.MetricKeys.LOSS: 7.7572712551e+13,
- # average_loss = total_loss / sum(label_weights) = 6.20581700408e+13
- metric_keys.MetricKeys.LOSS_MEAN: 6.20581700408e+13,
- ops.GraphKeys.GLOBAL_STEP: global_step
- }, dnn_regressor.evaluate(input_fn=input_fn, steps=4))
+ def tearDown(self):
+ if self._model_dir:
+ shutil.rmtree(self._model_dir)
- def test_multi_dim(self):
- # Create checkpoint: num_inputs=3, hidden_units=(2, 2), num_outputs=2.
+ def test_one_dim(self):
+ """Asserts evaluation metrics for one-dimensional input and logits."""
+ # Create checkpoint: num_inputs=1, hidden_units=(2, 2), num_outputs=1.
global_step = 100
_create_checkpoint((
- (((1., 2.), (3., 4.), (5., 6.),), (7., 8.)),
- (((9., 8.), (7., 6.),), (5., 4.)),
- (((3., 2.), (1., 2.),), (3., 4.)),
+ ([[.6, .5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1.], [1.]], [.3]),
), global_step, self._model_dir)
# Create DNNRegressor and evaluate.
dnn_regressor = dnn.DNNRegressor(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x', shape=(3,)),),
- label_dimension=2,
+ feature_columns=[feature_column.numeric_column('age')],
model_dir=self._model_dir)
- input_fn = numpy_io.numpy_input_fn(
- x={'x': np.array(((2., 4., 5.),))},
- y=np.array(((46., 58.),)),
- batch_size=1,
- shuffle=False)
+ def _input_fn():
+ return {'age': [[10.]]}, [[1.]]
+ # Uses identical numbers as DNNModelTest.test_one_dim_logits.
+ # See that test for calculation of logits.
+ # logits = [[-2.08]] => predictions = [-2.08].
+ # loss = (1+2.08)^2 = 9.4864
+ expected_loss = 9.4864
self.assertAllClose({
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = 3198, 3094
- # loss = ((46-3198)^2 + (58-3094)^2) = 19152400
- metric_keys.MetricKeys.LOSS: 19152400,
- # average_loss = loss / 2 = 9576200
- metric_keys.MetricKeys.LOSS_MEAN: 9576200,
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
ops.GraphKeys.GLOBAL_STEP: global_step
- }, dnn_regressor.evaluate(input_fn=input_fn, steps=1))
+ }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))
- def test_multi_feature_column(self):
- # Create checkpoint: num_inputs=2, hidden_units=(2, 2), num_outputs=1.
+ def test_multi_dim(self):
+ """Asserts evaluation metrics for multi-dimensional input and logits."""
+ # Create checkpoint: num_inputs=2, hidden_units=(2, 2), num_outputs=3.
global_step = 100
_create_checkpoint((
- (((1., 2.), (3., 4.),), (5., 6.)),
- (((7., 8.), (9., 8.),), (7., 6.)),
- (((5.,), (4.,),), (3.,))
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
), global_step, self._model_dir)
+ label_dimension = 3
# Create DNNRegressor and evaluate.
dnn_regressor = dnn.DNNRegressor(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('age'),
- feature_column.numeric_column('height')),
+ feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ label_dimension=label_dimension,
model_dir=self._model_dir)
- input_fn = numpy_io.numpy_input_fn(
- x={'age': np.array(((20,), (40,))), 'height': np.array(((4,), (8,)))},
- y=np.array(((213.,), (421.,))),
- batch_size=2,
- shuffle=False)
+ def _input_fn():
+ return {'age': [[10., 8.]]}, [[1., -1., 0.5]]
+ # Uses identical numbers as
+ # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.
+ # See that test for calculation of logits.
+ # logits = [[-0.48, 0.48, 0.39]]
+ # loss = (1+0.48)^2 + (-1-0.48)^2 + (0.5-0.39)^2 = 4.3929
+ expected_loss = 4.3929
self.assertAllClose({
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = 7315, 13771
- # loss = ((213-7315)^2 + (421-13771)^2) / 2 = 228660896
- metric_keys.MetricKeys.LOSS: 228660896.,
- # average_loss = loss / 2 = 114330452
- metric_keys.MetricKeys.LOSS_MEAN: 114330452.,
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss / label_dimension,
ops.GraphKeys.GLOBAL_STEP: global_step
- }, dnn_regressor.evaluate(input_fn=input_fn, steps=1))
+ }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))
class DNNRegressorPredictTest(test.TestCase):
@@ -497,13 +451,13 @@ class DNNRegressorPredictTest(test.TestCase):
if self._model_dir:
shutil.rmtree(self._model_dir)
- def test_1d(self):
- """Tests predict when all variables are one-dimensional."""
+ def test_one_dim(self):
+ """Asserts predictions for one-dimensional input and logits."""
# Create checkpoint: num_inputs=1, hidden_units=(2, 2), num_outputs=1.
_create_checkpoint((
- (((1., 2.),), (3., 4.)),
- (((5., 6.), (7., 8.),), (9., 10.)),
- (((11.,), (12.,),), (13.,))
+ ([[.6, .5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1.], [1.]], [.3]),
), global_step=0, model_dir=self._model_dir)
# Create DNNRegressor and predict.
@@ -512,66 +466,40 @@ class DNNRegressorPredictTest(test.TestCase):
feature_columns=(feature_column.numeric_column('x'),),
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
- x={'x': np.array(((1.,),))}, batch_size=1, shuffle=False)
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # prediction = 1778
+ x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)
+ # Uses identical numbers as DNNModelTest.test_one_dim_logits.
+ # See that test for calculation of logits.
+ # logits = [[-2.08]] => predictions = [-2.08].
self.assertAllClose({
- prediction_keys.PredictionKeys.PREDICTIONS: (1778.,)
+ prediction_keys.PredictionKeys.PREDICTIONS: [-2.08],
}, next(dnn_regressor.predict(input_fn=input_fn)))
def test_multi_dim(self):
- """Tests predict when all variables are multi-dimenstional."""
- # Create checkpoint: num_inputs=4, hidden_units=(2, 2), num_outputs=3.
+ """Asserts predictions for multi-dimensional input and logits."""
+ # Create checkpoint: num_inputs=2, hidden_units=(2, 2), num_outputs=3.
_create_checkpoint((
- (((1., 2.), (3., 4.), (5., 6.), (7., 8.),), (9., 8.)),
- (((7., 6.), (5., 4.),), (3., 2.)),
- (((1., 2., 3.), (4., 5., 6.),), (7., 8., 9.)),
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
), 100, self._model_dir)
# Create DNNRegressor and predict.
dnn_regressor = dnn.DNNRegressor(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x', shape=(4,)),),
+ feature_columns=(feature_column.numeric_column('x', shape=(2,)),),
label_dimension=3,
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
# Inputs shape is (batch_size, num_inputs).
- x={'x': np.array(((1., 2., 3., 4.), (5., 6., 7., 8.)))},
- batch_size=2,
- shuffle=False)
- # Output shape=(batch_size, num_outputs).
- self.assertAllClose((
- # TODO(ptucker): Point to tool for calculating a neural net output?
- (3275., 4660., 6045.),
- (6939., 9876., 12813.)
- ), tuple([
- x[prediction_keys.PredictionKeys.PREDICTIONS]
- for x in dnn_regressor.predict(input_fn=input_fn)
- ]), rtol=1e-04)
-
- def test_two_feature_columns(self):
- """Tests predict with two feature columns."""
- # Create checkpoint: num_inputs=2, hidden_units=(2, 2), num_outputs=1.
- _create_checkpoint((
- (((1., 2.), (3., 4.),), (5., 6.)),
- (((7., 8.), (9., 8.),), (7., 6.)),
- (((5.,), (4.,),), (3.,))
- ), 100, self._model_dir)
-
- # Create DNNRegressor and predict.
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x'),
- feature_column.numeric_column('y')),
- model_dir=self._model_dir)
- input_fn = numpy_io.numpy_input_fn(
- x={'x': np.array((20.,)), 'y': np.array((4.,))},
+ x={'x': np.array([[10., 8.]])},
batch_size=1,
shuffle=False)
+ # Uses identical numbers as
+ # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.
+ # See that test for calculation of logits.
+ # logits = [[-0.48, 0.48, 0.39]] => predictions = [-0.48, 0.48, 0.39]
self.assertAllClose({
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = 7315
- prediction_keys.PredictionKeys.PREDICTIONS: (7315,)
+ prediction_keys.PredictionKeys.PREDICTIONS: [-0.48, 0.48, 0.39],
}, next(dnn_regressor.predict(input_fn=input_fn)))
@@ -627,7 +555,6 @@ class DNNRegressorIntegrationTest(test.TestCase):
for x in est.predict(predict_input_fn)
])
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
- # TODO(ptucker): Deterministic test for predicted values?
# EXPORT
feature_spec = feature_column.make_parse_example_spec(feature_columns)
@@ -642,22 +569,25 @@ def _full_var_name(var_name):
return '%s/part_0:0' % var_name
-def _assert_close(expected, actual, rtol=1e-04, name='assert_close'):
+def _assert_close(
+ expected, actual, rtol=1e-04, message='', name='assert_close'):
with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:
expected = ops.convert_to_tensor(expected, name='expected')
actual = ops.convert_to_tensor(actual, name='actual')
- rdiff = math_ops.abs(expected - actual, 'diff') / expected
+ rdiff = math_ops.abs((expected - actual) / expected, 'diff')
rtol = ops.convert_to_tensor(rtol, name='rtol')
return check_ops.assert_less(
rdiff,
rtol,
data=(
+ message,
'Condition expected =~ actual did not hold element-wise:'
'expected = ', expected,
'actual = ', actual,
'rdiff = ', rdiff,
'rtol = ', rtol,
),
+ summarize=expected.get_shape().num_elements(),
name=scope)
@@ -833,20 +763,21 @@ class DNNRegressorTrainTest(test.TestCase):
self.assertIn(metric_keys.MetricKeys.LOSS, summary_keys)
self.assertIn(metric_keys.MetricKeys.LOSS_MEAN, summary_keys)
- def test_simple(self):
+ def test_one_dim(self):
+ """Asserts train loss for one-dimensional input and logits."""
base_global_step = 100
hidden_units = (2, 2)
_create_checkpoint((
- (((1., 2.),), (3., 4.)),
- (((5., 6.), (7., 8.),), (9., 10.)),
- (((11.,), (12.,),), (13.,))
+ ([[.6, .5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1.], [1.]], [.3]),
), base_global_step, self._model_dir)
- # Create DNNRegressor with mock optimizer.
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # prediction = 1778
- # loss = (10-1778)^2 = 3125824
- expected_loss = 3125824.
+ # Uses identical numbers as DNNModelFnTest.test_one_dim_logits.
+ # See that test for calculation of logits.
+ # logits = [-2.08] => predictions = [-2.08]
+ # loss = (1 + 2.08)^2 = 9.4864
+ expected_loss = 9.4864
mock_optimizer = _mock_optimizer(
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_regressor = dnn.DNNRegressor(
@@ -861,7 +792,7 @@ class DNNRegressorTrainTest(test.TestCase):
num_steps = 5
summary_hook = _SummaryHook()
dnn_regressor.train(
- input_fn=lambda: ({'age': ((1,),)}, ((10.,),)), steps=num_steps,
+ input_fn=lambda: ({'age': [[10.]]}, [[1.]]), steps=num_steps,
hooks=(summary_hook,))
self.assertEqual(1, mock_optimizer.minimize.call_count)
summaries = summary_hook.summaries()
@@ -871,11 +802,8 @@ class DNNRegressorTrainTest(test.TestCase):
self,
{
metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
- 'dnn/dnn/hiddenlayer_0_activation': 0.,
'dnn/dnn/hiddenlayer_0_fraction_of_zero_values': 0.,
- 'dnn/dnn/hiddenlayer_1_activation': 0.,
- 'dnn/dnn/hiddenlayer_1_fraction_of_zero_values': 0.,
- 'dnn/dnn/logits_activation': 0.,
+ 'dnn/dnn/hiddenlayer_1_fraction_of_zero_values': 0.5,
'dnn/dnn/logits_fraction_of_zero_values': 0.,
metric_keys.MetricKeys.LOSS: expected_loss,
},
@@ -884,28 +812,33 @@ class DNNRegressorTrainTest(test.TestCase):
self, base_global_step + num_steps, input_units=1,
hidden_units=hidden_units, output_units=1, model_dir=self._model_dir)
- def test_activation_fn(self):
+ def test_multi_dim(self):
+ """Asserts train loss for multi-dimensional input and logits."""
base_global_step = 100
hidden_units = (2, 2)
_create_checkpoint((
- (((1., 2.),), (3., 4.)),
- (((5., 6.), (7., 8.),), (9., 10.)),
- (((11.,), (12.,),), (13.,))
+ ([[.6, .5], [-.6, -.5]], [.1, -.1]),
+ ([[1., .8], [-.8, -1.]], [.2, -.2]),
+ ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
), base_global_step, self._model_dir)
-
- # Create DNNRegressor with mock optimizer.
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # prediction = 36
- # loss = (10-36)^2 = 676
- expected_loss = 676.
+ input_dimension = 2
+ label_dimension = 3
+
+ # Uses identical numbers as
+ # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.
+ # See that test for calculation of logits.
+ # logits = [[-0.48, 0.48, 0.39]]
+ # loss = (1+0.48)^2 + (-1-0.48)^2 + (0.5-0.39)^2 = 4.3929
+ expected_loss = 4.3929
mock_optimizer = _mock_optimizer(
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_regressor = dnn.DNNRegressor(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=[
+ feature_column.numeric_column('age', shape=[input_dimension])],
+ label_dimension=label_dimension,
optimizer=mock_optimizer,
- model_dir=self._model_dir,
- activation_fn=nn.tanh)
+ model_dir=self._model_dir)
self.assertEqual(0, mock_optimizer.minimize.call_count)
# Train for a few steps, then validate optimizer, summaries, and
@@ -913,7 +846,8 @@ class DNNRegressorTrainTest(test.TestCase):
num_steps = 5
summary_hook = _SummaryHook()
dnn_regressor.train(
- input_fn=lambda: ({'age': ((1,),)}, ((10.,),)), steps=num_steps,
+ input_fn=lambda: ({'age': [[10., 8.]]}, [[1., -1., 0.5]]),
+ steps=num_steps,
hooks=(summary_hook,))
self.assertEqual(1, mock_optimizer.minimize.call_count)
summaries = summary_hook.summaries()
@@ -922,236 +856,16 @@ class DNNRegressorTrainTest(test.TestCase):
_assert_simple_summary(
self,
{
- metric_keys.MetricKeys.LOSS: expected_loss,
- metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
- 'dnn/dnn/hiddenlayer_0_activation': 0.,
+ metric_keys.MetricKeys.LOSS_MEAN: expected_loss / label_dimension,
'dnn/dnn/hiddenlayer_0_fraction_of_zero_values': 0.,
- 'dnn/dnn/hiddenlayer_1_activation': 0.,
- 'dnn/dnn/hiddenlayer_1_fraction_of_zero_values': 0.,
- 'dnn/dnn/logits_activation': 0.,
+ 'dnn/dnn/hiddenlayer_1_fraction_of_zero_values': 0.5,
'dnn/dnn/logits_fraction_of_zero_values': 0.,
+ metric_keys.MetricKeys.LOSS: expected_loss,
},
summary)
_assert_checkpoint(
- self, base_global_step + num_steps, input_units=1,
- hidden_units=hidden_units, output_units=1, model_dir=self._model_dir)
-
- def test_weighted_multi_example_multi_column(self):
- hidden_units = (2, 2)
- base_global_step = 100
- _create_checkpoint((
- (((1., 2.), (3., 4.), (5., 6.), (7., 8.),), (9., 8.)),
- (((7., 6.), (5., 4.),), (3., 2.)),
- (((1., 2., 3.), (4., 5., 6.),), (7., 8., 9.)),
- ), base_global_step, self._model_dir)
-
- # Create DNNRegressor with mock optimizer.
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = [
- # [ 54033.5 76909.6 99785.7]
- # [8030393.8 11433082.4 14835771.0]
- # [5923209.2 8433014.8 10942820.4]
- # [1810021.6 2576969.6 3343917.6]
- # ]
- # loss = sum(label_weights*(labels-predictions)^2) = 3.10290850204e+14
- expected_loss = 3.10290850204e+14
- mock_optimizer = _mock_optimizer(
- self, hidden_units=hidden_units, expected_loss=expected_loss)
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=hidden_units,
- feature_columns=(
- # Dimensions add up to 4 (number of inputs).
- feature_column.numeric_column(
- 'x', dtype=dtypes.int32, shape=(2,)),
- feature_column.numeric_column(
- 'y', dtype=dtypes.float32, shape=(2,)),
- ),
- optimizer=mock_optimizer,
- model_dir=self._model_dir,
- label_dimension=3,
- weight_feature_key='label_weights')
- self.assertEqual(0, mock_optimizer.minimize.call_count)
-
- # Create batched inputs.
- input_fn = numpy_io.numpy_input_fn(
- # NOTE: feature columns are concatenated in alphabetic order of keys.
- x={
- # Inputs shapes are (batch_size, feature_column.dimension).
- 'x': np.array((
- (15., 0.),
- (45., 45000.),
- (21., 33000.),
- (60., 10000.)
- )),
- 'y': np.array((
- (1.5, 135.2),
- (1.8, 158.8),
- (1.7, 207.1),
- (1.6, 90.2)
- )),
- # TODO(ptucker): Add test for different weight shapes when we fix
- # head._compute_weighted_loss (currently it requires weights to be
- # same shape as labels & logits).
- 'label_weights': np.array((
- (1., 1., 0.),
- (.5, 1., .1),
- (.5, 0., .9),
- (0., 0., 0.),
- ))
- },
- # Labels shapes is (batch_size, num_outputs).
- y=np.array((
- (5., 2., 2.),
- (-2., 1., -4.),
- (-1., -1., -1.),
- (-4., 3., 9.),
- )),
- batch_size=4,
- num_epochs=None,
- shuffle=False)
-
- # Train for 1 step, then validate optimizer, summaries, and checkpoint.
- summary_hook = _SummaryHook()
- dnn_regressor.train(input_fn=input_fn, steps=1, hooks=(summary_hook,))
- self.assertEqual(1, mock_optimizer.minimize.call_count)
- summaries = summary_hook.summaries()
- self.assertEqual(1, len(summaries))
- _assert_simple_summary(
- self,
- {
- metric_keys.MetricKeys.LOSS: expected_loss,
- # average_loss = loss / sum(label_weights) = 3.10290850204e+14 / 5.
- # = 6.205817e+13
- metric_keys.MetricKeys.LOSS_MEAN: 6.205817e+13,
- 'dnn/dnn/hiddenlayer_0_activation': 0.,
- 'dnn/dnn/hiddenlayer_0_fraction_of_zero_values': 0.,
- 'dnn/dnn/hiddenlayer_1_activation': 0.,
- 'dnn/dnn/hiddenlayer_1_fraction_of_zero_values': 0.,
- 'dnn/dnn/logits_activation': 0.,
- 'dnn/dnn/logits_fraction_of_zero_values': 0.,
- },
- summaries[0])
- _assert_checkpoint(
- self,
- base_global_step + 1,
- input_units=4, # Sum of feature column dimensions.
- hidden_units=hidden_units,
- output_units=3, # = label_dimension
- model_dir=self._model_dir)
-
- # Train for 3 steps - we should still get the same loss since we're not
- # updating weights.
- dnn_regressor.train(input_fn=input_fn, steps=3)
- self.assertEqual(2, mock_optimizer.minimize.call_count)
- _assert_checkpoint(
- self,
- base_global_step + 4,
- input_units=4, # Sum of feature column dimensions.
- hidden_units=hidden_units,
- output_units=3, # = label_dimension
- model_dir=self._model_dir)
-
- def test_weighted_multi_batch(self):
- hidden_units = (2, 2)
- base_global_step = 100
- _create_checkpoint((
- (((1., 2.), (3., 4.), (5., 6.), (7., 8.),), (9., 8.)),
- (((7., 6.), (5., 4.),), (3., 2.)),
- (((1., 2., 3.), (4., 5., 6.),), (7., 8., 9.)),
- ), base_global_step, self._model_dir)
-
- mock_optimizer = _mock_optimizer(self, hidden_units=hidden_units)
- dnn_regressor = dnn.DNNRegressor(
- hidden_units=hidden_units,
- feature_columns=(
- # Dimension is number of inputs.
- feature_column.numeric_column(
- 'x', dtype=dtypes.int32, shape=(4,)),
- ),
- optimizer=mock_optimizer,
- model_dir=self._model_dir,
- label_dimension=3,
- weight_feature_key='label_weights')
- self.assertEqual(0, mock_optimizer.minimize.call_count)
-
- # Create batched input.
- input_fn = numpy_io.numpy_input_fn(
- x={
- # Inputs shape is (batch_size, feature_column.dimension).
- 'x': np.array((
- (15., 0., 1.5, 135.2),
- (45., 45000., 1.8, 158.8),
- (21., 33000., 1.7, 207.1),
- (60., 10000., 1.6, 90.2)
- )),
- # TODO(ptucker): Add test for different weight shapes when we fix
- # head._compute_weighted_loss (currently it requires weights to be
- # same shape as labels & logits).
- 'label_weights': np.array((
- (1., 1., 0.),
- (.5, 1., .1),
- (.5, 0., .9),
- (0., 0., 0.),
- ))
- },
- # Labels shapes is (batch_size, num_outputs).
- y=np.array((
- (5., 2., 2.),
- (-2., 1., -4.),
- (-1., -1., -1.),
- (-4., 3., 9.),
- )),
- batch_size=1,
- shuffle=False)
-
- # Train for 1 step, then validate optimizer, summaries, and checkpoint.
- num_steps = 4
- summary_hook = _SummaryHook()
- dnn_regressor.train(
- input_fn=input_fn, steps=num_steps, hooks=(summary_hook,))
- self.assertEqual(1, mock_optimizer.minimize.call_count)
- summaries = summary_hook.summaries()
- self.assertEqual(num_steps, len(summaries))
- # TODO(ptucker): Point to tool for calculating a neural net output?
- # predictions = [
- # [ 54033.5 76909.6 99785.7]
- # [8030393.8 11433082.4 14835771.0]
- # [5923209.2 8433014.8 10942820.4]
- # [1810021.6 2576969.6 3343917.6]
- # ]
- # losses = label_weights*(labels-predictions)^2 = [
- # [2.91907881e+09 5.91477894e+09 0]
- # [3.22436284e+13 1.30715350e+14 2.20100220e+13]
- # [1.75422095e+13 0 1.07770806e+14]
- # [ 0 0 0]
- # ]
- # step_losses = [sum(losses[i]) for i in 0...3]
- # = [8833857750, 1.84969e+14, 1.2531302e+14, 0]
- expected_step_losses = (8833857750, 1.84969e+14, 1.2531302e+14, 0)
- # step_average_losses = [
- # step_losses[i] / sum(label_weights[i]) for i in 0...3
- # ] = [4416928875, 1.1560563e+14, 8.95093e+13, 0]
- expected_step_average_losses = (4416928875, 1.1560563e+14, 8.95093e+13, 0)
- for i in range(len(summaries)):
- _assert_simple_summary(
- self,
- {
- metric_keys.MetricKeys.LOSS: expected_step_losses[i],
- metric_keys.MetricKeys.LOSS_MEAN: expected_step_average_losses[i],
- 'dnn/dnn/hiddenlayer_0_activation': 0.,
- 'dnn/dnn/hiddenlayer_0_fraction_of_zero_values': 0.,
- 'dnn/dnn/hiddenlayer_1_activation': 0.,
- 'dnn/dnn/hiddenlayer_1_fraction_of_zero_values': 0.,
- 'dnn/dnn/logits_activation': 0.,
- 'dnn/dnn/logits_fraction_of_zero_values': 0.,
- },
- summaries[i])
- _assert_checkpoint(
- self,
- base_global_step + num_steps,
- input_units=4, # Sum of feature column dimensions.
- hidden_units=hidden_units,
- output_units=3, # = label_dimension
+ self, base_global_step + num_steps, input_units=input_dimension,
+ hidden_units=hidden_units, output_units=label_dimension,
model_dir=self._model_dir)
@@ -1233,7 +947,8 @@ class DNNClassifierTrainTest(test.TestCase):
([[-1.], [1.]], [.3]),
), base_global_step, self._model_dir)
- # Create DNNClassifier with mock optimizer.
+ # Uses identical numbers as DNNModelFnTest.test_one_dim_logits.
+ # See that test for calculation of logits.
# logits = [-2.08] => probabilities = [0.889, 0.111]
# loss = -1. * log(0.111) = 2.19772100
expected_loss = 2.19772100
@@ -1281,7 +996,8 @@ class DNNClassifierTrainTest(test.TestCase):
([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),
), base_global_step, self._model_dir)
- # Create DNNClassifier with mock optimizer.
+ # Uses identical numbers as DNNModelFnTest.test_multi_dim_logits.
+ # See that test for calculation of logits.
# logits = [-2.08, 2.08, 1.19] => probabilities = [0.0109, 0.7011, 0.2879]
# loss = -1. * log(0.7011) = 0.35505795
expected_loss = 0.35505795