diff options
Diffstat (limited to 'tensorflow/python/framework/python_op_gen_internal.h')
-rw-r--r-- | tensorflow/python/framework/python_op_gen_internal.h | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h index c1efbf9be2..6b53825a6d 100644 --- a/tensorflow/python/framework/python_op_gen_internal.h +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -41,6 +41,28 @@ void GenerateLowerCaseOpName(const string& str, string* result); string DataTypeToPython(DataType dtype, const string& dtype_module); +// Names that corresponds to a single input parameter. +class ParamNames { + public: + // Create param based on Arg. + ParamNames(const string& name, const string& rename_to) : name_(name) { + rename_to_ = AvoidPythonReserved(rename_to); + } + + // Get original parameter name. + string GetName() const { return name_; } + + // Get the name to rename the parameter to. Note that AvoidPythonReserved + // has already been applied. + string GetRenameTo() const { return rename_to_; } + + private: + // Original parameter name. + string name_; + // API name for this parameter. + string rename_to_; +}; + class GenPythonOp { public: GenPythonOp(const OpDef& op_def, const ApiDef& api_def, @@ -84,7 +106,7 @@ class GenPythonOp { // All parameters, including inputs & non-inferred attrs, required and those // with defaults, except "name" - std::vector<string> param_names_; + std::vector<ParamNames> param_names_; }; } // namespace python_op_gen_internal |