#include "TorchMLModelDriver.hpp" #include "TorchMLModelDriverImplementation.hpp" //============================================================================== // // This is the standard interface to KIM Model Drivers // //============================================================================== //****************************************************************************** extern "C" { int model_driver_create(KIM::ModelDriverCreate * const modelDriverCreate, KIM::LengthUnit const requestedLengthUnit, KIM::EnergyUnit const requestedEnergyUnit, KIM::ChargeUnit const requestedChargeUnit, KIM::TemperatureUnit const requestedTemperatureUnit, KIM::TimeUnit const requestedTimeUnit) { int ier; // read input files, convert units if needed, compute // interpolation coefficients, set cutoff, and publish parameters auto modelObject = new TorchMLModelDriver(modelDriverCreate, requestedLengthUnit, requestedEnergyUnit, requestedChargeUnit, requestedTemperatureUnit, requestedTimeUnit, &ier); if (ier) { // constructor already reported the error delete modelObject; return ier; } // register pointer to TorchMLModelDriverImplementation object in KIM object modelDriverCreate->SetModelBufferPointer(static_cast<void *>(modelObject)); // everything is good ier = false; return ier; } } // extern "C" //============================================================================== // // Implementation of TorchMLModelDriver public wrapper functions // //============================================================================== // ****************************** ********* ********************************** TorchMLModelDriver::TorchMLModelDriver( KIM::ModelDriverCreate * const modelDriverCreate, KIM::LengthUnit const requestedLengthUnit, KIM::EnergyUnit const requestedEnergyUnit, KIM::ChargeUnit const requestedChargeUnit, KIM::TemperatureUnit const requestedTemperatureUnit, KIM::TimeUnit const requestedTimeUnit, int * const ier) { implementation_ = std::make_unique<TorchMLModelDriverImplementation>( modelDriverCreate, requestedLengthUnit, requestedEnergyUnit, requestedChargeUnit, requestedTemperatureUnit, requestedTimeUnit, ier); } // ************************************************************************** TorchMLModelDriver::~TorchMLModelDriver() = default; //****************************************************************************** // static member function int TorchMLModelDriver::Destroy(KIM::ModelDestroy * const modelDestroy) { TorchMLModelDriver * modelObject; modelDestroy->GetModelBufferPointer(reinterpret_cast<void **>(&modelObject)); delete modelObject; return false; } //****************************************************************************** // static member function int TorchMLModelDriver::Refresh(KIM::ModelRefresh * const modelRefresh) { TorchMLModelDriver * modelObject; modelRefresh->GetModelBufferPointer(reinterpret_cast<void **>(&modelObject)); return modelObject->implementation_->Refresh(modelRefresh); } //****************************************************************************** // static member function int TorchMLModelDriver::Compute( KIM::ModelCompute const * const modelCompute, KIM::ModelComputeArguments const * const modelComputeArguments) { TorchMLModelDriver * modelObject; modelCompute->GetModelBufferPointer(reinterpret_cast<void **>(&modelObject)); return modelObject->implementation_->Compute(modelComputeArguments); } //****************************************************************************** // static member function #undef KIM_LOGGER_OBJECT_NAME #define KIM_LOGGER_OBJECT_NAME modelComputeArgumentsCreate int TorchMLModelDriver::ComputeArgumentsCreate( KIM::ModelCompute const * const modelCompute, KIM::ModelComputeArgumentsCreate * const modelComputeArgumentsCreate) { TorchMLModelDriver * modelObject; modelCompute->GetModelBufferPointer(reinterpret_cast<void **>(&modelObject)); return modelObject->implementation_->ComputeArgumentsCreate( modelComputeArgumentsCreate); } //****************************************************************************** // static member function int TorchMLModelDriver::ComputeArgumentsDestroy( KIM::ModelCompute const * modelCompute, KIM::ModelComputeArgumentsDestroy * const modelComputeArgumentsDestroy) { TorchMLModelDriver * modelObject; modelCompute->GetModelBufferPointer(reinterpret_cast<void **>(&modelObject)); return modelObject->implementation_->ComputeArgumentsDestroy( modelComputeArgumentsDestroy); } //****************************************************************************** // static member function int TorchMLModelDriver::WriteParameterizedModel( const KIM::ModelWriteParameterizedModel * const modelWriteParameterizedModel) { TorchMLModelDriver * modelObject; modelWriteParameterizedModel->GetModelBufferPointer( reinterpret_cast<void **>(&modelObject)); return modelObject->implementation_->WriteParameterizedModel( modelWriteParameterizedModel); }