#!/usr/bin/env python3
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2020-2021, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#     Daniel S. Karls
#
################################################################################

# The docstring below is vc_description
"""This verification check examines whether a model called in parallel by
multiple threads gives the same results as when called sequentially. A number
`num_configs` (preset in the code) of configurations is generated, each
containing a different number of atoms based on a randomly distorted,
periodic, face-centered cubic (fcc) structure containing a random distribution
of all atoms supported by the model. Configurations used for testing are
provided as auxiliary files. The energy and forces for each configuration are
computed in sequence. Then the calculations are repeated using Python
multithreading with `num_configs` threads and the Global Interpreter
Lock (GIL) released to ensure true parallelism. Each thread possesses its
own copy of memory related to the calculation. The threaded calculations
are repeated `num_cycles` times (preset in the code) with the configurations
randomly distributed to the threads to create many varying opportunities
for collisions. To pass the test, the total energy and the forces on all atoms
for a given configuration obtained in all cycles must be be identical to those
obtained in the sequential calculation. Failure of this verification check
implies that the model is inappropriately storing information in persistent
(static) memory during or between calls. Passing this verification check
provides some assurance the the model is thread-safe, but it is NOT a guarantee
due to the inherent randomness of race conditions in unsafe code. To be certain,
the model code would need to be studied and analyzed carefully."""

# Python 2-3 compatible code issues
from __future__ import print_function

from multiprocessing import cpu_count
import random
import threading

import numpy as np
import scipy.optimize
from ase.lattice.cubic import FaceCenteredCubic
from ase.calculators.kim import KIM, get_model_supported_species
import kim_python_utils.ase as kim_ase_utils
import kim_python_utils.vc as kim_vc_utils
import kimpy

try:
    input = raw_input
except NameError:
    pass

__version__ = "005"
__author__ = "Ellad Tadmor and Daniel S. Karls"


class KimpyError(Exception):
    """
    A call to a kimpy function resulted in a RuntimeError being raised
    """

    pass


def check_kimpy_call(f, *args):
    """
    Call a kimpy function using its arguments and, if a RuntimeError is raised,
    catch it and raise a KimpyError with the exception's message.

    (Starting with kimpy 2.0.0, a RuntimeError is the only exception type raised
    when something goes wrong.)
    """
    try:
        return f(*args)
    except RuntimeError as exception:
        raise KimpyError(f'Calling kimpy function "{f.__name__}" failed:\n  {str(exception)}')


################################################################################
#
#   FUNCTIONS
#
################################################################################
def cubic_cell_energy(alat, atoms, ncells_per_side):
    """
    Calculate the energy of the passed 'atoms' structure containing a
    cubic structure with 'ncells_per_side'. Scale to lattice constant
    'alat' (passed as a nd array of length 1) and return the energy.
    """
    acell = alat[0] * ncells_per_side
    atoms.set_cell([acell, acell, acell], scale_atoms=True)
    return atoms.get_potential_energy()


################################################################################
def compute_config(atoms):
    """
    Calculate the energy and forces for the passed 'atoms' structure.
    """
    # This forces a calculation of the model
    # if anything has changed and the results
    # energy, forces, etc. are stored in atoms.
    atoms.get_potential_energy()


