aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-23 08:30:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-23 12:10:09 -0700
commit9394661924a37fe17e5f11f43833e9bbffbf23ff (patch)
treeb97cff67b6b80d66e7d959dee21d2ee8c4a6cf19
parentd79bc31f45b1fdf98bdcac022f0a86a3e7fbf860 (diff)
Fix dense_shape check.
This rectifies the following error: TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use the logical TensorFlow ops to test the value of a tensor. when the conditional branch contains a tf.IndexedSlices object with dense_shape=tf.constant(...). Change: 117937593
-rw-r--r--tensorflow/python/ops/control_flow_ops.py2
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py15
2 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index edfe69a8d0..ee9f19ac39 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -213,7 +213,7 @@ def switch(data, pred, dtype=None, name=None):
val, ind, dense_shape = data.values, data.indices, data.dense_shape
val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name)
ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices")
- if dense_shape:
+ if dense_shape is not None:
dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
dense_shape, pred, name="dense_shape")
else:
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 0376779d23..c537ed1b73 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -101,5 +101,20 @@ class ShapeTestCase(TensorFlowTestCase):
[tf.constant(1.0)], tensor).get_shape())
+class SwitchTestCase(TensorFlowTestCase):
+
+ def testIndexedSlicesWithDenseShape(self):
+ with self.test_session():
+ data = ops.IndexedSlices(tf.constant([1, 2, 3]),
+ tf.constant([0, 1]),
+ dense_shape=tf.constant([3]))
+ zero = tf.constant(0)
+ one = tf.constant(1)
+ less_op = tf.less(zero, one)
+ switch_false, switch_true = control_flow_ops.switch(data, less_op)
+ self.assertAllEqual([1, 2, 3], switch_true.values.eval())
+ self.assertAllEqual([0, 1], switch_true.indices.eval())
+
+
if __name__ == "__main__":
googletest.main()