aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_launch_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc62
1 files changed, 35 insertions, 27 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 6134b8c694..4efbb2d5d7 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include <memory>
+
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -182,7 +184,7 @@ void XlaComputationLaunchContext::PopulateInputs(
}
}
-void XlaComputationLaunchContext::PopulateOutputs(
+Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output) {
se::Stream* stream =
@@ -211,6 +213,15 @@ void XlaComputationLaunchContext::PopulateOutputs(
output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
}
+ std::shared_ptr<se::Event> definition_event;
+ if (use_multiple_streams_) {
+ definition_event = std::make_shared<se::Event>(stream->parent());
+ if (!definition_event->Init()) {
+ return errors::Internal("Failed to initialize tensor definition event.");
+ }
+ stream->ThenRecordEvent(definition_event.get());
+ }
+
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
@@ -228,12 +239,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
- OP_REQUIRES_OK(
- ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
- OP_REQUIRES(ctx, device != nullptr,
- errors::Internal("DeviceBase was not a Device."));
+ if (device == nullptr) {
+ return errors::Internal("DeviceBase was not a Device.");
+ }
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
@@ -263,16 +275,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
+ TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) {
- se::Event event(stream->parent());
- CHECK(event.Init());
- stream->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(stream, std::move(event));
+ xla_tensor->SetDefinedOn(stream, definition_event);
}
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
@@ -298,41 +307,39 @@ void XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- OP_REQUIRES(ctx,
- write.input_index >= 0 && write.input_index < ctx->num_inputs(),
- errors::Internal("Invalid input index for variable write."));
+ if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
+ return errors::Internal("Invalid input index for variable write.");
+ }
se::DeviceMemoryBase buffer = output.buffer({output_num});
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
- OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index),
- &variable, [this, ctx, &write](Var** ptr) {
- *ptr = new Var(write.type);
- return Status::OK();
- }));
+ TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
+ ctx, HandleFromInput(ctx, write.input_index), &variable,
+ [&write](Var** ptr) {
+ *ptr = new Var(write.type);
+ return Status::OK();
+ }));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu());
- OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
- errors::Internal("Mismatched type in variable write"));
+ if (variable->tensor()->dtype() != write.type) {
+ return errors::Internal("Mismatched type in variable write");
+ }
if (allocate_xla_tensors_) {
Tensor output_tensor;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor));
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_temp(write.type, write.shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
if (use_multiple_streams_) {
- se::Event event(stream->parent());
- CHECK(event.Init());
- stream->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(stream, std::move(event));
+ xla_tensor->SetDefinedOn(stream, definition_event);
}
*variable->tensor() = output_tensor;
} else {
@@ -343,6 +350,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
}
++output_num;
}
+ return Status::OK();
}
} // namespace tensorflow