diff options
author | 2017-04-07 13:36:36 -0800 | |
---|---|---|
committer | 2017-04-07 14:48:41 -0700 | |
commit | 82254c06f0e4ea8225216b7518b7339594010955 (patch) | |
tree | 2854c46c33ed675228ed4f0ba5bfe96f41c8c53b /tensorflow/contrib/bayesflow | |
parent | fd3d1371de9aed6430b881dfdf22faef43af11ac (diff) |
Formatting changes
Change: 152544842
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r-- | tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py | 42 |
1 files changed, 14 insertions, 28 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py index 81e40dbe5e..c7f185aab8 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py @@ -42,12 +42,10 @@ class StochasticTensorTest(test.TestCase): sigma2 = constant_op.constant([0.1, 0.2, 0.3]) prior_default = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma)) + distributions.Normal(loc=mu, scale=sigma)) self.assertTrue(isinstance(prior_default.value_type, st.SampleValue)) prior_0 = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma), + distributions.Normal(loc=mu, scale=sigma), dist_value_type=st.SampleValue()) self.assertTrue(isinstance(prior_0.value_type, st.SampleValue)) @@ -55,8 +53,7 @@ class StochasticTensorTest(test.TestCase): prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma)) self.assertTrue(isinstance(prior.value_type, st.SampleValue)) likelihood = st.StochasticTensor( - distributions.Normal( - loc=prior, scale=sigma2)) + distributions.Normal(loc=prior, scale=sigma2)) self.assertTrue(isinstance(likelihood.value_type, st.SampleValue)) coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) @@ -102,8 +99,7 @@ class StochasticTensorTest(test.TestCase): with st.value_type(st.SampleValue()): prior_single = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma)) + distributions.Normal(loc=mu, scale=sigma)) prior_single_value = prior_single.value() self.assertEqual(prior_single_value.get_shape(), (2, 3)) @@ -113,8 +109,7 @@ class StochasticTensorTest(test.TestCase): with st.value_type(st.SampleValue(1)): prior_single = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma)) + distributions.Normal(loc=mu, scale=sigma)) self.assertTrue(isinstance(prior_single.value_type, st.SampleValue)) prior_single_value = prior_single.value() @@ -125,8 +120,7 @@ class StochasticTensorTest(test.TestCase): with st.value_type(st.SampleValue(2)): prior_double = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma)) + distributions.Normal(loc=mu, scale=sigma)) prior_double_value = prior_double.value() self.assertEqual(prior_double_value.get_shape(), (2, 2, 3)) @@ -163,8 +157,7 @@ class StochasticTensorTest(test.TestCase): # With passed-in loss_fn. dt = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma), + distributions.Normal(loc=mu, scale=sigma), dist_value_type=st.MeanValue(stop_gradient=True), loss_fn=sge.get_score_function_with_constant_baseline( baseline=constant_op.constant(8.0))) @@ -199,8 +192,7 @@ class ObservedStochasticTensorTest(test.TestCase): sigma = constant_op.constant([1.1, 1.2, 1.3]) obs = array_ops.zeros((2, 3)) z = st.ObservedStochasticTensor( - distributions.Normal( - loc=mu, scale=sigma), value=obs) + distributions.Normal(loc=mu, scale=sigma), value=obs) [obs_val, z_val] = sess.run([obs, z.value()]) self.assertAllEqual(obs_val, z_val) @@ -212,15 +204,13 @@ class ObservedStochasticTensorTest(test.TestCase): sigma = array_ops.placeholder(dtypes.float32) obs = array_ops.placeholder(dtypes.float32) z = st.ObservedStochasticTensor( - distributions.Normal( - loc=mu, scale=sigma), value=obs) + distributions.Normal(loc=mu, scale=sigma), value=obs) mu2 = array_ops.placeholder(dtypes.float32, shape=[None]) sigma2 = array_ops.placeholder(dtypes.float32, shape=[None]) obs2 = array_ops.placeholder(dtypes.float32, shape=[None, None]) z2 = st.ObservedStochasticTensor( - distributions.Normal( - loc=mu2, scale=sigma2), value=obs2) + distributions.Normal(loc=mu2, scale=sigma2), value=obs2) coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) self.assertEqual(coll, [z, z2]) @@ -231,22 +221,18 @@ class ObservedStochasticTensorTest(test.TestCase): self.assertRaises( ValueError, st.ObservedStochasticTensor, - distributions.Normal( - loc=mu, scale=sigma), + distributions.Normal(loc=mu, scale=sigma), value=array_ops.zeros((3,))) self.assertRaises( ValueError, st.ObservedStochasticTensor, - distributions.Normal( - loc=mu, scale=sigma), + distributions.Normal(loc=mu, scale=sigma), value=array_ops.zeros((3, 1))) self.assertRaises( ValueError, st.ObservedStochasticTensor, - distributions.Normal( - loc=mu, scale=sigma), - value=array_ops.zeros( - (1, 2), dtype=dtypes.int32)) + distributions.Normal(loc=mu, scale=sigma), + value=array_ops.zeros((1, 2), dtype=dtypes.int32)) if __name__ == "__main__": |