diff options
author | Igor Ganichev <iga@google.com> | 2017-09-18 20:19:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-18 20:22:15 -0700 |
commit | 1da763a1cc94fc5e4ad1822788b444b77623538c (patch) | |
tree | 414b168e2f3b1dc8be71c58b71e36c34818ff1c6 /tensorflow/c/c_test_util.h | |
parent | d10902f0a947da40f80479d74e9a487617759085 (diff) |
Add function gradient support to C API
Also, change the internal representation of TF_Function and
rename TF_GraphAddFunction to TF_GraphAddFunctionCopy to make it
clear that a copy of the function is added to the graph. Any
subsequent modifications to the function will not be reflected
in the copy added to the graph.
PiperOrigin-RevId: 169187793
Diffstat (limited to 'tensorflow/c/c_test_util.h')
-rw-r--r-- | tensorflow/c/c_test_util.h | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 91f96b0e5d..9cfedf36e5 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -94,6 +94,14 @@ bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def); bool GetAttrValue(TF_Operation* oper, const char* attr_name, tensorflow::AttrValue* attr_value, TF_Status* s); +// Returns a sorted vector of std::pair<function_name, gradient_func> from +// graph_def.library().gradient() +std::vector<std::pair<string, string>> GetGradDefs( + const tensorflow::GraphDef& graph_def); + +// Returns a sorted vector of names contained in `grad_def` +std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def); + class CSession { public: CSession(TF_Graph* graph, TF_Status* s); |