aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py
index 2260b6b0b0..652e47ae07 100644
--- a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py
+++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py
@@ -123,6 +123,11 @@ ops.register_tensor_conversion_function(
class _StochasticValueType(object):
+ """Interface for the ValueType classes.
+
+ This is the base class for MeanValue, SampleValue, SampleAndReshapeValue,
+ and their descendants.
+ """
def pushed_above(self, unused_value_type):
pass
@@ -130,6 +135,9 @@ class _StochasticValueType(object):
def popped_above(self, unused_value_type):
pass
+ def declare_inputs(self, unused_stochastic_tensor, unused_inputs_dict):
+ pass
+
@abc.abstractproperty
def stop_gradient(self):
"""Whether the value should be wrapped in stop_gradient.
@@ -310,6 +318,8 @@ class DistributionTensor(StochasticTensor):
else:
self._value_type = get_current_value_type()
+ self._value_type.declare_inputs(self, dist_args)
+
with ops.op_scope(dist_args.values(), name, "DistributionTensor") as scope:
self._name = scope
self._dist = dist_cls(**dist_args)