diff options
author | Ben <bstriner@gmail.com> | 2018-08-26 15:44:13 -0400 |
---|---|---|
committer | Ben <bstriner@gmail.com> | 2018-08-26 15:44:13 -0400 |
commit | 88ec342544096d895908dac6b0bf6b44dadaaca1 (patch) | |
tree | cd570c40e6a40e37f14747d6fd387596ff324d01 /tensorflow/c | |
parent | 32d4ffeb95a344fde6a1b956a4a8d6792432bf15 (diff) | |
parent | 09792df012c22622324f085f46edde33006c7355 (diff) |
Merge branch 'master' into py37
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/c_api.cc | 3 | ||||
-rw-r--r-- | tensorflow/c/checkpoint_reader.h | 6 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 57 | ||||
-rw-r--r-- | tensorflow/c/tf_status_helper.h | 6 |
4 files changed, 65 insertions, 7 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 19ccb6e71d..b8adf6c127 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -202,7 +202,8 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, buf->len_ = len; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) && - reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) { + reinterpret_cast<intptr_t>(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) != + 0) { // TF_STRING and TF_RESOURCE tensors have a different representation in // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste // (any alignment requirements will be taken care of by TF_TensorToTensor diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index 4de1300a7f..91654c8d4f 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_CHECKPOINT_READER_H -#define TENSORFLOW_C_CHECKPOINT_READER_H +#ifndef TENSORFLOW_C_CHECKPOINT_READER_H_ +#define TENSORFLOW_C_CHECKPOINT_READER_H_ #include <memory> #include <string> @@ -79,4 +79,4 @@ class CheckpointReader { } // namespace checkpoint } // namespace tensorflow -#endif // TENSORFLOW_C_CHECKPOINT_READER_H +#endif // TENSORFLOW_C_CHECKPOINT_READER_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 71d5f3613c..7126227cf5 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1471,4 +1471,61 @@ void BM_ReadVariable(int iters) { } BENCHMARK(BM_ReadVariable); +TEST(CAPI, StringAttributes) { + // Test that TFE_OpSetAttrString doesn't hold on to the value after it + // returns. + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::vector<int64_t> dims(4, 1); + TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* tensor = + TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float)); + float tensor_data[] = {1}; + memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor)); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, tensor_handle, status); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(tensor_handle); + + std::vector<int64_t> values(4, 1); + TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size()); + TFE_OpSetAttrIntList(op, "strides", values.data(), values.size()); + + const int BUFFER_SIZE = 10; + char buffer[BUFFER_SIZE]; + std::strncpy(buffer, "VALID", BUFFER_SIZE); + TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer)); + // Overwriting value in "buffer", should be fine since TFE_Op + // shouldn't be holding on to it. + std::strncpy(buffer, "NHWC", BUFFER_SIZE); + TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer)); + + TFE_OpSetAttrType(op, "T", TF_FLOAT); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(op, &retvals[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + tensor = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(4, TF_TensorByteSize(tensor)); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(op); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} } // namespace diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index 86e687df20..7661a01de4 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H -#define TENSORFLOW_C_TF_STATUS_HELPER_H +#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H_ +#define TENSORFLOW_C_TF_STATUS_HELPER_H_ #include "tensorflow/c/c_api.h" #include "tensorflow/core/lib/core/status.h" @@ -29,4 +29,4 @@ Status StatusFromTF_Status(const TF_Status* tf_status); } // namespace tensorflow -#endif // TENSORFLOW_C_TF_STATUS_HELPER_H +#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_ |