aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 16:53:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 17:06:10 -0700
commitbb5fc614a4a358b350ef8dd19cb7010760fa9b29 (patch)
tree43a745ffdc409d0ff4660342d6a62735ac366a13
parent65b7d0b2f84c334327a295bf41bc06c7f6b8ffe5 (diff)
[XLA] Cleanup: Make AllocationTracker::Resolve const.
So that when resolving some global data, we don't have to worry whether "Resolve" is going to mutate the real data. PiperOrigin-RevId: 216448145
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc6
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.h8
-rw-r--r--tensorflow/compiler/xla/service/service.cc4
-rw-r--r--tensorflow/compiler/xla/service/service.h4
4 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index 1ed6142dce..ef5e211646 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -176,13 +176,13 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
}
StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::Resolve(
- const GlobalDataHandle& data) {
+ const GlobalDataHandle& data) const {
tensorflow::mutex_lock lock(mutex_);
return AllocationTracker::ResolveInternal(data);
}
StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica(
- const GlobalDataHandle& data, int replica_id) {
+ const GlobalDataHandle& data, int replica_id) const {
tensorflow::mutex_lock lock(mutex_);
TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
ResolveInternal(data));
@@ -196,7 +196,7 @@ StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica(
}
StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
- const GlobalDataHandle& data) {
+ const GlobalDataHandle& data) const {
VLOG(2) << "resolve:" << data.handle();
auto it = handle_to_shaped_buffers_.find(data.handle());
if (it == handle_to_shaped_buffers_.end()) {
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h
index 43feccee3c..98d1a302a9 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.h
+++ b/tensorflow/compiler/xla/service/allocation_tracker.h
@@ -65,13 +65,13 @@ class AllocationTracker {
// replica, or provide an error status to say whether any of those buffers
// were not found (or found, but found deallocated).
StatusOr<std::vector<const ShapedBuffer*>> Resolve(
- const GlobalDataHandle& data);
+ const GlobalDataHandle& data) const;
// Resolves a handle from an XLA client and replica id to a shaped buffer, or
// provide an error status to say whether it was not found (or found, but
// found deallocated).
StatusOr<const ShapedBuffer*> ResolveForReplica(const GlobalDataHandle& data,
- int replica_id);
+ int replica_id) const;
private:
// Data structure encapsulating single memory allocation on the device.
@@ -87,7 +87,7 @@ class AllocationTracker {
// Internal helper which resolves the given GlobalDataHandle to a
// list of ScopedShapedBuffers.
StatusOr<std::vector<const ShapedBuffer*>> ResolveInternal(
- const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ const GlobalDataHandle& data) const EXCLUSIVE_LOCKS_REQUIRED(mutex_);
// Internal helper which registers a vector of shaped buffers, one per
// replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If
@@ -113,7 +113,7 @@ class AllocationTracker {
// maintained per device ordinal.
using AllocationMap = absl::flat_hash_map<const void*, Allocation>;
- tensorflow::mutex mutex_;
+ mutable tensorflow::mutex mutex_;
// Backend to use with this tracker. The backend supplies the memory allocator
// to use when deallocating memory.
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index b27a92f2a0..084df17951 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -207,7 +207,7 @@ Status Service::ValidateResultShape(const Shape& client_shape,
StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
Service::ResolveAndValidateArguments(
absl::Span<const GlobalDataHandle* const> arguments,
- absl::Span<se::StreamExecutor* const> stream_executors) {
+ absl::Span<se::StreamExecutor* const> stream_executors) const {
CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
replicated_arguments.resize(options_.number_of_replicas());
@@ -590,7 +590,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
const ExecutionOptions& execution_options,
- absl::Span<const GlobalDataHandle* const> arguments) {
+ absl::Span<const GlobalDataHandle* const> arguments) const {
// Resolve the allocations for the arguments of the computation, and create
// a vector of device memory offsets for the arguments from the allocations.
// In the case of partitioned computations, assume all arguments go on the
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 1f62fad4c8..8cf1a7b9f0 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -191,7 +191,7 @@ class Service : public ServiceInterface {
// Prepare the arguments for executing parallel.
StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments(
const ExecutionOptions& execution_options,
- absl::Span<const GlobalDataHandle* const> arguments);
+ absl::Span<const GlobalDataHandle* const> arguments) const;
protected:
friend class LocalExecutable;
@@ -208,7 +208,7 @@ class Service : public ServiceInterface {
StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(
absl::Span<const GlobalDataHandle* const> arguments,
- absl::Span<se::StreamExecutor* const> stream_executors);
+ absl::Span<se::StreamExecutor* const> stream_executors) const;
// Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used.