aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/special_math_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/special_math_ops_test.py')
-rw-r--r--tensorflow/python/ops/special_math_ops_test.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index d17bb80d4b..3d289bcc9a 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -283,6 +283,43 @@ class EinsumTest(test.TestCase):
}
np.testing.assert_almost_equal([7], sess.run(out, feed_dict=feed_dict))
+ # Tests for placeholders which have two or more None values
+ with ops.Graph().as_default():
+ m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
+ m1 = array_ops.placeholder(dtypes.int32, shape=(2, 1))
+ out = special_math_ops.einsum('ijk,kl->ijl', m0, m1)
+ with session.Session() as sess:
+ feed_dict = {
+ m0: [[[1,2]]],
+ m1: [[3], [2]],
+ }
+ np.testing.assert_almost_equal(
+ [[[7]]], sess.run(out, feed_dict=feed_dict))
+
+ with ops.Graph().as_default():
+ m0 = array_ops.placeholder(dtypes.int32, shape=(2, 1))
+ m1 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
+ out = special_math_ops.einsum('kl,ijk->ijl', m0, m1)
+ with session.Session() as sess:
+ feed_dict = {
+ m0: [[3], [2]],
+ m1: [[[1,2]]],
+ }
+ np.testing.assert_almost_equal(
+ [[[7]]], sess.run(out, feed_dict=feed_dict))
+
+ with ops.Graph().as_default():
+ m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
+ m1 = array_ops.placeholder(dtypes.int32, shape=(2,))
+ out = special_math_ops.einsum('ijk,k->ij', m0, m1)
+ with session.Session() as sess:
+ feed_dict = {
+ m0: [[[1, 2]]],
+ m1: [3, 2],
+ }
+ np.testing.assert_almost_equal(
+ [[7]], sess.run(out, feed_dict=feed_dict))
+
if __name__ == '__main__':
test.main()