aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-01-26 16:53:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 16:59:01 -0800
commitaee7f95a027accc94f1f9130f0cfaecd9399bc1d (patch)
tree6b8484915bf631f18b2fa0561a73549d9bf19fad /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parente95537708f070a98607393a8f60bc61f1611a77b (diff)
Add C0301 line-too-long error to pylint sanity check.
PiperOrigin-RevId: 183467186
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py198
1 files changed, 111 insertions, 87 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 6e18ed132c..5d648bb235 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -181,8 +181,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(enter_v_constant.shape, [2])
# Otherwise, the shape should be unknown.
- enter_v_non_constant = control_flow_ops.enter(v, "frame2",
- is_constant=False)
+ enter_v_non_constant = control_flow_ops.enter(
+ v, "frame2", is_constant=False)
self.assertEqual(enter_v_non_constant.shape, None)
def testSwitchMergeIndexedSlices(self):
@@ -736,24 +736,21 @@ class ControlFlowTest(test.TestCase):
with self.test_session():
s = constant_op.constant([1, 2, 3, 4, 5])
r = isum(s, maximum_iterations=3)
- self.assertAllEqual([1+3, 2+3, 3+3, 4+3, 5+3], r.eval())
+ self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
def testWhileWithMaximumIterationsAndSingleArgument(self):
with self.test_session():
r = control_flow_ops.while_loop(
- lambda i: i < 3,
- lambda i: i + 1,
- [0],
- maximum_iterations=1)
+ lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, r.eval())
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
+
def training_loop_with_gradient(i):
out = control_flow_ops.while_loop(
lambda i_, _: i_ < 3,
- lambda i_, j: [i_ + 1, j * v],
- [0, 1.0],
+ lambda i_, j: [i_ + 1, j * v], [0, 1.0],
maximum_iterations=i)
g = gradients_impl.gradients(out, v)
with ops.control_dependencies(g):
@@ -763,8 +760,8 @@ class ControlFlowTest(test.TestCase):
xla_context.Enter()
# Create training loop, ensure we can call gradient() of
# while_loop inside the training loop.
- loop = control_flow_ops.while_loop(
- lambda i: i < 3, training_loop_with_gradient, [0])
+ loop = control_flow_ops.while_loop(lambda i: i < 3,
+ training_loop_with_gradient, [0])
xla_context.Exit()
loop_execute = array_ops.identity(loop) # Because loop is not fetchable.
@@ -774,17 +771,18 @@ class ControlFlowTest(test.TestCase):
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
+
def inner_body(i, x):
out = control_flow_ops.while_loop(
lambda i, _: i < 3,
- lambda i, j: [i + 1, j * v],
- [0, x],
+ lambda i, j: [i + 1, j * v], [0, x],
maximum_iterations=i)
return out
def create_while_loop(maximum_iterations=None):
return control_flow_ops.while_loop(
- lambda i, _: i < 3, inner_body, [0, 1.0],
+ lambda i, _: i < 3,
+ inner_body, [0, 1.0],
maximum_iterations=maximum_iterations)
loop_no_xla = create_while_loop(maximum_iterations=5)
@@ -819,14 +817,17 @@ class ControlFlowTest(test.TestCase):
def create_while_loop():
max_iter_holder = []
+
def create_mi():
max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=()))
return 1.0
- _ = control_flow_ops.cond(constant_op.constant(True),
- create_mi, create_mi)
+
+ _ = control_flow_ops.cond(
+ constant_op.constant(True), create_mi, create_mi)
return control_flow_ops.while_loop(
- lambda i, _: i < 3, lambda i, x: (i + 1, v * x), (0, 1.0),
+ lambda i, _: i < 3,
+ lambda i, x: (i + 1, v * x), (0, 1.0),
maximum_iterations=max_iter_holder[0])
xla_context = control_flow_ops.XLAControlFlowContext()
@@ -849,28 +850,32 @@ class ControlFlowTest(test.TestCase):
p = array_ops.placeholder(dtype=dtypes.int32)
def mid_body_builder(iterations):
+
def mid_body(i, x):
r = control_flow_ops.while_loop(
lambda *_: True,
- lambda i, x: (i + 1, v * x),
- (0, x),
- maximum_iterations=iterations, name="inner")
+ lambda i, x: (i + 1, v * x), (0, x),
+ maximum_iterations=iterations,
+ name="inner")
return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
+
return mid_body
def outer_body(i, x):
iterations = array_ops.size(p, name="iterations")
- return (
- i + 1,
- x + control_flow_ops.while_loop(
- lambda *_: True, mid_body_builder(iterations), (0, x),
- maximum_iterations=iterations, name="mid")[1])
+ return (i + 1, x + control_flow_ops.while_loop(
+ lambda *_: True,
+ mid_body_builder(iterations), (0, x),
+ maximum_iterations=iterations,
+ name="mid")[1])
def create_while_loop():
with ops.device("/cpu:0"):
r = control_flow_ops.while_loop(
- lambda *_: True, outer_body, (0, 1.0),
- maximum_iterations=5, name="outer")
+ lambda *_: True,
+ outer_body, (0, 1.0),
+ maximum_iterations=5,
+ name="outer")
return array_ops.identity(r[1])
xla_context = control_flow_ops.XLAControlFlowContext()
@@ -881,18 +886,19 @@ class ControlFlowTest(test.TestCase):
final_without_xla_context = create_while_loop()
with self.test_session(use_gpu=False) as sess:
- opts = config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE)
+ opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
final_value_without_xla_context = sess.run(
- final_without_xla_context,
- feed_dict={p: [0, 0, 0]})
+ final_without_xla_context, feed_dict={
+ p: [0, 0, 0]
+ })
final_value_with_xla_context = sess.run(
final_with_xla_context,
feed_dict={p: [0, 0, 0]},
- options=opts, run_metadata=run_metadata)
+ options=opts,
+ run_metadata=run_metadata)
node_stats = run_metadata.step_stats.dev_stats[0].node_stats
stack_push_count = len(
@@ -901,8 +907,8 @@ class ControlFlowTest(test.TestCase):
# the last two "3"s comes from size(p), when p == [0, 0, 0].
self.assertEqual(stack_push_count, 5 * 3 * 3)
- self.assertAllClose(
- final_value_with_xla_context, final_value_without_xla_context)
+ self.assertAllClose(final_value_with_xla_context,
+ final_value_without_xla_context)
# Have more than 10 parallel iterations and hence exercise k-bound
# most of the time.
@@ -951,8 +957,7 @@ class ControlFlowTest(test.TestCase):
with self.test_session():
def compute(i, c, o):
- c = array_ops.strided_slice(x,
- array_ops.expand_dims(i, 0),
+ c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
[1] + array_ops.expand_dims(i, 0))
o = array_ops.concat([o, c], 0)
i = math_ops.add(i, 1)
@@ -963,11 +968,12 @@ class ControlFlowTest(test.TestCase):
o = ops.convert_to_tensor([0])
x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
s = array_ops.size(x)
- r = control_flow_ops.while_loop(
- lambda i, c, o: math_ops.less(i, s), compute, [i, c, o], [
- i.get_shape(), tensor_shape.unknown_shape(),
- tensor_shape.unknown_shape()
- ])
+ r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s),
+ compute, [i, c, o], [
+ i.get_shape(),
+ tensor_shape.unknown_shape(),
+ tensor_shape.unknown_shape()
+ ])
result = r[2].eval()
self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
@@ -1033,7 +1039,8 @@ class ControlFlowTest(test.TestCase):
return [new_i, new_j]
r = control_flow_ops.while_loop(
- c, _b, [i, m], [i.get_shape(), tensor_shape.unknown_shape()])
+ c, _b, [i, m],
+ [i.get_shape(), tensor_shape.unknown_shape()])
r = r[1] * array_ops.ones([8, 8])
self.assertAllEqual(np.ones((8, 8)), r.eval())
@@ -1065,7 +1072,8 @@ class ControlFlowTest(test.TestCase):
return [new_i, new_j]
r = control_flow_ops.while_loop(
- c, b, [i, m], [i.get_shape(), tensor_shape.TensorShape([None, 2])])
+ c, b, [i, m],
+ [i.get_shape(), tensor_shape.TensorShape([None, 2])])
self.assertTrue(r[1].get_shape()[0].value is None)
self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2))
@@ -1092,20 +1100,22 @@ class ControlFlowTest(test.TestCase):
def b(i, x):
return [
- i + 1, sparse_tensor.SparseTensor(x.indices, x.values * 2.0,
- x.dense_shape)
+ i + 1,
+ sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
]
_, r = control_flow_ops.while_loop(c, b, [i, x])
self.assertEqual(r.dense_shape.get_shape()[0].value, 1)
_, r = control_flow_ops.while_loop(
- c, b, [i, x], [i.get_shape(), tensor_shape.TensorShape([None])])
+ c, b, [i, x],
+ [i.get_shape(), tensor_shape.TensorShape([None])])
self.assertTrue(r.dense_shape.get_shape()[0].value is None)
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
_, r = control_flow_ops.while_loop(
- c, b, [i, x], [i.get_shape(), tensor_shape.TensorShape([5])])
+ c, b, [i, x],
+ [i.get_shape(), tensor_shape.TensorShape([5])])
def testWhileShapeInferenceIndexedSlices(self):
with self.test_session():
@@ -1120,7 +1130,8 @@ class ControlFlowTest(test.TestCase):
def b(i, x):
return [
- i + 1, ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
+ i + 1,
+ ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
]
_, r = control_flow_ops.while_loop(c, b, [i, x])
@@ -1128,14 +1139,16 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2]))
_, r = control_flow_ops.while_loop(
- c, b, [i, x], [i.get_shape(), tensor_shape.TensorShape([None, 2])])
+ c, b, [i, x],
+ [i.get_shape(), tensor_shape.TensorShape([None, 2])])
self.assertEqual(r.dense_shape.get_shape()[0].value, 2)
self.assertTrue(r.values.get_shape()[0].value is None)
self.assertEqual(r.values.get_shape()[1].value, 2)
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
_, r = control_flow_ops.while_loop(
- c, b, [i, x], [i.get_shape(), tensor_shape.TensorShape([None, 5])])
+ c, b, [i, x],
+ [i.get_shape(), tensor_shape.TensorShape([None, 5])])
def _testNestedWhile_1(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
@@ -1276,16 +1289,17 @@ class ControlFlowTest(test.TestCase):
"v", [], initializer=init_ops.constant_initializer(2))
i0 = constant_op.constant(0)
with ops.control_dependencies([i0]):
+
def loop_condition(i):
return i < 4
def loop_body(i):
some_cond = control_flow_ops.cond(
constant_op.constant(True),
- lambda: state_ops.assign(v, math_ops.square(v)),
- lambda: v)
+ lambda: state_ops.assign(v, math_ops.square(v)), lambda: v)
with ops.control_dependencies([some_cond]):
return i + 1
+
r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,))
variables.global_variables_initializer().run()
self.assertEqual(4, r.eval())
@@ -1600,7 +1614,8 @@ class ControlFlowTest(test.TestCase):
_, rx = control_flow_ops.while_loop(
c1,
- b1, [r, x], [r.get_shape(), tensor_shape.unknown_shape()],
+ b1, [r, x],
+ [r.get_shape(), tensor_shape.unknown_shape()],
parallel_iterations=1)
self.assertEqual(45, rx.eval())
@@ -1663,7 +1678,8 @@ class ControlFlowTest(test.TestCase):
b = lambda i, v: [i + 1, math_ops.multiply(x, v)]
r = control_flow_ops.while_loop(
c,
- b, [n, v], [n.get_shape(), tensor_shape.unknown_shape()],
+ b, [n, v],
+ [n.get_shape(), tensor_shape.unknown_shape()],
parallel_iterations=1)
r = gradients_impl.gradients(r[1], x)[0]
@@ -1797,8 +1813,8 @@ class ControlFlowTest(test.TestCase):
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
- (constant_op.constant(2.0),
- constant_op.constant(3.0)), constant_op.constant(4.0)
+ (constant_op.constant(2.0), constant_op.constant(3.0)),
+ constant_op.constant(4.0)
]
c = lambda lv0, _1, _2: lv0.a < 100.0
@@ -1824,8 +1840,8 @@ class ControlFlowTest(test.TestCase):
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
- (constant_op.constant(2.0),
- constant_op.constant(3.0)), constant_op.constant(4.0)
+ (constant_op.constant(2.0), constant_op.constant(3.0)),
+ constant_op.constant(4.0)
]
c = lambda lv0, _1, _2: lv0.a < 100.0
@@ -2176,7 +2192,8 @@ class ControlFlowTest(test.TestCase):
def b(i, x):
return [
- i + 1, ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
+ i + 1,
+ ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
]
_, r = control_flow_ops.while_loop(c, b, [i, x])
@@ -2197,8 +2214,8 @@ class ControlFlowTest(test.TestCase):
def b(i, x):
return [
- i + 1, sparse_tensor.SparseTensor(x.indices, x.values * 2.0,
- x.dense_shape)
+ i + 1,
+ sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
]
_, r = control_flow_ops.while_loop(c, b, [i, x])
@@ -2220,8 +2237,8 @@ class ControlFlowTest(test.TestCase):
x1 = x + gradients_impl.gradients(data, params)[0]
return i + 1, x1
- output_grad = control_flow_ops.while_loop(c, b,
- [i0, constant_op.constant(0.0)])
+ output_grad = control_flow_ops.while_loop(
+ c, b, [i0, constant_op.constant(0.0)])
self.assertAllClose(600.0, sess.run(output_grad)[1])
def testWhileAndTensorArray(self):
@@ -2359,9 +2376,12 @@ class ControlFlowTest(test.TestCase):
def testStopGradMultiFlows(self):
with self.test_session():
+
def body(i, y, r):
x = variable_scope.get_variable(
- "x", shape=(), dtype=dtypes.float32,
+ "x",
+ shape=(),
+ dtype=dtypes.float32,
initializer=init_ops.ones_initializer())
y *= x
return [i + 1, y, r + math_ops.reduce_sum(y)]
@@ -2773,7 +2793,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(
lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
[constant_op.constant(0), x],
- [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])
+ [tensor_shape.unknown_shape(),
+ tensor_shape.unknown_shape()])
self.assertEqual(r[1].eval(), 65536.0)
r = gradients_impl.gradients(r, x)[0]
@@ -2800,12 +2821,14 @@ class ControlFlowContextCheckTest(test.TestCase):
def _getCondTensor(self):
cond_tensor = []
+
def true_fn():
if not cond_tensor:
cond_tensor.append(constant_op.constant(1))
return cond_tensor[0]
- control_flow_ops.cond(math_ops.less(1, 2), true_fn,
- lambda: constant_op.constant(0))
+
+ control_flow_ops.cond(
+ math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
return cond_tensor[0]
def testInvalidContext(self):
@@ -2821,14 +2844,13 @@ class ControlFlowContextCheckTest(test.TestCase):
# Accessing a while loop tensor in cond is illegal.
while_tensor = self._getWhileTensor()
with self.assertRaisesRegexp(
- ValueError,
- "Cannot use 'while/Const_1' as input to 'cond/Add' because "
+ ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because "
"'while/Const_1' is in a while loop. See info log for more details."):
# TODO(skyewm): this passes if we return while_tensor directly instead
# of using it as input to another op.
- control_flow_ops.cond(math_ops.less(1, 2),
- lambda: math_ops.add(1, while_tensor),
- lambda: constant_op.constant(0))
+ control_flow_ops.cond(
+ math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor),
+ lambda: constant_op.constant(0))
def testInvalidContextInWhile(self):
# Accessing a while loop tensor in a different while loop is illegal.
@@ -2856,6 +2878,7 @@ class ControlFlowContextCheckTest(test.TestCase):
# Accessing a tensor from a cond context from the other branch's cond
# context is OK (although dangerous).
cond_tensor = []
+
def branch_fn():
if not cond_tensor:
cond_tensor.append(constant_op.constant(1))
@@ -2892,12 +2915,13 @@ class ControlFlowContextCheckTest(test.TestCase):
while_tensor = self._getWhileTensor()
return control_flow_ops.while_loop(lambda i: i < 3,
lambda i: i + while_tensor, [0])
+
with self.assertRaisesRegexp(
ValueError,
"Cannot use 'cond/while_1/add' as input to 'cond/while/Const_1' because"
" they are in different while loops. See info log for more details."):
- control_flow_ops.cond(math_ops.less(1, 2), true_fn,
- lambda: constant_op.constant(0))
+ control_flow_ops.cond(
+ math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
@test_util.with_c_api
@@ -3005,11 +3029,13 @@ class AssertTest(test.TestCase):
sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata)
guarded_nodestat_names = [
n.node_name
- for d in guarded_metadata.step_stats.dev_stats for n in d.node_stats
+ for d in guarded_metadata.step_stats.dev_stats
+ for n in d.node_stats
]
unguarded_nodestat_names = [
n.node_name
- for d in unguarded_metadata.step_stats.dev_stats for n in d.node_stats
+ for d in unguarded_metadata.step_stats.dev_stats
+ for n in d.node_stats
]
guarded_memcpy_nodestat_names = [
n for n in guarded_nodestat_names if "MEMCPYDtoH" in n
@@ -3066,6 +3092,7 @@ class WhileOpBenchmark(test.Benchmark):
Returns:
The duration of the run in seconds.
"""
+
def loop_body(i, x):
with ops.device("/gpu:0"):
# Always put loop body on GPU.
@@ -3107,7 +3134,7 @@ class WhileOpBenchmark(test.Benchmark):
start_time = time.time()
for _ in xrange(num_iters):
sess.run(r)
- return (time.time() - start_time)/num_iters
+ return (time.time() - start_time) / num_iters
def benchmarkWhileOpCrossDevicePlacement(self):
iters = 10
@@ -3154,23 +3181,20 @@ class EagerTest(test.TestCase):
def testWhileLoop(self):
with context.eager_mode():
tensor = constant_op.constant([1, 2, 3, 4, 5])
- self.assertAllEqual(isum(tensor).numpy(),
- [46, 47, 48, 49, 50])
+ self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50])
def testWhileLoopWithMaxIterations(self):
with context.eager_mode():
tensor = constant_op.constant([1, 2, 3, 4, 5])
- self.assertAllEqual(isum(tensor, maximum_iterations=3).numpy(),
- [1+3, 2+3, 3+3, 4+3, 5+3])
+ self.assertAllEqual(
+ isum(tensor, maximum_iterations=3).numpy(),
+ [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3])
def testWhileWithMaximumIterationsAndSingleArgument(self):
with context.eager_mode():
tensor = constant_op.constant(0)
r = control_flow_ops.while_loop(
- lambda i: i < 3,
- lambda i: i + 1,
- [tensor],
- maximum_iterations=1)
+ lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1)
self.assertEqual(1, r.numpy())
def testWithDependencies(self):
@@ -3197,8 +3221,8 @@ class EagerTest(test.TestCase):
f2 = lambda: constant_op.constant(23)
f3 = lambda: constant_op.constant(-1)
- r1 = control_flow_ops.case([(x < y, f1), (x > z, f2)],
- default=f3, exclusive=True)
+ r1 = control_flow_ops.case(
+ [(x < y, f1), (x > z, f2)], default=f3, exclusive=True)
self.assertAllEqual(r1.numpy(), 17)