aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_test_util.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-18 20:19:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-18 20:22:15 -0700
commit1da763a1cc94fc5e4ad1822788b444b77623538c (patch)
tree414b168e2f3b1dc8be71c58b71e36c34818ff1c6 /tensorflow/c/c_test_util.cc
parentd10902f0a947da40f80479d74e9a487617759085 (diff)
Add function gradient support to C API
Also, change the internal representation of TF_Function and rename TF_GraphAddFunction to TF_GraphAddFunctionCopy to make it clear that a copy of the function is added to the graph. Any subsequent modifications to the function will not be reflected in the copy added to the graph. PiperOrigin-RevId: 169187793
Diffstat (limited to 'tensorflow/c/c_test_util.cc')
-rw-r--r--tensorflow/c/c_test_util.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index d1f99fe1ef..a380375db0 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -339,6 +340,25 @@ bool GetAttrValue(TF_Operation* oper, const char* attr_name,
return ret;
}
+std::vector<std::pair<string, string>> GetGradDefs(
+ const tensorflow::GraphDef& graph_def) {
+ std::vector<std::pair<string, string>> grads;
+ for (const tensorflow::GradientDef& grad : graph_def.library().gradient()) {
+ grads.emplace_back(grad.function_name(), grad.gradient_func());
+ }
+ std::sort(grads.begin(), grads.end());
+ return grads;
+}
+
+std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def) {
+ std::vector<string> names;
+ for (const tensorflow::FunctionDef& func : graph_def.library().function()) {
+ names.push_back(func.signature().name());
+ }
+ std::sort(names.begin(), names.end());
+ return names;
+}
+
CSession::CSession(TF_Graph* graph, TF_Status* s) {
TF_SessionOptions* opts = TF_NewSessionOptions();
session_ = TF_NewSession(graph, opts, s);