diff options
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py index 2260b6b0b0..652e47ae07 100644 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py +++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py @@ -123,6 +123,11 @@ ops.register_tensor_conversion_function( class _StochasticValueType(object): + """Interface for the ValueType classes. + + This is the base class for MeanValue, SampleValue, SampleAndReshapeValue, + and their descendants. + """ def pushed_above(self, unused_value_type): pass @@ -130,6 +135,9 @@ class _StochasticValueType(object): def popped_above(self, unused_value_type): pass + def declare_inputs(self, unused_stochastic_tensor, unused_inputs_dict): + pass + @abc.abstractproperty def stop_gradient(self): """Whether the value should be wrapped in stop_gradient. @@ -310,6 +318,8 @@ class DistributionTensor(StochasticTensor): else: self._value_type = get_current_value_type() + self._value_type.declare_inputs(self, dist_args) + with ops.op_scope(dist_args.values(), name, "DistributionTensor") as scope: self._name = scope self._dist = dist_cls(**dist_args) |