aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/BUILD89
-rw-r--r--tensorflow/cc/ops/array_grad.cc32
-rw-r--r--tensorflow/cc/ops/cc_op_gen.cc350
-rw-r--r--tensorflow/cc/ops/cc_op_gen.h14
-rw-r--r--tensorflow/cc/ops/cc_op_gen_main.cc34
-rw-r--r--tensorflow/cc/ops/const_op.cc113
-rw-r--r--tensorflow/cc/ops/const_op.h70
-rw-r--r--tensorflow/cc/ops/functional_grad.cc42
-rw-r--r--tensorflow/cc/ops/math_grad.cc566
-rw-r--r--tensorflow/cc/ops/nn_grad.cc55
-rw-r--r--tensorflow/cc/ops/standard_ops.h26
-rw-r--r--tensorflow/cc/tutorials/example_trainer.cc146
12 files changed, 1537 insertions, 0 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
new file mode 100644
index 0000000000..8a5bf87a29
--- /dev/null
+++ b/tensorflow/cc/BUILD
@@ -0,0 +1,89 @@
+# Description:
+# TensorFlow is a computational framework, primarily for use in machine
+# learning applications.
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("/tensorflow/tensorflow", "tf_copts")
+load("/tensorflow/tensorflow", "tf_gen_op_wrappers_cc")
+
+cc_library(
+ name = "cc_op_gen_main",
+ srcs = [
+ "ops/cc_op_gen.cc",
+ "ops/cc_op_gen_main.cc",
+ ],
+ hdrs = ["ops/cc_op_gen.h"],
+ copts = tf_copts(),
+ deps = [
+ "//tensorflow/core:framework",
+ ],
+)
+
+# Generates a library that contains C++ wrappers for ops.
+tf_gen_op_wrappers_cc(
+ name = "cc_ops",
+ op_lib_names = [
+ "array_ops",
+ "attention_ops",
+ "candidate_sampling_ops",
+ "control_flow_ops",
+ "data_flow_ops",
+ "image_ops",
+ "io_ops",
+ "linalg_ops",
+ "logging_ops",
+ "math_ops",
+ "nn_ops",
+ "no_op",
+ "parsing_ops",
+ "random_ops",
+ "sendrecv_ops",
+ "sparse_ops",
+ "state_ops",
+ "string_ops",
+ "summary_ops",
+ "training_ops",
+ "user_ops",
+ ],
+ other_hdrs = [
+ "ops/const_op.h",
+ "ops/standard_ops.h",
+ ],
+ other_srcs = [
+ "ops/const_op.cc",
+ ] + glob(["ops/*_grad.cc"]),
+ pkg = "//tensorflow/core",
+)
+
+cc_binary(
+ name = "tutorials_example_trainer",
+ srcs = ["tutorials/example_trainer.cc"],
+ copts = tf_copts(),
+ linkopts = [
+ "-lpthread",
+ "-lm",
+ ],
+ deps = [
+ ":cc_ops",
+ "//tensorflow/core:kernels",
+ "//tensorflow/core:local",
+ "//tensorflow/core:tensorflow",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/cc/ops/array_grad.cc b/tensorflow/cc/ops/array_grad.cc
new file mode 100644
index 0000000000..37ffed9792
--- /dev/null
+++ b/tensorflow/cc/ops/array_grad.cc
@@ -0,0 +1,32 @@
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+typedef FunctionDefHelper FDH;
+
+REGISTER_OP_NO_GRADIENT("Shape");
+REGISTER_OP_NO_GRADIENT("Rank");
+REGISTER_OP_NO_GRADIENT("Size");
+
+Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "shape: int32", "dy: T"},
+ // Ret val defs
+ {"dx: T", "dshape: int32"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // Nodes
+ {
+ {{"x_shape"}, "Shape", {"x"}, {{"T", "$T"}}},
+ {{"dx"}, "Reshape", {"dy", "x_shape"}, {{"T", "$T"}}},
+ {{"dshape"}, "ZerosLike", {"shape"}, {{"T", DT_INT32}}},
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("Reshape", ReshapeGrad);
+
+} // end namespace tensorflow
diff --git a/tensorflow/cc/ops/cc_op_gen.cc b/tensorflow/cc/ops/cc_op_gen.cc
new file mode 100644
index 0000000000..fdecf967f8
--- /dev/null
+++ b/tensorflow/cc/ops/cc_op_gen.cc
@@ -0,0 +1,350 @@
+// TODO(josh11b): Rewrite function parameter names to avoid C++ keywords
+// or "opts".
+
+#include "tensorflow/cc/ops/cc_op_gen.h"
+
+#include <unordered_map>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace {
+
+const int kRightMargin = 79;
+
+const char* AttrTypeName(StringPiece attr_type) {
+ static const char* kAttrTypeName[][2] = {
+ {"string", "StringPiece"},
+ {"list(string)", "gtl::ArraySlice<string>"},
+ {"int", "int64"},
+ {"list(int)", "gtl::ArraySlice<int>"},
+ {"float", "float"},
+ {"list(float)", "gtl::ArraySlice<float>"},
+ {"bool", "bool"},
+ {"list(bool)", "gtl::ArraySlice<bool>"},
+ {"type", "DataType"},
+ {"list(type)", "DataTypeSlice"},
+ {"shape", "TensorShape"},
+ {"list(shape)", "gtl::ArraySlice<TensorShape>"},
+ {"tensor", "const Tensor&"},
+ {"list(tensor)", "gtl::ArraySlice<Tensor>"},
+ {"func", "const NameAttrList&"},
+ };
+ for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
+ if (attr_type == kAttrTypeName[i][0]) {
+ return kAttrTypeName[i][1];
+ }
+ }
+ LOG(FATAL) << "Unsupported Attr type: " << attr_type;
+ return "";
+}
+
+// Change: Into:
+// ABC // ABC
+// //
+// DEF // DEF
+string MakeComment(StringPiece text) {
+ string ret;
+ while (!text.empty()) {
+ int last_non_space = -1;
+ int newline;
+ for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
+ if (text[newline] == '\n') break;
+ if (text[newline] != ' ') last_non_space = newline;
+ }
+ if (last_non_space == -1) {
+ strings::StrAppend(&ret, "//\n");
+ } else {
+ strings::StrAppend(&ret, "// ", text.substr(0, last_non_space + 1), "\n");
+ }
+ text.remove_prefix(newline + 1);
+ }
+ return ret;
+}
+
+void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) {
+ // TODO(josh11b): Better wrapping of comments.
+ string comment;
+ if (op_def.summary().empty()) {
+ comment = "TODO: add doc.\n";
+ } else {
+ comment = strings::StrCat(op_def.summary(), "\n");
+ if (!op_def.description().empty()) {
+ strings::StrAppend(&comment, "\n", op_def.description(), "\n");
+ }
+ }
+
+ static const string kSingleInputType = "NodeOut";
+ static const string kListInputType = "gtl::ArraySlice<NodeOut>";
+
+ std::vector<string> arg_types;
+ std::vector<string> arg_names;
+
+ strings::StrAppend(&comment, "\nArguments:\n");
+
+ // Map from attr name to the first input arg it is inferred from.
+ std::unordered_map<string, string> inferred_attrs;
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ const auto& arg(op_def.input_arg(i));
+ arg_names.emplace_back(arg.name());
+ bool is_list = false;
+
+ if (!arg.type_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
+ } else if (!arg.type_list_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
+ arg.name());
+ is_list = true;
+ }
+ if (!arg.number_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
+ is_list = true;
+ }
+ if (is_list) {
+ arg_types.emplace_back(kListInputType);
+ } else {
+ arg_types.emplace_back(kSingleInputType);
+ }
+
+ // TODO(josh11b): Include input type information.
+ StringPiece description = arg.description();
+ if (!description.empty()) {
+ ConsumeEquals(&description);
+ strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
+ arg.description(), "\n");
+ }
+ }
+
+ string options_comment;
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ const auto& attr(op_def.attr(i));
+ // Do not add inferred attrs or attrs with defaults to the C++
+ // function signature.
+ if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
+ if (!attr.has_default_value()) {
+ arg_names.emplace_back(attr.name());
+ arg_types.emplace_back(AttrTypeName(attr.type()));
+ if (!attr.description().empty()) {
+ strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
+ attr.description(), "\n");
+ }
+ } else {
+ strings::StrAppend(&options_comment, " .WithAttr(\"", attr.name(),
+ "\", ", AttrTypeName(attr.type()), "): Defaults to ",
+ SummarizeAttrValue(attr.default_value()), ".\n");
+ if (!attr.description().empty()) {
+ strings::StrAppend(&options_comment, " ", attr.description(),
+ "\n");
+ }
+ }
+ }
+ }
+ CHECK_EQ(arg_names.size(), arg_types.size());
+ strings::StrAppend(&comment, "* opts:\n", options_comment,
+ R"comment( .WithName(StringPiece): Set the Node's name
+ .WithDevice(StringPiece): Set the Node's requested device
+ .WithControlInput(Node*) / .WithControlInputs({Node*, ...}):
+ Add control depencies on the specified Node(s).
+
+Returns a pointer to the created Node)comment");
+
+ // TODO(josh11b): Include output type information.
+ if (op_def.output_arg_size() == 0) {
+ strings::StrAppend(&comment, ".\n");
+ } else if (op_def.output_arg_size() == 1) {
+ StringPiece description = op_def.output_arg(0).description();
+ ConsumeEquals(&description);
+ if (description.empty()) {
+ strings::StrAppend(&comment, ".\n");
+ } else {
+ strings::StrAppend(&comment, ", with output:\n", description, "\n");
+ }
+ } else {
+ strings::StrAppend(&comment, ", with outputs:\n");
+ for (int o = 0; o < op_def.output_arg_size(); ++o) {
+ StringPiece description = op_def.output_arg(o).description();
+ ConsumeEquals(&description);
+ if (description.empty()) {
+ strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), "\n");
+ } else {
+ strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), ": ",
+ description, "\n");
+ }
+ }
+ }
+
+ // Write the header comment.
+ TF_CHECK_OK(h->Append(MakeComment(comment)));
+
+ // Declare the function wrapper.
+ const string prefix = strings::StrCat("Node* ", op_def.name(), "(");
+ string h_rest;
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ strings::StrAppend(&h_rest, arg_types[i], " ", arg_names[i], ", ");
+ }
+ strings::StrAppend(&h_rest, "const GraphDefBuilder::Options& opts");
+ string cc_decl = h_rest;
+ strings::StrAppend(&h_rest, ");");
+ TF_CHECK_OK(h->Append(WordWrap(prefix, h_rest, kRightMargin) + "\n\n"));
+
+ // Define the function wrapper.
+ strings::StrAppend(&cc_decl, ") {");
+ TF_CHECK_OK(cc->Append(WordWrap(prefix, cc_decl, kRightMargin) + "\n"));
+ const string op_name = strings::StrCat(" static const string kOpName = \"",
+ op_def.name(), "\";\n");
+
+ if (arg_types.empty()) {
+ TF_CHECK_OK(cc->Append(op_name));
+ TF_CHECK_OK(cc->Append(" return SourceOp(kOpName, opts);\n}\n\n"));
+ } else if (arg_types == std::vector<string>({kSingleInputType})) {
+ TF_CHECK_OK(cc->Append(op_name));
+ TF_CHECK_OK(cc->Append(strings::StrCat(" return UnaryOp(kOpName, ",
+ arg_names[0], ", opts);\n}\n\n")));
+ } else if (arg_types ==
+ std::vector<string>({kSingleInputType, kSingleInputType})) {
+ TF_CHECK_OK(cc->Append(op_name));
+ // TODO(josh11b): Word wrap this if it ever becomes necessary.
+ TF_CHECK_OK(
+ cc->Append(strings::StrCat(" return BinaryOp(kOpName, ", arg_names[0],
+ ", ", arg_names[1], ", opts);\n}\n\n")));
+ } else {
+ TF_CHECK_OK(cc->Append(" if (opts.HaveError()) return nullptr;\n"));
+ TF_CHECK_OK(cc->Append(op_name));
+ TF_CHECK_OK(cc->Append(
+ " NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,\n"
+ " opts.op_registry());\n"));
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (i < static_cast<size_t>(op_def.input_arg_size())) {
+ TF_CHECK_OK(cc->Append(
+ strings::StrCat(" node_builder.Input(", arg_names[i], ");\n")));
+ } else {
+ TF_CHECK_OK(
+ cc->Append(strings::StrCat(" node_builder.Attr(\"", arg_names[i],
+ "\", ", arg_names[i], ");\n")));
+ }
+ }
+ TF_CHECK_OK(
+ cc->Append(" return opts.FinalizeBuilder(&node_builder);\n"
+ "}\n\n"));
+ }
+}
+
+// Converts:
+// bazel-out/.../genfiles/XX
+// to: XX.
+string GetPath(const std::string& dot_h_fname) {
+ auto pos = dot_h_fname.find("/genfiles/");
+ if (pos == string::npos) return dot_h_fname;
+ // - 1 account for the terminating null character (\0) in "/genfiles/".
+ return dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
+}
+
+// Converts:
+// cc/ops/gen_foo_ops.h
+// to:
+// CC_OPS_GEN_FOO_OPS_H_
+string ToGuard(const std::string& path) {
+ string guard;
+ guard.reserve(path.size() + 1); // + 1 -> trailing _
+ for (const char c : path) {
+ if (c >= 'A' && c <= 'Z') {
+ guard += c;
+ } else if (c >= 'a' && c <= 'z') {
+ guard += c + 'A' - 'a';
+ } else {
+ guard += '_';
+ }
+ }
+ guard += '_';
+ return guard;
+}
+
+} // namespace
+
+void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
+ const std::string& dot_cc_fname) {
+ Env* env = Env::Default();
+ WritableFile* h = nullptr;
+ WritableFile* cc = nullptr;
+ TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
+ TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
+
+ // .h Header
+ const string include = GetPath(dot_h_fname);
+ const string guard = ToGuard(include);
+ // TODO(josh11b): Mention the library for which wrappers are being generated.
+ Status s;
+ s = h->Append(
+ strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
+ "#ifndef ",
+ guard,
+ "\n"
+ "#define ",
+ guard, R"header(
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+namespace ops {
+
+// These add a node to the graph from opts.
+//
+// Note for "NodeOut" inputs, you will typically either pass
+// * a {Node*, int index} (to pass the index-th output of that node), or
+// * a Node* (to pass the first output of that node).
+
+
+)header"));
+ TF_CHECK_OK(s);
+ // .cc Header
+ s = cc->Append(
+ strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
+ "#include \"",
+ include, R"header("
+
+#include "tensorflow/core/graph/node_builder.h"
+
+namespace tensorflow {
+namespace ops {
+
+)header"));
+ TF_CHECK_OK(s);
+
+ for (const auto& op_def : ops.op()) {
+ WriteCCOp(op_def, h, cc);
+ }
+
+ // .h Footer
+
+ s = h->Append(strings::StrCat(R"footer(} // namespace ops
+} // namespace tensorflow
+
+#endif // )footer",
+ guard, "\n"));
+ TF_CHECK_OK(s);
+
+ // .cc Footer
+
+ s = cc->Append(R"footer(} // namespace ops
+} // namespace tensorflow
+)footer");
+ TF_CHECK_OK(s);
+
+ TF_CHECK_OK(cc->Close());
+ TF_CHECK_OK(h->Close());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/cc/ops/cc_op_gen.h b/tensorflow/cc/ops/cc_op_gen.h
new file mode 100644
index 0000000000..1a9474ec87
--- /dev/null
+++ b/tensorflow/cc/ops/cc_op_gen.h
@@ -0,0 +1,14 @@
+#ifndef TENSORFLOW_CC_OPS_CC_OP_GEN_H_
+#define TENSORFLOW_CC_OPS_CC_OP_GEN_H_
+
+#include "tensorflow/core/framework/op_def.pb.h"
+
+namespace tensorflow {
+
+// Result is written to files dot_h and dot_cc.
+void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
+ const std::string& dot_cc_fname);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CC_OPS_CC_OP_GEN_H_
diff --git a/tensorflow/cc/ops/cc_op_gen_main.cc b/tensorflow/cc/ops/cc_op_gen_main.cc
new file mode 100644
index 0000000000..b9f0e4a9bd
--- /dev/null
+++ b/tensorflow/cc/ops/cc_op_gen_main.cc
@@ -0,0 +1,34 @@
+#include "tensorflow/cc/ops/cc_op_gen.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace {
+
+void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
+ bool include_internal) {
+ OpList ops;
+ OpRegistry::Global()->Export(include_internal, &ops);
+ WriteCCOps(ops, dot_h, dot_cc);
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (argc != 4) {
+ fprintf(stderr,
+ "Usage: %s out.h out.cc include_internal\n"
+ " include_internal: 1 means include internal ops\n",
+ argv[0]);
+ exit(1);
+ }
+
+ bool include_internal = tensorflow::StringPiece("1") == argv[3];
+ tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal);
+ return 0;
+}
diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc
new file mode 100644
index 0000000000..e428e4f35e
--- /dev/null
+++ b/tensorflow/cc/ops/const_op.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/cc/ops/const_op.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace ops {
+
+namespace {
+const string& OpName() {
+ static const string kOpName = "Const";
+ return kOpName;
+}
+} // namespace
+
+#define DEFINE_CONST_SCALAR(TYPE) \
+ Node* Const(TYPE s, const GraphDefBuilder::Options& options) { \
+ return Const(gtl::ArraySlice<TYPE>(&s, 1), TensorShape({}), options); \
+ }
+
+#define DEFINE_CONST_VECTOR(TYPE) \
+ Node* Const(gtl::ArraySlice<TYPE> v, \
+ const GraphDefBuilder::Options& options) { \
+ return Const(v, TensorShape({static_cast<int64>(v.size())}), options); \
+ }
+
+#define DEFINE_CONST_TENSOR(TYPE, ...) \
+ Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \
+ const GraphDefBuilder::Options& options) { \
+ if (options.HaveError()) return nullptr; \
+ NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), \
+ options.op_registry()); \
+ const DataType dt = DataTypeToEnum<TYPE>::v(); \
+ if (t.size() == 1) { \
+ TensorProto proto; \
+ proto.set_dtype(dt); \
+ shape.AsProto(proto.mutable_tensor_shape()); \
+ __VA_ARGS__; \
+ node_builder.Attr("dtype", dt).Attr("value", proto); \
+ } else { \
+ Tensor tensor(dt, shape); \
+ if (tensor.NumElements() != static_cast<int64>(t.size())) { \
+ options.UpdateStatus(errors::InvalidArgument( \
+ t.size(), " values provided to Const() != ", tensor.NumElements(), \
+ " elements for shape ", shape.ShortDebugString())); \
+ } else { \
+ std::copy_n(t.data(), t.size(), tensor.flat<TYPE>().data()); \
+ node_builder.Attr("dtype", dt).Attr("value", tensor); \
+ } \
+ } \
+ return options.FinalizeBuilder(&node_builder); \
+ }
+
+#define DEFINE_CONST_IMPL(TYPE, ...) \
+ DEFINE_CONST_SCALAR(TYPE) \
+ DEFINE_CONST_VECTOR(TYPE) \
+ DEFINE_CONST_TENSOR(TYPE, __VA_ARGS__)
+
+#define DEFINE_CONST(TYPE, FIELD) \
+ DEFINE_CONST_IMPL(TYPE, proto.add_##FIELD(*t.begin());)
+
+DEFINE_CONST(float, float_val);
+DEFINE_CONST(double, double_val);
+DEFINE_CONST(int32, int_val);
+DEFINE_CONST(uint8, int_val);
+DEFINE_CONST(int16, int_val);
+DEFINE_CONST(int8, int_val);
+DEFINE_CONST(int64, int64_val);
+DEFINE_CONST(bool, bool_val);
+
+DEFINE_CONST_IMPL(complex64, proto.add_scomplex_val(t.begin()->real());
+ proto.add_scomplex_val(t.begin()->imag()););
+
+Node* Const(StringPiece s, const GraphDefBuilder::Options& options) {
+ if (options.HaveError()) return nullptr;
+ NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
+ options.op_registry());
+ TensorProto proto;
+ proto.set_dtype(DT_STRING);
+ TensorShape({}).AsProto(proto.mutable_tensor_shape());
+ proto.add_string_val(s.data(), s.size());
+ node_builder.Attr("dtype", DT_STRING).Attr("value", proto);
+ return options.FinalizeBuilder(&node_builder);
+}
+
+DEFINE_CONST_VECTOR(string)
+DEFINE_CONST_TENSOR(string, proto.add_string_val(*t.begin());)
+
+#undef DEFINE_CONST
+#undef DEFINE_CONST_IMPL
+#undef DEFINE_CONST_TENSOR
+#undef DEFINE_CONST_VECTOR
+#undef DEFINE_CONST_SCALAR
+
+Node* Const(const Tensor& t, const GraphDefBuilder::Options& options) {
+ if (options.HaveError()) return nullptr;
+ NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
+ options.op_registry());
+ node_builder.Attr("dtype", t.dtype()).Attr("value", t);
+ return options.FinalizeBuilder(&node_builder);
+}
+
+Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options) {
+ if (options.HaveError()) return nullptr;
+ NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
+ options.op_registry());
+ node_builder.Attr("dtype", proto.dtype()).Attr("value", proto);
+ return options.FinalizeBuilder(&node_builder);
+}
+
+} // namespace ops
+} // namespace tensorflow
diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h
new file mode 100644
index 0000000000..1fb739b974
--- /dev/null
+++ b/tensorflow/cc/ops/const_op.h
@@ -0,0 +1,70 @@
+#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_
+#define TENSORFLOW_CC_OPS_CONST_OP_H_
+
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace ops {
+
+// If a shape is specified, you may either provide the same number of values,
+// or a single value and that value will be duplicated to fill out the Tensor.
+#define DECLARE_CONST(TYPE) \
+ Node* Const(TYPE s, const GraphDefBuilder::Options& options); /* Scalar */ \
+ Node* Const(gtl::ArraySlice<TYPE> v, \
+ const GraphDefBuilder::Options& options); /* Vector */ \
+ Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \
+ const GraphDefBuilder::Options& options); /* Tensor */ \
+ inline Node* Const(std::initializer_list<TYPE> v, /* Vector using {...} */ \
+ const GraphDefBuilder::Options& options) { \
+ return Const(gtl::ArraySlice<TYPE>(v), options); \
+ } \
+ inline Node* Const(std::initializer_list<TYPE> t, /* Tensor using {...} */ \
+ const TensorShape& shape, \
+ const GraphDefBuilder::Options& options) { \
+ return Const(gtl::ArraySlice<TYPE>(t), shape, options); \
+ }
+
+DECLARE_CONST(float);
+DECLARE_CONST(double);
+DECLARE_CONST(int32);
+DECLARE_CONST(uint8);
+DECLARE_CONST(int16);
+DECLARE_CONST(int8);
+DECLARE_CONST(complex64);
+DECLARE_CONST(int64);
+DECLARE_CONST(bool);
+
+#undef DECLARE_CONST
+
+// String
+Node* Const(StringPiece s, const GraphDefBuilder::Options& options);
+Node* Const(gtl::ArraySlice<string> v, const GraphDefBuilder::Options& options);
+Node* Const(gtl::ArraySlice<string> t, const TensorShape& shape,
+ const GraphDefBuilder::Options& options);
+inline Node* Const(std::initializer_list<string> v,
+ const GraphDefBuilder::Options& options) {
+ return Const(gtl::ArraySlice<string>(v), options);
+}
+inline Node* Const(std::initializer_list<string> t, const TensorShape& shape,
+ const GraphDefBuilder::Options& options) {
+ return Const(gtl::ArraySlice<string>(t), shape, options);
+}
+
+// A Tensor of any type.
+Node* Const(const Tensor& t, const GraphDefBuilder::Options& options);
+Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options);
+
+template <class T>
+Node* EmptyConst(const GraphDefBuilder::Options& options) {
+ return Const(gtl::ArraySlice<T>(), options);
+}
+
+// TODO(josh11b): Support other types (e.g. quantized ints, float16).
+
+} // namespace ops
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CC_OPS_CONST_OP_H_
diff --git a/tensorflow/cc/ops/functional_grad.cc b/tensorflow/cc/ops/functional_grad.cc
new file mode 100644
index 0000000000..28b8b4a0e5
--- /dev/null
+++ b/tensorflow/cc/ops/functional_grad.cc
@@ -0,0 +1,42 @@
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+typedef FunctionDefHelper FDH;
+
+Status MapAccumulateGrad(const AttrSlice& attrs, FunctionDef* ret) {
+ const NameAttrList* func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "f", &func));
+ DataType T;
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
+ int k;
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "K", &k));
+ // The gradient function of f.
+ // f : (K*T, T, T) -> T
+ // g : (K*T, T, T, T) -> (K*T, T, T)
+ auto grad = FDH::FunctionRef("SymbolicGradient",
+ {{"f", *func},
+ {"Tin", std::vector<DataType>(k + 3, T)},
+ {"Tout", std::vector<DataType>(k + 2, T)}});
+ *ret = FDH::Define(
+ // Arg defs
+ {"theta: K*T", "x: T", "u: T", "dy: T"},
+ // Ret val defs
+ {"dtheta: K*T", "dx: T", "du: T"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // nodes.
+ {{{"y"},
+ "MapAccumulate",
+ {"theta", "x", "u"},
+ {{"f", *func}, {"T", "$T"}, {"K", k}}},
+ {{"dtheta", "dx", "du"},
+ "MapAccumulateGrad",
+ {"theta", "x", "u", "y", "dy"},
+ {{"g", grad}, {"T", "$T"}, {"K", k}}}});
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("MapAccumulate", MapAccumulateGrad);
+
+} // end namespace tensorflow
diff --git a/tensorflow/cc/ops/math_grad.cc b/tensorflow/cc/ops/math_grad.cc
new file mode 100644
index 0000000000..4e8baa0d10
--- /dev/null
+++ b/tensorflow/cc/ops/math_grad.cc
@@ -0,0 +1,566 @@
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+typedef FunctionDefHelper FDH;
+
+// Cwise binary ops
+Status GradForUnaryCwise(FunctionDef* g, std::vector<FDH::Node> nodes) {
+ for (auto& n : nodes) {
+ if (n.attr.empty()) {
+ n.attr = {{"T", "$T"}};
+ }
+ }
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "dy: T"},
+ // Ret val defs
+ {"dx: T"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // Nodes
+ nodes);
+ return Status::OK();
+}
+
+Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"sign"}, "Sign", {"x"}},
+ {{"dx"}, "Mul", {"dy", "sign"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Abs", AbsGrad);
+
+Status NegGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"dx"}, "Neg", {"dy"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Neg", NegGrad);
+
+Status InvGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"y"}, "Inv", {"x"}},
+ {{"y2"}, "Square", {"y"}},
+ {{"y2_neg"}, "Neg", {"y2"}},
+ {{"dx"}, "Mul", {"dy", "y2_neg"}}
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Inv", InvGrad);
+
+Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ FDH::Const("c", 2LL),
+ {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"x2"}, "Mul", {"x", "two"}}, // x * 2
+ {{"dx"}, "Mul", {"dy", "x2"}}, // dy * (x * 2)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Square", SquareGrad);
+
+Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"y"}, "Sqrt", {"x"}},
+ {{"y_inv"}, "Inv", {"y"}},
+ FDH::Const("const", 0.5f),
+ {{"half"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
+ {{"a"}, "Mul", {"half", "y_inv"}}, // .5 * 1/y
+ {{"dx"}, "Mul", {"dy", "a"}}, // dy * (.5 * 1/y)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Sqrt", SqrtGrad);
+
+Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"x_inv"}, "Inv", {"x"}},
+ {{"y"}, "Rsqrt", {"x"}},
+ FDH::Const("const", -.5f),
+ {{"neghalf"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
+ {{"a"}, "Mul", {"neghalf", "x_inv"}}, // -0.5 * 1/x
+ {{"b"}, "Mul", {"a", "y"}}, // -0.5 * 1/x * y
+ {{"dx"}, "Mul", {"dy", "b"}}, // dy * (1/y * .5)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Rsqrt", RsqrtGrad);
+
+Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"y"}, "Exp", {"x"}},
+ {{"dx"}, "Mul", {"dy", "y"}}, // dy * y
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Exp", ExpGrad);
+
+Status LogGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"x_inv"}, "Inv", {"x"}},
+ {{"dx"}, "Mul", {"dy", "x_inv"}}, // dy * 1/x
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Log", LogGrad);
+
+Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"y"}, "Tanh", {"x"}},
+ {{"y2"}, "Square", {"y"}},
+ FDH::Const("const", 1.0f),
+ {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
+ {{"a"}, "Sub", {"one", "y2"}},
+ {{"dx"}, "Mul", {"dy", "a"}}, // dy * (1 - y*y)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Tanh", TanhGrad);
+
+Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"y"}, "Sigmoid", {"x"}},
+ FDH::Const("const", 1.0f),
+ {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
+ {{"a"}, "Sub", {"one", "y"}},
+ {{"b"}, "Mul", {"y", "a"}}, // y * (1 - y)
+ {{"dx"}, "Mul", {"dy", "b"}}, // dy * y * (1 - y)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Sigmoid", SigmoidGrad);
+
+Status SignGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"s"}, "Shape", {"x"}},
+ FDH::Const("zero", 0.f),
+ {{"val"}, "Cast", {"zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
+ {{"dx"}, "Fill", {"s", "val"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Sign", SignGrad);
+
+Status SinGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"cos"}, "Cos", {"x"}},
+ {{"dx"}, "Mul", {"dy", "cos"}}, // dy * cos(x)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Sin", SinGrad);
+
+Status CosGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"sin"}, "Sin", {"x"}},
+ {{"neg"}, "Neg", {"sin"}},
+ {{"dx"}, "Mul", {"dy", "neg"}}, // dy * (-sin(x))
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Cos", CosGrad);
+
+Status RealGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ FDH::Const("zero", 0.f),
+ {{"dx"}, "Complex", {"dy", "zero"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Real", RealGrad);
+
+Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ FDH::Const("zero", 0.f),
+ {{"dx"}, "Complex", {"zero", "dy"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Imag", ImagGrad);
+
+Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"dx"}, "Conj", {"dy"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Conj", ConjGrad);
+
+// Cwise binary ops
+//
+// TODO(zhifengc): This can be arrange as a function in the standard
+// library.
+Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
+ // clang-format off
+ std::vector<FDH::Node> nodes = {
+ {{"sx"}, "Shape", {"x"}},
+ {{"sy"}, "Shape", {"y"}},
+ };
+ nodes.insert(nodes.end(), body.begin(), body.end());
+ std::vector<FDH::Node> reshapes = {
+ {{"sum_gx"}, "Sum", {"gx", "rx"}},
+ {{"dx"}, "Reshape", {"sum_gx", "sx"}},
+ {{"sum_gy"}, "Sum", {"gy", "ry"}},
+ {{"dy"}, "Reshape", {"sum_gy", "sy"}},
+ };
+ nodes.insert(nodes.end(), reshapes.begin(), reshapes.end());
+
+ // clang-format on
+ for (auto& n : nodes) {
+ if (n.attr.empty()) {
+ n.attr = {{"T", "$T"}};
+ }
+ }
+ // "BroadcastGradientArgs" doesn't need any attrs.
+ nodes.push_back({{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}});
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "y: T", "dz: T"},
+ // Ret val defs
+ {"dx: T", "dy: T"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // Nodes
+ nodes);
+ return Status::OK();
+}
+
+Status AddGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"gx"}, "Identity", {"dz"}},
+ {{"gy"}, "Identity", {"dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Add", AddGrad);
+
+Status SubGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"gx"}, "Identity", {"dz"}},
+ {{"gy"}, "Neg", {"dz"}}, // -dz
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Sub", SubGrad);
+
+Status MulGrad(const AttrSlice& attrs, FunctionDef* g) {
+ DataType T;
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
+ if (T == DT_COMPLEX64) {
+ return GradForBinaryCwise(
+ g, {
+ {{"cy"}, "Conj", {"y"}},
+ {{"gx"}, "Mul", {"dz", "cy"}}, // dz * Conj(y)
+ {{"cx"}, "Conj", {"x"}},
+ {{"gy"}, "Mul", {"cx", "dz"}}, // Conj(x) * dz
+ });
+ } else {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"gx"}, "Mul", {"dz", "y"}}, // dz * y
+ {{"gy"}, "Mul", {"x", "dz"}}, // x * dz
+ });
+ // clang-format on
+ }
+}
+REGISTER_OP_GRADIENT("Mul", MulGrad);
+
+Status DivGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"gx"}, "Div", {"dz", "y"}},
+ {{"nx"}, "Neg", {"x"}},
+ {{"y2"}, "Square", {"y"}},
+ {{"nx_y2"}, "Div", {"nx", "y2"}},
+ {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Div", DivGrad);
+
+Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"z"}, "Pow", {"x", "y"}},
+ // dz * y * Pow(x, y - 1)
+ FDH::Const("const", 1.0f),
+ {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
+ {{"t0"}, "Sub", {"y", "one"}},
+ {{"t1"}, "Pow", {"x", "t0"}},
+ {{"t2"}, "Mul", {"dz", "y"}},
+ {{"gx"}, "Mul", {"t1", "t2"}},
+ // dz * z * Log(x)
+ {{"t3"}, "Log", {"x"}},
+ {{"t4"}, "Mul", {"dz", "z"}},
+ {{"gy"}, "Mul", {"t3", "t4"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Pow", PowGrad);
+
+Status MaximumMinimumGradHelper(const string& comparator,
+ const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"c"}, comparator, {"x", "y"}},
+ {{"mask"}, "Cast", {"c"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"gx"}, "Mul", {"dz", "mask"}},
+ {{"gy"}, "Sub", {"dz", "gx"}},
+ });
+ // clang-format on
+}
+
+Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) {
+ return MaximumMinimumGradHelper("GreaterEqual", attrs, g);
+}
+REGISTER_OP_GRADIENT("Maximum", MaximumGrad);
+
+Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) {
+ return MaximumMinimumGradHelper("LessEqual", attrs, g);
+}
+REGISTER_OP_GRADIENT("Minimum", MinimumGrad);
+
+Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"gx"}, "Real", {"dz"}},
+ {{"gy"}, "Imag", {"dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Complex", ComplexGrad);
+
+// Cwise ternary ops.
+Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ {"c:bool", "x:T", "y:T", "dz:T"},
+ {"dc:bool", "dx:T", "dy:T"},
+ {{"T: {float, double}"}},
+ {
+ {{"dc"}, "ZerosLike", {"c"}, {{"T", DT_BOOL}}},
+ {{"zeros"}, "ZerosLike", {"x"}, {{"T", "$T"}}},
+ {{"dx"}, "Select", {"c", "dz", "zeros"}, {{"T", "$T"}}},
+ {{"dy"}, "Select", {"c", "zeros", "dz"}, {{"T", "$T"}}},
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("Select", SelectGrad);
+
+// N-ry ops
+// REGISTER_OP_GRADIENT("AddN", AddNGrad);
+
+// Reduction ops
+//
+// TODO(zhifengc): This helper is pretty ugly. Do something better.
+// TODO(zhifengc): This can be arrange as a function in the standard library.
+Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
+ // Shape manipulation nodes.
+
+ // clang-format off
+ std::vector<FDH::Node> nodes = {
+ {{"x_shape"}, "Shape", {"x"}},
+ {{"x_rank"}, "Rank", {"x"}},
+ {{"i_shape"}, "Shape", {"i"}, {{"T", DT_INT32}}},
+ FDH::Const("zero", 0),
+ FDH::Const("one", 1),
+ // stitch_idx0 = Range(0, x_rank, 1)
+ {{"stitch_idx1"}, "Identity", {"i"}, {{"T", DT_INT32}}},
+ {{"stitch_idx"}, "_ListToArray", {"stitch_idx0", "stitch_idx1"},
+ {{"Tin", DataTypeSlice{DT_INT32, DT_INT32}},
+ {"T", DT_INT32}, {"N", 2}}},
+ {{"stitch_val0"}, "Identity", {"x_shape"}, {{"T", DT_INT32}}},
+ {{"stitch_val1"}, "Fill", {"i_shape", "one"}, {{"T", DT_INT32}}},
+ {{"stitch_val"}, "_ListToArray", {"stitch_val0", "stitch_val1"},
+ {{"Tin", DataTypeSlice{DT_INT32, DT_INT32}},
+ {"T", DT_INT32}, {"N", 2}}},
+ {{"y_shape"}, "DynamicStitch", {"stitch_idx", "stitch_val"},
+ {{"N", 2}, {"T", DT_INT32}}},
+ {{"tile_scaling"}, "Div", {"x_shape", "y_shape"}, {{"T", DT_INT32}}},
+ {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
+ };
+ // clang-format on
+ nodes.insert(nodes.end(), body.begin(), body.end());
+ for (auto& n : nodes) {
+ if (n.attr.empty()) {
+ n.attr = {{"T", "$T"}};
+ }
+ }
+ // "Range" doesn't need any attr.
+ nodes.push_back({{"stitch_idx0"}, "Range", {"zero", "x_rank", "one"}, {}});
+ *g = FDH::Define(
+ // Arg defs
+ {"x:T", "i:int32", "dy:T"},
+ // Ret val defs
+ {"dx:T", "di:int32"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // Nodes
+ nodes);
+ return Status::OK();
+}
+
+Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForReductionOp(g, {
+ {{"dy_reshaped"}, "Reshape", {"dy", "y_shape"}},
+ {{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}},
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("Sum", SumGrad);
+
+Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForReductionOp(g, {
+ {{"factor"}, "Prod", {"tile_scaling", "zero"}, {{"T", DT_INT32}}},
+ {{"factor_T"}, "Cast", {"factor"}, {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
+ {{"dy_scaled"}, "Div", {"dy", "factor_T"}},
+ {{"dy_reshaped"}, "Reshape", {"dy_scaled", "y_shape"}},
+ {{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}},
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("Mean", MeanGrad);
+
+// REGISTER_OP_GRADIENT("Prod", ProdGrad);
+// REGISTER_OP_GRADIENT("SegmentSum", SegmentSumGrad);
+// REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad);
+// REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad);
+// REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad);
+// REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
+// REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
+// REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
+
+Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
+ FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"x:T", "i:int32", "dy:T"},
+ // Ret val defs
+ {"dx:T", "di:int32"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ {
+ // keep_dims because we need to do x == y, which requries x
+ // and y are broadcastable.
+ {{"y"}, op, {"x", "i"}, {{"T", "$T"}, {"keep_dims", true}}},
+ {{"mask"}, "Equal", {"x", "y"}, {{"T", "$T"}}},
+ {{"mask_cast"}, "Cast", {"mask"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"mask_sum"}, "Sum", {"mask_cast", "i"}, {{"T", "$T"}}},
+ {{"norm_dy"}, "Div", {"dy", "mask_sum"}, {{"T", "$T"}}},
+ {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}},
+ {{"norm_dy_reshaped"}, "Reshape", {"norm_dy", "sy"}, {{"T", "$T"}}},
+ {{"dx"}, "Mul", {"mask_cast", "norm_dy_reshaped"}, {{"T", "$T"}}},
+ {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
+ });
+ // clang-format on
+ return Status::OK();
+}
+
+Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) {
+ return MinMaxGradHelper("Max", attrs, g);
+}
+REGISTER_OP_GRADIENT("Max", MaxGrad);
+
+Status MinGrad(const AttrSlice& attrs, FunctionDef* g) {
+ return MinMaxGradHelper("Min", attrs, g);
+}
+REGISTER_OP_GRADIENT("Min", MinGrad);
+
+static Status MatMulGradHelper(FunctionDef* g, const string& x0, bool tx0,
+ const string& x1, bool tx1, const string& y0,
+ bool ty0, const string& y1, bool ty1) {
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "y: T", "dz: T"},
+ // Ret val defs
+ {"dx: T", "dy: T"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // Nodes
+ {
+ {{"dx"},
+ "MatMul",
+ {x0, x1},
+ {{"T", "$T"}, {"transpose_a", tx0}, {"transpose_b", tx1}}},
+ {{"dy"},
+ "MatMul",
+ {y0, y1},
+ {{"T", "$T"}, {"transpose_a", ty0}, {"transpose_b", ty1}}},
+ });
+ return Status::OK();
+}
+
+Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
+ DataType T;
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
+ if (T == DT_COMPLEX64) {
+ return errors::Unimplemented(
+ "MatMul gradient for complex is not supported yet.");
+ }
+ bool ta;
+ bool tb;
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "transpose_a", &ta));
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "transpose_b", &tb));
+ if (!ta && !tb) {
+ return MatMulGradHelper(g, "dz", false, "y", true, "x", true, "dz", false);
+ }
+ if (!ta && tb) {
+ return MatMulGradHelper(g, "dz", false, "y", false, "dz", true, "x", false);
+ }
+ if (ta && !tb) {
+ return MatMulGradHelper(g, "y", false, "dz", true, "x", false, "dz", false);
+ }
+ CHECK(ta && tb);
+ return MatMulGradHelper(g, "y", true, "dz", true, "dz", true, "x", true);
+}
+REGISTER_OP_GRADIENT("MatMul", MatMulGrad);
+
+// REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad);
+// REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad);
+
+// Comparison ops.
+REGISTER_OP_NO_GRADIENT("Less");
+REGISTER_OP_NO_GRADIENT("LessEqual");
+REGISTER_OP_NO_GRADIENT("Greater");
+REGISTER_OP_NO_GRADIENT("GreaterEqual");
+REGISTER_OP_NO_GRADIENT("Equal");
+REGISTER_OP_NO_GRADIENT("NotEqual");
+
+// Logical ops.
+REGISTER_OP_NO_GRADIENT("LogicalAnd");
+REGISTER_OP_NO_GRADIENT("LogicalOr");
+REGISTER_OP_NO_GRADIENT("LogicalNot");
+
+// Sequence generation ops.
+REGISTER_OP_NO_GRADIENT("Range");
+REGISTER_OP_NO_GRADIENT("LinSpace");
+
+} // end namespace tensorflow
diff --git a/tensorflow/cc/ops/nn_grad.cc b/tensorflow/cc/ops/nn_grad.cc
new file mode 100644
index 0000000000..89b037e3c8
--- /dev/null
+++ b/tensorflow/cc/ops/nn_grad.cc
@@ -0,0 +1,55 @@
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+typedef FunctionDefHelper FDH;
+
+Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "dy: T"},
+ // Ret val defs
+ {"dx: T"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // Nodes
+ {
+ {{"dx"}, "ReluGrad", {"dy", "x"}, {{"T", "$T"}}}
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("Relu", ReluGrad);
+
+Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"features: T", "labels: T", "dcost_dloss: T", "donotcare: T"},
+ // Ret val defs
+ {"dcost_dfeatures: T", "dcost_dlabels: T"},
+ // Attr defs
+ {{"T: {float, double}"}},
+ // Nodes
+ {
+ // _, dloss_dfeatures = CrossEntropy(features, labels)
+ {{"donotcare_loss", "dloss_dfeatures"}, "CrossEntropy",
+ {"features", "labels"}, {{"T", "$T"}}},
+ // dcost_dloss is of shape [batch_size].
+ // dcost_dloss_mat is of shape [batch_size, 1].
+ FDH::Const("neg1", -1),
+ {{"dcost_dloss_mat"}, "ExpandDims", {"dcost_dloss", "neg1"},
+ {{"T", "$T"}}},
+ // chain rule: dcost/dfeatures = dcost/dloss * dloss/dfeatures
+ {{"dcost_dfeatures"}, "Mul", {"dcost_dloss_mat", "dloss_dfeatures"},
+ {{"T", "$T"}}},
+ {{"dcost_dlabels"}, "ZerosLike", {"labels"}, {{"T", "$T"}}},
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad);
+
+} // end namespace tensorflow
diff --git a/tensorflow/cc/ops/standard_ops.h b/tensorflow/cc/ops/standard_ops.h
new file mode 100644
index 0000000000..8d7160a7f9
--- /dev/null
+++ b/tensorflow/cc/ops/standard_ops.h
@@ -0,0 +1,26 @@
+// #include this file to get access to the standard set of C++ graph
+// definition libraries.
+
+#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_
+#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_
+
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/attention_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/data_flow_ops.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/io_ops.h"
+#include "tensorflow/cc/ops/linalg_ops.h"
+#include "tensorflow/cc/ops/logging_ops.h"
+#include "tensorflow/cc/ops/math_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/parsing_ops.h"
+#include "tensorflow/cc/ops/random_ops.h"
+#include "tensorflow/cc/ops/sparse_ops.h"
+#include "tensorflow/cc/ops/state_ops.h"
+#include "tensorflow/cc/ops/string_ops.h"
+#include "tensorflow/cc/ops/summary_ops.h"
+#include "tensorflow/cc/ops/training_ops.h"
+#include "tensorflow/cc/ops/user_ops.h"
+
+#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_
diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc
new file mode 100644
index 0000000000..49046dd220
--- /dev/null
+++ b/tensorflow/cc/tutorials/example_trainer.cc
@@ -0,0 +1,146 @@
+#include <cstdio>
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/default_device.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/command_line_flags.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace example {
+
+struct Options {
+ int num_concurrent_sessions = 10; // The number of concurrent sessions
+ int num_concurrent_steps = 10; // The number of concurrent steps
+ int num_iterations = 100; // Each step repeats this many times
+ bool use_gpu = false; // Whether to use gpu in the training
+};
+
+TF_DEFINE_int32(num_concurrent_sessions, 10, "Number of concurrent sessions");
+TF_DEFINE_int32(num_concurrent_steps, 10, "Number of concurrent steps");
+TF_DEFINE_int32(num_iterations, 100, "Number of iterations");
+TF_DEFINE_bool(use_gpu, false, "Whether to use gpu in the training");
+
+// A = [3 2; -1 0]; x = rand(2, 1);
+// We want to compute the largest eigenvalue for A.
+// repeat x = y / y.norm(); y = A * x; end
+GraphDef CreateGraphDef() {
+ // TODO(jeff,opensource): This should really be a more interesting
+ // computation. Maybe turn this into an mnist model instead?
+ GraphDefBuilder b;
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ // Store rows [3, 2] and [-1, 0] in row major format.
+ Node* a = Const({3.f, 2.f, -1.f, 0.f}, {2, 2}, b.opts());
+
+ // x is from the feed.
+ Node* x = Const({0.f}, {2, 1}, b.opts().WithName("x"));
+
+ // y = A * x
+ Node* y = MatMul(a, x, b.opts().WithName("y"));
+
+ // y2 = y.^2
+ Node* y2 = Square(y, b.opts());
+
+ // y2_sum = sum(y2)
+ Node* y2_sum = Sum(y2, Const(0, b.opts()), b.opts());
+
+ // y_norm = sqrt(y2_sum)
+ Node* y_norm = Sqrt(y2_sum, b.opts());
+
+ // y_normalized = y ./ y_norm
+ Div(y, y_norm, b.opts().WithName("y_normalized"));
+
+ GraphDef def;
+ TF_CHECK_OK(b.ToGraphDef(&def));
+ return def;
+}
+
+string DebugString(const Tensor& x, const Tensor& y) {
+ CHECK_EQ(x.NumElements(), 2);
+ CHECK_EQ(y.NumElements(), 2);
+ auto x_flat = x.flat<float>();
+ auto y_flat = y.flat<float>();
+ const float lambda = y_flat(0) / x_flat(0);
+ return strings::Printf("lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]",
+ lambda, x_flat(0), x_flat(1), y_flat(0), y_flat(1));
+}
+
+void ConcurrentSteps(const Options* opts, int session_index) {
+ // Creates a session.
+ SessionOptions options;
+ std::unique_ptr<Session> session(NewSession(options));
+ GraphDef def = CreateGraphDef();
+ if (options.target.empty()) {
+ graph::SetDefaultDevice(opts->use_gpu ? "/gpu:0" : "/cpu:0", &def);
+ }
+
+ TF_CHECK_OK(session->Create(def));
+
+ // Spawn M threads for M concurrent steps.
+ const int M = opts->num_concurrent_steps;
+ thread::ThreadPool step_threads(Env::Default(), "trainer", M);
+
+ for (int step = 0; step < M; ++step) {
+ step_threads.Schedule([&session, opts, session_index, step]() {
+ // Randomly initialize the input.
+ Tensor x(DT_FLOAT, TensorShape({2, 1}));
+ x.flat<float>().setRandom();
+
+ // Iterations.
+ std::vector<Tensor> outputs;
+ for (int iter = 0; iter < opts->num_iterations; ++iter) {
+ outputs.clear();
+ TF_CHECK_OK(
+ session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs));
+ CHECK_EQ(2, outputs.size());
+
+ const Tensor& y = outputs[0];
+ const Tensor& y_norm = outputs[1];
+ // Print out lambda, x, and y.
+ std::printf("%06d/%06d %s\n", session_index, step,
+ DebugString(x, y).c_str());
+ // Copies y_normalized to x.
+ x = y_norm;
+ }
+ });
+ }
+
+ TF_CHECK_OK(session->Close());
+}
+
+void ConcurrentSessions(const Options& opts) {
+ // Spawn N threads for N concurrent sessions.
+ const int N = opts.num_concurrent_sessions;
+ thread::ThreadPool session_threads(Env::Default(), "trainer", N);
+ for (int i = 0; i < N; ++i) {
+ session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i));
+ }
+}
+
+} // end namespace example
+} // end namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ tensorflow::example::Options opts;
+ tensorflow::Status s = tensorflow::ParseCommandLineFlags(&argc, argv);
+ if (!s.ok()) {
+ LOG(FATAL) << "Error parsing command line flags: " << s.ToString();
+ }
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ opts.num_concurrent_sessions =
+ tensorflow::example::FLAGS_num_concurrent_sessions;
+ opts.num_concurrent_steps = tensorflow::example::FLAGS_num_concurrent_steps;
+ opts.num_iterations = tensorflow::example::FLAGS_num_iterations;
+ opts.use_gpu = tensorflow::example::FLAGS_use_gpu;
+ tensorflow::example::ConcurrentSessions(opts);
+}