aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/where_op.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-07-06 09:49:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-06 09:53:12 -0700
commit9cf44446550b6d2c3141074013509875649b0fd5 (patch)
tree86adf23dcb8eb2ad73b109c4d5178670f1805e78 /tensorflow/core/kernels/where_op.cc
parenta75496f74cad8c6bd25ee4c6c17d3f52199b2fe8 (diff)
Bugfixes for GPU WhereOp.
1. Set the cuda context properly within ComputeAsync. Also set the cuda context properly in the WhereOp GPU callback. 2. Ensure report_uninitialized_variables runs on CPU. This avoids intermediate copying of data to GPU after getting the variables' state and before returning it. PiperOrigin-RevId: 161092040
Diffstat (limited to 'tensorflow/core/kernels/where_op.cc')
-rw-r--r--tensorflow/core/kernels/where_op.cc7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc
index 6fdcb331cc..59b474e41c 100644
--- a/tensorflow/core/kernels/where_op.cc
+++ b/tensorflow/core/kernels/where_op.cc
@@ -40,6 +40,9 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
+#include "tensorflow/core/platform/cuda.h"
+
+using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
#endif // GOOGLE_CUDA
namespace tensorflow {
@@ -257,6 +260,10 @@ class WhereGPUOp : public AsyncOpKernel {
auto create_and_check_output = [context, &d, &input, input_dims,
num_true_host, done]() {
+ // Ensure that within the callback, the proper GPU settings are
+ // configured.
+ auto stream = context->op_device_context()->stream();
+ ScopedActivateExecutorContext scoped_activation{stream->parent()};
Tindex num_true = *num_true_host.data();