diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-23 13:04:48 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-23 14:48:50 -0700 |
commit | fd8e8a1cb6eaf76c7102995921dfff6f90c019cc (patch) | |
tree | d5274022cbfca84611723efb4a09d32b9fac5fba | |
parent | a89c86eadc30e884bbc94b1977830c0bfb22640e (diff) |
Adjust the create_train_op to have a flag for check_numerics.
Change: 151050903
-rw-r--r-- | tensorflow/contrib/slim/python/slim/learning.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/training/python/training/training.py | 9 |
2 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 48c96a58a6..814ce51100 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -382,7 +382,8 @@ def create_train_op(total_loss, gate_gradients=tf_optimizer.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, - gradient_multipliers=None): + gradient_multipliers=None, + check_numerics=True): """Creates an `Operation` that evaluates the gradients and returns the loss. Args: @@ -408,6 +409,8 @@ def create_train_op(total_loss, gradient_multipliers: A dictionary of either `Variables` or `Variable` op names to the coefficient by which the associated gradient should be scaled. + check_numerics: Whether or not we apply check_numerics. + Returns: A `Tensor` that when evaluated, computes the gradients and returns the total loss value. @@ -433,7 +436,8 @@ def create_train_op(total_loss, summarize_gradients=summarize_gradients, gate_gradients=gate_gradients, aggregation_method=aggregation_method, - colocate_gradients_with_ops=colocate_gradients_with_ops) + colocate_gradients_with_ops=colocate_gradients_with_ops, + check_numerics=check_numerics) def _wait_for_step(sess, global_step, step): diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index 1bb9175766..048410d321 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -368,7 +368,8 @@ def create_train_op(total_loss, summarize_gradients=False, gate_gradients=tf_optimizer.Optimizer.GATE_OP, aggregation_method=None, - colocate_gradients_with_ops=False): + colocate_gradients_with_ops=False, + check_numerics=True): """Creates an `Operation` that evaluates the gradients and returns the loss. Args: @@ -393,6 +394,7 @@ def create_train_op(total_loss, Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: Whether or not to try colocating the gradients with the ops that generated them. + check_numerics: Whether or not we apply check_numerics. Returns: A `Tensor` that when evaluated, computes the gradients and returns the total @@ -449,8 +451,9 @@ def create_train_op(total_loss, with ops.name_scope('train_op'): # Make sure total_loss is valid. - total_loss = array_ops.check_numerics(total_loss, - 'LossTensor is inf or nan') + if check_numerics: + total_loss = array_ops.check_numerics(total_loss, + 'LossTensor is inf or nan') # Ensure the train_tensor computes grad_updates. train_op = control_flow_ops.with_dependencies([grad_updates], total_loss) |