diff options
Diffstat (limited to 'tensorflow/python/ops/special_math_ops_test.py')
-rw-r--r-- | tensorflow/python/ops/special_math_ops_test.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index 3d289bcc9a..c792d32277 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -318,7 +318,19 @@ class EinsumTest(test.TestCase): m1: [3, 2], } np.testing.assert_almost_equal( - [[7]], sess.run(out, feed_dict=feed_dict)) + [[7]], sess.run(out, feed_dict=feed_dict)) + + with ops.Graph().as_default(): + m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2)) + m1 = array_ops.placeholder(dtypes.int32, shape=(None, 2)) + out = special_math_ops.einsum('ijkl,ij->ikl', m0, m1) + with session.Session() as sess: + feed_dict = { + m0: [[[[1, 2]], [[2, 1]]]], + m1: [[3, 2]], + } + np.testing.assert_almost_equal( + [[[7, 8]]], sess.run(out, feed_dict=feed_dict)) if __name__ == '__main__': |