aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@google.com>2016-07-15 14:28:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-15 15:33:32 -0700
commit25ac3dabfa3af7a313eb46b03690117c85030cc2 (patch)
tree06010c7cc7d25a538880c5f0d53df079e27093fd /tensorflow/cc/ops
parent194efde51895e0251d39c72c969dff1a50b67d35 (diff)
Improvements to the C++ graph building API.
TESTED: - passed opensource_build: http://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/2780/ Change: 127585603
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r--tensorflow/cc/ops/cc_op_gen.cc382
-rw-r--r--tensorflow/cc/ops/cc_op_gen.h29
-rw-r--r--tensorflow/cc/ops/cc_op_gen_main.cc49
-rw-r--r--tensorflow/cc/ops/const_op.cc142
-rw-r--r--tensorflow/cc/ops/const_op.h94
-rw-r--r--tensorflow/cc/ops/const_op_test.cc128
-rw-r--r--tensorflow/cc/ops/standard_ops.h14
7 files changed, 212 insertions, 626 deletions
diff --git a/tensorflow/cc/ops/cc_op_gen.cc b/tensorflow/cc/ops/cc_op_gen.cc
deleted file mode 100644
index d6cfccd6a9..0000000000
--- a/tensorflow/cc/ops/cc_op_gen.cc
+++ /dev/null
@@ -1,382 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// TODO(josh11b): Rewrite function parameter names to avoid C++ keywords
-// or "opts".
-
-#include "tensorflow/cc/ops/cc_op_gen.h"
-
-#include <unordered_map>
-#include "tensorflow/core/framework/attr_value_util.h"
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/framework/op_def_util.h"
-#include "tensorflow/core/framework/op_gen_lib.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/gtl/stl_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace {
-
-const int kRightMargin = 79;
-
-const char* AttrTypeName(StringPiece attr_type) {
- static const char* kAttrTypeName[][2] = {
- {"string", "StringPiece"},
- {"list(string)", "gtl::ArraySlice<string>"},
- {"int", "int64"},
- {"list(int)", "gtl::ArraySlice<int>"},
- {"float", "float"},
- {"list(float)", "gtl::ArraySlice<float>"},
- {"bool", "bool"},
- {"list(bool)", "gtl::ArraySlice<bool>"},
- {"type", "DataType"},
- {"list(type)", "DataTypeSlice"},
- {"shape", "TensorShape"},
- {"list(shape)", "gtl::ArraySlice<TensorShape>"},
- {"tensor", "const Tensor&"},
- {"list(tensor)", "gtl::ArraySlice<Tensor>"},
- {"func", "const NameAttrList&"},
- };
- for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
- if (attr_type == kAttrTypeName[i][0]) {
- return kAttrTypeName[i][1];
- }
- }
- LOG(FATAL) << "Unsupported Attr type: " << attr_type;
- return "";
-}
-
-// Change: Into:
-// ABC // ABC
-// //
-// DEF // DEF
-string MakeComment(StringPiece text) {
- string ret;
- while (!text.empty()) {
- int last_non_space = -1;
- int newline;
- for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
- if (text[newline] == '\n') break;
- if (text[newline] != ' ') last_non_space = newline;
- }
- if (last_non_space == -1) {
- strings::StrAppend(&ret, "//\n");
- } else {
- strings::StrAppend(&ret, "// ", text.substr(0, last_non_space + 1), "\n");
- }
- text.remove_prefix(newline + 1);
- }
- return ret;
-}
-
-void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) {
- // TODO(josh11b): Better wrapping of comments.
- string comment;
- if (op_def.summary().empty()) {
- comment = "TODO: add doc.\n";
- } else {
- comment = strings::StrCat(op_def.summary(), "\n");
- if (op_def.has_deprecation()) {
- strings::StrAppend(&comment, "\nDEPRECATED at GraphDef version ",
- op_def.deprecation().version(), ":\n",
- op_def.deprecation().explanation(), ".\n");
- }
- if (!op_def.description().empty()) {
- strings::StrAppend(&comment, "\n", op_def.description(), "\n");
- }
- }
-
- static const string kSingleInputType = "NodeOut";
- static const string kListInputType = "gtl::ArraySlice<NodeOut>";
-
- std::vector<string> arg_types;
- std::vector<string> arg_names;
-
- strings::StrAppend(&comment, "\nArguments:\n");
-
- // Map from attr name to the first input arg it is inferred from.
- std::unordered_map<string, string> inferred_attrs;
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- const auto& arg(op_def.input_arg(i));
- arg_names.emplace_back(arg.name());
- bool is_list = false;
-
- if (!arg.type_attr().empty()) {
- gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
- } else if (!arg.type_list_attr().empty()) {
- gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
- arg.name());
- is_list = true;
- }
- if (!arg.number_attr().empty()) {
- gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
- is_list = true;
- }
- if (is_list) {
- arg_types.emplace_back(kListInputType);
- } else {
- arg_types.emplace_back(kSingleInputType);
- }
-
- // TODO(josh11b): Include input type information.
- StringPiece description = arg.description();
- if (!description.empty()) {
- ConsumeEquals(&description);
- strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
- arg.description(), "\n");
- }
- }
-
- string options_comment;
- for (int i = 0; i < op_def.attr_size(); ++i) {
- const auto& attr(op_def.attr(i));
- // Do not add inferred attrs or attrs with defaults to the C++
- // function signature.
- if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
- if (!attr.has_default_value()) {
- arg_names.emplace_back(attr.name());
- arg_types.emplace_back(AttrTypeName(attr.type()));
- if (!attr.description().empty()) {
- strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
- attr.description(), "\n");
- }
- } else {
- strings::StrAppend(&options_comment, " .WithAttr(\"", attr.name(),
- "\", ", AttrTypeName(attr.type()), "): Defaults to ",
- SummarizeAttrValue(attr.default_value()), ".\n");
- if (!attr.description().empty()) {
- strings::StrAppend(&options_comment, " ", attr.description(),
- "\n");
- }
- }
- }
- }
- CHECK_EQ(arg_names.size(), arg_types.size());
- strings::StrAppend(&comment, "* opts:\n", options_comment,
- R"comment( .WithName(StringPiece): Set the Node's name
- .WithDevice(StringPiece): Set the Node's requested device
- .WithControlInput(Node*) / .WithControlInputs({Node*, ...}):
- Add control dependencies on the specified Node(s).
-
-Returns a pointer to the created Node)comment");
-
- // TODO(josh11b): Include output type information.
- if (op_def.output_arg_size() == 0) {
- strings::StrAppend(&comment, ".\n");
- } else if (op_def.output_arg_size() == 1) {
- StringPiece description = op_def.output_arg(0).description();
- ConsumeEquals(&description);
- if (description.empty()) {
- strings::StrAppend(&comment, ".\n");
- } else {
- strings::StrAppend(&comment, ", with output:\n", description, "\n");
- }
- } else {
- strings::StrAppend(&comment, ", with outputs:\n");
- for (int o = 0; o < op_def.output_arg_size(); ++o) {
- StringPiece description = op_def.output_arg(o).description();
- ConsumeEquals(&description);
- if (description.empty()) {
- strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), "\n");
- } else {
- strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), ": ",
- description, "\n");
- }
- }
- }
-
- // Write the header comment.
- TF_CHECK_OK(h->Append(MakeComment(comment)));
-
- // Declare the function wrapper.
- const string prefix = strings::StrCat("Node* ", op_def.name(), "(");
- string h_rest;
- for (size_t i = 0; i < arg_names.size(); ++i) {
- strings::StrAppend(&h_rest, arg_types[i], " ", arg_names[i], ", ");
- }
- strings::StrAppend(&h_rest, "const GraphDefBuilder::Options& opts");
- string cc_decl = h_rest;
- strings::StrAppend(&h_rest, ");");
- TF_CHECK_OK(h->Append(WordWrap(prefix, h_rest, kRightMargin) + "\n\n"));
-
- // Define the function wrapper.
- strings::StrAppend(&cc_decl, ") {");
- TF_CHECK_OK(cc->Append(WordWrap(prefix, cc_decl, kRightMargin) + "\n"));
- const string op_name = strings::StrCat(" static const string kOpName = \"",
- op_def.name(), "\";\n");
-
- if (arg_types.empty()) {
- TF_CHECK_OK(cc->Append(op_name));
- TF_CHECK_OK(cc->Append(" return SourceOp(kOpName, opts);\n}\n\n"));
- } else if (arg_types == std::vector<string>({kSingleInputType})) {
- TF_CHECK_OK(cc->Append(op_name));
- TF_CHECK_OK(cc->Append(strings::StrCat(" return UnaryOp(kOpName, ",
- arg_names[0], ", opts);\n}\n\n")));
- } else if (arg_types ==
- std::vector<string>({kSingleInputType, kSingleInputType})) {
- TF_CHECK_OK(cc->Append(op_name));
- // TODO(josh11b): Word wrap this if it ever becomes necessary.
- TF_CHECK_OK(
- cc->Append(strings::StrCat(" return BinaryOp(kOpName, ", arg_names[0],
- ", ", arg_names[1], ", opts);\n}\n\n")));
- } else {
- TF_CHECK_OK(cc->Append(" if (opts.HaveError()) return nullptr;\n"));
- TF_CHECK_OK(cc->Append(op_name));
- TF_CHECK_OK(cc->Append(
- " NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,\n"
- " opts.op_registry());\n"));
- for (size_t i = 0; i < arg_names.size(); ++i) {
- if (i < static_cast<size_t>(op_def.input_arg_size())) {
- TF_CHECK_OK(cc->Append(
- strings::StrCat(" node_builder.Input(", arg_names[i], ");\n")));
- } else {
- TF_CHECK_OK(
- cc->Append(strings::StrCat(" node_builder.Attr(\"", arg_names[i],
- "\", ", arg_names[i], ");\n")));
- }
- }
- TF_CHECK_OK(
- cc->Append(" return opts.FinalizeBuilder(&node_builder);\n"
- "}\n\n"));
- }
-}
-
-// Converts:
-// bazel-out/.../genfiles/(external/YYY/)?XX
-// to: XX.
-string GetPath(const std::string& dot_h_fname) {
- auto pos = dot_h_fname.find("/genfiles/");
- string result = dot_h_fname;
- if (pos != string::npos) {
- // - 1 account for the terminating null character (\0) in "/genfiles/".
- result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
- }
- if (result.size() > sizeof("external/") &&
- result.compare(0, sizeof("external/") - 1, "external/") == 0) {
- result = result.substr(sizeof("external/") - 1);
- pos = result.find("/");
- if (pos != string::npos) {
- result = result.substr(pos + 1);
- }
- }
- return result;
-}
-
-// Converts:
-// cc/ops/gen_foo_ops.h
-// to:
-// CC_OPS_GEN_FOO_OPS_H_
-string ToGuard(const std::string& path) {
- string guard;
- guard.reserve(path.size() + 1); // + 1 -> trailing _
- for (const char c : path) {
- if (c >= 'A' && c <= 'Z') {
- guard += c;
- } else if (c >= 'a' && c <= 'z') {
- guard += c + 'A' - 'a';
- } else {
- guard += '_';
- }
- }
- guard += '_';
- return guard;
-}
-
-} // namespace
-
-void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
- const std::string& dot_cc_fname) {
- Env* env = Env::Default();
- std::unique_ptr<WritableFile> h;
- std::unique_ptr<WritableFile> cc;
- TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
- TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
-
- // .h Header
- const string include = GetPath(dot_h_fname);
- const string guard = ToGuard(include);
- // TODO(josh11b): Mention the library for which wrappers are being generated.
- Status s;
- s = h->Append(
- strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
- "#ifndef ",
- guard,
- "\n"
- "#define ",
- guard, R"header(
-
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/graph/graph_def_builder.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-
-namespace tensorflow {
-namespace ops {
-
-// These add a node to the graph from opts.
-//
-// Note for "NodeOut" inputs, you will typically either pass
-// * a {Node*, int index} (to pass the index-th output of that node), or
-// * a Node* (to pass the first output of that node).
-
-
-)header"));
- TF_CHECK_OK(s);
- // .cc Header
- s = cc->Append(
- strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
- "#include \"",
- include, R"header("
-
-#include "tensorflow/core/graph/node_builder.h"
-
-namespace tensorflow {
-namespace ops {
-
-)header"));
- TF_CHECK_OK(s);
-
- for (const auto& op_def : ops.op()) {
- WriteCCOp(op_def, h.get(), cc.get());
- }
-
- // .h Footer
-
- s = h->Append(strings::StrCat(R"footer(} // namespace ops
-} // namespace tensorflow
-
-#endif // )footer",
- guard, "\n"));
- TF_CHECK_OK(s);
-
- // .cc Footer
-
- s = cc->Append(R"footer(} // namespace ops
-} // namespace tensorflow
-)footer");
- TF_CHECK_OK(s);
-
- TF_CHECK_OK(cc->Close());
- TF_CHECK_OK(h->Close());
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/cc/ops/cc_op_gen.h b/tensorflow/cc/ops/cc_op_gen.h
deleted file mode 100644
index 09f9ca2467..0000000000
--- a/tensorflow/cc/ops/cc_op_gen.h
+++ /dev/null
@@ -1,29 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CC_OPS_CC_OP_GEN_H_
-#define TENSORFLOW_CC_OPS_CC_OP_GEN_H_
-
-#include "tensorflow/core/framework/op_def.pb.h"
-
-namespace tensorflow {
-
-// Result is written to files dot_h and dot_cc.
-void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
- const std::string& dot_cc_fname);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CC_OPS_CC_OP_GEN_H_
diff --git a/tensorflow/cc/ops/cc_op_gen_main.cc b/tensorflow/cc/ops/cc_op_gen_main.cc
deleted file mode 100644
index 47491fac10..0000000000
--- a/tensorflow/cc/ops/cc_op_gen_main.cc
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/cc/ops/cc_op_gen.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/platform/init_main.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace {
-
-void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
- bool include_internal) {
- OpList ops;
- OpRegistry::Global()->Export(include_internal, &ops);
- WriteCCOps(ops, dot_h, dot_cc);
-}
-
-} // namespace
-} // namespace tensorflow
-
-int main(int argc, char* argv[]) {
- tensorflow::port::InitMain(argv[0], &argc, &argv);
- if (argc != 4) {
- fprintf(stderr,
- "Usage: %s out.h out.cc include_internal\n"
- " include_internal: 1 means include internal ops\n",
- argv[0]);
- exit(1);
- }
-
- bool include_internal = tensorflow::StringPiece("1") == argv[3];
- tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal);
- return 0;
-}
diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc
index 4b4b0e06e1..b37b8b67d7 100644
--- a/tensorflow/cc/ops/const_op.cc
+++ b/tensorflow/cc/ops/const_op.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -14,119 +14,59 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
-
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace ops {
-namespace {
-const string& OpName() {
- static const string kOpName = "Const";
- return kOpName;
-}
-} // namespace
-
-#define DEFINE_CONST_SCALAR(TYPE) \
- Node* Const(TYPE s, const GraphDefBuilder::Options& options) { \
- return Const(gtl::ArraySlice<TYPE>(&s, 1), TensorShape({}), options); \
- }
-
-#define DEFINE_CONST_VECTOR(TYPE) \
- Node* Const(gtl::ArraySlice<TYPE> v, \
- const GraphDefBuilder::Options& options) { \
- return Const(v, TensorShape({static_cast<int64>(v.size())}), options); \
+Output Const(const Scope& scope, const Input::Initializer& val) {
+ if (!scope.ok()) return Output();
+ if (!val.status.ok()) {
+ scope.UpdateStatus(val.status);
+ return Output();
}
-#define DEFINE_CONST_TENSOR(TYPE, ...) \
- Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \
- const GraphDefBuilder::Options& options) { \
- if (options.HaveError()) return nullptr; \
- NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), \
- options.op_registry()); \
- const DataType dt = DataTypeToEnum<TYPE>::v(); \
- if (t.size() == 1) { \
- TensorProto proto; \
- proto.set_dtype(dt); \
- shape.AsProto(proto.mutable_tensor_shape()); \
- __VA_ARGS__; \
- node_builder.Attr("dtype", dt).Attr("value", proto); \
- } else { \
- Tensor tensor(dt, shape); \
- if (tensor.NumElements() != static_cast<int64>(t.size())) { \
- options.UpdateStatus(errors::InvalidArgument( \
- t.size(), " values provided to Const() != ", tensor.NumElements(), \
- " elements for shape ", shape.DebugString())); \
- } else { \
- std::copy_n(t.data(), t.size(), tensor.flat<TYPE>().data()); \
- node_builder.Attr("dtype", dt).Attr("value", tensor); \
- } \
- } \
- return options.FinalizeBuilder(&node_builder); \
- }
-
-#define DEFINE_CONST_IMPL(TYPE, ...) \
- DEFINE_CONST_SCALAR(TYPE) \
- DEFINE_CONST_VECTOR(TYPE) \
- DEFINE_CONST_TENSOR(TYPE, __VA_ARGS__)
-
-#define DEFINE_CONST(TYPE, FIELD) \
- DEFINE_CONST_IMPL(TYPE, proto.add_##FIELD(*t.begin());)
-
-DEFINE_CONST(float, float_val);
-DEFINE_CONST(double, double_val);
-DEFINE_CONST(int32, int_val);
-DEFINE_CONST(uint8, int_val);
-DEFINE_CONST(int16, int_val);
-DEFINE_CONST(int8, int_val);
-DEFINE_CONST(int64, int64_val);
-DEFINE_CONST(bool, bool_val);
+ Node* ret;
+ Graph* graph = scope.graph();
+ const string unique_name = scope.GetUniqueNameForOp("Const");
+ auto builder = NodeBuilder(unique_name, "Const")
+ .Attr("value", val.tensor)
+ .Attr("dtype", val.tensor.dtype());
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(graph, &ret));
-DEFINE_CONST_IMPL(Eigen::half, proto.add_half_val(t.begin()->x));
+ if (!scope.ok()) return Output();
-DEFINE_CONST_IMPL(complex64, proto.add_scomplex_val(t.begin()->real());
- proto.add_scomplex_val(t.begin()->imag()););
-
-DEFINE_CONST_IMPL(complex128, proto.add_dcomplex_val(t.begin()->real());
- proto.add_dcomplex_val(t.begin()->imag()););
-
-Node* Const(StringPiece s, const GraphDefBuilder::Options& options) {
- if (options.HaveError()) return nullptr;
- NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
- options.op_registry());
- TensorProto proto;
- proto.set_dtype(DT_STRING);
- TensorShape({}).AsProto(proto.mutable_tensor_shape());
- proto.add_string_val(s.data(), s.size());
- node_builder.Attr("dtype", DT_STRING).Attr("value", proto);
- return options.FinalizeBuilder(&node_builder);
+ return Output(ret);
}
-DEFINE_CONST_VECTOR(string)
-DEFINE_CONST_TENSOR(string, proto.add_string_val(*t.begin());)
-
-#undef DEFINE_CONST
-#undef DEFINE_CONST_IMPL
-#undef DEFINE_CONST_TENSOR
-#undef DEFINE_CONST_VECTOR
-#undef DEFINE_CONST_SCALAR
-
-Node* Const(const Tensor& t, const GraphDefBuilder::Options& options) {
- if (options.HaveError()) return nullptr;
- NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
- options.op_registry());
- node_builder.Attr("dtype", t.dtype()).Attr("value", t);
- return options.FinalizeBuilder(&node_builder);
+NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp) {
+ if (!inp.status().ok()) {
+ scope.UpdateStatus(inp.status());
+ return NodeBuilder::NodeOut(inp.node(), inp.index());
+ }
+ if (inp.node()) {
+ return NodeBuilder::NodeOut(inp.node(), inp.index());
+ }
+ if (!inp.node_name().empty()) {
+ return NodeBuilder::NodeOut(inp.node_name(), inp.index(), inp.data_type());
+ }
+ auto transformed = Input{
+ Const(scope.NewSubScope("Const"), Input::Initializer(inp.tensor()))};
+ return NodeBuilder::NodeOut{transformed.node(), transformed.index()};
}
-Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options) {
- if (options.HaveError()) return nullptr;
- NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
- options.op_registry());
- node_builder.Attr("dtype", proto.dtype()).Attr("value", proto);
- return options.FinalizeBuilder(&node_builder);
+std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
+ const InputList& inp) {
+ std::vector<NodeBuilder::NodeOut> out;
+ for (const auto& i : inp) {
+ const auto node_out = AsNodeOut(scope, i);
+ if (!scope.ok()) {
+ return {};
+ }
+ out.push_back(node_out);
+ }
+ return out;
}
} // namespace ops
diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h
index a079368ef7..75844d124d 100644
--- a/tensorflow/cc/ops/const_op.h
+++ b/tensorflow/cc/ops/const_op.h
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,75 +13,53 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_
-#define TENSORFLOW_CC_OPS_CONST_OP_H_
+#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
+#define THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor.pb.h"
-#include "tensorflow/core/graph/graph_def_builder.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
namespace ops {
-// If a shape is specified, you may either provide the same number of values,
-// or a single value and that value will be duplicated to fill out the Tensor.
-#define DECLARE_CONST(TYPE) \
- Node* Const(TYPE s, const GraphDefBuilder::Options& options); /* Scalar */ \
- Node* Const(gtl::ArraySlice<TYPE> v, \
- const GraphDefBuilder::Options& options); /* Vector */ \
- Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \
- const GraphDefBuilder::Options& options); /* Tensor */ \
- inline Node* Const(std::initializer_list<TYPE> v, /* Vector using {...} */ \
- const GraphDefBuilder::Options& options) { \
- return Const(gtl::ArraySlice<TYPE>(v), options); \
- } \
- inline Node* Const(std::initializer_list<TYPE> t, /* Tensor using {...} */ \
- const TensorShape& shape, \
- const GraphDefBuilder::Options& options) { \
- return Const(gtl::ArraySlice<TYPE>(t), shape, options); \
- }
-
-DECLARE_CONST(Eigen::half);
-DECLARE_CONST(float);
-DECLARE_CONST(double);
-DECLARE_CONST(int32);
-DECLARE_CONST(uint8);
-DECLARE_CONST(int16);
-DECLARE_CONST(int8);
-DECLARE_CONST(complex64);
-DECLARE_CONST(complex128);
-DECLARE_CONST(int64);
-DECLARE_CONST(bool);
+Output Const(const Scope& scope, const Input::Initializer& val);
-#undef DECLARE_CONST
-
-// String
-Node* Const(StringPiece s, const GraphDefBuilder::Options& options);
-Node* Const(gtl::ArraySlice<string> v, const GraphDefBuilder::Options& options);
-Node* Const(gtl::ArraySlice<string> t, const TensorShape& shape,
- const GraphDefBuilder::Options& options);
-inline Node* Const(std::initializer_list<string> v,
- const GraphDefBuilder::Options& options) {
- return Const(gtl::ArraySlice<string>(v), options);
-}
-inline Node* Const(std::initializer_list<string> t, const TensorShape& shape,
- const GraphDefBuilder::Options& options) {
- return Const(gtl::ArraySlice<string>(t), shape, options);
+template <typename T>
+Output Const(const Scope& scope, const Input::Initializer& val) {
+ if (!scope.ok()) return Output();
+ if (!val.status.ok()) {
+ scope.UpdateStatus(val.status);
+ return Output();
+ }
+ typedef typename Input::Initializer::RealType<T>::type DstT;
+ if (val.tensor.NumElements() > 0) {
+ // TODO(keveman): Implement the in-situ cast.
+ scope.UpdateStatus(errors::Unimplemented(
+ "Explict cast of a non-empty tensor not implemented yet"));
+ return Output();
+ }
+ Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
+ return Const(scope, Input::Initializer(t));
}
-// A Tensor of any type.
-Node* Const(const Tensor& t, const GraphDefBuilder::Options& options);
-Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options);
+template <typename T>
+Output Const(const Scope& scope, const T& v, const TensorShape shape) {
+ return Const(scope, Input::Initializer(v, shape));
+}
-template <class T>
-Node* EmptyConst(const GraphDefBuilder::Options& options) {
- return Const(gtl::ArraySlice<T>(), options);
+template <typename T>
+Output Const(const Scope& scope, const std::initializer_list<T>& v,
+ const TensorShape shape) {
+ return Const(scope, Input::Initializer(v, shape));
}
-// TODO(josh11b): Support other types (e.g. quantized ints, float16).
+NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
+
+std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
+ const InputList& inp);
} // namespace ops
} // namespace tensorflow
-#endif // TENSORFLOW_CC_OPS_CONST_OP_H_
+#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc
new file mode 100644
index 0000000000..a56b66c1cc
--- /dev/null
+++ b/tensorflow/cc/ops/const_op_test.cc
@@ -0,0 +1,128 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+template <typename T>
+void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values,
+ TensorShape shape) {
+ EXPECT_TRUE(n->IsConstant());
+ Tensor tensor;
+ TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
+ DataType dtype;
+ TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
+ EXPECT_EQ(tensor.dtype(), dtype);
+ test::ExpectTensorEqual<T>(tensor, test::AsTensor(values, shape));
+}
+
+void ExpectTypeAndShape(const Node* n, DataType expected_dtype,
+ TensorShape expected_shape) {
+ EXPECT_TRUE(n->IsConstant());
+ Tensor tensor;
+ TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
+ DataType dtype;
+ TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
+ EXPECT_EQ(dtype, expected_dtype);
+ EXPECT_EQ(expected_shape, TensorShape(tensor.shape()));
+}
+
+} // namespace
+
+TEST(ConstOpTest, Basic) {
+ Scope root = Scope::NewRootScope();
+ auto c = ops::Const(root, 42.0f);
+ TF_EXPECT_OK(root.status());
+ EXPECT_EQ(c.op().output_type(0), DT_FLOAT);
+ ExpectNodeEqual<float>(c.node(), {42.0f}, {});
+}
+
+TEST(ConstOpTest, MultiDim) {
+ Scope root = Scope::NewRootScope();
+ auto c = ops::Const(root, {{2.0}, {3.0}});
+ TF_CHECK_OK(root.status());
+ EXPECT_EQ(c.op().output_type(0), DT_DOUBLE);
+ ExpectNodeEqual<double>(c.node(), {2.0, 3.0}, {2, 1});
+}
+
+TEST(ConstOpTest, Empty) {
+ Scope root = Scope::NewRootScope();
+
+ auto c1 = ops::Const(root, {});
+ TF_CHECK_OK(root.status());
+ ExpectTypeAndShape(c1.node(), DT_FLOAT, {0});
+
+ auto c2 = ops::Const(root, {{}});
+ TF_CHECK_OK(root.status());
+ ExpectTypeAndShape(c2.node(), DT_FLOAT, {1, 0});
+
+ auto c3 = ops::Const(root, {{{}, {}}});
+ TF_CHECK_OK(root.status());
+ ExpectTypeAndShape(c3.node(), DT_FLOAT, {1, 2, 0});
+
+ auto c4 = ops::Const<int>(root, {{{}}});
+ TF_CHECK_OK(root.status());
+ ExpectTypeAndShape(c4.node(), DT_INT32, {1, 1, 0});
+
+ ops::Const(root, {{}, {{}}});
+ EXPECT_FALSE(root.status().ok());
+}
+
+TEST(ConstOpTest, WithExplicitShape) {
+ Scope root = Scope::NewRootScope();
+ auto c = ops::Const(root, 42.0, {2, 2});
+ TF_CHECK_OK(root.status());
+ EXPECT_EQ(c.op().output_type(0), DT_DOUBLE);
+ ExpectNodeEqual<double>(c.node(), {42.0, 42.0, 42.0, 42.0}, {2, 2});
+
+ auto d = ops::Const(root, {"1", "2", "3", "4", "5", "6"}, {2, 3});
+ TF_CHECK_OK(root.status());
+ EXPECT_EQ(d.op().output_type(0), DT_STRING);
+ ExpectNodeEqual<string>(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3});
+}
+
+TEST(ConstOpTest, InvalidInitializer) {
+ Scope root = Scope::NewRootScope();
+ ops::Const(root, {{2.0}, {"df"}});
+ EXPECT_FALSE(root.status().ok());
+}
+
+TEST(ConstOpTest, Names) {
+ Scope root = Scope::NewRootScope();
+ auto c = ops::Const(root, {{2.0}, {3.0}});
+ EXPECT_EQ(c.node()->name(), "Const");
+ auto c_1 = ops::Const(root, {{2.0}, {3.0}});
+ EXPECT_EQ(c_1.node()->name(), "Const_1");
+
+ auto x = ops::Const(root.WithOpName("x"), 1);
+ EXPECT_EQ(x.node()->name(), "x");
+ auto x_1 = ops::Const(root.WithOpName("x"), 1);
+ EXPECT_EQ(x_1.node()->name(), "x_1");
+
+ Scope child = root.NewSubScope("c");
+ auto c_y = ops::Const(child.WithOpName("y"), 1);
+ EXPECT_EQ(c_y.node()->name(), "c/y");
+ auto c_y_1 = ops::Const(child.WithOpName("y"), 1);
+ EXPECT_EQ(c_y_1.node()->name(), "c/y_1");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/cc/ops/standard_ops.h b/tensorflow/cc/ops/standard_ops.h
index 13cb9d252b..e117ddd042 100644
--- a/tensorflow/cc/ops/standard_ops.h
+++ b/tensorflow/cc/ops/standard_ops.h
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// #include this file to get access to the standard set of C++ graph
-// definition libraries.
-
-#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_
-#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_
+#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
+#define THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/candidate_sampling_ops.h"
#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/control_flow_ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/io_ops.h"
@@ -28,6 +27,7 @@ limitations under the License.
#include "tensorflow/cc/ops/logging_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/no_op.h"
#include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/random_ops.h"
#include "tensorflow/cc/ops/sparse_ops.h"
@@ -36,4 +36,4 @@ limitations under the License.
#include "tensorflow/cc/ops/training_ops.h"
#include "tensorflow/cc/ops/user_ops.h"
-#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_
+#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_