aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-02-16 18:13:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 18:17:31 -0800
commit090bb9168cbcb5bbb3d7fb8e0b64f7d00013d188 (patch)
tree0ab953b0f33b7883db614206005b593ef1edcc32
parenta189502cc3032f0bc8f3294b0e39062e89fe9181 (diff)
[XLA] Pass the module to HloDataflowAnalysis by const reference.
PiperOrigin-RevId: 186072673
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/liveness_util_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc2
9 files changed, 22 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index c812df4235..cc195879a6 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -1156,7 +1156,7 @@ bool IsWhileBody(const HloComputation* computation,
HloModule* module) {
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
- HloDataflowAnalysis::Run(module));
+ HloDataflowAnalysis::Run(*module));
bool changed = false;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index 916b556fd4..9db85bc788 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -49,7 +49,7 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
- HloDataflowAnalysis::Run(module));
+ HloDataflowAnalysis::Run(*module));
// Make sure all operands of a library call are in memory instead of constants
// in IR.
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index 6d2a3aa5b5..30e32a46d7 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -419,7 +419,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
TF_ASSIGN_OR_RETURN(
alias_analysis->dataflow_analysis_,
- HloDataflowAnalysis::Run(module, /*ssa_form=*/true,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
/*bitcast_defines_value=*/false));
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index ccbbe8f196..934e43ba48 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -38,12 +38,12 @@ namespace xla {
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
-HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form,
+HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
bool bitcast_defines_value)
: module_(module),
ssa_form_(ssa_form),
bitcast_defines_value_(bitcast_defines_value),
- call_graph_(CallGraph::Build(module)) {}
+ call_graph_(CallGraph::Build(&module)) {}
bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
const ShapeIndex& index) const {
@@ -115,9 +115,9 @@ void HloDataflowAnalysis::DeleteMarkedValues() {
}
string HloDataflowAnalysis::ToString() const {
- string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
+ string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
StrAppend(&out, " Instruction value sets:\n");
- for (const HloComputation* computation : module_->computations()) {
+ for (const HloComputation* computation : module_.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
StrAppend(&out, " ", instruction->name(), ":\n");
if (ShapeUtil::IsTuple(instruction->shape())) {
@@ -592,7 +592,7 @@ void HloDataflowAnalysis::Propagate() {
}
};
- for (HloComputation* computation : module_->computations()) {
+ for (HloComputation* computation : module_.computations()) {
for (HloInstruction* instruction : computation->instructions()) {
add_to_worklist(instruction);
}
@@ -686,7 +686,7 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
}
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
- for (const HloComputation* computation : module_->computations()) {
+ for (const HloComputation* computation : module_.computations()) {
const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
for (HloInstruction* instruction : computation->instructions()) {
// Create an empty shape tree.
@@ -787,9 +787,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
/* static */
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
- HloModule* module, bool ssa_form, bool bitcast_defines_value) {
- VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name();
- XLA_VLOG_LINES(2, module->ToString());
+ const HloModule& module, bool ssa_form, bool bitcast_defines_value) {
+ VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
+ XLA_VLOG_LINES(2, module.ToString());
auto dataflow_analysis = WrapUnique(
new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
@@ -806,7 +806,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
// lookup is faster.
std::vector<std::vector<HloPosition>> value_positions(
dataflow_analysis->next_value_id_);
- for (const HloComputation* computation : module->computations()) {
+ for (const HloComputation* computation : module.computations()) {
for (HloInstruction* instruction : computation->instructions()) {
for (const auto& pair :
dataflow_analysis->GetInstructionValueSet(instruction)) {
@@ -858,7 +858,7 @@ Status HloDataflowAnalysis::Verify() const {
// For each value in each value set, verify that the value set's position
// appears in the value's positions().
- for (const auto& computation : module_->computations()) {
+ for (const auto& computation : module_.computations()) {
for (const auto& instruction : computation->instructions()) {
for (const auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 89d318188f..7b8a74b096 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -60,7 +60,7 @@ class HloDataflowAnalysis {
// a new HLO value in the analysis. If false then Bitcast forwards the
// value of its operand.
static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
- HloModule* module, bool ssa_form = false,
+ const HloModule& module, bool ssa_form = false,
bool bitcast_defines_value = false);
// Returns true if 'instruction' defines an HLO value at the given shape index
@@ -119,7 +119,7 @@ class HloDataflowAnalysis {
string ToString() const;
protected:
- HloDataflowAnalysis(HloModule* module, bool ssa_form,
+ HloDataflowAnalysis(const HloModule& module, bool ssa_form,
bool bitcast_defines_value = false);
// Returns a new HloValue defined at the given instruction and shape index.
@@ -180,7 +180,7 @@ class HloDataflowAnalysis {
// Verify various invariants of the dataflow analysis.
Status Verify() const;
- HloModule* const module_;
+ const HloModule& module_;
const bool ssa_form_;
const bool bitcast_defines_value_;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index e714b2567f..7bf3a1a060 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -50,7 +50,7 @@ class HloDataflowAnalysisTest : public HloTestBase,
bool bitcast_defines_value = false) {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis");
analysis_ =
- HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
+ HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
.ConsumeValueOrDie();
return *analysis_;
}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index aba66114de..a989fce632 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -262,8 +262,8 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
scalar_shape, HloOpcode::kAdd, constant, xla_while));
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(
- auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true));
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
DependencyHloOrdering ordering(module.get());
// Init value is defined before the while, but live range is not before the
diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc
index 2c2a02f637..f8b309488e 100644
--- a/tensorflow/compiler/xla/service/liveness_util_test.cc
+++ b/tensorflow/compiler/xla/service/liveness_util_test.cc
@@ -35,8 +35,7 @@ class PointsToAnalysisTestBase : public HloTestBase {
CHECK_NOTNULL(module_.get());
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
- dataflow_analysis_ =
- HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie();
+ dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie();
}
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index b060fb13b1..0bc7df2a65 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -287,7 +287,7 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
HloModule* const module) {
- TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module));
+ TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
const auto params = module->entry_computation()->parameter_instructions();
std::minstd_rand0 engine;
std::vector<std::unique_ptr<Literal>> arguments(params.size());