aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/variant_op_copy_test.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-08-24 13:09:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-24 13:17:23 -0700
commit105dd2aa22c7aaac401d2d8ed3131b5d81fdc227 (patch)
tree4f92b21249d85965fb93bf6b15f517b875e329ba /tensorflow/core/framework/variant_op_copy_test.cc
parent9380e3a8327e801bee5cf3ab34b98dcbd3c72bf7 (diff)
Implement Variant changes suggested by Manjunath Kudlur; simplify API.
1. Remove the self-mutating MaybeDecodeAndGet<>() in favor of the simple, original get<>(). 2. Remove mutexes and locks from everything. To allow tensorflow RPC to continue working transparently: 3. Add a registration mechanism for Variant decoding -- used by variant_coding within TensorFlow. 4. Add a MaybeDecodeAndCopy<>() const; for use by callers of Session::Run that get Tensors back which may contain encoded Variants. 5. Small bugfixes to Variant Encode impls and tensor FromProtoField and variant_coding. PiperOrigin-RevId: 166383211
Diffstat (limited to 'tensorflow/core/framework/variant_op_copy_test.cc')
-rw-r--r--tensorflow/core/framework/variant_op_copy_test.cc15
1 files changed, 7 insertions, 8 deletions
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc
index ff3415cb50..f02c572681 100644
--- a/tensorflow/core/framework/variant_op_copy_test.cc
+++ b/tensorflow/core/framework/variant_op_copy_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -48,6 +49,8 @@ struct StoredTensorValue {
}
};
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue");
+
REGISTER_OP("CreateTestVariant")
.Output("output: variant")
.SetShapeFn(shape_inference::UnknownShape);
@@ -112,16 +115,12 @@ TEST(VariantOpCopyTest, CreateCopyCPUToCPU) {
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({create_op, identity}, &outputs));
EXPECT_EQ(2, outputs.size());
- Variant& r1 = outputs[1].scalar<Variant>()();
+ const Variant& r1 = outputs[1].scalar<Variant>()();
- // Accessing the object more than once does not change its behavior (even
- // though the first access performed a deserialization step).
- EXPECT_EQ("StoredTensorValue", r1.TypeName());
- EXPECT_EQ(42, CHECK_NOTNULL(r1.MaybeDecodeAndGet<StoredTensorValue>())
- ->stored.scalar<int32>()());
EXPECT_EQ("StoredTensorValue", r1.TypeName());
- EXPECT_EQ(42, CHECK_NOTNULL(r1.MaybeDecodeAndGet<StoredTensorValue>())
- ->stored.scalar<int32>()());
+ const StoredTensorValue* v1 = r1.get<StoredTensorValue>();
+ EXPECT_NE(v1, nullptr);
+ EXPECT_EQ(42, v1->stored.scalar<int32>()());
}
} // end namespace tensorflow