diff options
author | Andrew Cotter <acotter@google.com> | 2018-04-23 15:57:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-23 16:01:56 -0700 |
commit | ff15c81e2b92ef8fb47bb15790cffd18377a4ef2 (patch) | |
tree | 5c806cfb8155cba0b7eae9ea137f62e9777e73e6 /tensorflow/contrib/constrained_optimization | |
parent | bb4a80c92105426ccf20a98c4291a1a3f8499b54 (diff) |
This is a library for performing constrained optimization. It defines two interfaces: ConstrainedMinimizationProblem, which specifies a constrained optimization problem, and ConstrainedOptimizer, which is slightly different from a tf.train.Optimizer, mostly due to the fact that it is meant to optimize ConstrainedMinimizationProblems. In addition to these two interfaces, three ConstrainedOptimizer implementations are included, as well as helper functions which, given a set of candidate solutions, heuristically find the best candidate (to the constrained problem), or the best distribution over candidates.
For more details, please see our arXiv paper: "https://arxiv.org/abs/1804.06500".
PiperOrigin-RevId: 193999550
Diffstat (limited to 'tensorflow/contrib/constrained_optimization')
12 files changed, 2598 insertions, 0 deletions
diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD new file mode 100644 index 0000000000..619153df67 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/BUILD @@ -0,0 +1,91 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +# Transitive dependencies of this target will be included in the pip package. +py_library( + name = "constrained_optimization_pip", + deps = [ + ":constrained_optimization", + ":test_util", + ], +) + +py_library( + name = "constrained_optimization", + srcs = [ + "__init__.py", + "python/candidates.py", + "python/constrained_minimization_problem.py", + "python/constrained_optimizer.py", + "python/external_regret_optimizer.py", + "python/swap_regret_optimizer.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:standard_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_test( + name = "candidates_test", + srcs = ["python/candidates_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +# NOTE: This library can't be "testonly" since it needs to be included in the +# pip package. +py_library( + name = "test_util", + srcs = ["python/test_util.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + "//tensorflow/python:dtypes", + "//tensorflow/python:standard_ops", + ], +) + +py_test( + name = "external_regret_optimizer_test", + srcs = ["python/external_regret_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + ":test_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:standard_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + +py_test( + name = "swap_regret_optimizer_test", + srcs = ["python/swap_regret_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + ":test_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:standard_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md new file mode 100644 index 0000000000..c65a150464 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -0,0 +1,345 @@ +<!-- TODO(acotter): Add usage example of non-convex optimization and stochastic classification. --> + +# ConstrainedOptimization (TFCO) + +TFCO is a library for optimizing inequality-constrained problems in TensorFlow. +Both the objective function and the constraints are represented as Tensors, +giving users the maximum amount of flexibility in specifying their optimization +problems. + +This flexibility makes optimization considerably more difficult: on a non-convex +problem, if one uses the "standard" approach of introducing a Lagrange +multiplier for each constraint, and then jointly maximizing over the Lagrange +multipliers and minimizing over the model parameters, then a stable stationary +point might not even *exist*. Hence, in some cases, oscillation, instead of +convergence, is inevitable. + +Thankfully, it turns out that even if, over the course of optimization, no +*particular* iterate does a good job of minimizing the objective while +satisfying the constraints, the *sequence* of iterates, on average, usually +will. This observation suggests the following approach: at training time, we'll +periodically snapshot the model state during optimization; then, at evaluation +time, each time we're given a new example to evaluate, we'll sample one of the +saved snapshots uniformly at random, and apply it to the example. This +*stochastic model* will generally perform well, both with respect to the +objective function, and the constraints. + +In fact, we can do better: it's possible to post-process the set of snapshots to +find a distribution over at most $$m+1$$ snapshots, where $$m$$ is the number of +constraints, that will be at least as good (and will usually be much better) +than the (much larger) uniform distribution described above. If you're unable or +unwilling to use a stochastic model at all, then you can instead use a heuristic +to choose the single best snapshot. + +For full details, motivation, and theoretical results on the approach taken by +this library, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +which will be referred to as [CoJiSr18] throughout the remainder of this +document. + +### Proxy Constraints + +Imagine that we want to constrain the recall of a binary classifier to be at +least 90%. Since the recall is proportional to the number of true positive +classifications, which itself is a sum of indicator functions, this constraint +is non-differentible, and therefore cannot be used in a problem that will be +optimized using a (stochastic) gradient-based algorithm. + +For this and similar problems, TFCO supports so-called *proxy constraints*, +which are (at least semi-differentiable) approximations of the original +constraints. For example, one could create a proxy recall function by replacing +the indicator functions with sigmoids. During optimization, each proxy +constraint function will be penalized, with the magnitude of the penalty being +chosen to satisfy the corresponding *original* (non-proxy) constraint. + +On a problem including proxy constraints—even a convex problem—the +Lagrangian approach discussed above isn't guaranteed to work. However, a +different algorithm, based on minimizing *swap regret*, does work. Aside from +this difference, the recommended procedure for optimizing a proxy-constrained +problem remains the same: periodically snapshot the model during optimization, +and then either find the best $$m+1$$-sized distribution, or heuristically +choose the single best snapshot. + +## Components + +* [constrained_minimization_problem](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py): + contains the `ConstrainedMinimizationProblem` interface. Your own + constrained optimization problems should be represented using + implementations of this interface. + +* [constrained_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py): + contains the `ConstrainedOptimizer` interface, which is similar to (but + different from) `tf.train.Optimizer`, with the main difference being that + `ConstrainedOptimizer`s are given `ConstrainedMinimizationProblem`s to + optimize, and perform constrained optimization. + + * [external_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py): + contains the `AdditiveExternalRegretOptimizer` implementation, which is + a `ConstrainedOptimizer` implementing the Lagrangian approach discussed + above (with additive updates to the Lagrange multipliers). You should + use this optimizer for problems *without* proxy constraints. It may also + work for problems with proxy constraints, but we recommend using a swap + regret optimizer, instead. + + This optimizer is most similar to Algorithm 3 in Appendix C.3 of + [CoJiSr18], and is discussed in Section 3. The two differences are that + it uses proxy constraints (if they're provided) in the update of the + model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for + the "inner" updates. + + * [swap_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py): + contains the `AdditiveSwapRegretOptimizer` and + `MultiplicativeSwapRegretOptimizer` implementations, which are + `ConstrainedOptimizer`s implementing the swap-regret minimization + approach mentioned above (with additive or multiplicative updates, + respectively, to the parameters associated with the + constraints—these parameters are not Lagrange multipliers, but + play a similar role). You should use one of these optimizers (we suggest + `MultiplicativeSwapRegretOptimizer`) for problems *with* proxy + constraints. + + The `MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 + in Section 4 of [CoJiSr18], with the difference being that it uses + `tf.train.Optimizer`s, instead of SGD, for the "inner" updates. The + `AdditiveSwapRegretOptimizer` differs further in that it performs + additive (instead of multiplicative) updates of the stochastic matrix. + +* [candidates](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/candidates.py): + contains two functions, `find_best_candidate_distribution` and + `find_best_candidate_index`. Both of these functions are given a set of + candidate solutions to a constrained optimization problem, from which the + former finds the best distribution over at most $$m+1$$ candidates, and the + latter heuristically finds the single best candidate. As discussed above, + the set of candidates will typically be model snapshots saved periodically + during optimization. Both of these functions require that scipy be + installed. + + The `find_best_candidate_distribution` function implements the approach + described in Lemma 3 of [CoJiSr18], while `find_best_candidate_index` + implements the heuristic used for hyperparameter search in the experiments + of Section 5.2. + +## Convex Example with Proxy Constraints + +This is a simple example of recall-constrained optimization on simulated data: +we will try to find a classifier that minimizes the average hinge loss while +constraining recall to be at least 90%. + +We'll start with the required imports—notice the definition of `tfco`: + +```python +import math +import numpy as np +import tensorflow as tf + +tfco = tf.contrib.constrained_optimization +``` + +We'll now create an implementation of the `ConstrainedMinimizationProblem` class +for this problem. The constructor takes three parameters: a Tensor containing +the classification labels (0 or 1) for every training example, another Tensor +containing the model's predictions on every training example (sometimes called +the "logits"), and the lower bound on recall that will be enforced using a +constraint. + +This implementation will contain both constraints *and* proxy constraints: the +former represents the constraint that the true recall (defined in terms of the +*number* of true positives) be at least `recall_lower_bound`, while the latter +represents the same constraint, but on a hinge approximation of the recall. + +```python +class ExampleProblem(tfco.ConstrainedMinimizationProblem): + + def __init__(self, labels, predictions, recall_lower_bound): + self._labels = labels + self._predictions = predictions + self._recall_lower_bound = recall_lower_bound + # The number of positively-labeled examples. + self._positive_count = tf.reduce_sum(self._labels) + + @property + def objective(self): + return tf.losses.hinge_loss(labels=self._labels, logits=self._predictions) + + @property + def constraints(self): + true_positives = self._labels * tf.to_float(self._predictions > 0) + true_positive_count = tf.reduce_sum(true_positives) + recall = true_positive_count / self._positive_count + # The constraint is (recall >= self._recall_lower_bound), which we convert + # to (self._recall_lower_bound - recall <= 0) because + # ConstrainedMinimizationProblems must always provide their constraints in + # the form (tensor <= 0). + # + # The result of this function should be a tensor, with each element being + # a quantity that is constrained to be nonpositive. We only have one + # constraint, so we return a one-element tensor. + return self._recall_lower_bound - recall + + @property + def proxy_constraints(self): + # Use 1 - hinge since we're SUBTRACTING recall in the constraint function, + # and we want the proxy constraint function to be convex. + true_positives = self._labels * tf.minimum(1.0, self._predictions) + true_positive_count = tf.reduce_sum(true_positives) + recall = true_positive_count / self._positive_count + # Please see the corresponding comment in the constraints property. + return self._recall_lower_bound - recall +``` + +We'll now create a simple simulated dataset by sampling 1000 random +10-dimensional feature vectors from a Gaussian, finding their labels using a +random "ground truth" linear model, and then adding noise by randomly flipping +200 labels. + +```python +# Create a simulated 10-dimensional training dataset consisting of 1000 labeled +# examples, of which 800 are labeled correctly and 200 are mislabeled. +num_examples = 1000 +num_mislabeled_examples = 200 +dimension = 10 +# We will constrain the recall to be at least 90%. +recall_lower_bound = 0.9 + +# Create random "ground truth" parameters to a linear model. +ground_truth_weights = np.random.normal(size=dimension) / math.sqrt(dimension) +ground_truth_threshold = 0 + +# Generate a random set of features for each example. +features = np.random.normal(size=(num_examples, dimension)).astype( + np.float32) / math.sqrt(dimension) +# Compute the labels from these features given the ground truth linear model. +labels = (np.matmul(features, ground_truth_weights) > + ground_truth_threshold).astype(np.float32) +# Add noise by randomly flipping num_mislabeled_examples labels. +mislabeled_indices = np.random.choice( + num_examples, num_mislabeled_examples, replace=False) +labels[mislabeled_indices] = 1 - labels[mislabeled_indices] +``` + +We're now ready to construct our model, and the corresponding optimization +problem. We'll use a linear model of the form $$f(x) = w^T x - t$$, where $$w$$ +is the `weights`, and $$t$$ is the `threshold`. The `problem` variable will hold +an instance of the `ExampleProblem` class we created earlier. + +```python +# Create variables containing the model parameters. +weights = tf.Variable(tf.zeros(dimension), dtype=tf.float32, name="weights") +threshold = tf.Variable(0.0, dtype=tf.float32, name="threshold") + +# Create the optimization problem. +constant_labels = tf.constant(labels, dtype=tf.float32) +constant_features = tf.constant(features, dtype=tf.float32) +predictions = tf.tensordot(constant_features, weights, axes=(1, 0)) - threshold +problem = ExampleProblem( + labels=constant_labels, + predictions=predictions, + recall_lower_bound=recall_lower_bound, +) +``` + +We're almost ready to train our model, but first we'll create a couple of +functions to measure its performance. We're interested in two quantities: the +average hinge loss (which we seek to minimize), and the recall (which we +constrain). + +```python +def average_hinge_loss(labels, predictions): + num_examples, = np.shape(labels) + signed_labels = (labels * 2) - 1 + total_hinge_loss = np.sum(np.maximum(0.0, 1.0 - signed_labels * predictions)) + return total_hinge_loss / num_examples + +def recall(labels, predictions): + positive_count = np.sum(labels) + true_positives = labels * (predictions > 0) + true_positive_count = np.sum(true_positives) + return true_positive_count / positive_count +``` + +As was mentioned earlier, external regret optimizers suffice for problems +without proxy constraints, but swap regret optimizers are recommended for +problems *with* proxy constraints. Since this problem contains proxy +constraints, we use the `MultiplicativeSwapRegretOptimizer`. + +For this problem, the constraint is fairly easy to satisfy, so we can use the +same "inner" optimizer (an `AdagradOptimizer` with a learning rate of 1) for +optimization of both the model parameters (`weights` and `threshold`), and the +internal parameters associated with the constraints (these are the analogues of +the Lagrange multipliers used by the `MultiplicativeSwapRegretOptimizer`). For +more difficult problems, it will often be necessary to use different optimizers, +with different learning rates (presumably found via a hyperparameter search): to +accomplish this, pass *both* the `optimizer` and `constraint_optimizer` +parameters to `MultiplicativeSwapRegretOptimizer`'s constructor. + +Since this is a convex problem (both the objective and proxy constraint +functions are convex), we can just take the last iterate. Periodic snapshotting, +and the use of the `find_best_candidate_distribution` or +`find_best_candidate_index` functions, is generally only necessary for +non-convex problems (and even then, it isn't *always* necessary). + +```python +with tf.Session() as session: + optimizer = tfco.MultiplicativeSwapRegretOptimizer( + optimizer=tf.train.AdagradOptimizer(learning_rate=1.0)) + train_op = optimizer.minimize(problem) + + session.run(tf.global_variables_initializer()) + for ii in xrange(1000): + session.run(train_op) + + trained_weights, trained_threshold = session.run((weights, threshold)) + +trained_predictions = np.matmul(features, trained_weights) - trained_threshold +print("Constrained average hinge loss = %f" % average_hinge_loss( + labels, trained_predictions)) +print("Constrained recall = %f" % recall(labels, trained_predictions)) +``` + +Running the above code gives the following output (due to the randomness of the +dataset, you'll get a different result when you run it): + +```none +Constrained average hinge loss = 0.710019 +Constrained recall = 0.899811 +``` + +As we hoped, the recall is extremely close to 90%—and, thanks to the use +of proxy constraints, this is the *true* recall, not a hinge approximation. + +For comparison, let's try optimizing the same problem *without* the recall +constraint: + +```python +with tf.Session() as session: + optimizer = tf.train.AdagradOptimizer(learning_rate=1.0) + # For optimizing the unconstrained problem, we just minimize the "objective" + # portion of the minimization problem. + train_op = optimizer.minimize(problem.objective) + + session.run(tf.global_variables_initializer()) + for ii in xrange(1000): + session.run(train_op) + + trained_weights, trained_threshold = session.run((weights, threshold)) + +trained_predictions = np.matmul(features, trained_weights) - trained_threshold +print("Unconstrained average hinge loss = %f" % average_hinge_loss( + labels, trained_predictions)) +print("Unconstrained recall = %f" % recall(labels, trained_predictions)) +``` + +This code gives the following output (again, you'll get a different answer, +since the dataset is random): + +```none +Unconstrained average hinge loss = 0.627271 +Unconstrained recall = 0.793951 +``` + +Because there is no constraint, the unconstrained problem does a better job of +minimizing the average hinge loss, but naturally doesn't approach 90% recall. diff --git a/tensorflow/contrib/constrained_optimization/__init__.py b/tensorflow/contrib/constrained_optimization/__init__.py new file mode 100644 index 0000000000..1e49ba9f17 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A library for performing constrained optimization in TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.constrained_optimization.python.candidates import * +from tensorflow.contrib.constrained_optimization.python.constrained_minimization_problem import * +from tensorflow.contrib.constrained_optimization.python.constrained_optimizer import * +from tensorflow.contrib.constrained_optimization.python.external_regret_optimizer import * +from tensorflow.contrib.constrained_optimization.python.swap_regret_optimizer import * +# pylint: enable=wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "AdditiveExternalRegretOptimizer", + "AdditiveSwapRegretOptimizer", + "ConstrainedMinimizationProblem", + "ConstrainedOptimizer", + "find_best_candidate_distribution", + "find_best_candidate_index", + "MultiplicativeSwapRegretOptimizer", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/constrained_optimization/python/candidates.py b/tensorflow/contrib/constrained_optimization/python/candidates.py new file mode 100644 index 0000000000..ac86a6741b --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/candidates.py @@ -0,0 +1,319 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Code for optimizing over a set of candidate solutions. + +The functions in this file deal with the constrained problem: + +> minimize f(w) +> s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + +Here, f(w) is the "objective function", and g_i(w) is the ith (of m) "constraint +function". Given the values of the objective and constraint functions for a set +of n "candidate solutions" {w_0,w_1,...,w_{n-1}} (for a total of n objective +function values, and n*m constraint function values), the +`find_best_candidate_distribution` function finds the best DISTRIBUTION over +these candidates, while `find_best_candidate_index' heuristically finds the +single best candidate. + +Both of these functions have dependencies on `scipy`, so if you want to call +them, then you must make sure that `scipy` is available. The imports are +performed inside the functions themselves, so if they're not actually called, +then `scipy` is not needed. + +For more specifics, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +The `find_best_candidate_distribution` function implements the approach +described in Lemma 3, while `find_best_candidate_index` implements the heuristic +used for hyperparameter search in the experiments of Section 5.2. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + + +def _find_best_candidate_distribution_helper(objective_vector, + constraints_matrix, + maximum_violation=0.0): + """Finds a distribution minimizing an objective subject to constraints. + + This function deals with the constrained problem: + + > minimize f(w) + > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + + Here, f(w) is the "objective function", and g_i(w) is the ith (of m) + "constraint function". Given a set of n "candidate solutions" + {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n + candidates that, in expectation, minimizes the objective while violating + the constraints by no more than `maximum_violation`. If no such distribution + exists, it returns an error (using Go-style error reporting). + + The `objective_vector` parameter should be a numpy array with shape (n,), for + which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a + numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j). + + This function will return a distribution for which at most m+1 probabilities, + and often fewer, are nonzero. + + Args: + objective_vector: numpy array of shape (n,), where n is the number of + "candidate solutions". Contains the objective function values. + constraints_matrix: numpy array of shape (m,n), where m is the number of + constraints and n is the number of "candidate solutions". Contains the + constraint violation magnitudes. + maximum_violation: nonnegative float, the maximum amount by which any + constraint may be violated, in expectation. + + Returns: + A pair (`result`, `message`), exactly one of which is None. If `message` is + None, then the `result` contains the optimal distribution as a numpy array + of shape (n,). If `result` is None, then `message` contains an error + message. + + Raises: + ValueError: If `objective_vector` and `constraints_matrix` have inconsistent + shapes, or if `maximum_violation` is negative. + ImportError: If we're unable to import `scipy.optimize`. + """ + if maximum_violation < 0.0: + raise ValueError("maximum_violation must be nonnegative") + + mm, nn = np.shape(constraints_matrix) + if (nn,) != np.shape(objective_vector): + raise ValueError( + "objective_vector must have shape (n,), and constraints_matrix (m, n)," + " where n is the number of candidates, and m is the number of " + "constraints") + + # We import scipy inline, instead of at the top of the file, so that a scipy + # dependency is only introduced if either find_best_candidate_distribution() + # or find_best_candidate_index() are actually called. + import scipy.optimize # pylint: disable=g-import-not-at-top + + # Feasibility (within maximum_violation) constraints. + a_ub = constraints_matrix + b_ub = np.full((mm, 1), maximum_violation) + # Sum-to-one constraint. + a_eq = np.ones((1, nn)) + b_eq = np.ones((1, 1)) + # Nonnegativity constraints. + bounds = (0, None) + + result = scipy.optimize.linprog( + objective_vector, + A_ub=a_ub, + b_ub=b_ub, + A_eq=a_eq, + b_eq=b_eq, + bounds=bounds) + # Go-style error reporting. We don't raise on error, since + # find_best_candidate_distribution() needs to handle the failure case, and we + # shouldn't use exceptions as flow-control. + if not result.success: + return (None, result.message) + else: + return (result.x, None) + + +def find_best_candidate_distribution(objective_vector, + constraints_matrix, + epsilon=0.0): + """Finds a distribution minimizing an objective subject to constraints. + + This function deals with the constrained problem: + + > minimize f(w) + > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + + Here, f(w) is the "objective function", and g_i(w) is the ith (of m) + "constraint function". Given a set of n "candidate solutions" + {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n + candidates that, in expectation, minimizes the objective while violating + the constraints by the smallest possible amount (with the amount being found + via bisection search). + + The `objective_vector` parameter should be a numpy array with shape (n,), for + which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a + numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j). + + This function will return a distribution for which at most m+1 probabilities, + and often fewer, are nonzero. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + This function implements the approach described in Lemma 3. + + Args: + objective_vector: numpy array of shape (n,), where n is the number of + "candidate solutions". Contains the objective function values. + constraints_matrix: numpy array of shape (m,n), where m is the number of + constraints and n is the number of "candidate solutions". Contains the + constraint violation magnitudes. + epsilon: nonnegative float, the threshold at which to terminate the binary + search while searching for the minimal expected constraint violation + magnitude. + + Returns: + The optimal distribution, as a numpy array of shape (n,). + + Raises: + ValueError: If `objective_vector` and `constraints_matrix` have inconsistent + shapes, or if `epsilon` is negative. + ImportError: If we're unable to import `scipy.optimize`. + """ + if epsilon < 0.0: + raise ValueError("epsilon must be nonnegative") + + # If there is a feasible solution (i.e. with maximum_violation=0), then that's + # what we'll return. + pp, _ = _find_best_candidate_distribution_helper(objective_vector, + constraints_matrix) + if pp is not None: + return pp + + # The bound is the minimum over all candidates, of the maximum per-candidate + # constraint violation. + lower = 0.0 + upper = np.min(np.amax(constraints_matrix, axis=0)) + best_pp, _ = _find_best_candidate_distribution_helper( + objective_vector, constraints_matrix, maximum_violation=upper) + assert best_pp is not None + + # Throughout this loop, a maximum_violation of "lower" is not achievable, + # but a maximum_violation of "upper" is achiveable. + while True: + middle = 0.5 * (lower + upper) + if (middle - lower <= epsilon) or (upper - middle <= epsilon): + break + else: + pp, _ = _find_best_candidate_distribution_helper( + objective_vector, constraints_matrix, maximum_violation=middle) + if pp is None: + lower = middle + else: + best_pp = pp + upper = middle + + return best_pp + + +def find_best_candidate_index(objective_vector, + constraints_matrix, + rank_objectives=False): + """Heuristically finds the best candidate solution to a constrained problem. + + This function deals with the constrained problem: + + > minimize f(w) + > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + + Here, f(w) is the "objective function", and g_i(w) is the ith (of m) + "constraint function". Given a set of n "candidate solutions" + {w_0,w_1,...,w_{n-1}}, this function finds the "best" solution according + to the following heuristic: + + 1. Across all models, the ith constraint violations (i.e. max{0, g_i(0)}) + are ranked, as are the objectives (if rank_objectives=True). + 2. Each model is then associated its MAXIMUM rank across all m constraints + (and the objective, if rank_objectives=True). + 3. The model with the minimal maximum rank is then identified. Ties are + broken using the objective function value. + 4. The index of this "best" model is returned. + + The `objective_vector` parameter should be a numpy array with shape (n,), for + which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a + numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j). + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + This function implements the heuristic used for hyperparameter search in the + experiments of Section 5.2. + + Args: + objective_vector: numpy array of shape (n,), where n is the number of + "candidate solutions". Contains the objective function values. + constraints_matrix: numpy array of shape (m,n), where m is the number of + constraints and n is the number of "candidate solutions". Contains the + constraint violation magnitudes. + rank_objectives: bool, whether the objective function values should be + included in the initial ranking step. If True, both the objective and + constraints will be ranked. If False, only the constraints will be ranked. + In either case, the objective function values will be used for + tiebreaking. + + Returns: + The index (in {0,1,...,n-1}) of the "best" model according to the above + heuristic. + + Raises: + ValueError: If `objective_vector` and `constraints_matrix` have inconsistent + shapes. + ImportError: If we're unable to import `scipy.stats`. + """ + mm, nn = np.shape(constraints_matrix) + if (nn,) != np.shape(objective_vector): + raise ValueError( + "objective_vector must have shape (n,), and constraints_matrix (m, n)," + " where n is the number of candidates, and m is the number of " + "constraints") + + # We import scipy inline, instead of at the top of the file, so that a scipy + # dependency is only introduced if either find_best_candidate_distribution() + # or find_best_candidate_index() are actually called. + import scipy.stats # pylint: disable=g-import-not-at-top + + if rank_objectives: + maximum_ranks = scipy.stats.rankdata(objective_vector, method="min") + else: + maximum_ranks = np.zeros(nn, dtype=np.int64) + for ii in xrange(mm): + # Take the maximum of the constraint functions with zero, since we want to + # rank the magnitude of constraint *violations*. If the constraint is + # satisfied, then we don't care how much it's satisfied by (as a result, we + # we expect all models satisfying a constraint to be tied at rank 1). + ranks = scipy.stats.rankdata( + np.maximum(0.0, constraints_matrix[ii, :]), method="min") + maximum_ranks = np.maximum(maximum_ranks, ranks) + + best_index = None + best_rank = float("Inf") + best_objective = float("Inf") + for ii in xrange(nn): + if maximum_ranks[ii] < best_rank: + best_index = ii + best_rank = maximum_ranks[ii] + best_objective = objective_vector[ii] + elif (maximum_ranks[ii] == best_rank) and (objective_vector[ii] <= + best_objective): + best_index = ii + best_objective = objective_vector[ii] + + return best_index diff --git a/tensorflow/contrib/constrained_optimization/python/candidates_test.py b/tensorflow/contrib/constrained_optimization/python/candidates_test.py new file mode 100644 index 0000000000..a4c49d48bc --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/candidates_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for constrained_optimization.python.candidates.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.constrained_optimization.python import candidates +from tensorflow.python.platform import test + + +class CandidatesTest(test.TestCase): + + def test_inconsistent_shapes_for_best_distribution(self): + """An error is raised when parameters have inconsistent shapes.""" + objective_vector = np.array([1, 2, 3]) + constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + with self.assertRaises(ValueError): + _ = candidates.find_best_candidate_distribution(objective_vector, + constraints_matrix) + + def test_inconsistent_shapes_for_best_index(self): + """An error is raised when parameters have inconsistent shapes.""" + objective_vector = np.array([1, 2, 3]) + constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + with self.assertRaises(ValueError): + _ = candidates.find_best_candidate_index(objective_vector, + constraints_matrix) + + def test_best_distribution(self): + """Distribution should match known solution.""" + objective_vector = np.array( + [0.03053309, -0.06667082, 0.88355145, 0.46529806]) + constraints_matrix = np.array( + [[-0.60164551, 0.36676229, 0.7856454, -0.8441711], + [0.00371592, -0.16392108, -0.59778071, -0.56908492]]) + distribution = candidates.find_best_candidate_distribution( + objective_vector, constraints_matrix) + # Verify that the solution is a probability distribution. + self.assertTrue(np.all(distribution >= 0)) + self.assertAlmostEqual(np.sum(distribution), 1.0) + # Verify that the solution satisfies the constraints. + maximum_constraint_violation = np.amax( + np.dot(constraints_matrix, distribution)) + self.assertLessEqual(maximum_constraint_violation, 0) + # Verify that the solution matches that which we expect. + expected_distribution = np.array([0.37872711, 0.62127289, 0, 0]) + self.assertAllClose(expected_distribution, distribution, rtol=0, atol=1e-6) + + def test_best_index_rank_objectives_true(self): + """Index should match known solution.""" + # Objective ranks = [2, 1, 4, 3]. + objective_vector = np.array( + [0.03053309, -0.06667082, 0.88355145, 0.46529806]) + # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]]. + constraints_matrix = np.array( + [[-0.60164551, 0.36676229, 0.7856454, -0.8441711], + [0.00371592, -0.16392108, -0.59778071, -0.56908492]]) + # Maximum ranks = [4, 3, 4, 3]. + index = candidates.find_best_candidate_index( + objective_vector, constraints_matrix, rank_objectives=True) + self.assertEqual(1, index) + + def test_best_index_rank_objectives_false(self): + """Index should match known solution.""" + # Objective ranks = [2, 1, 4, 3]. + objective_vector = np.array( + [0.03053309, -0.06667082, 0.88355145, 0.46529806]) + # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]]. + constraints_matrix = np.array( + [[-0.60164551, 0.36676229, 0.7856454, -0.8441711], + [0.00371592, -0.16392108, -0.59778071, -0.56908492]]) + # Maximum ranks = [4, 3, 4, 1]. + index = candidates.find_best_candidate_index( + objective_vector, constraints_matrix, rank_objectives=False) + self.assertEqual(3, index) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py new file mode 100644 index 0000000000..70813fb217 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines abstract class for `ConstrainedMinimizationProblem`s. + +A ConstrainedMinimizationProblem consists of an objective function to minimize, +and a set of constraint functions that are constrained to be nonpositive. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + + +@six.add_metaclass(abc.ABCMeta) +class ConstrainedMinimizationProblem(object): + """Abstract class representing a `ConstrainedMinimizationProblem`. + + A ConstrainedMinimizationProblem consists of an objective function to + minimize, and a set of constraint functions that are constrained to be + nonpositive. + + In addition to the constraint functions, there may (optionally) be proxy + constraint functions: a ConstrainedOptimizer will attempt to penalize these + proxy constraint functions so as to satisfy the (non-proxy) constraints. Proxy + constraints could be used if the constraints functions are difficult or + impossible to optimize (e.g. if they're piecewise constant), in which case the + proxy constraints should be some approximation of the original constraints + that is well-enough behaved to permit successful optimization. + """ + + @abc.abstractproperty + def objective(self): + """Returns the objective function. + + Returns: + A 0d tensor that should be minimized. + """ + pass + + @property + def num_constraints(self): + """Returns the number of constraints. + + Returns: + An int containing the number of constraints. + + Raises: + ValueError: If the constraints (or proxy_constraints, if present) do not + have fully-known shapes, OR if proxy_constraints are present, and the + shapes of constraints and proxy_constraints are fully-known, but they're + different. + """ + constraints_shape = self.constraints.get_shape() + if self.proxy_constraints is None: + proxy_constraints_shape = constraints_shape + else: + proxy_constraints_shape = self.proxy_constraints.get_shape() + + if (constraints_shape is None or proxy_constraints_shape is None or + any([ii is None for ii in constraints_shape.as_list()]) or + any([ii is None for ii in proxy_constraints_shape.as_list()])): + raise ValueError( + "constraints and proxy_constraints must have fully-known shapes") + if constraints_shape != proxy_constraints_shape: + raise ValueError( + "constraints and proxy_constraints must have the same shape") + + size = 1 + for ii in constraints_shape.as_list(): + size *= ii + return int(size) + + @abc.abstractproperty + def constraints(self): + """Returns the vector of constraint functions. + + Letting g_i be the ith element of the constraints vector, the ith constraint + will be g_i <= 0. + + Returns: + A tensor of constraint functions. + """ + pass + + # This is a property, instead of an abstract property, since it doesn't need + # to be overridden: if proxy_constraints returns None, then there are no + # proxy constraints. + @property + def proxy_constraints(self): + """Returns the optional vector of proxy constraint functions. + + The difference between `constraints` and `proxy_constraints` is that, when + proxy constraints are present, the `constraints` are merely EVALUATED during + optimization, whereas the `proxy_constraints` are DIFFERENTIATED. If there + are no proxy constraints, then the `constraints` are both evaluated and + differentiated. + + For example, if we want to impose constraints on step functions, then we + could use these functions for `constraints`. However, because a step + function has zero gradient almost everywhere, we can't differentiate these + functions, so we would take `proxy_constraints` to be some differentiable + approximation of `constraints`. + + Returns: + A tensor of proxy constraint functions. + """ + return None diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py new file mode 100644 index 0000000000..8055545366 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py @@ -0,0 +1,208 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines base class for `ConstrainedOptimizer`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.training import optimizer as train_optimizer + + +@six.add_metaclass(abc.ABCMeta) +class ConstrainedOptimizer(object): + """Base class representing a constrained optimizer. + + A ConstrainedOptimizer wraps a tf.train.Optimizer (or more than one), and + applies it to a ConstrainedMinimizationProblem. Unlike a tf.train.Optimizer, + which takes a tensor to minimize as a parameter to its minimize() method, a + constrained optimizer instead takes a ConstrainedMinimizationProblem. + """ + + def __init__(self, optimizer): + """Constructs a new `ConstrainedOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the + ConstraintedMinimizationProblem. + + Returns: + A new `ConstrainedOptimizer`. + """ + self._optimizer = optimizer + + @property + def optimizer(self): + """Returns the `tf.train.Optimizer` used for optimization.""" + return self._optimizer + + def minimize_unconstrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the unconstrained problem. + + Unlike `minimize_constrained`, this function ignores the `constraints` (and + `proxy_constraints`) portion of the minimization problem entirely, and only + minimizes `objective`. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + return self.optimizer.minimize( + minimization_problem.objective, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + @abc.abstractmethod + def minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + Unlike `minimize_unconstrained`, this function attempts to find a solution + that minimizes the `objective` portion of the minimization problem while + satisfying the `constraints` portion. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + pass + + def minimize(self, + minimization_problem, + unconstrained_steps=None, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + This method combines the functionality of `minimize_unconstrained` and + `minimize_constrained`. If global_step < unconstrained_steps, it will + perform an unconstrained update, and if global_step >= unconstrained_steps, + it will perform a constrained update. + + The reason for this functionality is that it may be best to initialize the + constrained optimizer with an approximate optimum of the unconstrained + problem. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + unconstrained_steps: int, number of steps for which we should perform + unconstrained updates, before transitioning to constrained updates. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + + Raises: + ValueError: If unconstrained_steps is provided, but global_step is not. + """ + + def unconstrained_fn(): + """Returns an `Op` for minimizing the unconstrained problem.""" + return self.minimize_unconstrained( + minimization_problem=minimization_problem, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + def constrained_fn(): + """Returns an `Op` for minimizing the constrained problem.""" + return self.minimize_constrained( + minimization_problem=minimization_problem, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + if unconstrained_steps is not None: + if global_step is None: + raise ValueError( + "global_step cannot be None if unconstrained_steps is provided") + unconstrained_steps_tensor = ops.convert_to_tensor(unconstrained_steps) + dtype = unconstrained_steps_tensor.dtype + return control_flow_ops.cond( + standard_ops.cast(global_step, dtype) < unconstrained_steps_tensor, + true_fn=unconstrained_fn, + false_fn=constrained_fn) + else: + return constrained_fn() diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py new file mode 100644 index 0000000000..01c6e4f08a --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py @@ -0,0 +1,375 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines `AdditiveExternalRegretOptimizer`. + +This optimizer minimizes a `ConstrainedMinimizationProblem` by introducing +Lagrange multipliers, and using `tf.train.Optimizer`s to jointly optimize over +the model parameters and Lagrange multipliers. + +For the purposes of constrained optimization, at least in theory, +external-regret minimization suffices if the `ConstrainedMinimizationProblem` +we're optimizing doesn't have any `proxy_constraints`, while swap-regret +minimization should be used if `proxy_constraints` are present. + +For more specifics, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +The formulation used by the AdditiveExternalRegretOptimizer--which is simply the +usual Lagrangian formulation--can be found in Definition 1, and is discussed in +Section 3. This optimizer is most similar to Algorithm 3 in Appendix C.3, with +the two differences being that it uses proxy constraints (if they're provided) +in the update of the model parameters, and uses `tf.train.Optimizer`s, instead +of SGD, for the "inner" updates. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.constrained_optimization.python import constrained_optimizer + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer as train_optimizer + + +def _project_multipliers_wrt_euclidean_norm(multipliers, radius): + """Projects its argument onto the feasible region. + + The feasible region is the set of all vectors with nonnegative elements that + sum to at most `radius`. + + Args: + multipliers: 1d tensor, the Lagrange multipliers to project. + radius: float, the radius of the feasible region. + + Returns: + The 1d tensor that results from projecting `multipliers` onto the feasible + region w.r.t. the Euclidean norm. + + Raises: + ValueError: if the `multipliers` tensor does not have a fully-known shape, + or is not one-dimensional. + """ + multipliers_shape = multipliers.get_shape() + if multipliers_shape is None: + raise ValueError("multipliers must have known shape") + if multipliers_shape.ndims != 1: + raise ValueError( + "multipliers must be one dimensional (instead is %d-dimensional)" % + multipliers_shape.ndims) + dimension = multipliers_shape[0].value + if dimension is None: + raise ValueError("multipliers must have fully-known shape") + + def while_loop_condition(iteration, multipliers, inactive, old_inactive): + """Returns false if the while loop should terminate.""" + del multipliers # Needed by the body, but not the condition. + not_done = (iteration < dimension) + not_converged = standard_ops.reduce_any( + standard_ops.not_equal(inactive, old_inactive)) + return standard_ops.logical_and(not_done, not_converged) + + def while_loop_body(iteration, multipliers, inactive, old_inactive): + """Performs one iteration of the projection.""" + del old_inactive # Needed by the condition, but not the body. + iteration += 1 + scale = standard_ops.minimum( + 0.0, + (radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum( + 1.0, standard_ops.reduce_sum(inactive))) + multipliers += scale * inactive + new_inactive = standard_ops.to_float(multipliers > 0) + multipliers *= new_inactive + return (iteration, multipliers, new_inactive, inactive) + + iteration = standard_ops.constant(0) + inactive = standard_ops.ones_like(multipliers) + + # We actually want a do-while loop, so we explicitly call while_loop_body() + # once before tf.while_loop(). + iteration, multipliers, inactive, old_inactive = while_loop_body( + iteration, multipliers, inactive, inactive) + iteration, multipliers, inactive, old_inactive = control_flow_ops.while_loop( + while_loop_condition, + while_loop_body, + loop_vars=(iteration, multipliers, inactive, old_inactive), + name="euclidean_projection") + + return multipliers + + +@six.add_metaclass(abc.ABCMeta) +class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): + """Base class representing an `_ExternalRegretOptimizer`. + + This class contains most of the logic for performing constrained + optimization, minimizing external regret for the constraints player. What it + *doesn't* do is keep track of the internal state (the Lagrange multipliers). + Instead, the state is accessed via the _initial_state(), + _lagrange_multipliers(), _constraint_grad_and_var() and _projection_op() + methods. + + The reason for this is that we want to make it easy to implement different + representations of the internal state. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by `_ExternalRegretOptimizer`s--which is simply the usual + Lagrangian formulation--can be found in Definition 1, and is discussed in + Section 3. Such optimizers are most similar to Algorithm 3 in Appendix C.3. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Constructs a new `_ExternalRegretOptimizer`. + + The difference between `optimizer` and `constraint_optimizer` (if the latter + is provided) is that the former is used for learning the model parameters, + while the latter us used for the Lagrange multipliers. If no + `constraint_optimizer` is provided, then `optimizer` is used for both. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of the ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multipliers. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multipliers. + + Returns: + A new `_ExternalRegretOptimizer`. + """ + super(_ExternalRegretOptimizer, self).__init__(optimizer=optimizer) + self._constraint_optimizer = constraint_optimizer + + @property + def constraint_optimizer(self): + """Returns the `tf.train.Optimizer` used for the Lagrange multipliers.""" + return self._constraint_optimizer + + @abc.abstractmethod + def _initial_state(self, num_constraints): + pass + + @abc.abstractmethod + def _lagrange_multipliers(self, state): + pass + + @abc.abstractmethod + def _constraint_grad_and_var(self, state, gradient): + pass + + @abc.abstractmethod + def _projection_op(self, state, name=None): + pass + + def minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + The `optimizer` constructor parameter will be used to update the model + parameters, while the Lagrange multipliers will be updated using + `constrained_optimizer` (if provided) or `optimizer` (if not). + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + objective = minimization_problem.objective + + constraints = minimization_problem.constraints + proxy_constraints = minimization_problem.proxy_constraints + if proxy_constraints is None: + proxy_constraints = constraints + # Flatten both constraints tensors to 1d. + num_constraints = minimization_problem.num_constraints + constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) + proxy_constraints = standard_ops.reshape( + proxy_constraints, shape=(num_constraints,)) + + # We use a lambda to initialize the state so that, if this function call is + # inside the scope of a tf.control_dependencies() block, the dependencies + # will not be applied to the initializer. + state = standard_ops.Variable( + lambda: self._initial_state(num_constraints), + trainable=False, + name="external_regret_optimizer_state") + + multipliers = self._lagrange_multipliers(state) + loss = ( + objective + standard_ops.tensordot(multipliers, proxy_constraints, 1)) + multipliers_gradient = constraints + + update_ops = [] + if self.constraint_optimizer is None: + # If we don't have a separate constraint_optimizer, then we use + # self._optimizer for both the update of the model parameters, and that of + # the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + grads_and_vars.append( + self._constraint_grad_and_var(state, multipliers_gradient)) + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + else: + # If we have a separate constraint_optimizer, then we use self._optimizer + # for the update of the model parameters, and self._constraint_optimizer + # for that of the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + multiplier_grads_and_vars = [ + self._constraint_grad_and_var(state, multipliers_gradient) + ] + + gradients = [ + gradient for gradient, _ in grads_and_vars + multiplier_grads_and_vars + if gradient is not None + ] + with ops.control_dependencies(gradients): + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + update_ops.append( + self.constraint_optimizer.apply_gradients( + multiplier_grads_and_vars, name="optimizer_state_update")) + + with ops.control_dependencies(update_ops): + if global_step is None: + # If we don't have a global step, just project, and we're done. + return self._projection_op(state, name=name) + else: + # If we have a global step, then we need to increment it in addition to + # projecting. + projection_op = self._projection_op(state, name="project") + with ops.colocate_with(global_step): + global_step_op = state_ops.assign_add( + global_step, 1, name="global_step_increment") + return control_flow_ops.group(projection_op, global_step_op, name=name) + + +class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer): + """A `ConstrainedOptimizer` based on external-regret minimization. + + This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + minimize over the model parameters, and maximize over Lagrange multipliers, + with the latter maximization using additive updates and an algorithm that + minimizes external regret. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by this optimizer--which is simply the usual Lagrangian + formulation--can be found in Definition 1, and is discussed in Section 3. It + is most similar to Algorithm 3 in Appendix C.3, with the two differences being + that it uses proxy constraints (if they're provided) in the update of the + model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for the + "inner" updates. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + maximum_multiplier_radius=None): + """Constructs a new `AdditiveExternalRegretOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multipliers. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multipliers. + maximum_multiplier_radius: float, an optional upper bound to impose on the + sum of the Lagrange multipliers. + + Returns: + A new `AdditiveExternalRegretOptimizer`. + + Raises: + ValueError: If the maximum_multiplier_radius parameter is nonpositive. + """ + super(AdditiveExternalRegretOptimizer, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + + if maximum_multiplier_radius and (maximum_multiplier_radius <= 0.0): + raise ValueError("maximum_multiplier_radius must be strictly positive") + + self._maximum_multiplier_radius = maximum_multiplier_radius + + def _initial_state(self, num_constraints): + # For an AdditiveExternalRegretOptimizer, the internal state is simply a + # tensor of Lagrange multipliers with shape (m,), where m is the number of + # constraints. + return standard_ops.zeros((num_constraints,), dtype=dtypes.float32) + + def _lagrange_multipliers(self, state): + return state + + def _constraint_grad_and_var(self, state, gradient): + # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + return (-gradient, state) + + def _projection_op(self, state, name=None): + with ops.colocate_with(state): + if self._maximum_multiplier_radius: + projected_multipliers = _project_multipliers_wrt_euclidean_norm( + state, self._maximum_multiplier_radius) + else: + projected_multipliers = standard_ops.maximum(state, 0.0) + return state_ops.assign(state, projected_multipliers, name=name) diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py new file mode 100644 index 0000000000..9b4bf62710 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py @@ -0,0 +1,136 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for constrained_optimization.python.external_regret_optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.constrained_optimization.python import external_regret_optimizer +from tensorflow.contrib.constrained_optimization.python import test_util + +from tensorflow.python.ops import standard_ops +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent + + +class AdditiveExternalRegretOptimizerWrapper( + external_regret_optimizer.AdditiveExternalRegretOptimizer): + """Testing wrapper class around AdditiveExternalRegretOptimizer. + + This class is identical to AdditiveExternalRegretOptimizer, except that it + caches the internal optimization state when _lagrange_multipliers() is called, + so that we can test that the Lagrange multipliers take on their expected + values. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + maximum_multiplier_radius=None): + """Same as AdditiveExternalRegretOptimizer.__init__.""" + super(AdditiveExternalRegretOptimizerWrapper, self).__init__( + optimizer=optimizer, + constraint_optimizer=constraint_optimizer, + maximum_multiplier_radius=maximum_multiplier_radius) + self._cached_lagrange_multipliers = None + + @property + def lagrange_multipliers(self): + """Returns the cached Lagrange multipliers.""" + return self._cached_lagrange_multipliers + + def _lagrange_multipliers(self, state): + """Caches the internal state for testing.""" + self._cached_lagrange_multipliers = super( + AdditiveExternalRegretOptimizerWrapper, + self)._lagrange_multipliers(state) + return self._cached_lagrange_multipliers + + +class ExternalRegretOptimizerTest(test.TestCase): + + def test_project_multipliers_wrt_euclidean_norm(self): + """Tests Euclidean projection routine on some known values.""" + multipliers1 = standard_ops.constant([-0.1, -0.6, -0.3]) + expected_projected_multipliers1 = np.array([0.0, 0.0, 0.0]) + + multipliers2 = standard_ops.constant([-0.1, 0.6, 0.3]) + expected_projected_multipliers2 = np.array([0.0, 0.6, 0.3]) + + multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1]) + expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0]) + + with self.test_session() as session: + projected_multipliers1 = session.run( + external_regret_optimizer._project_multipliers_wrt_euclidean_norm( + multipliers1, 1.0)) + projected_multipliers2 = session.run( + external_regret_optimizer._project_multipliers_wrt_euclidean_norm( + multipliers2, 1.0)) + projected_multipliers3 = session.run( + external_regret_optimizer._project_multipliers_wrt_euclidean_norm( + multipliers3, 1.0)) + + self.assertAllClose( + expected_projected_multipliers1, + projected_multipliers1, + rtol=0, + atol=1e-6) + self.assertAllClose( + expected_projected_multipliers2, + projected_multipliers2, + rtol=0, + atol=1e-6) + self.assertAllClose( + expected_projected_multipliers3, + projected_multipliers3, + rtol=0, + atol=1e-6) + + def test_additive_external_regret_optimizer(self): + """Tests that the Lagrange multipliers update as expected.""" + minimization_problem = test_util.ConstantMinimizationProblem( + np.array([0.6, -0.1, 0.4])) + optimizer = AdditiveExternalRegretOptimizerWrapper( + gradient_descent.GradientDescentOptimizer(1.0), + maximum_multiplier_radius=1.0) + train_op = optimizer.minimize_constrained(minimization_problem) + + expected_multipliers = [ + np.array([0.0, 0.0, 0.0]), + np.array([0.6, 0.0, 0.4]), + np.array([0.7, 0.0, 0.3]), + np.array([0.8, 0.0, 0.2]), + np.array([0.9, 0.0, 0.1]), + np.array([1.0, 0.0, 0.0]), + np.array([1.0, 0.0, 0.0]), + ] + + multipliers = [] + with self.test_session() as session: + session.run(standard_ops.global_variables_initializer()) + while len(multipliers) < len(expected_multipliers): + multipliers.append(session.run(optimizer.lagrange_multipliers)) + session.run(train_op) + + for expected, actual in zip(expected_multipliers, multipliers): + self.assertAllClose(expected, actual, rtol=0, atol=1e-6) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py new file mode 100644 index 0000000000..04014ab4ae --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -0,0 +1,595 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines `{Additive,Multiplicative}SwapRegretOptimizer`s. + +These optimizers minimize a `ConstrainedMinimizationProblem` by using a +swap-regret minimizing algorithm (either SGD or multiplicative weights) to learn +what weights should be associated with the objective function and constraints. +These algorithms do *not* use Lagrange multipliers, but the idea is similar. +The main differences between the formulation used here, and the standard +Lagrangian formulation, are that (i) the objective function is weighted, in +addition to the constraints, and (ii) we learn a matrix of weights, instead of a +vector. + +For the purposes of constrained optimization, at least in theory, +external-regret minimization suffices if the `ConstrainedMinimizationProblem` +we're optimizing doesn't have any `proxy_constraints`, while swap-regret +minimization should be used if `proxy_constraints` are present. + +For more specifics, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +The formulation used by both of the SwapRegretOptimizers can be found in +Definition 2, and is discussed in Section 4. The +`MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 in Section 4, +with the difference being that it uses `tf.train.Optimizer`s, instead of SGD, +for the "inner" updates. The `AdditiveSwapRegretOptimizer` differs further in +that it performs additive (instead of multiplicative) updates of the stochastic +matrix. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import math + +import six + +from tensorflow.contrib.constrained_optimization.python import constrained_optimizer + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer as train_optimizer + + +def _maximal_eigenvector_power_method(matrix, + epsilon=1e-6, + maximum_iterations=100): + """Returns the maximal right-eigenvector of `matrix` using the power method. + + Args: + matrix: 2D Tensor, the matrix of which we will find the maximal + right-eigenvector. + epsilon: nonnegative float, if two iterations of the power method differ (in + L2 norm) by no more than epsilon, we will terminate. + maximum_iterations: nonnegative int, if we perform this many iterations, we + will terminate. + + Result: + The maximal right-eigenvector of `matrix`. + + Raises: + ValueError: If the epsilon or maximum_iterations parameters violate their + bounds. + """ + if epsilon <= 0.0: + raise ValueError("epsilon must be strictly positive") + if maximum_iterations <= 0: + raise ValueError("maximum_iterations must be strictly positive") + + def while_loop_condition(iteration, eigenvector, old_eigenvector): + """Returns false if the while loop should terminate.""" + not_done = (iteration < maximum_iterations) + not_converged = (standard_ops.norm(eigenvector - old_eigenvector) > epsilon) + return standard_ops.logical_and(not_done, not_converged) + + def while_loop_body(iteration, eigenvector, old_eigenvector): + """Performs one iteration of the power method.""" + del old_eigenvector # Needed by the condition, but not the body. + iteration += 1 + # We need to use tf.matmul() and tf.expand_dims(), instead of + # tf.tensordot(), since the former will infer the shape of the result, while + # the latter will not (tf.while_loop() needs the shapes). + new_eigenvector = standard_ops.matmul( + matrix, standard_ops.expand_dims(eigenvector, 1))[:, 0] + new_eigenvector /= standard_ops.norm(new_eigenvector) + return (iteration, new_eigenvector, eigenvector) + + iteration = standard_ops.constant(0) + eigenvector = standard_ops.ones_like(matrix[:, 0]) + eigenvector /= standard_ops.norm(eigenvector) + + # We actually want a do-while loop, so we explicitly call while_loop_body() + # once before tf.while_loop(). + iteration, eigenvector, old_eigenvector = while_loop_body( + iteration, eigenvector, eigenvector) + iteration, eigenvector, old_eigenvector = control_flow_ops.while_loop( + while_loop_condition, + while_loop_body, + loop_vars=(iteration, eigenvector, old_eigenvector), + name="power_method") + + return eigenvector + + +def _project_stochastic_matrix_wrt_euclidean_norm(matrix): + """Projects its argument onto the set of left-stochastic matrices. + + This algorithm is O(n^3) at worst, where `matrix` is n*n. It can be done in + O(n^2 * log(n)) time by sorting each column (and maybe better with a different + algorithm), but the algorithm implemented here is easier to implement in + TensorFlow. + + Args: + matrix: 2d square tensor, the matrix to project. + + Returns: + The 2d square tensor that results from projecting `matrix` onto the set of + left-stochastic matrices w.r.t. the Euclidean norm applied column-wise + (i.e. the Frobenius norm). + + Raises: + ValueError: if the `matrix` tensor does not have a fully-known shape, or is + not two-dimensional and square. + """ + matrix_shape = matrix.get_shape() + if matrix_shape is None: + raise ValueError("matrix must have known shape") + if matrix_shape.ndims != 2: + raise ValueError( + "matrix must be two dimensional (instead is %d-dimensional)" % + matrix_shape.ndims) + if matrix_shape[0] != matrix_shape[1]: + raise ValueError("matrix must be be square (instead has shape (%d,%d))" % + (matrix_shape[0], matrix_shape[1])) + dimension = matrix_shape[0].value + if dimension is None: + raise ValueError("matrix must have fully-known shape") + + def while_loop_condition(iteration, matrix, inactive, old_inactive): + """Returns false if the while loop should terminate.""" + del matrix # Needed by the body, but not the condition. + not_done = (iteration < dimension) + not_converged = standard_ops.reduce_any( + standard_ops.not_equal(inactive, old_inactive)) + return standard_ops.logical_and(not_done, not_converged) + + def while_loop_body(iteration, matrix, inactive, old_inactive): + """Performs one iteration of the projection.""" + del old_inactive # Needed by the condition, but not the body. + iteration += 1 + scale = (1.0 - standard_ops.reduce_sum( + matrix, axis=0, keep_dims=True)) / standard_ops.maximum( + 1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True)) + matrix += scale * inactive + new_inactive = standard_ops.to_float(matrix > 0) + matrix *= new_inactive + return (iteration, matrix, new_inactive, inactive) + + iteration = standard_ops.constant(0) + inactive = standard_ops.ones_like(matrix) + + # We actually want a do-while loop, so we explicitly call while_loop_body() + # once before tf.while_loop(). + iteration, matrix, inactive, old_inactive = while_loop_body( + iteration, matrix, inactive, inactive) + iteration, matrix, inactive, old_inactive = control_flow_ops.while_loop( + while_loop_condition, + while_loop_body, + loop_vars=(iteration, matrix, inactive, old_inactive), + name="euclidean_projection") + + return matrix + + +def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix): + """Projects its argument onto the set of log-left-stochastic matrices. + + Args: + log_matrix: 2d square tensor, the element-wise logarithm of the matrix to + project. + + Returns: + The 2d square tensor that results from projecting exp(`matrix`) onto the set + of left-stochastic matrices w.r.t. the KL-divergence applied column-wise. + """ + + # For numerical reasons, make sure that the largest matrix element is zero + # before exponentiating. + log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True) + log_matrix -= standard_ops.log( + standard_ops.reduce_sum( + standard_ops.exp(log_matrix), axis=0, keep_dims=True)) + return log_matrix + + +@six.add_metaclass(abc.ABCMeta) +class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): + """Base class representing a `_SwapRegretOptimizer`. + + This class contains most of the logic for performing constrained optimization, + minimizing external regret for the constraints player. What it *doesn't* do is + keep track of the internal state (the stochastic matrix). Instead, the state + is accessed via the _initial_state(), _stochastic_matrix(), + _constraint_grad_and_var() and _projection_op() methods. + + The reason for this is that we want to make it easy to implement different + representations of the internal state. For example, for additive updates, it's + most natural to store the stochastic matrix directly, whereas for + multiplicative updates, it's most natural to store its element-wise logarithm. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by `_SwapRegretOptimizer`s can be found in Definition 2, + and is discussed in Section 4. Such optimizers are most similar to Algorithm + 2 in Section 4. Most notably, the internal state is a left-stochastic matrix + of shape (m+1,m+1), where m is the number of constraints. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Constructs a new `_SwapRegretOptimizer`. + + The difference between `optimizer` and `constraint_optimizer` (if the latter + is provided) is that the former is used for learning the model parameters, + while the latter us used for the update to the constraint/objective weight + matrix (the analogue of Lagrange multipliers). If no `constraint_optimizer` + is provided, then `optimizer` is used for both. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multiplier analogues. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multiplier analogues. + + Returns: + A new `_SwapRegretOptimizer`. + """ + super(_SwapRegretOptimizer, self).__init__(optimizer=optimizer) + self._constraint_optimizer = constraint_optimizer + + @property + def constraint_optimizer(self): + """Returns the `tf.train.Optimizer` used for the matrix.""" + return self._constraint_optimizer + + @abc.abstractmethod + def _initial_state(self, num_constraints): + pass + + @abc.abstractmethod + def _stochastic_matrix(self, state): + pass + + def _distribution(self, state): + distribution = _maximal_eigenvector_power_method( + self._stochastic_matrix(state)) + distribution = standard_ops.abs(distribution) + distribution /= standard_ops.reduce_sum(distribution) + return distribution + + @abc.abstractmethod + def _constraint_grad_and_var(self, state, gradient): + pass + + @abc.abstractmethod + def _projection_op(self, state, name=None): + pass + + def minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + The `optimizer` constructor parameter will be used to update the model + parameters, while the constraint/objective weight matrix (the analogue of + Lagrange multipliers) will be updated using `constrained_optimizer` (if + provided) or `optimizer` (if not). Whether the matrix updates are additive + or multiplicative depends on the derived class. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + objective = minimization_problem.objective + + constraints = minimization_problem.constraints + proxy_constraints = minimization_problem.proxy_constraints + if proxy_constraints is None: + proxy_constraints = constraints + # Flatten both constraints tensors to 1d. + num_constraints = minimization_problem.num_constraints + constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) + proxy_constraints = standard_ops.reshape( + proxy_constraints, shape=(num_constraints,)) + + # We use a lambda to initialize the state so that, if this function call is + # inside the scope of a tf.control_dependencies() block, the dependencies + # will not be applied to the initializer. + state = standard_ops.Variable( + lambda: self._initial_state(num_constraints), + trainable=False, + name="swap_regret_optimizer_state") + + zero_and_constraints = standard_ops.concat( + (standard_ops.zeros((1,)), constraints), axis=0) + objective_and_proxy_constraints = standard_ops.concat( + (standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0) + + distribution = self._distribution(state) + loss = standard_ops.tensordot(distribution, objective_and_proxy_constraints, + 1) + matrix_gradient = standard_ops.matmul( + standard_ops.expand_dims(zero_and_constraints, 1), + standard_ops.expand_dims(distribution, 0)) + + update_ops = [] + if self.constraint_optimizer is None: + # If we don't have a separate constraint_optimizer, then we use + # self._optimizer for both the update of the model parameters, and that of + # the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + grads_and_vars.append( + self._constraint_grad_and_var(state, matrix_gradient)) + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + else: + # If we have a separate constraint_optimizer, then we use self._optimizer + # for the update of the model parameters, and self._constraint_optimizer + # for that of the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + matrix_grads_and_vars = [ + self._constraint_grad_and_var(state, matrix_gradient) + ] + + gradients = [ + gradient for gradient, _ in grads_and_vars + matrix_grads_and_vars + if gradient is not None + ] + with ops.control_dependencies(gradients): + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + update_ops.append( + self.constraint_optimizer.apply_gradients( + matrix_grads_and_vars, name="optimizer_state_update")) + + with ops.control_dependencies(update_ops): + if global_step is None: + # If we don't have a global step, just project, and we're done. + return self._projection_op(state, name=name) + else: + # If we have a global step, then we need to increment it in addition to + # projecting. + projection_op = self._projection_op(state, name="project") + with ops.colocate_with(global_step): + global_step_op = state_ops.assign_add( + global_step, 1, name="global_step_increment") + return control_flow_ops.group(projection_op, global_step_op, name=name) + + +class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer): + """A `ConstrainedOptimizer` based on swap-regret minimization. + + This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + minimize over the model parameters, and maximize over constraint/objective + weight matrix (the analogue of Lagrange multipliers), with the latter + maximization using additive updates and an algorithm that minimizes swap + regret. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by this optimizer can be found in Definition 2, and is + discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with + the differences being that it uses `tf.train.Optimizer`s, instead of SGD, for + the "inner" updates, and performs additive (instead of multiplicative) updates + of the stochastic matrix. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Constructs a new `AdditiveSwapRegretOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multiplier analogues. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multiplier analogues. + + Returns: + A new `AdditiveSwapRegretOptimizer`. + """ + # TODO(acotter): add a parameter determining the initial values of the + # matrix elements (like initial_multiplier_radius in + # MultiplicativeSwapRegretOptimizer). + super(AdditiveSwapRegretOptimizer, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + + def _initial_state(self, num_constraints): + # For an AdditiveSwapRegretOptimizer, the internal state is a tensor of + # shape (m+1,m+1), where m is the number of constraints, representing a + # left-stochastic matrix. + dimension = num_constraints + 1 + # Initialize by putting all weight on the objective, and none on the + # constraints. + return standard_ops.concat( + (standard_ops.ones( + (1, dimension)), standard_ops.zeros((dimension - 1, dimension))), + axis=0) + + def _stochastic_matrix(self, state): + return state + + def _constraint_grad_and_var(self, state, gradient): + # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + return (-gradient, state) + + def _projection_op(self, state, name=None): + with ops.colocate_with(state): + return state_ops.assign( + state, + _project_stochastic_matrix_wrt_euclidean_norm(state), + name=name) + + +class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): + """A `ConstrainedOptimizer` based on swap-regret minimization. + + This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + minimize over the model parameters, and maximize over constraint/objective + weight matrix (the analogue of Lagrange multipliers), with the latter + maximization using multiplicative updates and an algorithm that minimizes swap + regret. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by this optimizer can be found in Definition 2, and is + discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with + the difference being that it uses `tf.train.Optimizer`s, instead of SGD, for + the "inner" updates. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + minimum_multiplier_radius=1e-3, + initial_multiplier_radius=None): + """Constructs a new `MultiplicativeSwapRegretOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multiplier analogues. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multiplier analogues. + minimum_multiplier_radius: float, each element of the matrix will be lower + bounded by `minimum_multiplier_radius` divided by one plus the number of + constraints. + initial_multiplier_radius: float, the initial value of each element of the + matrix associated with a constraint (i.e. excluding those elements + associated with the objective) will be `initial_multiplier_radius` + divided by one plus the number of constraints. Defaults to the value of + `minimum_multiplier_radius`. + + Returns: + A new `MultiplicativeSwapRegretOptimizer`. + + Raises: + ValueError: If the two radius parameters are inconsistent. + """ + super(MultiplicativeSwapRegretOptimizer, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + + if (minimum_multiplier_radius <= 0.0) or (minimum_multiplier_radius >= 1.0): + raise ValueError("minimum_multiplier_radius must be in the range (0,1)") + if initial_multiplier_radius is None: + initial_multiplier_radius = minimum_multiplier_radius + elif (initial_multiplier_radius < + minimum_multiplier_radius) or (minimum_multiplier_radius > 1.0): + raise ValueError("initial_multiplier_radius must be in the range " + "[minimum_multiplier_radius,1]") + + self._minimum_multiplier_radius = minimum_multiplier_radius + self._initial_multiplier_radius = initial_multiplier_radius + + def _initial_state(self, num_constraints): + # For a MultiplicativeSwapRegretOptimizer, the internal state is a tensor of + # shape (m+1,m+1), where m is the number of constraints, representing the + # element-wise logarithm of a left-stochastic matrix. + dimension = num_constraints + 1 + # Initialize by putting as much weight as possible on the objective, and as + # little as possible on the constraints. + log_initial_one = math.log(1.0 - (self._initial_multiplier_radius * + (dimension - 1) / (dimension))) + log_initial_zero = math.log(self._initial_multiplier_radius / dimension) + return standard_ops.concat( + (standard_ops.constant( + log_initial_one, dtype=dtypes.float32, shape=(1, dimension)), + standard_ops.constant( + log_initial_zero, + dtype=dtypes.float32, + shape=(dimension - 1, dimension))), + axis=0) + + def _stochastic_matrix(self, state): + return standard_ops.exp(state) + + def _constraint_grad_and_var(self, state, gradient): + # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + return (-gradient, state) + + def _projection_op(self, state, name=None): + with ops.colocate_with(state): + # Gets the dimension of the state (num_constraints + 1)--all of these + # assertions are of things that should be impossible, since the state + # passed into this method will have the same shape as that returned by + # _initial_state(). + state_shape = state.get_shape() + assert state_shape is not None + assert state_shape.ndims == 2 + assert state_shape[0] == state_shape[1] + dimension = state_shape[0].value + assert dimension is not None + + minimum_log_multiplier = standard_ops.log( + self._minimum_multiplier_radius / standard_ops.to_float(dimension)) + + return state_ops.assign( + state, + standard_ops.maximum( + _project_log_stochastic_matrix_wrt_kl_divergence(state), + minimum_log_multiplier), + name=name) diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py new file mode 100644 index 0000000000..34c4543dca --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py @@ -0,0 +1,212 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for constrained_optimization.python.swap_regret_optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.constrained_optimization.python import swap_regret_optimizer +from tensorflow.contrib.constrained_optimization.python import test_util + +from tensorflow.python.ops import standard_ops +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent + + +class AdditiveSwapRegretOptimizerWrapper( + swap_regret_optimizer.AdditiveSwapRegretOptimizer): + """Testing wrapper class around AdditiveSwapRegretOptimizer. + + This class is identical to AdditiveSwapRegretOptimizer, except that it caches + the internal optimization state when _stochastic_matrix() is called, so that + we can test that the stochastic matrices take on their expected values. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Same as AdditiveSwapRegretOptimizer.__init__().""" + super(AdditiveSwapRegretOptimizerWrapper, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + self._cached_stochastic_matrix = None + + @property + def stochastic_matrix(self): + """Returns the cached stochastic matrix.""" + return self._cached_stochastic_matrix + + def _stochastic_matrix(self, state): + """Caches the internal state for testing.""" + self._cached_stochastic_matrix = super(AdditiveSwapRegretOptimizerWrapper, + self)._stochastic_matrix(state) + return self._cached_stochastic_matrix + + +class MultiplicativeSwapRegretOptimizerWrapper( + swap_regret_optimizer.MultiplicativeSwapRegretOptimizer): + """Testing wrapper class around MultiplicativeSwapRegretOptimizer. + + This class is identical to MultiplicativeSwapRegretOptimizer, except that it + caches the internal optimization state when _stochastic_matrix() is called, so + that we can test that the stochastic matrices take on their expected values. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + minimum_multiplier_radius=None, + initial_multiplier_radius=None): + """Same as MultiplicativeSwapRegretOptimizer.__init__().""" + super(MultiplicativeSwapRegretOptimizerWrapper, self).__init__( + optimizer=optimizer, + constraint_optimizer=constraint_optimizer, + minimum_multiplier_radius=1e-3, + initial_multiplier_radius=initial_multiplier_radius) + self._cached_stochastic_matrix = None + + @property + def stochastic_matrix(self): + """Returns the cached stochastic matrix.""" + return self._cached_stochastic_matrix + + def _stochastic_matrix(self, state): + """Caches the internal state for testing.""" + self._cached_stochastic_matrix = super( + MultiplicativeSwapRegretOptimizerWrapper, + self)._stochastic_matrix(state) + return self._cached_stochastic_matrix + + +class SwapRegretOptimizerTest(test.TestCase): + + def test_maximum_eigenvector_power_method(self): + """Tests power method routine on some known left-stochastic matrices.""" + matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]]) + matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]]) + + with self.test_session() as session: + eigenvector1 = session.run( + swap_regret_optimizer._maximal_eigenvector_power_method( + standard_ops.constant(matrix1))) + eigenvector2 = session.run( + swap_regret_optimizer._maximal_eigenvector_power_method( + standard_ops.constant(matrix2))) + + # Check that eigenvector1 and eigenvector2 are eigenvectors of matrix1 and + # matrix2 (respectively) with associated eigenvalue 1. + matrix_eigenvector1 = np.tensordot(matrix1, eigenvector1, axes=1) + matrix_eigenvector2 = np.tensordot(matrix2, eigenvector2, axes=1) + self.assertAllClose(eigenvector1, matrix_eigenvector1, rtol=0, atol=1e-6) + self.assertAllClose(eigenvector2, matrix_eigenvector2, rtol=0, atol=1e-6) + + def test_project_stochastic_matrix_wrt_euclidean_norm(self): + """Tests Euclidean projection routine on some known values.""" + matrix = standard_ops.constant([[-0.1, -0.1, 0.4], [-0.8, 0.4, 1.2], + [-0.3, 0.1, 0.2]]) + expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], + [0.4, 0.3, 0.0]]) + + with self.test_session() as session: + projected_matrix = session.run( + swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm( + matrix)) + + self.assertAllClose( + expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6) + + def test_project_log_stochastic_matrix_wrt_kl_divergence(self): + """Tests KL-divergence projection routine on some known values.""" + matrix = standard_ops.constant([[0.2, 0.8, 0.6], [0.1, 0.2, 1.5], + [0.2, 1.0, 0.9]]) + expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], + [0.4, 0.5, 0.3]]) + + with self.test_session() as session: + projected_matrix = session.run( + standard_ops.exp( + swap_regret_optimizer. + _project_log_stochastic_matrix_wrt_kl_divergence( + standard_ops.log(matrix)))) + + self.assertAllClose( + expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6) + + def test_additive_swap_regret_optimizer(self): + """Tests that the stochastic matrices update as expected.""" + minimization_problem = test_util.ConstantMinimizationProblem( + np.array([0.6, -0.1, 0.4])) + optimizer = AdditiveSwapRegretOptimizerWrapper( + gradient_descent.GradientDescentOptimizer(1.0)) + train_op = optimizer.minimize_constrained(minimization_problem) + + # Calculated using a numpy+python implementation of the algorithm. + expected_matrices = [ + np.array([[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), + np.array([[0.66666667, 1.0, 1.0, 1.0], [0.26666667, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], [0.06666667, 0.0, 0.0, 0.0]]), + np.array([[0.41666667, 0.93333333, 1.0, + 0.98333333], [0.46666667, 0.05333333, 0.0, + 0.01333333], [0.0, 0.0, 0.0, 0.0], + [0.11666667, 0.01333333, 0.0, 0.00333333]]), + ] + + matrices = [] + with self.test_session() as session: + session.run(standard_ops.global_variables_initializer()) + while len(matrices) < len(expected_matrices): + matrices.append(session.run(optimizer.stochastic_matrix)) + session.run(train_op) + + for expected, actual in zip(expected_matrices, matrices): + self.assertAllClose(expected, actual, rtol=0, atol=1e-6) + + def test_multiplicative_swap_regret_optimizer(self): + """Tests that the stochastic matrices update as expected.""" + minimization_problem = test_util.ConstantMinimizationProblem( + np.array([0.6, -0.1, 0.4])) + optimizer = MultiplicativeSwapRegretOptimizerWrapper( + gradient_descent.GradientDescentOptimizer(1.0), + initial_multiplier_radius=0.8) + train_op = optimizer.minimize_constrained(minimization_problem) + + # Calculated using a numpy+python implementation of the algorithm. + expected_matrices = [ + np.array([[0.4, 0.4, 0.4, 0.4], [0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2], [0.2, 0.2, 0.2, 0.2]]), + np.array([[0.36999014, 0.38528351, 0.38528351, 0.38528351], [ + 0.23517483, 0.21720297, 0.21720297, 0.21720297 + ], [0.17774131, 0.18882719, 0.18882719, 0.18882719], + [0.21709373, 0.20868632, 0.20868632, 0.20868632]]), + np.array([[0.33972109, 0.36811863, 0.37118462, 0.36906575], [ + 0.27114826, 0.23738228, 0.23376693, 0.23626491 + ], [0.15712313, 0.17641793, 0.17858959, 0.17708679], + [0.23200752, 0.21808115, 0.21645886, 0.21758255]]), + ] + + matrices = [] + with self.test_session() as session: + session.run(standard_ops.global_variables_initializer()) + while len(matrices) < len(expected_matrices): + matrices.append(session.run(optimizer.stochastic_matrix)) + session.run(train_op) + + for expected, actual in zip(expected_matrices, matrices): + self.assertAllClose(expected, actual, rtol=0, atol=1e-6) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/test_util.py b/tensorflow/contrib/constrained_optimization/python/test_util.py new file mode 100644 index 0000000000..704b36ca4c --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/test_util.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains helpers used by tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.constrained_optimization.python import constrained_minimization_problem + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import standard_ops + + +class ConstantMinimizationProblem( + constrained_minimization_problem.ConstrainedMinimizationProblem): + """A `ConstrainedMinimizationProblem` with constant constraint violations. + + This minimization problem is intended for use in performing simple tests of + the Lagrange multiplier (or equivalent) update in the optimizers. There is a + one-element "dummy" model parameter, but it should be ignored. + """ + + def __init__(self, constraints): + """Constructs a new `ConstantMinimizationProblem'. + + Args: + constraints: 1d numpy array, the constant constraint violations. + + Returns: + A new `ConstantMinimizationProblem'. + """ + # We make an fake 1-parameter linear objective so that we don't get a "no + # variables to optimize" error. + self._objective = standard_ops.Variable(0.0, dtype=dtypes.float32) + self._constraints = standard_ops.constant(constraints, dtype=dtypes.float32) + + @property + def objective(self): + """Returns the objective function.""" + return self._objective + + @property + def constraints(self): + """Returns the constant constraint violations.""" + return self._constraints |