aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/function.cc')
-rw-r--r--tensorflow/core/common_runtime/function.cc1335
1 files changed, 1335 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
new file mode 100644
index 0000000000..2b1a041235
--- /dev/null
+++ b/tensorflow/core/common_runtime/function.cc
@@ -0,0 +1,1335 @@
+#include "tensorflow/core/common_runtime/function.h"
+
+#include <deque>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/optimizer_cse.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+// A few string constant used throughout this module.
+static const char* const kArgOp = "_Arg";
+static const char* const kRetOp = "_Retval";
+static const char* const kGradientOp = "SymbolicGradient";
+static const char* const kNodeLabel = "Func";
+
+// Represents the index-th output of a node.
+struct Endpoint {
+ Node* node;
+ int index;
+
+ // Returns the string name represents this endpoint.
+ string name() const {
+ if (index == 0) {
+ return node->name();
+ } else {
+ return strings::StrCat(node->name(), ":", index);
+ }
+ }
+
+ DataType dtype() const { return node->output_type(index); }
+};
+
+struct EndpointHash {
+ uint64 operator()(const Endpoint& x) const {
+ return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
+ x.index);
+ }
+};
+
+struct EndpointEq {
+ bool operator()(const Endpoint& x, const Endpoint& y) const {
+ return (x.node == y.node) && (x.index == y.index);
+ }
+};
+
+// The following Add* routines are used to add a few graph nodes while
+// functions are transformed.
+static Node* AddNoOp(Graph* g) {
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("NoOp");
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ return ret;
+}
+
+static Node* AddIdentity(Graph* g, Endpoint input) {
+ DCHECK_LT(0, input.dtype());
+ DCHECK_LT(input.dtype(), DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("Identity");
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+}
+
+static Node* AddArg(Graph* g, DataType dtype, int index) {
+ DCHECK_LT(0, dtype);
+ DCHECK_LT(dtype, DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op(kArgOp);
+ AddNodeAttr("T", dtype, &ndef);
+ AddNodeAttr("index", index, &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ return ret;
+}
+
+static Node* AddRet(Graph* g, Endpoint input, int index) {
+ DCHECK_LT(0, input.dtype());
+ DCHECK_LT(input.dtype(), DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op(kRetOp);
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ AddNodeAttr("index", index, &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+}
+
+static Node* AddZerosLike(Graph* g, Endpoint input) {
+ DCHECK_LT(0, input.dtype());
+ DCHECK_LT(input.dtype(), DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("ZerosLike");
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+}
+
+static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<Endpoint> grads) {
+ const int num_x = n->num_inputs();
+ const int num_y = n->num_outputs();
+ CHECK_EQ(num_y, grads.size());
+
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op(kGradientOp);
+
+ // The gradient node should have num_x + num_y inputs.
+ std::vector<Endpoint> n_inputs(num_x);
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ n_inputs[e->dst_input()] = {e->src(), e->src_output()};
+ }
+ DataTypeVector in_types;
+ for (const Endpoint& ep : n_inputs) {
+ ndef.add_input(ep.name());
+ in_types.push_back(ep.dtype());
+ }
+ for (const Endpoint& ep : grads) {
+ ndef.add_input(ep.name());
+ in_types.push_back(ep.dtype());
+ }
+ CHECK_EQ(ndef.input_size(), num_x + num_y);
+
+ AddNodeAttr("Tin", in_types, &ndef);
+
+ // The gradient node's outputs have the same types as the node 'n's
+ // inputs.
+ AddNodeAttr("Tout", n->input_types(), &ndef);
+ NameAttrList func;
+ func.set_name(n->type_string());
+ *(func.mutable_attr()) = n->def().attr();
+ AddNodeAttr("f", func, &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ return ret;
+}
+
+class ArgOp : public OpKernel {
+ public:
+ explicit ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ auto frame = ctx->call_frame();
+ OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
+ Tensor val;
+ OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
+ OP_REQUIRES(ctx, val.dtype() == dtype_,
+ errors::InvalidArgument(
+ "Type mismatch: actual ", DataTypeString(val.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ ctx->set_output(0, val);
+ }
+
+ private:
+ int index_;
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
+REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_GPU), ArgOp);
+
+class RetvalOp : public OpKernel {
+ public:
+ explicit RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& val = ctx->input(0);
+ OP_REQUIRES(ctx, val.dtype() == dtype_,
+ errors::InvalidArgument(
+ "Type mismatch: actual ", DataTypeString(val.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ auto frame = ctx->call_frame();
+ OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
+ OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
+ }
+
+ private:
+ int index_;
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
+REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_GPU), RetvalOp);
+
+static const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
+
+class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
+ public:
+ FunctionLibraryRuntimeImpl(Device* device, Runner runner,
+ const FunctionLibraryDefinition* lib_def);
+
+ ~FunctionLibraryRuntimeImpl() override;
+
+ Status Instantiate(const string& function_name,
+ const InstantiateAttrValueMap& attrs,
+ Handle* handle) override;
+
+ const FunctionBody* GetFunctionBody(Handle handle) override;
+
+ Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
+
+ void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
+ std::vector<Tensor>* rets, DoneCallback done) override;
+
+ bool IsDefined(const string& function_name) override;
+
+ private:
+ typedef FunctionLibraryRuntimeImpl ME;
+
+ Device* const device_;
+ Runner runner_ = nullptr;
+ const FunctionLibraryDefinition* const lib_def_;
+ std::function<Status(const string&, const OpDef**)> get_func_sig_;
+ std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
+
+ mutable mutex mu_;
+
+ // Maps function instantiation to a handle. The key is a
+ // canonicalized representation of the function name and
+ // instantiation attrs. The handle is an index into the items_.
+ std::unordered_map<string, Handle> table_ GUARDED_BY(mu_);
+
+ // func_graphs_ never shrinks or reorders its members.
+ std::vector<FunctionBody*> func_graphs_ GUARDED_BY(mu_);
+
+ // The instantiated and transformed function is encoded as a Graph
+ // object, and an executor is created for the graph.
+ struct Item : public core::RefCounted {
+ Executor* exec = nullptr;
+
+ ~Item() override { delete this->exec; }
+ };
+ std::vector<Item*> items_;
+
+ Status FunctionDefToBody(const FunctionDef& fdef,
+ const InstantiateAttrValueMap& attrs,
+ FunctionBody** fbody);
+ Status CreateItem(Handle handle, Item** item);
+ Status GetOrCreateItem(Handle handle, Item** item);
+ Status InstantiateSymbolicGradient(const InstantiateAttrValueMap& attrs,
+ FunctionBody** g_body);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
+};
+
+FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
+ Device* device, Runner runner, const FunctionLibraryDefinition* lib_def)
+ : device_(device), runner_(runner), lib_def_(lib_def) {
+ get_func_sig_ = [this](const string& op, const OpDef** sig) {
+ Status s;
+ *sig = lib_def_->LookUp(op, &s);
+ return s;
+ };
+ create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
+ return CreateKernel(ndef, kernel);
+ };
+}
+
+FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
+ for (FunctionBody* p : func_graphs_) delete p;
+ for (Item* item : items_)
+ if (item) item->Unref();
+}
+
+// An asynchronous op kernel which executes an instantiated function
+// defined in a library.
+class CallOp : public AsyncOpKernel {
+ public:
+ CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx), handle_(handle) {}
+
+ ~CallOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES_ASYNC(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."),
+ done);
+ FunctionLibraryRuntime::Options opts;
+ std::vector<Tensor> args;
+ args.reserve(ctx->num_inputs());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ args.push_back(ctx->input(i));
+ }
+ std::vector<Tensor>* rets = new std::vector<Tensor>;
+ lib->Run(opts, handle_, args, rets,
+ [ctx, done, rets](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ } else {
+ CHECK_EQ(rets->size(), ctx->num_outputs());
+ for (size_t i = 0; i < rets->size(); ++i) {
+ ctx->set_output(i, (*rets)[i]);
+ }
+ }
+ delete rets;
+ done();
+ });
+ }
+
+ private:
+ FunctionLibraryRuntime::Handle handle_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
+};
+
+const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
+ mutex_lock l(mu_);
+ CHECK_LE(0, h);
+ CHECK_LT(h, func_graphs_.size());
+ return func_graphs_[h];
+}
+
+Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
+ OpKernel** kernel) {
+ if (ndef.op() != kGradientOp && (lib_def_->Find(ndef.op()) == nullptr)) {
+ return CreateNonCachedKernel(device_, this, ndef, kernel);
+ }
+
+ // Try to instantiate this function for the func/attr. Maybe its
+ // cached already.
+ Handle handle;
+ TF_RETURN_IF_ERROR(Instantiate(ndef.op(), ndef.attr(), &handle));
+
+ const FunctionBody* fbody = GetFunctionBody(handle);
+ CHECK_NOTNULL(fbody);
+
+ // Constructs a CallOp kernel for running the instantiated function.
+ Status s;
+ auto device_type = DeviceType(device_->attributes().device_type());
+ OpKernelConstruction construction(
+ device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
+ &fbody->fdef.signature(), this, fbody->arg_types, fbody->ret_types, &s);
+ *kernel = new CallOp(handle, &construction);
+ if (!s.ok()) {
+ delete kernel;
+ }
+ return s;
+}
+
+Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
+ const FunctionDef& fdef, const InstantiateAttrValueMap& attrs,
+ FunctionBody** fbody) {
+ // Instantiates the function template into a graph def.
+ InstantiationResult result;
+ TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig_, &result));
+
+ Graph* graph = new Graph(lib_def_);
+ GraphConstructorOptions opts;
+ opts.allow_internal_ops = true;
+ opts.expect_device_spec = false;
+ Status s = ConvertGraphDefToGraph(opts, result.gdef, graph);
+ if (!s.ok()) {
+ delete graph;
+ } else {
+ *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types, graph);
+ }
+ return s;
+}
+
+Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
+ const InstantiateAttrValueMap& attrs, FunctionBody** g_body) {
+ const AttrValue* f = gtl::FindOrNull(attrs, "f");
+ if (f == nullptr) {
+ return errors::InvalidArgument("SymbolicGradient is missing attr: f");
+ }
+ const auto& func = f->func();
+ const FunctionDef* fdef = lib_def_->Find(func.name());
+ if (fdef == nullptr) {
+ // f is a primitve op.
+ gradient::Creator creator;
+ TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
+ if (creator == nullptr) {
+ return errors::InvalidArgument("No gradient is defined for ",
+ func.name());
+ }
+ FunctionDef grad_fdef;
+ TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
+ TF_RETURN_IF_ERROR(FunctionDefToBody(grad_fdef, func.attr(), g_body));
+ } else {
+ // f is a user-defined function.
+ Handle f_handle;
+ TF_RETURN_IF_ERROR(Instantiate(func.name(), func.attr(), &f_handle));
+ const FunctionBody* f_body = GetFunctionBody(f_handle);
+ CHECK_NOTNULL(f_body);
+ *g_body = SymbolicGradient(*f_body);
+ }
+ return Status::OK();
+}
+
+Status FunctionLibraryRuntimeImpl::Instantiate(
+ const string& function_name, const InstantiateAttrValueMap& attrs,
+ Handle* handle) {
+ const string key = Canonicalize(function_name, attrs);
+ {
+ mutex_lock l(mu_);
+ *handle = gtl::FindWithDefault(table_, key, kInvalidHandle);
+ if (*handle != kInvalidHandle) {
+ return Status::OK();
+ }
+ }
+
+ Status s;
+ FunctionBody* fbody = nullptr;
+ if (function_name == kGradientOp) {
+ TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(attrs, &fbody));
+ } else {
+ const FunctionDef* fdef = lib_def_->Find(function_name);
+ if (fdef == nullptr) {
+ return errors::NotFound("Function ", function_name, " is not defined.");
+ }
+ TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody));
+ }
+
+ {
+ mutex_lock l(mu_);
+ *handle = gtl::FindWithDefault(table_, key, kInvalidHandle);
+ if (*handle != kInvalidHandle) {
+ delete fbody;
+ } else {
+ *handle = func_graphs_.size();
+ table_.insert({key, *handle});
+ func_graphs_.push_back(fbody);
+ items_.resize(func_graphs_.size());
+ }
+ }
+ return Status::OK();
+}
+
+static void DumpGraph(const char* label, const Graph* g) {
+ if (VLOG_IS_ON(1)) {
+ LOG(INFO) << label << ": " << std::endl << DebugString(g);
+ }
+}
+
+static void SimplifyGraph(Graph* g) {
+ if (RemoveListArrayConverter(g)) {
+ DumpGraph("RemoveListArrayConverter", g);
+ }
+ bool changed;
+ do {
+ changed = false;
+ if (RemoveDeadNodes(g)) {
+ changed = true;
+ DumpGraph("RemoveDeadNodes", g);
+ }
+ if (RemoveIdentityNodes(g)) {
+ changed = true;
+ DumpGraph("RemoveIdentityNodes", g);
+ }
+ FixupSourceAndSinkEdges(g);
+ OptimizeCSE(g, nullptr);
+ DumpGraph("OptimizeCSE", g);
+ } while (changed);
+}
+
+void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
+ DumpGraph("Initial", *g);
+ const int kNumInlineRounds = 10;
+ for (int i = 0; i < kNumInlineRounds; ++i) {
+ if (!ExpandInlineFunctions(lib, *g)) break;
+ DumpGraph("ExpandInlineFunctions", *g);
+ SimplifyGraph(*g);
+ }
+
+ // Makes a copy so that we densify node ids.
+ Graph* copy = new Graph((*g)->op_registry());
+ CopyGraph(**g, copy);
+ delete *g;
+ *g = copy;
+ DumpGraph("ReCopy", *g);
+}
+
+Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
+ const FunctionBody* fbody = GetFunctionBody(handle);
+ CHECK_NOTNULL(fbody);
+ Graph* g = new Graph(lib_def_);
+ CopyGraph(*fbody->graph, g);
+ OptimizeGraph(this, &g);
+
+ // Creates an executor based on the g. This must be done without
+ // holding mu_ because create_kernel_ calls back into the library.
+ LocalExecutorParams params;
+ params.device = device_;
+ params.function_library = this;
+ params.has_control_flow = false;
+ params.create_kernel = create_kernel_;
+ params.delete_kernel = [](OpKernel* kernel) {
+ DeleteNonCachedKernel(kernel);
+ };
+ Executor* exec;
+ TF_RETURN_IF_ERROR(NewLocalExecutor(params, g, &exec));
+
+ *item = new Item;
+ (*item)->exec = exec;
+ return Status::OK();
+}
+
+Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
+ {
+ mutex_lock l(mu_);
+ if (handle >= items_.size()) {
+ return errors::NotFound("Function handle ", handle,
+ " is not valid. Likely an internal error.");
+ }
+ *item = items_[handle];
+ if (*item != nullptr) {
+ (*item)->Ref();
+ return Status::OK();
+ }
+ }
+ // NOTE: We need to call CreateItem out of mu_ because creating an
+ // executor needs to call CreateKernel.
+ TF_RETURN_IF_ERROR(CreateItem(handle, item));
+
+ {
+ mutex_lock l(mu_);
+ if (items_[handle] == nullptr) {
+ // Install *item in items_.
+ items_[handle] = *item;
+ (*item)->Ref();
+ }
+ }
+ return Status::OK();
+}
+
+void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
+ gtl::ArraySlice<Tensor> args,
+ std::vector<Tensor>* rets,
+ DoneCallback done) {
+ if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
+ return done(errors::Cancelled(""));
+ }
+ const FunctionBody* fbody = GetFunctionBody(handle);
+ FunctionCallFrame* frame =
+ new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
+ Status s = frame->SetArgs(args);
+ if (!s.ok()) {
+ delete frame;
+ return done(s);
+ }
+ Item* item = nullptr;
+ s = GetOrCreateItem(handle, &item);
+ if (!s.ok()) {
+ delete frame;
+ return done(s);
+ }
+ Executor::Args exec_args;
+ exec_args.call_frame = frame;
+ exec_args.cancellation_manager = opts.cancellation_manager;
+ exec_args.runner = runner_;
+ item->exec->RunAsync(
+ // Executor args
+ exec_args,
+ // Done callback.
+ [item, frame, rets, done](const Status& status) {
+ item->Unref();
+ Status s = status;
+ if (s.ok()) {
+ s = frame->GetRetvals(rets);
+ }
+ delete frame;
+ done(s);
+ });
+}
+
+bool FunctionLibraryRuntimeImpl::IsDefined(const string& function_name) {
+ return lib_def_->Find(function_name) != nullptr;
+}
+
+FunctionLibraryRuntime* NewFunctionLibraryRuntime(
+ Device* device, Runner runner, const FunctionLibraryDefinition* lib_def) {
+ return new FunctionLibraryRuntimeImpl(device, runner, lib_def);
+}
+
+bool RemoveDeadNodes(Graph* g) {
+ std::vector<bool> visited(g->num_node_ids(), false);
+ visited[Graph::kSourceId] = true;
+ visited[Graph::kSinkId] = true;
+ std::deque<Node*> q;
+ for (auto n : g->nodes()) {
+ if (n->op_def().is_stateful()) {
+ visited[n->id()] = true;
+ } else if (n->type_string() == kArgOp) {
+ visited[n->id()] = true;
+ } else if (n->type_string() == kRetOp) {
+ visited[n->id()] = true;
+ q.push_back(n);
+ }
+ }
+ while (!q.empty()) {
+ const Node* n = q.front();
+ q.pop_front();
+ visited[n->id()] = true;
+ for (auto e : n->in_edges()) {
+ q.push_back(e->src());
+ }
+ }
+ bool removed_any = false;
+ for (Node* n : g->nodes()) {
+ if (!visited[n->id()]) {
+ g->RemoveNode(n);
+ removed_any = true;
+ }
+ }
+ return removed_any;
+}
+
+namespace {
+// If 'edges' contains only 1 non-control edge, returns it. Otherwise,
+// returns a nullptr.
+const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
+ const Edge* ret = nullptr;
+ for (const Edge* e : edges) {
+ if (e->IsControlEdge() || ret) return nullptr;
+ ret = e;
+ }
+ return ret;
+}
+} // end namespace
+
+bool RemoveIdentityNodes(Graph* g) {
+ bool removed_any = false;
+ gtl::InlinedVector<Node*, 8> matches;
+ for (Node* n : g->nodes()) {
+ if ((n->type_string() == "Identity") && GetTheOnlyDataEdge(n->in_edges())) {
+ matches.push_back(n);
+ }
+ }
+ if (!matches.empty()) {
+ for (Node* n : matches) {
+ const Edge* in = GetTheOnlyDataEdge(n->in_edges());
+ for (const Edge* out : n->out_edges()) {
+ if (out->IsControlEdge()) {
+ g->AddControlEdge(in->src(), out->dst());
+ } else {
+ g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
+ }
+ }
+ g->RemoveNode(n);
+ removed_any = true;
+ }
+ }
+ return removed_any;
+}
+
+bool RemoveListArrayConverter(Graph* g) {
+ gtl::InlinedVector<Node*, 8> matches;
+ for (Node* n : g->nodes()) {
+ if ((n->type_string() == "_ListToArray") ||
+ (n->type_string() == "_ArrayToList")) {
+ matches.push_back(n);
+ }
+ }
+ bool removed_any = false;
+ if (!matches.empty()) {
+ for (Node* n : matches) {
+ if (n->num_inputs() != n->num_outputs()) {
+ continue; // Not expected. Skip.
+ }
+ gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
+
+ // Process input edges first.
+ Node* input_control_node = nullptr;
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ if (input_control_node == nullptr) {
+ // If node "n" has any control dependencies, adds a no-op
+ // node (input_control_node) which the additional Identity
+ // nodes depends on and the input_control_node depends on
+ // the node "n"s control dependencies.
+ input_control_node = AddNoOp(g);
+ }
+ g->AddControlEdge(e->src(), input_control_node);
+ } else {
+ const int index = e->dst_input();
+ Node** id_node = &identity_nodes[index];
+ if (*id_node != nullptr) {
+ LOG(ERROR)
+ << "RemoveListArrayConverter unexpected duplicated input: "
+ << e->dst_input();
+ return removed_any;
+ }
+ *id_node = AddIdentity(g, {e->src(), e->src_output()});
+ }
+ }
+
+ // If node "n" has any control dependencies, the added identity
+ // nodes should have control dependencies on input_control_node.
+ if (input_control_node != nullptr) {
+ for (Node* id : identity_nodes) {
+ g->AddControlEdge(input_control_node, id);
+ }
+ }
+
+ Node* output_control_node = nullptr;
+ for (const Edge* e : n->out_edges()) {
+ if (e->IsControlEdge()) {
+ if (output_control_node == nullptr) {
+ // If node "n" is control-depended upon by other nodes,
+ // adds a no-op node (output_control_node) which those
+ // nodes will depend on and output_control_node depends on
+ // all Identity nodes.
+ output_control_node = AddNoOp(g);
+ }
+ g->AddControlEdge(output_control_node, e->dst());
+ } else {
+ Node* id_node = identity_nodes[e->src_output()];
+ if (id_node == nullptr) {
+ LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
+ << e->src_output();
+ return removed_any;
+ }
+ CHECK(id_node);
+ g->AddEdge(id_node, 0, e->dst(), e->dst_input());
+ }
+ }
+
+ // If any nodes have control dependencies on node "n", those
+ // nodes should have control dependencies on
+ // output_control_node.
+ if (output_control_node != nullptr) {
+ for (Node* id : identity_nodes) {
+ g->AddControlEdge(id, output_control_node);
+ }
+ }
+
+ g->RemoveNode(n);
+ removed_any = true;
+ }
+ }
+ return removed_any;
+}
+
+// Returns true iff the function '*fbody' can be inlined at 'node'
+// based on the type signature of 'node' and 'fbody'.
+static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
+ if (static_cast<size_t>(node->num_inputs()) != fbody->arg_types.size()) {
+ return false;
+ }
+ if (static_cast<size_t>(node->num_inputs()) != fbody->arg_nodes.size()) {
+ return false;
+ }
+ if (static_cast<size_t>(node->num_outputs()) != fbody->ret_types.size()) {
+ return false;
+ }
+ if (static_cast<size_t>(node->num_outputs()) != fbody->ret_nodes.size()) {
+ return false;
+ }
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ if (node->input_type(i) != fbody->arg_types[i]) return false;
+ }
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ if (node->output_type(i) != fbody->ret_types[i]) return false;
+ }
+ return true;
+}
+
+// Given a "caller" in "graph", which is a function call of a function
+// to "fbody". Replaces the "caller" with fbody->graph and connects
+// edges properly.
+static void InlineFunctionBody(Graph* g, Node* caller,
+ const FunctionBody* fbody) {
+ if (!ValidateInlining(caller, fbody)) {
+ LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
+ << DebugString(fbody->graph);
+ return;
+ }
+
+ // Duplicate fbody->graph into 'g'. First, we copy the nodes of
+ // fbody->graph into 'g' except the source and sink nodes. We copy
+ // edges among nodes in 'fbody->graph'.
+ //
+ // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
+ // remember 'y' in node_map[x->id()].
+ std::vector<Node*> node_map(fbody->graph->num_node_ids());
+ for (Node* n : fbody->graph->nodes()) {
+ if (n->IsSource() || n->IsSink()) continue;
+ CHECK(n->IsOp());
+ node_map[n->id()] = g->CopyNode(n);
+ }
+ for (const Edge* e : fbody->graph->edges()) {
+ if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
+ e->dst()->IsSink()) {
+ continue;
+ }
+ Node* src_copy = node_map[e->src()->id()];
+ Node* dst_copy = node_map[e->dst()->id()];
+ g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
+ }
+
+ // Connect input edges.
+ //
+ // For data edges coming into "caller", we first compute the
+ // <src>:<src_output> for the i-th input in "inputs". We create one
+ // Identity node for each input. Then, we connect inputs[i] to to
+ // the i-th identity node added. The nodes that previously connects
+ // to the j-th output of i-th arg node are reconnected to th i-th
+ // identity node.
+ //
+ // If "caller" has any input control dependencies, we add a NoOp
+ // node "input_control_node". This "input_control_node" depends on
+ // what "caller" depends on, and the added identity nodes depend on
+ // "input_control_node".
+ std::vector<Endpoint> inputs(caller->num_inputs());
+ Node* input_control_node = nullptr;
+ for (const Edge* e : caller->in_edges()) {
+ if (e->IsControlEdge()) {
+ if (input_control_node == nullptr) {
+ input_control_node = AddNoOp(g);
+ }
+ g->AddControlEdge(e->src(), input_control_node);
+ } else {
+ inputs[e->dst_input()] = {e->src(), e->src_output()};
+ }
+ }
+ for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
+ Node* arg = node_map[fbody->arg_nodes[i]->id()];
+ Node* n = AddIdentity(g, inputs[i]);
+ if (input_control_node) {
+ g->AddControlEdge(input_control_node, n);
+ }
+ for (const Edge* e : arg->out_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(n, e->dst());
+ } else {
+ g->AddEdge(n, 0, e->dst(), e->dst_input());
+ }
+ }
+ node_map[fbody->arg_nodes[i]->id()] = n;
+ g->RemoveNode(arg); // 'arg' is disconnected.
+ }
+
+ // Connect output edges.
+ //
+ // For i-th return node in fbody->graph, we add in "g" an identity
+ // node (outputs[i-th]). We then reconnect every incoming edge into
+ // the i-th return node to the added identity node.
+ //
+ // For every data edge coming out of "callee"s i-th output, we
+ // reconnect it to the i-th identity added above.
+ //
+ // If "callee" is control-depended upon by any other nodes, we add a
+ // NoOp node "output_control_node". "output_control_node" depends on
+ // all identity nodes added above. And nodes previously depend on
+ // "callee" is changed to depend on "output_control_node".
+ std::vector<Node*> outputs(caller->num_inputs());
+ for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
+ Node* ret = node_map[fbody->ret_nodes[i]->id()];
+ Endpoint data; // Data input for the ret node.
+ for (const Edge* e : ret->in_edges()) {
+ if (!e->IsControlEdge()) {
+ data = {e->src(), e->src_output()};
+ break;
+ }
+ }
+ CHECK(data.node != nullptr);
+ Node* n = AddIdentity(g, data);
+ outputs[i] = n;
+ for (const Edge* e : ret->in_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(e->src(), n);
+ }
+ }
+ g->RemoveNode(ret); // 'ret' is disconnected.
+ }
+ Node* output_control_node = nullptr;
+ for (const Edge* e : caller->out_edges()) {
+ if (e->IsControlEdge()) {
+ if (output_control_node == nullptr) {
+ output_control_node = AddNoOp(g);
+ for (Node* n : outputs) {
+ g->AddControlEdge(n, output_control_node);
+ }
+ }
+ g->AddControlEdge(output_control_node, e->dst());
+ } else {
+ g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
+ }
+ }
+ g->RemoveNode(caller); // 'caller' is replaced with inlined nodes.
+}
+
+bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
+ std::vector<std::pair<Node*, const FunctionBody*>> candidates;
+ for (Node* node : graph->nodes()) {
+ VLOG(3) << "Expanding " << node->DebugString();
+ FunctionLibraryRuntime::Handle handle;
+ Status s =
+ lib->Instantiate(node->type_string(), node->def().attr(), &handle);
+ if (!s.ok()) {
+ // Either "node" is a primitive op, or the instantiation failed.
+ if (errors::IsNotFound(s)) {
+ VLOG(2) << "ExpandInlineFunctions " << s;
+ } else {
+ LOG(ERROR) << "ExpandInlineFunctions " << s;
+ }
+ continue;
+ }
+ const FunctionBody* fbody = lib->GetFunctionBody(handle);
+ CHECK_NOTNULL(fbody);
+ candidates.push_back({node, fbody});
+ }
+ for (const auto& p : candidates) {
+ InlineFunctionBody(graph, p.first, p.second);
+ }
+ return !candidates.empty();
+}
+
+// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
+// and stash the original NodeDef name as an attr for documentation
+// purpose.
+static void ToGraphDef(const Graph* g, GraphDef* gdef) {
+ // We visit nodes in forward topological sort order, which is a
+ // possible execution order of the graph.
+ std::vector<int> pending(g->num_node_ids());
+ std::deque<const Node*> ready;
+ for (const Node* n : g->nodes()) {
+ pending[n->id()] = n->in_edges().size();
+ if (pending[n->id()] == 0) ready.push_back(n);
+ }
+ gtl::InlinedVector<const Edge*, 4> inputs;
+ gdef->Clear();
+ while (!ready.empty()) {
+ const Node* n = ready.front();
+ ready.pop_front();
+ for (const Edge* e : n->out_edges()) {
+ const Node* next = e->dst();
+ if (--pending[next->id()] == 0) {
+ ready.push_back(next);
+ }
+ }
+ if (!n->IsOp()) continue;
+ NodeDef* ndef = gdef->add_node();
+ ndef->set_name(strings::StrCat("n", n->id()));
+ ndef->set_op(n->type_string());
+ *(ndef->mutable_attr()) = n->def().attr();
+ inputs.clear();
+ inputs.resize(n->num_inputs());
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ inputs.push_back(e);
+ } else {
+ if (inputs[e->dst_input()] == nullptr) {
+ inputs[e->dst_input()] = e;
+ } else {
+ LOG(WARNING) << "Malformed graph node. multiple input edges: "
+ << n->DebugString();
+ }
+ }
+ }
+ // node->name() is merely NodeDef::name, which are not guaranteed
+ // to be unique and stable after optimization rewrites. Therefore,
+ // we use "n<node id>" instead.
+ for (const Edge* e : inputs) {
+ if (e == nullptr) {
+ ndef->add_input("unknown");
+ } else if (!e->src()->IsOp()) {
+ } else if (e->IsControlEdge()) {
+ ndef->add_input(strings::StrCat("^n", e->src()->id()));
+ } else if (e->src_output() == 0) {
+ ndef->add_input(strings::StrCat("n", e->src()->id()));
+ } else {
+ ndef->add_input(
+ strings::StrCat("n", e->src()->id(), ":", e->src_output()));
+ }
+ }
+ }
+}
+
+string DebugString(const Graph* g) {
+ GraphDef gdef;
+ ToGraphDef(g, &gdef);
+ return DebugString(gdef);
+}
+
+FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
+ DataTypeSlice ret_t, Graph* g)
+ : fdef(f),
+ graph(g),
+ arg_types(arg_t.begin(), arg_t.end()),
+ ret_types(ret_t.begin(), ret_t.end()) {
+ this->arg_nodes.resize(arg_types.size());
+ this->ret_nodes.resize(ret_types.size());
+ for (Node* n : this->graph->nodes()) {
+ gtl::InlinedVector<Node*, 4>* node_vec;
+ if (n->type_string() == kRetOp) {
+ node_vec = &this->ret_nodes;
+ } else if (n->type_string() == kArgOp) {
+ node_vec = &this->arg_nodes;
+ } else {
+ continue;
+ }
+ int index;
+ TF_CHECK_OK(GetNodeAttr(n->def(), "index", &index));
+ CHECK_LE(0, index);
+ CHECK_LT(index, node_vec->size());
+ (*node_vec)[index] = n;
+ }
+}
+
+FunctionBody::~FunctionBody() { delete this->graph; }
+
+class SymbolicGradientHelper {
+ public:
+ explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
+
+ ~SymbolicGradientHelper() { delete gbody_; }
+
+ FunctionBody* Compute();
+
+ private:
+ const FunctionBody* fbody_;
+ FunctionBody* gbody_ = nullptr;
+
+ // A vector of output endpoints which represents backpropagated
+ // gradients
+ typedef std::vector<Endpoint> BackpropedGradients;
+
+ // backprops_ is a map from an output endpoint to its accumulated
+ // gradients. When an output endpoint has accumulated all its
+ // gradients, we add a node which sums them up.
+ std::unordered_map<Endpoint, BackpropedGradients, EndpointHash, EndpointEq>
+ backprops_;
+
+ // pending[i] is count-down counter for i-th node's expected
+ // backprops. When pending[i] becomes zero, we collected all
+ // backprop gradients for all output endpoint of the ith-node.
+ std::vector<int> pending_;
+
+ // 'ready' keeps track of nodes that have been completely
+ // backpropped. Initially, for every output y of the function f, we
+ // add dy as an input of the the gradient function.
+ std::deque<Node*> ready_;
+
+ // Makes a copy of fbody_ in gbody_.
+ void Copy();
+
+ // Initialize pending_ and ready_.
+ void InitBackprop();
+
+ // In the original function body, there is a forward edge from 'src'
+ // to 'dst', when the backprop algorithm constructs the node
+ // 'dst_grad' which computes the gradient, we need to propagate it
+ // to 'src'.
+ void BackpropAlongEdge(const Endpoint& dst_grad, const Endpoint& src);
+ void BackpropZerosAlongEdge(const Endpoint& src);
+
+ Endpoint SumGradients(const Endpoint& src);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
+};
+
+void SymbolicGradientHelper::Copy() {
+ const Graph& src = *(fbody_->graph);
+ gbody_->graph = new Graph(src.op_registry());
+ Graph* dst = gbody_->graph;
+
+ std::vector<Node*> node_map(src.num_node_ids());
+
+ // Copy the nodes.
+ node_map[src.source_node()->id()] = dst->source_node();
+ node_map[src.sink_node()->id()] = dst->sink_node();
+ for (Node* n : src.nodes()) {
+ if (n->IsSource() || n->IsSink()) continue;
+ CHECK(n->IsOp());
+ node_map[n->id()] = dst->CopyNode(n);
+ }
+
+ // Copy the edges.
+ for (const Edge* e : src.edges()) {
+ Node* src_copy = node_map[e->src()->id()];
+ Node* dst_copy = node_map[e->dst()->id()];
+ dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
+ }
+
+ // Save inputs in copied graph.
+ CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
+ gbody_->arg_types = fbody_->arg_types;
+ for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
+ gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
+ }
+
+ // Save outputs in copied graph.
+ CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
+ gbody_->ret_types = fbody_->ret_types;
+ for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
+ gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
+ }
+}
+
+void SymbolicGradientHelper::BackpropAlongEdge(const Endpoint& dst_grad,
+ const Endpoint& src) {
+ CHECK_NOTNULL(src.node);
+ auto iter = backprops_.find(src);
+ if (iter != backprops_.end()) {
+ auto* grads = &iter->second;
+ grads->push_back(dst_grad);
+ if (--pending_[src.node->id()] == 0) {
+ ready_.push_back(src.node);
+ }
+ }
+}
+
+void SymbolicGradientHelper::BackpropZerosAlongEdge(const Endpoint& src) {
+ CHECK_NOTNULL(src.node);
+ auto iter = backprops_.find(src);
+ if (iter != backprops_.end()) {
+ if (--pending_[src.node->id()] == 0) {
+ ready_.push_back(src.node);
+ }
+ }
+}
+
+void SymbolicGradientHelper::InitBackprop() {
+ Graph* g = gbody_->graph;
+ pending_.resize(g->num_node_ids(), 0);
+ {
+ backprops_.clear();
+ std::unordered_set<Node*> visited;
+ std::deque<Node*> queue;
+ for (Node* n : gbody_->arg_nodes) {
+ queue.push_back(n);
+ }
+
+ // Going forward to figure out which endpoints need backprop-ed.
+ // A node's endpoints need to be backprop-ed only if one of the
+ // arg node can reach the node via data edges.
+ while (!queue.empty()) {
+ Node* n = queue.front();
+ queue.pop_front();
+ visited.insert(n);
+ for (int i = 0; i < n->num_outputs(); ++i) {
+ backprops_[{n, i}].clear();
+ }
+ int num_expected_backprops = 0;
+ for (const Edge* e : n->out_edges()) {
+ if (e->IsControlEdge()) continue;
+ ++num_expected_backprops;
+ if (visited.find(e->dst()) == visited.end()) {
+ queue.push_back(e->dst());
+ }
+ }
+ pending_[n->id()] = num_expected_backprops;
+ }
+ }
+
+ {
+ const int num_y = gbody_->ret_nodes.size();
+ for (int i = 0; i < num_y; ++i) {
+ Node* y = gbody_->ret_nodes[i];
+ DCHECK_EQ(y->type_string(), kRetOp);
+ const DataType dtype = y->input_type(0);
+ const int index = gbody_->arg_nodes.size();
+ Node* dy = AddArg(g, dtype, index);
+ gbody_->arg_types.push_back(dtype);
+ gbody_->arg_nodes.push_back(dy);
+
+ // What's the input to y?
+ Endpoint y_in{nullptr, 0};
+ for (const Edge* e : y->in_edges()) {
+ if (!e->IsControlEdge()) {
+ y_in = {e->src(), e->src_output()};
+ break;
+ }
+ }
+ CHECK_NOTNULL(y_in.node);
+ BackpropAlongEdge({dy, 0}, y_in);
+ }
+ }
+}
+
+Endpoint SymbolicGradientHelper::SumGradients(const Endpoint& src) {
+ Graph* g = gbody_->graph;
+ const DataType dtype = src.dtype();
+ auto iter = backprops_.find(src);
+ CHECK(iter != backprops_.end());
+ const auto& grads = iter->second;
+ if (grads.empty()) {
+ // Nothing propagated back. The best we can come up is zeros.
+ Node* zero_like = AddZerosLike(g, src);
+ return {zero_like, 0};
+ }
+ if (grads.size() == 1) {
+ // Just one backprop edge.
+ return grads[0];
+ }
+ // Otherwise, adds backprop-ed gradients.
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("AddN"); // N-way Add
+ for (const Endpoint& ep : grads) {
+ ndef.add_input(ep.name());
+ }
+ AddNodeAttr("N", static_cast<int64>(grads.size()), &ndef);
+ AddNodeAttr("T", dtype, &ndef);
+ Status s;
+ Node* add = gbody_->graph->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ for (size_t i = 0; i < grads.size(); ++i) {
+ const Endpoint& ep = grads[i];
+ g->AddEdge(ep.node, ep.index, add, i);
+ }
+ return {add, 0};
+}
+
+static bool IsPrimitiveOpWithNoGrad(const string& func) {
+ gradient::Creator creator;
+ Status s = gradient::GetOpGradientCreator(func, &creator);
+ return s.ok() && (creator == nullptr);
+}
+
+FunctionBody* SymbolicGradientHelper::Compute() {
+ CHECK(gbody_ == nullptr);
+ gbody_ = new FunctionBody;
+
+ // Copy fbody_ into gbody_.
+ Copy();
+
+ // Initialize backprops.
+ InitBackprop();
+
+ // Backward propagation.
+ gtl::InlinedVector<Endpoint, 8> dy;
+ Graph* g = gbody_->graph;
+ while (!ready_.empty()) {
+ // n has collected all gradients.
+ Node* n = ready_.front();
+ ready_.pop_front();
+
+ if (n->type_string() == kArgOp) {
+ // We'll handle the _Arg node after backprop is done.
+ continue;
+ }
+
+ // "n" has num_x inputs and num_y outputs.
+ const int num_x = n->num_inputs();
+ const int num_y = n->num_outputs();
+
+ // dy[i] is the sum of i-th output's backpropped gradients.
+ dy.clear();
+ dy.resize(num_y, {nullptr, 0});
+ for (int i = 0; i < num_y; ++i) {
+ dy[i] = SumGradients({n, i});
+ }
+
+ if (IsPrimitiveOpWithNoGrad(n->type_string())) {
+ // No grad defined for this op. Backprops zeros along the in
+ // edges.
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ BackpropZerosAlongEdge({e->src(), e->src_output()});
+ }
+ continue;
+ }
+
+ // Adds a gradient node with num_x + num_y inputs and num_x
+ // outputs.
+ Node* grad = AddSymGrad(g, n, dy);
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ g->AddEdge(e->src(), e->src_output(), grad, e->dst_input());
+ }
+ for (int i = 0; i < num_y; ++i) {
+ g->AddEdge(dy[i].node, dy[i].index, grad, num_x + i);
+ }
+
+ // Backprops along the in edges.
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ BackpropAlongEdge({grad, e->dst_input()}, {e->src(), e->src_output()});
+ }
+ }
+
+ // The gradient's retval nodes.
+ for (Node* n : gbody_->ret_nodes) {
+ g->RemoveNode(n);
+ }
+ gbody_->ret_types = fbody_->arg_types;
+ gbody_->ret_nodes.clear();
+ for (size_t i = 0; i < fbody_->arg_types.size(); ++i) {
+ Endpoint grad = SumGradients({gbody_->arg_nodes[i], 0});
+ Node* ret = AddRet(g, grad, i);
+ gbody_->ret_nodes.push_back(ret);
+ }
+
+ auto ret = gbody_;
+ gbody_ = nullptr;
+ return ret;
+}
+
+FunctionBody* SymbolicGradient(const FunctionBody& f) {
+ return SymbolicGradientHelper(f).Compute();
+}
+
+} // end namespace tensorflow