diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-07-06 13:50:29 -0700 |
---|---|---|
committer | Yifei Feng <yifeif@google.com> | 2018-07-06 15:17:59 -0700 |
commit | 90fc5e3819ed62e93228a9c2c29dede0f0f8cfd6 (patch) | |
tree | 0e50e14646a382fbdf5edec988f9818bb93b12c0 /tensorflow/java | |
parent | d64754c5c768f26b6a95b350cfd8c7ded2590dc9 (diff) |
Allow is_initialized and initializer to be called on MirroredVariables and TowerLocalVariables.
PiperOrigin-RevId: 203520287
Diffstat (limited to 'tensorflow/java')
14 files changed, 424 insertions, 95 deletions
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index f5f54bf4d3..d9d6f8adc8 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_ #define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_ -#include <string> #include <list> #include <map> +#include <string> #include <utility> namespace tensorflow { diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 759d800ecf..05decd6b54 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -19,10 +19,10 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index 63e99fbb04..941ab2699c 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include <map> -#include <vector> #include <string> #include <utility> +#include <vector> #include "re2/re2.h" #include "tensorflow/core/framework/op.h" @@ -50,7 +50,7 @@ class TypeResolver { // For example, if the argument's datatype is DT_STRING, this method will // return "java.lang.String", so the argument can become "Operand<String>" // in the Ops API - Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out); + Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out); // Returns types of an input attribute // @@ -62,7 +62,7 @@ class TypeResolver { // <java.lang.Float, float>, so the attribute can be used as a "Float" object // in the Ops API and casted to a "float" when passing through the JNI layer. std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def, - bool *iterable_out); + bool* iterable_out); // Returns true if the type of this attribute has already been resolved bool IsAttributeVisited(const string& attr_name) { @@ -89,8 +89,7 @@ class TypeResolver { } }; -Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, - bool* iterable_out) { +Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) { *iterable_out = false; if (!arg_def.number_attr().empty()) { // when number_attr is set, argument has to be a list of tensors @@ -154,13 +153,13 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, } else { LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name() - << "\" in operation \"" << op_def_.name() << "\""; + << "\" in operation \"" << op_def_.name() << "\""; } return type; } std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, - bool* iterable_out) { + bool* iterable_out) { std::pair<Type, Type> types = MakeTypePair(Type::Wildcard()); *iterable_out = false; StringPiece attr_type = attr_def.type(); @@ -185,7 +184,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, } else if (attr_type == "tensor") { types = MakeTypePair(Type::Class("Tensor", "org.tensorflow") - .add_parameter(Type::Wildcard())); + .add_parameter(Type::Wildcard())); } else if (attr_type == "type") { Type type = *iterable_out ? Type::Wildcard() : NextGeneric(); @@ -196,7 +195,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, } else { LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type - << "\" in operation \"" << op_def_.name() << "\""; + << "\" in operation \"" << op_def_.name() << "\""; } visited_attrs_.insert(std::make_pair(attr_def.name(), types.first)); return types; @@ -219,47 +218,43 @@ string SnakeToCamelCase(const string& str, bool upper = false) { return result; } -bool FindAndCut(re2::StringPiece* input, const RE2& expr, - re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) { - re2::StringPiece match; - if (!expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1)) { - return false; - } - before_match->set(input->data(), match.begin() - input->begin()); - input->remove_prefix(match.end() - before_match->begin()); - if (ret_match != nullptr) { - *ret_match = match; - } +bool FindAndCut(string* input, const RE2& expr, string* before_match, + string* ret_match = nullptr) { + string match; + if (!RE2::PartialMatch(*input, expr, &match)) return false; + *before_match = input->substr(0, input->find(match)); + *input = input->substr(before_match->size() + match.size()); + if (ret_match != nullptr) *ret_match = match; return true; } -string ParseDocumentation(re2::StringPiece input) { +string ParseDocumentation(const string& inp) { std::stringstream javadoc_text; // TODO(karllessard) This is a very minimalist utility method for converting // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check // for alternatives to increase the level of support for markups. std::vector<string> markups_subexpr; - markups_subexpr.push_back("\n+\\*\\s+"); // lists - markups_subexpr.push_back("\n{2,}"); // paragraphs + markups_subexpr.push_back("\n+\\*\\s+"); // lists + markups_subexpr.push_back("\n{2,}"); // paragraphs markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks - markups_subexpr.push_back("`+"); // inlined code and code blocks + markups_subexpr.push_back("`+"); // inlined code and code blocks markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis - markups_subexpr.push_back("\\["); // hyperlinks - const RE2 markup_expr(str_util::Join(markups_subexpr, "|")); + markups_subexpr.push_back("\\["); // hyperlinks + const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")"); bool in_list = false; + string input = inp; while (true) { - re2::StringPiece text; - re2::StringPiece markup; + string text, markup; if (!FindAndCut(&input, markup_expr, &text, &markup)) { javadoc_text << input; break; // end of loop } javadoc_text << text; - if (markup.starts_with("\n")) { + if (str_util::StartsWith(markup, "\n")) { javadoc_text << "\n"; - if (markup.contains("*")) { + if (str_util::StrContains(markup, "*")) { // new list item javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n"; in_list = true; @@ -267,18 +262,18 @@ string ParseDocumentation(re2::StringPiece input) { // end of list javadoc_text << "</li>\n</ul>\n"; in_list = false; - } else if (!input.starts_with("```")) { + } else if (!str_util::StartsWith(input, "```")) { // new paragraph (not required if a <pre> block follows) javadoc_text << "<p>\n"; } - } else if (markup.starts_with("```")) { + } else if (str_util::StartsWith(markup, "```")) { // code blocks - if (FindAndCut(&input, "```\\s*\n*", &text)) { + if (FindAndCut(&input, "(```\\s*\n*)", &text)) { javadoc_text << "<pre>{@code\n" << text << "}</pre>\n"; } else { javadoc_text << markup; } - } else if (markup.starts_with("`")) { + } else if (str_util::StartsWith("(" + markup + ")", "`")) { // inlined code if (FindAndCut(&input, markup, &text)) { javadoc_text << "{@code " << text << "}"; @@ -287,26 +282,28 @@ string ParseDocumentation(re2::StringPiece input) { } } else if (markup == "**") { // text emphasis (strong) - if (FindAndCut(&input, "\\b\\*{2}", &text)) { + if (FindAndCut(&input, "(\\b\\*{2})", &text)) { javadoc_text << "<b>" << ParseDocumentation(text) << "</b>"; } else { javadoc_text << markup; } } else if (markup == "*") { // text emphasis (normal) - if (FindAndCut(&input, "\\b\\*{1}", &text)) { + if (FindAndCut(&input, "(\\b\\*{1})", &text)) { javadoc_text << "<i>" << ParseDocumentation(text) << "</i>"; } else { javadoc_text << markup; } - } else if (markup.starts_with("[")) { + } else if (str_util::StartsWith(markup, "[")) { // hyperlinks string label; string link; - if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) { + if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label, + &link) && + str_util::StartsWith(input, label + link)) { + input = input.substr(label.size() + link.size()); javadoc_text << "<a href=\"" << link << "\">" - << ParseDocumentation(label) - << "</a>"; + << ParseDocumentation(label) << "</a>"; } else { javadoc_text << markup; } @@ -319,57 +316,56 @@ string ParseDocumentation(re2::StringPiece input) { } ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, - const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) { + const ApiDef::Arg& input_api_def, + TypeResolver* type_resolver) { bool iterable = false; Type type = type_resolver->TypeOf(input_def, &iterable); - Type var_type = Type::Interface("Operand", "org.tensorflow") - .add_parameter(type); + Type var_type = + Type::Interface("Operand", "org.tensorflow").add_parameter(type); if (iterable) { var_type = Type::IterableOf(var_type); } - return ArgumentSpec(input_api_def.name(), + return ArgumentSpec( + input_api_def.name(), Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type), - type, - ParseDocumentation(input_api_def.description()), - iterable); + type, ParseDocumentation(input_api_def.description()), iterable); } AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, - const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { + const ApiDef::Attr& attr_api_def, + TypeResolver* type_resolver) { bool iterable = false; std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable); - Type var_type = types.first.kind() == Type::GENERIC ? - Type::Class("Class").add_parameter(types.first) : types.first; + Type var_type = types.first.kind() == Type::GENERIC + ? Type::Class("Class").add_parameter(types.first) + : types.first; if (iterable) { var_type = Type::ListOf(var_type); } - return AttributeSpec(attr_api_def.name(), + return AttributeSpec( + attr_api_def.name(), Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type), - types.first, - types.second, - ParseDocumentation(attr_api_def.description()), - iterable, - attr_api_def.has_default_value()); + types.first, types.second, ParseDocumentation(attr_api_def.description()), + iterable, attr_api_def.has_default_value()); } ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, - const ApiDef::Arg& output_api, TypeResolver* type_resolver) { + const ApiDef::Arg& output_api, + TypeResolver* type_resolver) { bool iterable = false; Type type = type_resolver->TypeOf(output_def, &iterable); - Type var_type = Type::Class("Output", "org.tensorflow") - .add_parameter(type); + Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type); if (iterable) { var_type = Type::ListOf(var_type); } - return ArgumentSpec(output_api.name(), + return ArgumentSpec( + output_api.name(), Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type), - type, - ParseDocumentation(output_api.description()), - iterable); + type, ParseDocumentation(output_api.description()), iterable); } EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, - const ApiDef_Endpoint& endpoint_def) { + const ApiDef_Endpoint& endpoint_def) { std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), "."); string package; string name; @@ -377,27 +373,25 @@ EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, package = name_tokens.at(0); name = name_tokens.at(1); } else { - package = kDefaultEndpointPackage; + package = "core"; // generate unclassified ops in the 'core' package name = name_tokens.at(0); } - return EndpointSpec(package, - name, - Javadoc::Create(ParseDocumentation(api_def.summary())) - .details(ParseDocumentation(api_def.description()))); + return EndpointSpec(package, name, + Javadoc::Create(ParseDocumentation(api_def.summary())) + .details(ParseDocumentation(api_def.description()))); } } // namespace OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { - OpSpec op(api_def.graph_op_name(), - api_def.visibility() == ApiDef::HIDDEN, - op_def.deprecation().explanation()); + OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN, + op_def.deprecation().explanation()); TypeResolver type_resolver(op_def); for (const string& next_input_name : api_def.arg_order()) { for (int i = 0; i < op_def.input_arg().size(); ++i) { if (op_def.input_arg(i).name() == next_input_name) { op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i), - &type_resolver)); + &type_resolver)); break; } } @@ -406,8 +400,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { // do not parse attributes already visited, they have probably been inferred // before as an input argument type if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) { - AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i), - &type_resolver); + AttributeSpec attr = + CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver); // attributes with a default value are optional if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) { op.optional_attributes_.push_back(attr); @@ -417,8 +411,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { } } for (int i = 0; i < op_def.output_arg().size(); ++i) { - op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i), - &type_resolver)); + op.outputs_.push_back( + CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver)); } for (const auto& endpoint_def : api_def.endpoint()) { op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def)); diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 3b53c730df..30ecb8ce53 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -19,9 +19,9 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/java/src/gen/cc/java_defs.h" namespace tensorflow { @@ -38,9 +38,8 @@ class EndpointSpec { // javadoc: the endpoint class documentation // TODO(annarev): hardcode depcreated to false until deprecated is possible EndpointSpec(const string& package, const string& name, - const Javadoc& javadoc) - : package_(package), name_(name), javadoc_(javadoc), - deprecated_(false) {} + const Javadoc& javadoc) + : package_(package), name_(name), javadoc_(javadoc), deprecated_(false) {} const string& package() const { return package_; } const string& name() const { return name_; } @@ -63,10 +62,13 @@ class ArgumentSpec { // type: the tensor type of this argument // description: a description of this argument, in javadoc // iterable: true if this argument is a list - ArgumentSpec(const string& op_def_name, const Variable& var, - const Type& type, const string& description, bool iterable) - : op_def_name_(op_def_name), var_(var), type_(type), - description_(description), iterable_(iterable) {} + ArgumentSpec(const string& op_def_name, const Variable& var, const Type& type, + const string& description, bool iterable) + : op_def_name_(op_def_name), + var_(var), + type_(type), + description_(description), + iterable_(iterable) {} const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -94,11 +96,16 @@ class AttributeSpec { // iterable: true if this attribute is a list // has_default_value: true if this attribute has a default value if not set AttributeSpec(const string& op_def_name, const Variable& var, - const Type& type, const Type& jni_type, const string& description, - bool iterable, bool has_default_value) - : op_def_name_(op_def_name), var_(var), type_(type), - description_(description), iterable_(iterable), - jni_type_(jni_type), has_default_value_(has_default_value) {} + const Type& type, const Type& jni_type, + const string& description, bool iterable, + bool has_default_value) + : op_def_name_(op_def_name), + var_(var), + type_(type), + description_(description), + iterable_(iterable), + jni_type_(jni_type), + has_default_value_(has_default_value) {} const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -147,9 +154,10 @@ class OpSpec { // hidden: true if this op should not be visible through the Graph Ops API // deprecation_explanation: message to show if all endpoints are deprecated explicit OpSpec(const string& graph_op_name, bool hidden, - const string& deprecation_explanation) - : graph_op_name_(graph_op_name), hidden_(hidden), - deprecation_explanation_(deprecation_explanation) {} + const string& deprecation_explanation) + : graph_op_name_(graph_op_name), + hidden_(hidden), + deprecation_explanation_(deprecation_explanation) {} const string graph_op_name_; const bool hidden_; diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java new file mode 100644 index 0000000000..13bc463e7d --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/Input.java @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +/** + * Interface implemented by operands of a TensorFlow operation. + * + * <p>Example usage: + * + * <pre>{@code + * // The "decodeJpeg" operation can be used as input to the "cast" operation + * Input decodeJpeg = ops.image().decodeJpeg(...); + * ops.math().cast(decodeJpeg, DataType.FLOAT); + * + * // The output "y" of the "unique" operation can be used as input to the "cast" operation + * Output y = ops.array().unique(...).y(); + * ops.math().cast(y, DataType.FLOAT); + * + * // The "split" operation can be used as input list to the "concat" operation + * Iterable<? extends Input> split = ops.array().split(...); + * ops.array().concat(0, split); + * }</pre> + */ +public interface Input<T> { + + /** + * Returns the symbolic handle of a tensor. + * + * <p>Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is + * used to obtain a symbolic handle that represents the computation of the input. + * + * @see OperationBuilder#addInput(Output) + */ + Output<T> asOutput(); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java new file mode 100644 index 0000000000..ab34f6aa12 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a boolean. */ +public class TFBool implements TFType { + private TFBool() {} + static { + Types.typeCodes.put(TFBool.class, DataType.BOOL); + } + static { + Types.scalars.put(TFBool.class, false); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java new file mode 100644 index 0000000000..49e5d9f2f3 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a 64-bit double precision floating point number. */ +public class TFDouble implements TFType { + private TFDouble() {} + static { + Types.typeCodes.put(TFDouble.class, DataType.DOUBLE); + } + static { + Types.scalars.put(TFDouble.class, 0.0); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java new file mode 100644 index 0000000000..8426ee41f0 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a 32-bit single precision floating point number. */ +public class TFFloat implements TFType { + private TFFloat() {} + static { + Types.typeCodes.put(TFFloat.class, DataType.FLOAT); + } + static { + Types.scalars.put(TFFloat.class, 0f); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java new file mode 100644 index 0000000000..3947b6ad09 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a 32-bit signed integer. */ +public class TFInt32 implements TFType { + private TFInt32() {} + static { + Types.typeCodes.put(TFInt32.class, DataType.INT32); + } + static { + Types.scalars.put(TFInt32.class, 0); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java new file mode 100644 index 0000000000..ccdded8693 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a 64-bit signed integer. */ +public class TFInt64 implements TFType { + private TFInt64() {} + static { + Types.typeCodes.put(TFInt64.class, DataType.INT64); + } + static { + Types.scalars.put(TFInt64.class, 0L); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java new file mode 100644 index 0000000000..e7327e8c57 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents an arbitrary sequence of bytes. */ +public class TFString implements TFType { + private TFString() {} + static { + Types.typeCodes.put(TFString.class, DataType.STRING); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java new file mode 100644 index 0000000000..562953ac9d --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java @@ -0,0 +1,20 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.types; + +/** + * A marker interface for classes representing TensorFlow types. + */ +public interface TFType {} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java new file mode 100644 index 0000000000..d7305ca5a8 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents an 8-bit unsigned integer. */ +public class TFUInt8 implements TFType { + private TFUInt8() {} + static { + Types.typeCodes.put(TFUInt8.class, DataType.UINT8); + } + static { + Types.scalars.put(TFUInt8.class, (byte)0); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java new file mode 100644 index 0000000000..976cd9fd34 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.types; + +import java.util.HashMap; +import java.util.Map; +import org.tensorflow.DataType; + +/** + * Utility class for managing the representation of TensorFlow types as Java + * types. For each TensorFlow type (e.g., int32), there is a corresponding Java + * type (e.g., TFInt32) that represents it at compile time and a corresponding + * class object (e.g., TFInt32.class) that represents it at run time. There is + * also an enumeration value in DataType that can be used to represent the + * type, though that should rarely be required. + */ +public class Types { + + private Types() {} // not instantiable + + static final Map<Class<?>, DataType> typeCodes = new HashMap<>(); + + /** Returns the DataType value corresponding to a TensorFlow type class. */ + public static DataType dataType(Class<? extends TFType> c) { + DataType dtype = typeCodes.get(c); + if (dtype == null) { + throw new IllegalArgumentException("" + c + " is not a TensorFlow type."); + } + return dtype; + } + + static final Map<Class<?>, Object> scalars = new HashMap<>(); + + /** Returns the zero value of type described by {@code c}, or null if + * the type (e.g., string) is not numeric and therefore has no zero value. + */ + public static Object zeroValue(Class<? extends TFType> c) { + return scalars.get(c); + } +} |