aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py')
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index b6becfa9fc..2aa771a71e 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -278,7 +278,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
* math_ops.log(self.temperature))
# compute the unnormalized density
log_softmax = nn_ops.log_softmax(logits_2d - x_2d * self._temperature_2d)
- log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keep_dims=False)
+ log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keepdims=False)
# combine unnormalized density with normalization constant
log_prob = log_norm_const + log_unnorm_prob
# Reshapes log_prob to be consistent with shape of user-supplied logits