aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py')
-rw-r--r--tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
index 70813fb217..41258edd90 100644
--- a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
+++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
@@ -72,7 +72,8 @@ class ConstrainedMinimizationProblem(object):
else:
proxy_constraints_shape = self.proxy_constraints.get_shape()
- if (constraints_shape is None or proxy_constraints_shape is None or
+ if (constraints_shape.ndims is None or
+ proxy_constraints_shape.ndims is None or
any([ii is None for ii in constraints_shape.as_list()]) or
any([ii is None for ii in proxy_constraints_shape.as_list()])):
raise ValueError(
@@ -121,3 +122,19 @@ class ConstrainedMinimizationProblem(object):
A tensor of proxy constraint functions.
"""
return None
+
+ # This is a property, instead of an abstract property, since it doesn't need
+ # to be overridden: if pre_train_ops returns None, then there are no ops to
+ # run before train_op.
+ @property
+ def pre_train_ops(self):
+ """Returns a list of `Operation`s to run before the train_op.
+
+ When a `ConstrainedOptimizer` creates a train_op (in `minimize`
+ `minimize_unconstrained`, or `minimize_constrained`), it will include these
+ ops before the main training step.
+
+ Returns:
+ A list of `Operation`s.
+ """
+ return None