// // Created by amit on 7/12/22. // #ifndef TORCH_ML_MODEL_DRIVER_IMPLEMENTATION_HPP #define TORCH_ML_MODEL_DRIVER_IMPLEMENTATION_HPP #include "KIM_ModelDriverHeaders.hpp" #include "MLModel.hpp" #include <memory> #ifdef USE_LIBDESC #include "Descriptors.hpp" using namespace Descriptor; #endif class TorchMLModelDriverImplementation { public: // All file params are public double influence_distance, cutoff_distance; int n_elements, n_layers; std::vector<std::string> elements_list; std::string preprocessing; std::string model_name; bool returns_forces; std::string descriptor_name = "None"; std::string descriptor_param_file = "None"; std::string descriptor_param_file_content; std::string fully_qualified_model_name; TorchMLModelDriverImplementation( KIM::ModelDriverCreate * modelDriverCreate, KIM::LengthUnit requestedLengthUnit, KIM::EnergyUnit requestedEnergyUnit, KIM::ChargeUnit requestedChargeUnit, KIM::TemperatureUnit requestedTemperatureUnit, KIM::TimeUnit requestedTimeUnit, int * ier); ~TorchMLModelDriverImplementation(); int Refresh(KIM::ModelRefresh * modelRefresh); int Refresh(KIM::ModelDriverCreate * modelRefresh); int Compute(KIM::ModelComputeArguments const * modelComputeArguments); int ComputeArgumentsCreate( KIM::ModelComputeArgumentsCreate * modelComputeArgumentsCreate); int ComputeArgumentsDestroy( KIM::ModelComputeArgumentsDestroy * modelComputeArgumentsDestroy); int WriteParameterizedModel(KIM::ModelWriteParameterizedModel const * const modelWriteParameterizedModel) const; private: // Derived or assigned variables are private int modelWillNotRequestNeighborsOfNoncontributingParticles_; int n_contributing_atoms; int number_of_inputs; std::vector<std::int64_t> species_atomic_number; std::vector<std::int64_t> contraction_array; std::unique_ptr<MLModel> ml_model; #ifdef USE_LIBDESC AvailableDescriptor descriptor_kind; std::unique_ptr<DescriptorKind> descriptor; #endif std::vector<int> num_neighbors_; std::vector<int> neighbor_list; std::vector<int> z_map; std::vector<double> descriptor_array; std::vector<std::vector<std::int64_t> > graph_edge_indices; void updateNeighborList(KIM::ModelComputeArguments const * modelComputeArguments); void setDefaultInputs(const KIM::ModelComputeArguments * modelComputeArguments); void setDescriptorInputs(const KIM::ModelComputeArguments * modelComputeArguments); void setGraphInputs(const KIM::ModelComputeArguments * modelComputeArguments); void readParametersFile(KIM::ModelDriverCreate * modelDriverCreate, int * ier); static void unitConversion(KIM::ModelDriverCreate * modelDriverCreate, KIM::LengthUnit requestedLengthUnit, KIM::EnergyUnit requestedEnergyUnit, KIM::ChargeUnit requestedChargeUnit, KIM::TemperatureUnit requestedTemperatureUnit, KIM::TimeUnit requestedTimeUnit, int * ier); void setSpecies(KIM::ModelDriverCreate * modelDriverCreate, int * ier); static void registerFunctionPointers(KIM::ModelDriverCreate * modelDriverCreate, int * ier); void preprocessInputs(KIM::ModelComputeArguments const * modelComputeArguments); void postprocessOutputs(KIM::ModelComputeArguments const *); void Run(KIM::ModelComputeArguments const * modelComputeArguments); void contributingAtomCounts( KIM::ModelComputeArguments const * modelComputeArguments); }; int sym_to_z(std::string &); // For hashing unordered_set of pairs // https://arxiv.org/pdf/2105.10752.pdf // It seems like this might not be the best approach for hashing edges // But surprisingly it is. // === Benchmarking (micro s) === // N_edgs Mrgsrt(arr) Cantor* BoostHash 32bitPackHash Cantor,bidirectional // 10^1 13 7 1 2 7 // 10^2 14 22 25 30 31 // 10^3 272 158 177 195 315 // 10^4 13222 1308 1511 1595 2498 // 10^5 1347793 25009 30921 28328 46013 // 10^6 142392300 422934 489779 466556 548133 // * current method // Most likely FMAs kinda instructions make CantorPairs on par with bitwise // hashes // class CantorPairing { public: int64_t operator()(const std::array<long, 2> & t) const { int64_t k1 = t[0]; int64_t k2 = t[1]; // int64_t kmin = std::min(k1, k2); // int64_t ksum = k1 + k2 + 1; // // return ((ksum * ksum - ksum % 2) + kmin) / 4; int64_t sum = k1 + k2; int64_t triangleNumber = sum * (sum + 1) / 2; return triangleNumber + k2; } }; #endif // TORCH_ML_MODEL_DRIVER_IMPLEMENTATION_HPP