diff options
author | Manjunath Kudlur <keveman@google.com> | 2016-07-15 14:28:59 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-15 15:33:32 -0700 |
commit | 25ac3dabfa3af7a313eb46b03690117c85030cc2 (patch) | |
tree | 06010c7cc7d25a538880c5f0d53df079e27093fd /tensorflow/cc/ops | |
parent | 194efde51895e0251d39c72c969dff1a50b67d35 (diff) |
Improvements to the C++ graph building API.
TESTED:
- passed opensource_build: http://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/2780/
Change: 127585603
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r-- | tensorflow/cc/ops/cc_op_gen.cc | 382 | ||||
-rw-r--r-- | tensorflow/cc/ops/cc_op_gen.h | 29 | ||||
-rw-r--r-- | tensorflow/cc/ops/cc_op_gen_main.cc | 49 | ||||
-rw-r--r-- | tensorflow/cc/ops/const_op.cc | 142 | ||||
-rw-r--r-- | tensorflow/cc/ops/const_op.h | 94 | ||||
-rw-r--r-- | tensorflow/cc/ops/const_op_test.cc | 128 | ||||
-rw-r--r-- | tensorflow/cc/ops/standard_ops.h | 14 |
7 files changed, 212 insertions, 626 deletions
diff --git a/tensorflow/cc/ops/cc_op_gen.cc b/tensorflow/cc/ops/cc_op_gen.cc deleted file mode 100644 index d6cfccd6a9..0000000000 --- a/tensorflow/cc/ops/cc_op_gen.cc +++ /dev/null @@ -1,382 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// TODO(josh11b): Rewrite function parameter names to avoid C++ keywords -// or "opts". - -#include "tensorflow/cc/ops/cc_op_gen.h" - -#include <unordered_map> -#include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/framework/op_def_util.h" -#include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/gtl/stl_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace { - -const int kRightMargin = 79; - -const char* AttrTypeName(StringPiece attr_type) { - static const char* kAttrTypeName[][2] = { - {"string", "StringPiece"}, - {"list(string)", "gtl::ArraySlice<string>"}, - {"int", "int64"}, - {"list(int)", "gtl::ArraySlice<int>"}, - {"float", "float"}, - {"list(float)", "gtl::ArraySlice<float>"}, - {"bool", "bool"}, - {"list(bool)", "gtl::ArraySlice<bool>"}, - {"type", "DataType"}, - {"list(type)", "DataTypeSlice"}, - {"shape", "TensorShape"}, - {"list(shape)", "gtl::ArraySlice<TensorShape>"}, - {"tensor", "const Tensor&"}, - {"list(tensor)", "gtl::ArraySlice<Tensor>"}, - {"func", "const NameAttrList&"}, - }; - for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) { - if (attr_type == kAttrTypeName[i][0]) { - return kAttrTypeName[i][1]; - } - } - LOG(FATAL) << "Unsupported Attr type: " << attr_type; - return ""; -} - -// Change: Into: -// ABC // ABC -// // -// DEF // DEF -string MakeComment(StringPiece text) { - string ret; - while (!text.empty()) { - int last_non_space = -1; - int newline; - for (newline = 0; newline < static_cast<int>(text.size()); ++newline) { - if (text[newline] == '\n') break; - if (text[newline] != ' ') last_non_space = newline; - } - if (last_non_space == -1) { - strings::StrAppend(&ret, "//\n"); - } else { - strings::StrAppend(&ret, "// ", text.substr(0, last_non_space + 1), "\n"); - } - text.remove_prefix(newline + 1); - } - return ret; -} - -void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) { - // TODO(josh11b): Better wrapping of comments. - string comment; - if (op_def.summary().empty()) { - comment = "TODO: add doc.\n"; - } else { - comment = strings::StrCat(op_def.summary(), "\n"); - if (op_def.has_deprecation()) { - strings::StrAppend(&comment, "\nDEPRECATED at GraphDef version ", - op_def.deprecation().version(), ":\n", - op_def.deprecation().explanation(), ".\n"); - } - if (!op_def.description().empty()) { - strings::StrAppend(&comment, "\n", op_def.description(), "\n"); - } - } - - static const string kSingleInputType = "NodeOut"; - static const string kListInputType = "gtl::ArraySlice<NodeOut>"; - - std::vector<string> arg_types; - std::vector<string> arg_names; - - strings::StrAppend(&comment, "\nArguments:\n"); - - // Map from attr name to the first input arg it is inferred from. - std::unordered_map<string, string> inferred_attrs; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); - arg_names.emplace_back(arg.name()); - bool is_list = false; - - if (!arg.type_attr().empty()) { - gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name()); - } else if (!arg.type_list_attr().empty()) { - gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(), - arg.name()); - is_list = true; - } - if (!arg.number_attr().empty()) { - gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name()); - is_list = true; - } - if (is_list) { - arg_types.emplace_back(kListInputType); - } else { - arg_types.emplace_back(kSingleInputType); - } - - // TODO(josh11b): Include input type information. - StringPiece description = arg.description(); - if (!description.empty()) { - ConsumeEquals(&description); - strings::StrAppend(&comment, "* ", arg_names.back(), ": ", - arg.description(), "\n"); - } - } - - string options_comment; - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); - // Do not add inferred attrs or attrs with defaults to the C++ - // function signature. - if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) { - if (!attr.has_default_value()) { - arg_names.emplace_back(attr.name()); - arg_types.emplace_back(AttrTypeName(attr.type())); - if (!attr.description().empty()) { - strings::StrAppend(&comment, "* ", arg_names.back(), ": ", - attr.description(), "\n"); - } - } else { - strings::StrAppend(&options_comment, " .WithAttr(\"", attr.name(), - "\", ", AttrTypeName(attr.type()), "): Defaults to ", - SummarizeAttrValue(attr.default_value()), ".\n"); - if (!attr.description().empty()) { - strings::StrAppend(&options_comment, " ", attr.description(), - "\n"); - } - } - } - } - CHECK_EQ(arg_names.size(), arg_types.size()); - strings::StrAppend(&comment, "* opts:\n", options_comment, - R"comment( .WithName(StringPiece): Set the Node's name - .WithDevice(StringPiece): Set the Node's requested device - .WithControlInput(Node*) / .WithControlInputs({Node*, ...}): - Add control dependencies on the specified Node(s). - -Returns a pointer to the created Node)comment"); - - // TODO(josh11b): Include output type information. - if (op_def.output_arg_size() == 0) { - strings::StrAppend(&comment, ".\n"); - } else if (op_def.output_arg_size() == 1) { - StringPiece description = op_def.output_arg(0).description(); - ConsumeEquals(&description); - if (description.empty()) { - strings::StrAppend(&comment, ".\n"); - } else { - strings::StrAppend(&comment, ", with output:\n", description, "\n"); - } - } else { - strings::StrAppend(&comment, ", with outputs:\n"); - for (int o = 0; o < op_def.output_arg_size(); ++o) { - StringPiece description = op_def.output_arg(o).description(); - ConsumeEquals(&description); - if (description.empty()) { - strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), "\n"); - } else { - strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), ": ", - description, "\n"); - } - } - } - - // Write the header comment. - TF_CHECK_OK(h->Append(MakeComment(comment))); - - // Declare the function wrapper. - const string prefix = strings::StrCat("Node* ", op_def.name(), "("); - string h_rest; - for (size_t i = 0; i < arg_names.size(); ++i) { - strings::StrAppend(&h_rest, arg_types[i], " ", arg_names[i], ", "); - } - strings::StrAppend(&h_rest, "const GraphDefBuilder::Options& opts"); - string cc_decl = h_rest; - strings::StrAppend(&h_rest, ");"); - TF_CHECK_OK(h->Append(WordWrap(prefix, h_rest, kRightMargin) + "\n\n")); - - // Define the function wrapper. - strings::StrAppend(&cc_decl, ") {"); - TF_CHECK_OK(cc->Append(WordWrap(prefix, cc_decl, kRightMargin) + "\n")); - const string op_name = strings::StrCat(" static const string kOpName = \"", - op_def.name(), "\";\n"); - - if (arg_types.empty()) { - TF_CHECK_OK(cc->Append(op_name)); - TF_CHECK_OK(cc->Append(" return SourceOp(kOpName, opts);\n}\n\n")); - } else if (arg_types == std::vector<string>({kSingleInputType})) { - TF_CHECK_OK(cc->Append(op_name)); - TF_CHECK_OK(cc->Append(strings::StrCat(" return UnaryOp(kOpName, ", - arg_names[0], ", opts);\n}\n\n"))); - } else if (arg_types == - std::vector<string>({kSingleInputType, kSingleInputType})) { - TF_CHECK_OK(cc->Append(op_name)); - // TODO(josh11b): Word wrap this if it ever becomes necessary. - TF_CHECK_OK( - cc->Append(strings::StrCat(" return BinaryOp(kOpName, ", arg_names[0], - ", ", arg_names[1], ", opts);\n}\n\n"))); - } else { - TF_CHECK_OK(cc->Append(" if (opts.HaveError()) return nullptr;\n")); - TF_CHECK_OK(cc->Append(op_name)); - TF_CHECK_OK(cc->Append( - " NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,\n" - " opts.op_registry());\n")); - for (size_t i = 0; i < arg_names.size(); ++i) { - if (i < static_cast<size_t>(op_def.input_arg_size())) { - TF_CHECK_OK(cc->Append( - strings::StrCat(" node_builder.Input(", arg_names[i], ");\n"))); - } else { - TF_CHECK_OK( - cc->Append(strings::StrCat(" node_builder.Attr(\"", arg_names[i], - "\", ", arg_names[i], ");\n"))); - } - } - TF_CHECK_OK( - cc->Append(" return opts.FinalizeBuilder(&node_builder);\n" - "}\n\n")); - } -} - -// Converts: -// bazel-out/.../genfiles/(external/YYY/)?XX -// to: XX. -string GetPath(const std::string& dot_h_fname) { - auto pos = dot_h_fname.find("/genfiles/"); - string result = dot_h_fname; - if (pos != string::npos) { - // - 1 account for the terminating null character (\0) in "/genfiles/". - result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1); - } - if (result.size() > sizeof("external/") && - result.compare(0, sizeof("external/") - 1, "external/") == 0) { - result = result.substr(sizeof("external/") - 1); - pos = result.find("/"); - if (pos != string::npos) { - result = result.substr(pos + 1); - } - } - return result; -} - -// Converts: -// cc/ops/gen_foo_ops.h -// to: -// CC_OPS_GEN_FOO_OPS_H_ -string ToGuard(const std::string& path) { - string guard; - guard.reserve(path.size() + 1); // + 1 -> trailing _ - for (const char c : path) { - if (c >= 'A' && c <= 'Z') { - guard += c; - } else if (c >= 'a' && c <= 'z') { - guard += c + 'A' - 'a'; - } else { - guard += '_'; - } - } - guard += '_'; - return guard; -} - -} // namespace - -void WriteCCOps(const OpList& ops, const std::string& dot_h_fname, - const std::string& dot_cc_fname) { - Env* env = Env::Default(); - std::unique_ptr<WritableFile> h; - std::unique_ptr<WritableFile> cc; - TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h)); - TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc)); - - // .h Header - const string include = GetPath(dot_h_fname); - const string guard = ToGuard(include); - // TODO(josh11b): Mention the library for which wrappers are being generated. - Status s; - s = h->Append( - strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n" - "#ifndef ", - guard, - "\n" - "#define ", - guard, R"header( - -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/lib/gtl/array_slice.h" - -namespace tensorflow { -namespace ops { - -// These add a node to the graph from opts. -// -// Note for "NodeOut" inputs, you will typically either pass -// * a {Node*, int index} (to pass the index-th output of that node), or -// * a Node* (to pass the first output of that node). - - -)header")); - TF_CHECK_OK(s); - // .cc Header - s = cc->Append( - strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n" - "#include \"", - include, R"header(" - -#include "tensorflow/core/graph/node_builder.h" - -namespace tensorflow { -namespace ops { - -)header")); - TF_CHECK_OK(s); - - for (const auto& op_def : ops.op()) { - WriteCCOp(op_def, h.get(), cc.get()); - } - - // .h Footer - - s = h->Append(strings::StrCat(R"footer(} // namespace ops -} // namespace tensorflow - -#endif // )footer", - guard, "\n")); - TF_CHECK_OK(s); - - // .cc Footer - - s = cc->Append(R"footer(} // namespace ops -} // namespace tensorflow -)footer"); - TF_CHECK_OK(s); - - TF_CHECK_OK(cc->Close()); - TF_CHECK_OK(h->Close()); -} - -} // namespace tensorflow diff --git a/tensorflow/cc/ops/cc_op_gen.h b/tensorflow/cc/ops/cc_op_gen.h deleted file mode 100644 index 09f9ca2467..0000000000 --- a/tensorflow/cc/ops/cc_op_gen.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CC_OPS_CC_OP_GEN_H_ -#define TENSORFLOW_CC_OPS_CC_OP_GEN_H_ - -#include "tensorflow/core/framework/op_def.pb.h" - -namespace tensorflow { - -// Result is written to files dot_h and dot_cc. -void WriteCCOps(const OpList& ops, const std::string& dot_h_fname, - const std::string& dot_cc_fname); - -} // namespace tensorflow - -#endif // TENSORFLOW_CC_OPS_CC_OP_GEN_H_ diff --git a/tensorflow/cc/ops/cc_op_gen_main.cc b/tensorflow/cc/ops/cc_op_gen_main.cc deleted file mode 100644 index 47491fac10..0000000000 --- a/tensorflow/cc/ops/cc_op_gen_main.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/ops/cc_op_gen.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace { - -void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, - bool include_internal) { - OpList ops; - OpRegistry::Global()->Export(include_internal, &ops); - WriteCCOps(ops, dot_h, dot_cc); -} - -} // namespace -} // namespace tensorflow - -int main(int argc, char* argv[]) { - tensorflow::port::InitMain(argv[0], &argc, &argv); - if (argc != 4) { - fprintf(stderr, - "Usage: %s out.h out.cc include_internal\n" - " include_internal: 1 means include internal ops\n", - argv[0]); - exit(1); - } - - bool include_internal = tensorflow::StringPiece("1") == argv[3]; - tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal); - return 0; -} diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc index 4b4b0e06e1..b37b8b67d7 100644 --- a/tensorflow/cc/ops/const_op.cc +++ b/tensorflow/cc/ops/const_op.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,119 +14,59 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/cc/ops/const_op.h" - #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { namespace ops { -namespace { -const string& OpName() { - static const string kOpName = "Const"; - return kOpName; -} -} // namespace - -#define DEFINE_CONST_SCALAR(TYPE) \ - Node* Const(TYPE s, const GraphDefBuilder::Options& options) { \ - return Const(gtl::ArraySlice<TYPE>(&s, 1), TensorShape({}), options); \ - } - -#define DEFINE_CONST_VECTOR(TYPE) \ - Node* Const(gtl::ArraySlice<TYPE> v, \ - const GraphDefBuilder::Options& options) { \ - return Const(v, TensorShape({static_cast<int64>(v.size())}), options); \ +Output Const(const Scope& scope, const Input::Initializer& val) { + if (!scope.ok()) return Output(); + if (!val.status.ok()) { + scope.UpdateStatus(val.status); + return Output(); } -#define DEFINE_CONST_TENSOR(TYPE, ...) \ - Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \ - const GraphDefBuilder::Options& options) { \ - if (options.HaveError()) return nullptr; \ - NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), \ - options.op_registry()); \ - const DataType dt = DataTypeToEnum<TYPE>::v(); \ - if (t.size() == 1) { \ - TensorProto proto; \ - proto.set_dtype(dt); \ - shape.AsProto(proto.mutable_tensor_shape()); \ - __VA_ARGS__; \ - node_builder.Attr("dtype", dt).Attr("value", proto); \ - } else { \ - Tensor tensor(dt, shape); \ - if (tensor.NumElements() != static_cast<int64>(t.size())) { \ - options.UpdateStatus(errors::InvalidArgument( \ - t.size(), " values provided to Const() != ", tensor.NumElements(), \ - " elements for shape ", shape.DebugString())); \ - } else { \ - std::copy_n(t.data(), t.size(), tensor.flat<TYPE>().data()); \ - node_builder.Attr("dtype", dt).Attr("value", tensor); \ - } \ - } \ - return options.FinalizeBuilder(&node_builder); \ - } - -#define DEFINE_CONST_IMPL(TYPE, ...) \ - DEFINE_CONST_SCALAR(TYPE) \ - DEFINE_CONST_VECTOR(TYPE) \ - DEFINE_CONST_TENSOR(TYPE, __VA_ARGS__) - -#define DEFINE_CONST(TYPE, FIELD) \ - DEFINE_CONST_IMPL(TYPE, proto.add_##FIELD(*t.begin());) - -DEFINE_CONST(float, float_val); -DEFINE_CONST(double, double_val); -DEFINE_CONST(int32, int_val); -DEFINE_CONST(uint8, int_val); -DEFINE_CONST(int16, int_val); -DEFINE_CONST(int8, int_val); -DEFINE_CONST(int64, int64_val); -DEFINE_CONST(bool, bool_val); + Node* ret; + Graph* graph = scope.graph(); + const string unique_name = scope.GetUniqueNameForOp("Const"); + auto builder = NodeBuilder(unique_name, "Const") + .Attr("value", val.tensor) + .Attr("dtype", val.tensor.dtype()); + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(graph, &ret)); -DEFINE_CONST_IMPL(Eigen::half, proto.add_half_val(t.begin()->x)); + if (!scope.ok()) return Output(); -DEFINE_CONST_IMPL(complex64, proto.add_scomplex_val(t.begin()->real()); - proto.add_scomplex_val(t.begin()->imag());); - -DEFINE_CONST_IMPL(complex128, proto.add_dcomplex_val(t.begin()->real()); - proto.add_dcomplex_val(t.begin()->imag());); - -Node* Const(StringPiece s, const GraphDefBuilder::Options& options) { - if (options.HaveError()) return nullptr; - NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), - options.op_registry()); - TensorProto proto; - proto.set_dtype(DT_STRING); - TensorShape({}).AsProto(proto.mutable_tensor_shape()); - proto.add_string_val(s.data(), s.size()); - node_builder.Attr("dtype", DT_STRING).Attr("value", proto); - return options.FinalizeBuilder(&node_builder); + return Output(ret); } -DEFINE_CONST_VECTOR(string) -DEFINE_CONST_TENSOR(string, proto.add_string_val(*t.begin());) - -#undef DEFINE_CONST -#undef DEFINE_CONST_IMPL -#undef DEFINE_CONST_TENSOR -#undef DEFINE_CONST_VECTOR -#undef DEFINE_CONST_SCALAR - -Node* Const(const Tensor& t, const GraphDefBuilder::Options& options) { - if (options.HaveError()) return nullptr; - NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), - options.op_registry()); - node_builder.Attr("dtype", t.dtype()).Attr("value", t); - return options.FinalizeBuilder(&node_builder); +NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp) { + if (!inp.status().ok()) { + scope.UpdateStatus(inp.status()); + return NodeBuilder::NodeOut(inp.node(), inp.index()); + } + if (inp.node()) { + return NodeBuilder::NodeOut(inp.node(), inp.index()); + } + if (!inp.node_name().empty()) { + return NodeBuilder::NodeOut(inp.node_name(), inp.index(), inp.data_type()); + } + auto transformed = Input{ + Const(scope.NewSubScope("Const"), Input::Initializer(inp.tensor()))}; + return NodeBuilder::NodeOut{transformed.node(), transformed.index()}; } -Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options) { - if (options.HaveError()) return nullptr; - NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), - options.op_registry()); - node_builder.Attr("dtype", proto.dtype()).Attr("value", proto); - return options.FinalizeBuilder(&node_builder); +std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope, + const InputList& inp) { + std::vector<NodeBuilder::NodeOut> out; + for (const auto& i : inp) { + const auto node_out = AsNodeOut(scope, i); + if (!scope.ok()) { + return {}; + } + out.push_back(node_out); + } + return out; } } // namespace ops diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index a079368ef7..75844d124d 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,75 +13,53 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_ -#define TENSORFLOW_CC_OPS_CONST_OP_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_ +#define THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_ -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/graph/node_builder.h" namespace tensorflow { namespace ops { -// If a shape is specified, you may either provide the same number of values, -// or a single value and that value will be duplicated to fill out the Tensor. -#define DECLARE_CONST(TYPE) \ - Node* Const(TYPE s, const GraphDefBuilder::Options& options); /* Scalar */ \ - Node* Const(gtl::ArraySlice<TYPE> v, \ - const GraphDefBuilder::Options& options); /* Vector */ \ - Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \ - const GraphDefBuilder::Options& options); /* Tensor */ \ - inline Node* Const(std::initializer_list<TYPE> v, /* Vector using {...} */ \ - const GraphDefBuilder::Options& options) { \ - return Const(gtl::ArraySlice<TYPE>(v), options); \ - } \ - inline Node* Const(std::initializer_list<TYPE> t, /* Tensor using {...} */ \ - const TensorShape& shape, \ - const GraphDefBuilder::Options& options) { \ - return Const(gtl::ArraySlice<TYPE>(t), shape, options); \ - } - -DECLARE_CONST(Eigen::half); -DECLARE_CONST(float); -DECLARE_CONST(double); -DECLARE_CONST(int32); -DECLARE_CONST(uint8); -DECLARE_CONST(int16); -DECLARE_CONST(int8); -DECLARE_CONST(complex64); -DECLARE_CONST(complex128); -DECLARE_CONST(int64); -DECLARE_CONST(bool); +Output Const(const Scope& scope, const Input::Initializer& val); -#undef DECLARE_CONST - -// String -Node* Const(StringPiece s, const GraphDefBuilder::Options& options); -Node* Const(gtl::ArraySlice<string> v, const GraphDefBuilder::Options& options); -Node* Const(gtl::ArraySlice<string> t, const TensorShape& shape, - const GraphDefBuilder::Options& options); -inline Node* Const(std::initializer_list<string> v, - const GraphDefBuilder::Options& options) { - return Const(gtl::ArraySlice<string>(v), options); -} -inline Node* Const(std::initializer_list<string> t, const TensorShape& shape, - const GraphDefBuilder::Options& options) { - return Const(gtl::ArraySlice<string>(t), shape, options); +template <typename T> +Output Const(const Scope& scope, const Input::Initializer& val) { + if (!scope.ok()) return Output(); + if (!val.status.ok()) { + scope.UpdateStatus(val.status); + return Output(); + } + typedef typename Input::Initializer::RealType<T>::type DstT; + if (val.tensor.NumElements() > 0) { + // TODO(keveman): Implement the in-situ cast. + scope.UpdateStatus(errors::Unimplemented( + "Explict cast of a non-empty tensor not implemented yet")); + return Output(); + } + Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape()); + return Const(scope, Input::Initializer(t)); } -// A Tensor of any type. -Node* Const(const Tensor& t, const GraphDefBuilder::Options& options); -Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options); +template <typename T> +Output Const(const Scope& scope, const T& v, const TensorShape shape) { + return Const(scope, Input::Initializer(v, shape)); +} -template <class T> -Node* EmptyConst(const GraphDefBuilder::Options& options) { - return Const(gtl::ArraySlice<T>(), options); +template <typename T> +Output Const(const Scope& scope, const std::initializer_list<T>& v, + const TensorShape shape) { + return Const(scope, Input::Initializer(v, shape)); } -// TODO(josh11b): Support other types (e.g. quantized ints, float16). +NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); + +std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope, + const InputList& inp); } // namespace ops } // namespace tensorflow -#endif // TENSORFLOW_CC_OPS_CONST_OP_H_ +#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_ diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc new file mode 100644 index 0000000000..a56b66c1cc --- /dev/null +++ b/tensorflow/cc/ops/const_op_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +template <typename T> +void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values, + TensorShape shape) { + EXPECT_TRUE(n->IsConstant()); + Tensor tensor; + TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + DataType dtype; + TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + EXPECT_EQ(tensor.dtype(), dtype); + test::ExpectTensorEqual<T>(tensor, test::AsTensor(values, shape)); +} + +void ExpectTypeAndShape(const Node* n, DataType expected_dtype, + TensorShape expected_shape) { + EXPECT_TRUE(n->IsConstant()); + Tensor tensor; + TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + DataType dtype; + TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + EXPECT_EQ(dtype, expected_dtype); + EXPECT_EQ(expected_shape, TensorShape(tensor.shape())); +} + +} // namespace + +TEST(ConstOpTest, Basic) { + Scope root = Scope::NewRootScope(); + auto c = ops::Const(root, 42.0f); + TF_EXPECT_OK(root.status()); + EXPECT_EQ(c.op().output_type(0), DT_FLOAT); + ExpectNodeEqual<float>(c.node(), {42.0f}, {}); +} + +TEST(ConstOpTest, MultiDim) { + Scope root = Scope::NewRootScope(); + auto c = ops::Const(root, {{2.0}, {3.0}}); + TF_CHECK_OK(root.status()); + EXPECT_EQ(c.op().output_type(0), DT_DOUBLE); + ExpectNodeEqual<double>(c.node(), {2.0, 3.0}, {2, 1}); +} + +TEST(ConstOpTest, Empty) { + Scope root = Scope::NewRootScope(); + + auto c1 = ops::Const(root, {}); + TF_CHECK_OK(root.status()); + ExpectTypeAndShape(c1.node(), DT_FLOAT, {0}); + + auto c2 = ops::Const(root, {{}}); + TF_CHECK_OK(root.status()); + ExpectTypeAndShape(c2.node(), DT_FLOAT, {1, 0}); + + auto c3 = ops::Const(root, {{{}, {}}}); + TF_CHECK_OK(root.status()); + ExpectTypeAndShape(c3.node(), DT_FLOAT, {1, 2, 0}); + + auto c4 = ops::Const<int>(root, {{{}}}); + TF_CHECK_OK(root.status()); + ExpectTypeAndShape(c4.node(), DT_INT32, {1, 1, 0}); + + ops::Const(root, {{}, {{}}}); + EXPECT_FALSE(root.status().ok()); +} + +TEST(ConstOpTest, WithExplicitShape) { + Scope root = Scope::NewRootScope(); + auto c = ops::Const(root, 42.0, {2, 2}); + TF_CHECK_OK(root.status()); + EXPECT_EQ(c.op().output_type(0), DT_DOUBLE); + ExpectNodeEqual<double>(c.node(), {42.0, 42.0, 42.0, 42.0}, {2, 2}); + + auto d = ops::Const(root, {"1", "2", "3", "4", "5", "6"}, {2, 3}); + TF_CHECK_OK(root.status()); + EXPECT_EQ(d.op().output_type(0), DT_STRING); + ExpectNodeEqual<string>(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3}); +} + +TEST(ConstOpTest, InvalidInitializer) { + Scope root = Scope::NewRootScope(); + ops::Const(root, {{2.0}, {"df"}}); + EXPECT_FALSE(root.status().ok()); +} + +TEST(ConstOpTest, Names) { + Scope root = Scope::NewRootScope(); + auto c = ops::Const(root, {{2.0}, {3.0}}); + EXPECT_EQ(c.node()->name(), "Const"); + auto c_1 = ops::Const(root, {{2.0}, {3.0}}); + EXPECT_EQ(c_1.node()->name(), "Const_1"); + + auto x = ops::Const(root.WithOpName("x"), 1); + EXPECT_EQ(x.node()->name(), "x"); + auto x_1 = ops::Const(root.WithOpName("x"), 1); + EXPECT_EQ(x_1.node()->name(), "x_1"); + + Scope child = root.NewSubScope("c"); + auto c_y = ops::Const(child.WithOpName("y"), 1); + EXPECT_EQ(c_y.node()->name(), "c/y"); + auto c_y_1 = ops::Const(child.WithOpName("y"), 1); + EXPECT_EQ(c_y_1.node()->name(), "c/y_1"); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/ops/standard_ops.h b/tensorflow/cc/ops/standard_ops.h index 13cb9d252b..e117ddd042 100644 --- a/tensorflow/cc/ops/standard_ops.h +++ b/tensorflow/cc/ops/standard_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// #include this file to get access to the standard set of C++ graph -// definition libraries. - -#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_ -#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_ #include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/candidate_sampling_ops.h" #include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/image_ops.h" #include "tensorflow/cc/ops/io_ops.h" @@ -28,6 +27,7 @@ limitations under the License. #include "tensorflow/cc/ops/logging_ops.h" #include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/no_op.h" #include "tensorflow/cc/ops/parsing_ops.h" #include "tensorflow/cc/ops/random_ops.h" #include "tensorflow/cc/ops/sparse_ops.h" @@ -36,4 +36,4 @@ limitations under the License. #include "tensorflow/cc/ops/training_ops.h" #include "tensorflow/cc/ops/user_ops.h" -#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_ |