#!/usr/bin/env python3
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2019, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#
################################################################################

# The docstring below is vc_description
"""Comparison of the analytical forces obtained from the model with forces
computed by numerical differentiation using Richardson extrapolation for a
randomly distorted non-periodic structure based off the equilibrium
face-centered cubic (fcc) structure. Separate configurations are tested for
each species supported by the model, as well as one containing a random
distribution of all species. Configurations used for testing are provided
as auxiliary files."""

# TODO:
#
# This program does not deal with the issue of discontinuities in the
# model energy.  Such discontinuities could lead to incorrect numerical
# derivatives and to incorrect conclusions about the accuracy of the
# analytical forces.  A useful extension would be to check whether
# energy discontinuities are responsible for the component with the
# largest error, and if yes, ignore it and move on to the next largest,
# and so on.  Discontinuities could be found using the "detectedge"
# algorithm developed for the "Chebfun" Matlab package. See the
# discussion in "Piecewise Smooth Chebfuns", Pachon, Platte and Trefethen,
# https://www.cs.ox.ac.uk/files/717/NA-08-07.pdf
# and the code here:
# https://github.com/chebfun/chebfun/blob/development/%40fun/detectEdge.m

# Python 2-3 compatible code issues
from __future__ import print_function

try:
    input = raw_input
except NameError:
    pass

from ase.lattice.cubic import FaceCenteredCubic
from ase.calculators.kim import KIM, get_model_supported_species
import kim_python_utils.ase as kim_ase_utils
import kim_python_utils.vc as kim_vc_utils
import random
import numpy as np
import numdifftools as nd
from numdifftools.step_generators import MaxStepGenerator
import scipy.optimize

__version__ = "003"
__author__ = "Ellad Tadmor"

################################################################################
#
#   FUNCTIONS
#
################################################################################


################################################################################
def negpot(p, at=0, dof=0, atoms=None):
    """
    Function that takes the value 'p' of degree of freedom 'dof' of atom 'at'
    and returns the negative of the total potential energy of full system of
    atoms. Used by the numerical derivative method.
    """
    if atoms is None:
        return 0
    sve = (atoms[at].position)[dof]
    (atoms[at].position)[dof] = p
    pot = atoms.get_potential_energy()
    (atoms[at].position)[dof] = sve
    return -pot


