aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-07-31 10:57:23 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-07-31 10:57:23 +0800
commit3ffa6eeac07064ef6b3b270d48ab4fa4ce088803 (patch)
tree01a414a60f2c07b65c3c11c2265325704656ef92 /tensorflow/cc/gradients
parentfee3f260d6eba1aec57df09045459790dcae686f (diff)
parenta6572d3d003cf7ef5b0fffd5ad7c5fc86919465c (diff)
Merge remote-tracking branch 'upstream/master' into ENH/unsafe_div
Diffstat (limited to 'tensorflow/cc/gradients')
-rw-r--r--tensorflow/cc/gradients/array_grad.cc52
-rw-r--r--tensorflow/cc/gradients/array_grad_test.cc7
-rw-r--r--tensorflow/cc/gradients/math_grad.cc1
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc6
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc85
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc22
6 files changed, 154 insertions, 19 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index ff348fadb2..b353accddc 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -421,6 +421,58 @@ Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
+Status SliceGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // Propagate the incoming gradient along all the selected values,
+ // and zero everywhere else. Use the Pad operator for this.
+ //
+ // First create an Nx2 padding where N is the number of input
+ // dimensions. The first column is the number of prepended zeros
+ // for each dimension, and the second column is the number of
+ // appended zeros.
+ //
+ // The first column is just the begin vector.
+ // The second column is the shape of the input element-wise
+ // subtracted by begin+size
+
+ // Running example:
+ // input.shape = [3, 5, 3]
+ // begin = [1, 2, 1], size = [1, 3, 2]
+ Input input = op.input(0);
+ Input begin = op.input(1);
+ // input_rank = 3
+ auto input_rank = Rank(scope, input);
+ // slice_size = [1, 3, 2]
+ auto slice_size = Shape(scope, op.output(0));
+ // padding_shape = [3, 1]
+ auto padding_shape = Stack(scope, {input_rank, 1});
+ // before_padding = [[1]
+ // [2]
+ // [1]]
+ Input before_padding = Reshape(scope, begin, padding_shape);
+ // after_padding_sizes = shape(input) - slice_size - begin
+ // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
+ // = [1, 0, 0]
+ auto after_padding_sizes =
+ Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
+ // after_padding = [[1]
+ // [0]
+ // [0]]
+ Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
+ // paddings = [[1 1]
+ // [2 0]
+ // [1 0]]
+ auto paddings =
+ Concat(scope, {before_padding, after_padding}, Const(scope, 1));
+ grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
+ // Nothing propagated for "begin" and "size" inputs
+ grad_outputs->push_back(NoGradient());
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Slice", SliceGrad);
+
} // anonymous namespace
} // namespace ops
} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc
index de3bd0fc9e..d09275b648 100644
--- a/tensorflow/cc/gradients/array_grad_test.cc
+++ b/tensorflow/cc/gradients/array_grad_test.cc
@@ -378,5 +378,12 @@ TEST_F(ArrayGradTest, StridedSliceGrad) {
RunTest(x, x_shape, y, {1, 2, 2, 2});
}
+TEST_F(ArrayGradTest, SliceGrad) {
+ TensorShape x_shape({3, 5, 3});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ auto y = Slice(scope_, x, {1, 2, 1}, {1, 3, 2});
+ RunTest(x, x_shape, y, {1, 3, 2});
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index a8909846c9..84552e7c5e 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -38,6 +38,7 @@ REGISTER_NO_GRADIENT_OP("NotEqual");
REGISTER_NO_GRADIENT_OP("LogicalAnd");
REGISTER_NO_GRADIENT_OP("LogicalOr");
REGISTER_NO_GRADIENT_OP("LogicalNot");
+REGISTER_NO_GRADIENT_OP("Floor");
// Conjugate helper function returns the conjugate of an Output if it
// is complex valued.
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 27021e28f8..330d1722af 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -479,11 +479,7 @@ TEST_F(CWiseUnaryGradTest, Tan_Complex) {
auto x_fn = [this](const int i) {
return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}});
};
- // TODO(kbsriram)
- // Enable when tan kernel supports complex inputs
- if (false) {
- TestCWiseGrad<complex64, complex64>(TAN, x_fn);
- }
+ TestCWiseGrad<complex64, complex64>(TAN, x_fn);
}
TEST_F(CWiseUnaryGradTest, Atan) {
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index c73482d5f4..588e96cb19 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -47,6 +47,72 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
+bool IsZero(const Scope& scope, const Output& grad) {
+ string op_type_name = grad.op().node()->type_string();
+ if (op_type_name == "ZerosLike" || op_type_name == "Zeros") {
+ return true;
+ }
+ // The Operation we were provided is not named something obvious so
+ // we need to actually look at its contents.
+ // The original python code did this by calling a utility function called
+ // tensor_util.constant_value.
+ // There is no C++ equivalent to tensor_util.constant_value so we do nothing
+ // for the moment.
+ return false;
+}
+
+// Multiply after broadcasting vec to match dimensions of mat.
+// Args:
+// vec: A 1-D tensor of dimension [D0]
+// mat: A 2-D tensor of dimesnion [D0, D1]
+//
+// Returns:
+// A tensor of dimension [D0, D1], the result fo vec * mat.
+Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
+ auto reshaped = ExpandDims(scope, vec, -1);
+ return Multiply(scope, reshaped, mat);
+}
+
+Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope,
+ const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // Softmax gradient with cross entropy logits function.
+ // We multiply the backprop for cost with the gradients - op.output[1].
+ // There is no gradient for labels.
+
+ // The outputs of the network are at input index 0.
+ auto logits = op.input(0);
+ // The "truth" labels are at index 1.
+ auto softmax_grad = op.output(1);
+
+ // The loss is the output at index 0, and backprop is the output at index 1.
+ auto grad_loss = grad_inputs[0];
+ auto grad_grad = grad_inputs[1];
+
+ auto grad = BroadcastMul(scope, grad_loss, softmax_grad);
+ if (!IsZero(scope, grad_grad)) {
+ std::vector<int> axis;
+ auto logits_softmax = Softmax(scope, logits);
+
+ auto grad_grad_expand = ExpandDims(scope, grad_grad, 1);
+ auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2);
+ auto matmul_result =
+ BatchMatMul(scope, grad_grad_expand, logits_softmax_expand);
+ axis.push_back(1);
+ auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis));
+ auto subtraction_result = Subtract(scope, grad_grad, squeeze_result);
+ auto multiply_result = Multiply(scope, subtraction_result, logits_softmax);
+ grad = Add(scope, grad, multiply_result);
+ }
+ auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f);
+ grad_outputs->push_back(grad);
+ grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits",
+ SoftmaxCrossEntropyWithLogitsGrad);
+
Status LogSoftmaxGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
@@ -195,9 +261,9 @@ Status MaxPool3DGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
MaxPool3DGrad::Attrs grad_attrs;
- auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0],
- ksize, strides, padding,
- grad_attrs.DataFormat(data_format));
+ auto dx =
+ MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], ksize,
+ strides, padding, grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
@@ -216,10 +282,9 @@ Status AvgPoolGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
internal::AvgPoolGrad::Attrs grad_attrs;
- auto dx =
- internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
- ksize, strides, padding,
- grad_attrs.DataFormat(data_format));
+ auto dx = internal::AvgPoolGrad(scope, Shape(scope, op.input(0)),
+ grad_inputs[0], ksize, strides, padding,
+ grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
@@ -238,9 +303,9 @@ Status AvgPool3DGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
AvgPool3DGrad::Attrs grad_attrs;
- auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
- ksize, strides, padding,
- grad_attrs.DataFormat(data_format));
+ auto dx =
+ AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], ksize,
+ strides, padding, grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index b4d457a9d1..aa72cf7ba2 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -25,6 +25,8 @@ limitations under the License.
namespace tensorflow {
namespace {
+using ops::AvgPool;
+using ops::AvgPool3D;
using ops::BiasAdd;
using ops::Conv2D;
using ops::Elu;
@@ -33,11 +35,9 @@ using ops::FractionalMaxPool;
using ops::L2Loss;
using ops::LogSoftmax;
using ops::LRN;
-using ops::AvgPool;
-using ops::AvgPool3D;
using ops::MaxPool;
-using ops::MaxPoolV2;
using ops::MaxPool3D;
+using ops::MaxPoolV2;
using ops::Placeholder;
using ops::Relu;
using ops::Relu6;
@@ -111,6 +111,20 @@ TEST_F(NNGradTest, SoftmaxGrad) {
RunTest(x, shape, y, shape);
}
+TEST_F(NNGradTest, SoftmaxCrossEntropyWithLogitsGrad) {
+ TensorShape logits_shape({5, 3});
+ TensorShape loss_shape({5});
+
+ auto logits = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
+ auto labels = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
+ auto y =
+ tensorflow::ops::SoftmaxCrossEntropyWithLogits(scope_, logits, labels);
+ // Note the reversal of the backprop and loss orders. Issue #18734 has been
+ // opened for this.
+ RunTest({logits, labels}, {logits_shape, logits_shape}, {y.backprop, y.loss},
+ {logits_shape, loss_shape});
+}
+
TEST_F(NNGradTest, LogSoftmaxGrad) {
TensorShape shape({5, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
@@ -253,7 +267,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
RunTest(x, x_shape, y, y_shape);
}
-TEST_F(NNGradTest, LRN){
+TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto y = LRN(scope_, x);