diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-10-18 11:58:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-18 12:04:54 -0700 |
commit | f5d3bf42b892ecfbde2ce9eb45f00b76473c824a (patch) | |
tree | fe8f8a9965f40ab5a463932b2073df32be21a1af /tensorflow/c/c_api_test.cc | |
parent | 6a725f6d0974dc71fe4ac311fc8dd16db4257452 (diff) |
Add TF_GraphGetOpDef() to C API and use in Operation.op_def()
Note that this creates a small change in behavior with the C API
enabled, since previously not all Python Operations had an OpDef
(op_def() returns None). With the C API enabled, op_def() always
returns an OpDef.
PiperOrigin-RevId: 172634411
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r-- | tensorflow/c/c_api_test.cc | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index c442029009..d220bc5e95 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -50,6 +51,11 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { +static void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(StringPiece(s).contains(expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); } TEST(CAPI, Status) { @@ -837,6 +843,31 @@ TEST(CAPI, ShapeInferenceError) { TF_DeleteStatus(status); } +TEST(CAPI, GetOpDef) { + TF_Status* status = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + TF_Buffer* buffer = TF_NewBuffer(); + + TF_GraphGetOpDef(graph, "Add", buffer, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + const OpDef* expected_op_def; + TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def)); + string expected_serialized; + expected_op_def->SerializeToString(&expected_serialized); + string actual_string(reinterpret_cast<const char*>(buffer->data), + buffer->length); + EXPECT_EQ(expected_serialized, actual_string); + + TF_GraphGetOpDef(graph, "MyFakeOp", buffer, status); + EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status)); + ExpectHasSubstr(TF_Message(status), + "Op type not registered 'MyFakeOp' in binary"); + + TF_DeleteBuffer(buffer); + TF_DeleteGraph(graph); + TF_DeleteStatus(status); +} + void StringVectorToArrays(const std::vector<string>& v, std::unique_ptr<const void* []>* ptrs, std::unique_ptr<size_t[]>* lens) { |