aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/common_shape_fns.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns.h')
-rw-r--r--tensorflow/core/framework/common_shape_fns.h8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 293c40e04d..789746b403 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -265,9 +265,15 @@ Status ConcatShape(shape_inference::InferenceContext* c,
// Shape function for concat operations.
Status ConcatV2Shape(shape_inference::InferenceContext* c);
+// Shape function for binary operators that broadcast their inputs
+// and with output to output_index.
+Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index);
+
// Shape function for binary operators that broadcast their inputs.
// Tested by ops/math_ops_test.cc.
-Status BroadcastBinaryOpShapeFn(InferenceContext* c);
+inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
+ return BroadcastBinaryOpOutputShapeFn(c, 0);
+}
// Shape function for random operations.
Status RandomShape(shape_inference::InferenceContext* c);