aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/constrained_optimization
diff options
context:
space:
mode:
authorGravatar Andrew Cotter <acotter@google.com>2018-04-23 15:57:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 16:01:56 -0700
commitff15c81e2b92ef8fb47bb15790cffd18377a4ef2 (patch)
tree5c806cfb8155cba0b7eae9ea137f62e9777e73e6 /tensorflow/contrib/constrained_optimization
parentbb4a80c92105426ccf20a98c4291a1a3f8499b54 (diff)
This is a library for performing constrained optimization. It defines two interfaces: ConstrainedMinimizationProblem, which specifies a constrained optimization problem, and ConstrainedOptimizer, which is slightly different from a tf.train.Optimizer, mostly due to the fact that it is meant to optimize ConstrainedMinimizationProblems. In addition to these two interfaces, three ConstrainedOptimizer implementations are included, as well as helper functions which, given a set of candidate solutions, heuristically find the best candidate (to the constrained problem), or the best distribution over candidates.
For more details, please see our arXiv paper: "https://arxiv.org/abs/1804.06500". PiperOrigin-RevId: 193999550
Diffstat (limited to 'tensorflow/contrib/constrained_optimization')
-rw-r--r--tensorflow/contrib/constrained_optimization/BUILD91
-rw-r--r--tensorflow/contrib/constrained_optimization/README.md345
-rw-r--r--tensorflow/contrib/constrained_optimization/__init__.py41
-rw-r--r--tensorflow/contrib/constrained_optimization/python/candidates.py319
-rw-r--r--tensorflow/contrib/constrained_optimization/python/candidates_test.py95
-rw-r--r--tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py123
-rw-r--r--tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py208
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py375
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py136
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py595
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py212
-rw-r--r--tensorflow/contrib/constrained_optimization/python/test_util.py58
12 files changed, 2598 insertions, 0 deletions
diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD
new file mode 100644
index 0000000000..619153df67
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/BUILD
@@ -0,0 +1,91 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+# Transitive dependencies of this target will be included in the pip package.
+py_library(
+ name = "constrained_optimization_pip",
+ deps = [
+ ":constrained_optimization",
+ ":test_util",
+ ],
+)
+
+py_library(
+ name = "constrained_optimization",
+ srcs = [
+ "__init__.py",
+ "python/candidates.py",
+ "python/constrained_minimization_problem.py",
+ "python/constrained_optimizer.py",
+ "python/external_regret_optimizer.py",
+ "python/swap_regret_optimizer.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:standard_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "candidates_test",
+ srcs = ["python/candidates_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":constrained_optimization",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+# NOTE: This library can't be "testonly" since it needs to be included in the
+# pip package.
+py_library(
+ name = "test_util",
+ srcs = ["python/test_util.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":constrained_optimization",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:standard_ops",
+ ],
+)
+
+py_test(
+ name = "external_regret_optimizer_test",
+ srcs = ["python/external_regret_optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":constrained_optimization",
+ ":test_util",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:standard_ops",
+ "//tensorflow/python:training",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "swap_regret_optimizer_test",
+ srcs = ["python/swap_regret_optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":constrained_optimization",
+ ":test_util",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:standard_ops",
+ "//tensorflow/python:training",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md
new file mode 100644
index 0000000000..c65a150464
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/README.md
@@ -0,0 +1,345 @@
+<!-- TODO(acotter): Add usage example of non-convex optimization and stochastic classification. -->
+
+# ConstrainedOptimization (TFCO)
+
+TFCO is a library for optimizing inequality-constrained problems in TensorFlow.
+Both the objective function and the constraints are represented as Tensors,
+giving users the maximum amount of flexibility in specifying their optimization
+problems.
+
+This flexibility makes optimization considerably more difficult: on a non-convex
+problem, if one uses the "standard" approach of introducing a Lagrange
+multiplier for each constraint, and then jointly maximizing over the Lagrange
+multipliers and minimizing over the model parameters, then a stable stationary
+point might not even *exist*. Hence, in some cases, oscillation, instead of
+convergence, is inevitable.
+
+Thankfully, it turns out that even if, over the course of optimization, no
+*particular* iterate does a good job of minimizing the objective while
+satisfying the constraints, the *sequence* of iterates, on average, usually
+will. This observation suggests the following approach: at training time, we'll
+periodically snapshot the model state during optimization; then, at evaluation
+time, each time we're given a new example to evaluate, we'll sample one of the
+saved snapshots uniformly at random, and apply it to the example. This
+*stochastic model* will generally perform well, both with respect to the
+objective function, and the constraints.
+
+In fact, we can do better: it's possible to post-process the set of snapshots to
+find a distribution over at most $$m+1$$ snapshots, where $$m$$ is the number of
+constraints, that will be at least as good (and will usually be much better)
+than the (much larger) uniform distribution described above. If you're unable or
+unwilling to use a stochastic model at all, then you can instead use a heuristic
+to choose the single best snapshot.
+
+For full details, motivation, and theoretical results on the approach taken by
+this library, please refer to:
+
+> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+> Constrained Optimization".
+> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+which will be referred to as [CoJiSr18] throughout the remainder of this
+document.
+
+### Proxy Constraints
+
+Imagine that we want to constrain the recall of a binary classifier to be at
+least 90%. Since the recall is proportional to the number of true positive
+classifications, which itself is a sum of indicator functions, this constraint
+is non-differentible, and therefore cannot be used in a problem that will be
+optimized using a (stochastic) gradient-based algorithm.
+
+For this and similar problems, TFCO supports so-called *proxy constraints*,
+which are (at least semi-differentiable) approximations of the original
+constraints. For example, one could create a proxy recall function by replacing
+the indicator functions with sigmoids. During optimization, each proxy
+constraint function will be penalized, with the magnitude of the penalty being
+chosen to satisfy the corresponding *original* (non-proxy) constraint.
+
+On a problem including proxy constraints&mdash;even a convex problem&mdash;the
+Lagrangian approach discussed above isn't guaranteed to work. However, a
+different algorithm, based on minimizing *swap regret*, does work. Aside from
+this difference, the recommended procedure for optimizing a proxy-constrained
+problem remains the same: periodically snapshot the model during optimization,
+and then either find the best $$m+1$$-sized distribution, or heuristically
+choose the single best snapshot.
+
+## Components
+
+* [constrained_minimization_problem](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py):
+ contains the `ConstrainedMinimizationProblem` interface. Your own
+ constrained optimization problems should be represented using
+ implementations of this interface.
+
+* [constrained_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py):
+ contains the `ConstrainedOptimizer` interface, which is similar to (but
+ different from) `tf.train.Optimizer`, with the main difference being that
+ `ConstrainedOptimizer`s are given `ConstrainedMinimizationProblem`s to
+ optimize, and perform constrained optimization.
+
+ * [external_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py):
+ contains the `AdditiveExternalRegretOptimizer` implementation, which is
+ a `ConstrainedOptimizer` implementing the Lagrangian approach discussed
+ above (with additive updates to the Lagrange multipliers). You should
+ use this optimizer for problems *without* proxy constraints. It may also
+ work for problems with proxy constraints, but we recommend using a swap
+ regret optimizer, instead.
+
+ This optimizer is most similar to Algorithm 3 in Appendix C.3 of
+ [CoJiSr18], and is discussed in Section 3. The two differences are that
+ it uses proxy constraints (if they're provided) in the update of the
+ model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for
+ the "inner" updates.
+
+ * [swap_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py):
+ contains the `AdditiveSwapRegretOptimizer` and
+ `MultiplicativeSwapRegretOptimizer` implementations, which are
+ `ConstrainedOptimizer`s implementing the swap-regret minimization
+ approach mentioned above (with additive or multiplicative updates,
+ respectively, to the parameters associated with the
+ constraints&mdash;these parameters are not Lagrange multipliers, but
+ play a similar role). You should use one of these optimizers (we suggest
+ `MultiplicativeSwapRegretOptimizer`) for problems *with* proxy
+ constraints.
+
+ The `MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2
+ in Section 4 of [CoJiSr18], with the difference being that it uses
+ `tf.train.Optimizer`s, instead of SGD, for the "inner" updates. The
+ `AdditiveSwapRegretOptimizer` differs further in that it performs
+ additive (instead of multiplicative) updates of the stochastic matrix.
+
+* [candidates](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/candidates.py):
+ contains two functions, `find_best_candidate_distribution` and
+ `find_best_candidate_index`. Both of these functions are given a set of
+ candidate solutions to a constrained optimization problem, from which the
+ former finds the best distribution over at most $$m+1$$ candidates, and the
+ latter heuristically finds the single best candidate. As discussed above,
+ the set of candidates will typically be model snapshots saved periodically
+ during optimization. Both of these functions require that scipy be
+ installed.
+
+ The `find_best_candidate_distribution` function implements the approach
+ described in Lemma 3 of [CoJiSr18], while `find_best_candidate_index`
+ implements the heuristic used for hyperparameter search in the experiments
+ of Section 5.2.
+
+## Convex Example with Proxy Constraints
+
+This is a simple example of recall-constrained optimization on simulated data:
+we will try to find a classifier that minimizes the average hinge loss while
+constraining recall to be at least 90%.
+
+We'll start with the required imports&mdash;notice the definition of `tfco`:
+
+```python
+import math
+import numpy as np
+import tensorflow as tf
+
+tfco = tf.contrib.constrained_optimization
+```
+
+We'll now create an implementation of the `ConstrainedMinimizationProblem` class
+for this problem. The constructor takes three parameters: a Tensor containing
+the classification labels (0 or 1) for every training example, another Tensor
+containing the model's predictions on every training example (sometimes called
+the "logits"), and the lower bound on recall that will be enforced using a
+constraint.
+
+This implementation will contain both constraints *and* proxy constraints: the
+former represents the constraint that the true recall (defined in terms of the
+*number* of true positives) be at least `recall_lower_bound`, while the latter
+represents the same constraint, but on a hinge approximation of the recall.
+
+```python
+class ExampleProblem(tfco.ConstrainedMinimizationProblem):
+
+ def __init__(self, labels, predictions, recall_lower_bound):
+ self._labels = labels
+ self._predictions = predictions
+ self._recall_lower_bound = recall_lower_bound
+ # The number of positively-labeled examples.
+ self._positive_count = tf.reduce_sum(self._labels)
+
+ @property
+ def objective(self):
+ return tf.losses.hinge_loss(labels=self._labels, logits=self._predictions)
+
+ @property
+ def constraints(self):
+ true_positives = self._labels * tf.to_float(self._predictions > 0)
+ true_positive_count = tf.reduce_sum(true_positives)
+ recall = true_positive_count / self._positive_count
+ # The constraint is (recall >= self._recall_lower_bound), which we convert
+ # to (self._recall_lower_bound - recall <= 0) because
+ # ConstrainedMinimizationProblems must always provide their constraints in
+ # the form (tensor <= 0).
+ #
+ # The result of this function should be a tensor, with each element being
+ # a quantity that is constrained to be nonpositive. We only have one
+ # constraint, so we return a one-element tensor.
+ return self._recall_lower_bound - recall
+
+ @property
+ def proxy_constraints(self):
+ # Use 1 - hinge since we're SUBTRACTING recall in the constraint function,
+ # and we want the proxy constraint function to be convex.
+ true_positives = self._labels * tf.minimum(1.0, self._predictions)
+ true_positive_count = tf.reduce_sum(true_positives)
+ recall = true_positive_count / self._positive_count
+ # Please see the corresponding comment in the constraints property.
+ return self._recall_lower_bound - recall
+```
+
+We'll now create a simple simulated dataset by sampling 1000 random
+10-dimensional feature vectors from a Gaussian, finding their labels using a
+random "ground truth" linear model, and then adding noise by randomly flipping
+200 labels.
+
+```python
+# Create a simulated 10-dimensional training dataset consisting of 1000 labeled
+# examples, of which 800 are labeled correctly and 200 are mislabeled.
+num_examples = 1000
+num_mislabeled_examples = 200
+dimension = 10
+# We will constrain the recall to be at least 90%.
+recall_lower_bound = 0.9
+
+# Create random "ground truth" parameters to a linear model.
+ground_truth_weights = np.random.normal(size=dimension) / math.sqrt(dimension)
+ground_truth_threshold = 0
+
+# Generate a random set of features for each example.
+features = np.random.normal(size=(num_examples, dimension)).astype(
+ np.float32) / math.sqrt(dimension)
+# Compute the labels from these features given the ground truth linear model.
+labels = (np.matmul(features, ground_truth_weights) >
+ ground_truth_threshold).astype(np.float32)
+# Add noise by randomly flipping num_mislabeled_examples labels.
+mislabeled_indices = np.random.choice(
+ num_examples, num_mislabeled_examples, replace=False)
+labels[mislabeled_indices] = 1 - labels[mislabeled_indices]
+```
+
+We're now ready to construct our model, and the corresponding optimization
+problem. We'll use a linear model of the form $$f(x) = w^T x - t$$, where $$w$$
+is the `weights`, and $$t$$ is the `threshold`. The `problem` variable will hold
+an instance of the `ExampleProblem` class we created earlier.
+
+```python
+# Create variables containing the model parameters.
+weights = tf.Variable(tf.zeros(dimension), dtype=tf.float32, name="weights")
+threshold = tf.Variable(0.0, dtype=tf.float32, name="threshold")
+
+# Create the optimization problem.
+constant_labels = tf.constant(labels, dtype=tf.float32)
+constant_features = tf.constant(features, dtype=tf.float32)
+predictions = tf.tensordot(constant_features, weights, axes=(1, 0)) - threshold
+problem = ExampleProblem(
+ labels=constant_labels,
+ predictions=predictions,
+ recall_lower_bound=recall_lower_bound,
+)
+```
+
+We're almost ready to train our model, but first we'll create a couple of
+functions to measure its performance. We're interested in two quantities: the
+average hinge loss (which we seek to minimize), and the recall (which we
+constrain).
+
+```python
+def average_hinge_loss(labels, predictions):
+ num_examples, = np.shape(labels)
+ signed_labels = (labels * 2) - 1
+ total_hinge_loss = np.sum(np.maximum(0.0, 1.0 - signed_labels * predictions))
+ return total_hinge_loss / num_examples
+
+def recall(labels, predictions):
+ positive_count = np.sum(labels)
+ true_positives = labels * (predictions > 0)
+ true_positive_count = np.sum(true_positives)
+ return true_positive_count / positive_count
+```
+
+As was mentioned earlier, external regret optimizers suffice for problems
+without proxy constraints, but swap regret optimizers are recommended for
+problems *with* proxy constraints. Since this problem contains proxy
+constraints, we use the `MultiplicativeSwapRegretOptimizer`.
+
+For this problem, the constraint is fairly easy to satisfy, so we can use the
+same "inner" optimizer (an `AdagradOptimizer` with a learning rate of 1) for
+optimization of both the model parameters (`weights` and `threshold`), and the
+internal parameters associated with the constraints (these are the analogues of
+the Lagrange multipliers used by the `MultiplicativeSwapRegretOptimizer`). For
+more difficult problems, it will often be necessary to use different optimizers,
+with different learning rates (presumably found via a hyperparameter search): to
+accomplish this, pass *both* the `optimizer` and `constraint_optimizer`
+parameters to `MultiplicativeSwapRegretOptimizer`'s constructor.
+
+Since this is a convex problem (both the objective and proxy constraint
+functions are convex), we can just take the last iterate. Periodic snapshotting,
+and the use of the `find_best_candidate_distribution` or
+`find_best_candidate_index` functions, is generally only necessary for
+non-convex problems (and even then, it isn't *always* necessary).
+
+```python
+with tf.Session() as session:
+ optimizer = tfco.MultiplicativeSwapRegretOptimizer(
+ optimizer=tf.train.AdagradOptimizer(learning_rate=1.0))
+ train_op = optimizer.minimize(problem)
+
+ session.run(tf.global_variables_initializer())
+ for ii in xrange(1000):
+ session.run(train_op)
+
+ trained_weights, trained_threshold = session.run((weights, threshold))
+
+trained_predictions = np.matmul(features, trained_weights) - trained_threshold
+print("Constrained average hinge loss = %f" % average_hinge_loss(
+ labels, trained_predictions))
+print("Constrained recall = %f" % recall(labels, trained_predictions))
+```
+
+Running the above code gives the following output (due to the randomness of the
+dataset, you'll get a different result when you run it):
+
+```none
+Constrained average hinge loss = 0.710019
+Constrained recall = 0.899811
+```
+
+As we hoped, the recall is extremely close to 90%&mdash;and, thanks to the use
+of proxy constraints, this is the *true* recall, not a hinge approximation.
+
+For comparison, let's try optimizing the same problem *without* the recall
+constraint:
+
+```python
+with tf.Session() as session:
+ optimizer = tf.train.AdagradOptimizer(learning_rate=1.0)
+ # For optimizing the unconstrained problem, we just minimize the "objective"
+ # portion of the minimization problem.
+ train_op = optimizer.minimize(problem.objective)
+
+ session.run(tf.global_variables_initializer())
+ for ii in xrange(1000):
+ session.run(train_op)
+
+ trained_weights, trained_threshold = session.run((weights, threshold))
+
+trained_predictions = np.matmul(features, trained_weights) - trained_threshold
+print("Unconstrained average hinge loss = %f" % average_hinge_loss(
+ labels, trained_predictions))
+print("Unconstrained recall = %f" % recall(labels, trained_predictions))
+```
+
+This code gives the following output (again, you'll get a different answer,
+since the dataset is random):
+
+```none
+Unconstrained average hinge loss = 0.627271
+Unconstrained recall = 0.793951
+```
+
+Because there is no constraint, the unconstrained problem does a better job of
+minimizing the average hinge loss, but naturally doesn't approach 90% recall.
diff --git a/tensorflow/contrib/constrained_optimization/__init__.py b/tensorflow/contrib/constrained_optimization/__init__.py
new file mode 100644
index 0000000000..1e49ba9f17
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/__init__.py
@@ -0,0 +1,41 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A library for performing constrained optimization in TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.contrib.constrained_optimization.python.candidates import *
+from tensorflow.contrib.constrained_optimization.python.constrained_minimization_problem import *
+from tensorflow.contrib.constrained_optimization.python.constrained_optimizer import *
+from tensorflow.contrib.constrained_optimization.python.external_regret_optimizer import *
+from tensorflow.contrib.constrained_optimization.python.swap_regret_optimizer import *
+# pylint: enable=wildcard-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "AdditiveExternalRegretOptimizer",
+ "AdditiveSwapRegretOptimizer",
+ "ConstrainedMinimizationProblem",
+ "ConstrainedOptimizer",
+ "find_best_candidate_distribution",
+ "find_best_candidate_index",
+ "MultiplicativeSwapRegretOptimizer",
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/constrained_optimization/python/candidates.py b/tensorflow/contrib/constrained_optimization/python/candidates.py
new file mode 100644
index 0000000000..ac86a6741b
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/candidates.py
@@ -0,0 +1,319 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Code for optimizing over a set of candidate solutions.
+
+The functions in this file deal with the constrained problem:
+
+> minimize f(w)
+> s.t. g_i(w) <= 0 for all i in {0,1,...,m-1}
+
+Here, f(w) is the "objective function", and g_i(w) is the ith (of m) "constraint
+function". Given the values of the objective and constraint functions for a set
+of n "candidate solutions" {w_0,w_1,...,w_{n-1}} (for a total of n objective
+function values, and n*m constraint function values), the
+`find_best_candidate_distribution` function finds the best DISTRIBUTION over
+these candidates, while `find_best_candidate_index' heuristically finds the
+single best candidate.
+
+Both of these functions have dependencies on `scipy`, so if you want to call
+them, then you must make sure that `scipy` is available. The imports are
+performed inside the functions themselves, so if they're not actually called,
+then `scipy` is not needed.
+
+For more specifics, please refer to:
+
+> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+> Constrained Optimization".
+> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+The `find_best_candidate_distribution` function implements the approach
+described in Lemma 3, while `find_best_candidate_index` implements the heuristic
+used for hyperparameter search in the experiments of Section 5.2.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+def _find_best_candidate_distribution_helper(objective_vector,
+ constraints_matrix,
+ maximum_violation=0.0):
+ """Finds a distribution minimizing an objective subject to constraints.
+
+ This function deals with the constrained problem:
+
+ > minimize f(w)
+ > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1}
+
+ Here, f(w) is the "objective function", and g_i(w) is the ith (of m)
+ "constraint function". Given a set of n "candidate solutions"
+ {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n
+ candidates that, in expectation, minimizes the objective while violating
+ the constraints by no more than `maximum_violation`. If no such distribution
+ exists, it returns an error (using Go-style error reporting).
+
+ The `objective_vector` parameter should be a numpy array with shape (n,), for
+ which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a
+ numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j).
+
+ This function will return a distribution for which at most m+1 probabilities,
+ and often fewer, are nonzero.
+
+ Args:
+ objective_vector: numpy array of shape (n,), where n is the number of
+ "candidate solutions". Contains the objective function values.
+ constraints_matrix: numpy array of shape (m,n), where m is the number of
+ constraints and n is the number of "candidate solutions". Contains the
+ constraint violation magnitudes.
+ maximum_violation: nonnegative float, the maximum amount by which any
+ constraint may be violated, in expectation.
+
+ Returns:
+ A pair (`result`, `message`), exactly one of which is None. If `message` is
+ None, then the `result` contains the optimal distribution as a numpy array
+ of shape (n,). If `result` is None, then `message` contains an error
+ message.
+
+ Raises:
+ ValueError: If `objective_vector` and `constraints_matrix` have inconsistent
+ shapes, or if `maximum_violation` is negative.
+ ImportError: If we're unable to import `scipy.optimize`.
+ """
+ if maximum_violation < 0.0:
+ raise ValueError("maximum_violation must be nonnegative")
+
+ mm, nn = np.shape(constraints_matrix)
+ if (nn,) != np.shape(objective_vector):
+ raise ValueError(
+ "objective_vector must have shape (n,), and constraints_matrix (m, n),"
+ " where n is the number of candidates, and m is the number of "
+ "constraints")
+
+ # We import scipy inline, instead of at the top of the file, so that a scipy
+ # dependency is only introduced if either find_best_candidate_distribution()
+ # or find_best_candidate_index() are actually called.
+ import scipy.optimize # pylint: disable=g-import-not-at-top
+
+ # Feasibility (within maximum_violation) constraints.
+ a_ub = constraints_matrix
+ b_ub = np.full((mm, 1), maximum_violation)
+ # Sum-to-one constraint.
+ a_eq = np.ones((1, nn))
+ b_eq = np.ones((1, 1))
+ # Nonnegativity constraints.
+ bounds = (0, None)
+
+ result = scipy.optimize.linprog(
+ objective_vector,
+ A_ub=a_ub,
+ b_ub=b_ub,
+ A_eq=a_eq,
+ b_eq=b_eq,
+ bounds=bounds)
+ # Go-style error reporting. We don't raise on error, since
+ # find_best_candidate_distribution() needs to handle the failure case, and we
+ # shouldn't use exceptions as flow-control.
+ if not result.success:
+ return (None, result.message)
+ else:
+ return (result.x, None)
+
+
+def find_best_candidate_distribution(objective_vector,
+ constraints_matrix,
+ epsilon=0.0):
+ """Finds a distribution minimizing an objective subject to constraints.
+
+ This function deals with the constrained problem:
+
+ > minimize f(w)
+ > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1}
+
+ Here, f(w) is the "objective function", and g_i(w) is the ith (of m)
+ "constraint function". Given a set of n "candidate solutions"
+ {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n
+ candidates that, in expectation, minimizes the objective while violating
+ the constraints by the smallest possible amount (with the amount being found
+ via bisection search).
+
+ The `objective_vector` parameter should be a numpy array with shape (n,), for
+ which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a
+ numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j).
+
+ This function will return a distribution for which at most m+1 probabilities,
+ and often fewer, are nonzero.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ This function implements the approach described in Lemma 3.
+
+ Args:
+ objective_vector: numpy array of shape (n,), where n is the number of
+ "candidate solutions". Contains the objective function values.
+ constraints_matrix: numpy array of shape (m,n), where m is the number of
+ constraints and n is the number of "candidate solutions". Contains the
+ constraint violation magnitudes.
+ epsilon: nonnegative float, the threshold at which to terminate the binary
+ search while searching for the minimal expected constraint violation
+ magnitude.
+
+ Returns:
+ The optimal distribution, as a numpy array of shape (n,).
+
+ Raises:
+ ValueError: If `objective_vector` and `constraints_matrix` have inconsistent
+ shapes, or if `epsilon` is negative.
+ ImportError: If we're unable to import `scipy.optimize`.
+ """
+ if epsilon < 0.0:
+ raise ValueError("epsilon must be nonnegative")
+
+ # If there is a feasible solution (i.e. with maximum_violation=0), then that's
+ # what we'll return.
+ pp, _ = _find_best_candidate_distribution_helper(objective_vector,
+ constraints_matrix)
+ if pp is not None:
+ return pp
+
+ # The bound is the minimum over all candidates, of the maximum per-candidate
+ # constraint violation.
+ lower = 0.0
+ upper = np.min(np.amax(constraints_matrix, axis=0))
+ best_pp, _ = _find_best_candidate_distribution_helper(
+ objective_vector, constraints_matrix, maximum_violation=upper)
+ assert best_pp is not None
+
+ # Throughout this loop, a maximum_violation of "lower" is not achievable,
+ # but a maximum_violation of "upper" is achiveable.
+ while True:
+ middle = 0.5 * (lower + upper)
+ if (middle - lower <= epsilon) or (upper - middle <= epsilon):
+ break
+ else:
+ pp, _ = _find_best_candidate_distribution_helper(
+ objective_vector, constraints_matrix, maximum_violation=middle)
+ if pp is None:
+ lower = middle
+ else:
+ best_pp = pp
+ upper = middle
+
+ return best_pp
+
+
+def find_best_candidate_index(objective_vector,
+ constraints_matrix,
+ rank_objectives=False):
+ """Heuristically finds the best candidate solution to a constrained problem.
+
+ This function deals with the constrained problem:
+
+ > minimize f(w)
+ > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1}
+
+ Here, f(w) is the "objective function", and g_i(w) is the ith (of m)
+ "constraint function". Given a set of n "candidate solutions"
+ {w_0,w_1,...,w_{n-1}}, this function finds the "best" solution according
+ to the following heuristic:
+
+ 1. Across all models, the ith constraint violations (i.e. max{0, g_i(0)})
+ are ranked, as are the objectives (if rank_objectives=True).
+ 2. Each model is then associated its MAXIMUM rank across all m constraints
+ (and the objective, if rank_objectives=True).
+ 3. The model with the minimal maximum rank is then identified. Ties are
+ broken using the objective function value.
+ 4. The index of this "best" model is returned.
+
+ The `objective_vector` parameter should be a numpy array with shape (n,), for
+ which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a
+ numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j).
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ This function implements the heuristic used for hyperparameter search in the
+ experiments of Section 5.2.
+
+ Args:
+ objective_vector: numpy array of shape (n,), where n is the number of
+ "candidate solutions". Contains the objective function values.
+ constraints_matrix: numpy array of shape (m,n), where m is the number of
+ constraints and n is the number of "candidate solutions". Contains the
+ constraint violation magnitudes.
+ rank_objectives: bool, whether the objective function values should be
+ included in the initial ranking step. If True, both the objective and
+ constraints will be ranked. If False, only the constraints will be ranked.
+ In either case, the objective function values will be used for
+ tiebreaking.
+
+ Returns:
+ The index (in {0,1,...,n-1}) of the "best" model according to the above
+ heuristic.
+
+ Raises:
+ ValueError: If `objective_vector` and `constraints_matrix` have inconsistent
+ shapes.
+ ImportError: If we're unable to import `scipy.stats`.
+ """
+ mm, nn = np.shape(constraints_matrix)
+ if (nn,) != np.shape(objective_vector):
+ raise ValueError(
+ "objective_vector must have shape (n,), and constraints_matrix (m, n),"
+ " where n is the number of candidates, and m is the number of "
+ "constraints")
+
+ # We import scipy inline, instead of at the top of the file, so that a scipy
+ # dependency is only introduced if either find_best_candidate_distribution()
+ # or find_best_candidate_index() are actually called.
+ import scipy.stats # pylint: disable=g-import-not-at-top
+
+ if rank_objectives:
+ maximum_ranks = scipy.stats.rankdata(objective_vector, method="min")
+ else:
+ maximum_ranks = np.zeros(nn, dtype=np.int64)
+ for ii in xrange(mm):
+ # Take the maximum of the constraint functions with zero, since we want to
+ # rank the magnitude of constraint *violations*. If the constraint is
+ # satisfied, then we don't care how much it's satisfied by (as a result, we
+ # we expect all models satisfying a constraint to be tied at rank 1).
+ ranks = scipy.stats.rankdata(
+ np.maximum(0.0, constraints_matrix[ii, :]), method="min")
+ maximum_ranks = np.maximum(maximum_ranks, ranks)
+
+ best_index = None
+ best_rank = float("Inf")
+ best_objective = float("Inf")
+ for ii in xrange(nn):
+ if maximum_ranks[ii] < best_rank:
+ best_index = ii
+ best_rank = maximum_ranks[ii]
+ best_objective = objective_vector[ii]
+ elif (maximum_ranks[ii] == best_rank) and (objective_vector[ii] <=
+ best_objective):
+ best_index = ii
+ best_objective = objective_vector[ii]
+
+ return best_index
diff --git a/tensorflow/contrib/constrained_optimization/python/candidates_test.py b/tensorflow/contrib/constrained_optimization/python/candidates_test.py
new file mode 100644
index 0000000000..a4c49d48bc
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/candidates_test.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for constrained_optimization.python.candidates."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.constrained_optimization.python import candidates
+from tensorflow.python.platform import test
+
+
+class CandidatesTest(test.TestCase):
+
+ def test_inconsistent_shapes_for_best_distribution(self):
+ """An error is raised when parameters have inconsistent shapes."""
+ objective_vector = np.array([1, 2, 3])
+ constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
+ with self.assertRaises(ValueError):
+ _ = candidates.find_best_candidate_distribution(objective_vector,
+ constraints_matrix)
+
+ def test_inconsistent_shapes_for_best_index(self):
+ """An error is raised when parameters have inconsistent shapes."""
+ objective_vector = np.array([1, 2, 3])
+ constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
+ with self.assertRaises(ValueError):
+ _ = candidates.find_best_candidate_index(objective_vector,
+ constraints_matrix)
+
+ def test_best_distribution(self):
+ """Distribution should match known solution."""
+ objective_vector = np.array(
+ [0.03053309, -0.06667082, 0.88355145, 0.46529806])
+ constraints_matrix = np.array(
+ [[-0.60164551, 0.36676229, 0.7856454, -0.8441711],
+ [0.00371592, -0.16392108, -0.59778071, -0.56908492]])
+ distribution = candidates.find_best_candidate_distribution(
+ objective_vector, constraints_matrix)
+ # Verify that the solution is a probability distribution.
+ self.assertTrue(np.all(distribution >= 0))
+ self.assertAlmostEqual(np.sum(distribution), 1.0)
+ # Verify that the solution satisfies the constraints.
+ maximum_constraint_violation = np.amax(
+ np.dot(constraints_matrix, distribution))
+ self.assertLessEqual(maximum_constraint_violation, 0)
+ # Verify that the solution matches that which we expect.
+ expected_distribution = np.array([0.37872711, 0.62127289, 0, 0])
+ self.assertAllClose(expected_distribution, distribution, rtol=0, atol=1e-6)
+
+ def test_best_index_rank_objectives_true(self):
+ """Index should match known solution."""
+ # Objective ranks = [2, 1, 4, 3].
+ objective_vector = np.array(
+ [0.03053309, -0.06667082, 0.88355145, 0.46529806])
+ # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]].
+ constraints_matrix = np.array(
+ [[-0.60164551, 0.36676229, 0.7856454, -0.8441711],
+ [0.00371592, -0.16392108, -0.59778071, -0.56908492]])
+ # Maximum ranks = [4, 3, 4, 3].
+ index = candidates.find_best_candidate_index(
+ objective_vector, constraints_matrix, rank_objectives=True)
+ self.assertEqual(1, index)
+
+ def test_best_index_rank_objectives_false(self):
+ """Index should match known solution."""
+ # Objective ranks = [2, 1, 4, 3].
+ objective_vector = np.array(
+ [0.03053309, -0.06667082, 0.88355145, 0.46529806])
+ # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]].
+ constraints_matrix = np.array(
+ [[-0.60164551, 0.36676229, 0.7856454, -0.8441711],
+ [0.00371592, -0.16392108, -0.59778071, -0.56908492]])
+ # Maximum ranks = [4, 3, 4, 1].
+ index = candidates.find_best_candidate_index(
+ objective_vector, constraints_matrix, rank_objectives=False)
+ self.assertEqual(3, index)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
new file mode 100644
index 0000000000..70813fb217
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines abstract class for `ConstrainedMinimizationProblem`s.
+
+A ConstrainedMinimizationProblem consists of an objective function to minimize,
+and a set of constraint functions that are constrained to be nonpositive.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class ConstrainedMinimizationProblem(object):
+ """Abstract class representing a `ConstrainedMinimizationProblem`.
+
+ A ConstrainedMinimizationProblem consists of an objective function to
+ minimize, and a set of constraint functions that are constrained to be
+ nonpositive.
+
+ In addition to the constraint functions, there may (optionally) be proxy
+ constraint functions: a ConstrainedOptimizer will attempt to penalize these
+ proxy constraint functions so as to satisfy the (non-proxy) constraints. Proxy
+ constraints could be used if the constraints functions are difficult or
+ impossible to optimize (e.g. if they're piecewise constant), in which case the
+ proxy constraints should be some approximation of the original constraints
+ that is well-enough behaved to permit successful optimization.
+ """
+
+ @abc.abstractproperty
+ def objective(self):
+ """Returns the objective function.
+
+ Returns:
+ A 0d tensor that should be minimized.
+ """
+ pass
+
+ @property
+ def num_constraints(self):
+ """Returns the number of constraints.
+
+ Returns:
+ An int containing the number of constraints.
+
+ Raises:
+ ValueError: If the constraints (or proxy_constraints, if present) do not
+ have fully-known shapes, OR if proxy_constraints are present, and the
+ shapes of constraints and proxy_constraints are fully-known, but they're
+ different.
+ """
+ constraints_shape = self.constraints.get_shape()
+ if self.proxy_constraints is None:
+ proxy_constraints_shape = constraints_shape
+ else:
+ proxy_constraints_shape = self.proxy_constraints.get_shape()
+
+ if (constraints_shape is None or proxy_constraints_shape is None or
+ any([ii is None for ii in constraints_shape.as_list()]) or
+ any([ii is None for ii in proxy_constraints_shape.as_list()])):
+ raise ValueError(
+ "constraints and proxy_constraints must have fully-known shapes")
+ if constraints_shape != proxy_constraints_shape:
+ raise ValueError(
+ "constraints and proxy_constraints must have the same shape")
+
+ size = 1
+ for ii in constraints_shape.as_list():
+ size *= ii
+ return int(size)
+
+ @abc.abstractproperty
+ def constraints(self):
+ """Returns the vector of constraint functions.
+
+ Letting g_i be the ith element of the constraints vector, the ith constraint
+ will be g_i <= 0.
+
+ Returns:
+ A tensor of constraint functions.
+ """
+ pass
+
+ # This is a property, instead of an abstract property, since it doesn't need
+ # to be overridden: if proxy_constraints returns None, then there are no
+ # proxy constraints.
+ @property
+ def proxy_constraints(self):
+ """Returns the optional vector of proxy constraint functions.
+
+ The difference between `constraints` and `proxy_constraints` is that, when
+ proxy constraints are present, the `constraints` are merely EVALUATED during
+ optimization, whereas the `proxy_constraints` are DIFFERENTIATED. If there
+ are no proxy constraints, then the `constraints` are both evaluated and
+ differentiated.
+
+ For example, if we want to impose constraints on step functions, then we
+ could use these functions for `constraints`. However, because a step
+ function has zero gradient almost everywhere, we can't differentiate these
+ functions, so we would take `proxy_constraints` to be some differentiable
+ approximation of `constraints`.
+
+ Returns:
+ A tensor of proxy constraint functions.
+ """
+ return None
diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py
new file mode 100644
index 0000000000..8055545366
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py
@@ -0,0 +1,208 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines base class for `ConstrainedOptimizer`s."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.training import optimizer as train_optimizer
+
+
+@six.add_metaclass(abc.ABCMeta)
+class ConstrainedOptimizer(object):
+ """Base class representing a constrained optimizer.
+
+ A ConstrainedOptimizer wraps a tf.train.Optimizer (or more than one), and
+ applies it to a ConstrainedMinimizationProblem. Unlike a tf.train.Optimizer,
+ which takes a tensor to minimize as a parameter to its minimize() method, a
+ constrained optimizer instead takes a ConstrainedMinimizationProblem.
+ """
+
+ def __init__(self, optimizer):
+ """Constructs a new `ConstrainedOptimizer`.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the
+ ConstraintedMinimizationProblem.
+
+ Returns:
+ A new `ConstrainedOptimizer`.
+ """
+ self._optimizer = optimizer
+
+ @property
+ def optimizer(self):
+ """Returns the `tf.train.Optimizer` used for optimization."""
+ return self._optimizer
+
+ def minimize_unconstrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the unconstrained problem.
+
+ Unlike `minimize_constrained`, this function ignores the `constraints` (and
+ `proxy_constraints`) portion of the minimization problem entirely, and only
+ minimizes `objective`.
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+ """
+ return self.optimizer.minimize(
+ minimization_problem.objective,
+ global_step=global_step,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ name=name,
+ grad_loss=grad_loss)
+
+ @abc.abstractmethod
+ def minimize_constrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the constrained problem.
+
+ Unlike `minimize_unconstrained`, this function attempts to find a solution
+ that minimizes the `objective` portion of the minimization problem while
+ satisfying the `constraints` portion.
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+ """
+ pass
+
+ def minimize(self,
+ minimization_problem,
+ unconstrained_steps=None,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the constrained problem.
+
+ This method combines the functionality of `minimize_unconstrained` and
+ `minimize_constrained`. If global_step < unconstrained_steps, it will
+ perform an unconstrained update, and if global_step >= unconstrained_steps,
+ it will perform a constrained update.
+
+ The reason for this functionality is that it may be best to initialize the
+ constrained optimizer with an approximate optimum of the unconstrained
+ problem.
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ unconstrained_steps: int, number of steps for which we should perform
+ unconstrained updates, before transitioning to constrained updates.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+
+ Raises:
+ ValueError: If unconstrained_steps is provided, but global_step is not.
+ """
+
+ def unconstrained_fn():
+ """Returns an `Op` for minimizing the unconstrained problem."""
+ return self.minimize_unconstrained(
+ minimization_problem=minimization_problem,
+ global_step=global_step,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ name=name,
+ grad_loss=grad_loss)
+
+ def constrained_fn():
+ """Returns an `Op` for minimizing the constrained problem."""
+ return self.minimize_constrained(
+ minimization_problem=minimization_problem,
+ global_step=global_step,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ name=name,
+ grad_loss=grad_loss)
+
+ if unconstrained_steps is not None:
+ if global_step is None:
+ raise ValueError(
+ "global_step cannot be None if unconstrained_steps is provided")
+ unconstrained_steps_tensor = ops.convert_to_tensor(unconstrained_steps)
+ dtype = unconstrained_steps_tensor.dtype
+ return control_flow_ops.cond(
+ standard_ops.cast(global_step, dtype) < unconstrained_steps_tensor,
+ true_fn=unconstrained_fn,
+ false_fn=constrained_fn)
+ else:
+ return constrained_fn()
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
new file mode 100644
index 0000000000..01c6e4f08a
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
@@ -0,0 +1,375 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines `AdditiveExternalRegretOptimizer`.
+
+This optimizer minimizes a `ConstrainedMinimizationProblem` by introducing
+Lagrange multipliers, and using `tf.train.Optimizer`s to jointly optimize over
+the model parameters and Lagrange multipliers.
+
+For the purposes of constrained optimization, at least in theory,
+external-regret minimization suffices if the `ConstrainedMinimizationProblem`
+we're optimizing doesn't have any `proxy_constraints`, while swap-regret
+minimization should be used if `proxy_constraints` are present.
+
+For more specifics, please refer to:
+
+> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+> Constrained Optimization".
+> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+The formulation used by the AdditiveExternalRegretOptimizer--which is simply the
+usual Lagrangian formulation--can be found in Definition 1, and is discussed in
+Section 3. This optimizer is most similar to Algorithm 3 in Appendix C.3, with
+the two differences being that it uses proxy constraints (if they're provided)
+in the update of the model parameters, and uses `tf.train.Optimizer`s, instead
+of SGD, for the "inner" updates.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+from tensorflow.contrib.constrained_optimization.python import constrained_optimizer
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import optimizer as train_optimizer
+
+
+def _project_multipliers_wrt_euclidean_norm(multipliers, radius):
+ """Projects its argument onto the feasible region.
+
+ The feasible region is the set of all vectors with nonnegative elements that
+ sum to at most `radius`.
+
+ Args:
+ multipliers: 1d tensor, the Lagrange multipliers to project.
+ radius: float, the radius of the feasible region.
+
+ Returns:
+ The 1d tensor that results from projecting `multipliers` onto the feasible
+ region w.r.t. the Euclidean norm.
+
+ Raises:
+ ValueError: if the `multipliers` tensor does not have a fully-known shape,
+ or is not one-dimensional.
+ """
+ multipliers_shape = multipliers.get_shape()
+ if multipliers_shape is None:
+ raise ValueError("multipliers must have known shape")
+ if multipliers_shape.ndims != 1:
+ raise ValueError(
+ "multipliers must be one dimensional (instead is %d-dimensional)" %
+ multipliers_shape.ndims)
+ dimension = multipliers_shape[0].value
+ if dimension is None:
+ raise ValueError("multipliers must have fully-known shape")
+
+ def while_loop_condition(iteration, multipliers, inactive, old_inactive):
+ """Returns false if the while loop should terminate."""
+ del multipliers # Needed by the body, but not the condition.
+ not_done = (iteration < dimension)
+ not_converged = standard_ops.reduce_any(
+ standard_ops.not_equal(inactive, old_inactive))
+ return standard_ops.logical_and(not_done, not_converged)
+
+ def while_loop_body(iteration, multipliers, inactive, old_inactive):
+ """Performs one iteration of the projection."""
+ del old_inactive # Needed by the condition, but not the body.
+ iteration += 1
+ scale = standard_ops.minimum(
+ 0.0,
+ (radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum(
+ 1.0, standard_ops.reduce_sum(inactive)))
+ multipliers += scale * inactive
+ new_inactive = standard_ops.to_float(multipliers > 0)
+ multipliers *= new_inactive
+ return (iteration, multipliers, new_inactive, inactive)
+
+ iteration = standard_ops.constant(0)
+ inactive = standard_ops.ones_like(multipliers)
+
+ # We actually want a do-while loop, so we explicitly call while_loop_body()
+ # once before tf.while_loop().
+ iteration, multipliers, inactive, old_inactive = while_loop_body(
+ iteration, multipliers, inactive, inactive)
+ iteration, multipliers, inactive, old_inactive = control_flow_ops.while_loop(
+ while_loop_condition,
+ while_loop_body,
+ loop_vars=(iteration, multipliers, inactive, old_inactive),
+ name="euclidean_projection")
+
+ return multipliers
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
+ """Base class representing an `_ExternalRegretOptimizer`.
+
+ This class contains most of the logic for performing constrained
+ optimization, minimizing external regret for the constraints player. What it
+ *doesn't* do is keep track of the internal state (the Lagrange multipliers).
+ Instead, the state is accessed via the _initial_state(),
+ _lagrange_multipliers(), _constraint_grad_and_var() and _projection_op()
+ methods.
+
+ The reason for this is that we want to make it easy to implement different
+ representations of the internal state.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by `_ExternalRegretOptimizer`s--which is simply the usual
+ Lagrangian formulation--can be found in Definition 1, and is discussed in
+ Section 3. Such optimizers are most similar to Algorithm 3 in Appendix C.3.
+ """
+
+ def __init__(self, optimizer, constraint_optimizer=None):
+ """Constructs a new `_ExternalRegretOptimizer`.
+
+ The difference between `optimizer` and `constraint_optimizer` (if the latter
+ is provided) is that the former is used for learning the model parameters,
+ while the latter us used for the Lagrange multipliers. If no
+ `constraint_optimizer` is provided, then `optimizer` is used for both.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of the ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multipliers.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multipliers.
+
+ Returns:
+ A new `_ExternalRegretOptimizer`.
+ """
+ super(_ExternalRegretOptimizer, self).__init__(optimizer=optimizer)
+ self._constraint_optimizer = constraint_optimizer
+
+ @property
+ def constraint_optimizer(self):
+ """Returns the `tf.train.Optimizer` used for the Lagrange multipliers."""
+ return self._constraint_optimizer
+
+ @abc.abstractmethod
+ def _initial_state(self, num_constraints):
+ pass
+
+ @abc.abstractmethod
+ def _lagrange_multipliers(self, state):
+ pass
+
+ @abc.abstractmethod
+ def _constraint_grad_and_var(self, state, gradient):
+ pass
+
+ @abc.abstractmethod
+ def _projection_op(self, state, name=None):
+ pass
+
+ def minimize_constrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the constrained problem.
+
+ The `optimizer` constructor parameter will be used to update the model
+ parameters, while the Lagrange multipliers will be updated using
+ `constrained_optimizer` (if provided) or `optimizer` (if not).
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+ """
+ objective = minimization_problem.objective
+
+ constraints = minimization_problem.constraints
+ proxy_constraints = minimization_problem.proxy_constraints
+ if proxy_constraints is None:
+ proxy_constraints = constraints
+ # Flatten both constraints tensors to 1d.
+ num_constraints = minimization_problem.num_constraints
+ constraints = standard_ops.reshape(constraints, shape=(num_constraints,))
+ proxy_constraints = standard_ops.reshape(
+ proxy_constraints, shape=(num_constraints,))
+
+ # We use a lambda to initialize the state so that, if this function call is
+ # inside the scope of a tf.control_dependencies() block, the dependencies
+ # will not be applied to the initializer.
+ state = standard_ops.Variable(
+ lambda: self._initial_state(num_constraints),
+ trainable=False,
+ name="external_regret_optimizer_state")
+
+ multipliers = self._lagrange_multipliers(state)
+ loss = (
+ objective + standard_ops.tensordot(multipliers, proxy_constraints, 1))
+ multipliers_gradient = constraints
+
+ update_ops = []
+ if self.constraint_optimizer is None:
+ # If we don't have a separate constraint_optimizer, then we use
+ # self._optimizer for both the update of the model parameters, and that of
+ # the internal state.
+ grads_and_vars = self.optimizer.compute_gradients(
+ loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ grads_and_vars.append(
+ self._constraint_grad_and_var(state, multipliers_gradient))
+ update_ops.append(
+ self.optimizer.apply_gradients(grads_and_vars, name="update"))
+ else:
+ # If we have a separate constraint_optimizer, then we use self._optimizer
+ # for the update of the model parameters, and self._constraint_optimizer
+ # for that of the internal state.
+ grads_and_vars = self.optimizer.compute_gradients(
+ loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ multiplier_grads_and_vars = [
+ self._constraint_grad_and_var(state, multipliers_gradient)
+ ]
+
+ gradients = [
+ gradient for gradient, _ in grads_and_vars + multiplier_grads_and_vars
+ if gradient is not None
+ ]
+ with ops.control_dependencies(gradients):
+ update_ops.append(
+ self.optimizer.apply_gradients(grads_and_vars, name="update"))
+ update_ops.append(
+ self.constraint_optimizer.apply_gradients(
+ multiplier_grads_and_vars, name="optimizer_state_update"))
+
+ with ops.control_dependencies(update_ops):
+ if global_step is None:
+ # If we don't have a global step, just project, and we're done.
+ return self._projection_op(state, name=name)
+ else:
+ # If we have a global step, then we need to increment it in addition to
+ # projecting.
+ projection_op = self._projection_op(state, name="project")
+ with ops.colocate_with(global_step):
+ global_step_op = state_ops.assign_add(
+ global_step, 1, name="global_step_increment")
+ return control_flow_ops.group(projection_op, global_step_op, name=name)
+
+
+class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer):
+ """A `ConstrainedOptimizer` based on external-regret minimization.
+
+ This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly
+ minimize over the model parameters, and maximize over Lagrange multipliers,
+ with the latter maximization using additive updates and an algorithm that
+ minimizes external regret.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by this optimizer--which is simply the usual Lagrangian
+ formulation--can be found in Definition 1, and is discussed in Section 3. It
+ is most similar to Algorithm 3 in Appendix C.3, with the two differences being
+ that it uses proxy constraints (if they're provided) in the update of the
+ model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for the
+ "inner" updates.
+ """
+
+ def __init__(self,
+ optimizer,
+ constraint_optimizer=None,
+ maximum_multiplier_radius=None):
+ """Constructs a new `AdditiveExternalRegretOptimizer`.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multipliers.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multipliers.
+ maximum_multiplier_radius: float, an optional upper bound to impose on the
+ sum of the Lagrange multipliers.
+
+ Returns:
+ A new `AdditiveExternalRegretOptimizer`.
+
+ Raises:
+ ValueError: If the maximum_multiplier_radius parameter is nonpositive.
+ """
+ super(AdditiveExternalRegretOptimizer, self).__init__(
+ optimizer=optimizer, constraint_optimizer=constraint_optimizer)
+
+ if maximum_multiplier_radius and (maximum_multiplier_radius <= 0.0):
+ raise ValueError("maximum_multiplier_radius must be strictly positive")
+
+ self._maximum_multiplier_radius = maximum_multiplier_radius
+
+ def _initial_state(self, num_constraints):
+ # For an AdditiveExternalRegretOptimizer, the internal state is simply a
+ # tensor of Lagrange multipliers with shape (m,), where m is the number of
+ # constraints.
+ return standard_ops.zeros((num_constraints,), dtype=dtypes.float32)
+
+ def _lagrange_multipliers(self, state):
+ return state
+
+ def _constraint_grad_and_var(self, state, gradient):
+ # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True?
+ return (-gradient, state)
+
+ def _projection_op(self, state, name=None):
+ with ops.colocate_with(state):
+ if self._maximum_multiplier_radius:
+ projected_multipliers = _project_multipliers_wrt_euclidean_norm(
+ state, self._maximum_multiplier_radius)
+ else:
+ projected_multipliers = standard_ops.maximum(state, 0.0)
+ return state_ops.assign(state, projected_multipliers, name=name)
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
new file mode 100644
index 0000000000..9b4bf62710
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
@@ -0,0 +1,136 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for constrained_optimization.python.external_regret_optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.constrained_optimization.python import external_regret_optimizer
+from tensorflow.contrib.constrained_optimization.python import test_util
+
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent
+
+
+class AdditiveExternalRegretOptimizerWrapper(
+ external_regret_optimizer.AdditiveExternalRegretOptimizer):
+ """Testing wrapper class around AdditiveExternalRegretOptimizer.
+
+ This class is identical to AdditiveExternalRegretOptimizer, except that it
+ caches the internal optimization state when _lagrange_multipliers() is called,
+ so that we can test that the Lagrange multipliers take on their expected
+ values.
+ """
+
+ def __init__(self,
+ optimizer,
+ constraint_optimizer=None,
+ maximum_multiplier_radius=None):
+ """Same as AdditiveExternalRegretOptimizer.__init__."""
+ super(AdditiveExternalRegretOptimizerWrapper, self).__init__(
+ optimizer=optimizer,
+ constraint_optimizer=constraint_optimizer,
+ maximum_multiplier_radius=maximum_multiplier_radius)
+ self._cached_lagrange_multipliers = None
+
+ @property
+ def lagrange_multipliers(self):
+ """Returns the cached Lagrange multipliers."""
+ return self._cached_lagrange_multipliers
+
+ def _lagrange_multipliers(self, state):
+ """Caches the internal state for testing."""
+ self._cached_lagrange_multipliers = super(
+ AdditiveExternalRegretOptimizerWrapper,
+ self)._lagrange_multipliers(state)
+ return self._cached_lagrange_multipliers
+
+
+class ExternalRegretOptimizerTest(test.TestCase):
+
+ def test_project_multipliers_wrt_euclidean_norm(self):
+ """Tests Euclidean projection routine on some known values."""
+ multipliers1 = standard_ops.constant([-0.1, -0.6, -0.3])
+ expected_projected_multipliers1 = np.array([0.0, 0.0, 0.0])
+
+ multipliers2 = standard_ops.constant([-0.1, 0.6, 0.3])
+ expected_projected_multipliers2 = np.array([0.0, 0.6, 0.3])
+
+ multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1])
+ expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0])
+
+ with self.test_session() as session:
+ projected_multipliers1 = session.run(
+ external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
+ multipliers1, 1.0))
+ projected_multipliers2 = session.run(
+ external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
+ multipliers2, 1.0))
+ projected_multipliers3 = session.run(
+ external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
+ multipliers3, 1.0))
+
+ self.assertAllClose(
+ expected_projected_multipliers1,
+ projected_multipliers1,
+ rtol=0,
+ atol=1e-6)
+ self.assertAllClose(
+ expected_projected_multipliers2,
+ projected_multipliers2,
+ rtol=0,
+ atol=1e-6)
+ self.assertAllClose(
+ expected_projected_multipliers3,
+ projected_multipliers3,
+ rtol=0,
+ atol=1e-6)
+
+ def test_additive_external_regret_optimizer(self):
+ """Tests that the Lagrange multipliers update as expected."""
+ minimization_problem = test_util.ConstantMinimizationProblem(
+ np.array([0.6, -0.1, 0.4]))
+ optimizer = AdditiveExternalRegretOptimizerWrapper(
+ gradient_descent.GradientDescentOptimizer(1.0),
+ maximum_multiplier_radius=1.0)
+ train_op = optimizer.minimize_constrained(minimization_problem)
+
+ expected_multipliers = [
+ np.array([0.0, 0.0, 0.0]),
+ np.array([0.6, 0.0, 0.4]),
+ np.array([0.7, 0.0, 0.3]),
+ np.array([0.8, 0.0, 0.2]),
+ np.array([0.9, 0.0, 0.1]),
+ np.array([1.0, 0.0, 0.0]),
+ np.array([1.0, 0.0, 0.0]),
+ ]
+
+ multipliers = []
+ with self.test_session() as session:
+ session.run(standard_ops.global_variables_initializer())
+ while len(multipliers) < len(expected_multipliers):
+ multipliers.append(session.run(optimizer.lagrange_multipliers))
+ session.run(train_op)
+
+ for expected, actual in zip(expected_multipliers, multipliers):
+ self.assertAllClose(expected, actual, rtol=0, atol=1e-6)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
new file mode 100644
index 0000000000..04014ab4ae
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
@@ -0,0 +1,595 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines `{Additive,Multiplicative}SwapRegretOptimizer`s.
+
+These optimizers minimize a `ConstrainedMinimizationProblem` by using a
+swap-regret minimizing algorithm (either SGD or multiplicative weights) to learn
+what weights should be associated with the objective function and constraints.
+These algorithms do *not* use Lagrange multipliers, but the idea is similar.
+The main differences between the formulation used here, and the standard
+Lagrangian formulation, are that (i) the objective function is weighted, in
+addition to the constraints, and (ii) we learn a matrix of weights, instead of a
+vector.
+
+For the purposes of constrained optimization, at least in theory,
+external-regret minimization suffices if the `ConstrainedMinimizationProblem`
+we're optimizing doesn't have any `proxy_constraints`, while swap-regret
+minimization should be used if `proxy_constraints` are present.
+
+For more specifics, please refer to:
+
+> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+> Constrained Optimization".
+> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+The formulation used by both of the SwapRegretOptimizers can be found in
+Definition 2, and is discussed in Section 4. The
+`MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 in Section 4,
+with the difference being that it uses `tf.train.Optimizer`s, instead of SGD,
+for the "inner" updates. The `AdditiveSwapRegretOptimizer` differs further in
+that it performs additive (instead of multiplicative) updates of the stochastic
+matrix.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import math
+
+import six
+
+from tensorflow.contrib.constrained_optimization.python import constrained_optimizer
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import optimizer as train_optimizer
+
+
+def _maximal_eigenvector_power_method(matrix,
+ epsilon=1e-6,
+ maximum_iterations=100):
+ """Returns the maximal right-eigenvector of `matrix` using the power method.
+
+ Args:
+ matrix: 2D Tensor, the matrix of which we will find the maximal
+ right-eigenvector.
+ epsilon: nonnegative float, if two iterations of the power method differ (in
+ L2 norm) by no more than epsilon, we will terminate.
+ maximum_iterations: nonnegative int, if we perform this many iterations, we
+ will terminate.
+
+ Result:
+ The maximal right-eigenvector of `matrix`.
+
+ Raises:
+ ValueError: If the epsilon or maximum_iterations parameters violate their
+ bounds.
+ """
+ if epsilon <= 0.0:
+ raise ValueError("epsilon must be strictly positive")
+ if maximum_iterations <= 0:
+ raise ValueError("maximum_iterations must be strictly positive")
+
+ def while_loop_condition(iteration, eigenvector, old_eigenvector):
+ """Returns false if the while loop should terminate."""
+ not_done = (iteration < maximum_iterations)
+ not_converged = (standard_ops.norm(eigenvector - old_eigenvector) > epsilon)
+ return standard_ops.logical_and(not_done, not_converged)
+
+ def while_loop_body(iteration, eigenvector, old_eigenvector):
+ """Performs one iteration of the power method."""
+ del old_eigenvector # Needed by the condition, but not the body.
+ iteration += 1
+ # We need to use tf.matmul() and tf.expand_dims(), instead of
+ # tf.tensordot(), since the former will infer the shape of the result, while
+ # the latter will not (tf.while_loop() needs the shapes).
+ new_eigenvector = standard_ops.matmul(
+ matrix, standard_ops.expand_dims(eigenvector, 1))[:, 0]
+ new_eigenvector /= standard_ops.norm(new_eigenvector)
+ return (iteration, new_eigenvector, eigenvector)
+
+ iteration = standard_ops.constant(0)
+ eigenvector = standard_ops.ones_like(matrix[:, 0])
+ eigenvector /= standard_ops.norm(eigenvector)
+
+ # We actually want a do-while loop, so we explicitly call while_loop_body()
+ # once before tf.while_loop().
+ iteration, eigenvector, old_eigenvector = while_loop_body(
+ iteration, eigenvector, eigenvector)
+ iteration, eigenvector, old_eigenvector = control_flow_ops.while_loop(
+ while_loop_condition,
+ while_loop_body,
+ loop_vars=(iteration, eigenvector, old_eigenvector),
+ name="power_method")
+
+ return eigenvector
+
+
+def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
+ """Projects its argument onto the set of left-stochastic matrices.
+
+ This algorithm is O(n^3) at worst, where `matrix` is n*n. It can be done in
+ O(n^2 * log(n)) time by sorting each column (and maybe better with a different
+ algorithm), but the algorithm implemented here is easier to implement in
+ TensorFlow.
+
+ Args:
+ matrix: 2d square tensor, the matrix to project.
+
+ Returns:
+ The 2d square tensor that results from projecting `matrix` onto the set of
+ left-stochastic matrices w.r.t. the Euclidean norm applied column-wise
+ (i.e. the Frobenius norm).
+
+ Raises:
+ ValueError: if the `matrix` tensor does not have a fully-known shape, or is
+ not two-dimensional and square.
+ """
+ matrix_shape = matrix.get_shape()
+ if matrix_shape is None:
+ raise ValueError("matrix must have known shape")
+ if matrix_shape.ndims != 2:
+ raise ValueError(
+ "matrix must be two dimensional (instead is %d-dimensional)" %
+ matrix_shape.ndims)
+ if matrix_shape[0] != matrix_shape[1]:
+ raise ValueError("matrix must be be square (instead has shape (%d,%d))" %
+ (matrix_shape[0], matrix_shape[1]))
+ dimension = matrix_shape[0].value
+ if dimension is None:
+ raise ValueError("matrix must have fully-known shape")
+
+ def while_loop_condition(iteration, matrix, inactive, old_inactive):
+ """Returns false if the while loop should terminate."""
+ del matrix # Needed by the body, but not the condition.
+ not_done = (iteration < dimension)
+ not_converged = standard_ops.reduce_any(
+ standard_ops.not_equal(inactive, old_inactive))
+ return standard_ops.logical_and(not_done, not_converged)
+
+ def while_loop_body(iteration, matrix, inactive, old_inactive):
+ """Performs one iteration of the projection."""
+ del old_inactive # Needed by the condition, but not the body.
+ iteration += 1
+ scale = (1.0 - standard_ops.reduce_sum(
+ matrix, axis=0, keep_dims=True)) / standard_ops.maximum(
+ 1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True))
+ matrix += scale * inactive
+ new_inactive = standard_ops.to_float(matrix > 0)
+ matrix *= new_inactive
+ return (iteration, matrix, new_inactive, inactive)
+
+ iteration = standard_ops.constant(0)
+ inactive = standard_ops.ones_like(matrix)
+
+ # We actually want a do-while loop, so we explicitly call while_loop_body()
+ # once before tf.while_loop().
+ iteration, matrix, inactive, old_inactive = while_loop_body(
+ iteration, matrix, inactive, inactive)
+ iteration, matrix, inactive, old_inactive = control_flow_ops.while_loop(
+ while_loop_condition,
+ while_loop_body,
+ loop_vars=(iteration, matrix, inactive, old_inactive),
+ name="euclidean_projection")
+
+ return matrix
+
+
+def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix):
+ """Projects its argument onto the set of log-left-stochastic matrices.
+
+ Args:
+ log_matrix: 2d square tensor, the element-wise logarithm of the matrix to
+ project.
+
+ Returns:
+ The 2d square tensor that results from projecting exp(`matrix`) onto the set
+ of left-stochastic matrices w.r.t. the KL-divergence applied column-wise.
+ """
+
+ # For numerical reasons, make sure that the largest matrix element is zero
+ # before exponentiating.
+ log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True)
+ log_matrix -= standard_ops.log(
+ standard_ops.reduce_sum(
+ standard_ops.exp(log_matrix), axis=0, keep_dims=True))
+ return log_matrix
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
+ """Base class representing a `_SwapRegretOptimizer`.
+
+ This class contains most of the logic for performing constrained optimization,
+ minimizing external regret for the constraints player. What it *doesn't* do is
+ keep track of the internal state (the stochastic matrix). Instead, the state
+ is accessed via the _initial_state(), _stochastic_matrix(),
+ _constraint_grad_and_var() and _projection_op() methods.
+
+ The reason for this is that we want to make it easy to implement different
+ representations of the internal state. For example, for additive updates, it's
+ most natural to store the stochastic matrix directly, whereas for
+ multiplicative updates, it's most natural to store its element-wise logarithm.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by `_SwapRegretOptimizer`s can be found in Definition 2,
+ and is discussed in Section 4. Such optimizers are most similar to Algorithm
+ 2 in Section 4. Most notably, the internal state is a left-stochastic matrix
+ of shape (m+1,m+1), where m is the number of constraints.
+ """
+
+ def __init__(self, optimizer, constraint_optimizer=None):
+ """Constructs a new `_SwapRegretOptimizer`.
+
+ The difference between `optimizer` and `constraint_optimizer` (if the latter
+ is provided) is that the former is used for learning the model parameters,
+ while the latter us used for the update to the constraint/objective weight
+ matrix (the analogue of Lagrange multipliers). If no `constraint_optimizer`
+ is provided, then `optimizer` is used for both.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multiplier analogues.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multiplier analogues.
+
+ Returns:
+ A new `_SwapRegretOptimizer`.
+ """
+ super(_SwapRegretOptimizer, self).__init__(optimizer=optimizer)
+ self._constraint_optimizer = constraint_optimizer
+
+ @property
+ def constraint_optimizer(self):
+ """Returns the `tf.train.Optimizer` used for the matrix."""
+ return self._constraint_optimizer
+
+ @abc.abstractmethod
+ def _initial_state(self, num_constraints):
+ pass
+
+ @abc.abstractmethod
+ def _stochastic_matrix(self, state):
+ pass
+
+ def _distribution(self, state):
+ distribution = _maximal_eigenvector_power_method(
+ self._stochastic_matrix(state))
+ distribution = standard_ops.abs(distribution)
+ distribution /= standard_ops.reduce_sum(distribution)
+ return distribution
+
+ @abc.abstractmethod
+ def _constraint_grad_and_var(self, state, gradient):
+ pass
+
+ @abc.abstractmethod
+ def _projection_op(self, state, name=None):
+ pass
+
+ def minimize_constrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the constrained problem.
+
+ The `optimizer` constructor parameter will be used to update the model
+ parameters, while the constraint/objective weight matrix (the analogue of
+ Lagrange multipliers) will be updated using `constrained_optimizer` (if
+ provided) or `optimizer` (if not). Whether the matrix updates are additive
+ or multiplicative depends on the derived class.
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+ """
+ objective = minimization_problem.objective
+
+ constraints = minimization_problem.constraints
+ proxy_constraints = minimization_problem.proxy_constraints
+ if proxy_constraints is None:
+ proxy_constraints = constraints
+ # Flatten both constraints tensors to 1d.
+ num_constraints = minimization_problem.num_constraints
+ constraints = standard_ops.reshape(constraints, shape=(num_constraints,))
+ proxy_constraints = standard_ops.reshape(
+ proxy_constraints, shape=(num_constraints,))
+
+ # We use a lambda to initialize the state so that, if this function call is
+ # inside the scope of a tf.control_dependencies() block, the dependencies
+ # will not be applied to the initializer.
+ state = standard_ops.Variable(
+ lambda: self._initial_state(num_constraints),
+ trainable=False,
+ name="swap_regret_optimizer_state")
+
+ zero_and_constraints = standard_ops.concat(
+ (standard_ops.zeros((1,)), constraints), axis=0)
+ objective_and_proxy_constraints = standard_ops.concat(
+ (standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0)
+
+ distribution = self._distribution(state)
+ loss = standard_ops.tensordot(distribution, objective_and_proxy_constraints,
+ 1)
+ matrix_gradient = standard_ops.matmul(
+ standard_ops.expand_dims(zero_and_constraints, 1),
+ standard_ops.expand_dims(distribution, 0))
+
+ update_ops = []
+ if self.constraint_optimizer is None:
+ # If we don't have a separate constraint_optimizer, then we use
+ # self._optimizer for both the update of the model parameters, and that of
+ # the internal state.
+ grads_and_vars = self.optimizer.compute_gradients(
+ loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ grads_and_vars.append(
+ self._constraint_grad_and_var(state, matrix_gradient))
+ update_ops.append(
+ self.optimizer.apply_gradients(grads_and_vars, name="update"))
+ else:
+ # If we have a separate constraint_optimizer, then we use self._optimizer
+ # for the update of the model parameters, and self._constraint_optimizer
+ # for that of the internal state.
+ grads_and_vars = self.optimizer.compute_gradients(
+ loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ matrix_grads_and_vars = [
+ self._constraint_grad_and_var(state, matrix_gradient)
+ ]
+
+ gradients = [
+ gradient for gradient, _ in grads_and_vars + matrix_grads_and_vars
+ if gradient is not None
+ ]
+ with ops.control_dependencies(gradients):
+ update_ops.append(
+ self.optimizer.apply_gradients(grads_and_vars, name="update"))
+ update_ops.append(
+ self.constraint_optimizer.apply_gradients(
+ matrix_grads_and_vars, name="optimizer_state_update"))
+
+ with ops.control_dependencies(update_ops):
+ if global_step is None:
+ # If we don't have a global step, just project, and we're done.
+ return self._projection_op(state, name=name)
+ else:
+ # If we have a global step, then we need to increment it in addition to
+ # projecting.
+ projection_op = self._projection_op(state, name="project")
+ with ops.colocate_with(global_step):
+ global_step_op = state_ops.assign_add(
+ global_step, 1, name="global_step_increment")
+ return control_flow_ops.group(projection_op, global_step_op, name=name)
+
+
+class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer):
+ """A `ConstrainedOptimizer` based on swap-regret minimization.
+
+ This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly
+ minimize over the model parameters, and maximize over constraint/objective
+ weight matrix (the analogue of Lagrange multipliers), with the latter
+ maximization using additive updates and an algorithm that minimizes swap
+ regret.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by this optimizer can be found in Definition 2, and is
+ discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with
+ the differences being that it uses `tf.train.Optimizer`s, instead of SGD, for
+ the "inner" updates, and performs additive (instead of multiplicative) updates
+ of the stochastic matrix.
+ """
+
+ def __init__(self, optimizer, constraint_optimizer=None):
+ """Constructs a new `AdditiveSwapRegretOptimizer`.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multiplier analogues.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multiplier analogues.
+
+ Returns:
+ A new `AdditiveSwapRegretOptimizer`.
+ """
+ # TODO(acotter): add a parameter determining the initial values of the
+ # matrix elements (like initial_multiplier_radius in
+ # MultiplicativeSwapRegretOptimizer).
+ super(AdditiveSwapRegretOptimizer, self).__init__(
+ optimizer=optimizer, constraint_optimizer=constraint_optimizer)
+
+ def _initial_state(self, num_constraints):
+ # For an AdditiveSwapRegretOptimizer, the internal state is a tensor of
+ # shape (m+1,m+1), where m is the number of constraints, representing a
+ # left-stochastic matrix.
+ dimension = num_constraints + 1
+ # Initialize by putting all weight on the objective, and none on the
+ # constraints.
+ return standard_ops.concat(
+ (standard_ops.ones(
+ (1, dimension)), standard_ops.zeros((dimension - 1, dimension))),
+ axis=0)
+
+ def _stochastic_matrix(self, state):
+ return state
+
+ def _constraint_grad_and_var(self, state, gradient):
+ # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True?
+ return (-gradient, state)
+
+ def _projection_op(self, state, name=None):
+ with ops.colocate_with(state):
+ return state_ops.assign(
+ state,
+ _project_stochastic_matrix_wrt_euclidean_norm(state),
+ name=name)
+
+
+class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer):
+ """A `ConstrainedOptimizer` based on swap-regret minimization.
+
+ This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly
+ minimize over the model parameters, and maximize over constraint/objective
+ weight matrix (the analogue of Lagrange multipliers), with the latter
+ maximization using multiplicative updates and an algorithm that minimizes swap
+ regret.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by this optimizer can be found in Definition 2, and is
+ discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with
+ the difference being that it uses `tf.train.Optimizer`s, instead of SGD, for
+ the "inner" updates.
+ """
+
+ def __init__(self,
+ optimizer,
+ constraint_optimizer=None,
+ minimum_multiplier_radius=1e-3,
+ initial_multiplier_radius=None):
+ """Constructs a new `MultiplicativeSwapRegretOptimizer`.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multiplier analogues.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multiplier analogues.
+ minimum_multiplier_radius: float, each element of the matrix will be lower
+ bounded by `minimum_multiplier_radius` divided by one plus the number of
+ constraints.
+ initial_multiplier_radius: float, the initial value of each element of the
+ matrix associated with a constraint (i.e. excluding those elements
+ associated with the objective) will be `initial_multiplier_radius`
+ divided by one plus the number of constraints. Defaults to the value of
+ `minimum_multiplier_radius`.
+
+ Returns:
+ A new `MultiplicativeSwapRegretOptimizer`.
+
+ Raises:
+ ValueError: If the two radius parameters are inconsistent.
+ """
+ super(MultiplicativeSwapRegretOptimizer, self).__init__(
+ optimizer=optimizer, constraint_optimizer=constraint_optimizer)
+
+ if (minimum_multiplier_radius <= 0.0) or (minimum_multiplier_radius >= 1.0):
+ raise ValueError("minimum_multiplier_radius must be in the range (0,1)")
+ if initial_multiplier_radius is None:
+ initial_multiplier_radius = minimum_multiplier_radius
+ elif (initial_multiplier_radius <
+ minimum_multiplier_radius) or (minimum_multiplier_radius > 1.0):
+ raise ValueError("initial_multiplier_radius must be in the range "
+ "[minimum_multiplier_radius,1]")
+
+ self._minimum_multiplier_radius = minimum_multiplier_radius
+ self._initial_multiplier_radius = initial_multiplier_radius
+
+ def _initial_state(self, num_constraints):
+ # For a MultiplicativeSwapRegretOptimizer, the internal state is a tensor of
+ # shape (m+1,m+1), where m is the number of constraints, representing the
+ # element-wise logarithm of a left-stochastic matrix.
+ dimension = num_constraints + 1
+ # Initialize by putting as much weight as possible on the objective, and as
+ # little as possible on the constraints.
+ log_initial_one = math.log(1.0 - (self._initial_multiplier_radius *
+ (dimension - 1) / (dimension)))
+ log_initial_zero = math.log(self._initial_multiplier_radius / dimension)
+ return standard_ops.concat(
+ (standard_ops.constant(
+ log_initial_one, dtype=dtypes.float32, shape=(1, dimension)),
+ standard_ops.constant(
+ log_initial_zero,
+ dtype=dtypes.float32,
+ shape=(dimension - 1, dimension))),
+ axis=0)
+
+ def _stochastic_matrix(self, state):
+ return standard_ops.exp(state)
+
+ def _constraint_grad_and_var(self, state, gradient):
+ # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True?
+ return (-gradient, state)
+
+ def _projection_op(self, state, name=None):
+ with ops.colocate_with(state):
+ # Gets the dimension of the state (num_constraints + 1)--all of these
+ # assertions are of things that should be impossible, since the state
+ # passed into this method will have the same shape as that returned by
+ # _initial_state().
+ state_shape = state.get_shape()
+ assert state_shape is not None
+ assert state_shape.ndims == 2
+ assert state_shape[0] == state_shape[1]
+ dimension = state_shape[0].value
+ assert dimension is not None
+
+ minimum_log_multiplier = standard_ops.log(
+ self._minimum_multiplier_radius / standard_ops.to_float(dimension))
+
+ return state_ops.assign(
+ state,
+ standard_ops.maximum(
+ _project_log_stochastic_matrix_wrt_kl_divergence(state),
+ minimum_log_multiplier),
+ name=name)
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
new file mode 100644
index 0000000000..34c4543dca
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
@@ -0,0 +1,212 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for constrained_optimization.python.swap_regret_optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.constrained_optimization.python import swap_regret_optimizer
+from tensorflow.contrib.constrained_optimization.python import test_util
+
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent
+
+
+class AdditiveSwapRegretOptimizerWrapper(
+ swap_regret_optimizer.AdditiveSwapRegretOptimizer):
+ """Testing wrapper class around AdditiveSwapRegretOptimizer.
+
+ This class is identical to AdditiveSwapRegretOptimizer, except that it caches
+ the internal optimization state when _stochastic_matrix() is called, so that
+ we can test that the stochastic matrices take on their expected values.
+ """
+
+ def __init__(self, optimizer, constraint_optimizer=None):
+ """Same as AdditiveSwapRegretOptimizer.__init__()."""
+ super(AdditiveSwapRegretOptimizerWrapper, self).__init__(
+ optimizer=optimizer, constraint_optimizer=constraint_optimizer)
+ self._cached_stochastic_matrix = None
+
+ @property
+ def stochastic_matrix(self):
+ """Returns the cached stochastic matrix."""
+ return self._cached_stochastic_matrix
+
+ def _stochastic_matrix(self, state):
+ """Caches the internal state for testing."""
+ self._cached_stochastic_matrix = super(AdditiveSwapRegretOptimizerWrapper,
+ self)._stochastic_matrix(state)
+ return self._cached_stochastic_matrix
+
+
+class MultiplicativeSwapRegretOptimizerWrapper(
+ swap_regret_optimizer.MultiplicativeSwapRegretOptimizer):
+ """Testing wrapper class around MultiplicativeSwapRegretOptimizer.
+
+ This class is identical to MultiplicativeSwapRegretOptimizer, except that it
+ caches the internal optimization state when _stochastic_matrix() is called, so
+ that we can test that the stochastic matrices take on their expected values.
+ """
+
+ def __init__(self,
+ optimizer,
+ constraint_optimizer=None,
+ minimum_multiplier_radius=None,
+ initial_multiplier_radius=None):
+ """Same as MultiplicativeSwapRegretOptimizer.__init__()."""
+ super(MultiplicativeSwapRegretOptimizerWrapper, self).__init__(
+ optimizer=optimizer,
+ constraint_optimizer=constraint_optimizer,
+ minimum_multiplier_radius=1e-3,
+ initial_multiplier_radius=initial_multiplier_radius)
+ self._cached_stochastic_matrix = None
+
+ @property
+ def stochastic_matrix(self):
+ """Returns the cached stochastic matrix."""
+ return self._cached_stochastic_matrix
+
+ def _stochastic_matrix(self, state):
+ """Caches the internal state for testing."""
+ self._cached_stochastic_matrix = super(
+ MultiplicativeSwapRegretOptimizerWrapper,
+ self)._stochastic_matrix(state)
+ return self._cached_stochastic_matrix
+
+
+class SwapRegretOptimizerTest(test.TestCase):
+
+ def test_maximum_eigenvector_power_method(self):
+ """Tests power method routine on some known left-stochastic matrices."""
+ matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]])
+ matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]])
+
+ with self.test_session() as session:
+ eigenvector1 = session.run(
+ swap_regret_optimizer._maximal_eigenvector_power_method(
+ standard_ops.constant(matrix1)))
+ eigenvector2 = session.run(
+ swap_regret_optimizer._maximal_eigenvector_power_method(
+ standard_ops.constant(matrix2)))
+
+ # Check that eigenvector1 and eigenvector2 are eigenvectors of matrix1 and
+ # matrix2 (respectively) with associated eigenvalue 1.
+ matrix_eigenvector1 = np.tensordot(matrix1, eigenvector1, axes=1)
+ matrix_eigenvector2 = np.tensordot(matrix2, eigenvector2, axes=1)
+ self.assertAllClose(eigenvector1, matrix_eigenvector1, rtol=0, atol=1e-6)
+ self.assertAllClose(eigenvector2, matrix_eigenvector2, rtol=0, atol=1e-6)
+
+ def test_project_stochastic_matrix_wrt_euclidean_norm(self):
+ """Tests Euclidean projection routine on some known values."""
+ matrix = standard_ops.constant([[-0.1, -0.1, 0.4], [-0.8, 0.4, 1.2],
+ [-0.3, 0.1, 0.2]])
+ expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9],
+ [0.4, 0.3, 0.0]])
+
+ with self.test_session() as session:
+ projected_matrix = session.run(
+ swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm(
+ matrix))
+
+ self.assertAllClose(
+ expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6)
+
+ def test_project_log_stochastic_matrix_wrt_kl_divergence(self):
+ """Tests KL-divergence projection routine on some known values."""
+ matrix = standard_ops.constant([[0.2, 0.8, 0.6], [0.1, 0.2, 1.5],
+ [0.2, 1.0, 0.9]])
+ expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5],
+ [0.4, 0.5, 0.3]])
+
+ with self.test_session() as session:
+ projected_matrix = session.run(
+ standard_ops.exp(
+ swap_regret_optimizer.
+ _project_log_stochastic_matrix_wrt_kl_divergence(
+ standard_ops.log(matrix))))
+
+ self.assertAllClose(
+ expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6)
+
+ def test_additive_swap_regret_optimizer(self):
+ """Tests that the stochastic matrices update as expected."""
+ minimization_problem = test_util.ConstantMinimizationProblem(
+ np.array([0.6, -0.1, 0.4]))
+ optimizer = AdditiveSwapRegretOptimizerWrapper(
+ gradient_descent.GradientDescentOptimizer(1.0))
+ train_op = optimizer.minimize_constrained(minimization_problem)
+
+ # Calculated using a numpy+python implementation of the algorithm.
+ expected_matrices = [
+ np.array([[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
+ np.array([[0.66666667, 1.0, 1.0, 1.0], [0.26666667, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0], [0.06666667, 0.0, 0.0, 0.0]]),
+ np.array([[0.41666667, 0.93333333, 1.0,
+ 0.98333333], [0.46666667, 0.05333333, 0.0,
+ 0.01333333], [0.0, 0.0, 0.0, 0.0],
+ [0.11666667, 0.01333333, 0.0, 0.00333333]]),
+ ]
+
+ matrices = []
+ with self.test_session() as session:
+ session.run(standard_ops.global_variables_initializer())
+ while len(matrices) < len(expected_matrices):
+ matrices.append(session.run(optimizer.stochastic_matrix))
+ session.run(train_op)
+
+ for expected, actual in zip(expected_matrices, matrices):
+ self.assertAllClose(expected, actual, rtol=0, atol=1e-6)
+
+ def test_multiplicative_swap_regret_optimizer(self):
+ """Tests that the stochastic matrices update as expected."""
+ minimization_problem = test_util.ConstantMinimizationProblem(
+ np.array([0.6, -0.1, 0.4]))
+ optimizer = MultiplicativeSwapRegretOptimizerWrapper(
+ gradient_descent.GradientDescentOptimizer(1.0),
+ initial_multiplier_radius=0.8)
+ train_op = optimizer.minimize_constrained(minimization_problem)
+
+ # Calculated using a numpy+python implementation of the algorithm.
+ expected_matrices = [
+ np.array([[0.4, 0.4, 0.4, 0.4], [0.2, 0.2, 0.2, 0.2],
+ [0.2, 0.2, 0.2, 0.2], [0.2, 0.2, 0.2, 0.2]]),
+ np.array([[0.36999014, 0.38528351, 0.38528351, 0.38528351], [
+ 0.23517483, 0.21720297, 0.21720297, 0.21720297
+ ], [0.17774131, 0.18882719, 0.18882719, 0.18882719],
+ [0.21709373, 0.20868632, 0.20868632, 0.20868632]]),
+ np.array([[0.33972109, 0.36811863, 0.37118462, 0.36906575], [
+ 0.27114826, 0.23738228, 0.23376693, 0.23626491
+ ], [0.15712313, 0.17641793, 0.17858959, 0.17708679],
+ [0.23200752, 0.21808115, 0.21645886, 0.21758255]]),
+ ]
+
+ matrices = []
+ with self.test_session() as session:
+ session.run(standard_ops.global_variables_initializer())
+ while len(matrices) < len(expected_matrices):
+ matrices.append(session.run(optimizer.stochastic_matrix))
+ session.run(train_op)
+
+ for expected, actual in zip(expected_matrices, matrices):
+ self.assertAllClose(expected, actual, rtol=0, atol=1e-6)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/constrained_optimization/python/test_util.py b/tensorflow/contrib/constrained_optimization/python/test_util.py
new file mode 100644
index 0000000000..704b36ca4c
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/test_util.py
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains helpers used by tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.constrained_optimization.python import constrained_minimization_problem
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import standard_ops
+
+
+class ConstantMinimizationProblem(
+ constrained_minimization_problem.ConstrainedMinimizationProblem):
+ """A `ConstrainedMinimizationProblem` with constant constraint violations.
+
+ This minimization problem is intended for use in performing simple tests of
+ the Lagrange multiplier (or equivalent) update in the optimizers. There is a
+ one-element "dummy" model parameter, but it should be ignored.
+ """
+
+ def __init__(self, constraints):
+ """Constructs a new `ConstantMinimizationProblem'.
+
+ Args:
+ constraints: 1d numpy array, the constant constraint violations.
+
+ Returns:
+ A new `ConstantMinimizationProblem'.
+ """
+ # We make an fake 1-parameter linear objective so that we don't get a "no
+ # variables to optimize" error.
+ self._objective = standard_ops.Variable(0.0, dtype=dtypes.float32)
+ self._constraints = standard_ops.constant(constraints, dtype=dtypes.float32)
+
+ @property
+ def objective(self):
+ """Returns the objective function."""
+ return self._objective
+
+ @property
+ def constraints(self):
+ """Returns the constant constraint violations."""
+ return self._constraints