aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_test.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-10-18 11:58:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-18 12:04:54 -0700
commitf5d3bf42b892ecfbde2ce9eb45f00b76473c824a (patch)
treefe8f8a9965f40ab5a463932b2073df32be21a1af /tensorflow/c/c_api_test.cc
parent6a725f6d0974dc71fe4ac311fc8dd16db4257452 (diff)
Add TF_GraphGetOpDef() to C API and use in Operation.op_def()
Note that this creates a small change in behavior with the C API enabled, since previously not all Python Operations had an OpDef (op_def() returns None). With the C API enabled, op_def() always returns an OpDef. PiperOrigin-RevId: 172634411
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r--tensorflow/c/c_api_test.cc31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index c442029009..d220bc5e95 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -50,6 +51,11 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace {
+static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
+ EXPECT_TRUE(StringPiece(s).contains(expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+}
+
TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); }
TEST(CAPI, Status) {
@@ -837,6 +843,31 @@ TEST(CAPI, ShapeInferenceError) {
TF_DeleteStatus(status);
}
+TEST(CAPI, GetOpDef) {
+ TF_Status* status = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+ TF_Buffer* buffer = TF_NewBuffer();
+
+ TF_GraphGetOpDef(graph, "Add", buffer, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status));
+ const OpDef* expected_op_def;
+ TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def));
+ string expected_serialized;
+ expected_op_def->SerializeToString(&expected_serialized);
+ string actual_string(reinterpret_cast<const char*>(buffer->data),
+ buffer->length);
+ EXPECT_EQ(expected_serialized, actual_string);
+
+ TF_GraphGetOpDef(graph, "MyFakeOp", buffer, status);
+ EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status));
+ ExpectHasSubstr(TF_Message(status),
+ "Op type not registered 'MyFakeOp' in binary");
+
+ TF_DeleteBuffer(buffer);
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(status);
+}
+
void StringVectorToArrays(const std::vector<string>& v,
std::unique_ptr<const void* []>* ptrs,
std::unique_ptr<size_t[]>* lens) {