aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py1
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py21
-rw-r--r--tensorflow/python/kernel_tests/pad_op_test.py11
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py13
-rw-r--r--tensorflow/python/ops/array_ops.py73
-rw-r--r--tensorflow/python/ops/math_ops.py8
6 files changed, 113 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 7e56bb5843..d5ad0a7bb3 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -892,5 +892,6 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
[1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0]])
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index 026bbbead1..918ebf05ec 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -34,17 +34,16 @@ class ListDiffTest(tf.test.TestCase):
x = [tf.compat.as_bytes(str(a)) for a in x]
y = [tf.compat.as_bytes(str(a)) for a in y]
out = [tf.compat.as_bytes(str(a)) for a in out]
-
- with self.test_session() as sess:
- x_tensor = tf.convert_to_tensor(x, dtype=dtype)
- y_tensor = tf.convert_to_tensor(y, dtype=dtype)
- out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
-
- self.assertAllEqual(tf_out, out)
- self.assertAllEqual(tf_idx, idx)
- self.assertEqual(1, out_tensor.get_shape().ndims)
- self.assertEqual(1, idx_tensor.get_shape().ndims)
+ for diff_func in [tf.listdiff, tf.setdiff1d]:
+ with self.test_session() as sess:
+ x_tensor = tf.convert_to_tensor(x, dtype=dtype)
+ y_tensor = tf.convert_to_tensor(y, dtype=dtype)
+ out_tensor, idx_tensor = diff_func(x_tensor, y_tensor)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
def testBasic1(self):
x = [1, 2, 3, 4]
diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py
index a4e411755a..986571e5c6 100644
--- a/tensorflow/python/kernel_tests/pad_op_test.py
+++ b/tensorflow/python/kernel_tests/pad_op_test.py
@@ -85,10 +85,11 @@ class PadOpTest(tf.test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _testAll(self, np_inputs, paddings):
- for mode in ("CONSTANT", "REFLECT", "SYMMETRIC"):
+ for mode in ("CONSTANT", "REFLECT", "SYMMETRIC", "reflect", "symmetric",
+ "constant"):
# Zero-sized input is not allowed for REFLECT mode, but we still want
# zero-sized input test cases for the other modes.
- if np_inputs.size or mode != "REFLECT":
+ if np_inputs.size or mode.upper() != "REFLECT":
self._testPad(np_inputs, paddings, mode=mode)
if np_inputs.dtype == np.float32:
self._testGradient(np_inputs, paddings, mode=mode)
@@ -155,6 +156,12 @@ class PadOpTest(tf.test.TestCase):
tf.constant([0, 3], shape=[1, 2]),
mode="SYMMETRIC").eval()
+ def testInvalid(self):
+ with self.test_session():
+ x = [[1, 2, 3], [4, 5, 6]]
+ with self.assertRaisesRegexp(ValueError, "Unknown padding mode"):
+ tf.pad(x, [[1, 0], [2, 1]], mode="weird").eval()
+
def testIntTypes(self):
# TODO(touts): Figure out why the padding tests do not work on GPU
# for int types and rank > 2.
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index b559b37d8c..00d35d8265 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -35,6 +35,13 @@ class WhereOpTest(tf.test.TestCase):
with self.assertRaisesOpError(expected_err_re):
ans.eval()
+ def testWrongNumbers(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.where([False, True], [1, 2], None)
+ with self.assertRaises(ValueError):
+ tf.where([False, True], None, [1, 2])
+
def testBasicMat(self):
x = np.asarray([[True, False], [True, False]])
@@ -55,6 +62,12 @@ class WhereOpTest(tf.test.TestCase):
self._testWhere(x, truth)
+ def testThreeArgument(self):
+ x = np.array([[-2, 3, -1], [1, -3, -3]])
+ np_val = np.where(x > 0, x*x, -x)
+ with self.test_session():
+ tf_val = tf.where(tf.constant(x) > 0, x*x, -x).eval()
+ self.assertAllEqual(tf_val, np_val)
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index e1f4ce2172..1a12cc9964 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -79,6 +79,7 @@ or join multiple tensors together.
@@dequantize
@@quantize_v2
@@quantized_concat
+@@setdiff1d
"""
from __future__ import absolute_import
@@ -116,6 +117,24 @@ _baseslice = slice
listdiff = gen_array_ops.list_diff
+def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
+ """Returns the difference between the `x` and `y` treated as sets.
+
+ Args:
+ x: Set of values not assumed to be unique.
+ y: Set of values not assumed to be unique.
+ index_dtype: Output index type (`tf.int32`, `tf.int64`) default: `tf.int32`
+ name: A name for the operation (optional).
+
+
+ Returns:
+ A `Tensor` the same type as `x` and `y`
+ A `Tensor` that is of type `index_dtype` representing indices from .
+ """
+
+ return gen_array_ops.list_diff(x, y, index_dtype, name)
+
+
def shape(input, name=None, out_type=dtypes.int32):
# pylint: disable=redefined-builtin
"""Returns the shape of a tensor.
@@ -1525,7 +1544,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invali
Args:
tensor: A `Tensor`.
paddings: A `Tensor` of type `int32`.
- mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC".
+ mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
name: A name for the operation (optional).
Returns:
@@ -1535,6 +1554,9 @@ def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invali
ValueError: When mode is not one of "CONSTANT", "REFLECT", or "SYMMETRIC".
"""
+ # Convert lower/mixed case to upper for NumPy compatibility
+ # NumPy uses all lower-case modes.
+ mode = mode.upper()
if mode == "CONSTANT":
return gen_array_ops._pad(tensor, paddings, name=name)
if mode == "REFLECT":
@@ -2453,6 +2475,55 @@ def squeeze(input, squeeze_dims=None, name=None):
return gen_array_ops._squeeze(input, squeeze_dims, name)
+def where(condition, x=None, y=None, name=None):
+ """Return the elements, either from `x` or `y`, depending on the `condition`.
+
+ If both `x` and `y` are None, then this operation returns the coordinates of
+ true elements of `condition`. The coordinates are returned in a 2-D tensor
+ where the first dimension (rows) represents the number of true elements, and
+ the second dimension (columns) represents the coordinates of the true
+ elements. Keep in mind, the shape of the output tensor can vary depending on
+ how many true values there are in input. Indices are output in row-major
+ order.
+
+ If both non-None, `x` and `y` must have the same shape.
+ The `condition` tensor must be a scalar if `x` and `y` are scalar.
+ If `x` and `y` are vectors or higher rank, then `condition` must be either a
+ vector with size matching the first dimension of `x`, or must have the same
+ shape as `x`.
+
+ The `condition` tensor acts as a mask that chooses, based on the value at each
+ element, whether the corresponding element / row in the output should be taken
+ from `x` (if true) or `y` (if false).
+
+ If `condition` is a vector and `x` and `y` are higher rank matrices, then it
+ chooses which row (outer dimension) to copy from `x` and `y`. If `condition`
+ has the same shape as `x` and `y`, then it chooses which element to copy from
+ `x` and `y`.
+
+ Args:
+ condition: A `Tensor` of type `bool`
+ x: A Tensor which may have the same shape as `condition`. If `condition` is
+ rank 1, `x` may have higher rank, but its first dimension must match the
+ size of `condition`.
+ y: A `tensor` with the same shape and type as `x`.
+ name: A name of the operation (optional)
+
+ Returns:
+ A `Tensor` with the same type and shape as `x`, `y` if they are non-None.
+ A `Tensor` with shape `(num_true, dim_size(condition))`.
+
+ Raises:
+ ValueError: When exactly one of `x` or `y` is non-None.
+ """
+ if x is None and y is None:
+ return gen_array_ops.where(input=condition, name=name)
+ elif x is not None and y is not None:
+ return gen_math_ops.select(condition=condition, t=x, e=y, name=name)
+ else:
+ raise ValueError("x and y must both be non-None or both be None.")
+
+
@ops.RegisterShape("QuantizedReshape")
def _QuantizedReshapeShape(op):
return _ReshapeShape(op) + [tensor_shape.scalar(), tensor_shape.scalar()]
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 1a34634cf2..cf4b47757e 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -24,8 +24,10 @@ operators to your graph.
@@add
@@sub
@@mul
+@@multiply
@@scalar_mul
@@div
+@@divide
@@truediv
@@floordiv
@@mod
@@ -39,6 +41,7 @@ mathematical functions to your graph.
@@add_n
@@abs
@@neg
+@@negative
@@sign
@@inv
@@square
@@ -274,6 +277,11 @@ def divide(x, y, name=None):
with ops.name_scope(name, "Divide", [x]) as name:
return x / y
+# Make Python Aliases
+multiply = gen_math_ops.mul
+subtract = gen_math_ops.sub
+negative = gen_math_ops.neg
+
def neg(x, name=None):
"""Computes numerical negative value element-wise.