aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-07 17:14:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-07 17:25:02 -0800
commit30d2a2235a66e158757b13b6dd0bab7bf5e60492 (patch)
treeffa260e3b9c50281a99d283b9b87d7def9092358
parentfe5740fef35b3dae53d40f59451523cd5360c6ba (diff)
Fix bug in tf.einsum() when input is placeholder.
Change: 141378762
-rw-r--r--tensorflow/python/ops/special_math_ops.py4
-rw-r--r--tensorflow/python/ops/special_math_ops_test.py12
2 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index 02c111e3bd..9c56ad6591 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -121,6 +121,7 @@ def einsum(equation, *inputs):
Many common operations can be expressed in this way. For example:
+ ```python
# Matrix multiplication
>>> einsum('ij,jk->ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k]
@@ -135,6 +136,7 @@ def einsum(equation, *inputs):
# Batch matrix multiplication
>>> einsum('aij,ajk->aik', s, t) # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
+ ```
This function behaves like `numpy.einsum`, but does not support:
* Ellipses (subscripts like `ij...,jk...->ik...`)
@@ -358,6 +360,8 @@ def _transpose_if_necessary(tensor, perm):
def _reshape_if_necessary(tensor, new_shape):
"""Like reshape(), but avoids creating a new tensor if possible."""
+ # Accept None as an alias for -1 in new_shape.
+ new_shape = tuple(-1 if x is None else x for x in new_shape)
cur_shape = tuple(x.value for x in tensor.get_shape())
if (len(new_shape) == len(cur_shape) and
all(d0 == d1 or d1 == -1 for d0, d1 in zip(cur_shape, new_shape))):
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index 43c01e31a7..9e6fe68800 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -254,6 +254,18 @@ class EinsumTest(tf.test.TestCase):
np.testing.assert_almost_equal([[7]],
sess.run(out, feed_dict=feed_dict))
+ with tf.Graph().as_default():
+ m0 = tf.placeholder(tf.int32, shape=(None, 3))
+ m1 = tf.placeholder(tf.int32, shape=(3,))
+ out = tf.einsum('ij,j->i', m0, m1)
+ with tf.Session() as sess:
+ feed_dict = {
+ m0: [[1, 2, 3]],
+ m1: [2, 1, 1],
+ }
+ np.testing.assert_almost_equal([7],
+ sess.run(out, feed_dict=feed_dict))
+
if __name__ == '__main__':
tf.test.main()