diff options
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns.h')
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.h | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 87bb133d92..2bedce1d6a 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -267,7 +267,22 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c); // Shape function for binary operators that broadcast their inputs // and with output to output_index. -Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index); +// Note: out cannot be NULL. +Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, + ShapeHandle shape_x, + ShapeHandle shape_y, + ShapeHandle* out); + +// Shape function for binary operators that broadcast their inputs +// and with output to output_index. +inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, + int output_index) { + ShapeHandle out; + TF_RETURN_IF_ERROR( + BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out)); + c->set_output(output_index, out); + return Status::OK(); +} // Shape function for binary operators that broadcast their inputs. // Tested by ops/math_ops_test.cc. |