aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Lakshay Garg <lakshayg@outlook.in>2017-08-15 04:57:12 +0530
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2017-08-14 16:27:12 -0700
commit227fc9f811eee28a2f1471aac4933f312c21479b (patch)
tree33dec3eb3cfb11a01985f3f9df5e7b1aafd1f837
parent17bc79f4eb0608f5fd82fe08bb50688773a41bc7 (diff)
Implements tf.arg for complex numbers (Closes #483) (#10643)
* implement tf.arg, closes #483 * Remove GPU support for arg op * rename arg to angle
-rw-r--r--tensorflow/cc/gradients/math_grad.cc16
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc16
-rw-r--r--tensorflow/core/kernels/cwise_op_arg.cc37
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_arg.cu.cc28
-rw-r--r--tensorflow/core/kernels/cwise_ops.h4
-rw-r--r--tensorflow/core/ops/math_grad.cc14
-rw-r--r--tensorflow/core/ops/math_grad_test.cc1
-rw-r--r--tensorflow/core/ops/math_ops.cc28
-rw-r--r--tensorflow/core/ops/ops.pbtxt39
-rw-r--r--tensorflow/docs_src/api_guides/python/math_ops.md1
-rw-r--r--tensorflow/go/op/wrappers.go46
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py27
-rw-r--r--tensorflow/python/ops/math_grad.py13
-rw-r--r--tensorflow/python/ops/math_ops.py37
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
15 files changed, 306 insertions, 5 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 0b9b665b1e..906fe9b70b 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -373,6 +373,22 @@ Status ImagGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Imag", ImagGrad);
+Status AngleGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = Angle(x)
+ // dx = -dy / (Im(x) + iRe(x)) = -dy * z
+ auto re = Real(scope, op.input(0));
+ auto im = Imag(scope, op.input(0));
+ auto z_inv = Reciprocal(scope, Complex(scope, im, re));
+ auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
+ auto grad = Complex(scope, grad_inputs[0], zero);
+ auto dx = Neg(scope, Mul(scope, grad, z_inv));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Angle", AngleGrad);
+
Status ConjGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 48b3ddbe90..9983333566 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -681,7 +681,7 @@ class CWiseUnaryComplexGradTest : public ::testing::Test {
CWiseUnaryComplexGradTest()
: scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
- enum UnaryOpType { REAL, IMAG, CONJ };
+ enum UnaryOpType { REAL, IMAG, ANGLE, CONJ };
void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
const Tensor& dy, const Tensor& dx_expected) {
@@ -693,6 +693,9 @@ class CWiseUnaryComplexGradTest : public ::testing::Test {
case IMAG:
y = Imag(scope_, x);
break;
+ case ANGLE:
+ y = Angle(scope_, x);
+ break;
case CONJ:
y = Conj(scope_, x);
break;
@@ -727,6 +730,17 @@ TEST_F(CWiseUnaryComplexGradTest, Imag) {
TestCWiseGradComplex(IMAG, x, dy, dx_expected);
}
+TEST_F(CWiseUnaryComplexGradTest, Angle) {
+ Tensor x = test::AsTensor<complex64>(
+ {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
+ Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
+ Tensor dx_expected = test::AsTensor<complex64>(
+ {{5.5, 5.5}, {3, 3},
+ {2.1666666666666665, 2.1666666666666665}, {1.75, 1.75},
+ {0.9375, 0.9375}, {0.8888888888888888, 0.8888888888888888}}, {2, 3});
+ TestCWiseGradComplex(ANGLE, x, dy, dx_expected);
+}
+
TEST_F(CWiseUnaryComplexGradTest, Conj) {
Tensor x = test::AsTensor<complex64>(
{{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
diff --git a/tensorflow/core/kernels/cwise_op_arg.cc b/tensorflow/core/kernels/cwise_op_arg.cc
new file mode 100644
index 0000000000..62ffa0718f
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_arg.cc
@@ -0,0 +1,37 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+#define REGISTER_COMPLEX(D, R, C) \
+ REGISTER_KERNEL_BUILDER(Name("Angle") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<C>("T") \
+ .TypeConstraint<R>("Tout"), \
+ UnaryOp<D##Device, functor::get_angle<C>>);
+
+REGISTER_COMPLEX(CPU, float, complex64);
+REGISTER_COMPLEX(CPU, double, complex128);
+
+// TODO: Enable GPU support for angle op after resolving
+// build failures on GPU (See #10643 for context).
+#if 0 && GOOGLE_CUDA
+REGISTER_COMPLEX(GPU, float, complex64);
+REGISTER_COMPLEX(GPU, double, complex128);
+#endif
+
+#undef REGISTER_COMPLEX
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_arg.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_arg.cu.cc
new file mode 100644
index 0000000000..9b3f8200bd
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_arg.cu.cc
@@ -0,0 +1,28 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TODO: Enable GPU support for angle op after resolving
+// build failures on GPU (See #10643 for context).
+#if 0 && GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(get_angle, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 65a60720dd..d935331904 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -831,6 +831,10 @@ struct get_imag
: base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {};
template <typename T>
+struct get_angle
+ : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {};
+
+template <typename T>
struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {};
////////////////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 5e082ce8f5..1290d3103e 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -349,6 +349,20 @@ Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Imag", ImagGrad);
+Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"re"}, "Real", {"x"}},
+ {{"im"}, "Imag", {"x"}},
+ {{"z"}, "Complex", {"im", "re"}},
+ {{"z_inv"}, "Reciprocal", {"z"}},
+ {{"neg"}, "Neg", {"z_inv"}},
+ {{"dx"}, "Mul", {"neg", "dy"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Angle", AngleGrad);
+
Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 1393bffb91..2b4b35547b 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -612,6 +612,7 @@ TEST_F(MathGradTest, Cos) {
// TODO(zhifengc)
// TEST_F(MathGradSComplexTest, Real) {}
// TEST_F(MathGradSComplexTest, Imag) {}
+// TEST_F(MathGradSComplexTest, Angle) {}
// TEST_F(MathGradSComplexTest, Conj) {}
// TEST_F(MathGradTernary, Select) {}
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 36f999ff60..6ff05bd2a6 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -2065,6 +2065,34 @@ tf.imag(input) ==> [4.75, 5.75]
```
)doc");
+REGISTER_OP("Angle")
+ .Input("input: T")
+ .Output("output: Tout")
+ .Attr("T: {complex64, complex128} = DT_COMPLEX64")
+ .Attr("Tout: {float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Returns the argument of a complex number.
+
+Given a tensor `input` of complex numbers, this operation returns a tensor of
+type `float` that is the argument of each element in `input`. All elements in
+`input` must be complex numbers of the form \\(a + bj\\), where *a*
+is the real part and *b* is the imaginary part.
+
+The argument returned by this operation is of the form \\(atan2(b, a)\\).
+
+For example:
+
+```
+# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+tf.angle(input) ==> [2.0132, 1.056]
+```
+
+@compatibility(numpy)
+Equivalent to np.angle.
+@end_compatibility
+)doc");
+
REGISTER_OP("Conj")
.Input("input: T")
.Output("output: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 35342b75a8..bdc57c521e 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -593,6 +593,45 @@ op {
is_stateful: true
}
op {
+ name: "Angle"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "Tout"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_COMPLEX64
+ }
+ allowed_values {
+ list {
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ attr {
+ name: "Tout"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ summary: "Returns the argument of a complex number."
+ description: "Given a tensor `input` of complex numbers, this operation returns a tensor of\ntype `float` that is the argument of each element in `input`. All elements in\n`input` must be complex numbers of the form \\(a + bj\\), where *a*\nis the real part and *b* is the imaginary part.\n\nThe argument returned by this operation is of the form \\(atan2(b, a)\\).\n\nFor example:\n```\n # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]\ntf.angle(input) ==> [2.0132, 1.056]\n```"
+}
+op {
name: "Any"
input_arg {
name: "input"
diff --git a/tensorflow/docs_src/api_guides/python/math_ops.md b/tensorflow/docs_src/api_guides/python/math_ops.md
index b3c7a0c010..dee7f1618a 100644
--- a/tensorflow/docs_src/api_guides/python/math_ops.md
+++ b/tensorflow/docs_src/api_guides/python/math_ops.md
@@ -122,6 +122,7 @@ functions to your graph.
* @{tf.complex}
* @{tf.conj}
* @{tf.imag}
+* @{tf.angle}
* @{tf.real}
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 8f5de741fc..dfa7f26c2b 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -10006,6 +10006,52 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf
return op.Output(0)
}
+// ArgAttr is an optional argument to Arg.
+type ArgAttr func(optionalAttr)
+
+// ArgTout sets the optional Tout attribute to value.
+// If not specified, defaults to DT_FLOAT
+func ArgTout(value tf.DataType) ArgAttr {
+ return func(m optionalAttr) {
+ m["Tout"] = value
+ }
+}
+
+// Returns the argument of a complex number.
+//
+// Given a tensor `input` of complex numbers, this operation returns a tensor of
+// type `float` that is the argument of each element in `input`. All elements in
+// `input` must be complex numbers of the form \\(a + bj\\), where *a*
+// is the real part and *b* is the imaginary part.
+//
+//
+// The argument returned by this operation is of the form \\(atan2(b, a)\\).
+//
+// For example:
+//
+// ```
+// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+// tf.angle(input) ==> [2.0132, 1.056]
+// ```
+func Angle(scope *Scope, input tf.Output, optional ...ArgAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Angle",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes fingerprints of the input strings.
//
// Arguments:
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index aff78f2c70..abe325ac41 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -1950,6 +1950,33 @@ class ComplexMakeRealImagTest(test.TestCase):
self._compareRealImag(cplx, use_gpu=False)
self._compareRealImag(cplx, use_gpu=True)
+ def _compareAngle(self, cplx, use_gpu):
+ np_angle = np.angle(cplx)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inx = ops.convert_to_tensor(cplx)
+ tf_angle = math_ops.angle(inx)
+ tf_angle_val = sess.run([tf_angle])
+ self.assertAllEqual(np_angle, tf_angle_val)
+ self.assertShapeEqual(np_angle, tf_angle)
+
+ def testAngle64(self):
+ real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32)
+ imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32)
+ cplx = real + 1j * imag
+ self._compareAngle(cplx, use_gpu=False)
+ # TODO: Enable GPU tests for angle op after resolving
+ # build failures on GPU (See #10643 for context).
+ # self._compareAngle(cplx, use_gpu=True)
+
+ def testAngle(self):
+ real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float64)
+ imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float64)
+ cplx = real + 1j * imag
+ self._compareAngle(cplx, use_gpu=False)
+ # TODO: Enable GPU tests for angle op after resolving
+ # build failures on GPU (See #10643 for context).
+ # self._compareAngle(cplx, use_gpu=True)
+
def testRealReal(self):
for dtype in dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.float32, dtypes_lib.float64:
x = array_ops.placeholder(dtype)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 9aa64a3298..262f5d9564 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -1014,6 +1014,19 @@ def _ImagGrad(_, grad):
return math_ops.complex(zero, grad)
+@ops.RegisterGradient("Angle")
+def _AngleGrad(op, grad):
+ """Returns -grad / (Im(x) + iRe(x))"""
+ x = op.inputs[0]
+ with ops.control_dependencies([grad.op]):
+ re = math_ops.real(x)
+ im = math_ops.imag(x)
+ z = math_ops.reciprocal(math_ops.complex(im, re))
+ zero = constant_op.constant(0, dtype=grad.dtype)
+ complex_grad = math_ops.complex(grad, zero)
+ return -complex_grad * z
+
+
@ops.RegisterGradient("Conj")
def _ConjGrad(_, grad):
"""Returns the complex conjugate of grad."""
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index c7b2068353..fb486ba23d 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -100,6 +100,7 @@ See the @{$python/math_ops} guide.
@@complex
@@conj
@@imag
+@@angle
@@real
@@fft
@@ifft
@@ -622,10 +623,9 @@ def imag(input, name=None):
r"""Returns the imaginary part of a complex number.
Given a tensor `input` of complex numbers, this operation returns a tensor of
- type `float32` or `float64` that is the imaginary part of each element in
- `input`. All elements in `input` must be complex numbers of the form \\(a +
- bj\\), where *a* is the real part and *b* is the imaginary part returned by
- this operation.
+ type `float` that is the argument of each element in `input`. All elements in
+ `input` must be complex numbers of the form \\(a + bj\\), where *a*
+ is the real part and *b* is the imaginary part returned by the operation.
For example:
@@ -646,6 +646,35 @@ def imag(input, name=None):
return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name)
+def angle(input, name=None):
+ r"""Returns the argument of a complex number.
+
+ Given a tensor `input` of complex numbers, this operation returns a tensor of
+ type `float32` or `float64` that is the argument of each element in `input`.
+ All elements in `input` must be complex numbers of the form \\(a + bj\\),
+ where *a* is the real part and *b* is the imaginary part.
+
+ The argument returned by this function is of the form \\(atan2(b, a)\\).
+
+ For example:
+
+ ```
+ # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+ tf.angle(input) ==> [2.0132, 1.056]
+ ```
+
+ Args:
+ input: A `Tensor`. Must be one of the following types: `complex64`,
+ `complex128`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `float32` or `float64`.
+ """
+ with ops.name_scope(name, "Angle", [input]) as name:
+ return gen_math_ops.angle(input, Tout=input.dtype.real_dtype, name=name)
+
+
# pylint: enable=redefined-outer-name,redefined-builtin
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index cfee6a122d..194888b454 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -557,6 +557,10 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "angle"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "arg_max"
argspec: "args=[\'input\', \'dimension\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}