#!/usr/bin/env python3
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2019, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#
################################################################################

# The docstring below is vc_description
"""
Check that the model supports periodic boundary conditions correctly.
If the simulation box is increased by an integer factor along a periodic
direction, the total energy must multiply by that factor and the forces
on atoms that are periodic copies of each other must be the same.
The check is performed for a randomly distorted non-periodic
face-centered cubic (FCC) cube base structure. Separate configurations
are tested for each species supported by the model, as well as one containing
a random distribution of all species. For each configuration, all
possible combinations of periodic boundary conditions are tested:
TFF, FTF, FFT, TTF, TFT, TTF, TTT (where 'T' indicates periodicity
along a direction, and 'F' indicates no periodicity). The verification
check passes if the energy of all configurations that the model is able
to compute support all periodic boundary conditions correctly.
Configurations used for testing are provided as auxiliary files."""

# Python 2-3 compatible code issues
from __future__ import print_function

try:
    input = raw_input
except NameError:
    pass

import kim_python_utils.ase as kim_ase_utils
import kim_python_utils.vc as kim_vc_utils
from ase.calculators.kim import KIM, get_model_supported_species
from ase import Atoms
from ase.lattice.cubic import FaceCenteredCubic
import random
import itertools
import numpy as np

__version__ = "004"
__author__ = "Ellad Tadmor"

################################################################################
#
#   FUNCTIONS
#
################################################################################


################################################################################
def perform_pbc_check(model, vc, atoms, heading, dashwidth):
    """
    Perform periodic boundary conditions check for the ASE atoms
    object in 'atoms'
    """
    # set comparison tolerance
    tole = 1e-8
    eps_prec = np.finfo(float).eps

    # compute the energy and forces in the original location
    energy = atoms.get_potential_energy()
    forces = atoms.get_forces()
    if np.isnan(forces).any():
        raise RuntimeError("ERROR: Computed forces include at least one nan.")       

    # extract original system information
    pbc = atoms.get_pbc()
    species = atoms.get_chemical_symbols()
    positions = atoms.get_positions()
    cell = atoms.get_cell()
    num_atoms = len(atoms)

    # construct a system that is extended in all periodic directions
    num_periodic = pbc.tolist().count(True)
    num_copies = 2 ** num_periodic
    species2 = species * num_copies
    positions2 = positions.copy()
    cell2 = cell.copy()
    for j in range(3):
        if pbc[j]:
            cell2[j] = np.multiply(cell2[j], 2.0)
            natoms = len(positions2)
            positions2 = np.tile(positions2, (2, 1))
            for i in range(natoms, 2 * natoms):
                positions2[i, :] += cell[j]
    atoms2 = Atoms("".join(species2), positions=positions2, cell=cell2, pbc=pbc)
    calc2 = KIM(model)
    atoms2.calc=calc2

    # compute the energy and forces in the original location
    energy2 = atoms2.get_potential_energy()
    forces2 = atoms2.get_forces()
    if np.isnan(forces2).any():
        raise RuntimeError("ERROR: Computed forces include at least one nan.")     
    

    # check if energy scales in the expected way
    den = max(0.5 * (abs(energy2) + abs(num_copies * energy)), eps_prec)
    passed_energy = abs(energy2 - num_copies * energy) / den < tole

    # report energy results
    vc.rwrite("")
    vc.rwrite(heading)
    vc.rwrite("-" * dashwidth)
    vc.rwrite("")
    vc.rwrite(
        "The system is doubled in p={0} periodic directions, which means an "
        "increase by a factor n=2^{0}={1}".format(num_periodic, num_copies)
    )
    vc.rwrite("in the number of atoms and in the energy.")
    vc.rwrite("")
    vc.rwrite("Energy requirement:")
    vc.rwrite("")
    vc.rwrite(
        "V(DBL_p(r_1,...,r_N)) = (2^p) V(r_1,...,r_N), "
        "where r_i is the position of atom i, V is the potential energy, "
    )
    vc.rwrite(
        "and DBL_p is an operator that doubles the configuration in "
        "p periodic directions."
    )
    vc.rwrite("")
    vc.rwrite("V(DBL_p(r_1,...,r_N)) = {0}".format(energy2))
    vc.rwrite("2^p V(r_1,...,r_N)    = {0}".format(num_copies * energy))
    vc.rwrite("")

    # report force results and check if they forces in the doubled
    # periodic cell map back to the original forces as expected.
    vc.rwrite("Forces requirement:")
    vc.rwrite("")
    vc.rwrite(
        "f_k(DBL_p(r_1,...,r_N)) = f_(k % N)(r_1,...,r_N), where r_i "
        "is the position of atom i, f_k is the force on atom k "
    )
    vc.rwrite(
        "(where k runs from 1 to the number of atoms in the doubled "
        "configuration), DBL_p doubles the configuration "
    )
    vc.rwrite(
        "in p periodic directions, N is the number of atoms in the original "
        "configuration, and % is the modulo operator."
    )
    vc.rwrite("")

    hfmt = "{:>3}" + " " * 17 + "{}" + " " * 34 + "{}"
    fmt = "{:>3}   " + "{: .8e}   " * 3 + "|  " + "{: .8e}   " * 3 + "{}"
    vc.rwrite(hfmt.format("k", "f_k(DBL_p(r_1,...,r_N))", "f_(k % N)(r_1,...,r_N)"))
    vc.rwrite("-" * dashwidth)
    passed_forces = True
    for i in range(0, len(atoms2)):
        f_lhs = forces2[i]
        j = i % num_atoms
        f_rhs = forces[j]
        den = np.maximum(0.5 * (np.absolute(f_lhs) + np.absolute(f_rhs)), eps_prec)
        force_ok = np.all(np.absolute(f_lhs - f_rhs) / den < tole)
        stat = ""
        if not force_ok:
            passed_forces = False
            stat = "ERR"
        vc.rwrite(
            fmt.format(
                i, f_lhs[0], f_lhs[1], f_lhs[2], f_rhs[0], f_rhs[1], f_rhs[2], stat
            )
        )
    vc.rwrite("-" * dashwidth)

    # determine overall result
    passed = passed_energy and passed_forces

    if passed:
        vc.rwrite(
            "PASS: Energies and forces are the same to within a "
            "relative error of {0}".format(tole)
        )
    else:
        vc.rwrite(
            "FAIL: Energies and/or forces differ by more than "
            "a relative error of {0}".format(tole)
        )
    vc.rwrite("-" * dashwidth)

    return passed


