From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/core/kernels/ops_testutil.h | 191 +++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 tensorflow/core/kernels/ops_testutil.h (limited to 'tensorflow/core/kernels/ops_testutil.h') diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h new file mode 100644 index 0000000000..7a3405bf04 --- /dev/null +++ b/tensorflow/core/kernels/ops_testutil.h @@ -0,0 +1,191 @@ +#ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ +#define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" +#include + +namespace tensorflow { + +namespace test { + +// Return a NodeDef with the specified name/op/inputs. +NodeDef Node(const string& name, const string& op, + const std::vector& inputs); + +} // namespace test + +// Helpful functions to test operators. +// +// This class will eventually be replaced / heavily modified +// to use the BrainClient interface. +class OpsTestBase : public ::testing::Test { + public: + OpsTestBase() : device_type_(DEVICE_CPU) { + device_.reset( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + CHECK(device_.get()) << "Could not create CPU device"; + } + + ~OpsTestBase() override { + gtl::STLDeleteElements(&tensors_); + context_.reset(nullptr); + } + + void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); } + + // Clients can manipulate the underlying NodeDef via this accessor. + NodeDef* node_def() { return &node_def_; } + + // Initializes an operator that takes in 'input_types' as input + // and output types as output. + // + // Returns the status of initialization. + Status InitOp() { + Status status; + kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(), + node_def_, &status); + if (kernel_ != nullptr) input_types_ = kernel_->input_types(); + return status; + } + + // Adds an input for every element described by the shape. + // 'input_mapping' maps an index (0...NumElements(shape)) to a + // value. + // + // TODO(vrv): Replace with something like a BrainClient Feed. + template + void AddInput(const TensorShape& shape, std::function input_mapping) { + CHECK_GT(input_types_.size(), inputs_.size()) + << "Adding more inputs than types; perhaps you need to call MakeOp"; + bool is_ref = IsRefType(input_types_[inputs_.size()]); + Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()), + DataTypeToEnum::v(), shape); + test::FillFn(input, input_mapping); + tensors_.push_back(input); + if (is_ref) { + CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), + DataTypeToEnum::v()); + inputs_.push_back({&lock_for_refs_, input}); + } else { + CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum::v()); + inputs_.push_back({nullptr, input}); + } + } + + // Like AddInput but takes in an explicit arrayslice of data. + template + void AddInputFromArray(const TensorShape& shape, + const gtl::ArraySlice& data) { + CHECK_GT(input_types_.size(), inputs_.size()) + << "Adding more inputs than types; perhaps you need to call MakeOp"; + bool is_ref = IsRefType(input_types_[inputs_.size()]); + Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()), + DataTypeToEnum::v(), shape); + test::FillValues(input, data); + tensors_.push_back(input); + if (is_ref) { + CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), + DataTypeToEnum::v()); + inputs_.push_back({&lock_for_refs_, input}); + } else { + CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum::v()); + inputs_.push_back({nullptr, input}); + } + } + + // Runs an operation producing 'num_outputs' outputs. + // + // Returns the context's status after running the operation. + Status RunOpKernel() { + OpKernelContext::Params params; + params.device = device_.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs_; + params.op_kernel = kernel_.get(); + params.output_alloc_attr = [this, ¶ms](int index) { + AllocatorAttributes attr; + const bool on_host = + (kernel_->output_memory_types()[index] == HOST_MEMORY); + attr.set_on_host(on_host); + return attr; + }; + checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; + params.slice_reader_cache = &slice_reader_cache_wrapper; + + context_.reset(new OpKernelContext(params)); + device_->Compute(kernel_.get(), context_.get()); + return context_->status(); + } + + // Returns the tensor input for 'input_index'. + // + // REQUIRES: 0 <= input_index < context_->num_inputs() + const Tensor& GetInput(int input_index) const { + CHECK_LT(input_index, context_->num_inputs()); + CHECK(!IsRefType(context_->input_dtype(input_index))); + return context_->input(input_index); + } + + TensorValue mutable_input(int input_index) { + CHECK_LT(input_index, inputs_.size()); + return inputs_[input_index]; + } + // Returns the tensor output for 'output_index'. + // + // REQUIRES: 0 <= output_index < context_->num_outputs() + Tensor* GetOutput(int output_index) { + CHECK_LT(output_index, context_->num_outputs()); + return context_->mutable_output(output_index); + } + + Allocator* allocator() { + return device_->GetAllocator(AllocatorAttributes()); + } + + const DataTypeVector& output_types() const { return kernel_->output_types(); } + + protected: + std::unique_ptr device_; + + std::unique_ptr kernel_; + NodeDef node_def_; + DataTypeVector input_types_; + DeviceType device_type_; + + mutex lock_for_refs_; // Used as the Mutex for inputs added as refs + + gtl::InlinedVector inputs_; + // Owns Tensors. + std::vector tensors_; + + std::unique_ptr context_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_ -- cgit v1.2.3