diff options
author | 2018-05-23 12:18:23 -0700 | |
---|---|---|
committer | 2018-05-23 12:21:07 -0700 | |
commit | d2309fe5895ba431a4bb11e79564d12028cc93d5 (patch) | |
tree | 7db5c2464c4b2514157dbc3f79a7e945c3709704 /tensorflow/core/framework/variant.cc | |
parent | f6e5089c41fc234ca19fabe2e529ee877a09a33d (diff) |
Introduce Encoder and Decoder classes so that platform/*coding* doesn't have to
depend on framework/resource_handler and framework/variant.
PiperOrigin-RevId: 197768387
Diffstat (limited to 'tensorflow/core/framework/variant.cc')
-rw-r--r-- | tensorflow/core/framework/variant.cc | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc index 6ad2fafee7..5a507804b0 100644 --- a/tensorflow/core/framework/variant.cc +++ b/tensorflow/core/framework/variant.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -73,4 +74,36 @@ bool DecodeVariant(const string& buf, VariantTensorDataProto* value) { return value->ParseFromString(buf); } +void EncodeVariantList(const Variant* variant_array, int64 n, + std::unique_ptr<port::StringListEncoder> e) { + for (int i = 0; i < n; ++i) { + string s; + variant_array[i].Encode(&s); + e->Append(s); + } + e->Finalize(); +} + +bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d, + Variant* variant_array, int64 n) { + std::vector<uint32> sizes(n); + if (!d->ReadSizes(&sizes)) return false; + + for (int i = 0; i < n; ++i) { + if (variant_array[i].is_empty()) { + variant_array[i] = VariantTensorDataProto(); + } + string str(d->Data(sizes[i]), sizes[i]); + if (!variant_array[i].Decode(str)) return false; + if (!DecodeUnaryVariant(&variant_array[i])) { + LOG(ERROR) << "Could not decode variant with type_name: \"" + << variant_array[i].TypeName() + << "\". Perhaps you forgot to register a " + "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?"; + return false; + } + } + return true; +} + } // end namespace tensorflow |