################################################################################
def do_vc(model, vc):
    """
    Do periodicity support check
    """
    # Max iterations allowed for some of the while loops below
    max_iters = 2000

    # Get supported species
    species = get_model_supported_species(model)
    species = kim_ase_utils.remove_species_not_supported_by_ASE(list(species))
    species.sort()

    # Basic cell parameters
    lattice_constant_orig = 3.0
    pert_amp_orig = 0.1 * lattice_constant_orig
    ncells_per_side = 1
    seed = 13
    random.seed(seed)

    # Print VC defining information
    dashwidth = 120
    vc.rwrite("")
    vc.rwrite("-" * dashwidth)
    vc.rwrite("Results for KIM Model      : %s" % model.strip())
    vc.rwrite("Supported species          : %s" % " ".join(species))
    vc.rwrite("")
    vc.rwrite("random seed                = %d" % seed)
    vc.rwrite("lattice constant (orig)    = %0.3f" % lattice_constant_orig)
    vc.rwrite("perturbation amplitude     = %0.3f" % pert_amp_orig)
    vc.rwrite("number unit cells per side = %d" % ncells_per_side)
    vc.rwrite("-" * dashwidth)
    vc.rwrite("")

    # Initialize variables
    got_atleast_one = False
    passed_all = True

    # Create list of all possible periodic boundary conditions
    onezero = (1, 0)
    TF = {1: "T", 0: "F"}
    pbc_options = list(itertools.product(onezero, repeat=3))
    pbc_options.remove((0, 0, 0))

    # Perform periodicity check for monoatomic systems
    for spec in species:
        # Check if this species has non-trivial force and energy interactions
        atoms_interacting_energy, atoms_interacting_force = kim_ase_utils.check_if_atoms_interacting(
            model, symbols=[spec, spec]
        )
        if not atoms_interacting_energy:
            vc.rwrite("")
            vc.rwrite(
                "WARNING: The model provided, {}, does not possess a non-trivial energy "
                "interaction for species {} as required by this Verification "
                "Check. Skipping...".format(model, spec)
            )
            vc.rwrite("")
            continue

        if not atoms_interacting_force:
            vc.rwrite("")
            vc.rwrite(
                "WARNING: The model provided, {}, does not possess a non-trivial force "
                "interaction for species {} as required by this Verification Check.  "
                "Skipping...".format(model, spec)
            )
            vc.rwrite("")
            continue

        for pbc in pbc_options:
            lattice_constant = lattice_constant_orig
            got_initial_config = False
            calc = KIM(model)
            while not got_initial_config:
                atoms = FaceCenteredCubic(
                    size=(ncells_per_side, ncells_per_side, ncells_per_side),
                    latticeconstant=lattice_constant,
                    symbol=spec,
                    pbc=pbc,
                )
                atoms.calc=calc
                try:
                    kim_ase_utils.rescale_to_get_nonzero_forces(atoms, 0.01)
                    # Make sure forces can be computed for the imposed periodicity
                    # (Routine rescale_to_get_nonzero_forces turns off periodicity.
                    # There are cases where turning it back to the settings in 'pbc'
                    # generates an error.)
                    if hasattr(calc, "__del__"):
                        # Delete and reinstantiate calculator in case the model is an SM that is
                        # tightly bound to the number of atoms in the system (including ghost atoms)
                        calc.__del__()
                        calc = KIM(model)
                        atoms.calc=calc                    
                    if np.isnan(atoms.get_forces()).any(): # An error will trigger the except Exception clause
                        raise RuntimeError("ERROR: Computed forces include at least one nan.")                                       
                    got_initial_config = True
                except kim_ase_utils.KIMASEError:
                    # Routine failed in on recoverable manner
                    raise  # re-raise same exception
                except Exception:
                    # Initial config failed. This most likely due to an evaluation
                    # outside the legal model range. Increase lattice constant and
                    # try again.
                    lattice_constant += 0.25
                    if lattice_constant > 10.0:
                        raise RuntimeError(
                            "Cannot find a working configuration within a reasonable "
                            "lattice constant range."
                        )

            # Randomize positions
            save_positions = atoms.get_positions()
            pert_amp = pert_amp_orig
            got_randomized_config = False
            iters = 0
            while not got_randomized_config:
                try:
                    kim_ase_utils.randomize_positions(atoms, pert_amp)
                    if np.isnan(atoms.get_forces()).any(): # An error will trigger the except Exception clause
                        raise RuntimeError("ERROR: Computed forces include at least one nan.")             
                    got_randomized_config = True
                except:  # noqa: E722
                    # Failed to compute forces; reset to original posns and retry
                    atoms.set_positions(save_positions)
                    pert_amp *= 0.5  # cut perturbation amplitude by half
                    iters += 1
                    if iters >= max_iters:
                        raise RuntimeError(
                            "Iteration limit exceeded when randomizing positions "
                            "during check of pbc={} for species {}".format(pbc, spec)
                        )

            # Move atoms around until all forces are sizeable
            kim_ase_utils.perturb_until_all_forces_sizeable(atoms, pert_amp)
            pbcstr = "".join(TF[x] for x in pbc)
            aux_file = "config-" + spec + "-" + pbcstr + ".xyz"
            vc.vc_files.append(aux_file)
            vc.write_aux_ase_atoms(aux_file, atoms, "xyz")
            heading = (
                "MONOATOMIC STRUCTURE -- Species = " + spec + ", "
                "PBC = " + pbcstr + '   (Configuration in file "' + aux_file + '")'
            )
            try:
                passed = perform_pbc_check(model, vc, atoms, heading, dashwidth)
                passed_all = passed_all and passed
                got_atleast_one = True
            except:  # noqa: E722
                pass
            finally:
                if hasattr(calc, "__del__"):
                    # Force deallocation of calculator memory for tidiness (relevant for SMs)
                    calc.__del__()

    # Perform periodicity check for mixed system
    if len(species) > 1:
        for pbc in pbc_options:

            lattice_constant = lattice_constant_orig
            while True:
                atoms = FaceCenteredCubic(
                    size=(ncells_per_side, ncells_per_side, ncells_per_side),
                    latticeconstant=lattice_constant,
                    symbol="H",
                    pbc=pbc,
                )
                if len(atoms) < len(species):
                    ncells_per_side += 1
                else:
                    break
            kim_ase_utils.randomize_species(atoms, species)
            calc = KIM(model)
            atoms.calc=calc
            got_initial_config = False
            while not got_initial_config:
                try:
                    kim_ase_utils.rescale_to_get_nonzero_forces(atoms, 0.01)
                    # Make sure forces can be computed for the imposed periodicity
                    # (Routine rescale_to_get_nonzero_forces turns off periodicity.
                    # There are cases where turning it back to the settings in 'pbc'
                    # generates an error.)
                    if hasattr(calc, "__del__"):
                        # Delete and reinstantiate calculator in case the model is an SM that is
                        # tightly bound to the number of atoms in the system (including ghost atoms)
                        calc.__del__()
                        calc = KIM(model)
                        atoms.calc=calc                    
                    if np.isnan(atoms.get_forces()).any(): # An error will trigger the except Exception clause
                        raise RuntimeError("ERROR: Computed forces include at least one nan.")             
                    got_initial_config = True
                except kim_ase_utils.KIMASEError:
                    # Routine failed in on recoverable manner
                    raise  # re-raise same exception
                except Exception:
                    # Initial config failed. This most likely due to an evaluation
                    # outside the legal model range. Increase lattice constant and
                    # try again.
                    lattice_constant += 0.25
                    if lattice_constant > 10.0:
                        raise RuntimeError(
                            "Cannot find a working configuration within # a reasonable lattice constant range."
                        )
                    acell = lattice_constant * ncells_per_side
                    atoms.set_cell([acell, acell, acell], scale_atoms=True)

            # Randomize positions
            save_positions = atoms.get_positions()
            pert_amp = pert_amp_orig
            got_randomized_config = False
            iters = 0
            while not got_randomized_config:
                try:
                    kim_ase_utils.randomize_positions(atoms, pert_amp)
                    if np.isnan(atoms.get_forces()).any(): # An error will trigger the except Exception clause
                        raise RuntimeError("ERROR: Computed forces include at least one nan.")             
                    got_randomized_config = True
                except:  # noqa: E722
                    # Failed to compute forces; reset to original posns and retry
                    atoms.set_positions(save_positions)
                    pert_amp *= 0.5  # cut perturbation amplitude by half
                    iters += 1
                    if iters >= max_iters:
                        raise RuntimeError(
                            "Iteration limit exceeded when randomizing positions "
                            "during check of pbc={} for mixed species.".format(pbc)
                        )

            kim_ase_utils.perturb_until_all_forces_sizeable(atoms, pert_amp)
            pbcstr = "".join(TF[x] for x in pbc)
            aux_file = "config-" + "".join(species) + "-" + pbcstr + ".xyz"
            vc.vc_files.append(aux_file)
            vc.write_aux_ase_atoms(aux_file, atoms, "xyz")
            heading = (
                "MIXED STRUCTURE -- Species = " + " ".join(species) + ", "
                "PBC = " + pbcstr + '   (Configuration in file "' + aux_file + '")'
            )
            try:
                passed = perform_pbc_check(model, vc, atoms, heading, dashwidth)
                passed_all = passed_all and passed
                got_atleast_one = True
            except:  # noqa: E722
                pass
            finally:
                if hasattr(calc, "__del__"):
                    # Force deallocation of calculator memory for tidiness (relevant for SMs)
                    calc.__del__()

    if got_atleast_one:

        # Compute grade
        vc.rwrite("=" * dashwidth)
        # Compute grade
        vc.rwrite("")
        vc.rwrite("=" * dashwidth)
        vc.rwrite(
            "To pass this verification check the model must correctly "
            "support periodic boundary conditions for all configurations "
            "it was able to compute."
        )
        vc.rwrite("")

        if passed_all:
            vc_grade = "P"
            vc_comment = (
                "Periodic boundary conditions were correctly supported "
                "for all configurations that the model was able to compute."
            )
        else:
            vc_grade = "F"
            vc_comment = (
                "Periodic boundary conditions were NOT supported correctly "
                "for at least one configuration that the model was able to "
                "compute. This is an error in the implementation of the model."
            )

        return vc_grade, vc_comment

    else:
        msg = (
            "ERROR: Failed to compute periodic boundary conditions for all "
            "test configurations."
        )
        vc.rwrite("")
        vc.rwrite(msg)
        vc.rwrite("")
        raise RuntimeError(msg)


################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == "__main__":

    vcargs = {
        "vc_name": "vc-periodicity-support",
        "vc_author": __author__,
        "vc_description": kim_vc_utils.vc_stripall(__doc__),
        "vc_category": "mandatory",
        "vc_grade_basis": "passfail",
        "vc_files": [],
        "vc_debug": False,  # Set to True to get exception traceback info
    }

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Execute VC
    kim_vc_utils.setup_and_run_vc(do_vc, model, **vcargs)