################################################################################
def perform_numerical_derivative_check(vc, atoms, heading, dashwidth):
    """
    Perform a numerical derivative check for the ASE atoms object in 'atoms'.
    """
    # compute analytical forces (negative gradient of cohesive energy)
    energy = atoms.get_potential_energy()
    forces = atoms.get_forces()

    # Loop over atoms and compute numerical derivative check
    sg = [
        MaxStepGenerator(
            base_step=1e-4, num_steps=14, use_exact_steps=True, step_ratio=1.6, offset=0
        ),
        MaxStepGenerator(
            base_step=1e-3, num_steps=14, use_exact_steps=True, step_ratio=1.6, offset=0
        ),
        MaxStepGenerator(
            base_step=1e-2, num_steps=14, use_exact_steps=True, step_ratio=1.6, offset=0
        ),
    ]
    nsg = 4  # number of step generators to try
    Dnegpot = [nd.Derivative(negpot, full_output=True)]
    for i in range(nsg - 1):
        Dnegpot.append(nd.Derivative(negpot, step=sg[i], full_output=True))
    forces_num = np.zeros(shape=(len(atoms), 3), dtype=float, order="C")
    forces_uncert = np.zeros(shape=(len(atoms), 3), dtype=float, order="C")
    forces_failed = np.zeros(shape=(len(atoms), 3), dtype=int, order="C")
    for at in range(0, len(atoms)):
        for dof in range(0, 3):
            p = (atoms[at].position)[dof]
            errmin_sg = 1e30
            failed_to_get_deriv = True
            for i in range(nsg):
                try:
                    val, info = Dnegpot[i](p, at=at, dof=dof, atoms=atoms)
                    if abs(val - forces[at, dof]) < errmin_sg:
                        errmin_sg = abs(val - forces[at, dof])
                        val_best = val
                        info_best = info
                        failed_to_get_deriv = False
                except:  # noqa: E722
                    # Failed to compute derivative, so skip this value
                    (atoms[at].position)[dof] = p  # Restore value which may have been
                    # left changed when exception was
                    # generated.
            if failed_to_get_deriv:
                # Failed all attempts for this at/dof, assume this is
                # because the potential is being evaluated outside its
                # legal range and skip. (TODO: A more careful check
                # on the reason for the error would be good.)
                forces_failed[at, dof] = 1
            else:
                val = val_best
                info = info_best
                forces_num[at, dof] = val
                forces_uncert[at, dof] = info.error_estimate

    # Identify outliers using a box plot construction with fences
    # (See http://www.itl.nist.gov/div898/handbook/prc/section1/prc16.htm)
    # We'll take all results above the upper outer fence to be outliers.
    #
    # create flattened version of forces_uncert without failures
    numterms = 3 * len(atoms) - np.sum(forces_failed)
    forces_uncert_without_failures = np.zeros(numterms, dtype=float)
    i = -1
    for at in range(0, len(atoms)):
        for dof in range(0, 3):
            if not forces_failed[at, dof]:
                i += 1
                forces_uncert_without_failures[i] = forces_uncert[at, dof]
    uncert_lower_quartile = np.percentile(forces_uncert_without_failures, 25)
    uncert_upper_quartile = np.percentile(forces_uncert_without_failures, 75)
    uncert_interquartile_range = uncert_upper_quartile - uncert_lower_quartile
    uncert_upper_fence = uncert_upper_quartile + 3 * uncert_interquartile_range

    # Initialize for printing
    frmt_head = "{0:>6}  {1:>4} {2:>3} {3:>25} {4:>25} {5:>15} {6:>15}"
    frmt_line = {
        True: "{0: 6d}  {1:4} {2: 3d} {3: 25.15e} {4: 25.15e} {5: 15.5e} {6: 15.5e} {7:1}",
        False: "             {2: 3d} {3: 25.15e} {4: 25.15e} {5: 15.5e} {6: 15.5e} {7:1}",
    }
    frmt_fail = "{0: 6d}  {1:4} {2: 3d} {3:>25} {4:>25} {5:>15} {6:>15} {7:1}"
    vc.rwrite(
        "Comparison of analytical forces obtained from the model, "
        "the force computed as a numerical derivative"
    )
    vc.rwrite(
        "of the energy, the difference between them, and the uncertainty "
        "in the numerical estimate of the force."
    )
    vc.rwrite(
        "The computed equilibrium lattice constant for the crystal (a0) "
        "is given in the heading."
    )
    vc.rwrite("")
    vc.rwrite(heading)
    vc.rwrite("-" * dashwidth)
    args = (
        "Part",
        "Spec",
        "Dir",
        "Force_model",
        "Force_numer",
        "|Force diff|",
        "uncertainty",
    )
    vc.rwrite(frmt_head.format(*args))
    vc.rwrite("-" * dashwidth)

    # Identify max error and print numerical derivative results
    eps_prec = np.finfo(float).eps
    errmax = 0.0
    at_least_one_result_discarded = False
    at_least_one_failure = False
    for at in range(0, len(atoms)):
        for dof in range(0, 3):
            if forces_failed[at, dof]:
                # skip at/dof for which force was not computed
                args = (at + 1, atoms[at].symbol, dof + 1, "", "", "", "", "F")
                vc.rwrite(frmt_fail.format(*args))
                at_least_one_failure = True
                continue
            forcediff = abs(forces[at, dof] - forces_num[at, dof])
            den = max(abs(forces_num[at, dof]), eps_prec)
            if forces_uncert[at, dof] < uncert_upper_fence:
                # Result is not an outlier. Include it in determining max error
                lowacc_mark = " "
                if forcediff / den > errmax:
                    errmax = forcediff / den
                    at_max = at
                    dof_max = dof
            else:
                lowacc_mark = "*"
                at_least_one_result_discarded = True
            # Print results line
            args = (
                at + 1,
                atoms[at].symbol,
                dof + 1,
                forces[at, dof],
                forces_num[at, dof],
                forcediff,
                forces_uncert[at, dof],
                lowacc_mark,
            )
            vc.rwrite(frmt_line[dof == 0].format(*args))
            if dof == 2:
                vc.rwrite("-" * dashwidth)
    if at_least_one_result_discarded:
        vc.rwrite(
            "* Starred lines are suspected outliers and are not "
            "included when determining the error."
        )
        vc.rwrite(
            "  A calculation is considered an outlier if it has an "
            "uncertainty that lies at an abnormal"
        )
        vc.rwrite(
            "  distance from the other uncertainties in this set of "
            "calculations.  Outliers are determined"
        )
        vc.rwrite(
            "  using the box plot construction with fences. "
            "An outlier could indicate a problem with the"
        )
        vc.rwrite(
            "  the numerical differentiation or problems with the "
            "potential energy, such as discontinuities."
        )
    if at_least_one_failure:
        if at_least_one_result_discarded:
            vc.rwrite("")
        vc.rwrite(
            "WARNING: Numerical derivative could not be computed for "
            "at least one atom/dof."
        )
        vc.rwrite(
            '         Failed calculations indicated with an "F". This '
            "can be due to attempts"
        )
        vc.rwrite(
            "         to evaluate the model outside its legal range "
            "during a numerical derivative."
        )
        vc.rwrite(
            "         calculation. These lines are ignored when " "computing the error."
        )

    # Print summary
    vc.rwrite("")
    vc.rwrite(
        "Maximum error obtained for particle = {0:d}, direction = {1:d}:".format(
            at_max + 1, dof_max + 1
        )
    )
    vc.rwrite("")
    vc.rwrite("              |F_model - F_numer|")
    vc.rwrite("    error = ----------------------- = {0:.5e}".format(errmax))
    vc.rwrite("              max{|F_numer|, eps}")
    vc.rwrite("")
    vc.rwrite("")

    return errmax


