aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-04-07 13:36:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-07 14:48:41 -0700
commit82254c06f0e4ea8225216b7518b7339594010955 (patch)
tree2854c46c33ed675228ed4f0ba5bfe96f41c8c53b /tensorflow/contrib/bayesflow
parentfd3d1371de9aed6430b881dfdf22faef43af11ac (diff)
Formatting changes
Change: 152544842
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py42
1 files changed, 14 insertions, 28 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
index 81e40dbe5e..c7f185aab8 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
@@ -42,12 +42,10 @@ class StochasticTensorTest(test.TestCase):
sigma2 = constant_op.constant([0.1, 0.2, 0.3])
prior_default = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma))
+ distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior_default.value_type, st.SampleValue))
prior_0 = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma),
+ distributions.Normal(loc=mu, scale=sigma),
dist_value_type=st.SampleValue())
self.assertTrue(isinstance(prior_0.value_type, st.SampleValue))
@@ -55,8 +53,7 @@ class StochasticTensorTest(test.TestCase):
prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior.value_type, st.SampleValue))
likelihood = st.StochasticTensor(
- distributions.Normal(
- loc=prior, scale=sigma2))
+ distributions.Normal(loc=prior, scale=sigma2))
self.assertTrue(isinstance(likelihood.value_type, st.SampleValue))
coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
@@ -102,8 +99,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue()):
prior_single = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma))
+ distributions.Normal(loc=mu, scale=sigma))
prior_single_value = prior_single.value()
self.assertEqual(prior_single_value.get_shape(), (2, 3))
@@ -113,8 +109,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue(1)):
prior_single = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma))
+ distributions.Normal(loc=mu, scale=sigma))
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
prior_single_value = prior_single.value()
@@ -125,8 +120,7 @@ class StochasticTensorTest(test.TestCase):
with st.value_type(st.SampleValue(2)):
prior_double = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma))
+ distributions.Normal(loc=mu, scale=sigma))
prior_double_value = prior_double.value()
self.assertEqual(prior_double_value.get_shape(), (2, 2, 3))
@@ -163,8 +157,7 @@ class StochasticTensorTest(test.TestCase):
# With passed-in loss_fn.
dt = st.StochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma),
+ distributions.Normal(loc=mu, scale=sigma),
dist_value_type=st.MeanValue(stop_gradient=True),
loss_fn=sge.get_score_function_with_constant_baseline(
baseline=constant_op.constant(8.0)))
@@ -199,8 +192,7 @@ class ObservedStochasticTensorTest(test.TestCase):
sigma = constant_op.constant([1.1, 1.2, 1.3])
obs = array_ops.zeros((2, 3))
z = st.ObservedStochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma), value=obs)
+ distributions.Normal(loc=mu, scale=sigma), value=obs)
[obs_val, z_val] = sess.run([obs, z.value()])
self.assertAllEqual(obs_val, z_val)
@@ -212,15 +204,13 @@ class ObservedStochasticTensorTest(test.TestCase):
sigma = array_ops.placeholder(dtypes.float32)
obs = array_ops.placeholder(dtypes.float32)
z = st.ObservedStochasticTensor(
- distributions.Normal(
- loc=mu, scale=sigma), value=obs)
+ distributions.Normal(loc=mu, scale=sigma), value=obs)
mu2 = array_ops.placeholder(dtypes.float32, shape=[None])
sigma2 = array_ops.placeholder(dtypes.float32, shape=[None])
obs2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
z2 = st.ObservedStochasticTensor(
- distributions.Normal(
- loc=mu2, scale=sigma2), value=obs2)
+ distributions.Normal(loc=mu2, scale=sigma2), value=obs2)
coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
self.assertEqual(coll, [z, z2])
@@ -231,22 +221,18 @@ class ObservedStochasticTensorTest(test.TestCase):
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
- distributions.Normal(
- loc=mu, scale=sigma),
+ distributions.Normal(loc=mu, scale=sigma),
value=array_ops.zeros((3,)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
- distributions.Normal(
- loc=mu, scale=sigma),
+ distributions.Normal(loc=mu, scale=sigma),
value=array_ops.zeros((3, 1)))
self.assertRaises(
ValueError,
st.ObservedStochasticTensor,
- distributions.Normal(
- loc=mu, scale=sigma),
- value=array_ops.zeros(
- (1, 2), dtype=dtypes.int32))
+ distributions.Normal(loc=mu, scale=sigma),
+ value=array_ops.zeros((1, 2), dtype=dtypes.int32))
if __name__ == "__main__":