diff options
author | Michael Case <mikecase@google.com> | 2018-09-04 17:46:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 18:02:15 -0700 |
commit | bfde272cf661d942b11877a8709739a09c5d41fd (patch) | |
tree | 307fd2a92b916ad36b4d4fb72a636dee7b2d6c49 /tensorflow/python/estimator | |
parent | 65899c10ab9a384670369257662c7c00fca12f19 (diff) |
Disable variable partitioning from TPU DNN canned estimator.
PiperOrigin-RevId: 211557743
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/canned/dnn.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index c08cf61220..1c0c4581c0 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -142,7 +142,7 @@ def _dnn_model_fn(features, dropout=None, input_layer_partitioner=None, config=None, - tpu_estimator_spec=False, + use_tpu=False, batch_norm=False): """Deep Neural Net model_fn. @@ -164,8 +164,8 @@ def _dnn_model_fn(features, input_layer_partitioner: Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. - tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or - or `model_fn.EstimatorSpec` instance. + use_tpu: Whether to make a DNN model able to run on TPU. Will make function + return a `_TPUEstimatorSpec` instance and disable variable partitioning. batch_norm: Whether to use batch normalization after each hidden layer. Returns: @@ -182,13 +182,15 @@ def _dnn_model_fn(features, optimizer, learning_rate=_LEARNING_RATE) num_ps_replicas = config.num_ps_replicas if config else 0 - partitioner = partitioned_variables.min_max_variable_partitioner( - max_partitions=num_ps_replicas) + partitioner = (None if use_tpu else + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas)) with variable_scope.variable_scope( 'dnn', values=tuple(six.itervalues(features)), partitioner=partitioner): input_layer_partitioner = input_layer_partitioner or ( + None if use_tpu else partitioned_variables.min_max_variable_partitioner( max_partitions=num_ps_replicas, min_slice_size=64 << 20)) @@ -203,7 +205,7 @@ def _dnn_model_fn(features, batch_norm=batch_norm) logits = logit_fn(features=features, mode=mode) - if tpu_estimator_spec: + if use_tpu: return head._create_tpu_estimator_spec( # pylint: disable=protected-access features=features, mode=mode, |