aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_runner.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-10-24 20:05:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-24 20:09:09 -0700
commit557b0b27edff763c165ad59d10d49da8bccbec4f (patch)
tree98836179520fb866dcdad16964594dc4fb6646fd /tensorflow/compiler/xla/service/hlo_runner.cc
parent16953025097793d9748099ebf4296edca04a5366 (diff)
Make HloRunner methods return StatusOr. Also move templated method definition
of Execute into the header file. PiperOrigin-RevId: 173348703
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_runner.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc40
1 files changed, 9 insertions, 31 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index d5d7042a02..9fdda38d2d 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
@@ -133,7 +132,8 @@ StatusOr<se::DeviceMemoryBase> HloRunner::Execute(
return result;
}
-se::DeviceMemoryBase HloRunner::TransferToDevice(const Literal& literal) {
+StatusOr<se::DeviceMemoryBase> HloRunner::TransferToDevice(
+ const Literal& literal) {
// Allocate memory on the device using the stream executor.
int64 allocation_size =
backend().transfer_manager()->GetByteSizeRequirement(literal.shape());
@@ -142,52 +142,30 @@ se::DeviceMemoryBase HloRunner::TransferToDevice(const Literal& literal) {
allocation_size);
allocations_.push_back(allocation);
- TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice(
+ TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
backend().default_stream_executor(), literal, &allocation));
return allocation;
}
-std::unique_ptr<Literal> HloRunner::TransferFromDevice(
+StatusOr<std::unique_ptr<Literal>> HloRunner::TransferFromDevice(
const Shape& shape, se::DeviceMemoryBase device_base) {
auto literal = MakeUnique<Literal>();
- TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice(
+ TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromDevice(
backend().default_stream_executor(), device_base, shape, shape,
literal.get()));
- return literal;
+ return std::move(literal);
}
-std::unique_ptr<Literal> HloRunner::ExecuteAndTransfer(
+StatusOr<std::unique_ptr<Literal>> HloRunner::ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
Shape result_shape;
- se::DeviceMemoryBase device_base =
- Execute(std::move(module), arguments, &result_shape).ValueOrDie();
+ TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase device_base,
+ Execute(std::move(module), arguments, &result_shape));
return TransferFromDevice(result_shape, device_base);
}
-template <>
-std::unique_ptr<Literal> HloRunner::Execute(
- std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>>& literals) {
- std::vector<se::DeviceMemoryBase> arguments;
- for (const auto& literal : literals) {
- arguments.push_back(TransferToDevice(*literal));
- }
- return ExecuteAndTransfer(std::move(module), arguments);
-}
-
-template <>
-std::unique_ptr<Literal> HloRunner::Execute(
- std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<Literal*>& literals) {
- std::vector<se::DeviceMemoryBase> arguments;
- for (const auto& literal : literals) {
- arguments.push_back(TransferToDevice(*literal));
- }
- return ExecuteAndTransfer(std::move(module), arguments);
-}
-
Backend& HloRunner::backend() {
if (!backend_) {
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();