diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/computation_placer.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/computation_placer.cc | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index d26486fcfe..187ce568cb 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -29,9 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + namespace xla { Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { @@ -71,6 +75,19 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { return std::move(assignment); } +string DeviceAssignment::ToString() const { + string output = StrCat("Computations: ", computation_count(), + " Replicas: ", replica_count(), "\n"); + for (int computation = 0; computation < computation_count(); ++computation) { + StrAppend(&output, "Computation ", computation, ": "); + for (int replica = 0; replica < replica_count(); ++replica) { + StrAppend(&output, operator()(replica, computation), " "); + } + StrAppend(&output, "\n"); + } + return output; +} + StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation, int replica_count, int computation_count) { |