################################################################################
def cubic_cell_energy(alat, atoms, ncells_per_side):
    """
    Calculate the energy of the passed 'atoms' structure containing a
    cubic structure with 'ncells_per_side'. Scale to lattice constant
    'alat' (passed as a nd array of length 1) and return the energy.
    """
    acell = alat[0] * ncells_per_side
    atoms.set_cell([acell, acell, acell], scale_atoms=True)
    return atoms.get_potential_energy()


################################################################################
def do_vc(model, vc):
    """
    Perform Numerical Derivative Check VC
    """
    # Get supported species
    species = get_model_supported_species(model)
    species = kim_ase_utils.remove_species_not_supported_by_ASE(list(species))
    species.sort()

    # Basic cell parameters
    lattice_constant = 2.5
    max_lattice_constant = 10.0
    pert_amp = 0.1 * lattice_constant
    ncells_per_side = 2
    seed = 13
    random.seed(seed)

    # Finite domain in which to embed the finite cluster of atoms we'll translate and
    # invert
    large_cell_len = 7 * lattice_constant * ncells_per_side

    # Print Vc info
    dashwidth = 101
    vc.rwrite("")
    vc.rwrite("-" * dashwidth)
    vc.rwrite("Results for KIM Model      : %s" % model.strip())
    vc.rwrite("Supported species          : %s" % " ".join(species))
    vc.rwrite("")
    vc.rwrite("random seed                = %d" % seed)
    vc.rwrite("perturbation amplitude     = %0.3f" % pert_amp)
    vc.rwrite("number unit cells per side = %d" % ncells_per_side)
    vc.rwrite("-" * dashwidth)
    vc.rwrite("")

    # Initialize variables
    errmaxmax = 0.0
    got_atleast_one_equil = False
    got_atleast_one_deriv = False
    alat_ave = 0.0

    # Perform numerical derivative check for monotatomic systems
    for spec in species:
        # Check if this species has non-trivial force and energy interactions
        atoms_interacting_energy, atoms_interacting_force = kim_ase_utils.check_if_atoms_interacting(
            model, symbols=[spec, spec]
        )
        if not atoms_interacting_energy:
            vc.rwrite("")
            vc.rwrite(
                "WARNING: The model provided, {}, does not possess a non-trivial energy "
                "interaction for species {} as required by this Verification "
                "Check. Skipping...".format(model, spec)
            )
            vc.rwrite("")
            continue

        if not atoms_interacting_force:
            vc.rwrite("")
            vc.rwrite(
                "WARNING: The model provided, {}, does not possess a non-trivial force "
                "interaction for species {} as required by this Verification Check.  "
                "Skipping...".format(model, spec)
            )
            vc.rwrite("")
            continue

        # find equilibrium lattice constant, so that the numerical derivatives
        # of all potentials are evaluated in a similar portion of their
        # potential energy surface, making comparisons between potentials
        # more meaningful.
        calc = KIM(model)
        alat = lattice_constant
        done = False
        got_equil = False
        while not done:
            atoms = FaceCenteredCubic(
                size=(1, 1, 1), latticeconstant=alat, symbol=spec, pbc=True
            )
            atoms.set_calculator(calc)
            try:
                res = scipy.optimize.minimize(
                    cubic_cell_energy,
                    alat,
                    args=(atoms, 1),
                    method="Nelder-Mead",
                    tol=1e-6,
                )
                alat = res.x[0]
                done = True
                got_equil = True
                got_atleast_one_equil = True
            except:  # noqa: E722
                # failed for some reason (assume it's because of KIM error)
                alat += 0.5
                if alat > max_lattice_constant:
                    done = True

        if hasattr(calc, "__del__"):
            # Delete and reinstantiate calculator in case the model is an SM that is
            # tightly bound to the number of atoms in the system (including ghost atoms)
            calc.__del__()
            calc = KIM(model)

        # Create a non-periodic crystal with the equilibrium lattice constant
        if not got_equil:
            vc.rwrite(
                "FAILED to get equilibrium constant for species =  %s; skipping" % spec
            )
            continue

        alat_ave += alat
        atoms = FaceCenteredCubic(
            size=(ncells_per_side, ncells_per_side, ncells_per_side),
            latticeconstant=alat,
            symbol=spec,
            pbc=False,
        )

        # Move our finite cluster of atoms to the center of a large cell
        large_cell_len = 7 * alat * ncells_per_side
        atoms.set_cell([large_cell_len, large_cell_len, large_cell_len])
        trans = [0.5 * large_cell_len] * 3
        atoms.translate(trans)

        atoms.set_calculator(calc)
        # Rescale and perturb crystal for numerical derivative calculation
        kim_ase_utils.rescale_to_get_nonzero_forces(atoms, 0.01)
        kim_ase_utils.randomize_positions(atoms, pert_amp)
        kim_ase_utils.perturb_until_all_forces_sizeable(atoms, pert_amp)
        aux_file = "config-" + spec + ".xyz"
        vc.vc_files.append(aux_file)
        vc.write_aux_ase_atoms(aux_file, atoms, "xyz")
        # Perform numerical derivative test
        a0string = "a0 = {}".format(alat)
        heading = (
            "MONOATOMIC STRUCTURE -- Species = "
            + spec
            + " | "
            + a0string
            + ' | (Configuration in file "'
            + aux_file
            + '")'
        )
        try:
            errmax = perform_numerical_derivative_check(vc, atoms, heading, dashwidth)
            errmaxmax = max(errmax, errmaxmax)
            got_atleast_one_deriv = True
        except:  # noqa: E722
            pass
        finally:
            # Explicitly close calculator to ensure any allocated memory (particularly for
            # Simulator Models) is freed
            if hasattr(calc, "__del__"):
                calc.__del__()

    if got_atleast_one_equil and len(species) > 1:
        # Perform numerical derivative check for mixed system
        alat_ave = alat_ave / len(species)

        while True:
            atoms = FaceCenteredCubic(
                size=(ncells_per_side, ncells_per_side, ncells_per_side),
                latticeconstant=alat_ave,
                symbol="H",
                pbc=True,
            )
            if len(atoms) < len(species):
                ncells_per_side += 1
            else:
                break

        kim_ase_utils.randomize_species(atoms, species)
        calc = KIM(model)
        atoms.set_calculator(calc)
        # Find equilibrium lattice constant
        res = scipy.optimize.minimize(
            cubic_cell_energy,
            alat_ave,
            args=(atoms, ncells_per_side),
            method="Nelder-Mead",
            tol=1e-6,
        )
        alat = res.x[0]

        # Change periodicity to false
        atoms.set_pbc([False, False, False])

        # Move our finite cluster of atoms to the center of a large cell
        large_cell_len = 7 * alat * ncells_per_side
        atoms.set_cell([large_cell_len, large_cell_len, large_cell_len])
        trans = [0.5 * large_cell_len] * 3
        atoms.translate(trans)

        # Rescale and perturb crystal for numerical derivative calculation
        kim_ase_utils.rescale_to_get_nonzero_forces(atoms, 0.01)
        kim_ase_utils.randomize_positions(atoms, pert_amp)
        kim_ase_utils.perturb_until_all_forces_sizeable(atoms, pert_amp)
        aux_file = "config-" + "".join(species) + ".xyz"
        vc.vc_files.append(aux_file)
        vc.write_aux_ase_atoms(aux_file, atoms, "xyz")
        # Perform numerical derivative test
        a0string = "a0 = {}".format(alat)
        heading = (
            "MIXED STRUCTURE -- Species = "
            + " ".join(species)
            + " | "
            + a0string
            + ' | (Configuration in file "'
            + aux_file
            + '")'
        )
        try:
            errmax = perform_numerical_derivative_check(vc, atoms, heading, dashwidth)
            errmaxmax = max(errmax, errmaxmax)
            got_atleast_one_deriv = True
        except:  # noqa: E722
            pass
        finally:
            # Explicitly close calculator to ensure any allocated memory (particularly for
            # Simulator Models) is freed
            if hasattr(calc, "__del__"):
                calc.__del__()

    if got_atleast_one_deriv:

        # Compute grade
        vc.rwrite("=" * dashwidth)
        vc.rwrite(
            "Grade is based on maximum error across all systems "
            "(for a model supporting multiple species):"
        )
        vc.rwrite("")
        vc.rwrite("Maximum error     = {:15.5e}".format(errmaxmax))
        vc.rwrite("Machine precision = {:15.5e}".format(np.finfo(float).eps))
        vc.rwrite("")
        vc_grade, vc_comment = kim_vc_utils.vc_letter_grade_machine_precision(errmaxmax)

        vc.rwrite(
            "Note: Lower grades can reflect errors in the analytical "
            "forces, a lack of smoothness in the potential, or "
            "loss of precision in the numerical derivative computation."
        )
        vc.rwrite("")

        return vc_grade, vc_comment

    else:
        msg = "ERROR: Failed to compute a single numerical derivative."
        vc.rwrite("")
        vc.rwrite(msg)
        vc.rwrite("")
        raise RuntimeError(msg)


################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == "__main__":

    vcargs = {
        "vc_name": "vc-forces-numerical-derivative",
        "vc_author": __author__,
        "vc_description": kim_vc_utils.vc_stripall(__doc__),
        "vc_category": "consistency",
        "vc_grade_basis": "graded",
        "vc_files": [],
        "vc_debug": False,  # Set to True to get exception traceback info
    }

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Execute VC
    kim_vc_utils.setup_and_run_vc(do_vc, model, **vcargs)