diff options
author | Asim Shankar <ashankar@google.com> | 2018-06-20 10:14:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-20 10:16:31 -0700 |
commit | 2b0805301e4531dd7c2ed677d932f6408675460e (patch) | |
tree | 81ba5542d342f80fb56a3b57e2b6babf8ce31abd /tensorflow/c/eager | |
parent | af3455aad7ebf2e70c816e642f90594625e4fd44 (diff) |
[eager]: Support string attributes where the value contains `\0`.
Apparently, some custom operations stuff non-printable characters in string
valued attributes.
This change also makes the eager C API consistent with the C API for graph
construction (TF_SetAttrString and TF_SetAttrStringList).
PiperOrigin-RevId: 201372089
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 38 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.h | 6 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 4 |
3 files changed, 31 insertions, 17 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 55d9c26b0d..6e4764bcbf 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -441,8 +442,11 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, return ret; } -void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { - op->operation.MutableAttrs()->Set(attr_name, value); +void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, + size_t length) { + op->operation.MutableAttrs()->Set( + attr_name, + tensorflow::StringPiece(static_cast<const char*>(value), length)); } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { @@ -493,16 +497,22 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } -#define TFE_OP_SET_ATTR_LIST(fn, type) \ - void fn(TFE_Op* op, const char* attr_name, const type* values, \ - int num_values) { \ - op->operation.MutableAttrs()->Set( \ - attr_name, \ - tensorflow::gtl::ArraySlice<const type>(values, num_values)); \ +void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values) { + std::vector<tensorflow::StringPiece> v(num_values); + for (int i = 0; i < num_values; ++i) { + v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]), + lengths[i]); } -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) -#undef TFE_OP_SET_ATTR_LIST + op->operation.MutableAttrs()->Set(attr_name, v); +} + +void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, + const float* values, int num_values) { + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values)); +} void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { @@ -675,9 +685,11 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, const char* attr_name, TF_Status* status) { switch (default_value.value_case()) { - case tensorflow::AttrValue::kS: - TFE_OpSetAttrString(op, attr_name, default_value.s().data()); + case tensorflow::AttrValue::kS: { + const string& v = default_value.s(); + TFE_OpSetAttrString(op, attr_name, v.data(), v.size()); break; + } case tensorflow::AttrValue::kI: TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i())); break; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 1862af3ce2..fdbd5374b2 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -278,7 +278,8 @@ TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, - const char* value); + const void* value, + size_t length); TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, @@ -305,7 +306,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, - const char** value, + const void* const* values, + const size_t* lengths, int num_values); TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 1d71a78b75..cd035940ff 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1162,8 +1162,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", ""); - TFE_OpSetAttrString(op, "shared_name", ""); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); if (TF_GetCode(status) != TF_OK) return nullptr; TFE_TensorHandle* var_handle = nullptr; int num_retvals = 1; |