aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-09-04 17:46:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 18:02:15 -0700
commitbfde272cf661d942b11877a8709739a09c5d41fd (patch)
tree307fd2a92b916ad36b4d4fb72a636dee7b2d6c49 /tensorflow/python/estimator
parent65899c10ab9a384670369257662c7c00fca12f19 (diff)
Disable variable partitioning from TPU DNN canned estimator.
PiperOrigin-RevId: 211557743
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/canned/dnn.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index c08cf61220..1c0c4581c0 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -142,7 +142,7 @@ def _dnn_model_fn(features,
dropout=None,
input_layer_partitioner=None,
config=None,
- tpu_estimator_spec=False,
+ use_tpu=False,
batch_norm=False):
"""Deep Neural Net model_fn.
@@ -164,8 +164,8 @@ def _dnn_model_fn(features,
input_layer_partitioner: Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
- tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or
- or `model_fn.EstimatorSpec` instance.
+ use_tpu: Whether to make a DNN model able to run on TPU. Will make function
+ return a `_TPUEstimatorSpec` instance and disable variable partitioning.
batch_norm: Whether to use batch normalization after each hidden layer.
Returns:
@@ -182,13 +182,15 @@ def _dnn_model_fn(features,
optimizer, learning_rate=_LEARNING_RATE)
num_ps_replicas = config.num_ps_replicas if config else 0
- partitioner = partitioned_variables.min_max_variable_partitioner(
- max_partitions=num_ps_replicas)
+ partitioner = (None if use_tpu else
+ partitioned_variables.min_max_variable_partitioner(
+ max_partitions=num_ps_replicas))
with variable_scope.variable_scope(
'dnn',
values=tuple(six.itervalues(features)),
partitioner=partitioner):
input_layer_partitioner = input_layer_partitioner or (
+ None if use_tpu else
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
@@ -203,7 +205,7 @@ def _dnn_model_fn(features,
batch_norm=batch_norm)
logits = logit_fn(features=features, mode=mode)
- if tpu_estimator_spec:
+ if use_tpu:
return head._create_tpu_estimator_spec( # pylint: disable=protected-access
features=features,
mode=mode,