diff options
author | 2017-09-18 20:19:13 -0700 | |
---|---|---|
committer | 2017-09-18 20:22:15 -0700 | |
commit | 1da763a1cc94fc5e4ad1822788b444b77623538c (patch) | |
tree | 414b168e2f3b1dc8be71c58b71e36c34818ff1c6 /tensorflow/c/c_test_util.cc | |
parent | d10902f0a947da40f80479d74e9a487617759085 (diff) |
Add function gradient support to C API
Also, change the internal representation of TF_Function and
rename TF_GraphAddFunction to TF_GraphAddFunctionCopy to make it
clear that a copy of the function is added to the graph. Any
subsequent modifications to the function will not be reflected
in the copy added to the graph.
PiperOrigin-RevId: 169187793
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r-- | tensorflow/c/c_test_util.cc | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index d1f99fe1ef..a380375db0 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -339,6 +340,25 @@ bool GetAttrValue(TF_Operation* oper, const char* attr_name, return ret; } +std::vector<std::pair<string, string>> GetGradDefs( + const tensorflow::GraphDef& graph_def) { + std::vector<std::pair<string, string>> grads; + for (const tensorflow::GradientDef& grad : graph_def.library().gradient()) { + grads.emplace_back(grad.function_name(), grad.gradient_func()); + } + std::sort(grads.begin(), grads.end()); + return grads; +} + +std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def) { + std::vector<string> names; + for (const tensorflow::FunctionDef& func : graph_def.library().function()) { + names.push_back(func.signature().name()); + } + std::sort(names.begin(), names.end()); + return names; +} + CSession::CSession(TF_Graph* graph, TF_Status* s) { TF_SessionOptions* opts = TF_NewSessionOptions(); session_ = TF_NewSession(graph, opts, s); |