diff options
author | 2018-04-30 17:41:33 -0700 | |
---|---|---|
committer | 2018-04-30 17:43:59 -0700 | |
commit | 45bafe9a3589fc735c22c3c703f8689ea9c1e71e (patch) | |
tree | e39723521a1ca68e9c2c74d1a9d3ac5ef2e8abc4 /tensorflow/compiler/aot | |
parent | c89a1d9605427d74079774af7da37933f9ca153c (diff) |
[XLA] Redesign: migrate tensorflow/compiler/tf2xla, tensorflow/compiler/aot:
- xla::ComputationBuilder -> xla::XlaBuilder
- xla::ComputationDataHandle -> xla::XlaOp
- xla::Computation -> xla::XlaComputation
- xla::CompileOnlyClient::AotComputationInstance -> xla::CompileOnlyClient::AotXlaComputationInstance
- xla::SessionModule -> xla::HloSnapshot
PiperOrigin-RevId: 194874462
Diffstat (limited to 'tensorflow/compiler/aot')
-rw-r--r-- | tensorflow/compiler/aot/compile.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tests/tfcompile_test.cc | 14 |
2 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 31044ff85d..bbc35da2ef 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -44,7 +44,7 @@ namespace { // Compiles the XLA computation into executable code. Status CompileXla(xla::CompileOnlyClient* client, - const xla::Computation& computation, + const xla::XlaComputation& computation, const xla::cpu::CpuAotCompilationOptions& aot_opts, CompileResult* compile_result) { // Retrieves arg and result layouts from the computation. @@ -62,7 +62,7 @@ Status CompileXla(xla::CompileOnlyClient* client, for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } - xla::CompileOnlyClient::AotComputationInstance instance; + xla::CompileOnlyClient::AotXlaComputationInstance instance; instance.computation = &computation; instance.argument_layouts = std::move(arg_layouts); instance.result_layout = &pshape->result(); @@ -93,14 +93,14 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, xla::CompileOnlyClient* client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); - xla::Computation computation; + xla::XlaComputation computation; TF_RETURN_IF_ERROR( ConvertGraphDefToXla(graph_def, config, client, &computation)); if (!flags.out_session_module.empty()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module, + TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module, computation.Snapshot()); - // Serialize the SessionModule deterministically so that all the outputs of - // a tf_library genrule are deterministic. + // Serialize the HloSnapshot deterministically so that all the outputs of a + // tf_library genrule are deterministic. string proto; TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index aa9d968265..27ba42b31f 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -525,14 +525,16 @@ TEST(TFCompileTest, HloProfiling) { auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); + "%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%arg1.0.1)"); auto add_profile_line = HasSubstr( - "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); + "%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%arg1.0.1)"); auto tuple_profile_line = HasSubstr( - "%tuple.2 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, " - "f32[2,2]{1,0} %add)"); - auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)"); - auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)"); + "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " + "%dot.0.2, f32[2,2]{1,0} %add.0.5)"); + auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); hlo_profile_lines.erase(hlo_profile_lines.begin() + 7, hlo_profile_lines.end()); |