aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt/python/training/external_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/opt/python/training/external_optimizer.py')
-rw-r--r--tensorflow/contrib/opt/python/training/external_optimizer.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/contrib/opt/python/training/external_optimizer.py b/tensorflow/contrib/opt/python/training/external_optimizer.py
index ff80167ff4..0909760b38 100644
--- a/tensorflow/contrib/opt/python/training/external_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/external_optimizer.py
@@ -99,8 +99,13 @@ class ExternalOptimizerInterface(object):
slice(start, end) for start, end in zip(accumulated_dims[:-1],
accumulated_dims[1:])]
- def minimize(self, session=None, feed_dict=None, fetches=None,
- step_callback=None, loss_callback=None):
+ def minimize(self,
+ session=None,
+ feed_dict=None,
+ fetches=None,
+ step_callback=None,
+ loss_callback=None,
+ **run_kwargs):
"""Minimize a scalar `Tensor`.
Variables subject to optimization are updated in-place at the end of
@@ -120,6 +125,7 @@ class ExternalOptimizerInterface(object):
flattened into a single vector.
loss_callback: A function to be called every time the loss and gradients
are computed, with evaluated fetches supplied as positional arguments.
+ **run_kwargs: kwargs to pass to `session.run`.
"""
session = session or ops.get_default_session()
feed_dict = feed_dict or {}
@@ -160,8 +166,10 @@ class ExternalOptimizerInterface(object):
for packing_slice in self._packing_slices]
# Set optimization variables to their new values.
- session.run(self._var_updates,
- feed_dict=dict(zip(self._update_placeholders, var_vals)))
+ session.run(
+ self._var_updates,
+ feed_dict=dict(zip(self._update_placeholders, var_vals)),
+ **run_kwargs)
def _minimize(self, initial_val, loss_grad_func, equality_funcs,
equality_grad_funcs, inequality_funcs, inequality_grad_funcs,