diff options
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r-- | tensorflow/cc/ops/array_grad.cc | 32 | ||||
-rw-r--r-- | tensorflow/cc/ops/cc_op_gen.cc | 350 | ||||
-rw-r--r-- | tensorflow/cc/ops/cc_op_gen.h | 14 | ||||
-rw-r--r-- | tensorflow/cc/ops/cc_op_gen_main.cc | 34 | ||||
-rw-r--r-- | tensorflow/cc/ops/const_op.cc | 113 | ||||
-rw-r--r-- | tensorflow/cc/ops/const_op.h | 70 | ||||
-rw-r--r-- | tensorflow/cc/ops/functional_grad.cc | 42 | ||||
-rw-r--r-- | tensorflow/cc/ops/math_grad.cc | 566 | ||||
-rw-r--r-- | tensorflow/cc/ops/nn_grad.cc | 55 | ||||
-rw-r--r-- | tensorflow/cc/ops/standard_ops.h | 26 |
10 files changed, 1302 insertions, 0 deletions
diff --git a/tensorflow/cc/ops/array_grad.cc b/tensorflow/cc/ops/array_grad.cc new file mode 100644 index 0000000000..37ffed9792 --- /dev/null +++ b/tensorflow/cc/ops/array_grad.cc @@ -0,0 +1,32 @@ +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +REGISTER_OP_NO_GRADIENT("Shape"); +REGISTER_OP_NO_GRADIENT("Rank"); +REGISTER_OP_NO_GRADIENT("Size"); + +Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"x: T", "shape: int32", "dy: T"}, + // Ret val defs + {"dx: T", "dshape: int32"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + { + {{"x_shape"}, "Shape", {"x"}, {{"T", "$T"}}}, + {{"dx"}, "Reshape", {"dy", "x_shape"}, {{"T", "$T"}}}, + {{"dshape"}, "ZerosLike", {"shape"}, {{"T", DT_INT32}}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Reshape", ReshapeGrad); + +} // end namespace tensorflow diff --git a/tensorflow/cc/ops/cc_op_gen.cc b/tensorflow/cc/ops/cc_op_gen.cc new file mode 100644 index 0000000000..fdecf967f8 --- /dev/null +++ b/tensorflow/cc/ops/cc_op_gen.cc @@ -0,0 +1,350 @@ +// TODO(josh11b): Rewrite function parameter names to avoid C++ keywords +// or "opts". + +#include "tensorflow/cc/ops/cc_op_gen.h" + +#include <unordered_map> +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { +namespace { + +const int kRightMargin = 79; + +const char* AttrTypeName(StringPiece attr_type) { + static const char* kAttrTypeName[][2] = { + {"string", "StringPiece"}, + {"list(string)", "gtl::ArraySlice<string>"}, + {"int", "int64"}, + {"list(int)", "gtl::ArraySlice<int>"}, + {"float", "float"}, + {"list(float)", "gtl::ArraySlice<float>"}, + {"bool", "bool"}, + {"list(bool)", "gtl::ArraySlice<bool>"}, + {"type", "DataType"}, + {"list(type)", "DataTypeSlice"}, + {"shape", "TensorShape"}, + {"list(shape)", "gtl::ArraySlice<TensorShape>"}, + {"tensor", "const Tensor&"}, + {"list(tensor)", "gtl::ArraySlice<Tensor>"}, + {"func", "const NameAttrList&"}, + }; + for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) { + if (attr_type == kAttrTypeName[i][0]) { + return kAttrTypeName[i][1]; + } + } + LOG(FATAL) << "Unsupported Attr type: " << attr_type; + return ""; +} + +// Change: Into: +// ABC // ABC +// // +// DEF // DEF +string MakeComment(StringPiece text) { + string ret; + while (!text.empty()) { + int last_non_space = -1; + int newline; + for (newline = 0; newline < static_cast<int>(text.size()); ++newline) { + if (text[newline] == '\n') break; + if (text[newline] != ' ') last_non_space = newline; + } + if (last_non_space == -1) { + strings::StrAppend(&ret, "//\n"); + } else { + strings::StrAppend(&ret, "// ", text.substr(0, last_non_space + 1), "\n"); + } + text.remove_prefix(newline + 1); + } + return ret; +} + +void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) { + // TODO(josh11b): Better wrapping of comments. + string comment; + if (op_def.summary().empty()) { + comment = "TODO: add doc.\n"; + } else { + comment = strings::StrCat(op_def.summary(), "\n"); + if (!op_def.description().empty()) { + strings::StrAppend(&comment, "\n", op_def.description(), "\n"); + } + } + + static const string kSingleInputType = "NodeOut"; + static const string kListInputType = "gtl::ArraySlice<NodeOut>"; + + std::vector<string> arg_types; + std::vector<string> arg_names; + + strings::StrAppend(&comment, "\nArguments:\n"); + + // Map from attr name to the first input arg it is inferred from. + std::unordered_map<string, string> inferred_attrs; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + const auto& arg(op_def.input_arg(i)); + arg_names.emplace_back(arg.name()); + bool is_list = false; + + if (!arg.type_attr().empty()) { + gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name()); + } else if (!arg.type_list_attr().empty()) { + gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(), + arg.name()); + is_list = true; + } + if (!arg.number_attr().empty()) { + gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name()); + is_list = true; + } + if (is_list) { + arg_types.emplace_back(kListInputType); + } else { + arg_types.emplace_back(kSingleInputType); + } + + // TODO(josh11b): Include input type information. + StringPiece description = arg.description(); + if (!description.empty()) { + ConsumeEquals(&description); + strings::StrAppend(&comment, "* ", arg_names.back(), ": ", + arg.description(), "\n"); + } + } + + string options_comment; + for (int i = 0; i < op_def.attr_size(); ++i) { + const auto& attr(op_def.attr(i)); + // Do not add inferred attrs or attrs with defaults to the C++ + // function signature. + if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) { + if (!attr.has_default_value()) { + arg_names.emplace_back(attr.name()); + arg_types.emplace_back(AttrTypeName(attr.type())); + if (!attr.description().empty()) { + strings::StrAppend(&comment, "* ", arg_names.back(), ": ", + attr.description(), "\n"); + } + } else { + strings::StrAppend(&options_comment, " .WithAttr(\"", attr.name(), + "\", ", AttrTypeName(attr.type()), "): Defaults to ", + SummarizeAttrValue(attr.default_value()), ".\n"); + if (!attr.description().empty()) { + strings::StrAppend(&options_comment, " ", attr.description(), + "\n"); + } + } + } + } + CHECK_EQ(arg_names.size(), arg_types.size()); + strings::StrAppend(&comment, "* opts:\n", options_comment, + R"comment( .WithName(StringPiece): Set the Node's name + .WithDevice(StringPiece): Set the Node's requested device + .WithControlInput(Node*) / .WithControlInputs({Node*, ...}): + Add control depencies on the specified Node(s). + +Returns a pointer to the created Node)comment"); + + // TODO(josh11b): Include output type information. + if (op_def.output_arg_size() == 0) { + strings::StrAppend(&comment, ".\n"); + } else if (op_def.output_arg_size() == 1) { + StringPiece description = op_def.output_arg(0).description(); + ConsumeEquals(&description); + if (description.empty()) { + strings::StrAppend(&comment, ".\n"); + } else { + strings::StrAppend(&comment, ", with output:\n", description, "\n"); + } + } else { + strings::StrAppend(&comment, ", with outputs:\n"); + for (int o = 0; o < op_def.output_arg_size(); ++o) { + StringPiece description = op_def.output_arg(o).description(); + ConsumeEquals(&description); + if (description.empty()) { + strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), "\n"); + } else { + strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), ": ", + description, "\n"); + } + } + } + + // Write the header comment. + TF_CHECK_OK(h->Append(MakeComment(comment))); + + // Declare the function wrapper. + const string prefix = strings::StrCat("Node* ", op_def.name(), "("); + string h_rest; + for (size_t i = 0; i < arg_names.size(); ++i) { + strings::StrAppend(&h_rest, arg_types[i], " ", arg_names[i], ", "); + } + strings::StrAppend(&h_rest, "const GraphDefBuilder::Options& opts"); + string cc_decl = h_rest; + strings::StrAppend(&h_rest, ");"); + TF_CHECK_OK(h->Append(WordWrap(prefix, h_rest, kRightMargin) + "\n\n")); + + // Define the function wrapper. + strings::StrAppend(&cc_decl, ") {"); + TF_CHECK_OK(cc->Append(WordWrap(prefix, cc_decl, kRightMargin) + "\n")); + const string op_name = strings::StrCat(" static const string kOpName = \"", + op_def.name(), "\";\n"); + + if (arg_types.empty()) { + TF_CHECK_OK(cc->Append(op_name)); + TF_CHECK_OK(cc->Append(" return SourceOp(kOpName, opts);\n}\n\n")); + } else if (arg_types == std::vector<string>({kSingleInputType})) { + TF_CHECK_OK(cc->Append(op_name)); + TF_CHECK_OK(cc->Append(strings::StrCat(" return UnaryOp(kOpName, ", + arg_names[0], ", opts);\n}\n\n"))); + } else if (arg_types == + std::vector<string>({kSingleInputType, kSingleInputType})) { + TF_CHECK_OK(cc->Append(op_name)); + // TODO(josh11b): Word wrap this if it ever becomes necessary. + TF_CHECK_OK( + cc->Append(strings::StrCat(" return BinaryOp(kOpName, ", arg_names[0], + ", ", arg_names[1], ", opts);\n}\n\n"))); + } else { + TF_CHECK_OK(cc->Append(" if (opts.HaveError()) return nullptr;\n")); + TF_CHECK_OK(cc->Append(op_name)); + TF_CHECK_OK(cc->Append( + " NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,\n" + " opts.op_registry());\n")); + for (size_t i = 0; i < arg_names.size(); ++i) { + if (i < static_cast<size_t>(op_def.input_arg_size())) { + TF_CHECK_OK(cc->Append( + strings::StrCat(" node_builder.Input(", arg_names[i], ");\n"))); + } else { + TF_CHECK_OK( + cc->Append(strings::StrCat(" node_builder.Attr(\"", arg_names[i], + "\", ", arg_names[i], ");\n"))); + } + } + TF_CHECK_OK( + cc->Append(" return opts.FinalizeBuilder(&node_builder);\n" + "}\n\n")); + } +} + +// Converts: +// bazel-out/.../genfiles/XX +// to: XX. +string GetPath(const std::string& dot_h_fname) { + auto pos = dot_h_fname.find("/genfiles/"); + if (pos == string::npos) return dot_h_fname; + // - 1 account for the terminating null character (\0) in "/genfiles/". + return dot_h_fname.substr(pos + sizeof("/genfiles/") - 1); +} + +// Converts: +// cc/ops/gen_foo_ops.h +// to: +// CC_OPS_GEN_FOO_OPS_H_ +string ToGuard(const std::string& path) { + string guard; + guard.reserve(path.size() + 1); // + 1 -> trailing _ + for (const char c : path) { + if (c >= 'A' && c <= 'Z') { + guard += c; + } else if (c >= 'a' && c <= 'z') { + guard += c + 'A' - 'a'; + } else { + guard += '_'; + } + } + guard += '_'; + return guard; +} + +} // namespace + +void WriteCCOps(const OpList& ops, const std::string& dot_h_fname, + const std::string& dot_cc_fname) { + Env* env = Env::Default(); + WritableFile* h = nullptr; + WritableFile* cc = nullptr; + TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h)); + TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc)); + + // .h Header + const string include = GetPath(dot_h_fname); + const string guard = ToGuard(include); + // TODO(josh11b): Mention the library for which wrappers are being generated. + Status s; + s = h->Append( + strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n" + "#ifndef ", + guard, + "\n" + "#define ", + guard, R"header( + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { +namespace ops { + +// These add a node to the graph from opts. +// +// Note for "NodeOut" inputs, you will typically either pass +// * a {Node*, int index} (to pass the index-th output of that node), or +// * a Node* (to pass the first output of that node). + + +)header")); + TF_CHECK_OK(s); + // .cc Header + s = cc->Append( + strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n" + "#include \"", + include, R"header(" + +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace ops { + +)header")); + TF_CHECK_OK(s); + + for (const auto& op_def : ops.op()) { + WriteCCOp(op_def, h, cc); + } + + // .h Footer + + s = h->Append(strings::StrCat(R"footer(} // namespace ops +} // namespace tensorflow + +#endif // )footer", + guard, "\n")); + TF_CHECK_OK(s); + + // .cc Footer + + s = cc->Append(R"footer(} // namespace ops +} // namespace tensorflow +)footer"); + TF_CHECK_OK(s); + + TF_CHECK_OK(cc->Close()); + TF_CHECK_OK(h->Close()); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/ops/cc_op_gen.h b/tensorflow/cc/ops/cc_op_gen.h new file mode 100644 index 0000000000..1a9474ec87 --- /dev/null +++ b/tensorflow/cc/ops/cc_op_gen.h @@ -0,0 +1,14 @@ +#ifndef TENSORFLOW_CC_OPS_CC_OP_GEN_H_ +#define TENSORFLOW_CC_OPS_CC_OP_GEN_H_ + +#include "tensorflow/core/framework/op_def.pb.h" + +namespace tensorflow { + +// Result is written to files dot_h and dot_cc. +void WriteCCOps(const OpList& ops, const std::string& dot_h_fname, + const std::string& dot_cc_fname); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_OPS_CC_OP_GEN_H_ diff --git a/tensorflow/cc/ops/cc_op_gen_main.cc b/tensorflow/cc/ops/cc_op_gen_main.cc new file mode 100644 index 0000000000..b9f0e4a9bd --- /dev/null +++ b/tensorflow/cc/ops/cc_op_gen_main.cc @@ -0,0 +1,34 @@ +#include "tensorflow/cc/ops/cc_op_gen.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace { + +void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, + bool include_internal) { + OpList ops; + OpRegistry::Global()->Export(include_internal, &ops); + WriteCCOps(ops, dot_h, dot_cc); +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char* argv[]) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc != 4) { + fprintf(stderr, + "Usage: %s out.h out.cc include_internal\n" + " include_internal: 1 means include internal ops\n", + argv[0]); + exit(1); + } + + bool include_internal = tensorflow::StringPiece("1") == argv[3]; + tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal); + return 0; +} diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc new file mode 100644 index 0000000000..e428e4f35e --- /dev/null +++ b/tensorflow/cc/ops/const_op.cc @@ -0,0 +1,113 @@ +#include "tensorflow/cc/ops/const_op.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace ops { + +namespace { +const string& OpName() { + static const string kOpName = "Const"; + return kOpName; +} +} // namespace + +#define DEFINE_CONST_SCALAR(TYPE) \ + Node* Const(TYPE s, const GraphDefBuilder::Options& options) { \ + return Const(gtl::ArraySlice<TYPE>(&s, 1), TensorShape({}), options); \ + } + +#define DEFINE_CONST_VECTOR(TYPE) \ + Node* Const(gtl::ArraySlice<TYPE> v, \ + const GraphDefBuilder::Options& options) { \ + return Const(v, TensorShape({static_cast<int64>(v.size())}), options); \ + } + +#define DEFINE_CONST_TENSOR(TYPE, ...) \ + Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \ + const GraphDefBuilder::Options& options) { \ + if (options.HaveError()) return nullptr; \ + NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), \ + options.op_registry()); \ + const DataType dt = DataTypeToEnum<TYPE>::v(); \ + if (t.size() == 1) { \ + TensorProto proto; \ + proto.set_dtype(dt); \ + shape.AsProto(proto.mutable_tensor_shape()); \ + __VA_ARGS__; \ + node_builder.Attr("dtype", dt).Attr("value", proto); \ + } else { \ + Tensor tensor(dt, shape); \ + if (tensor.NumElements() != static_cast<int64>(t.size())) { \ + options.UpdateStatus(errors::InvalidArgument( \ + t.size(), " values provided to Const() != ", tensor.NumElements(), \ + " elements for shape ", shape.ShortDebugString())); \ + } else { \ + std::copy_n(t.data(), t.size(), tensor.flat<TYPE>().data()); \ + node_builder.Attr("dtype", dt).Attr("value", tensor); \ + } \ + } \ + return options.FinalizeBuilder(&node_builder); \ + } + +#define DEFINE_CONST_IMPL(TYPE, ...) \ + DEFINE_CONST_SCALAR(TYPE) \ + DEFINE_CONST_VECTOR(TYPE) \ + DEFINE_CONST_TENSOR(TYPE, __VA_ARGS__) + +#define DEFINE_CONST(TYPE, FIELD) \ + DEFINE_CONST_IMPL(TYPE, proto.add_##FIELD(*t.begin());) + +DEFINE_CONST(float, float_val); +DEFINE_CONST(double, double_val); +DEFINE_CONST(int32, int_val); +DEFINE_CONST(uint8, int_val); +DEFINE_CONST(int16, int_val); +DEFINE_CONST(int8, int_val); +DEFINE_CONST(int64, int64_val); +DEFINE_CONST(bool, bool_val); + +DEFINE_CONST_IMPL(complex64, proto.add_scomplex_val(t.begin()->real()); + proto.add_scomplex_val(t.begin()->imag());); + +Node* Const(StringPiece s, const GraphDefBuilder::Options& options) { + if (options.HaveError()) return nullptr; + NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), + options.op_registry()); + TensorProto proto; + proto.set_dtype(DT_STRING); + TensorShape({}).AsProto(proto.mutable_tensor_shape()); + proto.add_string_val(s.data(), s.size()); + node_builder.Attr("dtype", DT_STRING).Attr("value", proto); + return options.FinalizeBuilder(&node_builder); +} + +DEFINE_CONST_VECTOR(string) +DEFINE_CONST_TENSOR(string, proto.add_string_val(*t.begin());) + +#undef DEFINE_CONST +#undef DEFINE_CONST_IMPL +#undef DEFINE_CONST_TENSOR +#undef DEFINE_CONST_VECTOR +#undef DEFINE_CONST_SCALAR + +Node* Const(const Tensor& t, const GraphDefBuilder::Options& options) { + if (options.HaveError()) return nullptr; + NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), + options.op_registry()); + node_builder.Attr("dtype", t.dtype()).Attr("value", t); + return options.FinalizeBuilder(&node_builder); +} + +Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options) { + if (options.HaveError()) return nullptr; + NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), + options.op_registry()); + node_builder.Attr("dtype", proto.dtype()).Attr("value", proto); + return options.FinalizeBuilder(&node_builder); +} + +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h new file mode 100644 index 0000000000..1fb739b974 --- /dev/null +++ b/tensorflow/cc/ops/const_op.h @@ -0,0 +1,70 @@ +#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_ +#define TENSORFLOW_CC_OPS_CONST_OP_H_ + +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace ops { + +// If a shape is specified, you may either provide the same number of values, +// or a single value and that value will be duplicated to fill out the Tensor. +#define DECLARE_CONST(TYPE) \ + Node* Const(TYPE s, const GraphDefBuilder::Options& options); /* Scalar */ \ + Node* Const(gtl::ArraySlice<TYPE> v, \ + const GraphDefBuilder::Options& options); /* Vector */ \ + Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \ + const GraphDefBuilder::Options& options); /* Tensor */ \ + inline Node* Const(std::initializer_list<TYPE> v, /* Vector using {...} */ \ + const GraphDefBuilder::Options& options) { \ + return Const(gtl::ArraySlice<TYPE>(v), options); \ + } \ + inline Node* Const(std::initializer_list<TYPE> t, /* Tensor using {...} */ \ + const TensorShape& shape, \ + const GraphDefBuilder::Options& options) { \ + return Const(gtl::ArraySlice<TYPE>(t), shape, options); \ + } + +DECLARE_CONST(float); +DECLARE_CONST(double); +DECLARE_CONST(int32); +DECLARE_CONST(uint8); +DECLARE_CONST(int16); +DECLARE_CONST(int8); +DECLARE_CONST(complex64); +DECLARE_CONST(int64); +DECLARE_CONST(bool); + +#undef DECLARE_CONST + +// String +Node* Const(StringPiece s, const GraphDefBuilder::Options& options); +Node* Const(gtl::ArraySlice<string> v, const GraphDefBuilder::Options& options); +Node* Const(gtl::ArraySlice<string> t, const TensorShape& shape, + const GraphDefBuilder::Options& options); +inline Node* Const(std::initializer_list<string> v, + const GraphDefBuilder::Options& options) { + return Const(gtl::ArraySlice<string>(v), options); +} +inline Node* Const(std::initializer_list<string> t, const TensorShape& shape, + const GraphDefBuilder::Options& options) { + return Const(gtl::ArraySlice<string>(t), shape, options); +} + +// A Tensor of any type. +Node* Const(const Tensor& t, const GraphDefBuilder::Options& options); +Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options); + +template <class T> +Node* EmptyConst(const GraphDefBuilder::Options& options) { + return Const(gtl::ArraySlice<T>(), options); +} + +// TODO(josh11b): Support other types (e.g. quantized ints, float16). + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_CC_OPS_CONST_OP_H_ diff --git a/tensorflow/cc/ops/functional_grad.cc b/tensorflow/cc/ops/functional_grad.cc new file mode 100644 index 0000000000..28b8b4a0e5 --- /dev/null +++ b/tensorflow/cc/ops/functional_grad.cc @@ -0,0 +1,42 @@ +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +Status MapAccumulateGrad(const AttrSlice& attrs, FunctionDef* ret) { + const NameAttrList* func; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "f", &func)); + DataType T; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); + int k; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "K", &k)); + // The gradient function of f. + // f : (K*T, T, T) -> T + // g : (K*T, T, T, T) -> (K*T, T, T) + auto grad = FDH::FunctionRef("SymbolicGradient", + {{"f", *func}, + {"Tin", std::vector<DataType>(k + 3, T)}, + {"Tout", std::vector<DataType>(k + 2, T)}}); + *ret = FDH::Define( + // Arg defs + {"theta: K*T", "x: T", "u: T", "dy: T"}, + // Ret val defs + {"dtheta: K*T", "dx: T", "du: T"}, + // Attr defs + {{"T: {float, double}"}}, + // nodes. + {{{"y"}, + "MapAccumulate", + {"theta", "x", "u"}, + {{"f", *func}, {"T", "$T"}, {"K", k}}}, + {{"dtheta", "dx", "du"}, + "MapAccumulateGrad", + {"theta", "x", "u", "y", "dy"}, + {{"g", grad}, {"T", "$T"}, {"K", k}}}}); + return Status::OK(); +} +REGISTER_OP_GRADIENT("MapAccumulate", MapAccumulateGrad); + +} // end namespace tensorflow diff --git a/tensorflow/cc/ops/math_grad.cc b/tensorflow/cc/ops/math_grad.cc new file mode 100644 index 0000000000..4e8baa0d10 --- /dev/null +++ b/tensorflow/cc/ops/math_grad.cc @@ -0,0 +1,566 @@ +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +// Cwise binary ops +Status GradForUnaryCwise(FunctionDef* g, std::vector<FDH::Node> nodes) { + for (auto& n : nodes) { + if (n.attr.empty()) { + n.attr = {{"T", "$T"}}; + } + } + *g = FDH::Define( + // Arg defs + {"x: T", "dy: T"}, + // Ret val defs + {"dx: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + nodes); + return Status::OK(); +} + +Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"sign"}, "Sign", {"x"}}, + {{"dx"}, "Mul", {"dy", "sign"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Abs", AbsGrad); + +Status NegGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"dx"}, "Neg", {"dy"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Neg", NegGrad); + +Status InvGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Inv", {"x"}}, + {{"y2"}, "Square", {"y"}}, + {{"y2_neg"}, "Neg", {"y2"}}, + {{"dx"}, "Mul", {"dy", "y2_neg"}} + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Inv", InvGrad); + +Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + FDH::Const("c", 2LL), + {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"x2"}, "Mul", {"x", "two"}}, // x * 2 + {{"dx"}, "Mul", {"dy", "x2"}}, // dy * (x * 2) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Square", SquareGrad); + +Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Sqrt", {"x"}}, + {{"y_inv"}, "Inv", {"y"}}, + FDH::Const("const", 0.5f), + {{"half"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"a"}, "Mul", {"half", "y_inv"}}, // .5 * 1/y + {{"dx"}, "Mul", {"dy", "a"}}, // dy * (.5 * 1/y) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sqrt", SqrtGrad); + +Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"x_inv"}, "Inv", {"x"}}, + {{"y"}, "Rsqrt", {"x"}}, + FDH::Const("const", -.5f), + {{"neghalf"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"a"}, "Mul", {"neghalf", "x_inv"}}, // -0.5 * 1/x + {{"b"}, "Mul", {"a", "y"}}, // -0.5 * 1/x * y + {{"dx"}, "Mul", {"dy", "b"}}, // dy * (1/y * .5) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Rsqrt", RsqrtGrad); + +Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Exp", {"x"}}, + {{"dx"}, "Mul", {"dy", "y"}}, // dy * y + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Exp", ExpGrad); + +Status LogGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"x_inv"}, "Inv", {"x"}}, + {{"dx"}, "Mul", {"dy", "x_inv"}}, // dy * 1/x + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Log", LogGrad); + +Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Tanh", {"x"}}, + {{"y2"}, "Square", {"y"}}, + FDH::Const("const", 1.0f), + {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"a"}, "Sub", {"one", "y2"}}, + {{"dx"}, "Mul", {"dy", "a"}}, // dy * (1 - y*y) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Tanh", TanhGrad); + +Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"y"}, "Sigmoid", {"x"}}, + FDH::Const("const", 1.0f), + {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"a"}, "Sub", {"one", "y"}}, + {{"b"}, "Mul", {"y", "a"}}, // y * (1 - y) + {{"dx"}, "Mul", {"dy", "b"}}, // dy * y * (1 - y) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sigmoid", SigmoidGrad); + +Status SignGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"s"}, "Shape", {"x"}}, + FDH::Const("zero", 0.f), + {{"val"}, "Cast", {"zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"dx"}, "Fill", {"s", "val"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sign", SignGrad); + +Status SinGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"cos"}, "Cos", {"x"}}, + {{"dx"}, "Mul", {"dy", "cos"}}, // dy * cos(x) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sin", SinGrad); + +Status CosGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"sin"}, "Sin", {"x"}}, + {{"neg"}, "Neg", {"sin"}}, + {{"dx"}, "Mul", {"dy", "neg"}}, // dy * (-sin(x)) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Cos", CosGrad); + +Status RealGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + FDH::Const("zero", 0.f), + {{"dx"}, "Complex", {"dy", "zero"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Real", RealGrad); + +Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + FDH::Const("zero", 0.f), + {{"dx"}, "Complex", {"zero", "dy"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Imag", ImagGrad); + +Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForUnaryCwise(g, { + {{"dx"}, "Conj", {"dy"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Conj", ConjGrad); + +// Cwise binary ops +// +// TODO(zhifengc): This can be arrange as a function in the standard +// library. +Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) { + // clang-format off + std::vector<FDH::Node> nodes = { + {{"sx"}, "Shape", {"x"}}, + {{"sy"}, "Shape", {"y"}}, + }; + nodes.insert(nodes.end(), body.begin(), body.end()); + std::vector<FDH::Node> reshapes = { + {{"sum_gx"}, "Sum", {"gx", "rx"}}, + {{"dx"}, "Reshape", {"sum_gx", "sx"}}, + {{"sum_gy"}, "Sum", {"gy", "ry"}}, + {{"dy"}, "Reshape", {"sum_gy", "sy"}}, + }; + nodes.insert(nodes.end(), reshapes.begin(), reshapes.end()); + + // clang-format on + for (auto& n : nodes) { + if (n.attr.empty()) { + n.attr = {{"T", "$T"}}; + } + } + // "BroadcastGradientArgs" doesn't need any attrs. + nodes.push_back({{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}}); + *g = FDH::Define( + // Arg defs + {"x: T", "y: T", "dz: T"}, + // Ret val defs + {"dx: T", "dy: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + nodes); + return Status::OK(); +} + +Status AddGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Identity", {"dz"}}, + {{"gy"}, "Identity", {"dz"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Add", AddGrad); + +Status SubGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Identity", {"dz"}}, + {{"gy"}, "Neg", {"dz"}}, // -dz + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Sub", SubGrad); + +Status MulGrad(const AttrSlice& attrs, FunctionDef* g) { + DataType T; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); + if (T == DT_COMPLEX64) { + return GradForBinaryCwise( + g, { + {{"cy"}, "Conj", {"y"}}, + {{"gx"}, "Mul", {"dz", "cy"}}, // dz * Conj(y) + {{"cx"}, "Conj", {"x"}}, + {{"gy"}, "Mul", {"cx", "dz"}}, // Conj(x) * dz + }); + } else { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Mul", {"dz", "y"}}, // dz * y + {{"gy"}, "Mul", {"x", "dz"}}, // x * dz + }); + // clang-format on + } +} +REGISTER_OP_GRADIENT("Mul", MulGrad); + +Status DivGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Div", {"dz", "y"}}, + {{"nx"}, "Neg", {"x"}}, + {{"y2"}, "Square", {"y"}}, + {{"nx_y2"}, "Div", {"nx", "y2"}}, + {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2) + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Div", DivGrad); + +Status PowGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"z"}, "Pow", {"x", "y"}}, + // dz * y * Pow(x, y - 1) + FDH::Const("const", 1.0f), + {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"t0"}, "Sub", {"y", "one"}}, + {{"t1"}, "Pow", {"x", "t0"}}, + {{"t2"}, "Mul", {"dz", "y"}}, + {{"gx"}, "Mul", {"t1", "t2"}}, + // dz * z * Log(x) + {{"t3"}, "Log", {"x"}}, + {{"t4"}, "Mul", {"dz", "z"}}, + {{"gy"}, "Mul", {"t3", "t4"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Pow", PowGrad); + +Status MaximumMinimumGradHelper(const string& comparator, + const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"c"}, comparator, {"x", "y"}}, + {{"mask"}, "Cast", {"c"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}}, + {{"gx"}, "Mul", {"dz", "mask"}}, + {{"gy"}, "Sub", {"dz", "gx"}}, + }); + // clang-format on +} + +Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) { + return MaximumMinimumGradHelper("GreaterEqual", attrs, g); +} +REGISTER_OP_GRADIENT("Maximum", MaximumGrad); + +Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) { + return MaximumMinimumGradHelper("LessEqual", attrs, g); +} +REGISTER_OP_GRADIENT("Minimum", MinimumGrad); + +Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"gx"}, "Real", {"dz"}}, + {{"gy"}, "Imag", {"dz"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Complex", ComplexGrad); + +// Cwise ternary ops. +Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + {"c:bool", "x:T", "y:T", "dz:T"}, + {"dc:bool", "dx:T", "dy:T"}, + {{"T: {float, double}"}}, + { + {{"dc"}, "ZerosLike", {"c"}, {{"T", DT_BOOL}}}, + {{"zeros"}, "ZerosLike", {"x"}, {{"T", "$T"}}}, + {{"dx"}, "Select", {"c", "dz", "zeros"}, {{"T", "$T"}}}, + {{"dy"}, "Select", {"c", "zeros", "dz"}, {{"T", "$T"}}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Select", SelectGrad); + +// N-ry ops +// REGISTER_OP_GRADIENT("AddN", AddNGrad); + +// Reduction ops +// +// TODO(zhifengc): This helper is pretty ugly. Do something better. +// TODO(zhifengc): This can be arrange as a function in the standard library. +Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) { + // Shape manipulation nodes. + + // clang-format off + std::vector<FDH::Node> nodes = { + {{"x_shape"}, "Shape", {"x"}}, + {{"x_rank"}, "Rank", {"x"}}, + {{"i_shape"}, "Shape", {"i"}, {{"T", DT_INT32}}}, + FDH::Const("zero", 0), + FDH::Const("one", 1), + // stitch_idx0 = Range(0, x_rank, 1) + {{"stitch_idx1"}, "Identity", {"i"}, {{"T", DT_INT32}}}, + {{"stitch_idx"}, "_ListToArray", {"stitch_idx0", "stitch_idx1"}, + {{"Tin", DataTypeSlice{DT_INT32, DT_INT32}}, + {"T", DT_INT32}, {"N", 2}}}, + {{"stitch_val0"}, "Identity", {"x_shape"}, {{"T", DT_INT32}}}, + {{"stitch_val1"}, "Fill", {"i_shape", "one"}, {{"T", DT_INT32}}}, + {{"stitch_val"}, "_ListToArray", {"stitch_val0", "stitch_val1"}, + {{"Tin", DataTypeSlice{DT_INT32, DT_INT32}}, + {"T", DT_INT32}, {"N", 2}}}, + {{"y_shape"}, "DynamicStitch", {"stitch_idx", "stitch_val"}, + {{"N", 2}, {"T", DT_INT32}}}, + {{"tile_scaling"}, "Div", {"x_shape", "y_shape"}, {{"T", DT_INT32}}}, + {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}} + }; + // clang-format on + nodes.insert(nodes.end(), body.begin(), body.end()); + for (auto& n : nodes) { + if (n.attr.empty()) { + n.attr = {{"T", "$T"}}; + } + } + // "Range" doesn't need any attr. + nodes.push_back({{"stitch_idx0"}, "Range", {"zero", "x_rank", "one"}, {}}); + *g = FDH::Define( + // Arg defs + {"x:T", "i:int32", "dy:T"}, + // Ret val defs + {"dx:T", "di:int32"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + nodes); + return Status::OK(); +} + +Status SumGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForReductionOp(g, { + {{"dy_reshaped"}, "Reshape", {"dy", "y_shape"}}, + {{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Sum", SumGrad); + +Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForReductionOp(g, { + {{"factor"}, "Prod", {"tile_scaling", "zero"}, {{"T", DT_INT32}}}, + {{"factor_T"}, "Cast", {"factor"}, {{"SrcT", DT_INT32}, {"DstT", "$T"}}}, + {{"dy_scaled"}, "Div", {"dy", "factor_T"}}, + {{"dy_reshaped"}, "Reshape", {"dy_scaled", "y_shape"}}, + {{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Mean", MeanGrad); + +// REGISTER_OP_GRADIENT("Prod", ProdGrad); +// REGISTER_OP_GRADIENT("SegmentSum", SegmentSumGrad); +// REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad); +// REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad); +// REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad); +// REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad); +// REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad); +// REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad); + +Status MinMaxGradHelper(const string& op, const AttrSlice& attrs, + FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"x:T", "i:int32", "dy:T"}, + // Ret val defs + {"dx:T", "di:int32"}, + // Attr defs + {{"T: {float, double}"}}, + { + // keep_dims because we need to do x == y, which requries x + // and y are broadcastable. + {{"y"}, op, {"x", "i"}, {{"T", "$T"}, {"keep_dims", true}}}, + {{"mask"}, "Equal", {"x", "y"}, {{"T", "$T"}}}, + {{"mask_cast"}, "Cast", {"mask"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}}, + {{"mask_sum"}, "Sum", {"mask_cast", "i"}, {{"T", "$T"}}}, + {{"norm_dy"}, "Div", {"dy", "mask_sum"}, {{"T", "$T"}}}, + {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}}, + {{"norm_dy_reshaped"}, "Reshape", {"norm_dy", "sy"}, {{"T", "$T"}}}, + {{"dx"}, "Mul", {"mask_cast", "norm_dy_reshaped"}, {{"T", "$T"}}}, + {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}} + }); + // clang-format on + return Status::OK(); +} + +Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) { + return MinMaxGradHelper("Max", attrs, g); +} +REGISTER_OP_GRADIENT("Max", MaxGrad); + +Status MinGrad(const AttrSlice& attrs, FunctionDef* g) { + return MinMaxGradHelper("Min", attrs, g); +} +REGISTER_OP_GRADIENT("Min", MinGrad); + +static Status MatMulGradHelper(FunctionDef* g, const string& x0, bool tx0, + const string& x1, bool tx1, const string& y0, + bool ty0, const string& y1, bool ty1) { + *g = FDH::Define( + // Arg defs + {"x: T", "y: T", "dz: T"}, + // Ret val defs + {"dx: T", "dy: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + { + {{"dx"}, + "MatMul", + {x0, x1}, + {{"T", "$T"}, {"transpose_a", tx0}, {"transpose_b", tx1}}}, + {{"dy"}, + "MatMul", + {y0, y1}, + {{"T", "$T"}, {"transpose_a", ty0}, {"transpose_b", ty1}}}, + }); + return Status::OK(); +} + +Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { + DataType T; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); + if (T == DT_COMPLEX64) { + return errors::Unimplemented( + "MatMul gradient for complex is not supported yet."); + } + bool ta; + bool tb; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "transpose_a", &ta)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "transpose_b", &tb)); + if (!ta && !tb) { + return MatMulGradHelper(g, "dz", false, "y", true, "x", true, "dz", false); + } + if (!ta && tb) { + return MatMulGradHelper(g, "dz", false, "y", false, "dz", true, "x", false); + } + if (ta && !tb) { + return MatMulGradHelper(g, "y", false, "dz", true, "x", false, "dz", false); + } + CHECK(ta && tb); + return MatMulGradHelper(g, "y", true, "dz", true, "dz", true, "x", true); +} +REGISTER_OP_GRADIENT("MatMul", MatMulGrad); + +// REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad); +// REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad); + +// Comparison ops. +REGISTER_OP_NO_GRADIENT("Less"); +REGISTER_OP_NO_GRADIENT("LessEqual"); +REGISTER_OP_NO_GRADIENT("Greater"); +REGISTER_OP_NO_GRADIENT("GreaterEqual"); +REGISTER_OP_NO_GRADIENT("Equal"); +REGISTER_OP_NO_GRADIENT("NotEqual"); + +// Logical ops. +REGISTER_OP_NO_GRADIENT("LogicalAnd"); +REGISTER_OP_NO_GRADIENT("LogicalOr"); +REGISTER_OP_NO_GRADIENT("LogicalNot"); + +// Sequence generation ops. +REGISTER_OP_NO_GRADIENT("Range"); +REGISTER_OP_NO_GRADIENT("LinSpace"); + +} // end namespace tensorflow diff --git a/tensorflow/cc/ops/nn_grad.cc b/tensorflow/cc/ops/nn_grad.cc new file mode 100644 index 0000000000..89b037e3c8 --- /dev/null +++ b/tensorflow/cc/ops/nn_grad.cc @@ -0,0 +1,55 @@ +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"x: T", "dy: T"}, + // Ret val defs + {"dx: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + { + {{"dx"}, "ReluGrad", {"dy", "x"}, {{"T", "$T"}}} + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Relu", ReluGrad); + +Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"features: T", "labels: T", "dcost_dloss: T", "donotcare: T"}, + // Ret val defs + {"dcost_dfeatures: T", "dcost_dlabels: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + { + // _, dloss_dfeatures = CrossEntropy(features, labels) + {{"donotcare_loss", "dloss_dfeatures"}, "CrossEntropy", + {"features", "labels"}, {{"T", "$T"}}}, + // dcost_dloss is of shape [batch_size]. + // dcost_dloss_mat is of shape [batch_size, 1]. + FDH::Const("neg1", -1), + {{"dcost_dloss_mat"}, "ExpandDims", {"dcost_dloss", "neg1"}, + {{"T", "$T"}}}, + // chain rule: dcost/dfeatures = dcost/dloss * dloss/dfeatures + {{"dcost_dfeatures"}, "Mul", {"dcost_dloss_mat", "dloss_dfeatures"}, + {{"T", "$T"}}}, + {{"dcost_dlabels"}, "ZerosLike", {"labels"}, {{"T", "$T"}}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad); + +} // end namespace tensorflow diff --git a/tensorflow/cc/ops/standard_ops.h b/tensorflow/cc/ops/standard_ops.h new file mode 100644 index 0000000000..8d7160a7f9 --- /dev/null +++ b/tensorflow/cc/ops/standard_ops.h @@ -0,0 +1,26 @@ +// #include this file to get access to the standard set of C++ graph +// definition libraries. + +#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_ + +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/attention_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/data_flow_ops.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/io_ops.h" +#include "tensorflow/cc/ops/linalg_ops.h" +#include "tensorflow/cc/ops/logging_ops.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/parsing_ops.h" +#include "tensorflow/cc/ops/random_ops.h" +#include "tensorflow/cc/ops/sparse_ops.h" +#include "tensorflow/cc/ops/state_ops.h" +#include "tensorflow/cc/ops/string_ops.h" +#include "tensorflow/cc/ops/summary_ops.h" +#include "tensorflow/cc/ops/training_ops.h" +#include "tensorflow/cc/ops/user_ops.h" + +#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_ |