aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils/grappler_test.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/utils/grappler_test.h')
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.h9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h
index 3bc7bea454..e1394b9c35 100644
--- a/tensorflow/core/grappler/utils/grappler_test.h
+++ b/tensorflow/core/grappler/utils/grappler_test.h
@@ -57,6 +57,15 @@ class GrapplerTest : public ::testing::Test {
// Count nodes of the given op-type in a graph.
int CountOpNodes(const GraphDef& graph, const string& op);
+ // Get a random tansor with given shape.
+ template <DataType DTYPE>
+ Tensor GenerateRandomTensor(const TensorShape& shape) const {
+ typedef typename EnumToDataType<DTYPE>::Type T;
+ Tensor tensor(DTYPE, shape);
+ tensor.flat<T>() = tensor.flat<T>().random();
+ return tensor;
+ }
+
private:
SessionOptions options_;
};