aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/common_shape_fns.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns.h')
-rw-r--r--tensorflow/core/framework/common_shape_fns.h17
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 87bb133d92..2bedce1d6a 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -267,7 +267,22 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c);
// Shape function for binary operators that broadcast their inputs
// and with output to output_index.
-Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index);
+// Note: out cannot be NULL.
+Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
+ ShapeHandle shape_x,
+ ShapeHandle shape_y,
+ ShapeHandle* out);
+
+// Shape function for binary operators that broadcast their inputs
+// and with output to output_index.
+inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
+ int output_index) {
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(
+ BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out));
+ c->set_output(output_index, out);
+ return Status::OK();
+}
// Shape function for binary operators that broadcast their inputs.
// Tested by ops/math_ops_test.cc.