aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/computation_tracker.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/computation_tracker.cc')
-rw-r--r--tensorflow/compiler/xla/service/computation_tracker.cc204
1 files changed, 204 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc
new file mode 100644
index 0000000000..281277bed5
--- /dev/null
+++ b/tensorflow/compiler/xla/service/computation_tracker.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/computation_tracker.h"
+
+#include <list>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+ComputationTracker::ComputationTracker() : next_computation_(1) {}
+
+ComputationHandle ComputationTracker::NewComputation(
+ const string& computation_name) {
+ tensorflow::mutex_lock lock(computation_mutex_);
+ ComputationHandle computation_handle;
+ int64 handle_value = next_computation_++;
+ computation_handle.set_handle(handle_value);
+ opaque_to_computation_[handle_value] =
+ MakeUnique<UserComputation>(computation_name, computation_handle);
+ return computation_handle;
+}
+
+StatusOr<ComputationHandle> ComputationTracker::LoadSessionModule(
+ const SessionModule& session_module) {
+ tensorflow::mutex_lock lock(computation_mutex_);
+
+ // For each embedded computation, create a new computation based on its
+ // serialized data, and place the mapping from the old computation handle to
+ // the new computation handle.
+ std::map<int64, ComputationHandle> old_to_new;
+ for (const SessionComputation& computation :
+ session_module.embedded_computations()) {
+ const int64 old_handle = computation.computation_handle().handle();
+ TF_ASSIGN_OR_RETURN(old_to_new[old_handle],
+ LoadSessionComputation(computation, &old_to_new));
+ }
+
+ // Finally, place the entry computation in the tracker with all of the
+ // remappings populated from the above.
+ const int64 old_handle = session_module.entry().computation_handle().handle();
+ TF_ASSIGN_OR_RETURN(
+ old_to_new[old_handle],
+ LoadSessionComputation(session_module.entry(), &old_to_new));
+ return old_to_new[old_handle];
+}
+
+StatusOr<std::unique_ptr<SessionModule>>
+ComputationTracker::SnapshotComputation(const ComputationHandle& computation) {
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation));
+ const VersionedComputationHandle entry_versioned_handle =
+ user_computation->GetVersionedHandle();
+ std::set<VersionedComputationHandle> visited;
+ std::list<VersionedComputationHandle> post_order;
+ {
+ tensorflow::mutex_lock lock(computation_mutex_);
+ ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order);
+ }
+ auto session_module = MakeUnique<SessionModule>();
+ *session_module->mutable_entry() =
+ Resolve(entry_versioned_handle.handle)
+ .ValueOrDie()
+ ->CloneSessionComputation(entry_versioned_handle.version);
+ for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) {
+ *session_module->add_embedded_computations() =
+ Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version);
+ }
+ return std::move(session_module);
+}
+
+StatusOr<UserComputation*> ComputationTracker::Resolve(
+ const ComputationHandle& computation) const {
+ tensorflow::mutex_lock lock(computation_mutex_);
+ return ResolveInternal(computation);
+}
+
+ComputationHandle ComputationTracker::AllocateHandle() {
+ int64 handle_value = next_computation_++;
+ ComputationHandle result;
+ result.set_handle(handle_value);
+ return result;
+}
+
+StatusOr<ComputationHandle> ComputationTracker::LoadSessionComputation(
+ const SessionComputation& session_computation,
+ std::map<int64, ComputationHandle>* old_to_new) {
+ TF_RET_CHECK(old_to_new != nullptr);
+ const ComputationHandle new_handle = AllocateHandle();
+ (*old_to_new)[session_computation.computation_handle().handle()] = new_handle;
+ TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()],
+ UserComputation::MakeWithRemapping(
+ session_computation, new_handle, *old_to_new));
+ return new_handle;
+}
+
+StatusOr<UserComputation*> ComputationTracker::ResolveInternal(
+ const ComputationHandle& computation) const {
+ auto it = opaque_to_computation_.find(computation.handle());
+ if (it == opaque_to_computation_.end()) {
+ return NotFound("computation handle not found: %lld", computation.handle());
+ }
+ UserComputation* user_computation = it->second.get();
+ return user_computation;
+}
+
+void ComputationTracker::ComputeComputationPostOrder(
+ const VersionedComputationHandle& versioned_handle,
+ std::set<VersionedComputationHandle>* visited,
+ std::list<VersionedComputationHandle>* post_order) const {
+ if (visited->count(versioned_handle) > 0) {
+ DCHECK_EQ(1, visited->count(versioned_handle));
+ return;
+ }
+
+ UserComputation* computation =
+ ResolveInternal(versioned_handle.handle).ValueOrDie();
+ std::vector<VersionedComputationHandle> embedded_handles =
+ computation->GetEmbeddedComputations(versioned_handle.version);
+
+ for (const auto& embedded_handle : embedded_handles) {
+ ComputeComputationPostOrder(embedded_handle, visited, post_order);
+ }
+
+ visited->insert(versioned_handle);
+ post_order->push_back(versioned_handle);
+ return;
+}
+
+StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
+ const VersionedComputationHandle& entry_handle,
+ bool include_unused_parameters) const {
+ tensorflow::mutex_lock lock(computation_mutex_);
+
+ TF_ASSIGN_OR_RETURN(UserComputation * entry_computation,
+ ResolveInternal(entry_handle.handle));
+
+ // Build a topological sort of the entry and any embedded computations as a
+ // list. The root of the computation will be the last element in the list.
+ std::set<VersionedComputationHandle> visited;
+ std::list<VersionedComputationHandle> post_order;
+ ComputeComputationPostOrder(entry_handle, &visited, &post_order);
+
+ // Map from ComputationHandle value and computation version to HloComputation.
+ std::map<VersionedComputationHandle, HloComputation*> hlo_computations;
+
+ // The resolver lambda resolves VersionedHandles to embedded
+ // HloComputation*. This is required by UserComputation::BuildHloComputation
+ // when lowering calling operations (map, reduce etc).
+ auto resolver = [&hlo_computations](
+ const VersionedComputationHandle& versioned_handle) -> HloComputation* {
+ CHECK_GT(hlo_computations.count(versioned_handle), 0);
+ return hlo_computations.at(versioned_handle);
+ };
+
+ string module_name =
+ tensorflow::strings::StrCat(entry_computation->name(), "_module");
+ auto module = MakeUnique<HloModule>(module_name, entry_handle);
+ for (auto versioned_handle : post_order) {
+ UserComputation* computation =
+ ResolveInternal(versioned_handle.handle).ValueOrDie();
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloComputation> hlo_computation,
+ computation->BuildHloComputation(versioned_handle.version, resolver,
+ include_unused_parameters));
+
+ // Add the newly created computation to VersionedHandle-to-HloComputation
+ // map.
+ DCHECK_EQ(0, hlo_computations.count(versioned_handle));
+ hlo_computations[versioned_handle] = hlo_computation.get();
+
+ if (computation == entry_computation) {
+ module->AddEntryComputation(std::move(hlo_computation));
+ } else {
+ module->AddEmbeddedComputation(std::move(hlo_computation));
+ }
+ }
+
+ return std::move(module);
+}
+
+} // namespace xla