aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Ben <bstriner@gmail.com>2018-08-26 15:44:13 -0400
committerGravatar Ben <bstriner@gmail.com>2018-08-26 15:44:13 -0400
commit88ec342544096d895908dac6b0bf6b44dadaaca1 (patch)
treecd570c40e6a40e37f14747d6fd387596ff324d01 /tensorflow/c
parent32d4ffeb95a344fde6a1b956a4a8d6792432bf15 (diff)
parent09792df012c22622324f085f46edde33006c7355 (diff)
Merge branch 'master' into py37
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/c_api.cc3
-rw-r--r--tensorflow/c/checkpoint_reader.h6
-rw-r--r--tensorflow/c/eager/c_api_test.cc57
-rw-r--r--tensorflow/c/tf_status_helper.h6
4 files changed, 65 insertions, 7 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 19ccb6e71d..b8adf6c127 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -202,7 +202,8 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
buf->len_ = len;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
- reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
+ reinterpret_cast<intptr_t>(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
+ 0) {
// TF_STRING and TF_RESOURCE tensors have a different representation in
// TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
// (any alignment requirements will be taken care of by TF_TensorToTensor
diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h
index 4de1300a7f..91654c8d4f 100644
--- a/tensorflow/c/checkpoint_reader.h
+++ b/tensorflow/c/checkpoint_reader.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_C_CHECKPOINT_READER_H
-#define TENSORFLOW_C_CHECKPOINT_READER_H
+#ifndef TENSORFLOW_C_CHECKPOINT_READER_H_
+#define TENSORFLOW_C_CHECKPOINT_READER_H_
#include <memory>
#include <string>
@@ -79,4 +79,4 @@ class CheckpointReader {
} // namespace checkpoint
} // namespace tensorflow
-#endif // TENSORFLOW_C_CHECKPOINT_READER_H
+#endif // TENSORFLOW_C_CHECKPOINT_READER_H_
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 71d5f3613c..7126227cf5 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1471,4 +1471,61 @@ void BM_ReadVariable(int iters) {
}
BENCHMARK(BM_ReadVariable);
+TEST(CAPI, StringAttributes) {
+ // Test that TFE_OpSetAttrString doesn't hold on to the value after it
+ // returns.
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ std::vector<int64_t> dims(4, 1);
+ TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* tensor =
+ TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
+ float tensor_data[] = {1};
+ memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
+ TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, tensor_handle, status);
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(tensor_handle);
+
+ std::vector<int64_t> values(4, 1);
+ TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size());
+ TFE_OpSetAttrIntList(op, "strides", values.data(), values.size());
+
+ const int BUFFER_SIZE = 10;
+ char buffer[BUFFER_SIZE];
+ std::strncpy(buffer, "VALID", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer));
+ // Overwriting value in "buffer", should be fine since TFE_Op
+ // shouldn't be holding on to it.
+ std::strncpy(buffer, "NHWC", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer));
+
+ TFE_OpSetAttrType(op, "T", TF_FLOAT);
+
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(op, &retvals[0], &num_retvals, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ tensor = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ EXPECT_EQ(4, TF_TensorByteSize(tensor));
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(op);
+
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
} // namespace
diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h
index 86e687df20..7661a01de4 100644
--- a/tensorflow/c/tf_status_helper.h
+++ b/tensorflow/c/tf_status_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H
-#define TENSORFLOW_C_TF_STATUS_HELPER_H
+#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H_
+#define TENSORFLOW_C_TF_STATUS_HELPER_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/core/status.h"
@@ -29,4 +29,4 @@ Status StatusFromTF_Status(const TF_Status* tf_status);
} // namespace tensorflow
-#endif // TENSORFLOW_C_TF_STATUS_HELPER_H
+#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_