aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-20 16:32:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-20 17:47:08 -0700
commit7573d54aa1b0181f4855d12ea063251cc995d630 (patch)
treea191386bf31960886b5a16d9e34445b51d35f087
parent81b26083e15d4f79c4033343840c08d09b97ec56 (diff)
Add stochastic_tensors arg to surrogate_loss to specify which STs get loss terms
Change: 128009334
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py22
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py18
2 files changed, 35 insertions, 5 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
index eebefbad28..d42da11c77 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py
@@ -331,5 +331,27 @@ class TestSurrogateLosses(tf.test.TestCase):
loss_fn=None)
self.assertEqual(None, dt.loss(tf.constant([2.0])))
+ def testExplicitStochasticTensors(self):
+ with self.test_session() as sess:
+ mu = tf.constant([0.0, 0.1, 0.2])
+ sigma = tf.constant([1.1, 1.2, 1.3])
+ with sg.value_type(sg.SampleAndReshapeValue()):
+ dt1 = sg.DistributionTensor(NormalNotParam, mu=mu, sigma=sigma)
+ dt2 = sg.DistributionTensor(NormalNotParam, mu=mu, sigma=sigma)
+ loss = tf.square(tf.identity(dt1)) + 10. + dt2
+
+ sl_all = sg.surrogate_loss([loss])
+ sl_dt1 = sg.surrogate_loss([loss], stochastic_tensors=[dt1])
+ sl_dt2 = sg.surrogate_loss([loss], stochastic_tensors=[dt2])
+
+ dt1_term = dt1.distribution.log_pdf(dt1) * loss
+ dt2_term = dt2.distribution.log_pdf(dt2) * loss
+
+ self.assertAllClose(*sess.run(
+ [sl_all, sum([loss, dt1_term, dt2_term])]))
+ self.assertAllClose(*sess.run([sl_dt1, sum([loss, dt1_term])]))
+ self.assertAllClose(*sess.run([sl_dt2, sum([loss, dt2_term])]))
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py
index 7686162944..ca87f71182 100644
--- a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py
+++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py
@@ -493,11 +493,13 @@ class DistributionTensor(StochasticTensor):
self._value_type)
-def _stochastic_dependencies_map(fixed_losses):
+def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None):
"""Map stochastic tensors to the fixed losses that depend on them.
Args:
- fixed_losses: a list of Tensors.
+ fixed_losses: a list of `Tensor`s.
+ stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses.
+ If `None`, all `StochasticTensor`s in the graph will be used.
Returns:
A dict `dependencies` that maps `StochasticTensor` objects to subsets of
@@ -506,7 +508,7 @@ def _stochastic_dependencies_map(fixed_losses):
If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there
is a direct path from `st.value()` to `loss` in the graph.
"""
- stoch_value_collection = ops.get_collection(
+ stoch_value_collection = stochastic_tensors or ops.get_collection(
STOCHASTIC_TENSOR_COLLECTION)
if not stoch_value_collection:
@@ -530,7 +532,9 @@ def _stochastic_dependencies_map(fixed_losses):
return stoch_dependencies_map
-def surrogate_loss(sample_losses, name="SurrogateLoss"):
+def surrogate_loss(sample_losses,
+ stochastic_tensors=None,
+ name="SurrogateLoss"):
"""Surrogate loss for stochastic graphs.
This function will call `loss_fn` on each `StochasticTensor`
@@ -543,6 +547,9 @@ def surrogate_loss(sample_losses, name="SurrogateLoss"):
sample_losses: a list or tuple of final losses. Each loss should be per
example in the batch (and possibly per sample); that is, it should have
dimensionality of 1 or greater. All losses should have the same shape.
+ stochastic_tensors: a list of `StochasticTensor`s to add loss terms for.
+ If None, defaults to all `StochasticTensor`s in the graph upstream of
+ the `Tensor`s in `sample_losses`.
name: the name with which to prepend created ops.
Returns:
@@ -568,7 +575,8 @@ def surrogate_loss(sample_losses, name="SurrogateLoss"):
loss)
fixed_losses.append(array_ops.stop_gradient(loss))
- stoch_dependencies_map = _stochastic_dependencies_map(fixed_losses)
+ stoch_dependencies_map = _stochastic_dependencies_map(
+ fixed_losses, stochastic_tensors=stochastic_tensors)
if not stoch_dependencies_map:
logging.warn(
"No collection of Stochastic Tensors found for current graph.")