diff options
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns.h')
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.h | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 293c40e04d..789746b403 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -265,9 +265,15 @@ Status ConcatShape(shape_inference::InferenceContext* c, // Shape function for concat operations. Status ConcatV2Shape(shape_inference::InferenceContext* c); +// Shape function for binary operators that broadcast their inputs +// and with output to output_index. +Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index); + // Shape function for binary operators that broadcast their inputs. // Tested by ops/math_ops_test.cc. -Status BroadcastBinaryOpShapeFn(InferenceContext* c); +inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) { + return BroadcastBinaryOpOutputShapeFn(c, 0); +} // Shape function for random operations. Status RandomShape(shape_inference::InferenceContext* c); |