aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2017-12-04 12:31:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 12:37:12 -0800
commit8f1e63d5629bda4f6c91fdec7a3b8418ed96786e (patch)
tree0f56fddabcd4e5cf91090acf5152cc54b3651e6a /tensorflow/python
parenta1c29139ccf441ad4de97c4e7fe2729e6130fcb8 (diff)
Actually use ApiDef when generating Python API.
PiperOrigin-RevId: 177851421
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/build_defs.bzl4
-rw-r--r--tensorflow/python/eager/gen_op.bzl20
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.cc234
-rw-r--r--tensorflow/python/eager/python_eager_op_gen_main.cc37
-rw-r--r--tensorflow/python/framework/python_op_gen.cc60
-rw-r--r--tensorflow/python/framework/python_op_gen_internal.h24
-rw-r--r--tensorflow/python/framework/python_op_gen_main.cc56
-rw-r--r--tensorflow/python/ops/array_ops.py8
8 files changed, 271 insertions, 172 deletions
diff --git a/tensorflow/python/build_defs.bzl b/tensorflow/python/build_defs.bzl
index 2d8625933f..48b03fab0f 100644
--- a/tensorflow/python/build_defs.bzl
+++ b/tensorflow/python/build_defs.bzl
@@ -27,4 +27,8 @@ def tf_gen_op_wrapper_private_py(name, out=None, deps=[],
deps=deps,
require_shape_functions=require_shape_functions,
generated_target_name=name,
+ api_def_srcs = [
+ "//tensorflow/core:base_api_def",
+ "//tensorflow/core:python_api_def",
+ ],
)
diff --git a/tensorflow/python/eager/gen_op.bzl b/tensorflow/python/eager/gen_op.bzl
index 1c99d342be..8bc1d6c10a 100644
--- a/tensorflow/python/eager/gen_op.bzl
+++ b/tensorflow/python/eager/gen_op.bzl
@@ -10,7 +10,9 @@ def tfe_gen_op_wrapper_py(name,
out=None,
visibility=None,
deps=[],
- generated_target_name=None):
+ generated_target_name=None,
+ # ApiDefs will be loaded in the order specified in this list.
+ api_def_srcs=[]):
"""Generate an eager-mode Python op wrapper for an op library."""
# Construct a cc_binary containing the specified ops.
tool_name = "gen_" + name + "_py_wrappers_cc"
@@ -30,11 +32,25 @@ def tfe_gen_op_wrapper_py(name,
if not out:
out = "gen_" + name + ".py"
+ if not api_def_srcs:
+ api_def_args_str = ","
+ else:
+ api_def_args = []
+ for api_def_src in api_def_srcs:
+ # Add directory of the first ApiDef source to args.
+ # We are assuming all ApiDefs in a single api_def_src are in the
+ # same directory.
+ api_def_args.append(
+ "$$(dirname $$(echo $(locations " + api_def_src +
+ ") | cut -d\" \" -f1))")
+ api_def_args_str = ",".join(api_def_args)
+
native.genrule(
name=name + "_pygenrule",
outs=[out],
+ srcs=api_def_srcs,
tools=[tool_name] + tf_binary_additional_srcs(),
- cmd=("$(location " + tool_name + ") > $@"))
+ cmd=("$(location " + tool_name + ") " + api_def_args_str + " > $@"))
# Make a py_library out of the generated python file.
if not generated_target_name:
diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc
index 956fbdac50..90a8779ff8 100644
--- a/tensorflow/python/eager/python_eager_op_gen.cc
+++ b/tensorflow/python/eager/python_eager_op_gen.cc
@@ -99,6 +99,15 @@ string TensorPBString(const TensorProto& pb) {
return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
}
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
+ for (int i = 0; i < api_def.in_arg_size(); ++i) {
+ if (api_def.in_arg(i).name() == name) {
+ return &api_def.in_arg(i);
+ }
+ }
+ return nullptr;
+}
+
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
public:
GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
@@ -164,14 +173,14 @@ string GenEagerPythonOp::FlattenInputs(
} else if (inputs_state == WAS_LIST_INPUT) {
strings::StrAppend(&inputs, " + ");
}
- strings::StrAppend(&inputs, "list(", param_names_[i], ")");
+ strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
inputs_state = WAS_LIST_INPUT;
if (output_sizes != nullptr) {
if (!arg.number_attr().empty()) {
output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
} else {
output_sizes->emplace_back(
- strings::StrCat("len(", param_names_[i], ")"));
+ strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
}
}
} else {
@@ -182,7 +191,7 @@ string GenEagerPythonOp::FlattenInputs(
} else {
strings::StrAppend(&inputs, "[");
}
- strings::StrAppend(&inputs, param_names_[i]);
+ strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
inputs_state = WAS_SOLO_INPUT;
if (output_sizes != nullptr) output_sizes->emplace_back();
}
@@ -195,15 +204,21 @@ string GenEagerPythonOp::FlattenInputs(
}
string GenEagerPythonOp::Code() {
+ if (api_def_.visibility() == ApiDef::SKIP) {
+ return "";
+ }
// This has all the input args followed by those attrs that don't have
// defaults.
- std::vector<string> args_no_default;
+ std::vector<python_op_gen_internal::ParamNames> params_no_default;
// The parameters with defaults (these have to be listed after those without).
// No input args are included, just attrs.
- std::vector<std::pair<string, string>> args_with_defaults;
- for (int i = 0; i < op_def_.input_arg_size(); ++i) {
- const auto& arg(op_def_.input_arg(i));
- args_no_default.push_back(arg.name());
+ std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
+ params_with_default;
+
+ for (int i = 0; i < api_def_.arg_order_size(); ++i) {
+ const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
+ const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
+ params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
if (!arg.type_attr().empty()) {
AddAttrForArg(arg.type_attr(), i);
} else if (!arg.type_list_attr().empty()) {
@@ -215,31 +230,39 @@ string GenEagerPythonOp::Code() {
}
for (int i = 0; i < op_def_.attr_size(); ++i) {
const auto& attr(op_def_.attr(i));
+ const auto& api_def_attr(api_def_.attr(i));
// Do not add inferred attrs to the Python function signature.
if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
- if (attr.has_default_value()) {
+ if (api_def_attr.has_default_value()) {
if (attr.type() == "tensor") {
- args_with_defaults.emplace_back(
- attr.name(),
- strings::StrCat("_execute.make_tensor(",
- TensorPBString(attr.default_value().tensor()),
- ", \"", attr.name(), "\")"));
+ params_with_default.emplace_back(
+ python_op_gen_internal::ParamNames(api_def_attr.name(),
+ api_def_attr.rename_to()),
+ strings::StrCat(
+ "_execute.make_tensor(",
+ TensorPBString(api_def_attr.default_value().tensor()), ", \"",
+ api_def_attr.rename_to(), "\")"));
} else if (attr.type() == "list(tensor)") {
std::vector<string> pbtxt;
- for (const auto& pb : attr.default_value().list().tensor()) {
+ for (const auto& pb : api_def_attr.default_value().list().tensor()) {
pbtxt.emplace_back(TensorPBString(pb));
}
- args_with_defaults.emplace_back(
- attr.name(),
- strings::StrCat("[_execute.make_tensor(_pb, \"", attr.name(),
- "\") for _pb in ", VectorToTuple(pbtxt), "]"));
+ params_with_default.emplace_back(
+ python_op_gen_internal::ParamNames(api_def_attr.name(),
+ api_def_attr.rename_to()),
+ strings::StrCat("[_execute.make_tensor(_pb, \"",
+ api_def_attr.rename_to(), "\") for _pb in ",
+ VectorToTuple(pbtxt), "]"));
} else {
- args_with_defaults.emplace_back(
- attr.name(), python_op_gen_internal::AttrValueToPython(
- attr.type(), attr.default_value(), "_dtypes."));
+ params_with_default.emplace_back(
+ python_op_gen_internal::ParamNames(api_def_attr.name(),
+ api_def_attr.rename_to()),
+ python_op_gen_internal::AttrValueToPython(
+ attr.type(), api_def_attr.default_value(), "_dtypes."));
}
} else {
- args_no_default.push_back(attr.name());
+ params_no_default.emplace_back(api_def_attr.name(),
+ api_def_attr.rename_to());
}
}
}
@@ -247,34 +270,37 @@ string GenEagerPythonOp::Code() {
// Save the list of attr parameters (attrs that won't be inferred),
// those with defaults go at the end.
// Get the attrs in the order we want by taking the attrs without defaults
- // from the end of args_no_default, and adding args_no_default.
- attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
- args_with_defaults.size());
- attrs_.insert(attrs_.end(),
- args_no_default.begin() + op_def_.input_arg_size(),
- args_no_default.end());
- for (const auto& a : args_with_defaults) {
- attrs_.push_back(a.first);
+ // from the end of params_no_default, and adding params_no_default.
+ attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
+ params_with_default.size());
+ for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
+ attrs_.push_back(params_no_default[i].GetName());
+ }
+ for (const auto& p : params_with_default) {
+ attrs_.push_back(p.first.GetName());
+ }
+
+ param_names_.reserve(params_no_default.size() + params_with_default.size());
+ param_names_.insert(param_names_.begin(), params_no_default.begin(),
+ params_no_default.end());
+ for (const auto& param_and_default : params_with_default) {
+ param_names_.push_back(param_and_default.first);
}
- param_names_.reserve(args_no_default.size() + args_with_defaults.size());
string parameters;
- for (const string& name : args_no_default) {
+ for (const auto& param : params_no_default) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
- const string param = python_op_gen_internal::AvoidPythonReserved(name);
- strings::StrAppend(&parameters, param);
- param_names_.push_back(param);
+ strings::StrAppend(&parameters, param.GetRenameTo());
}
- for (const auto& name_default : args_with_defaults) {
+ for (const auto& param_and_default : params_with_default) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
- const string param =
- python_op_gen_internal::AvoidPythonReserved(name_default.first);
- strings::StrAppend(&parameters, param, "=", name_default.second);
- param_names_.push_back(param);
+ strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), "=",
+ param_and_default.second);
}
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
strings::StrAppend(&parameters, "name=None");
+ AddExport();
AddDefLine(parameters);
AddDocStringDescription();
AddDocStringArgs();
@@ -297,25 +323,26 @@ string GenEagerPythonOp::Code() {
// inputs are lists and have the same length.
for (auto iter = arg_list->second.begin();
iter != arg_list->second.end(); ++iter) {
- const string& arg_name = param_names_[*iter];
- ExpectListArg(arg_name);
+ const string& arg_api_name = param_names_[*iter].GetRenameTo();
+ ExpectListArg(arg_api_name);
if (iter == arg_list->second.begin()) {
- AddInferredAttr(attr.name(), strings::StrCat("len(", arg_name, ")"),
+ AddInferredAttr(attr.name(),
+ strings::StrCat("len(", arg_api_name, ")"),
&result_, &attr_expressions_);
} else {
const auto& attr_var = attr_expressions_[attr.name()];
- strings::StrAppend(&result_, " if len(", arg_name,
+ strings::StrAppend(&result_, " if len(", arg_api_name,
") != ", attr_var,
":\n"
" raise ValueError(\n"
" \"List argument '",
- arg_name, "' to '", op_name_,
+ arg_api_name, "' to '", op_name_,
"' Op with length %d \"\n"
" \"must match length %d of argument '",
inferred_attrs_[attr.name()],
"'.\" %\n"
" (len(",
- arg_name, "), ", attr_var, "))\n");
+ arg_api_name, "), ", attr_var, "))\n");
}
}
}
@@ -325,65 +352,76 @@ string GenEagerPythonOp::Code() {
// Values for non-inferred attrs.
for (int i = 0; i < attrs_.size(); ++i) {
const string& attr_name = attrs_[i];
- const string& param = param_names_[i + op_def_.input_arg_size()];
+ const auto& param = param_names_[i + op_def_.input_arg_size()];
const auto& attr = *FindAttr(attr_name, op_def_);
+ const string& attr_api_name = param.GetRenameTo();
StringPiece attr_type = attr.type();
- attr_expressions_[attr_name] = param;
- const int default_index = i - (attrs_.size() - args_with_defaults.size());
+ attr_expressions_[attr_name] = attr_api_name;
+ const int default_index = i - (attrs_.size() - params_with_default.size());
if (default_index >= 0) {
- const string& default_value = args_with_defaults[default_index].second;
- strings::StrAppend(&result_, " if ", param, " is None:\n");
- strings::StrAppend(&result_, " ", param, " = ", default_value, "\n");
+ const string& default_value = params_with_default[default_index].second;
+ strings::StrAppend(&result_, " if ", attr_api_name, " is None:\n");
+ strings::StrAppend(&result_, " ", attr_api_name, " = ", default_value,
+ "\n");
}
if (attr_type.starts_with("list(")) {
- ExpectListArg(param);
+ ExpectListArg(attr_api_name);
}
if (attr_type == "string") {
- strings::StrAppend(&result_, " ", param, " = _execute.make_str(", param,
- ", \"", param, "\")\n");
+ strings::StrAppend(&result_, " ", attr_api_name, " = _execute.make_str(",
+ attr_api_name, ", \"", attr_api_name, "\")\n");
} else if (attr_type == "list(string)") {
- strings::StrAppend(&result_, " ", param, " = [_execute.make_str(_s, \"",
- param, "\") for _s in ", param, "]\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = [_execute.make_str(_s, \"", attr_api_name,
+ "\") for _s in ", attr_api_name, "]\n");
} else if (attr_type == "int") {
- strings::StrAppend(&result_, " ", param, " = _execute.make_int(", param,
- ", \"", param, "\")\n");
+ strings::StrAppend(&result_, " ", attr_api_name, " = _execute.make_int(",
+ attr_api_name, ", \"", attr_api_name, "\")\n");
} else if (attr_type == "list(int)") {
- strings::StrAppend(&result_, " ", param, " = [_execute.make_int(_i, \"",
- param, "\") for _i in ", param, "]\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = [_execute.make_int(_i, \"", attr_api_name,
+ "\") for _i in ", attr_api_name, "]\n");
} else if (attr_type == "float") {
- strings::StrAppend(&result_, " ", param, " = _execute.make_float(",
- param, ", \"", param, "\")\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = _execute.make_float(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
} else if (attr_type == "list(float)") {
- strings::StrAppend(&result_, " ", param,
- " = [_execute.make_float(_f, \"", param,
- "\") for _f in ", param, "]\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = [_execute.make_float(_f, \"", attr_api_name,
+ "\") for _f in ", attr_api_name, "]\n");
} else if (attr_type == "bool") {
- strings::StrAppend(&result_, " ", param, " = _execute.make_bool(", param,
- ", \"", param, "\")\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = _execute.make_bool(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
} else if (attr_type == "list(bool)") {
- strings::StrAppend(&result_, " ", param, " = [_execute.make_bool(_b, \"",
- param, "\") for _b in ", param, "]\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = [_execute.make_bool(_b, \"", attr_api_name,
+ "\") for _b in ", attr_api_name, "]\n");
} else if (attr_type == "type") {
- strings::StrAppend(&result_, " ", param, " = _execute.make_type(", param,
- ", \"", param, "\")\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = _execute.make_type(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
} else if (attr_type == "list(type)") {
- strings::StrAppend(&result_, " ", param, " = [_execute.make_type(_t, \"",
- param, "\") for _t in ", param, "]\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = [_execute.make_type(_t, \"", attr_api_name,
+ "\") for _t in ", attr_api_name, "]\n");
} else if (attr_type == "shape") {
- strings::StrAppend(&result_, " ", param, " = _execute.make_shape(",
- param, ", \"", param, "\")\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = _execute.make_shape(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
} else if (attr_type == "list(shape)") {
- strings::StrAppend(&result_, " ", param,
- " = [_execute.make_shape(_s, \"", param,
- "\") for _s in ", param, "]\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = [_execute.make_shape(_s, \"", attr_api_name,
+ "\") for _s in ", attr_api_name, "]\n");
} else if (attr_type == "tensor") {
- strings::StrAppend(&result_, " ", param, " = _execute.make_tensor(",
- param, ", \"", param, "\")\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = _execute.make_tensor(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
} else if (attr_type == "list(tensor)") {
- strings::StrAppend(&result_, " ", param,
- " = [_execute.make_tensor(_t, \"", param,
- "\") for _t in ", param, "]\n");
+ strings::StrAppend(&result_, " ", attr_api_name,
+ " = [_execute.make_tensor(_t, \"", attr_api_name,
+ "\") for _t in ", attr_api_name, "]\n");
} else if (attr_type != "func") {
return strings::StrCat("# No definition for ", function_name_,
" since we don't support attrs with type\n"
@@ -484,16 +522,20 @@ string GenEagerPythonOp::Code() {
bool eager_allowed = true;
string ref_arg;
- for (const auto& arg : op_def_.input_arg()) {
+ for (int i = 0; i < op_def_.input_arg_size(); ++i) {
+ const auto& arg = op_def_.input_arg(i);
if (arg.is_ref()) {
eager_allowed = false;
- ref_arg = arg.name();
+ DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
+ ref_arg = api_def_.in_arg(i).rename_to();
}
}
- for (const auto& arg : op_def_.output_arg()) {
+ for (int i = 0; i < op_def_.output_arg_size(); ++i) {
+ const auto& arg = op_def_.output_arg(i);
if (arg.is_ref()) {
eager_allowed = false;
- ref_arg = arg.name();
+ DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
+ ref_arg = api_def_.out_arg(i).rename_to();
}
}
@@ -553,6 +595,7 @@ void GenEagerPythonOp::AddEagerInferredAttrs() {
// Figure out values for inferred attrs, and cast to eager tensors.
for (int i = 0; i < op_def_.attr_size(); ++i) {
const auto& attr(op_def_.attr(i));
+ const auto& api_def_attr(api_def_.attr(i));
auto arg_list = attr_to_args_.find(attr.name());
if (arg_list != attr_to_args_.end()) {
if (attr.type() == "type") {
@@ -565,14 +608,15 @@ void GenEagerPythonOp::AddEagerInferredAttrs() {
strings::StrAppend(
&conversion, ", ",
python_op_gen_internal::AttrValueToPython(
- attr.type(), attr.default_value(), "_dtypes."));
+ attr.type(), api_def_attr.default_value(), "_dtypes."));
}
strings::StrAppend(&conversion, ")");
const string var_name = AttrVarName(attr.name(), &attr_expressions_);
if (output_sizes.size() == 1) {
// Avoid creating a temporary variable in the case where
// we can easily assign to the right value directly.
- const string inputs_var = param_names_[arg_list->second.front()];
+ const string inputs_var =
+ param_names_[arg_list->second.front()].GetRenameTo();
if (output_sizes.front().empty()) {
strings::StrAppend(&result_, " ", var_name, ", (", inputs_var,
",) = ", conversion, "\n");
@@ -589,7 +633,7 @@ void GenEagerPythonOp::AddEagerInferredAttrs() {
Unflatten(" ", output_sizes, inputs_var, &result_);
std::vector<string> p;
for (int j : arg_list->second) {
- p.emplace_back(param_names_[j]);
+ p.emplace_back(param_names_[j].GetRenameTo());
}
strings::StrAppend(&result_, " ", VectorToTuple(p), " = ",
inputs_var, "\n");
@@ -608,14 +652,14 @@ void GenEagerPythonOp::AddEagerInferredAttrs() {
std::vector<string> lists;
for (auto iter = arg_list->second.begin();
iter != arg_list->second.end(); ++iter) {
- lists.push_back(param_names_[*iter]);
+ lists.push_back(param_names_[*iter].GetRenameTo());
}
inputs_var = VectorToTuple(lists);
conversion = "_execute.args_to_mixed_eager_tensors";
} else {
// For one list(tensor) argument, we just convert every
// element of the list to an eager tensor.
- inputs_var = param_names_[arg_list->second.front()];
+ inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
conversion = "_execute.convert_to_mixed_eager_tensors";
}
strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ",
@@ -630,7 +674,7 @@ void GenEagerPythonOp::AddEagerInputCasts() {
for (int i = 0; i < op_def_.input_arg_size(); ++i) {
const auto& arg(op_def_.input_arg(i));
if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
- const string& param = param_names_[i];
+ const string& param = param_names_[i].GetRenameTo();
const string fn = arg.number_attr().empty() ? "" : "n_";
const string dtype =
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
diff --git a/tensorflow/python/eager/python_eager_op_gen_main.cc b/tensorflow/python/eager/python_eager_op_gen_main.cc
index cd74c438ec..05351bd8b1 100644
--- a/tensorflow/python/eager/python_eager_op_gen_main.cc
+++ b/tensorflow/python/eager/python_eager_op_gen_main.cc
@@ -21,34 +21,32 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
namespace tensorflow {
namespace {
-constexpr char kBaseApiDef[] =
- "tensorflow/core/api_def/base_api/*.pbtxt";
-constexpr char kPythonApiDef[] =
- "tensorflow/core/api_def/python_api/*.pbtxt";
-constexpr bool kUseApiDef = false;
-
-void PrintAllPythonOps(const std::vector<string>& hidden_ops) {
+void PrintAllPythonOps(const std::vector<string>& hidden_ops,
+ const std::vector<string>& api_def_dirs) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
ApiDefMap api_def_map(ops);
- if (kUseApiDef) {
+ if (!api_def_dirs.empty()) {
Env* env = Env::Default();
- std::vector<string> base_api_files;
- std::vector<string> python_api_files;
- TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files));
- TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files));
-
- TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files));
- TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files));
+ for (const auto& api_def_dir : api_def_dirs) {
+ std::vector<string> api_files;
+ TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
+ &api_files));
+ TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
+ }
+ api_def_map.UpdateDocs();
}
+
PrintEagerPythonOps(ops, api_def_map, hidden_ops, true /* require_shapes */);
}
@@ -58,8 +56,15 @@ void PrintAllPythonOps(const std::vector<string>& hidden_ops) {
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ // Usage:
+ // python_eager_op_gen_main api_def_dir1,api_def_dir2,...
if (argc == 1) {
- tensorflow::PrintAllPythonOps({});
+ tensorflow::PrintAllPythonOps({}, {});
+ } else if (argc == 2) {
+ const std::vector<tensorflow::string> api_def_dirs =
+ tensorflow::str_util::Split(argv[1], ",",
+ tensorflow::str_util::SkipEmpty());
+ tensorflow::PrintAllPythonOps({}, api_def_dirs);
} else {
return -1;
}
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index c57f0a9842..72d3ea90fd 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <stdio.h>
#include <sstream>
#include <unordered_map>
+#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb_text.h"
@@ -480,15 +481,15 @@ string GenPythonOp::Code() {
}
// This has all the input args followed by those attrs that don't have
// defaults.
- std::vector<string> args_no_default;
+ std::vector<ParamNames> params_no_default;
// The parameters with defaults (these have to be listed after those without).
// No input args are included, just attrs.
- std::vector<string> args_with_defaults;
+ std::vector<ParamNames> params_with_default;
for (int i = 0; i < api_def_.arg_order_size(); ++i) {
const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
- args_no_default.push_back(api_def_arg.rename_to());
+ params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
if (!arg.type_attr().empty()) {
gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
} else if (!arg.type_list_attr().empty()) {
@@ -504,9 +505,9 @@ string GenPythonOp::Code() {
// Do not add inferred attrs to the Python function signature.
if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
if (attr.has_default_value()) {
- args_with_defaults.push_back(attr.rename_to());
+ params_with_default.emplace_back(attr.name(), attr.rename_to());
} else {
- args_no_default.push_back(attr.rename_to());
+ params_no_default.emplace_back(attr.name(), attr.rename_to());
}
}
}
@@ -515,27 +516,30 @@ string GenPythonOp::Code() {
// those with defaults go at the end.
// Get the attrs in the order we want by taking the attrs without defaults
// from the end of args_no_default, and adding args_no_default.
- attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
- args_with_defaults.size());
- attrs_.insert(attrs_.end(),
- args_no_default.begin() + op_def_.input_arg_size(),
- args_no_default.end());
- attrs_.insert(attrs_.end(), args_with_defaults.begin(),
- args_with_defaults.end());
-
- param_names_.reserve(args_no_default.size() + args_with_defaults.size());
+ attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
+ params_with_default.size());
+ for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
+ attrs_.push_back(params_no_default[i].GetName());
+ }
+ for (int i = 0; i < params_with_default.size(); ++i) {
+ attrs_.push_back(params_with_default[i].GetName());
+ }
+
+ param_names_.reserve(params_no_default.size() + params_with_default.size());
+ param_names_.insert(param_names_.begin(), params_no_default.begin(),
+ params_no_default.end());
+ for (const auto& param : params_with_default) {
+ param_names_.push_back(param);
+ }
+
string parameters;
- for (const string& name : args_no_default) {
+ for (const auto& param : params_no_default) {
AddDelimiter(&parameters, ", ");
- const string param = AvoidPythonReserved(name);
- strings::StrAppend(&parameters, param);
- param_names_.push_back(param);
+ strings::StrAppend(&parameters, param.GetRenameTo());
}
- for (const string& name : args_with_defaults) {
+ for (const auto& param_and_default : params_with_default) {
AddDelimiter(&parameters, ", ");
- const string param = AvoidPythonReserved(name);
- strings::StrAppend(&parameters, param, "=None");
- param_names_.push_back(param);
+ strings::StrAppend(&parameters, param_and_default.GetRenameTo(), "=None");
}
AddDelimiter(&parameters, ", ");
strings::StrAppend(&parameters, "name=None");
@@ -557,10 +561,11 @@ string GenPythonOp::Code() {
}
void GenPythonOp::AddExport() {
- if (api_def_.visibility() != api_def_.VISIBLE) {
+ if (api_def_.visibility() != ApiDef::VISIBLE) {
return;
}
- strings::StrAppend(&result_, "tf_export(");
+
+ strings::StrAppend(&result_, "@tf_export(");
// Add all endpoint names to tf_export.
bool first_endpoint = true;
@@ -603,9 +608,9 @@ void GenPythonOp::AddDocStringInputs() {
StringPiece description = api_def_arg.description();
string desc;
if (ConsumeEquals(&description)) { // Skip the generated type info.
- desc = strings::StrCat(param_names_[i], ": ");
+ desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ");
} else {
- desc = strings::StrCat(param_names_[i], ": ",
+ desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ",
ArgTypeName(op_def_, arg, inferred_attrs_, false));
}
if (!description.empty()) {
@@ -750,7 +755,8 @@ void GenPythonOp::AddBody(const string& prefix) {
void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
string args = strings::StrCat("\"", op_def_.name(), "\", ");
for (size_t i = 0; i < param_names_.size(); ++i) {
- strings::StrAppend(&args, param_names_[i], "=", param_names_[i], ", ");
+ strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()),
+ "=", param_names_[i].GetRenameTo(), ", ");
}
strings::StrAppend(&args, "name=name)");
diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h
index c1efbf9be2..6b53825a6d 100644
--- a/tensorflow/python/framework/python_op_gen_internal.h
+++ b/tensorflow/python/framework/python_op_gen_internal.h
@@ -41,6 +41,28 @@ void GenerateLowerCaseOpName(const string& str, string* result);
string DataTypeToPython(DataType dtype, const string& dtype_module);
+// Names that corresponds to a single input parameter.
+class ParamNames {
+ public:
+ // Create param based on Arg.
+ ParamNames(const string& name, const string& rename_to) : name_(name) {
+ rename_to_ = AvoidPythonReserved(rename_to);
+ }
+
+ // Get original parameter name.
+ string GetName() const { return name_; }
+
+ // Get the name to rename the parameter to. Note that AvoidPythonReserved
+ // has already been applied.
+ string GetRenameTo() const { return rename_to_; }
+
+ private:
+ // Original parameter name.
+ string name_;
+ // API name for this parameter.
+ string rename_to_;
+};
+
class GenPythonOp {
public:
GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
@@ -84,7 +106,7 @@ class GenPythonOp {
// All parameters, including inputs & non-inferred attrs, required and those
// with defaults, except "name"
- std::vector<string> param_names_;
+ std::vector<ParamNames> param_names_;
};
} // namespace python_op_gen_internal
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
index 61b1d02a5e..bc5ca195da 100644
--- a/tensorflow/python/framework/python_op_gen_main.cc
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -34,12 +34,6 @@ limitations under the License.
namespace tensorflow {
namespace {
-constexpr char kBaseApiDef[] =
- "tensorflow/core/api_def/base_api/*.pbtxt";
-constexpr char kPythonApiDef[] =
- "tensorflow/core/api_def/python_api/*.pbtxt";
-constexpr bool kUseApiDef = false;
-
Status ReadOpListFromFile(const string& filename,
std::vector<string>* op_list) {
std::unique_ptr<RandomAccessFile> file;
@@ -110,22 +104,23 @@ string InferSourceFileName(const char* argv_zero) {
}
void PrintAllPythonOps(const std::vector<string>& op_list,
+ const std::vector<string>& api_def_dirs,
const string& source_file_name, bool require_shapes,
bool op_list_is_whitelist) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
ApiDefMap api_def_map(ops);
- if (kUseApiDef) {
+ if (!api_def_dirs.empty()) {
Env* env = Env::Default();
- std::vector<string> base_api_files;
- std::vector<string> python_api_files;
- TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files));
- TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files));
-
- TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files));
- TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files));
+ for (const auto& api_def_dir : api_def_dirs) {
+ std::vector<string> api_files;
+ TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
+ &api_files));
+ TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
+ }
+ api_def_map.UpdateDocs();
}
if (op_list_is_whitelist) {
@@ -154,23 +149,30 @@ int main(int argc, char* argv[]) {
tensorflow::InferSourceFileName(argv[0]);
// Usage:
- // gen_main [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
- if (argc == 2) {
- tensorflow::PrintAllPythonOps({}, source_file_name,
- tensorflow::string(argv[1]) == "1",
- false /* op_list_is_whitelist */);
- } else if (argc == 3) {
- std::vector<tensorflow::string> hidden_ops;
- TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &hidden_ops));
- tensorflow::PrintAllPythonOps(hidden_ops, source_file_name,
+ // gen_main api_def_dir1,api_def_dir2,...
+ // [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
+ if (argc < 3) {
+ return -1;
+ }
+ std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
+ argv[1], ",", tensorflow::str_util::SkipEmpty());
+
+ if (argc == 3) {
+ tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
tensorflow::string(argv[2]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 4) {
+ std::vector<tensorflow::string> hidden_ops;
+ TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
+ tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
+ tensorflow::string(argv[3]) == "1",
+ false /* op_list_is_whitelist */);
+ } else if (argc == 5) {
std::vector<tensorflow::string> op_list;
- TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &op_list));
- tensorflow::PrintAllPythonOps(op_list, source_file_name,
- tensorflow::string(argv[2]) == "1",
- tensorflow::string(argv[3]) == "1");
+ TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
+ tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
+ tensorflow::string(argv[3]) == "1",
+ tensorflow::string(argv[4]) == "1");
} else {
return -1;
}
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 23aa74c027..73a19e7042 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1306,7 +1306,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
size_splits = ops.convert_to_tensor(num_or_size_splits)
if size_splits._rank() == 0 and size_splits.dtype.is_integer:
return gen_array_ops._split(
- split_dim=axis, num_split=num_or_size_splits, value=value, name=name)
+ axis=axis, num_split=num_or_size_splits, value=value, name=name)
if num is None:
num = size_splits._shape_tuple()[0]
@@ -1316,7 +1316,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
return gen_array_ops._split_v(
value=value,
size_splits=size_splits,
- split_dim=axis,
+ axis=axis,
num_split=num,
name=name)
@@ -2538,9 +2538,9 @@ def where(condition, x=None, y=None, name=None):
with ops.name_scope(name, "Where", [condition]) as name:
condition = ops.convert_to_tensor(
condition, preferred_dtype=dtypes.bool, name="condition")
- return gen_array_ops.where(input=condition, name=name)
+ return gen_array_ops.where(condition=condition, name=name)
elif x is not None and y is not None:
- return gen_math_ops._select(condition=condition, t=x, e=y, name=name)
+ return gen_math_ops._select(condition=condition, x=x, y=y, name=name)
else:
raise ValueError("x and y must both be non-None or both be None.")