diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-03-23 08:30:21 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-03-23 12:10:09 -0700 |
commit | 9394661924a37fe17e5f11f43833e9bbffbf23ff (patch) | |
tree | b97cff67b6b80d66e7d959dee21d2ee8c4a6cf19 | |
parent | d79bc31f45b1fdf98bdcac022f0a86a3e7fbf860 (diff) |
Fix dense_shape check.
This rectifies the following error:
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is
not None:` instead of `if t:` to test if a tensor is defined, and use the
logical TensorFlow ops to test the value of a tensor.
when the conditional branch contains a tf.IndexedSlices object with
dense_shape=tf.constant(...).
Change: 117937593
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops_test.py | 15 |
2 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index edfe69a8d0..ee9f19ac39 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -213,7 +213,7 @@ def switch(data, pred, dtype=None, name=None): val, ind, dense_shape = data.values, data.indices, data.dense_shape val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name) ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices") - if dense_shape: + if dense_shape is not None: dense_shape_f, dense_shape_t = gen_control_flow_ops._switch( dense_shape, pred, name="dense_shape") else: diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 0376779d23..c537ed1b73 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -101,5 +101,20 @@ class ShapeTestCase(TensorFlowTestCase): [tf.constant(1.0)], tensor).get_shape()) +class SwitchTestCase(TensorFlowTestCase): + + def testIndexedSlicesWithDenseShape(self): + with self.test_session(): + data = ops.IndexedSlices(tf.constant([1, 2, 3]), + tf.constant([0, 1]), + dense_shape=tf.constant([3])) + zero = tf.constant(0) + one = tf.constant(1) + less_op = tf.less(zero, one) + switch_false, switch_true = control_flow_ops.switch(data, less_op) + self.assertAllEqual([1, 2, 3], switch_true.values.eval()) + self.assertAllEqual([0, 1], switch_true.indices.eval()) + + if __name__ == "__main__": googletest.main() |