aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/variant.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-23 12:18:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 12:21:07 -0700
commitd2309fe5895ba431a4bb11e79564d12028cc93d5 (patch)
tree7db5c2464c4b2514157dbc3f79a7e945c3709704 /tensorflow/core/framework/variant.cc
parentf6e5089c41fc234ca19fabe2e529ee877a09a33d (diff)
Introduce Encoder and Decoder classes so that platform/*coding* doesn't have to
depend on framework/resource_handler and framework/variant. PiperOrigin-RevId: 197768387
Diffstat (limited to 'tensorflow/core/framework/variant.cc')
-rw-r--r--tensorflow/core/framework/variant.cc33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
index 6ad2fafee7..5a507804b0 100644
--- a/tensorflow/core/framework/variant.cc
+++ b/tensorflow/core/framework/variant.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -73,4 +74,36 @@ bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
return value->ParseFromString(buf);
}
+void EncodeVariantList(const Variant* variant_array, int64 n,
+ std::unique_ptr<port::StringListEncoder> e) {
+ for (int i = 0; i < n; ++i) {
+ string s;
+ variant_array[i].Encode(&s);
+ e->Append(s);
+ }
+ e->Finalize();
+}
+
+bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
+ Variant* variant_array, int64 n) {
+ std::vector<uint32> sizes(n);
+ if (!d->ReadSizes(&sizes)) return false;
+
+ for (int i = 0; i < n; ++i) {
+ if (variant_array[i].is_empty()) {
+ variant_array[i] = VariantTensorDataProto();
+ }
+ string str(d->Data(sizes[i]), sizes[i]);
+ if (!variant_array[i].Decode(str)) return false;
+ if (!DecodeUnaryVariant(&variant_array[i])) {
+ LOG(ERROR) << "Could not decode variant with type_name: \""
+ << variant_array[i].TypeName()
+ << "\". Perhaps you forgot to register a "
+ "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
+ return false;
+ }
+ }
+ return true;
+}
+
} // end namespace tensorflow