aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-23 21:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 21:21:38 -0700
commit22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch)
treed16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/contrib/kfac
parent24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff)
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py6
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions_lib.py1
2 files changed, 3 insertions, 4 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
index e7d4243fc3..42d525c2c2 100644
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py
@@ -613,19 +613,19 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
def multiply_fisher(self, vector):
probs = self._probs
return vector * probs - probs * math_ops.reduce_sum(
- vector * probs, axis=-1, keep_dims=True)
+ vector * probs, axis=-1, keepdims=True)
def multiply_fisher_factor(self, vector):
probs = self._probs
sqrt_probs = self._sqrt_probs
return sqrt_probs * vector - probs * math_ops.reduce_sum(
- sqrt_probs * vector, axis=-1, keep_dims=True)
+ sqrt_probs * vector, axis=-1, keepdims=True)
def multiply_fisher_factor_transpose(self, vector):
probs = self._probs
sqrt_probs = self._sqrt_probs
return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
- probs * vector, axis=-1, keep_dims=True)
+ probs * vector, axis=-1, keepdims=True)
def multiply_fisher_factor_replicated_one_hot(self, index):
assert len(index) == 1, "Length of index was {}".format(len(index))
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
index 705a871d48..4279cb2792 100644
--- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
@@ -33,7 +33,6 @@ _allowed_symbols = [
"CategoricalLogitsNegativeLogProbLoss",
"OnehotCategoricalLogitsNegativeLogProbLoss",
"MultiBernoulliNegativeLogProbLoss",
- "MultiBernoulliNegativeLogProbLoss",
"insert_slice_in_zeros",
]