diff options
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/dnn.py')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/dnn.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index 4bb90cf81b..9efa8f474d 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -112,7 +112,8 @@ class DNNEstimator(estimator.Estimator): dropout=None, input_layer_partitioner=None, config=None, - warm_start_from=None): + warm_start_from=None, + batch_norm=False): """Initializes a `DNNEstimator` instance. Args: @@ -142,6 +143,7 @@ class DNNEstimator(estimator.Estimator): string filepath is provided instead of a `WarmStartSettings`, then all weights are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + batch_norm: Whether to use batch normalization after each hidden layer. """ def _model_fn(features, labels, mode, config): return dnn_lib._dnn_model_fn( # pylint: disable=protected-access @@ -155,7 +157,8 @@ class DNNEstimator(estimator.Estimator): activation_fn=activation_fn, dropout=dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm) super(DNNEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, warm_start_from=warm_start_from) |