aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-21 17:31:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-21 17:52:15 -0800
commit4891c01b1cadf085a915a3eac5dd1b8d8cdee203 (patch)
tree87ec00e1927877ba26a2ffb69bc4f74f25c36f6a /tensorflow/python/kernel_tests
parent123c2bb0af532d5fdaa05358158da33497d4bfe6 (diff)
Allow (safe) in-place computation in TensorFlow C++ ops. When at least one input tensor has the same size and type as the output, and the underlying buffer is owned by the op, i.e. when its refcount is 1 at the time the op's Compute method executes, the computation can be performed in place and allocation of the output buffer avoided.
I updated the following ops to perform in-place computation automatically when possible: * All standard coefficient-wise unary and binary operators (including with broadcasting) inheriting from base classes in kernels/cwise_ops_common.h. * unary and binary operators inheriting from base classes in framework/numeric_op.h. This is mostly old code for the Relu family and associated gradients. * All linear algebra ops inheriting from linalg_common. * Misc individual files/ops: softmax, select, bias, aggregate ops, batch_norm & fused_batch_norm, adjust_hue, constant, depthwise_conv_grad, fractional_avg_pool, misc. pooling ops, matrix_set_diag, xent & sparse_xent, unique_op. Change: 148166936
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py97
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py9
2 files changed, 67 insertions, 39 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index e9db47716d..6c7cbbff9c 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -326,9 +326,8 @@ class ControlFlowTest(test.TestCase):
def testFetchables(self):
with self.test_session() as sess:
x = array_ops.placeholder(dtypes.float32)
- control_flow_ops.cond(constant_op.constant(True),
- lambda: x + 2,
- lambda: x + 0)
+ control_flow_ops.cond(
+ constant_op.constant(True), lambda: x + 2, lambda: x + 0)
tensor_names = all_fetchables()
for name in tensor_names:
sess.run(name, feed_dict={x: 3})
@@ -388,11 +387,12 @@ class ControlFlowTest(test.TestCase):
rv = resource_variable_ops.ResourceVariable(True)
variables.global_variables_initializer().run()
t = ops.convert_to_tensor(1.0)
+
def case():
- assign = resource_variable_ops.assign_variable_op(
- rv.handle, False)
+ assign = resource_variable_ops.assign_variable_op(rv.handle, False)
with ops.control_dependencies([assign]):
return array_ops.identity(t)
+
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
def testCondIndexedSlicesDifferentTypes(self):
@@ -544,13 +544,15 @@ class ControlFlowTest(test.TestCase):
with self.test_session() as sess:
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
+
def true_branch():
with ops.control_dependencies([control_holder]):
_ = a + 1
return a + 2
- r = control_flow_ops.cond(constant_op.constant(True),
- true_branch,
- lambda: constant_op.constant(1))
+
+ r = control_flow_ops.cond(
+ constant_op.constant(True), true_branch,
+ lambda: constant_op.constant(1))
self.assertEqual(5, r.eval())
def testUninitializedRefIdentity(self):
@@ -770,16 +772,37 @@ class ControlFlowTest(test.TestCase):
o = ops.convert_to_tensor([0])
x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
s = array_ops.size(x)
- r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s),
- compute, [i, c, o], [
- i.get_shape(),
- tensor_shape.unknown_shape(),
- tensor_shape.unknown_shape()
- ])
+ r = control_flow_ops.while_loop(
+ lambda i, c, o: math_ops.less(i, s), compute, [i, c, o], [
+ i.get_shape(), tensor_shape.unknown_shape(),
+ tensor_shape.unknown_shape()
+ ])
result = r[2].eval()
self.assertTrue(check_op_order(i.graph))
self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
+ def testBufferForwarding(self):
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+
+ with self.test_session() as sess:
+ with ops.device("/cpu:0"):
+ c = constant_op.constant(2)
+ i0 = constant_op.constant(0)
+ r = control_flow_ops.while_loop(lambda i: i < 1000,
+ lambda i: math_ops.square(c) + i, [i0])
+ r_val = sess.run(r, options=run_options, run_metadata=run_metadata)
+ self.assertEqual(1000, r_val)
+ self.assertTrue(run_metadata.HasField("step_stats"))
+ unique_allocs = set()
+ for node_stat in run_metadata.step_stats.dev_stats[0].node_stats:
+ for output in node_stat.output:
+ unique_allocs.add(
+ output.tensor_description.allocation_description.ptr)
+ # Prior to cl/147536680, the number of unique allocations was about 1005.
+ self.assertLess(len(unique_allocs), 756)
+
def _testWhile_Gpu_1(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
n = constant_op.constant(1.0)
@@ -1368,8 +1391,9 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(45, rx.eval())
def _testWhileGrad_ColocateGradients(self, colocate):
- gpu_dev_name = test.gpu_device_name() if test.is_gpu_available() else "/gpu:0"
- gpu_short_name = gpu_dev_name.split('/')[-1]
+ gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
+ ) else "/gpu:0"
+ gpu_short_name = gpu_dev_name.split("/")[-1]
with self.test_session(graph=ops.Graph()) as sess:
v = constant_op.constant(2.0, name="v")
@@ -1485,16 +1509,21 @@ class ControlFlowTest(test.TestCase):
def _testNestedWhileCondWhileGrad(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
+
def inner_loop(s):
z = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 4)
b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
return control_flow_ops.while_loop(c, b, [z, s])
+
c = lambda x: math_ops.less(x, 128.0)
+
def b(x):
- return control_flow_ops.cond(constant_op.constant(True),
- lambda: math_ops.square(inner_loop(x)[1]),
- lambda: math_ops.multiply(x, 2.0))
+ return control_flow_ops.cond(
+ constant_op.constant(True),
+ lambda: math_ops.square(inner_loop(x)[1]),
+ lambda: math_ops.multiply(x, 2.0))
+
r = control_flow_ops.while_loop(c, b, [v])
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
@@ -1550,10 +1579,9 @@ class ControlFlowTest(test.TestCase):
with self.test_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
- named(
- a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
- (constant_op.constant(2.0), constant_op.constant(3.0)),
- constant_op.constant(4.0)
+ named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
+ (constant_op.constant(2.0),
+ constant_op.constant(3.0)), constant_op.constant(4.0)
]
c = lambda lv0, _1, _2: lv0.a < 100.0
@@ -1578,10 +1606,9 @@ class ControlFlowTest(test.TestCase):
with self.test_session():
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
- named(
- a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
- (constant_op.constant(2.0), constant_op.constant(3.0)),
- constant_op.constant(4.0)
+ named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
+ (constant_op.constant(2.0),
+ constant_op.constant(3.0)), constant_op.constant(4.0)
]
c = lambda lv0, _1, _2: lv0.a < 100.0
@@ -2522,15 +2549,11 @@ class TupleTest(test.TestCase):
with self.test_session():
v1 = variables.Variable([1.0])
add1 = math_ops.add(
- control_flow_ops.with_dependencies(
- [v1.initializer],
- v1._ref()), # pylint: disable=protected-access
+ control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
2.0)
v2 = variables.Variable([10.0])
add2 = math_ops.add(
- control_flow_ops.with_dependencies(
- [v2.initializer],
- v2._ref()), # pylint: disable=protected-access
+ control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access
20.0)
t1, _, t2 = control_flow_ops.tuple([add1, None, add2])
@@ -2558,18 +2581,14 @@ class TupleTest(test.TestCase):
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
np.float32))
v1_at_1 = ops.IndexedSlices(
- control_flow_ops.with_dependencies(
- [v1.initializer],
- v1._ref()), # pylint: disable=protected-access
+ control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
constant_op.constant([1]))
v2 = variables.Variable(
np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
np.float32))
v2_at_1 = ops.IndexedSlices(
- control_flow_ops.with_dependencies(
- [v2.initializer],
- v2._ref()), # pylint: disable=protected-access
+ control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access
constant_op.constant([1]))
st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 29f76a2182..c11f78b77e 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -269,6 +269,15 @@ class SliceTest(test.TestCase):
c = array_ops.slice(a, [begin, 0], [-1, 2])
self.assertEqual([None, 2], c.get_shape().as_list())
+ def testSliceOfSlice(self):
+ with self.test_session(use_gpu=True):
+ a = constant_op.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
+ b = a[1:, :]
+ c = b[:-1, :]
+ d = c[1, :]
+ res = 2 * d - c[1, :] + a[2, :] - 2 * b[-2, :]
+ self.assertAllEqual([0, 0, 0], res.eval())
+
if __name__ == "__main__":
test.main()