aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-08-21 12:35:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 12:43:09 -0700
commit4f41091f88cca9c87a627864ccd6962e7bb44313 (patch)
tree1b7f9682cac04fc1eb6fe20a37ac52d75cf898df /tensorflow
parentc61a49ec318a42e5740efe566936957126dc04d0 (diff)
[XLA] Propagate invalid shape errors through reduce folding and turn it on
HloEvaluator should be stable enough for reduce folding, but it shouldn't crash when it encounters an instruction without a layout. Verify the layout on every instruction that gets evaluated and return an error on failure. PiperOrigin-RevId: 209641401
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h23
5 files changed, 57 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 7fdffe85c0..73964733e8 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2456,6 +2456,7 @@ tf_cc_test(
":hlo",
":hlo_constant_folding",
":hlo_matchers",
+ ":hlo_parser",
":hlo_pass",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 6dddda1ca8..2ed645c3ae 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -52,9 +52,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
computation->root_instruction() != instruction) {
continue;
}
- // Skip Constant, Parameter, Reduce, and AfterAll operation.
- // TODO(b/35975797): Enable Reduce operation once arbitrary computation
- // are supported by the evaluator.
+ // Skip Constant, Parameter, and AfterAll operation.
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
// TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
// operand in which case constant folding will be impossible and this
@@ -62,7 +60,6 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kTuple ||
- instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kAfterAll) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 64a42c1efc..7cd1481a8a 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -202,5 +203,45 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
EXPECT_TRUE(matched);
}
+const char* const kConstantFoldReduce = R"(
+ HloModule ConstantFoldReduce
+
+ add {
+ a = s32[] parameter(0)
+ b = s32[] parameter(1)
+ ROOT add = s32[] add(a, b)
+ }
+
+ ENTRY r {
+ x = s32[3] constant({1, 2, 3})
+ init = s32[] constant(0)
+ ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
+ })";
+
+TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(kConstantFoldReduce));
+ HloConstantFolding const_folder;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
+
+ EXPECT_EQ(6, module->entry_computation()
+ ->root_instruction()
+ ->literal()
+ .GetFirstElement<int32>());
+}
+
+TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(kConstantFoldReduce));
+ HloInstruction* add = module->computations().begin()->root_instruction();
+ LayoutUtil::ClearLayout(add->mutable_shape());
+ HloConstantFolding const_folder;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ EXPECT_FALSE(result);
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 35d9e799df..fb90049491 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -230,7 +230,6 @@ template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
evaluated_.clear();
arg_literals_.clear();
@@ -267,7 +266,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
return tensorflow::errors::FailedPrecondition(
"Not all operands are constants.");
}
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
arg_literals_.clear();
evaluated_.clear();
@@ -1285,7 +1283,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString();
- return Status::OK();
+ return ShapeUtil::ValidateShape(hlo->shape());
}
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 83d7b404f0..aafba8afe8 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1544,10 +1544,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
auto result = absl::make_unique<Literal>(reduce->shape());
+ Status eval_status;
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
ReturnT result_val = init_scalar;
+ if (!eval_status.ok()) {
+ return result_val;
+ }
std::vector<int64> base(arg_dimensions.size());
for (int64 i = 0; i < multi_index.size(); ++i) {
@@ -1568,7 +1572,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_dim_steps, func);
return static_cast<ReturnT>(computed_result);
}
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index)
+ -> StatusOr<bool> {
auto curr_val = arg_literal.Get<ReturnT>(input_index);
// Evaluate computation with specified literal operands.
@@ -1576,12 +1581,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val);
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<const Literal*>(
- *function,
- {result_val_literal.get(), curr_val_literal.get()})
- .ConsumeValueOrDie();
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+ embedded_evaluator.Evaluate<const Literal*>(
+ *function, {result_val_literal.get(),
+ curr_val_literal.get()}));
// Clear visit states so that we can use the evaluator again on
// the same computation.
embedded_evaluator.ResetVisitStates();
@@ -1591,13 +1594,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
};
// Computes one element of the result, reducing all dimensions that
// contribute to that element.
- ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
- arg_dim_steps, func);
+ eval_status = ShapeUtil::ForEachIndexWithStatus(
+ arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func);
return result_val;
}));
parent_->evaluated_[reduce] = std::move(result);
- return Status::OK();
+ return eval_status;
}
bool IsScalarAdd(HloComputation* computation) {