aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-06-20 10:14:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 10:16:31 -0700
commit2b0805301e4531dd7c2ed677d932f6408675460e (patch)
tree81ba5542d342f80fb56a3b57e2b6babf8ce31abd /tensorflow/c/eager
parentaf3455aad7ebf2e70c816e642f90594625e4fd44 (diff)
[eager]: Support string attributes where the value contains `\0`.
Apparently, some custom operations stuff non-printable characters in string valued attributes. This change also makes the eager C API consistent with the C API for graph construction (TF_SetAttrString and TF_SetAttrStringList). PiperOrigin-RevId: 201372089
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/c_api.cc38
-rw-r--r--tensorflow/c/eager/c_api.h6
-rw-r--r--tensorflow/c/eager/c_api_test.cc4
3 files changed, 31 insertions, 17 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 55d9c26b0d..6e4764bcbf 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -46,6 +46,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -441,8 +442,11 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
return ret;
}
-void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
- op->operation.MutableAttrs()->Set(attr_name, value);
+void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
+ size_t length) {
+ op->operation.MutableAttrs()->Set(
+ attr_name,
+ tensorflow::StringPiece(static_cast<const char*>(value), length));
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
@@ -493,16 +497,22 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
-#define TFE_OP_SET_ATTR_LIST(fn, type) \
- void fn(TFE_Op* op, const char* attr_name, const type* values, \
- int num_values) { \
- op->operation.MutableAttrs()->Set( \
- attr_name, \
- tensorflow::gtl::ArraySlice<const type>(values, num_values)); \
+void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
+ const void* const* values, const size_t* lengths,
+ int num_values) {
+ std::vector<tensorflow::StringPiece> v(num_values);
+ for (int i = 0; i < num_values; ++i) {
+ v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
+ lengths[i]);
}
-TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
-TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
-#undef TFE_OP_SET_ATTR_LIST
+ op->operation.MutableAttrs()->Set(attr_name, v);
+}
+
+void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
+ const float* values, int num_values) {
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
+}
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
@@ -675,9 +685,11 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
const char* attr_name, TF_Status* status) {
switch (default_value.value_case()) {
- case tensorflow::AttrValue::kS:
- TFE_OpSetAttrString(op, attr_name, default_value.s().data());
+ case tensorflow::AttrValue::kS: {
+ const string& v = default_value.s();
+ TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
break;
+ }
case tensorflow::AttrValue::kI:
TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
break;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 1862af3ce2..fdbd5374b2 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -278,7 +278,8 @@ TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType(
TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op,
const char* attr_name,
- const char* value);
+ const void* value,
+ size_t length);
TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name,
int64_t value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name,
@@ -305,7 +306,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op,
TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op,
const char* attr_name,
- const char** value,
+ const void* const* values,
+ const size_t* lengths,
int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op,
const char* attr_name,
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 1d71a78b75..cd035940ff 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1162,8 +1162,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
- TFE_OpSetAttrString(op, "container", "");
- TFE_OpSetAttrString(op, "shared_name", "");
+ TFE_OpSetAttrString(op, "container", "", 0);
+ TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;