aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/gen/cc/op_specs.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/gen/cc/op_specs.cc')
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.cc423
1 files changed, 423 insertions, 0 deletions
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
new file mode 100644
index 0000000000..181fd4c5e3
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_specs.cc
@@ -0,0 +1,423 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
+
+namespace tensorflow {
+namespace java {
+namespace {
+
+inline bool IsRealNumbers(const AttrValue& values) {
+ if (!values.has_list()) {
+ return RealNumberTypes().Contains(values.type());
+ }
+ for (int i = 0; i < values.list().type_size(); ++i) {
+ if (!RealNumberTypes().Contains(values.list().type(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+class TypeResolver {
+ public:
+ explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
+
+ // Returns the class type of an input/output argument
+ //
+ // For example, if the argument's datatype is DT_STRING, this method will
+ // return "java.lang.String", so the argument can become "Operand<String>"
+ // in the Ops API
+ Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out);
+
+ // Returns types of an input attribute
+ //
+ // The first element of the pair is the class type of this attribute while
+ // the second is its JNI/primitive type equivalent, required for explicit
+ // unboxing.
+ //
+ // For example, if the attribute is of type "float", this method will return
+ // <java.lang.Float, float>, so the attribute can be used as a "Float" object
+ // in the Ops API and casted to a "float" when passing through the JNI layer.
+ std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def,
+ bool* iterable_out);
+
+ // Returns true if the type of this attribute has already been resolved
+ bool IsAttributeVisited(const string& attr_name) {
+ return visited_attrs_.find(attr_name) != visited_attrs_.cend();
+ }
+
+ private:
+ const OpDef op_def_;
+ std::map<std::string, Type> visited_attrs_;
+ char next_generic_letter_ = 'T';
+
+ std::pair<Type, Type> MakeTypePair(const Type& type, const Type& jni_type) {
+ return std::make_pair(type, jni_type);
+ }
+ std::pair<Type, Type> MakeTypePair(const Type& type) {
+ return std::make_pair(type, type);
+ }
+ Type NextGeneric() {
+ char generic_letter = next_generic_letter_++;
+ if (next_generic_letter_ > 'Z') {
+ next_generic_letter_ = 'A';
+ }
+ return Type::Generic(string(1, generic_letter));
+ }
+};
+
+Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
+ *iterable_out = false;
+ if (!arg_def.number_attr().empty()) {
+ // when number_attr is set, argument has to be a list of tensors
+ *iterable_out = true;
+ visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
+ }
+ Type type = Type::Wildcard();
+ if (arg_def.type() != DataType::DT_INVALID) {
+ // resolve type from DataType
+ switch (arg_def.type()) {
+ case DataType::DT_BOOL:
+ type = Type::Class("Boolean");
+ break;
+ case DataType::DT_STRING:
+ type = Type::Class("String");
+ break;
+ case DataType::DT_FLOAT:
+ type = Type::Class("Float");
+ break;
+ case DataType::DT_DOUBLE:
+ type = Type::Class("Double");
+ break;
+ case DataType::DT_UINT8:
+ type = Type::Class("UInt8", "org.tensorflow.types");
+ break;
+ case DataType::DT_INT32:
+ type = Type::Class("Integer");
+ break;
+ case DataType::DT_INT64:
+ type = Type::Class("Long");
+ break;
+ case DataType::DT_RESOURCE:
+ // TODO(karllessard) create a Resource utility class that could be
+ // used to store a resource and its type (passed in a second argument).
+ // For now, we need to force a wildcard and we will unfortunately lose
+ // track of the resource type.
+ break;
+ default:
+ // Any other datatypes does not have a equivalent in Java and must
+ // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
+ break;
+ }
+ } else if (!arg_def.type_attr().empty()) {
+ // resolve type from attribute (if already visited, retrieve its type)
+ if (IsAttributeVisited(arg_def.type_attr())) {
+ type = visited_attrs_.at(arg_def.type_attr());
+ } else {
+ for (const auto& attr_def : op_def_.attr()) {
+ if (attr_def.name() == arg_def.type_attr()) {
+ type = TypesOf(attr_def, iterable_out).first;
+ break;
+ }
+ }
+ }
+ } else if (!arg_def.type_list_attr().empty()) {
+ // type is a list of tensors that can be of different data types, so leave
+ // it as a list of wildcards
+ *iterable_out = true;
+ visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type));
+
+ } else {
+ LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
+ << "\" in operation \"" << op_def_.name() << "\"";
+ }
+ return type;
+}
+
+std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
+ bool* iterable_out) {
+ std::pair<Type, Type> types = MakeTypePair(Type::Wildcard());
+ *iterable_out = false;
+ StringPiece attr_type = attr_def.type();
+ if (str_util::ConsumePrefix(&attr_type, "list(")) {
+ attr_type.remove_suffix(1); // remove closing brace
+ *iterable_out = true;
+ }
+ if (attr_type == "string") {
+ types = MakeTypePair(Type::Class("String"));
+
+ } else if (attr_type == "int") {
+ types = MakeTypePair(Type::Class("Long"), Type::Long());
+
+ } else if (attr_type == "float") {
+ types = MakeTypePair(Type::Class("Float"), Type::Float());
+
+ } else if (attr_type == "bool") {
+ types = MakeTypePair(Type::Class("Boolean"), Type::Boolean());
+
+ } else if (attr_type == "shape") {
+ types = MakeTypePair(Type::Class("Shape", "org.tensorflow"));
+
+ } else if (attr_type == "tensor") {
+ types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
+ .add_parameter(Type::Wildcard()));
+
+ } else if (attr_type == "type") {
+ Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
+ if (IsRealNumbers(attr_def.allowed_values())) {
+ type.add_supertype(Type::Class("Number"));
+ }
+ types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
+
+ } else {
+ LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
+ << "\" in operation \"" << op_def_.name() << "\"";
+ }
+ visited_attrs_.insert(std::make_pair(attr_def.name(), types.first));
+ return types;
+}
+
+string SnakeToCamelCase(const string& str, bool upper = false) {
+ string result;
+ bool cap = upper;
+ for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
+ const char c = *it;
+ if (c == '_') {
+ cap = true;
+ } else if (cap) {
+ result += toupper(c);
+ cap = false;
+ } else {
+ result += c;
+ }
+ }
+ return result;
+}
+
+bool FindAndCut(string* input, const RE2& expr, string* before_match,
+ string* ret_match = nullptr) {
+ string match;
+ if (!RE2::PartialMatch(*input, expr, &match)) return false;
+ *before_match = input->substr(0, input->find(match));
+ *input = input->substr(before_match->size() + match.size());
+ if (ret_match != nullptr) *ret_match = match;
+ return true;
+}
+
+string ParseDocumentation(const string& inp) {
+ std::stringstream javadoc_text;
+
+ // TODO(karllessard) This is a very minimalist utility method for converting
+ // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check
+ // for alternatives to increase the level of support for markups.
+ std::vector<string> markups_subexpr;
+ markups_subexpr.push_back("\n+\\*\\s+"); // lists
+ markups_subexpr.push_back("\n{2,}"); // paragraphs
+ markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks
+ markups_subexpr.push_back("`+"); // inlined code and code blocks
+ markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis
+ markups_subexpr.push_back("\\["); // hyperlinks
+ const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")");
+
+ bool in_list = false;
+ string input = inp;
+ while (true) {
+ string text, markup;
+ if (!FindAndCut(&input, markup_expr, &text, &markup)) {
+ javadoc_text << input;
+ break; // end of loop
+ }
+ javadoc_text << text;
+ if (str_util::StartsWith(markup, "\n")) {
+ javadoc_text << "\n";
+ if (str_util::StrContains(markup, "*")) {
+ // new list item
+ javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
+ in_list = true;
+ } else if (in_list) {
+ // end of list
+ javadoc_text << "</li>\n</ul>\n";
+ in_list = false;
+ } else if (!str_util::StartsWith(input, "```")) {
+ // new paragraph (not required if a <pre> block follows)
+ javadoc_text << "<p>\n";
+ }
+ } else if (str_util::StartsWith(markup, "```")) {
+ // code blocks
+ if (FindAndCut(&input, "(```\\s*\n*)", &text)) {
+ javadoc_text << "<pre>{@code\n" << text << "}</pre>\n";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (str_util::StartsWith("(" + markup + ")", "`")) {
+ // inlined code
+ if (FindAndCut(&input, markup, &text)) {
+ javadoc_text << "{@code " << text << "}";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (markup == "**") {
+ // text emphasis (strong)
+ if (FindAndCut(&input, "(\\b\\*{2})", &text)) {
+ javadoc_text << "<b>" << ParseDocumentation(text) << "</b>";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (markup == "*") {
+ // text emphasis (normal)
+ if (FindAndCut(&input, "(\\b\\*{1})", &text)) {
+ javadoc_text << "<i>" << ParseDocumentation(text) << "</i>";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (str_util::StartsWith(markup, "[")) {
+ // hyperlinks
+ string label;
+ string link;
+ if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label,
+ &link) &&
+ str_util::StartsWith(input, label + link)) {
+ input = input.substr(label.size() + link.size());
+ javadoc_text << "<a href=\"" << link << "\">"
+ << ParseDocumentation(label) << "</a>";
+ } else {
+ javadoc_text << markup;
+ }
+ } else {
+ // safe fallback
+ javadoc_text << markup;
+ }
+ }
+ return javadoc_text.str();
+}
+
+ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
+ const ApiDef::Arg& input_api_def,
+ TypeResolver* type_resolver) {
+ bool iterable = false;
+ Type type = type_resolver->TypeOf(input_def, &iterable);
+ Type var_type =
+ Type::Interface("Operand", "org.tensorflow").add_parameter(type);
+ if (iterable) {
+ var_type = Type::IterableOf(var_type);
+ }
+ return ArgumentSpec(
+ input_api_def.name(),
+ Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
+ type, ParseDocumentation(input_api_def.description()), iterable);
+}
+
+AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
+ const ApiDef::Attr& attr_api_def,
+ TypeResolver* type_resolver) {
+ bool iterable = false;
+ std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
+ Type var_type = types.first.kind() == Type::GENERIC
+ ? Type::Class("Class").add_parameter(types.first)
+ : types.first;
+ if (iterable) {
+ var_type = Type::ListOf(var_type);
+ }
+ return AttributeSpec(
+ attr_api_def.name(),
+ Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
+ types.first, types.second, ParseDocumentation(attr_api_def.description()),
+ iterable, attr_api_def.has_default_value());
+}
+
+ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
+ const ApiDef::Arg& output_api,
+ TypeResolver* type_resolver) {
+ bool iterable = false;
+ Type type = type_resolver->TypeOf(output_def, &iterable);
+ Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type);
+ if (iterable) {
+ var_type = Type::ListOf(var_type);
+ }
+ return ArgumentSpec(
+ output_api.name(),
+ Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
+ type, ParseDocumentation(output_api.description()), iterable);
+}
+
+EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
+ const ApiDef_Endpoint& endpoint_def) {
+ std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
+ string package;
+ string name;
+ if (name_tokens.size() > 1) {
+ package = name_tokens.at(0);
+ name = name_tokens.at(1);
+ } else {
+ package = "core"; // generate unclassified ops in the 'core' package
+ name = name_tokens.at(0);
+ }
+ return EndpointSpec(package, name,
+ Javadoc::Create(ParseDocumentation(api_def.summary()))
+ .details(ParseDocumentation(api_def.description())));
+}
+
+} // namespace
+
+OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
+ OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN,
+ op_def.deprecation().explanation());
+ TypeResolver type_resolver(op_def);
+ for (const string& next_input_name : api_def.arg_order()) {
+ for (int i = 0; i < op_def.input_arg().size(); ++i) {
+ if (op_def.input_arg(i).name() == next_input_name) {
+ op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
+ &type_resolver));
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < op_def.attr().size(); ++i) {
+ // do not parse attributes already visited, they have probably been inferred
+ // before as an input argument type
+ if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
+ AttributeSpec attr =
+ CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver);
+ // attributes with a default value are optional
+ if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) {
+ op.optional_attributes_.push_back(attr);
+ } else {
+ op.attributes_.push_back(attr);
+ }
+ }
+ }
+ for (int i = 0; i < op_def.output_arg().size(); ++i) {
+ op.outputs_.push_back(
+ CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver));
+ }
+ for (const auto& endpoint_def : api_def.endpoint()) {
+ op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
+ }
+ return op;
+}
+
+} // namespace java
+} // namespace tensorflow