aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/nn_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r--tensorflow/core/ops/nn_ops.cc23
1 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 1f4e9753c3..6c2fc60bab 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1062,12 +1062,27 @@ REGISTER_OP("SoftmaxCrossEntropyWithLogits")
.Attr("T: {half, bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
- TF_RETURN_IF_ERROR(c->Merge(input, c->input(1), &input));
+ if (c->WithRank(c->input(0), 2, &input) == Status::OK() &&
+ c->Merge(input, c->input(1), &input) == Status::OK()) {
+ DimensionHandle batch_size = c->Dim(input, 0);
+ c->set_output(0, c->Vector(batch_size));
+ c->set_output(1, input);
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFn(c, 1));
- DimensionHandle batch_size = c->Dim(input, 0);
+ if (!c->RankKnown(c->output(1))) {
+ return errors::InvalidArgument(
+ "Shape must be broadcasted with rank 2, but is rank is unknown.");
+ }
+
+ if (c->Rank(c->output(1)) != 2) {
+ return errors::InvalidArgument(
+ "Shape must be broadcasted with rank 2, but is rank ",
+ c->Rank(c->output(1)));
+ }
+ DimensionHandle batch_size = c->Dim(c->output(1), 0);
c->set_output(0, c->Vector(batch_size));
- c->set_output(1, input);
return Status::OK();
});