aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_test_util.h
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-18 20:19:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-18 20:22:15 -0700
commit1da763a1cc94fc5e4ad1822788b444b77623538c (patch)
tree414b168e2f3b1dc8be71c58b71e36c34818ff1c6 /tensorflow/c/c_test_util.h
parentd10902f0a947da40f80479d74e9a487617759085 (diff)
Add function gradient support to C API
Also, change the internal representation of TF_Function and rename TF_GraphAddFunction to TF_GraphAddFunctionCopy to make it clear that a copy of the function is added to the graph. Any subsequent modifications to the function will not be reflected in the copy added to the graph. PiperOrigin-RevId: 169187793
Diffstat (limited to 'tensorflow/c/c_test_util.h')
-rw-r--r--tensorflow/c/c_test_util.h8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index 91f96b0e5d..9cfedf36e5 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -94,6 +94,14 @@ bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def);
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s);
+// Returns a sorted vector of std::pair<function_name, gradient_func> from
+// graph_def.library().gradient()
+std::vector<std::pair<string, string>> GetGradDefs(
+ const tensorflow::GraphDef& graph_def);
+
+// Returns a sorted vector of names contained in `grad_def`
+std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def);
+
class CSession {
public:
CSession(TF_Graph* graph, TF_Status* s);