aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-04-10 18:44:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 18:46:38 -0700
commit5ad9e4588874f30d0d079acc60e07f2eddc0480f (patch)
treeab800846cc505d867b2961578869aec97eeb81a3 /tensorflow/java/src
parentfad74785d12ea7463e5d0474522cd7d754699656 (diff)
Merge changes from github.
PiperOrigin-RevId: 192388250
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h45
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.cc305
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.h192
-rw-r--r--tensorflow/java/src/gen/cc/source_writer_test.cc369
-rw-r--r--tensorflow/java/src/gen/resources/test.java.snippet2
5 files changed, 833 insertions, 80 deletions
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index 615cdc165b..59f8beaee7 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -17,10 +17,7 @@ limitations under the License.
#define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
#include <string>
-#include <vector>
-#include <deque>
-
-#include "tensorflow/core/platform/env.h"
+#include <list>
namespace tensorflow {
namespace java {
@@ -104,17 +101,17 @@ class Type {
description_ = description;
return *this;
}
- const std::vector<Type>& parameters() const { return parameters_; }
+ const std::list<Type>& parameters() const { return parameters_; }
Type& add_parameter(const Type& parameter) {
parameters_.push_back(parameter);
return *this;
}
- const std::vector<Annotation>& annotations() const { return annotations_; }
+ const std::list<Annotation>& annotations() const { return annotations_; }
Type& add_annotation(const Annotation& annotation) {
annotations_.push_back(annotation);
return *this;
}
- const std::deque<Type>& supertypes() const { return supertypes_; }
+ const std::list<Type>& supertypes() const { return supertypes_; }
Type& add_supertype(const Type& type) {
if (type.kind_ == CLASS) {
supertypes_.push_front(type); // keep superclass at the front of the list
@@ -141,9 +138,9 @@ class Type {
string name_;
string package_;
string description_;
- std::vector<Type> parameters_;
- std::vector<Annotation> annotations_;
- std::deque<Type> supertypes_;
+ std::list<Type> parameters_;
+ std::list<Annotation> annotations_;
+ std::list<Type> supertypes_;
};
// Definition of a Java annotation
@@ -223,16 +220,12 @@ class Method {
return_description_ = description;
return *this;
}
- const std::vector<Variable>& arguments() const { return arguments_; }
- Method& add_arguments(const std::vector<Variable>& args) {
- arguments_.insert(arguments_.cend(), args.cbegin(), args.cend());
- return *this;
- }
+ const std::list<Variable>& arguments() const { return arguments_; }
Method& add_argument(const Variable& var) {
arguments_.push_back(var);
return *this;
}
- const std::vector<Annotation>& annotations() const { return annotations_; }
+ const std::list<Annotation>& annotations() const { return annotations_; }
Method& add_annotation(const Annotation& annotation) {
annotations_.push_back(annotation);
return *this;
@@ -244,29 +237,13 @@ class Method {
bool constructor_;
string description_;
string return_description_;
- std::vector<Variable> arguments_;
- std::vector<Annotation> annotations_;
+ std::list<Variable> arguments_;
+ std::list<Annotation> annotations_;
Method(const string& name, const Type& return_type, bool constructor)
: name_(name), return_type_(return_type), constructor_(constructor) {}
};
-// A piece of code to read from a file.
-class Snippet {
- public:
- static Snippet Create(const string& fname, Env* env = Env::Default()) {
- return Snippet(fname, env);
- }
- const string& data() const { return data_; }
-
- private:
- string data_;
-
- Snippet(const string& fname, Env* env) {
- TF_CHECK_OK(ReadFileToString(env, fname, &data_));
- }
-};
-
} // namespace java
} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
index 2da81f2911..a02f75ad6e 100644
--- a/tensorflow/java/src/gen/cc/source_writer.cc
+++ b/tensorflow/java/src/gen/cc/source_writer.cc
@@ -14,49 +14,328 @@ limitations under the License.
==============================================================================*/
#include <string>
+#include <algorithm>
+#include <deque>
#include "tensorflow/java/src/gen/cc/source_writer.h"
namespace tensorflow {
+namespace java {
-SourceWriter& SourceWriter::Append(const StringPiece& str) {
- if (!str.empty()) {
- if (newline_) {
- DoAppend(left_margin_ + line_prefix_);
- newline_ = false;
- }
- DoAppend(str);
+SourceWriter::SourceWriter() {
+ // Push an empty generic namespace at start, for simplification.
+ generic_namespaces_.push(new GenericNamespace());
+}
+
+SourceWriter::~SourceWriter() {
+ // Remove empty generic namespace added at start as well as any other
+ // namespace objects that haven't been removed.
+ while (!generic_namespaces_.empty()) {
+ GenericNamespace* generic_namespace = generic_namespaces_.top();
+ generic_namespaces_.pop();
+ delete generic_namespace;
}
+}
+
+SourceWriter& SourceWriter::Indent(int tab) {
+ left_margin_.resize(
+ std::max(static_cast<int>(left_margin_.size() + tab), 0), ' ');
+ return *this;
+}
+
+SourceWriter& SourceWriter::Prefix(const char* line_prefix) {
+ line_prefix_ = line_prefix;
return *this;
}
-SourceWriter& SourceWriter::Write(const string& str) {
+SourceWriter& SourceWriter::Write(const StringPiece& str) {
size_t line_pos = 0;
do {
size_t start_pos = line_pos;
line_pos = str.find('\n', start_pos);
if (line_pos != string::npos) {
++line_pos;
- Append(StringPiece(str.data() + start_pos, line_pos - start_pos));
+ Append(str.substr(start_pos, line_pos - start_pos));
newline_ = true;
} else {
- Append(StringPiece(str.data() + start_pos, str.size() - start_pos));
+ Append(str.substr(start_pos, str.size() - start_pos));
}
} while (line_pos != string::npos && line_pos < str.size());
return *this;
}
+SourceWriter& SourceWriter::WriteFromFile(const string& fname, Env* env) {
+ string data_;
+ TF_CHECK_OK(ReadFileToString(env, fname, &data_));
+ return Write(data_);
+}
+
+SourceWriter& SourceWriter::Append(const StringPiece& str) {
+ if (!str.empty()) {
+ if (newline_) {
+ DoAppend(left_margin_ + line_prefix_);
+ newline_ = false;
+ }
+ DoAppend(str);
+ }
+ return *this;
+}
+
+SourceWriter& SourceWriter::AppendType(const Type& type) {
+ if (type.kind() == Type::Kind::GENERIC && type.name().empty()) {
+ Append("?");
+ } else {
+ Append(type.name());
+ }
+ if (!type.parameters().empty()) {
+ Append("<");
+ for (const Type& t : type.parameters()) {
+ if (&t != &type.parameters().front()) {
+ Append(", ");
+ }
+ AppendType(t);
+ }
+ Append(">");
+ }
+ return *this;
+}
+
SourceWriter& SourceWriter::EndLine() {
Append("\n");
newline_ = true;
return *this;
}
-SourceWriter& SourceWriter::Indent(int tab) {
- left_margin_.resize(std::max(static_cast<int>(left_margin_.size() + tab), 0),
- ' ');
+SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
+ GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
+ if (!method.constructor()) {
+ generic_namespace->Visit(method.return_type());
+ }
+ for (const Variable& v : method.arguments()) {
+ generic_namespace->Visit(v.type());
+ }
+ EndLine();
+ WriteDoc(method.description(), method.return_description(),
+ &method.arguments());
+ if (!method.annotations().empty()) {
+ WriteAnnotations(method.annotations());
+ }
+ WriteModifiers(modifiers);
+ if (!generic_namespace->declared_types().empty()) {
+ WriteGenerics(generic_namespace->declared_types());
+ Append(" ");
+ }
+ if (!method.constructor()) {
+ AppendType(method.return_type()).Append(" ");
+ }
+ Append(method.name()).Append("(");
+ for (const Variable& v : method.arguments()) {
+ if (&v != &method.arguments().front()) {
+ Append(", ");
+ }
+ AppendType(v.type()).Append(v.variadic() ? "... " : " ").Append(v.name());
+ }
+ return Append(")").BeginBlock();
+}
+
+SourceWriter& SourceWriter::EndMethod() {
+ EndBlock();
+ PopGenericNamespace();
+ return *this;
+}
+
+SourceWriter& SourceWriter::BeginType(const Type& type,
+ const std::list<Type>* dependencies, int modifiers) {
+ if (!type.package().empty()) {
+ Append("package ").Append(type.package()).Append(";").EndLine();
+ }
+ if (dependencies != nullptr && !dependencies->empty()) {
+ TypeImporter type_importer(type.package());
+ for (const Type& t : *dependencies) {
+ type_importer.Visit(t);
+ }
+ EndLine();
+ for (const string& s : type_importer.imports()) {
+ Append("import ").Append(s).Append(";").EndLine();
+ }
+ }
+ return BeginInnerType(type, modifiers);
+}
+
+SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) {
+ GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
+ generic_namespace->Visit(type);
+ EndLine();
+ WriteDoc(type.description());
+ if (!type.annotations().empty()) {
+ WriteAnnotations(type.annotations());
+ }
+ WriteModifiers(modifiers);
+ CHECK_EQ(Type::Kind::CLASS, type.kind()) << ": Not supported yet";
+ Append("class ").Append(type.name());
+ if (!generic_namespace->declared_types().empty()) {
+ WriteGenerics(generic_namespace->declared_types());
+ }
+ if (!type.supertypes().empty()) {
+ bool first_interface = true;
+ for (const Type& t : type.supertypes()) {
+ if (t.kind() == Type::CLASS) { // superclass is always first in list
+ Append(" extends ");
+ } else if (first_interface) {
+ Append(" implements ");
+ first_interface = false;
+ } else {
+ Append(", ");
+ }
+ AppendType(t);
+ }
+ }
+ return BeginBlock();
+}
+
+SourceWriter& SourceWriter::EndType() {
+ EndBlock();
+ PopGenericNamespace();
+ return *this;
+}
+
+SourceWriter& SourceWriter::WriteFields(const std::list<Variable>& fields,
+ int modifiers) {
+ EndLine();
+ for (const Variable& v : fields) {
+ WriteModifiers(modifiers);
+ AppendType(v.type()).Append(" ").Append(v.name()).Append(";");
+ EndLine();
+ }
+ return *this;
+}
+
+SourceWriter& SourceWriter::WriteModifiers(int modifiers) {
+ if (modifiers & PUBLIC) {
+ Append("public ");
+ } else if (modifiers & PROTECTED) {
+ Append("protected ");
+ } else if (modifiers & PRIVATE) {
+ Append("private ");
+ }
+ if (modifiers & STATIC) {
+ Append("static ");
+ }
+ if (modifiers & FINAL) {
+ Append("final ");
+ }
+ return *this;
+}
+
+SourceWriter& SourceWriter::WriteDoc(const string& description,
+ const string& return_description, const std::list<Variable>* parameters) {
+ if (description.empty() && return_description.empty()
+ && (parameters == nullptr || parameters->empty())) {
+ return *this; // no doc to write
+ }
+ bool do_line_break = false;
+ Append("/**").EndLine().Prefix(" * ");
+ if (!description.empty()) {
+ Write(description).EndLine();
+ do_line_break = true;
+ }
+ if (parameters != nullptr && !parameters->empty()) {
+ if (do_line_break) {
+ EndLine();
+ do_line_break = false;
+ }
+ for (const Variable& v : *parameters) {
+ Append("@param ").Append(v.name());
+ if (!v.description().empty()) {
+ Append(" ").Write(v.description());
+ }
+ EndLine();
+ }
+ }
+ if (!return_description.empty()) {
+ if (do_line_break) {
+ EndLine();
+ do_line_break = false;
+ }
+ Append("@return ").Write(return_description).EndLine();
+ }
+ return Prefix("").Append(" **/").EndLine();
+}
+
+SourceWriter& SourceWriter::WriteAnnotations(
+ const std::list<Annotation>& annotations) {
+ for (const Annotation& a : annotations) {
+ Append("@" + a.name());
+ if (!a.attributes().empty()) {
+ Append("(").Append(a.attributes()).Append(")");
+ }
+ EndLine();
+ }
return *this;
}
+SourceWriter& SourceWriter::WriteGenerics(
+ const std::list<const Type*>& generics) {
+ Append("<");
+ for (const Type* pt : generics) {
+ if (pt != generics.front()) {
+ Append(", ");
+ }
+ Append(pt->name());
+ if (!pt->supertypes().empty()) {
+ Append(" extends ").AppendType(pt->supertypes().front());
+ }
+ }
+ return Append(">");
+}
+
+SourceWriter::GenericNamespace* SourceWriter::PushGenericNamespace(
+ int modifiers) {
+ GenericNamespace* generic_namespace;
+ if (modifiers & STATIC) {
+ generic_namespace = new GenericNamespace();
+ } else {
+ generic_namespace = new GenericNamespace(generic_namespaces_.top());
+ }
+ generic_namespaces_.push(generic_namespace);
+ return generic_namespace;
+}
+
+void SourceWriter::PopGenericNamespace() {
+ GenericNamespace* generic_namespace = generic_namespaces_.top();
+ generic_namespaces_.pop();
+ delete generic_namespace;
+}
+
+void SourceWriter::TypeVisitor::Visit(const Type& type) {
+ DoVisit(type);
+ for (const Type& t : type.parameters()) {
+ DoVisit(t);
+ }
+ for (const Annotation& t : type.annotations()) {
+ DoVisit(t);
+ }
+ for (const Type& t : type.supertypes()) {
+ DoVisit(t);
+ }
+}
+
+void SourceWriter::GenericNamespace::DoVisit(const Type& type) {
+ // ignore non-generic parameters, wildcards and generics already declared
+ if (type.kind() == Type::GENERIC
+ && !type.IsWildcard()
+ && generic_names_.find(type.name()) == generic_names_.end()) {
+ declared_types_.push_back(&type);
+ generic_names_.insert(type.name());
+ }
+}
+
+void SourceWriter::TypeImporter::DoVisit(const Type& type) {
+ if (!type.package().empty() && type.package() != current_package_) {
+ imports_.insert(type.package() + '.' + type.name());
+ }
+}
+
+} // namespace java
} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h
index bff26eb185..637072c0df 100644
--- a/tensorflow/java/src/gen/cc/source_writer.h
+++ b/tensorflow/java/src/gen/cc/source_writer.h
@@ -17,44 +17,23 @@ limitations under the License.
#define TENSORFLOW_JAVA_SRC_GEN_CC_SOURCE_WRITER_H_
#include <string>
+#include <stack>
+#include <list>
+#include <set>
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
namespace tensorflow {
+namespace java {
-// A utility class for writing source code, normally generated at
-// compile-time.
-//
-// Source writers are language-agnostic and therefore only expose generic
-// methods common to most languages. Extend or wrap this class to implement
-// language-specific features.
-//
-// Note: if you are looking to reuse this class for generating code in another
-// language than Java, please do by moving it at the '//tensorflow/core/lib/io'
-// level.
+// A class for writing Java source code.
class SourceWriter {
public:
- virtual ~SourceWriter() = default;
-
- // Returns true if the writer is at the beginnig of a new line
- bool newline() const { return newline_; }
-
- // Appends a piece of code or text.
- //
- // It is expected that no newline character is present in the data provided,
- // otherwise Write() must be used.
- SourceWriter& Append(const StringPiece& str);
+ SourceWriter();
- // Writes a block of code or text.
- //
- // The data might potentially contain newline characters, therefore it will
- // be scanned to ensure that each line is indented and prefixed properly,
- // making it a bit slower than Append().
- SourceWriter& Write(const string& text);
-
- // Appends a newline character and start writing on a new line.
- SourceWriter& EndLine();
+ virtual ~SourceWriter();
// Indents following lines with white spaces.
//
@@ -75,18 +54,166 @@ class SourceWriter {
// Indent(2)->Prefix("//") will result in prefixing lines with " //".
//
// An empty value ("") will remove any line prefix that was previously set.
- SourceWriter& Prefix(const char* line_prefix) {
- line_prefix_ = line_prefix;
- return *this;
+ SourceWriter& Prefix(const char* line_prefix);
+
+ // Writes a source code snippet.
+ //
+ // The data might potentially contain newline characters, therefore it will
+ // be scanned to ensure that each line is indented and prefixed properly,
+ // making it a bit slower than Append().
+ SourceWriter& Write(const StringPiece& text);
+
+ // Writes a source code snippet read from a file.
+ //
+ // All lines of the file at the provided path will be read and written back
+ // to the output of this writer in regard of its current attributes (e.g.
+ // the indentation, prefix, etc.)
+ SourceWriter& WriteFromFile(const string& fname, Env* env = Env::Default());
+
+ // Appends a piece of source code.
+ //
+ // It is expected that no newline character is present in the data provided,
+ // otherwise Write() must be used.
+ SourceWriter& Append(const StringPiece& str);
+
+ // Appends a type to the current line.
+ //
+ // The type is written in its simple form (i.e. not prefixed by its package)
+ // and followed by any parameter types it has enclosed in brackets (<>).
+ SourceWriter& AppendType(const Type& type);
+
+ // Appends a newline character.
+ //
+ // Data written after calling this method will start on a new line, in respect
+ // of the current indentation.
+ SourceWriter& EndLine();
+
+ // Begins a block of source code.
+ //
+ // This method appends a new opening brace to the current data and indent the
+ // next lines according to Google Java Style Guide. The block can optionally
+ // be preceded by an expression (e.g. Append("if(true)").BeginBlock();)
+ SourceWriter& BeginBlock() {
+ return Append(newline_ ? "{" : " {").EndLine().Indent(2);
+ }
+
+ // Ends the current block of source code.
+ //
+ // This method appends a new closing brace to the current data and outdent the
+ // next lines back to the margin used before BeginBlock() was invoked.
+ SourceWriter& EndBlock() {
+ return Indent(-2).Append("}").EndLine();
}
+ // Begins to write a method.
+ //
+ // This method outputs the signature of the Java method from the data passed
+ // in the 'method' parameter and starts a new block. Additionnal modifiers can
+ // also be passed in parameter to define the accesses and the scope of this
+ // method.
+ SourceWriter& BeginMethod(const Method& method, int modifiers = 0);
+
+ // Ends the current method.
+ //
+ // This method ends the block of code that has begun when invoking
+ // BeginMethod() prior to this.
+ SourceWriter& EndMethod();
+
+ // Begins to write the main type of a source file.
+ //
+ // This method outputs the declaration of the Java type from the data passed
+ // in the 'type' parameter and starts a new block. Additionnal modifiers can
+ // also be passed in parameter to define the accesses and the scope of this
+ // type.
+ //
+ // If not null, all types found in the 'dependencies' list will be imported
+ // before declaring the new type.
+ SourceWriter& BeginType(const Type& clazz,
+ const std::list<Type>* dependencies, int modifiers = 0);
+
+ // Begins to write a new inner type.
+ //
+ // This method outputs the declaration of the Java type from the data passed
+ // in the 'type' parameter and starts a new block. Additionnal modifiers can
+ // also be passed in parameter to define the accesses and the scope of this
+ // type.
+ SourceWriter& BeginInnerType(const Type& type, int modifiers = 0);
+
+ // Ends the current type.
+ //
+ // This method ends the block of code that has begun when invoking
+ // BeginType() or BeginInnerType() prior to this.
+ SourceWriter& EndType();
+
+ // Writes a list of variables as fields of a type.
+ //
+ // This method must be called within the definition of a type (see BeginType()
+ // or BeginInnerType()). Additional modifiers can also be passed in parameter
+ // to define the accesses and the scope of those fields.
+ SourceWriter& WriteFields(const std::list<Variable>& fields,
+ int modifiers = 0);
+
protected:
virtual void DoAppend(const StringPiece& str) = 0;
private:
+ // A utility base class for visiting elements of a type.
+ class TypeVisitor {
+ public:
+ virtual ~TypeVisitor() = default;
+ void Visit(const Type& type);
+
+ protected:
+ virtual void DoVisit(const Type& type) = 0;
+ };
+
+ // A utility class for keeping track of declared generics in a given scope.
+ class GenericNamespace : public TypeVisitor {
+ public:
+ GenericNamespace() = default;
+ explicit GenericNamespace(const GenericNamespace* parent)
+ : generic_names_(parent->generic_names_) {}
+ std::list<const Type*> declared_types() {
+ return declared_types_;
+ }
+ protected:
+ virtual void DoVisit(const Type& type);
+
+ private:
+ std::list<const Type*> declared_types_;
+ std::set<string> generic_names_;
+ };
+
+ // A utility class for collecting a list of import statements to declare.
+ class TypeImporter : public TypeVisitor {
+ public:
+ explicit TypeImporter(const string& current_package)
+ : current_package_(current_package) {}
+ virtual ~TypeImporter() = default;
+ const std::set<string> imports() {
+ return imports_;
+ }
+ protected:
+ virtual void DoVisit(const Type& type);
+
+ private:
+ string current_package_;
+ std::set<string> imports_;
+ };
+
string left_margin_;
string line_prefix_;
bool newline_ = true;
+ std::stack<GenericNamespace*> generic_namespaces_;
+
+ SourceWriter& WriteModifiers(int modifiers);
+ SourceWriter& WriteDoc(const string& description,
+ const string& return_description = "",
+ const std::list<Variable>* parameters = nullptr);
+ SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations);
+ SourceWriter& WriteGenerics(const std::list<const Type*>& generics);
+ GenericNamespace* PushGenericNamespace(int modifiers);
+ void PopGenericNamespace();
};
// A writer that outputs source code into a file.
@@ -128,6 +255,7 @@ class SourceBufferWriter : public SourceWriter {
string* buffer_;
};
+} // namespace java
} // namespace tensorflow
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_SOURCE_WRITER_H_
diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc
index e973895754..4bce2fea70 100644
--- a/tensorflow/java/src/gen/cc/source_writer_test.cc
+++ b/tensorflow/java/src/gen/cc/source_writer_test.cc
@@ -13,11 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/java/src/gen/cc/source_writer.h"
+#include <list>
+
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
+#include "tensorflow/java/src/gen/cc/source_writer.h"
namespace tensorflow {
+namespace java {
namespace {
TEST(AppendTest, SingleLineText) {
@@ -211,5 +215,368 @@ TEST(MarginTest, EmptyPrefix) {
ASSERT_STREQ(expected, writer.str().data());
}
+TEST(StreamTest, BlocksAndLines) {
+ SourceBufferWriter writer;
+
+ writer.Append("int i = 0;").EndLine()
+ .Append("int j = 10;").EndLine()
+ .Append("if (true)")
+ .BeginBlock()
+ .Append("int aLongWayToTen = 0;").EndLine()
+ .Append("while (++i <= j)")
+ .BeginBlock()
+ .Append("++aLongWayToTen;").EndLine()
+ .EndBlock()
+ .EndBlock();
+
+ const char* expected =
+ "int i = 0;\n"
+ "int j = 10;\n"
+ "if (true) {\n"
+ " int aLongWayToTen = 0;\n"
+ " while (++i <= j) {\n"
+ " ++aLongWayToTen;\n"
+ " }\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(StreamTest, Types) {
+ SourceBufferWriter writer;
+ Type generic = Type::Generic("T").add_supertype(Type::Class("Number"));
+
+ writer.AppendType(Type::Int()).Append(", ")
+ .AppendType(Type::Class("String")).Append(", ")
+ .AppendType(generic).Append(", ")
+ .AppendType(Type::ListOf(generic)).Append(", ")
+ .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ")
+ .AppendType(Type::ListOf(Type::Generic()));
+
+ const char* expected =
+ "int, String, T, List<T>, List<Iterable<T>>, List<?>";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(StreamTest, FileSnippet) {
+ SourceBufferWriter writer;
+ const string fname = tensorflow::io::JoinPath(
+ tensorflow::testing::TensorFlowSrcRoot(),
+ "java/src/gen/resources/test.java.snippet");
+
+ writer.WriteFromFile(fname)
+ .BeginBlock()
+ .WriteFromFile(fname)
+ .EndBlock();
+
+ const char* expected =
+ "// Here is a little snippet\n"
+ "System.out.println(\"Hello!\");\n"
+ "{\n"
+ " // Here is a little snippet\n"
+ " System.out.println(\"Hello!\");\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteType, SimpleClass) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+
+ writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test {\n}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteType, SimpleClassWithDependencies) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ std::list<Type> deps;
+ deps.push_back(Type::Class("TypeA", "org.test.sub"));
+ deps.push_back(Type::Class("TypeA", "org.test.sub")); // a second time
+ deps.push_back(Type::Class("TypeB", "org.other"));
+ deps.push_back(Type::Class("SamePackageType", "org.tensorflow"));
+ deps.push_back(Type::Class("NoPackageType"));
+
+ writer.BeginType(clazz, &deps, PUBLIC).EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "import org.other.TypeB;\n"
+ "import org.test.sub.TypeA;\n\n"
+ "public class Test {\n}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteType, AnnotatedAndDocumentedClass) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ clazz.description("This class has a\n<p>\nmultiline description.");
+ clazz.add_annotation(Annotation::Create("Bean"));
+ clazz.add_annotation(Annotation::Create("SuppressWarnings")
+ .attributes("\"rawtypes\""));
+
+ writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "/**\n"
+ " * This class has a\n"
+ " * <p>\n"
+ " * multiline description.\n"
+ " **/\n"
+ "@Bean\n"
+ "@SuppressWarnings(\"rawtypes\")\n"
+ "public class Test {\n}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteType, ParameterizedClass) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ clazz.add_parameter(Type::Generic("T"));
+ clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number")));
+
+ writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test<T, U extends Number> {\n}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteType, ParameterizedClassAndSupertypes) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ Type type_t = Type::Generic("T");
+ clazz.add_parameter(type_t);
+ Type type_u = Type::Generic("U").add_supertype(Type::Class("Number"));
+ clazz.add_parameter(type_u);
+ clazz.add_supertype(Type::Interface("Parametrizable").add_parameter(type_u));
+ clazz.add_supertype(Type::Interface("Runnable"));
+ clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t));
+
+ writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test<T, U extends Number>"
+ " extends SuperTest<T> implements Parametrizable<U>, Runnable {\n}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteType, ParameterizedClassFields) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ Type type_t = Type::Generic("T").add_supertype(Type::Class("Number"));
+ clazz.add_parameter(type_t);
+ std::list<Variable> static_fields;
+ static_fields.push_back(Variable::Create("field1", Type::Class("String")));
+ std::list<Variable> member_fields;
+ member_fields.push_back(Variable::Create("field2", Type::Class("String")));
+ member_fields.push_back(Variable::Create("field3", type_t));
+
+ writer.BeginType(clazz, nullptr, PUBLIC)
+ .WriteFields(static_fields, STATIC | PUBLIC | FINAL)
+ .WriteFields(member_fields, PRIVATE)
+ .EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test<T extends Number> {\n"
+ " \n"
+ " public static final String field1;\n"
+ " \n"
+ " private String field2;\n"
+ " private T field3;\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteType, SimpleInnerClass) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ Type inner_class = Type::Class("InnerTest");
+
+ writer.BeginType(clazz, nullptr, PUBLIC)
+ .BeginInnerType(inner_class, PUBLIC)
+ .EndType()
+ .EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test {\n"
+ " \n"
+ " public class InnerTest {\n"
+ " }\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteType, StaticParameterizedInnerClass) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ Type type_t = Type::Generic("T").add_supertype(Type::Class("Number"));
+ clazz.add_parameter(type_t);
+ Type inner_class = Type::Class("InnerTest");
+ inner_class.add_parameter(type_t);
+
+ writer.BeginType(clazz, nullptr, PUBLIC)
+ .BeginInnerType(inner_class, PUBLIC | STATIC)
+ .EndType()
+ .EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test<T extends Number> {\n"
+ " \n"
+ " public static class InnerTest<T extends Number> {\n"
+ " }\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteMethod, SimpleMethod) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ Method method = Method::Create("doNothing", Type::Void());
+
+ writer.BeginType(clazz, nullptr, PUBLIC)
+ .BeginMethod(method, PUBLIC).EndMethod()
+ .EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test {\n"
+ " \n"
+ " public void doNothing() {\n"
+ " }\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ Method method = Method::Create("doNothing", Type::Void());
+ method.description("This method has a\n<p>\nmultiline description.");
+ method.add_annotation(Annotation::Create("Override"));
+ method.add_annotation(Annotation::Create("SuppressWarnings")
+ .attributes("\"rawtypes\""));
+
+ writer.BeginType(clazz, nullptr, PUBLIC)
+ .BeginMethod(method, PUBLIC).EndMethod()
+ .EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test {\n"
+ " \n"
+ " /**\n"
+ " * This method has a\n"
+ " * <p>\n"
+ " * multiline description.\n"
+ " **/\n"
+ " @Override\n"
+ " @SuppressWarnings(\"rawtypes\")\n"
+ " public void doNothing() {\n"
+ " }\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteMethod, DocumentedMethodWithArguments) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ Method method = Method::Create("boolToInt", Type::Int());
+ method.description("Converts a boolean to an int");
+ method.return_description("int value for this boolean");
+ method.add_argument(Variable::Create("b", Type::Boolean()));
+ Variable reverse = Variable::Create("reverse", Type::Boolean());
+ reverse.description("if true, value is reversed");
+ method.add_argument(reverse);
+
+ writer.BeginType(clazz, nullptr, PUBLIC)
+ .BeginMethod(method, PUBLIC)
+ .Append("if (b && !reverse)")
+ .BeginBlock()
+ .Append("return 1;").EndLine()
+ .EndBlock()
+ .Append("return 0;").EndLine()
+ .EndMethod()
+ .EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test {\n"
+ " \n"
+ " /**\n"
+ " * Converts a boolean to an int\n"
+ " * \n"
+ " * @param b\n"
+ " * @param reverse if true, value is reversed\n"
+ " * @return int value for this boolean\n"
+ " **/\n"
+ " public int boolToInt(boolean b, boolean reverse) {\n"
+ " if (b && !reverse) {\n"
+ " return 1;\n"
+ " }\n"
+ " return 0;\n"
+ " }\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteMethod, ParameterizedMethod) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ Type type_t = Type::Generic("T").add_supertype(Type::Class("Number"));
+ clazz.add_parameter(type_t);
+ Method method = Method::Create("doNothing", type_t);
+
+ writer.BeginType(clazz, nullptr, PUBLIC)
+ .BeginMethod(method, PUBLIC)
+ .Append("return null;").EndLine()
+ .EndMethod()
+ .EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test<T extends Number> {\n"
+ " \n"
+ " public T doNothing() {\n"
+ " return null;\n"
+ " }\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
+TEST(WriteMethod, StaticParameterizedMethod) {
+ SourceBufferWriter writer;
+ Type clazz = Type::Class("Test", "org.tensorflow");
+ Type type_t = Type::Generic("T").add_supertype(Type::Class("Number"));
+ clazz.add_parameter(type_t);
+ Method method = Method::Create("doNothing", type_t);
+
+ writer.BeginType(clazz, nullptr, PUBLIC)
+ .BeginMethod(method, PUBLIC | STATIC)
+ .Append("return null;").EndLine()
+ .EndMethod()
+ .EndType();
+
+ const char* expected =
+ "package org.tensorflow;\n\n"
+ "public class Test<T extends Number> {\n"
+ " \n"
+ " public static <T extends Number> T doNothing() {\n"
+ " return null;\n"
+ " }\n"
+ "}\n";
+ ASSERT_STREQ(expected, writer.str().data());
+}
+
} // namespace
+} // namespace java
} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/resources/test.java.snippet b/tensorflow/java/src/gen/resources/test.java.snippet
new file mode 100644
index 0000000000..5e412a9aef
--- /dev/null
+++ b/tensorflow/java/src/gen/resources/test.java.snippet
@@ -0,0 +1,2 @@
+// Here is a little snippet
+System.out.println("Hello!");