aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/special_math_ops_test.py
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2017-03-01 17:59:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-01 18:08:24 -0800
commit3e975ea978bac4d861bb09328b06f3c316212611 (patch)
tree79bac044c9723df8443495eb962c2dd98a2ed421 /tensorflow/python/ops/special_math_ops_test.py
parent8043a27ed77f59bb68409070f2bfa01df0e04b89 (diff)
Merge changes from github.
Change: 148954491
Diffstat (limited to 'tensorflow/python/ops/special_math_ops_test.py')
-rw-r--r--tensorflow/python/ops/special_math_ops_test.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index 3d289bcc9a..c792d32277 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -318,7 +318,19 @@ class EinsumTest(test.TestCase):
m1: [3, 2],
}
np.testing.assert_almost_equal(
- [[7]], sess.run(out, feed_dict=feed_dict))
+ [[7]], sess.run(out, feed_dict=feed_dict))
+
+ with ops.Graph().as_default():
+ m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2))
+ m1 = array_ops.placeholder(dtypes.int32, shape=(None, 2))
+ out = special_math_ops.einsum('ijkl,ij->ikl', m0, m1)
+ with session.Session() as sess:
+ feed_dict = {
+ m0: [[[[1, 2]], [[2, 1]]]],
+ m1: [[3, 2]],
+ }
+ np.testing.assert_almost_equal(
+ [[[7, 8]]], sess.run(out, feed_dict=feed_dict))
if __name__ == '__main__':