diff options
Diffstat (limited to 'tensorflow/core/kernels/functional_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/functional_ops.cc | 189 |
1 files changed, 168 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index b687088db1..911aa3a78f 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -21,10 +20,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/mutex.h" -namespace tensorflow { +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/stream.h" +#endif // GOOGLE_CUDA +namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice; typedef FunctionLibraryRuntime::Handle FHandle; @@ -106,11 +107,9 @@ void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, opts->runner = ctx->runner(); } -} // end namespace - -class FunctionalIf : public AsyncOpKernel { +class IfOp : public AsyncOpKernel { public: - explicit FunctionalIf(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { auto lib = ctx->function_library(); OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library")); const NameAttrList* func; @@ -120,7 +119,7 @@ class FunctionalIf : public AsyncOpKernel { OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_)); } - ~FunctionalIf() override {} + ~IfOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { bool cond; @@ -134,8 +133,7 @@ class FunctionalIf : public AsyncOpKernel { class State { public: - State(FunctionalIf* kernel, OpKernelContext* ctx, bool cond, - DoneCallback done) + State(IfOp* kernel, OpKernelContext* ctx, bool cond, DoneCallback done) : kernel_(kernel), ctx_(ctx), cond_(cond), @@ -168,7 +166,7 @@ class FunctionalIf : public AsyncOpKernel { } private: - FunctionalIf* const kernel_; + IfOp* const kernel_; OpKernelContext* const ctx_; const bool cond_; const DoneCallback done_; @@ -179,18 +177,22 @@ class FunctionalIf : public AsyncOpKernel { }; }; -REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), FunctionalIf); +// TODO(drpng): remove this. +REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp); REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"), - FunctionalIf); + IfOp); + +REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp); +REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp); -class FunctionalWhile : public AsyncOpKernel { +class WhileOp : public AsyncOpKernel { public: - explicit FunctionalWhile(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_)); } - ~FunctionalWhile() override {} + ~WhileOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { auto lib = ctx->function_library(); @@ -234,7 +236,7 @@ class FunctionalWhile : public AsyncOpKernel { class State { public: - State(FunctionalWhile* kernel, OpKernelContext* ctx, FHandle cond_handle, + State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, FHandle body_handle, DoneCallback done) : kernel_(kernel), ctx_(ctx), @@ -253,7 +255,7 @@ class FunctionalWhile : public AsyncOpKernel { void Start() { EvalCond(); } private: - FunctionalWhile* const kernel_; + WhileOp* const kernel_; OpKernelContext* const ctx_; const FHandle cond_handle_; const FHandle body_handle_; @@ -316,7 +318,152 @@ class FunctionalWhile : public AsyncOpKernel { } }; }; -REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), FunctionalWhile); -REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), FunctionalWhile); +// TODO(drpng): remove these. +REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp); + +REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp); + +Status GetScalar(OpKernelContext* ctx, int index, int32* value, + const char* label) { + Tensor t = ctx->input(index); + if (!TensorShapeUtils::IsScalar(t.shape())) { + return errors::InvalidArgument(label, " must be a scalar, but ", + t.shape().DebugString()); + } + *value = t.scalar<int32>()(); + return Status::OK(); +} + +class ForOp : public AsyncOpKernel { + public: + explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + auto lib = ctx->function_library(); + OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library")); + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func)); + OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_)); + } + + ~ForOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + (new State(this, ctx, done))->Start(); + } + + private: + FHandle body_handle_; + + class State { + public: + State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done) + : kernel_(kernel), + ctx_(ctx), + done_(std::move(done)), + lib_(CHECK_NOTNULL(ctx_->function_library())), + args_(1 + ctx_->num_inputs() - 3) { + args_[0] = Tensor(DT_INT32, {}); + iter_ = &args_[0].scalar<int32>()(); + + const int32 num_loop_inputs = ctx_->num_inputs() - 3; + rets_.reserve(num_loop_inputs); + for (int i = 0; i < num_loop_inputs; ++i) { + rets_.push_back(ctx_->input(3 + i)); + } + } + + ~State() {} + + void Start() { + Status s = StartLoop(); + if (!s.ok()) Finish(s); + } + + private: + ForOp* const kernel_; + OpKernelContext* const ctx_; + const DoneCallback done_; + FunctionLibraryRuntime* const lib_; + FunctionLibraryRuntime::Options opts_; + TensorVec args_; + TensorVec rets_; + + int32* iter_; // points to args_[0]. + int32 limit_; + int32 delta_; + + // If an error e is returned, caller must call Finish(e). + // If OK is returned, the async loop execution has been started. + Status StartLoop() { + SetRunOptions(ctx_, &opts_, false /* always_collect_stats */); + + TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start")); + TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit")); + TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta")); + + if ((delta_ > 0 && *iter_ <= limit_) || + (delta_ < 0 && *iter_ >= limit_) || + (delta_ == 0 && *iter_ == limit_)) { + RunNext(); + return Status::OK(); + } else { + return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_, + " ", limit_, " ", delta_); + } + } + + void RunNext() { + bool done_loop; + if (delta_ > 0) { + done_loop = *iter_ >= limit_; + } else { + done_loop = *iter_ <= limit_; + } + if (done_loop) { + Finish(Status::OK()); + return; + } + + if (rets_.size() >= args_.size()) { + Finish(errors::InvalidArgument( + "For loop body returned ", rets_.size(), + " arguments. Expected: ", args_.size() - 1)); + return; + } + for (int i = 0; i < rets_.size(); ++i) { + args_[1 + i] = std::move(rets_[i]); + } + rets_.clear(); + lib_->Run(opts_, kernel_->body_handle_, args_, &rets_, + [this](const Status& s) { + if (s.ok()) { + *iter_ += delta_; + RunNext(); + } else { + Finish(s); + } + }); + } + + void Finish(Status s) { + if (s.ok()) { + s = SetOutputs(kernel_, ctx_, rets_); + } + ctx_->SetStatus(s); + done_(); + delete this; + } + }; +}; + +REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp); +REGISTER_KERNEL_BUILDER(Name("For") + .Device(DEVICE_GPU) + .HostMemory("start") + .HostMemory("limit") + .HostMemory("delta"), + ForOp); +} // namespace } // namespace tensorflow |