aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-12 10:16:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-12 10:20:30 -0700
commite6b011763a60d239972c8c6c0f36536ab6f885a3 (patch)
tree8930a1e6f5efa50c860683ea86807335c7470cbf /tensorflow/cc/framework
parentf63aa7f49f81a66112bfef6670a18658d5a479e5 (diff)
Extend c++ gradient_checker to complex types.
PiperOrigin-RevId: 168392949
Diffstat (limited to 'tensorflow/cc/framework')
-rw-r--r--tensorflow/cc/framework/gradient_checker.cc291
-rw-r--r--tensorflow/cc/framework/gradient_checker.h28
-rw-r--r--tensorflow/cc/framework/gradient_checker_test.cc70
3 files changed, 294 insertions, 95 deletions
diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc
index f3a7c138c4..de2645cb44 100644
--- a/tensorflow/cc/framework/gradient_checker.cc
+++ b/tensorflow/cc/framework/gradient_checker.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -31,7 +32,74 @@ namespace {
// TODO(andydavis) Vectorize and/or multi-thread Jacobian computations if
// performance becomes an issue.
+// BaseUnitsForType provides a list of typed unit values for each basis in the
+// requested type.
+// When T is real,
+// BaseUnitsForType<T>::values() is just a single-entry vector [1]
+// When T is complex,
+// BaseUnitsForType<T>::values() is a two-entry vector [1, i] - the unit
+// values in each of its two bases.
template <typename T>
+struct BaseUnitsForType {}; // Specializations below
+
+// Template specialization for BaseUnitsForType
+#define SET_BASE_UNITS_FOR_TYPE(TYPE, INIT) \
+ template <> \
+ struct BaseUnitsForType<TYPE> { \
+ static const std::vector<TYPE>& values() { \
+ static std::vector<TYPE>* units = new std::vector<TYPE> INIT; \
+ return *units; \
+ } \
+ }
+
+SET_BASE_UNITS_FOR_TYPE(float, {1});
+SET_BASE_UNITS_FOR_TYPE(double, {1});
+SET_BASE_UNITS_FOR_TYPE(complex64, ({{1, 0}, {0, 1}}));
+SET_BASE_UNITS_FOR_TYPE(complex128, ({{1, 0}, {0, 1}}));
+
+// SetJacobian sets the jacobian value at the provided row and column from a
+// tensor entry with type T.
+// When T is real, this is a simple assignment that casts the entry into the
+// jacobian type.
+// When T is complex, it assigns the real and complex values to successive rows
+// or columns in the matrix depending on the expand_by_row parameter
+template <typename T, typename JAC_T>
+typename std::enable_if<std::is_floating_point<T>::value>::type SetJacobian(
+ typename TTypes<JAC_T>::Matrix* jacobian, const int row, const int col,
+ const T& value, const bool expand_by_row) {
+ (*jacobian)(row, col) = JAC_T{value};
+}
+
+template <typename T, typename JAC_T>
+typename std::enable_if<is_complex<T>::value>::type SetJacobian(
+ typename TTypes<JAC_T>::Matrix* jacobian, const int row, const int col,
+ const T& value, const bool expand_by_row) {
+ (*jacobian)(row, col) = JAC_T{value.real()};
+ if (expand_by_row) {
+ (*jacobian)(row + 1, col) = JAC_T{value.imag()};
+ } else {
+ (*jacobian)(row, col + 1) = JAC_T{value.imag()};
+ }
+}
+
+// JacobianStride<T>::value holds the number of Jacobian elements needed to
+// represent one element of the given type.
+// When T is real the stride is 1, and when T is complex the stride is 2.
+template <typename T>
+struct JacobianStride {}; // Specializations below
+
+#define SET_JACOBIAN_STRIDE(TYPE, VALUE) \
+ template <> \
+ struct JacobianStride<TYPE> { \
+ static constexpr int value = VALUE; \
+ }
+
+SET_JACOBIAN_STRIDE(float, 1);
+SET_JACOBIAN_STRIDE(double, 1);
+SET_JACOBIAN_STRIDE(complex64, 2);
+SET_JACOBIAN_STRIDE(complex128, 2);
+
+template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeTheoreticalJacobianTranspose(
const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
@@ -44,9 +112,9 @@ Status ComputeTheoreticalJacobianTranspose(
OutputList dys;
dys.reserve(y_shapes.size());
for (const auto& y_shape : y_shapes) {
- // TODO(suharshs): This currently assumes that all x's are the same type.
+ // TODO(suharshs): This currently assumes that all y's are the same type.
dys.push_back(
- ops::Cast(scope, ops::Const(scope, 1.0, y_shape), xs[0].type()));
+ ops::Cast(scope, ops::Const(scope, 1.0, y_shape), ys[0].type()));
}
OutputList dxs;
TF_RETURN_IF_ERROR(AddSymbolicGradients(scope, ys, xs, dys, &dxs));
@@ -55,7 +123,7 @@ Status ComputeTheoreticalJacobianTranspose(
std::vector<Tensor> dy_datas(y_num);
for (int i = 0; i < y_num; i++) {
dy_datas[i] = Tensor(ys[i].type(), y_shapes[i]);
- auto dy_data_flat = dy_datas[i].flat<T>();
+ auto dy_data_flat = dy_datas[i].flat<Y_T>();
dy_data_flat.setZero();
}
@@ -68,30 +136,41 @@ Status ComputeTheoreticalJacobianTranspose(
feed_list.insert({dys[i], dy_datas[i]});
}
+ // x_stride and y_stride are used to calculate the correct jacobian row and
+ // column position for a pair of elements at positions r, c within the x and y
+ // tensors respectively.
+ const int x_stride = JacobianStride<X_T>::value;
+ const int y_stride = JacobianStride<Y_T>::value;
ClientSession session(scope);
for (int y_idx = 0; y_idx < y_num; y_idx++) {
- auto dy_data_flat = dy_datas[y_idx].flat<T>();
+ auto dy_data_flat = dy_datas[y_idx].flat<Y_T>();
const int64 dy_size = y_shapes[y_idx].num_elements();
// Compute the theoretical Jacobians one row at a time by back propagating
- // '1.0' for each element of 'dy', while holding all other elements of 'dy'
- // at zero.
+ // '1.0' (or '1' and 'i' if y is complex) for each element of 'dy', while
+ // holding all other elements of 'dy' at zero.
for (int c = 0; c < dy_size; ++c) {
- dy_data_flat(c) = 1.0;
-
- std::vector<Tensor> dxout;
- TF_RETURN_IF_ERROR(session.Run(feed_list, dxs, &dxout));
-
- for (int x_idx = 0; x_idx < x_num; x_idx++) {
- const int64 x_size = x_shapes[x_idx].num_elements();
- auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<T>();
- auto dx_flat = dxout[x_idx].flat<T>();
- for (int r = 0; r < x_size; ++r) {
- jacobian(r, c) = dx_flat(r);
+ int unit_dimension = 0;
+ for (Y_T unit : BaseUnitsForType<Y_T>::values()) {
+ dy_data_flat(c) = unit;
+
+ std::vector<Tensor> dxout;
+ TF_RETURN_IF_ERROR(session.Run(feed_list, dxs, &dxout));
+
+ for (int x_idx = 0; x_idx < x_num; x_idx++) {
+ const int64 x_size = x_shapes[x_idx].num_elements();
+ auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
+ auto dx_flat = dxout[x_idx].flat<X_T>();
+ for (int r = 0; r < x_size; ++r) {
+ SetJacobian<X_T, JAC_T>(&jacobian, r * x_stride,
+ c * y_stride + unit_dimension, dx_flat(r),
+ true /* expand_by_row=true */);
+ }
}
- }
- dy_data_flat(c) = 0.0;
+ dy_data_flat(c) = Y_T{0};
+ unit_dimension++;
+ }
}
}
return Status::OK();
@@ -122,104 +201,154 @@ Status EvaluateGraph(ClientSession* session, const OutputList& xs,
return Status::OK();
}
-template <typename T>
+template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
- const T delta,
+ const JAC_T delta,
std::vector<Tensor>* x_datas,
std::vector<Tensor>* jacobian_ts) {
size_t y_num = y_shapes.size();
size_t x_num = x_shapes.size();
+ // x_stride and y_stride are used to calculate the correct jacobian row and
+ // column position for a pair of elements at positions r, c within the x and y
+ // tensors respectively.
+ const int x_stride = JacobianStride<X_T>::value;
+ const int y_stride = JacobianStride<Y_T>::value;
ClientSession session(scope);
for (int x_idx = 0; x_idx < x_num; x_idx++) {
- auto x_data_flat = (*x_datas)[x_idx].flat<T>();
+ auto x_data_flat = (*x_datas)[x_idx].flat<X_T>();
const int64 x_size = x_shapes[x_idx].num_elements();
// Compute the numeric Jacobian one column at a time by perturbing each
// element of 'x_data' (positively and negatively) by 'delta', and
- // updating the jacobian with the centered difference.
+ // updating the jacobian with the centered difference. When x_data is
+ // complex-valued, we perturb its real and complex parts separately.
for (int r = 0; r < x_size; ++r) {
- // Store current value of 'x' at 'r'.
- T v = x_data_flat(r);
- // Evaluate at positive delta.
- x_data_flat(r) = v + delta;
- std::vector<Tensor> y_pos;
- TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos));
- // Evaluate at negative delta.
- x_data_flat(r) = v - delta;
- std::vector<Tensor> y_neg;
- TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg));
-
- for (int y_idx = 0; y_idx < y_num; y_idx++) {
- // Compute element-wise centered difference and store in each Jacobian.
- auto y_pos_flat = y_pos[y_idx].flat<T>();
- auto y_neg_flat = y_neg[y_idx].flat<T>();
- const int64 y_size = y_shapes[y_idx].num_elements();
- const T scale = 2 * delta;
- auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<T>();
- for (int c = 0; c < y_size; ++c) {
- jacobian(r, c) = (y_pos_flat(c) - y_neg_flat(c)) / scale;
+ int unit_dimension = 0;
+ for (X_T unit : BaseUnitsForType<X_T>::values()) {
+ X_T x_delta = unit * X_T{delta};
+ // Store current value of 'x' at 'r'.
+ X_T v = x_data_flat(r);
+ // Evaluate at positive delta.
+ x_data_flat(r) = v + x_delta;
+ std::vector<Tensor> y_pos;
+ TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos));
+ // Evaluate at negative delta.
+ x_data_flat(r) = v - x_delta;
+ std::vector<Tensor> y_neg;
+ TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg));
+
+ for (int y_idx = 0; y_idx < y_num; y_idx++) {
+ // Compute element-wise centered difference and store in each
+ // Jacobian.
+ auto y_pos_flat = y_pos[y_idx].flat<Y_T>();
+ auto y_neg_flat = y_neg[y_idx].flat<Y_T>();
+ const int64 y_size = y_shapes[y_idx].num_elements();
+ const Y_T scale = Y_T{2 * delta};
+ auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
+ for (int c = 0; c < y_size; ++c) {
+ SetJacobian<Y_T, JAC_T>(&jacobian, r * x_stride + unit_dimension,
+ c * y_stride,
+ (y_pos_flat(c) - y_neg_flat(c)) / scale,
+ false /* expand_by_row=false */);
+ }
}
+ // Restore pre-perturbation value.
+ x_data_flat(r) = v;
+ unit_dimension++;
}
- // Restore pre-perturbation value.
- x_data_flat(r) = v;
}
}
return Status::OK();
}
-template <typename T>
+// The Jacobian is always a real-valued matrix.
+// Given y = f(x) for tensors y and x, it contains the derivatives dy_i/dx_j for
+// every pair y_i in y and x_j in x. Note that the Jacobian is defined directly
+// over the elements of tensors y and x, and doesn't depend on their shapes.
+//
+// If x = (x_1, x_2, ..., x_m) and y = (y_1, y_2, .., y_n) the matrix evaluated
+// is actually the Jacobian transpose, defined as this mxn matrix:
+// dy_1/d_x1 dy_2/dx_1 ... dy_n/dx_1
+// dy_1/dx_2 dy_2/dx_2 ... dy_n/dx_2
+// .
+// .
+// .
+// dy_1/dx_m dy_2/dx_m ... dy_n/dx_m
+//
+// If x or y is complex, each complex entry is "expanded" into a real and
+// imaginary entry, and the Jacobian is organized as above on the expanded list.
+// e.g.
+// [y1, y2] = Square([x1, x2]) where x and y are complex.
+// Writing
+// x = [x1_real, x1_imag, x2_real, x2_imag]
+// y = [y1_real, y1_imag, y2_real, y2_imag]
+// the Jacobian transpose is
+// the 4x4 matrix:
+// dy1_real/dx1_real dy1_imag/dx1_real dy2_real/dx1_real dy2_imag/dx1_real
+// dy1_real/dx1_imag dy1_imag/dx1_imag dy2_real/dx1_imag dy2_imag/dx1_imag
+// dy1_real/dx2_real dy1_imag/dx2_real dy2_real/dx2_real dy2_imag/dx2_real
+// dy1_real/dx2_imag dy1_imag/dx2_imag dy2_real/dx2_imag dy2_imag/dx2_imag
+template <typename X_T, typename Y_T, typename JAC_T>
void InitJacobians(const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const std::vector<TensorShape>& y_shapes,
std::vector<Tensor>* jacobians) {
- size_t y_num = y_shapes.size();
- size_t x_num = x_shapes.size();
+ const size_t y_num = y_shapes.size();
+ const size_t x_num = x_shapes.size();
+ const DataType jacobian_type = DataTypeToEnum<JAC_T>::v();
jacobians->resize(y_num * x_num);
for (int x_idx = 0; x_idx < x_num; x_idx++) {
- const int64 x_size = x_shapes[x_idx].num_elements();
+ // The number of rows is the number of elements in the x tensor multiplied
+ // by the number of Jacobian entries needed to represent each x type.
+ const int64 x_size =
+ x_shapes[x_idx].num_elements() * JacobianStride<X_T>::value;
for (int y_idx = 0; y_idx < y_num; y_idx++) {
- const int64 y_size = y_shapes[y_idx].num_elements();
- Tensor jacobian_t(xs[x_idx].type(), {x_size, y_size});
- auto jacobian_t_flat = jacobian_t.flat<T>();
+ // The number of columns is the number of elements in the y tensor
+ // multiplied by the number of Jacobian entries needed to represent each
+ // y type.
+ const int64 y_size =
+ y_shapes[y_idx].num_elements() * JacobianStride<Y_T>::value;
+ Tensor jacobian_t(jacobian_type, {x_size, y_size});
+ auto jacobian_t_flat = jacobian_t.flat<JAC_T>();
jacobian_t_flat.setZero();
(*jacobians)[x_idx * y_num + y_idx] = std::move(jacobian_t);
}
}
}
-template <typename T>
+template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
std::vector<Tensor>* x_datas,
- T* max_error) {
+ JAC_T* max_error) {
// Initialize theoretical Jacobians to zeros.
std::vector<Tensor> jacobian_ts;
- InitJacobians<T>(xs, x_shapes, y_shapes, &jacobian_ts);
+ InitJacobians<X_T, Y_T, JAC_T>(xs, x_shapes, y_shapes, &jacobian_ts);
// Compute theoretical Jacobian.
- TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
- scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts));
+ TF_RETURN_IF_ERROR((ComputeTheoreticalJacobianTranspose<X_T, Y_T, JAC_T>(
+ scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts)));
// Initialize numeric Jacobian to zeros.
std::vector<Tensor> jacobian_ns;
- InitJacobians<T>(xs, x_shapes, y_shapes, &jacobian_ns);
+ InitJacobians<X_T, Y_T, JAC_T>(xs, x_shapes, y_shapes, &jacobian_ns);
// Compute numeric Jacobian.
- TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
- scope, xs, x_shapes, ys, y_shapes, 1e-3, x_datas, &jacobian_ns));
+ TF_RETURN_IF_ERROR((ComputeNumericJacobianTranspose<X_T, Y_T, JAC_T>(
+ scope, xs, x_shapes, ys, y_shapes, JAC_T{1e-3f}, x_datas, &jacobian_ns)));
for (int i = 0; i < jacobian_ts.size(); i++) {
// Compute the maximum error between theoretical and numeric Jacobians.
*max_error = 0.0;
- auto jac_t = jacobian_ts[i].matrix<T>();
- auto jac_n = jacobian_ns[i].matrix<T>();
+ auto jac_t = jacobian_ts[i].matrix<JAC_T>();
+ auto jac_n = jacobian_ns[i].matrix<JAC_T>();
for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) {
for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) {
*max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
@@ -231,12 +360,12 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
} // namespace
-template <typename T>
+template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientError(const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
- T* max_error) {
+ JAC_T* max_error) {
if (xs.size() != x_shapes.size()) {
return errors::InvalidArgument("xs(size ", xs.size(),
") and x_shapes(size ", x_shapes.size(),
@@ -251,35 +380,39 @@ Status ComputeGradientError(const Scope& scope, const OutputList& xs,
std::vector<Tensor> x_datas(x_shapes.size());
for (int i = 0; i < x_shapes.size(); i++) {
x_datas[i] = Tensor(xs[i].type(), x_shapes[i]);
- auto x_data_flat = x_datas[i].flat<T>();
+ auto x_data_flat = x_datas[i].flat<X_T>();
x_data_flat.setRandom();
}
// Compute gradient error.
- return ComputeGradientErrorInternal(scope, xs, x_shapes, ys, y_shapes,
- &x_datas, max_error);
+ return ComputeGradientErrorInternal<X_T, Y_T, JAC_T>(
+ scope, xs, x_shapes, ys, y_shapes, &x_datas, max_error);
}
-template <typename T>
+template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientError(const Scope& scope, const Output& x,
const Tensor& x_init_value, const Output& y,
- const TensorShape& y_shape, T* max_error) {
+ const TensorShape& y_shape, JAC_T* max_error) {
// Initialize 'x_data' from 'x_init_value'.
std::vector<Tensor> x_datas(1, Tensor(x_init_value));
// Compute gradient error.
- return ComputeGradientErrorInternal(scope, {x}, {x_datas[0].shape()}, {y},
- {y_shape}, &x_datas, max_error);
+ return ComputeGradientErrorInternal<X_T, Y_T, JAC_T>(
+ scope, {x}, {x_datas[0].shape()}, {y}, {y_shape}, &x_datas, max_error);
}
-#define INSTANTIATE_GRAD_ERR_TYPE(T) \
- template Status ComputeGradientError<T>( \
+#define INSTANTIATE_GRAD_ERR_TYPE(X_T, Y_T, JAC_T) \
+ template Status ComputeGradientError<X_T, Y_T, JAC_T>( \
const Scope& scope, const OutputList& xs, \
const std::vector<TensorShape>& x_shapes, const OutputList& ys, \
- const std::vector<TensorShape>& y_shapes, T* max_error); \
- template Status ComputeGradientError<T>( \
+ const std::vector<TensorShape>& y_shapes, JAC_T* max_error); \
+ template Status ComputeGradientError<X_T, Y_T, JAC_T>( \
const Scope& scope, const Output& x, const Tensor& x_init_value, \
- const Output& y, const TensorShape& y_shape, T* max_error);
-
-INSTANTIATE_GRAD_ERR_TYPE(float);
-INSTANTIATE_GRAD_ERR_TYPE(double);
+ const Output& y, const TensorShape& y_shape, JAC_T* max_error);
+
+INSTANTIATE_GRAD_ERR_TYPE(float, float, float);
+INSTANTIATE_GRAD_ERR_TYPE(double, double, double);
+INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float);
+INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float);
+INSTANTIATE_GRAD_ERR_TYPE(complex64, complex64, float);
+INSTANTIATE_GRAD_ERR_TYPE(complex128, complex128, double);
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/gradient_checker.h b/tensorflow/cc/framework/gradient_checker.h
index 2e61213615..d055c60d09 100644
--- a/tensorflow/cc/framework/gradient_checker.h
+++ b/tensorflow/cc/framework/gradient_checker.h
@@ -24,19 +24,39 @@ namespace tensorflow {
/// Returns in 'max_error' the maximum element-wise error for dy/dx between the
/// computed and numeric Jacobian matrices where 'xs' and 'ys' are tensors.
+/// X_T and Y_T are the c++ types for the x and y tensors, and JAC_T is a
+/// real-valued type to store the Jacobian derivatives dy/dx.
/// This function adds operations to the graph associated with 'scope'.
-template <typename T>
+///
+/// Examples:
+/// if y = Square(x), where x (and so y) are DT_FLOAT,
+/// <X_T, Y_T, JAC_T> should be <float, float, float>
+///
+/// if y = Square(x), where x (and so y) are DT_DOUBLE,
+/// <X_T, Y_T, JAC_T> should be <double, double, double>
+///
+/// if y = Square(x), where x (and so y) are DT_COMPLEX64,
+/// <X_T, Y_T, JAC_T> should be <complex64, complex64, float>
+/// Note that JAC_T is always real-valued, and should be an appropriate
+/// precision to host the partial derivatives for dy/dx
+///
+/// if y = ComplexAbs(x) where x is DT_COMPLEX64 (so y is DT_FLOAT)
+/// <X_T, Y_T, JAC_T> should be <complex64, float, float>
+///
+/// if y = Complex(x, x) where x is DT_FLOAT (so y is DT_COMPLEX64)
+/// <X_T, Y_T, JAC_T> should be <float, complex64, float>
+template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientError(const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
- T* max_error);
+ JAC_T* max_error);
/// Overload of ComputeGradientError which takes an initial value for 'x'.
-template <typename T>
+template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientError(const Scope& scope, const Output& x,
const Tensor& x_init_value, const Output& y,
- const TensorShape& y_shape, T* max_error);
+ const TensorShape& y_shape, JAC_T* max_error);
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc
index c5bddc50fc..fdc457f40a 100644
--- a/tensorflow/cc/framework/gradient_checker_test.cc
+++ b/tensorflow/cc/framework/gradient_checker_test.cc
@@ -34,8 +34,8 @@ TEST(GradientCheckerTest, BasicFloat) {
auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
auto y = Square(scope, x);
float max_error;
- TF_ASSERT_OK(ComputeGradientError<float>(scope, {x}, {shape}, {y}, {shape},
- &max_error));
+ TF_ASSERT_OK((ComputeGradientError<float, float, float>(
+ scope, {x}, {shape}, {y}, {shape}, &max_error)));
EXPECT_LT(max_error, 1e-4);
}
@@ -45,11 +45,57 @@ TEST(GradientCheckerTest, BasicDouble) {
auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(shape));
auto y = Square(scope, x);
double max_error;
- TF_ASSERT_OK(ComputeGradientError<double>(scope, {x}, {shape}, {y}, {shape},
- &max_error));
+ TF_ASSERT_OK((ComputeGradientError<double, double, double>(
+ scope, {x}, {shape}, {y}, {shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}
+TEST(GradientCheckerTest, BasicComplex64) {
+ Scope scope = Scope::NewRootScope();
+ TensorShape shape({2, 4, 3});
+ auto x = Placeholder(scope, DT_COMPLEX64, Placeholder::Shape(shape));
+ auto y = Square(scope, x);
+ float max_error;
+ TF_ASSERT_OK((ComputeGradientError<complex64, complex64, float>(
+ scope, {x}, {shape}, {y}, {shape}, &max_error)));
+ EXPECT_LT(max_error, 1e-4);
+}
+
+TEST(GradientCheckerTest, BasicComplex128) {
+ Scope scope = Scope::NewRootScope();
+ TensorShape shape({2, 4, 3});
+ auto x = Placeholder(scope, DT_COMPLEX128, Placeholder::Shape(shape));
+ auto y = Square(scope, x);
+ double max_error;
+ TF_ASSERT_OK((ComputeGradientError<complex128, complex128, double>(
+ scope, {x}, {shape}, {y}, {shape}, &max_error)));
+ EXPECT_LT(max_error, 1e-10);
+}
+
+TEST(GradientCheckerTest, FloatToComplex64) {
+ // Test an op whose inputs are real and outputs are complex
+ Scope scope = Scope::NewRootScope();
+ TensorShape shape({2, 4, 3});
+ auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
+ auto y = Complex(scope, x, x);
+ float max_error;
+ TF_ASSERT_OK((ComputeGradientError<float, complex64, float>(
+ scope, {x}, {shape}, {y}, {shape}, &max_error)));
+ EXPECT_LT(max_error, 1e-4);
+}
+
+TEST(GradientCheckerTest, Complex64ToFloat) {
+ // Test an op whose inputs are complex and outputs are real
+ Scope scope = Scope::NewRootScope();
+ TensorShape shape({2, 4, 3});
+ auto x = Placeholder(scope, DT_COMPLEX64, Placeholder::Shape(shape));
+ auto y = Real(scope, x);
+ float max_error;
+ TF_ASSERT_OK((ComputeGradientError<complex64, float, float>(
+ scope, {x}, {shape}, {y}, {shape}, &max_error)));
+ EXPECT_LT(max_error, 1e-4);
+}
+
TEST(GradientCheckerTest, MatMulGrad) {
Scope scope = Scope::NewRootScope();
@@ -61,8 +107,8 @@ TEST(GradientCheckerTest, MatMulGrad) {
auto y = Const(scope, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, y_shape);
auto z = MatMul(scope, x, y);
double max_error;
- TF_ASSERT_OK(ComputeGradientError<double>(scope, {x}, {x_shape}, {z},
- {z_shape}, &max_error));
+ TF_ASSERT_OK((ComputeGradientError<double, double, double>(
+ scope, {x}, {x_shape}, {z}, {z_shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}
@@ -76,8 +122,8 @@ TEST(GradientCheckerTest, SplitGrad) {
auto y = Split(scope, split_dim, x, /* num_split */ 2);
TensorShape y_shape = TensorShape({5, 1});
double max_error;
- TF_ASSERT_OK(ComputeGradientError<double>(scope, {x}, {x_shape}, y.output,
- {y_shape, y_shape}, &max_error));
+ TF_ASSERT_OK((ComputeGradientError<double, double, double>(
+ scope, {x}, {x_shape}, y.output, {y_shape, y_shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}
@@ -91,8 +137,8 @@ TEST(GradientCheckerTest, StackGrad) {
auto y = Stack(scope, xs, Stack::Axis(0));
TensorShape y_shape({2, 1, 2, 3});
double max_error;
- TF_ASSERT_OK(ComputeGradientError<double>(scope, xs, {x_shape, x_shape}, {y},
- {y_shape}, &max_error));
+ TF_ASSERT_OK((ComputeGradientError<double, double, double>(
+ scope, xs, {x_shape, x_shape}, {y}, {y_shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}
@@ -107,8 +153,8 @@ TEST(GradientCheckerTest, StackUnstackGrad) {
auto tmp = Stack(scope, xs, Stack::Axis(0));
auto y = Unstack(scope, tmp, 2, Unstack::Axis(0));
double max_error;
- TF_ASSERT_OK(ComputeGradientError<double>(scope, xs, {shape, shape}, y.output,
- {shape, shape}, &max_error));
+ TF_ASSERT_OK((ComputeGradientError<double, double, double>(
+ scope, xs, {shape, shape}, y.output, {shape, shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}