aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-03-11 10:20:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-11 11:42:35 -0800
commitf6f77accbb9c978b5d3562922425bffd01690f79 (patch)
treece3be814626d2ea9b27e80839f38d2d47444c03b
parent0e1a324c7bb73d80e0e2d069396d61115e0d9096 (diff)
Implements "bool strictness" for `tf.Tensor`.
Using a `tf.Tensor` as a Python boolean value is ambiguous: in library code it typically means "is this tensor not None?", whereas in user code it can mean "does this tensor evaluate to `True`" (which does not work and leads to buggy code). This change adds a warning when a tensor is used as a bool, and replaces all known instances in TensorFlow with the appropriate `is None` or `is not None` checks. This is the first step towards addressing Issue #1454. Change: 116984574
-rw-r--r--tensorflow/contrib/layers/python/framework/tensor_util.py2
-rw-r--r--tensorflow/python/framework/function.py2
-rw-r--r--tensorflow/python/framework/ops.py49
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py31
-rw-r--r--tensorflow/python/ops/clip_ops.py2
-rw-r--r--tensorflow/python/ops/control_flow_grad.py2
-rw-r--r--tensorflow/python/ops/control_flow_ops.py12
-rw-r--r--tensorflow/python/ops/gradients.py33
-rw-r--r--tensorflow/python/ops/gradients_test.py10
-rw-r--r--tensorflow/python/ops/nn.py3
-rw-r--r--tensorflow/python/ops/nn_ops.py2
-rw-r--r--tensorflow/python/ops/rnn.py4
-rw-r--r--tensorflow/python/ops/rnn_cell.py3
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py21
-rw-r--r--tensorflow/python/ops/variable_scope.py2
-rw-r--r--tensorflow/python/training/input.py2
-rw-r--r--tensorflow/python/training/optimizer.py2
17 files changed, 137 insertions, 45 deletions
diff --git a/tensorflow/contrib/layers/python/framework/tensor_util.py b/tensorflow/contrib/layers/python/framework/tensor_util.py
index 7515b78265..1a5450630c 100644
--- a/tensorflow/contrib/layers/python/framework/tensor_util.py
+++ b/tensorflow/contrib/layers/python/framework/tensor_util.py
@@ -51,7 +51,7 @@ def _assert_same_base_type(items, expected_type=None):
"""
original_item_str = None
for item in items:
- if item:
+ if item is not None:
item_type = item.dtype.base_dtype
if not expected_type:
expected_type = item_type
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index f0329a6730..1277e57c0d 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -383,7 +383,7 @@ def define_function(func, input_types):
outputs = func(*inputs)
else:
outputs = func(**kwargs)
- if not outputs:
+ if not isinstance(outputs, ops.Tensor) and not outputs:
raise ValueError("Function must return at least one tensor")
# Convenience: if func only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index b1c68f33fb..d7ba56c479 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -26,6 +26,7 @@ import linecache
import re
import sys
import threading
+import warnings
import weakref
import six
@@ -440,6 +441,43 @@ class Tensor(object):
"""
raise TypeError("'Tensor' object is not iterable.")
+ def __bool__(self):
+ """Dummy method to warn when a tensor is being used as a Python `bool`.
+
+ NOTE(mrry): This overload produces a warning when the user
+ inadvertently treats a `Tensor` as a boolean (e.g. in an `if`
+ statement). For example:
+
+ ```python
+ if tf.constant(True): # Will warn.
+ # ...
+
+ if tf.constant(5) < tf.constant(7): # Will warn.
+ # ...
+ ```
+
+ This functionality is deprecated. In future it will raise a `TypeError`.
+
+ Returns:
+ `True`.
+ """
+ warnings.warn("Using a `tf.Tensor` as a Python `bool` is deprecated. "
+ "Use `if t is None:` instead of `if t:` in new code. "
+ "A `TypeError` will be raised in future versions.",
+ DeprecationWarning)
+ return True
+
+ def __nonzero__(self):
+ """Dummy method to warn when a tensor is being used as a Python `bool`.
+
+ NOTE(mrry): This is the Python 2.x counterpart to `__bool__()`
+ above.
+
+ Returns:
+ `True`.
+ """
+ return self.__bool__()
+
def eval(self, feed_dict=None, session=None):
"""Evaluates this tensor in a `Session`.
@@ -783,7 +821,8 @@ class IndexedSlices(object):
def __str__(self):
return "IndexedSlices(indices=%s, values=%s%s)" % (
self._indices, self._values,
- (", dense_shape=%s" % self._dense_shape) if self._dense_shape else "")
+ (", dense_shape=%s" % self._dense_shape)
+ if self._dense_shape is not None else "")
def __neg__(self):
return IndexedSlices(-self.values, self.indices, self.dense_shape)
@@ -2137,7 +2176,9 @@ class Graph(object):
else:
raise ValueError("allow_tensor and allow_operation can't both be False.")
- obj = _as_graph_element(obj) or obj
+ temp_obj = _as_graph_element(obj)
+ if temp_obj is not None:
+ obj = temp_obj
# If obj appears to be a name...
if isinstance(obj, compat.bytes_or_text_types):
@@ -3362,11 +3403,11 @@ def _get_graph_from_inputs(op_input_list, graph=None):
else:
graph_element = _as_graph_element(op_input)
- if graph_element:
+ if graph_element is not None:
if not graph:
original_graph_element = graph_element
graph = graph_element.graph
- elif original_graph_element:
+ elif original_graph_element is not None:
_assert_same_graph(original_graph_element, graph_element)
elif graph_element.graph is not graph:
raise ValueError(
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 347b844345..9e5516cb87 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import math
+import warnings
import numpy as np
import tensorflow as tf
@@ -893,6 +894,36 @@ class LogicalOpTest(tf.test.TestCase):
ValueError, lambda e: "Incompatible shapes" in str(e)):
f(x, y)
+ def testUsingAsPythonValueFails(self):
+ # TODO(mrry): Replace with `assertRaises(TypeError)` after this
+ # functionality is deprecated.
+ warnings.simplefilter("always")
+ # Ensure that we raise an error when the user attempts to treat a
+ # `Tensor` as a Python `bool`.
+ b = tf.constant(False)
+ with warnings.catch_warnings(record=True) as w:
+ if b:
+ pass
+ self.assertEqual(1, len(w))
+ self.assertTrue("`bool` is deprecated" in str(w[-1].message))
+
+ x = tf.constant(3)
+ y = tf.constant(4)
+ with warnings.catch_warnings(record=True) as w:
+ if x > y:
+ pass
+ self.assertEqual(1, len(w))
+ self.assertTrue("`bool` is deprecated" in str(w[-1].message))
+
+ z = tf.constant(7)
+
+ # The chained comparison should fail because Python computes `x <
+ # y` and short-circuits the comparison with `z` if it is `False`.
+ with warnings.catch_warnings(record=True) as w:
+ _ = x < y < z
+ self.assertEqual(1, len(w))
+ self.assertTrue("`bool` is deprecated" in str(w[-1].message))
+
class SelectOpTest(tf.test.TestCase):
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index fea6d44476..a85e943dc6 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -127,7 +127,7 @@ def global_norm(t_list, name=None):
if t is not None else t
for i, t in enumerate(t_list)]
squared_norms = array_ops.pack(
- [math_ops.reduce_sum(v * v) for v in values if v])
+ [math_ops.reduce_sum(v * v) for v in values if v is not None])
norm = math_ops.sqrt(
math_ops.reduce_sum(squared_norms), name="global_norm")
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
index aa85c12931..44ad433b34 100644
--- a/tensorflow/python/ops/control_flow_grad.py
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -104,7 +104,7 @@ def _MergeGrad(op, grad, _):
# use the accumulated values as the predicate for this backprop switch.
grad_state = grad_ctxt.grad_state
real_pred = grad_state.history_map.get(pred.name)
- if not real_pred:
+ if real_pred is None:
# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context
grad_ctxt.Exit()
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index e861206107..813f8fa372 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -256,8 +256,8 @@ def merge(inputs, name=None):
values, _ = _Merge([inp.values for inp in inputs], name=name)
indices, chosen_index = _Merge(
[inp.indices for inp in inputs], name="indices")
- if any(inp.dense_shape for inp in inputs):
- if not all(inp.dense_shape for inp in inputs):
+ if any(inp.dense_shape is not None for inp in inputs):
+ if any(inp.dense_shape is None for inp in inputs):
raise ValueError("Either all merged IndexedSlices must have a "
"dense_shape, or none must have a dense_shape.")
dense_shape, _ = _Merge(
@@ -604,7 +604,7 @@ class GradLoopState(object):
# Guard stack pop with a switch if it is controlled by a cond
grad_state = self
pred = None
- while not pred and grad_state:
+ while pred is None and grad_state:
pred = grad_state.history_map.get(cond_ctxt.pred.name)
grad_state = grad_state.outer_grad_state
branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
@@ -1200,7 +1200,7 @@ class WhileContext(ControlFlowContext):
return self
def GetControlPivot(self):
- if self._pivot_for_body:
+ if self._pivot_for_body is not None:
return self._pivot_for_body
return self._pivot_for_pred
@@ -1715,7 +1715,7 @@ def tuple(tensors, name=None, control_inputs=None):
"""
with ops.op_scope(tensors, name, "tuple") as name:
- gating_ops = [t.op for t in tensors if t]
+ gating_ops = [t.op for t in tensors if t is not None]
if control_inputs:
for c in control_inputs:
if isinstance(c, ops.Tensor):
@@ -1731,7 +1731,7 @@ def tuple(tensors, name=None, control_inputs=None):
gate = group(*gating_ops)
tpl = []
for t in tensors:
- if t:
+ if t is not None:
tpl.append(with_dependencies([gate], t))
else:
tpl.append(None)
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 9fc1aa80d1..5966f0039f 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -435,7 +435,9 @@ def gradients(ys,
grad_fn = None
# pylint: disable=protected-access
is_func_call = ops.get_default_graph()._is_function(op.type)
- if not is_func_call and any(out_grads) and op._id not in stop_ops:
+ if not is_func_call and any(
+ isinstance(g, ops.Tensor) or g for g in out_grads) and (
+ op._id not in stop_ops):
# pylint: enable=protected-access
# A grad_fn must be defined, either as a function or as None
# for ops that do not have gradients.
@@ -448,12 +450,14 @@ def gradients(ys,
if loop_state:
loop_state.EnterGradWhileContext(op, before=False)
- if (grad_fn or is_func_call) and any(out_grads):
+ if (grad_fn or is_func_call) and any(
+ isinstance(g, ops.Tensor) or g for g in out_grads):
# NOTE: If _AggregatedGrads didn't compute a value for the i'th
# output, it means that the cost does not depend on output[i],
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
- if not out_grad and _IsFloat(op.outputs[i]):
+ if (not isinstance(out_grad, ops.Tensor)
+ and not out_grad) and _IsFloat(op.outputs[i]):
# Only floating-point outputs get a zero gradient. Gradient
# functions should ignore the gradient for other outputs.
if loop_state:
@@ -476,19 +480,27 @@ def gradients(ys,
else:
in_grads = _AsList(grad_fn(op, *out_grads))
_VerifyGeneratedGradients(in_grads, op)
- if gate_gradients and len(tuple(filter(None, in_grads))) > 1:
+ if gate_gradients and len(
+ [x for x in in_grads if x is not None]) > 1:
in_grads = control_flow_ops.tuple(in_grads)
logging.vlog(1, "Gradient for '" + op.name + "'")
+ def _FilterGrad(x):
+ if x is None:
+ return False
+ if isinstance(x, (list, tuple)):
+ return bool(x)
+ else:
+ return True
logging.vlog(1, " in --> %s",
- ", ".join([x.name for x in out_grads if x]))
+ ", ".join([x.name for x in out_grads if _FilterGrad(x)]))
logging.vlog(1, " out --> %s",
- ", ".join([x.name for x in in_grads if x]))
+ ", ".join([x.name for x in in_grads if _FilterGrad(x)]))
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagates a list of None backwards.
in_grads = [None] * len(op.inputs)
for t_in, in_grad in zip(op.inputs, in_grads):
- if in_grad:
+ if in_grad is not None:
_SetGrad(grads, t_in, in_grad)
if loop_state:
loop_state.ExitGradWhileContext(op, before=False)
@@ -621,12 +633,12 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
continue
# Grads have to be Tensors or IndexedSlices
if not all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
- for g in out_grad if g]):
+ for g in out_grad if g is not None]):
raise TypeError("gradients have to be either all Tensors "
"or all IndexedSlices")
# Aggregate multiple gradients, and convert [] to None.
if out_grad:
- if all([isinstance(g, ops.Tensor) for g in out_grad if g]):
+ if all([isinstance(g, ops.Tensor) for g in out_grad if g is not None]):
tensor_shape = _AccumulatorShape(out_grad)
if len(out_grad) < 2:
used = "nop"
@@ -665,7 +677,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
tensor_shape, used)
else:
- out_grad = math_ops._as_indexed_slices_list([g for g in out_grad if g])
+ out_grad = math_ops._as_indexed_slices_list([g for g in out_grad
+ if g is not None])
out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad]
# Form IndexedSlices out of the concatenated values and
# indices.
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index c535b5639d..492d60931f 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -190,7 +190,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
y = x + 1.0
z = y + 1
grads = gradients.gradients(z, [x])
- self.assertTrue(all([x for x in grads]))
+ self.assertTrue(all(x is not None for x in grads))
def testBoundaryContinue(self):
# Test that we differentiate both 'x' and 'y' correctly when x is a
@@ -200,7 +200,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
y = x * 2.0
z = y * 3.0
grads = gradients.gradients(z, [x, y])
- self.assertTrue(all([x for x in grads]))
+ self.assertTrue(all(x is not None for x in grads))
self.assertEqual(6.0, grads[0].eval())
def testAggregationMethodAccumulateN(self):
@@ -213,7 +213,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
[x, y],
aggregation_method=
gradients.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
- self.assertTrue(all([x for x in grads]))
+ self.assertTrue(all(x is not None for x in grads))
self.assertEqual(20.0, grads[0].eval())
self.assertEqual(10.0, grads[1].eval())
@@ -226,7 +226,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
z,
[x, y],
aggregation_method=gradients.AggregationMethod.ADD_N)
- self.assertTrue(all([x for x in grads]))
+ self.assertTrue(all(x is not None for x in grads))
self.assertEqual(20.0, grads[0].eval())
self.assertEqual(10.0, grads[1].eval())
@@ -239,7 +239,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
z,
[x, y],
aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
- self.assertTrue(all([x for x in grads]))
+ self.assertTrue(all(x is not None for x in grads))
self.assertEqual(20.0, grads[0].eval())
self.assertEqual(10.0, grads[1].eval())
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 02a444ace3..75de6b5d7d 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -687,7 +687,8 @@ def batch_normalization(x,
inv = math_ops.rsqrt(variance + variance_epsilon)
if scale is not None:
inv *= scale
- return x * inv + (offset - mean * inv if offset else -mean * inv)
+ return x * inv + (
+ offset - mean * inv if offset is not None else -mean * inv)
def batch_norm_with_global_normalization(t,
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index f7891bb2d0..118240cfdb 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -663,7 +663,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
keep_prob, dtype=x.dtype, name="keep_prob")
keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
- noise_shape = noise_shape or array_ops.shape(x)
+ noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
# uniform [keep_prob, 1.0 + keep_prob)
random_tensor = keep_prob
random_tensor += random_ops.random_uniform(
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 611f5fa314..cddaaeee7b 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -122,7 +122,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
if sequence_length is not None:
sequence_length = math_ops.to_int32(sequence_length)
- if sequence_length: # Prepare variables
+ if sequence_length is not None: # Prepare variables
zero_output = array_ops.zeros(
array_ops.pack([batch_size, cell.output_size]), inputs[0].dtype)
zero_output.set_shape(
@@ -135,7 +135,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
# pylint: disable=cell-var-from-loop
call_cell = lambda: cell(input_, state)
# pylint: enable=cell-var-from-loop
- if sequence_length:
+ if sequence_length is not None:
(output, state) = _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
zero_output, state, call_cell)
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index c9dfbca979..3db490007a 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -683,7 +683,8 @@ def linear(args, output_size, bias, bias_start=0.0, scope=None):
Raises:
ValueError: if some of the arguments has unspecified or wrong shape.
"""
- assert args
+ if args is None or (isinstance(args, (list, tuple)) and not args):
+ raise ValueError("`args` must be specified")
if not isinstance(args, (list, tuple)):
args = [args]
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index ace9277a0b..e17757daae 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -73,17 +73,17 @@ class TensorArray(object):
ValueError: if both handle and tensor_array_name are provided.
TypeError: if handle is provided but is not a Tensor.
"""
- if handle and tensor_array_name:
+ if handle is not None and tensor_array_name:
raise ValueError(
"Cannot construct with both handle and tensor_array_name")
- if handle and not isinstance(handle, ops.Tensor):
+ if handle is not None and not isinstance(handle, ops.Tensor):
raise TypeError("Handle must be a Tensor")
if handle is None and size is None:
raise ValueError("Size must be provided if handle is not provided")
- if handle and size is not None:
+ if handle is not None and size is not None:
raise ValueError("Cannot provide both a handle and size "
"at the same time")
- if handle and dynamic_size is not None:
+ if handle is not None and dynamic_size is not None:
raise ValueError("Cannot provide both a handle and dynamic_size "
"at the same time")
@@ -91,13 +91,16 @@ class TensorArray(object):
self._dtype = dtype
with ops.op_scope([handle, size, flow], name, "TensorArray") as scope:
- if handle:
+ if handle is not None:
self._handle = handle
else:
self._handle = gen_data_flow_ops._tensor_array(
dtype=dtype, size=size, dynamic_size=dynamic_size,
tensor_array_name=tensor_array_name, name=scope)
- self._flow = flow or constant_op.constant(0, dtype=_dtypes.float32)
+ if flow is not None:
+ self._flow = flow
+ else:
+ self._flow = constant_op.constant(0, dtype=_dtypes.float32)
@property
def flow(self):
@@ -119,9 +122,11 @@ class TensorArray(object):
# TensorArrays are dynamically sized. This forces the creation
# of the grad TensorArray only once the final forward array's size
# is fixed.
+ if flow is None:
+ flow = self.flow
g_handle = gen_data_flow_ops._tensor_array_grad(
- handle=self._handle, source=source, flow_in=flow or self.flow)
- g = TensorArray(dtype=self._dtype, handle=g_handle, flow=flow or self.flow)
+ handle=self._handle, source=source, flow_in=flow)
+ g = TensorArray(dtype=self._dtype, handle=g_handle, flow=flow)
return g
def read(self, index, name=None):
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 7ffdd37212..9cb08fcd13 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -154,7 +154,7 @@ class _VariableStore(object):
if regularizer:
with ops.name_scope(name + "/Regularizer/"):
loss = regularizer(v)
- if loss:
+ if loss is not None:
logging.info("Applied regularizer to %s and added the result %s to "
"REGULARIZATION_LOSSES.", v.name, loss.name)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 55ae8adba8..ace7b49d97 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -128,7 +128,7 @@ def string_input_producer(string_tensor, num_epochs=None, shuffle=True,
will fail with an assertion if string_tensor becomes a null tensor.
"""
not_null_err = "string_input_producer requires a non-null input tensor"
- if not string_tensor:
+ if not isinstance(string_tensor, ops.Tensor) and not string_tensor:
raise ValueError(not_null_err)
with ops.op_scope([string_tensor], name, "input_producer") as name:
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 1c3ac2d09d..70a3761de1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -290,7 +290,7 @@ class Optimizer(object):
with ops.op_scope([], name, self._name) as name:
self._prepare()
for grad, var in grads_and_vars:
- if not grad:
+ if grad is None:
continue
# We colocate all ops created in _apply_dense or _apply_sparse
# on the same device as the variable.