diff options
Diffstat (limited to 'tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py')
-rw-r--r-- | tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index a647f963ed..ecf068bf6b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -20,11 +20,11 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency -from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain -from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp -from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered -from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus +from tensorflow.contrib.distributions.python.ops.bijectors import bijector_test_util +from tensorflow.contrib.distributions.python.ops.bijectors import chain as chain_lib +from tensorflow.contrib.distributions.python.ops.bijectors import exp as exp_lib +from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered as softmax_centered_lib +from tensorflow.contrib.distributions.python.ops.bijectors import softplus as softplus_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test @@ -34,7 +34,8 @@ class ChainBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): - chain = Chain((Exp(event_ndims=1), Softplus(event_ndims=1))) + chain = chain_lib.Chain((exp_lib.Exp(event_ndims=1), + softplus_lib.Softplus(event_ndims=1))) self.assertEqual("chain_of_exp_of_softplus", chain.name) x = np.asarray([[[1., 2.], [2., 3.]]]) @@ -48,7 +49,7 @@ class ChainBijectorTest(test.TestCase): def testBijectorIdentity(self): with self.test_session(): - chain = Chain() + chain = chain_lib.Chain() self.assertEqual("identity", chain.name) x = np.asarray([[[1., 2.], [2., 3.]]]) @@ -59,16 +60,16 @@ class ChainBijectorTest(test.TestCase): def testScalarCongruency(self): with self.test_session(): - bijector = Chain((Exp(), Softplus())) - assert_scalar_congruency( + bijector = chain_lib.Chain((exp_lib.Exp(), softplus_lib.Softplus())) + bijector_test_util.assert_scalar_congruency( bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) def testShapeGetters(self): with self.test_session(): - bijector = Chain([ - SoftmaxCentered( + bijector = chain_lib.Chain([ + softmax_centered_lib.SoftmaxCentered( event_ndims=1, validate_args=True), - SoftmaxCentered( + softmax_centered_lib.SoftmaxCentered( event_ndims=0, validate_args=True)]) x = tensor_shape.TensorShape([]) y = tensor_shape.TensorShape([2 + 1]) |