#ifndef MLMODEL_HPP #define MLMODEL_HPP #include <cstdlib> #include <memory> #include <string> #include <type_traits> #include <vector> #include <torch/csrc/jit/runtime/graph_executor.h> #include <torch/script.h> #include <torch/torch.h> #ifndef DISABLE_GRAPH #include <torchscatter/scatter.h> #include <torchsparse/sparse.h> #endif // common datatype to torch type map. template<typename T> torch::Dtype getTorchDtype() { if (std::is_same<T, int>::value) return torch::kInt32; if (std::is_same<T, std::int64_t>::value) return torch::kInt64; if (std::is_same<T, float>::value) return torch::kFloat32; if (std::is_same<T, double>::value) return torch::kFloat64; throw std::runtime_error("Invalid datatype provided as input to the model"); } /* Abstract base class for an ML model -- 'product' of the factory pattern */ class MLModel { public: static MLModel * create(std::string & /*model_file_path*/, std::string & /*device_name*/, int /*model_input_size*/); virtual void SetInputNode(int /*model_input_index*/, int * /*input*/, std::vector<std::int64_t> & /*size*/, bool /*requires grad*/, bool /* to clone*/) = 0; virtual void SetInputNode(int /*model_input_index*/, std::int64_t * /*input*/, std::vector<std::int64_t> & /*size*/, bool /*requires grad*/, bool /* to clone*/) = 0; virtual void SetInputNode(int /*model_input_index*/, double * /*input*/, std::vector<std::int64_t> & /*size*/, bool /*requires grad*/, bool /* to clone*/) = 0; virtual void Run(double *, double *, double *, bool) = 0; virtual void WriteMLModel(std::string & /*model_path*/) = 0; virtual ~MLModel() = default; }; // Concrete MLModel corresponding to pytorch class PytorchModel : public MLModel { private: torch::jit::script::Module module_; std::vector<torch::jit::IValue> model_inputs_; std::unique_ptr<torch::Device> device_; void SetExecutionDevice(std::string & /*device_name*/); int grad_idx; template<typename T> void SetInputNodeTemplate(int idx, T * data, std::vector<std::int64_t> & shape, bool requires_grad, bool clone) { // Configure tensor options torch::TensorOptions options = torch::TensorOptions() .dtype(getTorchDtype<T>()) .requires_grad(requires_grad); // Create tensor from blob torch::Tensor input_tensor = torch::from_blob(data, shape, options); // explicit copy to device if not done already if (input_tensor.device() != *device_) input_tensor = input_tensor.to(*device_); // Only need clone if device is CPU, else implicit deep copy will be // triggered if (clone || (*device_ == torch::kCPU)) input_tensor = input_tensor.clone(); // Workaround for PyTorch bug if (requires_grad) { input_tensor.retain_grad(); grad_idx = idx; } model_inputs_[idx] = input_tensor; } public: std::string model_file_path_; PytorchModel(std::string & /*model_file_path*/, std::string & /*device_name*/, int /*input size*/); void SetInputNode(int /*model_input_index*/, int * /*input*/, std::vector<std::int64_t> & /*size*/, bool /*requires grad*/, bool /*to clone or not*/) override; void SetInputNode(int /*model_input_index*/, std::int64_t * /*input*/, std::vector<std::int64_t> & /*size*/, bool /*requires grad*/, bool /*to clone or not*/) override; void SetInputNode(int /*model_input_index*/, double * /*input*/, std::vector<std::int64_t> & /*size*/, bool /*requires grad*/, bool /*to clone or not*/) override; void Run(double * /*energy*/, double * /*partial_energy*/, double * /*forces*/, bool /*backprop*/) override; void WriteMLModel(std::string & /*path*/) override; ~PytorchModel() override = default; }; #endif /* MLMODEL_HPP */