aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-22 00:07:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-22 00:10:43 -0800
commit531bf79775188a34c8ca5c427a554620a837847c (patch)
treee548bf350f78e26f4e6d65399f4f60b27056b0da
parenta2039aa91f6741d4a786c268851368eba1119366 (diff)
[XLA] Support conditional in all backends.
PiperOrigin-RevId: 179900775
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc4
-rw-r--r--tensorflow/compiler/xla/tests/BUILD6
2 files changed, 3 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 52014b508c..9805818e4c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1684,9 +1684,11 @@ bool HloInstruction::IdenticalSlowPath(
return custom_call_target_ == other.custom_call_target_;
case HloOpcode::kReverse:
return dimensions() == other.dimensions();
+ case HloOpcode::kConditional:
+ return eq_computations(true_computation(), other.true_computation()) &&
+ eq_computations(false_computation(), other.false_computation());
// These opcodes are not yet supported.
- case HloOpcode::kConditional:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kSort:
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 782f444dc4..45689976b0 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -442,12 +442,6 @@ xla_test(
xla_test(
name = "conditional_test",
srcs = ["conditional_test.cc"],
- # Currently, Conditional is supported only in CPU and GPU backends.
- backends = [
- "cpu",
- "gpu",
- "cpu_parallel",
- ],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",