aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 13:28:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 13:28:53 -0700
commit3e9adc8b51a209c30e79812dc844827f8706bc3f (patch)
tree3d2a254bb1c2d9a8c16adb1ae966305cca86a44f /tensorflow/cc
parente2f8d4a8bdfc4e3970cacc89a6b184297205a1cc (diff)
parent7fe0ae12ec42eca1ea07d93bbd63de394743a018 (diff)
Merge pull request #20763 from pbanavara:master
PiperOrigin-RevId: 205880828
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc85
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc22
2 files changed, 93 insertions, 14 deletions
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index c73482d5f4..588e96cb19 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -47,6 +47,72 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
+bool IsZero(const Scope& scope, const Output& grad) {
+ string op_type_name = grad.op().node()->type_string();
+ if (op_type_name == "ZerosLike" || op_type_name == "Zeros") {
+ return true;
+ }
+ // The Operation we were provided is not named something obvious so
+ // we need to actually look at its contents.
+ // The original python code did this by calling a utility function called
+ // tensor_util.constant_value.
+ // There is no C++ equivalent to tensor_util.constant_value so we do nothing
+ // for the moment.
+ return false;
+}
+
+// Multiply after broadcasting vec to match dimensions of mat.
+// Args:
+// vec: A 1-D tensor of dimension [D0]
+// mat: A 2-D tensor of dimesnion [D0, D1]
+//
+// Returns:
+// A tensor of dimension [D0, D1], the result fo vec * mat.
+Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
+ auto reshaped = ExpandDims(scope, vec, -1);
+ return Multiply(scope, reshaped, mat);
+}
+
+Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope,
+ const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // Softmax gradient with cross entropy logits function.
+ // We multiply the backprop for cost with the gradients - op.output[1].
+ // There is no gradient for labels.
+
+ // The outputs of the network are at input index 0.
+ auto logits = op.input(0);
+ // The "truth" labels are at index 1.
+ auto softmax_grad = op.output(1);
+
+ // The loss is the output at index 0, and backprop is the output at index 1.
+ auto grad_loss = grad_inputs[0];
+ auto grad_grad = grad_inputs[1];
+
+ auto grad = BroadcastMul(scope, grad_loss, softmax_grad);
+ if (!IsZero(scope, grad_grad)) {
+ std::vector<int> axis;
+ auto logits_softmax = Softmax(scope, logits);
+
+ auto grad_grad_expand = ExpandDims(scope, grad_grad, 1);
+ auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2);
+ auto matmul_result =
+ BatchMatMul(scope, grad_grad_expand, logits_softmax_expand);
+ axis.push_back(1);
+ auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis));
+ auto subtraction_result = Subtract(scope, grad_grad, squeeze_result);
+ auto multiply_result = Multiply(scope, subtraction_result, logits_softmax);
+ grad = Add(scope, grad, multiply_result);
+ }
+ auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f);
+ grad_outputs->push_back(grad);
+ grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits",
+ SoftmaxCrossEntropyWithLogitsGrad);
+
Status LogSoftmaxGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
@@ -195,9 +261,9 @@ Status MaxPool3DGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
MaxPool3DGrad::Attrs grad_attrs;
- auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0],
- ksize, strides, padding,
- grad_attrs.DataFormat(data_format));
+ auto dx =
+ MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], ksize,
+ strides, padding, grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
@@ -216,10 +282,9 @@ Status AvgPoolGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
internal::AvgPoolGrad::Attrs grad_attrs;
- auto dx =
- internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
- ksize, strides, padding,
- grad_attrs.DataFormat(data_format));
+ auto dx = internal::AvgPoolGrad(scope, Shape(scope, op.input(0)),
+ grad_inputs[0], ksize, strides, padding,
+ grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
@@ -238,9 +303,9 @@ Status AvgPool3DGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
AvgPool3DGrad::Attrs grad_attrs;
- auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
- ksize, strides, padding,
- grad_attrs.DataFormat(data_format));
+ auto dx =
+ AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], ksize,
+ strides, padding, grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index b4d457a9d1..aa72cf7ba2 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -25,6 +25,8 @@ limitations under the License.
namespace tensorflow {
namespace {
+using ops::AvgPool;
+using ops::AvgPool3D;
using ops::BiasAdd;
using ops::Conv2D;
using ops::Elu;
@@ -33,11 +35,9 @@ using ops::FractionalMaxPool;
using ops::L2Loss;
using ops::LogSoftmax;
using ops::LRN;
-using ops::AvgPool;
-using ops::AvgPool3D;
using ops::MaxPool;
-using ops::MaxPoolV2;
using ops::MaxPool3D;
+using ops::MaxPoolV2;
using ops::Placeholder;
using ops::Relu;
using ops::Relu6;
@@ -111,6 +111,20 @@ TEST_F(NNGradTest, SoftmaxGrad) {
RunTest(x, shape, y, shape);
}
+TEST_F(NNGradTest, SoftmaxCrossEntropyWithLogitsGrad) {
+ TensorShape logits_shape({5, 3});
+ TensorShape loss_shape({5});
+
+ auto logits = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
+ auto labels = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
+ auto y =
+ tensorflow::ops::SoftmaxCrossEntropyWithLogits(scope_, logits, labels);
+ // Note the reversal of the backprop and loss orders. Issue #18734 has been
+ // opened for this.
+ RunTest({logits, labels}, {logits_shape, logits_shape}, {y.backprop, y.loss},
+ {logits_shape, loss_shape});
+}
+
TEST_F(NNGradTest, LogSoftmaxGrad) {
TensorShape shape({5, 3});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
@@ -253,7 +267,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
RunTest(x, x_shape, y, y_shape);
}
-TEST_F(NNGradTest, LRN){
+TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
auto y = LRN(scope_, x);