diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-08-24 13:09:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-24 13:17:23 -0700 |
commit | 105dd2aa22c7aaac401d2d8ed3131b5d81fdc227 (patch) | |
tree | 4f92b21249d85965fb93bf6b15f517b875e329ba /tensorflow/core/framework/variant_op_copy_test.cc | |
parent | 9380e3a8327e801bee5cf3ab34b98dcbd3c72bf7 (diff) |
Implement Variant changes suggested by Manjunath Kudlur; simplify API.
1. Remove the self-mutating MaybeDecodeAndGet<>() in favor of the simple,
original get<>().
2. Remove mutexes and locks from everything.
To allow tensorflow RPC to continue working transparently:
3. Add a registration mechanism for Variant decoding -- used by
variant_coding within TensorFlow.
4. Add a MaybeDecodeAndCopy<>() const; for use by callers of Session::Run
that get Tensors back which may contain encoded Variants.
5. Small bugfixes to Variant Encode impls and tensor FromProtoField and
variant_coding.
PiperOrigin-RevId: 166383211
Diffstat (limited to 'tensorflow/core/framework/variant_op_copy_test.cc')
-rw-r--r-- | tensorflow/core/framework/variant_op_copy_test.cc | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc index ff3415cb50..f02c572681 100644 --- a/tensorflow/core/framework/variant_op_copy_test.cc +++ b/tensorflow/core/framework/variant_op_copy_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -48,6 +49,8 @@ struct StoredTensorValue { } }; +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue"); + REGISTER_OP("CreateTestVariant") .Output("output: variant") .SetShapeFn(shape_inference::UnknownShape); @@ -112,16 +115,12 @@ TEST(VariantOpCopyTest, CreateCopyCPUToCPU) { std::vector<Tensor> outputs; TF_EXPECT_OK(session.Run({create_op, identity}, &outputs)); EXPECT_EQ(2, outputs.size()); - Variant& r1 = outputs[1].scalar<Variant>()(); + const Variant& r1 = outputs[1].scalar<Variant>()(); - // Accessing the object more than once does not change its behavior (even - // though the first access performed a deserialization step). - EXPECT_EQ("StoredTensorValue", r1.TypeName()); - EXPECT_EQ(42, CHECK_NOTNULL(r1.MaybeDecodeAndGet<StoredTensorValue>()) - ->stored.scalar<int32>()()); EXPECT_EQ("StoredTensorValue", r1.TypeName()); - EXPECT_EQ(42, CHECK_NOTNULL(r1.MaybeDecodeAndGet<StoredTensorValue>()) - ->stored.scalar<int32>()()); + const StoredTensorValue* v1 = r1.get<StoredTensorValue>(); + EXPECT_NE(v1, nullptr); + EXPECT_EQ(42, v1->stored.scalar<int32>()()); } } // end namespace tensorflow |