diff options
Diffstat (limited to 'tensorflow/contrib/opt/python/training/external_optimizer.py')
-rw-r--r-- | tensorflow/contrib/opt/python/training/external_optimizer.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/contrib/opt/python/training/external_optimizer.py b/tensorflow/contrib/opt/python/training/external_optimizer.py index ff80167ff4..0909760b38 100644 --- a/tensorflow/contrib/opt/python/training/external_optimizer.py +++ b/tensorflow/contrib/opt/python/training/external_optimizer.py @@ -99,8 +99,13 @@ class ExternalOptimizerInterface(object): slice(start, end) for start, end in zip(accumulated_dims[:-1], accumulated_dims[1:])] - def minimize(self, session=None, feed_dict=None, fetches=None, - step_callback=None, loss_callback=None): + def minimize(self, + session=None, + feed_dict=None, + fetches=None, + step_callback=None, + loss_callback=None, + **run_kwargs): """Minimize a scalar `Tensor`. Variables subject to optimization are updated in-place at the end of @@ -120,6 +125,7 @@ class ExternalOptimizerInterface(object): flattened into a single vector. loss_callback: A function to be called every time the loss and gradients are computed, with evaluated fetches supplied as positional arguments. + **run_kwargs: kwargs to pass to `session.run`. """ session = session or ops.get_default_session() feed_dict = feed_dict or {} @@ -160,8 +166,10 @@ class ExternalOptimizerInterface(object): for packing_slice in self._packing_slices] # Set optimization variables to their new values. - session.run(self._var_updates, - feed_dict=dict(zip(self._update_placeholders, var_vals))) + session.run( + self._var_updates, + feed_dict=dict(zip(self._update_placeholders, var_vals)), + **run_kwargs) def _minimize(self, initial_val, loss_grad_func, equality_funcs, equality_grad_funcs, inequality_funcs, inequality_grad_funcs, |