################################################################################
def do_vc(model, vc):
    """
    Perform Thread Safety VC
    """
    # Get supported species
    species = get_model_supported_species(model)
    species = kim_ase_utils.remove_species_not_supported_by_ASE(list(species))
    species.sort()

    # Basic VC parameters
    num_cores = cpu_count()
    num_configs = 10
    num_cycles = 10
    lattice_constant = 2.5
    max_lattice_constant = 10.0
    mincells_per_side = 2
    maxcells_per_side = 5
    seed = 13
    random.seed(seed)

    # Print VC info
    agree_text = {True: "OK", False: "Failed"}
    dashwidth = 80
    vc.rwrite("")
    vc.rwrite("-" * dashwidth)
    vc.rwrite("Results for KIM Model       : %s" % model.strip())
    vc.rwrite("Supported species           : %s" % " ".join(species))
    vc.rwrite("")
    vc.rwrite("Number of cores             = %d" % num_cores)
    vc.rwrite("number of configurations    = %d" % num_configs)
    vc.rwrite("number of cycles            = %d" % num_cycles)
    vc.rwrite("random seed                 = %d" % seed)
    vc.rwrite("minimum unit cells per side = %d" % mincells_per_side)
    vc.rwrite("maximum unit cells per side = %d" % maxcells_per_side)
    vc.rwrite("-" * dashwidth)
    vc.rwrite("")

    # Check the model type. VC can only be performed for KIM Models
    col = check_kimpy_call(kimpy.collections.create)
    model_type = check_kimpy_call(col.get_item_type, model)
    if model_type != kimpy.collection_item_type.portableModel:
        vc.rwrite(model.strip() + " is not a KIM Model.")
        vc.rwrite("")
        vc.rwrite(
            "Thread safety can only be checked for models compatible with the KIM API."
        )
        vc.rwrite("Check not performed.")
        vc_grade = "N/A"
        vc_comment = "Thread safety can only be checked for KIM Models."
        return vc_grade, vc_comment

    # Make sure there are more than one processor available
    if num_cores == 1:
        vc.rwrite(
            "Only one core available on system. Unable to perform thread safety verification check."
        )
        raise RuntimeError(
            "Only one core available on system. Unable to perform thread safety verification check."
        )

    # Create num_config randomized configurations containing all
    # supported species
    vc.rwrite("SINGLE THREAD REFERENCE CALCULATIONS")
    vc.rwrite("")
    vc.rwrite(
        "{0:>6}   {1:>6}   {2:>13}   {3:>13}".format(
            "Config", "#Atoms", "Energy", "Ave Norm"
        )
    )
    vc.rwrite("-" * 47)
    alat = lattice_constant
    config_atoms = []
    config_energy = []
    config_forces = []
    for n in range(num_configs):
        # Randomly set system size based on allowed range
        ncells_per_side = random.randint(mincells_per_side, maxcells_per_side)

        # Since this VC only applies to KIM portable models, we know that the
        # calculator returned by function "KIM" will be a KIMCalculator instance,
        # which has context manager functions (__enter__ and __exit__) defined.
        # This will ensure that the destroy function of the model will be called,
        # which is helpful since some models may have global variables affected
        # by their create and destroy functions which could impact thread safety.
        with KIM(model) as calc:
            # Compute equilibrium lattice constant
            not_done = True
            got_equil = False
            while not_done:
                while True:
                    atoms = FaceCenteredCubic(
                        size=(ncells_per_side, ncells_per_side, ncells_per_side),
                        latticeconstant=alat,
                        symbol="H",
                        pbc=False,
                    )

                    # Change to the cell size in the following block should
                    # be run at most once when a selected size is too small
                    if len(atoms) < len(species):
                        ncells_per_side += 1
                    else:
                        break

                # Randomize the species in the cluster using all of the species supported by
                # the Model
                kim_ase_utils.randomize_species(atoms, species)

                # Attach calculator
                atoms.set_calculator(calc)

                # Find equilibrium lattice constant
                try:
                    res = scipy.optimize.minimize(
                        cubic_cell_energy,
                        alat,
                        args=(atoms, ncells_per_side),
                        method="Nelder-Mead",
                        tol=1e-6,
                    )
                    alat = res.x[0]
                    not_done = False
                    got_equil = True
                except:  # noqa: E722
                    # failed for some reason (assume it's because of KIM error)
                    alat += 0.5
                    if alat > max_lattice_constant:
                        not_done = False

            if not got_equil:
                raise RuntimeError(
                    "Unable to compute equilibrium lattice constant for one of the configurations."
                )

            # Rescale and perturb crystal
            kim_ase_utils.rescale_to_get_nonzero_forces(atoms, 0.01)
            pert_amp = 0.01 * alat
            kim_ase_utils.randomize_positions(atoms, pert_amp)
            kim_ase_utils.perturb_until_all_forces_sizeable(atoms, pert_amp)

            # Write configuration files
            aux_file = (
                "config-" + str(n).strip() + "-config-" + "".join(species) + ".xyz"
            )
            vc.vc_files.append(aux_file)
            vc.write_aux_ase_atoms(aux_file, atoms, "xyz")

            # Store configuration, energy, and forces
            config_atoms.append(atoms.copy())
            config_energy.append(atoms.get_potential_energy())
            config_forces.append(atoms.get_forces())

        # Instantiate new calculator for this configuration that releases the GIL so the
        # threads will be able to execute in parallel in the next stage of this VC.
        #
        # NOTE: While defining a different calculator for each configuration means that
        # there will be multiple KIM API Model Objects (one for each configuration), only
        # a *single* copy of the model's shared library will be loaded into the memory
        # space of this python process.  This means that if the model is doing something
        # thread-unsafe, we'll be able to detect it.
        config_atoms[n].set_calculator(KIM(model, options={"release_GIL": True}))

        # Report to user
        natoms = len(config_atoms[n])
        energy = config_energy[n]
        avenorm = np.linalg.norm(config_forces[n]) / natoms
        vc.rwrite(
            "{0:6d}   {1:6d}   {2:13.6e}   {3:13.6e}".format(n, natoms, energy, avenorm)
        )

    # Close table
    vc.rwrite("-" * 47)
    vc.rwrite("")

    # Cycle through configuration calculations num_cycles times
    # using num_config threads to compute all configurations.
    vc.rwrite("THREADED CALCULATIONS")
    vc.rwrite("")
    for cycle in range(num_cycles):
        vc.rwrite("### Cycle {} ###".format(cycle + 1))
        vc.rwrite("")
        # Randomize order of configurations for thread calculations.
        # This way the configurations are visited in different order
        # during each cycle.
        order = list(range(num_configs))
        random.shuffle(order)

        # Change positions, request energy which will issue a calculate
        # operation, then restore positions. This will forcce the threaded
        # calculations to recalculate and not use the energy and forces
        # stored in the config_atoms[] objects from previous cycle
        for n in range(num_configs):
            posns = config_atoms[n].get_positions()
            config_atoms[n].positions += 1.0
            config_atoms[n].get_potential_energy()
            config_atoms[n].set_positions(posns)

        # Create threads. Each thread will compute one configuration.
        threads = []
        num_threads = num_configs
        for n in range(num_threads):
            thr = threading.Thread(
                name="thread-" + str(n),
                target=compute_config,
                args=(config_atoms[order[n]],),
            )
            threads.append(thr)
            thr.start()

        # Start all threads and wait until they're all done executing
        for thr in threads:
            thr.join()

        # All threads are complete at this points. Each configuration
        # contains the energy and forces computed by a thread. Compare
        # these with the original value computed in the earlier reference
        # sequential calculation and make sure there are no discrepancies.
        vc.rwrite(
            "{0:>6}   {1:>6}   {2:>13}   {3:>13}   {4:>6}   {5:>6}".format(
                "Config", "#Atoms", "Energy", "Ave Norm", "Thread", "Status"
            )
        )
        vc.rwrite("-" * 65)
        all_agree = True
        for n in range(num_configs):
            natoms = len(config_atoms[n])
            energy = config_atoms[n].get_potential_energy()
            forces = config_atoms[n].get_forces()
            avenorm = np.linalg.norm(forces) / natoms
            thrnum = order.index(n)  # number of thread that processed config n
            agree = np.allclose(energy, config_energy[n]) and np.allclose(
                forces, config_forces[n]
            )
            agree_s = agree_text[agree]
            vc.rwrite(
                "{0:6d}   {1:6d}   {2:13.6e}   {3:13.6e}   {4:6d}   {5:<6}".format(
                    n, natoms, energy, avenorm, thrnum, agree_s
                )
            )
            all_agree = all_agree and agree

        # Close table
        vc.rwrite("-" * 65)
        vc.rwrite("")

        # Stop cycling if an error is detected
        if not all_agree:
            break

    # Report grade
    vc.rwrite("=" * dashwidth)
    vc.rwrite(
        "To pass this verification check all threads must give identical results,"
    )
    vc.rwrite(
        "i.e. the same total energy and the same average norm (force norm divided"
    )
    vc.rwrite(
        'by the number of atoms). This is indicated by the "Status" column in the'
    )
    vc.rwrite("above table.")
    vc.rwrite("")
    if all_agree:
        vc_grade = "P"
        vc_comment = (
            "All threads give identical results for tested case. Model "
            "appears to be thread-safe."
        )
    else:
        vc_grade = "F"
        vc_comment = (
            "One or more threads gave different results than a single thread "
            "calculation. The model is not thread-safe."
        )

    return vc_grade, vc_comment


################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == "__main__":

    vc_args = {
        "vc_name": "vc-thread-safe",
        "vc_author": __author__,
        "vc_description": kim_vc_utils.vc_stripall(__doc__),
        "vc_category": "mandatory",
        "vc_grade_basis": "passfail",
        "vc_files": [],
        "vc_debug": False,
    }  # Set to True to get exception traceback info

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Execute VC
    kim_vc_utils.setup_and_run_vc(do_vc, model, **vc_args)