diff options
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 1f4e9753c3..6c2fc60bab 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1062,12 +1062,27 @@ REGISTER_OP("SoftmaxCrossEntropyWithLogits") .Attr("T: {half, bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle input; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); - TF_RETURN_IF_ERROR(c->Merge(input, c->input(1), &input)); + if (c->WithRank(c->input(0), 2, &input) == Status::OK() && + c->Merge(input, c->input(1), &input) == Status::OK()) { + DimensionHandle batch_size = c->Dim(input, 0); + c->set_output(0, c->Vector(batch_size)); + c->set_output(1, input); + return Status::OK(); + } + TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFn(c, 1)); - DimensionHandle batch_size = c->Dim(input, 0); + if (!c->RankKnown(c->output(1))) { + return errors::InvalidArgument( + "Shape must be broadcasted with rank 2, but is rank is unknown."); + } + + if (c->Rank(c->output(1)) != 2) { + return errors::InvalidArgument( + "Shape must be broadcasted with rank 2, but is rank ", + c->Rank(c->output(1))); + } + DimensionHandle batch_size = c->Dim(c->output(1), 0); c->set_output(0, c->Vector(batch_size)); - c->set_output(1, input); return Status::OK(); }); |