aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-26 19:08:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 19:10:31 -0700
commit56d1cfde15c04ebe27fe31409a724a56e7051b15 (patch)
tree51fa8253438a06004294cc18333100a9516d30cc /tensorflow/compiler/xla/service/shape_inference.cc
parent0be974c423f6e5c363db2d95ed335dde4cb4e69b (diff)
[XLA] Redesign: implement and test ternary ops.
PiperOrigin-RevId: 190561679
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 2a70ea0354..36456d552d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1038,8 +1038,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
const HloInstruction* ehs) {
- return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs->shape(),
- rhs->shape(), ehs->shape());
+ return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
+ HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
+ return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs);
}
/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(