diff options
Diffstat (limited to 'tensorflow/python/ops/special_math_ops_test.py')
-rw-r--r-- | tensorflow/python/ops/special_math_ops_test.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index d17bb80d4b..3d289bcc9a 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -283,6 +283,43 @@ class EinsumTest(test.TestCase): } np.testing.assert_almost_equal([7], sess.run(out, feed_dict=feed_dict)) + # Tests for placeholders which have two or more None values + with ops.Graph().as_default(): + m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2)) + m1 = array_ops.placeholder(dtypes.int32, shape=(2, 1)) + out = special_math_ops.einsum('ijk,kl->ijl', m0, m1) + with session.Session() as sess: + feed_dict = { + m0: [[[1,2]]], + m1: [[3], [2]], + } + np.testing.assert_almost_equal( + [[[7]]], sess.run(out, feed_dict=feed_dict)) + + with ops.Graph().as_default(): + m0 = array_ops.placeholder(dtypes.int32, shape=(2, 1)) + m1 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2)) + out = special_math_ops.einsum('kl,ijk->ijl', m0, m1) + with session.Session() as sess: + feed_dict = { + m0: [[3], [2]], + m1: [[[1,2]]], + } + np.testing.assert_almost_equal( + [[[7]]], sess.run(out, feed_dict=feed_dict)) + + with ops.Graph().as_default(): + m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2)) + m1 = array_ops.placeholder(dtypes.int32, shape=(2,)) + out = special_math_ops.einsum('ijk,k->ij', m0, m1) + with session.Session() as sess: + feed_dict = { + m0: [[[1, 2]]], + m1: [3, 2], + } + np.testing.assert_almost_equal( + [[7]], sess.run(out, feed_dict=feed_dict)) + if __name__ == '__main__': test.main() |