aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/dnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/dnn.py')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py
index 4bb90cf81b..9efa8f474d 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn.py
@@ -112,7 +112,8 @@ class DNNEstimator(estimator.Estimator):
dropout=None,
input_layer_partitioner=None,
config=None,
- warm_start_from=None):
+ warm_start_from=None,
+ batch_norm=False):
"""Initializes a `DNNEstimator` instance.
Args:
@@ -142,6 +143,7 @@ class DNNEstimator(estimator.Estimator):
string filepath is provided instead of a `WarmStartSettings`, then all
weights are warm-started, and it is assumed that vocabularies and Tensor
names are unchanged.
+ batch_norm: Whether to use batch normalization after each hidden layer.
"""
def _model_fn(features, labels, mode, config):
return dnn_lib._dnn_model_fn( # pylint: disable=protected-access
@@ -155,7 +157,8 @@ class DNNEstimator(estimator.Estimator):
activation_fn=activation_fn,
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
- config=config)
+ config=config,
+ batch_norm=batch_norm)
super(DNNEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
warm_start_from=warm_start_from)