diff options
author | 2016-12-07 17:14:52 -0800 | |
---|---|---|
committer | 2016-12-07 17:25:02 -0800 | |
commit | 30d2a2235a66e158757b13b6dd0bab7bf5e60492 (patch) | |
tree | ffa260e3b9c50281a99d283b9b87d7def9092358 | |
parent | fe5740fef35b3dae53d40f59451523cd5360c6ba (diff) |
Fix bug in tf.einsum() when input is placeholder.
Change: 141378762
-rw-r--r-- | tensorflow/python/ops/special_math_ops.py | 4 | ||||
-rw-r--r-- | tensorflow/python/ops/special_math_ops_test.py | 12 |
2 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 02c111e3bd..9c56ad6591 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -121,6 +121,7 @@ def einsum(equation, *inputs): Many common operations can be expressed in this way. For example: + ```python # Matrix multiplication >>> einsum('ij,jk->ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] @@ -135,6 +136,7 @@ def einsum(equation, *inputs): # Batch matrix multiplication >>> einsum('aij,ajk->aik', s, t) # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k] + ``` This function behaves like `numpy.einsum`, but does not support: * Ellipses (subscripts like `ij...,jk...->ik...`) @@ -358,6 +360,8 @@ def _transpose_if_necessary(tensor, perm): def _reshape_if_necessary(tensor, new_shape): """Like reshape(), but avoids creating a new tensor if possible.""" + # Accept None as an alias for -1 in new_shape. + new_shape = tuple(-1 if x is None else x for x in new_shape) cur_shape = tuple(x.value for x in tensor.get_shape()) if (len(new_shape) == len(cur_shape) and all(d0 == d1 or d1 == -1 for d0, d1 in zip(cur_shape, new_shape))): diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index 43c01e31a7..9e6fe68800 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -254,6 +254,18 @@ class EinsumTest(tf.test.TestCase): np.testing.assert_almost_equal([[7]], sess.run(out, feed_dict=feed_dict)) + with tf.Graph().as_default(): + m0 = tf.placeholder(tf.int32, shape=(None, 3)) + m1 = tf.placeholder(tf.int32, shape=(3,)) + out = tf.einsum('ij,j->i', m0, m1) + with tf.Session() as sess: + feed_dict = { + m0: [[1, 2, 3]], + m1: [2, 1, 1], + } + np.testing.assert_almost_equal([7], + sess.run(out, feed_dict=feed_dict)) + if __name__ == '__main__': tf.test.main() |