aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py4
2 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index ae2fd8b490..3dcf0374c8 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -485,7 +485,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
reduction=losses.Reduction.NONE)
# Averages loss over classes.
unweighted_loss = math_ops.reduce_mean(
- unweighted_loss, axis=-1, keep_dims=True)
+ unweighted_loss, axis=-1, keepdims=True)
weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access,
features=features, weight_column=self._weight_column, logits=logits)
training_loss = losses.compute_weighted_loss(
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index fa2697800e..a8774d6dab 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -456,7 +456,7 @@ def _get_local_devices(device_type):
def _split_batch(features, labels, number_of_shards, device):
- """Split input features and labes into batches."""
+ """Split input features and labels into batches."""
def ensure_divisible_by_shards(sequence):
batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0]
@@ -602,7 +602,7 @@ def _local_device_setter(worker_device, ps_devices, ps_strategy):
def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers):
- """Produce an EstimatorSpec with approproriately scaled loss."""
+ """Produce an EstimatorSpec with appropriately scaled loss."""
if tower_spec.loss is None:
return tower_spec