aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py25
1 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
index a647f963ed..ecf068bf6b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
-from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain
-from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
-from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
-from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
+from tensorflow.contrib.distributions.python.ops.bijectors import bijector_test_util
+from tensorflow.contrib.distributions.python.ops.bijectors import chain as chain_lib
+from tensorflow.contrib.distributions.python.ops.bijectors import exp as exp_lib
+from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered as softmax_centered_lib
+from tensorflow.contrib.distributions.python.ops.bijectors import softplus as softplus_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
@@ -34,7 +34,8 @@ class ChainBijectorTest(test.TestCase):
def testBijector(self):
with self.test_session():
- chain = Chain((Exp(event_ndims=1), Softplus(event_ndims=1)))
+ chain = chain_lib.Chain((exp_lib.Exp(event_ndims=1),
+ softplus_lib.Softplus(event_ndims=1)))
self.assertEqual("chain_of_exp_of_softplus", chain.name)
x = np.asarray([[[1., 2.],
[2., 3.]]])
@@ -48,7 +49,7 @@ class ChainBijectorTest(test.TestCase):
def testBijectorIdentity(self):
with self.test_session():
- chain = Chain()
+ chain = chain_lib.Chain()
self.assertEqual("identity", chain.name)
x = np.asarray([[[1., 2.],
[2., 3.]]])
@@ -59,16 +60,16 @@ class ChainBijectorTest(test.TestCase):
def testScalarCongruency(self):
with self.test_session():
- bijector = Chain((Exp(), Softplus()))
- assert_scalar_congruency(
+ bijector = chain_lib.Chain((exp_lib.Exp(), softplus_lib.Softplus()))
+ bijector_test_util.assert_scalar_congruency(
bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
def testShapeGetters(self):
with self.test_session():
- bijector = Chain([
- SoftmaxCentered(
+ bijector = chain_lib.Chain([
+ softmax_centered_lib.SoftmaxCentered(
event_ndims=1, validate_args=True),
- SoftmaxCentered(
+ softmax_centered_lib.SoftmaxCentered(
event_ndims=0, validate_args=True)])
x = tensor_shape.TensorShape([])
y = tensor_shape.TensorShape([2 + 1])