diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/user_computation.h')
-rw-r--r-- | tensorflow/compiler/xla/service/user_computation.h | 336 |
1 files changed, 336 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h new file mode 100644 index 0000000000..06824b01c7 --- /dev/null +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -0,0 +1,336 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ + +#include <functional> +#include <map> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A UserComputation is the built-up computation that users create via the +// XLA Service interface. +// +// The XLA service adds instructions to a user computation via this +// interface. The state of the computation is stored as a SessionComputation +// proto which holds a record of all operation-building requests received by the +// XLA service. +// +// UserComputations are lowered to HloComputations which are passed to the high +// level compiler interface. +class UserComputation { + public: + // Factory used when restoring a computation from serialized session + // computation (computation snapshot) data. Remaps any references to + // computation handle via the old_to_new mapping. + // + // An error will occur if the old_to_new mapping cannot resolve a reference to + // a computation that is present in session_computation. + static StatusOr<std::unique_ptr<UserComputation>> MakeWithRemapping( + const SessionComputation& session_computation, + const ComputationHandle& handle, + const std::map<int64, ComputationHandle>& old_to_new); + + // Creates an empty computation with the given name and computation handle. + explicit UserComputation(const string& name, const ComputationHandle& handle); + + // Enqueues a parameter-retrieving instruction onto this user computation. + // Returns an error status if the parameter number is already registered with + // different values. + StatusOr<ComputationDataHandle> AddParameterInstruction( + const ParameterRequest& parameter_request); + + // Enqueues a pad instruction onto this user computation. + StatusOr<ComputationDataHandle> AddPadInstruction( + const PadRequest& parameter_request); + + // Enqueues a tracing instruction onto this user computation. + // Returns an error status if the operand cannot be resolved. + Status AddTraceInstruction(const TraceRequest& trace_request); + + // Enqueues a random number generation instruction onto this user computation. + StatusOr<ComputationDataHandle> AddRngInstruction( + const RngRequest& rng_request); + + // Enqueues a unary instruction onto this user computation. + // Returns an error status if the operand index is out of bounds. + StatusOr<ComputationDataHandle> AddUnaryInstruction( + const UnaryOpRequest& unary_request); + + // Enqueues a binary instruction onto this user computation. + // Returns an error status if the operand indices are out of bounds. + StatusOr<ComputationDataHandle> AddBinaryInstruction( + const BinaryOpRequest& binary_request); + + // Enqueues a ternary instruction onto this user computation. + // Returns an error status if the operand indices are out of bounds. + StatusOr<ComputationDataHandle> AddTernaryInstruction( + const TernaryOpRequest& request); + + // Enqueues a variadic instruction onto this user computation. + // Returns an error status if the operand indices are out of bounds. + StatusOr<ComputationDataHandle> AddVariadicInstruction( + const VariadicOpRequest& variadic_request); + + // Enqueues a constant instruction onto this user computation. + StatusOr<ComputationDataHandle> AddConstantInstruction( + const ConstantRequest& constant_request); + + // Enqueues a get tuple element instruction onto this user computation. + StatusOr<ComputationDataHandle> AddGetTupleElementInstruction( + const GetTupleElementRequest& get_tuple_element_request); + + // Enqueues a map instruction onto this user computation. + StatusOr<ComputationDataHandle> AddMapInstruction( + const MapRequest& map_request, + const UserComputation& to_apply_computation); + + // Enqueues a convolution instruction onto this user computation. + StatusOr<ComputationDataHandle> AddConvolveInstruction( + const ConvolveRequest& convolve_request); + + // Enqueues a cross replica sum instruction onto this user computation. + StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction( + const CrossReplicaSumRequest& cross_replica_sum_request); + + // Enqueues an infeed instruction onto this user computation. + StatusOr<ComputationDataHandle> AddInfeedInstruction( + const InfeedRequest& infeed_request); + + // Enqueues a call instruction onto this user computation. + StatusOr<ComputationDataHandle> AddCallInstruction( + const CallRequest& call_request, + const UserComputation& to_apply_computation); + + // Enqueues a custom call instruction onto this user computation. + StatusOr<ComputationDataHandle> AddCustomCallInstruction( + const CustomCallRequest& custom_call_request); + + // Enqueues a broadcast instruction onto this user computation. + StatusOr<ComputationDataHandle> AddBroadcastInstruction( + const BroadcastRequest& broadcast_request); + + // Enqueues a reshape instruction onto this user computation. + StatusOr<ComputationDataHandle> AddReshapeInstruction( + const ReshapeRequest& reshape_request); + + // Enqueues a slice instruction onto this user computation. + StatusOr<ComputationDataHandle> AddSliceInstruction( + const SliceRequest& slice_request); + + // Enqueues a dynamic slice instruction onto this user computation. + StatusOr<ComputationDataHandle> AddDynamicSliceInstruction( + const DynamicSliceRequest& dynamic_slice_request); + + // Enqueues a dynamic update slice instruction onto this user computation. + StatusOr<ComputationDataHandle> AddDynamicUpdateSliceInstruction( + const DynamicUpdateSliceRequest& dynamic_update_slice_request); + + // Enqueues a concatenate instruction onto this user computation. + StatusOr<ComputationDataHandle> AddConcatenateInstruction( + const ConcatenateRequest& slice_request); + + // Enqueues a convert instruction onto this user computation. + StatusOr<ComputationDataHandle> AddConvertInstruction( + const ConvertRequest& convert_request); + + // Enqueues a reduce instruction onto this user computation. + StatusOr<ComputationDataHandle> AddReduceInstruction( + const ReduceRequest& reduce_request, + const UserComputation& reduction_computation); + + // Enqueues a windowed reduce instruction onto this user computation. + StatusOr<ComputationDataHandle> AddReduceWindowInstruction( + const ReduceWindowRequest& reduce_window_request, + const UserComputation& reduction_computation); + + // Enqueues a select-and-scatter instruction onto this user + // computation. + StatusOr<ComputationDataHandle> AddSelectAndScatterInstruction( + const SelectAndScatterRequest& scatter_to_selected_window_element_request, + const UserComputation& select_computation, + const UserComputation& scatter_computation); + + // Enqueues a reverse instruction onto this user computation. + StatusOr<ComputationDataHandle> AddReverseInstruction( + const ReverseRequest& reverse_request); + + // Enqueues a while instruction onto this user computation. + StatusOr<ComputationDataHandle> AddWhileInstruction( + const WhileRequest& while_request, + const UserComputation& condition_computation, + const UserComputation& body_computation); + + // Enqueues a Send instruction onto this user computation. + Status AddSendInstruction(const SendRequest& send_request); + + // Enqueues a Recv instruction onto this user computation. + StatusOr<ComputationDataHandle> AddRecvInstruction( + const RecvRequest& recv_request); + + // Returns the user-provided name of this user computation, which is provided + // via the XLA computation-building API. + const string& name() const { return name_; } + + // Subsequent executions of this computation will compute the value + // represented by handle, rather than the last expression enqueued + // on the computation. + Status SetReturnValue(const ComputationDataHandle& handle); + + // Return a versioned handle for this computation. + VersionedComputationHandle GetVersionedHandle() const; + + // Return a versioned handle for this computation with a version equal to the + // point at which given operation was added to the computation. + VersionedComputationHandle GetVersionedHandleAtOperation( + const ComputationDataHandle& operation) const; + + // Return a version value representing the current state of the + // computation. + VersionedComputationHandle::Version version() const; + + // Computes and returns the program shape for the user computation -- gathers + // parameters and result type into a single proto. A shared_ptr is used + // because the returned pointer refers to an internally cached value which may + // be discarded by the UserComputation object. This avoid unnecessary copies. + // + // If the parameter space is not dense (i.e. there are holes in the parameter + // numbers provided) then an error status is returned. + StatusOr<std::shared_ptr<const ProgramShape>> ComputeProgramShape( + VersionedComputationHandle::Version version) const; + + // Returns true if the given data handle does not depend on any + // parameters. That is, the value can be computed at compile time. + StatusOr<bool> IsConstant(const ComputationDataHandle& handle); + + // Returns the output shape of the operation indicated by the given handle. + StatusOr<Shape> GetShape(const ComputationDataHandle& handle); + + // Builds a HLO computation from the UserComputation. The parameter "resolver" + // is a function which returns a pointer to the HloComputation corresponding + // to the given ComputationHandle at the given version. The resolver is used + // for operations, such as map, which call other computations and need a + // pointer to the called HloComputation to construct the respective HLO + // instructions. If include_unused_computation is true, then all parameter + // instructions are lowered into HloInstructions even if the parameter is + // unused (the root of the computation is unreachable from the parameter). + using HloComputationResolver = + std::function<HloComputation*(const VersionedComputationHandle& handle)>; + StatusOr<std::unique_ptr<HloComputation>> BuildHloComputation( + VersionedComputationHandle::Version version, + HloComputationResolver hlo_resolver, + bool include_unused_parameters = true) const; + + // Return a vector containing the embedded computations used by this + // UserComputation. Only embedded computations which are called directly by + // this UserComputation are included. That is, the transitive closure of + // embedded computations is not included. + std::vector<VersionedComputationHandle> GetEmbeddedComputations( + VersionedComputationHandle::Version version) const; + + // Returns the number of OperationRequest objects in this UserComputation. + // The 'version' of a computation is identical to the number of + // OperationRequests in the UserComputation. + int64 request_count(VersionedComputationHandle::Version version) const { + return version; + } + + // Returns a copy of the internal session state for this computation -- this + // is useful for serializing the guts of a user computation, though references + // to other handles (e.g. referred-to computations) must be handled with care + // in the serialization / de-serialization process. + SessionComputation CloneSessionComputation( + VersionedComputationHandle::Version version) const; + + private: + // Warning: dangerous mutating operation that doesn't respect versioning. + // This is only used at initialization time when constructing from a + // SessionComputation a la MakeWithRemapping. + // + // Remaps references to old computations (with handle values in the keys of + // old_to_new) to the computation handle given in the values. This is useful + // when loading computations from snapshots, to finish initialization, before + // the user computation is released into the wild. + Status RemapEmbeddedComputations( + const std::map<int64, ComputationHandle>& old_to_new) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Returns the OperationRequestion corresponding to the root (result) of the + // computation. + const OperationRequest& GetRoot(VersionedComputationHandle::Version version) + const EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Returns the OperationRequest corresponding to the given handle value. + StatusOr<const OperationRequest*> LookupRequest( + const ComputationDataHandle& handle) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Creates a new ComputationDataHandle with the next available handle value. + ComputationDataHandle CreateComputationDataHandle() + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Checks whether the parameter numbers of the parameter operations are + // contiguous starting from zero. Returns appropriate error status if not. + Status CheckParametersAreContiguous( + VersionedComputationHandle::Version version) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Name of the computation. + string name_; + + mutable tensorflow::mutex mutex_; + + // State of the computation as a record of all operation-building requests. + SessionComputation session_computation_ GUARDED_BY(mutex_); + + // Mapping from parameter number to operation request containing the + // respective ParameterRequest. + std::map<int64, OperationRequest*> parameters_ GUARDED_BY(mutex_); + + // The next ComputationDataHandle value to assign. Handle values are assigned + // sequentially. + int64 next_handle_value_ GUARDED_BY(mutex_); + + // If handle_to_return_.has_handle() then an Execution of this Computation + // will compute the value represented by handle_to_return_, otherwise it will + // compute the value of (next_handle_value_ - 1). + ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_); + + // Memoized ProgramShape and its version. A shared_ptr is used because + // references to this object are returned by ComputeProgramShape. + mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0; + mutable std::shared_ptr<const ProgramShape> program_shape_ GUARDED_BY(mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(UserComputation); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ |