aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/python_op_gen_internal.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/python_op_gen_internal.h')
-rw-r--r--tensorflow/python/framework/python_op_gen_internal.h24
1 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h
index c1efbf9be2..6b53825a6d 100644
--- a/tensorflow/python/framework/python_op_gen_internal.h
+++ b/tensorflow/python/framework/python_op_gen_internal.h
@@ -41,6 +41,28 @@ void GenerateLowerCaseOpName(const string& str, string* result);
string DataTypeToPython(DataType dtype, const string& dtype_module);
+// Names that corresponds to a single input parameter.
+class ParamNames {
+ public:
+ // Create param based on Arg.
+ ParamNames(const string& name, const string& rename_to) : name_(name) {
+ rename_to_ = AvoidPythonReserved(rename_to);
+ }
+
+ // Get original parameter name.
+ string GetName() const { return name_; }
+
+ // Get the name to rename the parameter to. Note that AvoidPythonReserved
+ // has already been applied.
+ string GetRenameTo() const { return rename_to_; }
+
+ private:
+ // Original parameter name.
+ string name_;
+ // API name for this parameter.
+ string rename_to_;
+};
+
class GenPythonOp {
public:
GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
@@ -84,7 +106,7 @@ class GenPythonOp {
// All parameters, including inputs & non-inferred attrs, required and those
// with defaults, except "name"
- std::vector<string> param_names_;
+ std::vector<ParamNames> param_names_;
};
} // namespace python_op_gen_internal