diff options
author | 2016-07-20 16:32:32 -0800 | |
---|---|---|
committer | 2016-07-20 17:47:08 -0700 | |
commit | 7573d54aa1b0181f4855d12ea063251cc995d630 (patch) | |
tree | a191386bf31960886b5a16d9e34445b51d35f087 | |
parent | 81b26083e15d4f79c4033343840c08d09b97ec56 (diff) |
Add stochastic_tensors arg to surrogate_loss to specify which STs get loss terms
Change: 128009334
-rw-r--r-- | tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py | 22 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py | 18 |
2 files changed, 35 insertions, 5 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py index eebefbad28..d42da11c77 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py @@ -331,5 +331,27 @@ class TestSurrogateLosses(tf.test.TestCase): loss_fn=None) self.assertEqual(None, dt.loss(tf.constant([2.0]))) + def testExplicitStochasticTensors(self): + with self.test_session() as sess: + mu = tf.constant([0.0, 0.1, 0.2]) + sigma = tf.constant([1.1, 1.2, 1.3]) + with sg.value_type(sg.SampleAndReshapeValue()): + dt1 = sg.DistributionTensor(NormalNotParam, mu=mu, sigma=sigma) + dt2 = sg.DistributionTensor(NormalNotParam, mu=mu, sigma=sigma) + loss = tf.square(tf.identity(dt1)) + 10. + dt2 + + sl_all = sg.surrogate_loss([loss]) + sl_dt1 = sg.surrogate_loss([loss], stochastic_tensors=[dt1]) + sl_dt2 = sg.surrogate_loss([loss], stochastic_tensors=[dt2]) + + dt1_term = dt1.distribution.log_pdf(dt1) * loss + dt2_term = dt2.distribution.log_pdf(dt2) * loss + + self.assertAllClose(*sess.run( + [sl_all, sum([loss, dt1_term, dt2_term])])) + self.assertAllClose(*sess.run([sl_dt1, sum([loss, dt1_term])])) + self.assertAllClose(*sess.run([sl_dt2, sum([loss, dt2_term])])) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py index 7686162944..ca87f71182 100644 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py +++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py @@ -493,11 +493,13 @@ class DistributionTensor(StochasticTensor): self._value_type) -def _stochastic_dependencies_map(fixed_losses): +def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None): """Map stochastic tensors to the fixed losses that depend on them. Args: - fixed_losses: a list of Tensors. + fixed_losses: a list of `Tensor`s. + stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses. + If `None`, all `StochasticTensor`s in the graph will be used. Returns: A dict `dependencies` that maps `StochasticTensor` objects to subsets of @@ -506,7 +508,7 @@ def _stochastic_dependencies_map(fixed_losses): If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there is a direct path from `st.value()` to `loss` in the graph. """ - stoch_value_collection = ops.get_collection( + stoch_value_collection = stochastic_tensors or ops.get_collection( STOCHASTIC_TENSOR_COLLECTION) if not stoch_value_collection: @@ -530,7 +532,9 @@ def _stochastic_dependencies_map(fixed_losses): return stoch_dependencies_map -def surrogate_loss(sample_losses, name="SurrogateLoss"): +def surrogate_loss(sample_losses, + stochastic_tensors=None, + name="SurrogateLoss"): """Surrogate loss for stochastic graphs. This function will call `loss_fn` on each `StochasticTensor` @@ -543,6 +547,9 @@ def surrogate_loss(sample_losses, name="SurrogateLoss"): sample_losses: a list or tuple of final losses. Each loss should be per example in the batch (and possibly per sample); that is, it should have dimensionality of 1 or greater. All losses should have the same shape. + stochastic_tensors: a list of `StochasticTensor`s to add loss terms for. + If None, defaults to all `StochasticTensor`s in the graph upstream of + the `Tensor`s in `sample_losses`. name: the name with which to prepend created ops. Returns: @@ -568,7 +575,8 @@ def surrogate_loss(sample_losses, name="SurrogateLoss"): loss) fixed_losses.append(array_ops.stop_gradient(loss)) - stoch_dependencies_map = _stochastic_dependencies_map(fixed_losses) + stoch_dependencies_map = _stochastic_dependencies_map( + fixed_losses, stochastic_tensors=stochastic_tensors) if not stoch_dependencies_map: logging.warn( "No collection of Stochastic Tensors found for current graph.") |