aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-18 14:27:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 14:31:45 -0700
commit3fa0009cbdb8ef95593ffaf63d97e05bf1835cb8 (patch)
tree841d1a766afd0e98a963a9ae7ac54fa9798e15d5 /tensorflow/contrib/distributions
parent3550ef89bc66d03b6e2db8e47bf7b038d9f4ceff (diff)
Replace distribution_util.assert_close with tf.assert_near.
PiperOrigin-RevId: 201058937
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/python/ops/onehot_categorical.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py2
2 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
index 0c762f17c9..214c6dca4a 100644
--- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
@@ -235,7 +235,7 @@ class OneHotCategorical(distribution.Distribution):
return x
return control_flow_ops.with_dependencies([
check_ops.assert_non_positive(x),
- distribution_util.assert_close(
+ check_ops.assert_near(
array_ops.zeros([], dtype=self.dtype),
math_ops.reduce_logsumexp(x, axis=[-1])),
], x)
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index 9b5bd7576f..25aaac379a 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -299,7 +299,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
return x
return control_flow_ops.with_dependencies([
check_ops.assert_non_positive(x),
- distribution_util.assert_close(
+ check_ops.assert_near(
array_ops.zeros([], dtype=self.dtype),
math_ops.reduce_logsumexp(x, axis=[-1])),
], x)