aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-23 13:04:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-23 14:48:50 -0700
commitfd8e8a1cb6eaf76c7102995921dfff6f90c019cc (patch)
treed5274022cbfca84611723efb4a09d32b9fac5fba
parenta89c86eadc30e884bbc94b1977830c0bfb22640e (diff)
Adjust the create_train_op to have a flag for check_numerics.
Change: 151050903
-rw-r--r--tensorflow/contrib/slim/python/slim/learning.py8
-rw-r--r--tensorflow/contrib/training/python/training/training.py9
2 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py
index 48c96a58a6..814ce51100 100644
--- a/tensorflow/contrib/slim/python/slim/learning.py
+++ b/tensorflow/contrib/slim/python/slim/learning.py
@@ -382,7 +382,8 @@ def create_train_op(total_loss,
gate_gradients=tf_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
- gradient_multipliers=None):
+ gradient_multipliers=None,
+ check_numerics=True):
"""Creates an `Operation` that evaluates the gradients and returns the loss.
Args:
@@ -408,6 +409,8 @@ def create_train_op(total_loss,
gradient_multipliers: A dictionary of either `Variables` or `Variable` op
names to the coefficient by which the associated gradient should be
scaled.
+ check_numerics: Whether or not we apply check_numerics.
+
Returns:
A `Tensor` that when evaluated, computes the gradients and returns the total
loss value.
@@ -433,7 +436,8 @@ def create_train_op(total_loss,
summarize_gradients=summarize_gradients,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops)
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ check_numerics=check_numerics)
def _wait_for_step(sess, global_step, step):
diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py
index 1bb9175766..048410d321 100644
--- a/tensorflow/contrib/training/python/training/training.py
+++ b/tensorflow/contrib/training/python/training/training.py
@@ -368,7 +368,8 @@ def create_train_op(total_loss,
summarize_gradients=False,
gate_gradients=tf_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
- colocate_gradients_with_ops=False):
+ colocate_gradients_with_ops=False,
+ check_numerics=True):
"""Creates an `Operation` that evaluates the gradients and returns the loss.
Args:
@@ -393,6 +394,7 @@ def create_train_op(total_loss,
Valid values are defined in the class `AggregationMethod`.
colocate_gradients_with_ops: Whether or not to try colocating the gradients
with the ops that generated them.
+ check_numerics: Whether or not we apply check_numerics.
Returns:
A `Tensor` that when evaluated, computes the gradients and returns the total
@@ -449,8 +451,9 @@ def create_train_op(total_loss,
with ops.name_scope('train_op'):
# Make sure total_loss is valid.
- total_loss = array_ops.check_numerics(total_loss,
- 'LossTensor is inf or nan')
+ if check_numerics:
+ total_loss = array_ops.check_numerics(total_loss,
+ 'LossTensor is inf or nan')
# Ensure the train_tensor computes grad_updates.
train_op = control_flow_ops.with_dependencies([grad_updates], total_loss)