aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-11-09 09:43:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-09 10:04:51 -0800
commitc66f878520ed42715606dea0fa0f2d8e0e2425d1 (patch)
tree6cd39837517a4412b04f0763ba9afd6eecf2ee1a
parentd9da9721f45950035f5087c59f9bc6910e232271 (diff)
bayesflow: replace SampleAndReshapeValue with SampleValue()
Change: 138649779
-rw-r--r--tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py8
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py40
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py118
4 files changed, 46 insertions, 122 deletions
diff --git a/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py b/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py
index ff9abfea45..2eb625487f 100644
--- a/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py
+++ b/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py
@@ -113,7 +113,7 @@ class REINFORCESimpleExample(tf.test.TestCase):
with self.test_session() as sess:
# Use sampling to train REINFORCE
- with st.value_type(st.SampleAndReshapeValue(n=1)):
+ with st.value_type(st.SampleValue()):
(route_selection,
routing_loss,
final_loss) = build_split_apply_merge_model()
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
index de5c5c82b8..5d4fc66c69 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
@@ -38,7 +38,7 @@ class TestSurrogateLosses(tf.test.TestCase):
with self.test_session():
mu = [0.0, 0.1, 0.2]
sigma = tf.constant([1.1, 1.2, 1.3])
- with st.value_type(st.SampleAndReshapeValue()):
+ with st.value_type(st.SampleValue()):
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
likelihood = st.StochasticTensor(
distributions.Normal(mu=prior, sigma=sigma))
@@ -76,7 +76,7 @@ class TestSurrogateLosses(tf.test.TestCase):
with self.test_session() as sess:
mu = tf.constant([0.0, 0.1, 0.2])
sigma = tf.constant([1.1, 1.2, 1.3])
- with st.value_type(st.SampleAndReshapeValue()):
+ with st.value_type(st.SampleValue()):
prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma))
prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
@@ -153,7 +153,7 @@ class TestSurrogateLosses(tf.test.TestCase):
with self.test_session():
mu = tf.constant([0.0, 0.1, 0.2])
sigma = tf.constant([1.1, 1.2, 1.3])
- with st.value_type(st.SampleAndReshapeValue()):
+ with st.value_type(st.SampleValue()):
dt = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma),
loss_fn=None)
self.assertEqual(None, dt.loss(tf.constant([2.0])))
@@ -162,7 +162,7 @@ class TestSurrogateLosses(tf.test.TestCase):
with self.test_session() as sess:
mu = tf.constant([0.0, 0.1, 0.2])
sigma = tf.constant([1.1, 1.2, 1.3])
- with st.value_type(st.SampleAndReshapeValue()):
+ with st.value_type(st.SampleValue()):
dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
loss = tf.square(tf.identity(dt1)) + 10. + dt2
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
index b7bd2adfe8..b73e87ce28 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py
@@ -37,19 +37,19 @@ class StochasticTensorTest(tf.test.TestCase):
prior_default = st.StochasticTensor(
distributions.Normal(mu=mu, sigma=sigma))
self.assertTrue(
- isinstance(prior_default.value_type, st.SampleAndReshapeValue))
+ isinstance(prior_default.value_type, st.SampleValue))
prior_0 = st.StochasticTensor(
distributions.Normal(mu=mu, sigma=sigma),
- dist_value_type=st.SampleAndReshapeValue())
- self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue))
+ dist_value_type=st.SampleValue())
+ self.assertTrue(isinstance(prior_0.value_type, st.SampleValue))
- with st.value_type(st.SampleAndReshapeValue()):
+ with st.value_type(st.SampleValue()):
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
- self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue))
+ self.assertTrue(isinstance(prior.value_type, st.SampleValue))
likelihood = st.StochasticTensor(
distributions.Normal(mu=prior, sigma=sigma2))
self.assertTrue(
- isinstance(likelihood.value_type, st.SampleAndReshapeValue))
+ isinstance(likelihood.value_type, st.SampleValue))
coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
self.assertEqual(coll, [prior_default, prior_0, prior, likelihood])
@@ -87,15 +87,14 @@ class StochasticTensorTest(tf.test.TestCase):
self.assertAllEqual(prior_mean_val, mu)
self.assertAllEqual(prior_mean_val, prior_value_val)
- def testSampleAndReshapeValue(self):
+ def testSampleValueScalar(self):
with self.test_session() as sess:
mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]]
sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]])
- with st.value_type(st.SampleAndReshapeValue()):
+ with st.value_type(st.SampleValue()):
prior_single = st.StochasticTensor(
- distributions.Normal(
- mu=mu, sigma=sigma))
+ distributions.Normal(mu=mu, sigma=sigma))
prior_single_value = prior_single.value()
self.assertEqual(prior_single_value.get_shape(), (2, 3))
@@ -103,22 +102,7 @@ class StochasticTensorTest(tf.test.TestCase):
prior_single_value_val = sess.run([prior_single_value])[0]
self.assertEqual(prior_single_value_val.shape, (2, 3))
- with st.value_type(st.SampleAndReshapeValue(n=2)):
- prior_double = st.StochasticTensor(
- distributions.Normal(mu=mu, sigma=sigma))
-
- prior_double_value = prior_double.value()
- self.assertEqual(prior_double_value.get_shape(), (4, 3))
-
- prior_double_value_val = sess.run([prior_double_value])[0]
- self.assertEqual(prior_double_value_val.shape, (4, 3))
-
- def testSampleValue(self):
- with self.test_session() as sess:
- mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]]
- sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]])
-
- with st.value_type(st.SampleValue()):
+ with st.value_type(st.SampleValue(1)):
prior_single = st.StochasticTensor(
distributions.Normal(mu=mu, sigma=sigma))
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
@@ -129,7 +113,7 @@ class StochasticTensorTest(tf.test.TestCase):
prior_single_value_val = sess.run([prior_single_value])[0]
self.assertEqual(prior_single_value_val.shape, (1, 2, 3))
- with st.value_type(st.SampleValue(n=2)):
+ with st.value_type(st.SampleValue(2)):
prior_double = st.StochasticTensor(
distributions.Normal(mu=mu, sigma=sigma))
@@ -182,7 +166,7 @@ class ValueTypeTest(tf.test.TestCase):
def testValueType(self):
type_mean = st.MeanValue()
- type_reshape = st.SampleAndReshapeValue()
+ type_reshape = st.SampleValue()
type_full = st.SampleValue()
with st.value_type(type_mean):
self.assertEqual(st.get_current_value_type(), type_mean)
diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py
index eaee3344e5..e52c81740d 100644
--- a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py
+++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py
@@ -31,7 +31,6 @@ both continuous and discrete stochastic nodes.
@@MeanValue
@@SampleValue
-@@SampleAndReshapeValue
@@value_type
@@get_current_value_type
@@ -51,7 +50,6 @@ import six
from tensorflow.contrib import distributions
from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators as sge
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
STOCHASTIC_TENSOR_COLLECTION = "_stochastic_tensor_collection_"
@@ -122,8 +120,7 @@ ops.register_tensor_conversion_function(
class _StochasticValueType(object):
"""Interface for the ValueType classes.
- This is the base class for MeanValue, SampleValue, SampleAndReshapeValue,
- and their descendants.
+ This is the base class for MeanValue, SampleValue, and their descendants.
"""
def pushed_above(self, unused_value_type):
@@ -155,89 +152,53 @@ class MeanValue(_StochasticValueType):
class SampleValue(_StochasticValueType):
- """Draw n samples along a new outer dimension.
+ """Draw samples, possibly adding new outer dimensions along the way.
- This ValueType draws `n` samples from StochasticTensors run within its
- context, increasing the rank by one along a new outer dimension.
+ This ValueType draws samples from StochasticTensors run within its
+ context, increasing the rank according to the requested shape.
- Example:
+ Examples:
```python
mu = tf.zeros((2,3))
sigma = tf.ones((2, 3))
- with sg.value_type(sg.SampleValue(n=4)):
+ with sg.value_type(sg.SampleValue()):
st = sg.StochasticTensor(
distributions.Normal, mu=mu, sigma=sigma)
- # draws 4 samples each with shape (2, 3) and concatenates
- assertEqual(st.value().get_shape(), (4, 2, 3))
+ # draws 1 sample and does not reshape
+ assertEqual(st.value().get_shape(), (2, 3))
```
- """
-
- def __init__(self, n=1, stop_gradient=False):
- """Sample `n` times and concatenate along a new outer dimension.
-
- Args:
- n: A python integer or int32 tensor. The number of samples to take.
- stop_gradient: If `True`, StochasticTensors' values are wrapped in
- `stop_gradient`, to avoid backpropagation through.
- """
- self._n = n
- self._stop_gradient = stop_gradient
-
- @property
- def n(self):
- return self._n
-
- @property
- def stop_gradient(self):
- return self._stop_gradient
-
-
-class SampleAndReshapeValue(_StochasticValueType):
- """Ask the StochasticTensor for n samples and reshape the result.
-
- Sampling from a StochasticTensor increases the rank of the value by 1
- (because each sample represents a new outer dimension).
-
- This ValueType requests `n` samples from StochasticTensors run within its
- context that the outer two dimensions are reshaped to intermix the samples
- with the outermost (usually batch) dimension.
-
- Example:
```python
- # mu and sigma are both shaped (2, 3)
- mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]]
- sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]])
-
- with sg.value_type(sg.SampleAndReshapeValue(n=2)):
+ mu = tf.zeros((2,3))
+ sigma = tf.ones((2, 3))
+ with sg.value_type(sg.SampleValue(4)):
st = sg.StochasticTensor(
- distributions.Normal, mu=mu, sigma=sigma)
-
- # sample(2) creates a (2, 2, 3) tensor, and the two outermost dimensions
- # are reshaped into one: the final value is a (4, 3) tensor.
- st_value = st.value()
- assertEqual(st_value.get_shape(), (4, 3))
-
- st_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
- assertEqual(st_value_val.shape, (4, 3))
+ distributions.Normal, mu=mu, sigma=sigma)
+ # draws 4 samples each with shape (2, 3) and concatenates
+ assertEqual(st.value().get_shape(), (4, 2, 3))
```
"""
- def __init__(self, n=1, stop_gradient=False):
- """Sample `n` times and reshape the outer 2 axes so rank does not change.
+ def __init__(self, shape=(), stop_gradient=False):
+ """Sample according to shape.
+
+ For the given StochasticTensor `st` using this value type,
+ the shape of `st.value()` will match that of
+ `st.distribution.sample(shape)`.
Args:
- n: A python integer or int32 tensor. The number of samples to take.
+ shape: A shape tuple or int32 tensor. The sample shape.
+ Default is a scalar: take one sample and do not change the size.
stop_gradient: If `True`, StochasticTensors' values are wrapped in
`stop_gradient`, to avoid backpropagation through.
"""
- self._n = n
+ self._shape = shape
self._stop_gradient = stop_gradient
@property
- def n(self):
- return self._n
+ def shape(self):
+ return self._shape
@property
def stop_gradient(self):
@@ -267,7 +228,7 @@ def value_type(dist_value_type):
in a `stop_gradients` call to disable any possible backpropagation.
Args:
- dist_value_type: An instance of `MeanValue`, `SampleAndReshapeValue`, or
+ dist_value_type: An instance of `MeanValue`, `SampleValue`, or
any other stochastic value type.
Yields:
@@ -317,7 +278,7 @@ class StochasticTensor(BaseStochasticTensor):
`StochasticTensor` is backed by the `dist` distribution and its `value`
method will return the same value each time it is called. What `value` is
returned is controlled by the `dist_value_type` (defaults to
- `SampleAndReshapeValue`).
+ `SampleValue`).
Some distributions' sample functions are not differentiable (e.g. a sample
from a discrete distribution like a Bernoulli) and so to differentiate
@@ -356,7 +317,7 @@ class StochasticTensor(BaseStochasticTensor):
try:
self._value_type = get_current_value_type()
except NoValueTypeSetError:
- self._value_type = SampleAndReshapeValue()
+ self._value_type = SampleValue()
else:
# We want to enforce a value type here, but use the value_type()
# context manager to enforce some error checking.
@@ -388,26 +349,7 @@ class StochasticTensor(BaseStochasticTensor):
if isinstance(self._value_type, MeanValue):
value_tensor = self._dist.mean()
elif isinstance(self._value_type, SampleValue):
- value_tensor = self._dist.sample(self._value_type.n)
- elif isinstance(self._value_type, SampleAndReshapeValue):
- if self._value_type.n == 1:
- value_tensor = self._dist.sample()
- else:
- samples = self._dist.sample(self._value_type.n)
- samples_shape = array_ops.shape(samples)
- samples_static_shape = samples.get_shape()
- new_batch_size = samples_shape[0] * samples_shape[1]
- value_tensor = array_ops.reshape(
- samples, array_ops.concat(0, ([new_batch_size], samples_shape[2:])))
- if samples_static_shape.ndims is not None:
- # Update the static shape for shape inference purposes
- shape_list = samples_static_shape.as_list()
- new_shape = tensor_shape.vector(
- shape_list[0] * shape_list[1]
- if shape_list[0] is not None and shape_list[1] is not None
- else None)
- new_shape = new_shape.concatenate(samples_static_shape[2:])
- value_tensor.set_shape(new_shape)
+ value_tensor = self._dist.sample(self._value_type.shape)
else:
raise TypeError(
"Unrecognized Distribution Value Type: %s", self._value_type)
@@ -462,7 +404,6 @@ class StochasticTensor(BaseStochasticTensor):
with ops.name_scope(self.name, values=[final_loss]):
with ops.name_scope(name):
if (self._value_type.stop_gradient or
- isinstance(self._value_type, SampleAndReshapeValue) or
isinstance(self._value_type, SampleValue)):
return self._loss_fn(self, self._value, final_loss)
elif isinstance(self._value_type, MeanValue):
@@ -530,7 +471,6 @@ __all__ = [
"ObservedStochasticTensor",
"MeanValue",
"SampleValue",
- "SampleAndReshapeValue",
"value_type",
"get_current_value_type",
]