aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/functional_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/functional_ops.cc')
-rw-r--r--tensorflow/core/kernels/functional_ops.cc189
1 files changed, 168 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index b687088db1..911aa3a78f 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -21,10 +20,12 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/platform/mutex.h"
-namespace tensorflow {
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/stream.h"
+#endif // GOOGLE_CUDA
+namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef FunctionLibraryRuntime::Handle FHandle;
@@ -106,11 +107,9 @@ void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
opts->runner = ctx->runner();
}
-} // end namespace
-
-class FunctionalIf : public AsyncOpKernel {
+class IfOp : public AsyncOpKernel {
public:
- explicit FunctionalIf(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
auto lib = ctx->function_library();
OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
const NameAttrList* func;
@@ -120,7 +119,7 @@ class FunctionalIf : public AsyncOpKernel {
OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_));
}
- ~FunctionalIf() override {}
+ ~IfOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
bool cond;
@@ -134,8 +133,7 @@ class FunctionalIf : public AsyncOpKernel {
class State {
public:
- State(FunctionalIf* kernel, OpKernelContext* ctx, bool cond,
- DoneCallback done)
+ State(IfOp* kernel, OpKernelContext* ctx, bool cond, DoneCallback done)
: kernel_(kernel),
ctx_(ctx),
cond_(cond),
@@ -168,7 +166,7 @@ class FunctionalIf : public AsyncOpKernel {
}
private:
- FunctionalIf* const kernel_;
+ IfOp* const kernel_;
OpKernelContext* const ctx_;
const bool cond_;
const DoneCallback done_;
@@ -179,18 +177,22 @@ class FunctionalIf : public AsyncOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), FunctionalIf);
+// TODO(drpng): remove this.
+REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp);
REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
- FunctionalIf);
+ IfOp);
+
+REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
+REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
-class FunctionalWhile : public AsyncOpKernel {
+class WhileOp : public AsyncOpKernel {
public:
- explicit FunctionalWhile(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_));
}
- ~FunctionalWhile() override {}
+ ~WhileOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
auto lib = ctx->function_library();
@@ -234,7 +236,7 @@ class FunctionalWhile : public AsyncOpKernel {
class State {
public:
- State(FunctionalWhile* kernel, OpKernelContext* ctx, FHandle cond_handle,
+ State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
FHandle body_handle, DoneCallback done)
: kernel_(kernel),
ctx_(ctx),
@@ -253,7 +255,7 @@ class FunctionalWhile : public AsyncOpKernel {
void Start() { EvalCond(); }
private:
- FunctionalWhile* const kernel_;
+ WhileOp* const kernel_;
OpKernelContext* const ctx_;
const FHandle cond_handle_;
const FHandle body_handle_;
@@ -316,7 +318,152 @@ class FunctionalWhile : public AsyncOpKernel {
}
};
};
-REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), FunctionalWhile);
-REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), FunctionalWhile);
+// TODO(drpng): remove these.
+REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
+REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp);
+
+REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
+REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp);
+
+Status GetScalar(OpKernelContext* ctx, int index, int32* value,
+ const char* label) {
+ Tensor t = ctx->input(index);
+ if (!TensorShapeUtils::IsScalar(t.shape())) {
+ return errors::InvalidArgument(label, " must be a scalar, but ",
+ t.shape().DebugString());
+ }
+ *value = t.scalar<int32>()();
+ return Status::OK();
+}
+
+class ForOp : public AsyncOpKernel {
+ public:
+ explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ auto lib = ctx->function_library();
+ OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
+ const NameAttrList* func;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func));
+ OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_));
+ }
+
+ ~ForOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ (new State(this, ctx, done))->Start();
+ }
+
+ private:
+ FHandle body_handle_;
+
+ class State {
+ public:
+ State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done)
+ : kernel_(kernel),
+ ctx_(ctx),
+ done_(std::move(done)),
+ lib_(CHECK_NOTNULL(ctx_->function_library())),
+ args_(1 + ctx_->num_inputs() - 3) {
+ args_[0] = Tensor(DT_INT32, {});
+ iter_ = &args_[0].scalar<int32>()();
+
+ const int32 num_loop_inputs = ctx_->num_inputs() - 3;
+ rets_.reserve(num_loop_inputs);
+ for (int i = 0; i < num_loop_inputs; ++i) {
+ rets_.push_back(ctx_->input(3 + i));
+ }
+ }
+
+ ~State() {}
+
+ void Start() {
+ Status s = StartLoop();
+ if (!s.ok()) Finish(s);
+ }
+
+ private:
+ ForOp* const kernel_;
+ OpKernelContext* const ctx_;
+ const DoneCallback done_;
+ FunctionLibraryRuntime* const lib_;
+ FunctionLibraryRuntime::Options opts_;
+ TensorVec args_;
+ TensorVec rets_;
+
+ int32* iter_; // points to args_[0].
+ int32 limit_;
+ int32 delta_;
+
+ // If an error e is returned, caller must call Finish(e).
+ // If OK is returned, the async loop execution has been started.
+ Status StartLoop() {
+ SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
+
+ TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start"));
+ TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit"));
+ TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta"));
+
+ if ((delta_ > 0 && *iter_ <= limit_) ||
+ (delta_ < 0 && *iter_ >= limit_) ||
+ (delta_ == 0 && *iter_ == limit_)) {
+ RunNext();
+ return Status::OK();
+ } else {
+ return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_,
+ " ", limit_, " ", delta_);
+ }
+ }
+
+ void RunNext() {
+ bool done_loop;
+ if (delta_ > 0) {
+ done_loop = *iter_ >= limit_;
+ } else {
+ done_loop = *iter_ <= limit_;
+ }
+ if (done_loop) {
+ Finish(Status::OK());
+ return;
+ }
+
+ if (rets_.size() >= args_.size()) {
+ Finish(errors::InvalidArgument(
+ "For loop body returned ", rets_.size(),
+ " arguments. Expected: ", args_.size() - 1));
+ return;
+ }
+ for (int i = 0; i < rets_.size(); ++i) {
+ args_[1 + i] = std::move(rets_[i]);
+ }
+ rets_.clear();
+ lib_->Run(opts_, kernel_->body_handle_, args_, &rets_,
+ [this](const Status& s) {
+ if (s.ok()) {
+ *iter_ += delta_;
+ RunNext();
+ } else {
+ Finish(s);
+ }
+ });
+ }
+
+ void Finish(Status s) {
+ if (s.ok()) {
+ s = SetOutputs(kernel_, ctx_, rets_);
+ }
+ ctx_->SetStatus(s);
+ done_();
+ delete this;
+ }
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp);
+REGISTER_KERNEL_BUILDER(Name("For")
+ .Device(DEVICE_GPU)
+ .HostMemory("start")
+ .HostMemory("limit")
+ .HostMemory("delta"),
+ ForOp);
+} // namespace
} // namespace tensorflow