diff options
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns.cc')
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index ed3318d841..21c6940b62 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1231,11 +1231,13 @@ Status ConcatV2Shape(InferenceContext* c) { c->num_inputs() - 1 /* dim_index */); } -Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index) { - ShapeHandle shape_x = c->input(0); - ShapeHandle shape_y = c->input(1); +Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, + ShapeHandle shape_x, + ShapeHandle shape_y, + ShapeHandle* out) { + CHECK_NOTNULL(out); if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { - c->set_output(0, c->UnknownShape()); + *out = c->UnknownShape(); return Status::OK(); } const int32 rank_x = c->Rank(shape_x); @@ -1293,7 +1295,7 @@ Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index) { } } - c->set_output(output_index, c->MakeShape(dims)); + *out = c->MakeShape(dims); return Status::OK(); } |