#!/usr/bin/env python3
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2019, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#
################################################################################

# The docstring below is vc_description
"""Check whether a model is invariant with respect to rigid-body motion
(translation and rotation) as required by objectivity (material
frame-indifference). This is expected to be true for any model that does
not depend on an external field. The check is performed for a randomly
distorted non-periodic body-centered cubic (BCC) cube base structure.
Separate configurations are tested for each species supported by the model,
as well as one containing a random distribution of all species.  The energy
and forces of each configuration is compared with that of the same configuration
rotated about a random axis by an irrational angle and translated in
a random direction by an irrational distance. The verification check will
pass if the energy of all configurations that the model is able to compute
are invariant and the forces are mapped back by the inverse rotation.
Configurations used for testing are provided as auxiliary files."""

# Python 2-3 compatible code issues
from __future__ import print_function

try:
    input = raw_input
except NameError:
    pass

from ase.lattice.cubic import BodyCenteredCubic
from ase.calculators.kim import KIM, get_model_supported_species
import kim_python_utils.ase as kim_ase_utils
import kim_python_utils.vc as kim_vc_utils
import random
import numpy as np
import math

__version__ = "002"
__author__ = "Ellad Tadmor"

################################################################################
#
#   FUNCTIONS
#
################################################################################


################################################################################
def get_random_unit_vector():
    """
    Generates a random 3D unit vector (direction) with a uniform spherical
    distribution.
    stackoverflow.com/questions/5408276/python-uniform-spherical-distribution
    """
    phi = random.uniform(0, 2 * math.pi)
    costheta = random.uniform(-1.0, 1.0)

    theta = np.arccos(costheta)
    x = np.sin(theta) * np.cos(phi)
    y = np.sin(theta) * np.sin(phi)
    z = np.cos(theta)
    return (x, y, z)


################################################################################
def get_random_rotation_matrix():
    """
    Generate a uniformly distributed random rotation matrix.

    Based on James Arvo, "Fast Random Rotation Matrices",
             Graphics Gems III, Pages 117-120, 1992.

    Implentation in https://github.com/qobilidop/randrot

    """
    x1 = random.uniform(0, 1.0)
    x2 = random.uniform(0, 1.0)
    x3 = random.uniform(0, 1.0)
    R = np.array(
        [
            [np.cos(2 * np.pi * x1), np.sin(2 * np.pi * x1), 0],
            [-np.sin(2 * np.pi * x1), np.cos(2 * np.pi * x1), 0],
            [0, 0, 1],
        ]
    )
    v = np.array(
        [
            [np.cos(2 * np.pi * x2) * np.sqrt(x3)],
            [np.sin(2 * np.pi * x2) * np.sqrt(x3)],
            [np.sqrt(1 - x3)],
        ]
    )
    H = np.eye(3) - 2 * np.outer(v, v)
    M = -np.dot(H, R)
    return M


################################################################################
def perform_objectivity_check(vc, atoms, heading, dashwidth):
    """
    Perform objectivity check for the ASE atoms object in 'atoms'
    """
    # set comparison tolerance
    tole = 1e-8
    eps_prec = np.finfo(float).eps

    # compute the energy and forces in the original location
    energy_orig = atoms.get_potential_energy()
    forces_orig = atoms.get_forces()

    # shift atoms back to origin to apply rotation
    cell = atoms.get_cell()
    large_cell_len = cell[0][0]
    trans = np.asarray([0.5 * large_cell_len] * 3)
    atoms.translate(-trans)

    # get a random rotation matrix and apply rotation
    rot = get_random_rotation_matrix()
    for at in range(0, len(atoms)):
        atoms[at].position = np.dot(rot, atoms[at].position)

    # reshift atoms back to center of large finite cell
    atoms.translate(trans)

    energy_rot = atoms.get_potential_energy()

    # Get a random translation vector and apply translation
    trans = np.multiply(get_random_unit_vector(), math.pi)
    for at in range(0, len(atoms)):
        atoms[at].position += trans
    energy_rot_trans = atoms.get_potential_energy()
    forces_rot_trans = atoms.get_forces()

    # check if energy is the same up to a numerical tolerance
    den = max(0.5 * (abs(energy_rot_trans) + abs(energy_orig)), eps_prec)
    passed_energy = abs(energy_rot_trans - energy_orig) / den < tole
    den = max(0.5 * (abs(energy_rot_trans) + abs(energy_rot)), eps_prec)
    trans_passed = abs(energy_rot_trans - energy_rot) / den < tole

    # report energy results
    vc.rwrite("")
    vc.rwrite(heading)
    vc.rwrite("-" * dashwidth)
    vc.rwrite("Rotation matrix    = {0: .8e}  {1: .8e}  {2: .8e}".format(*rot[0, :]))
    vc.rwrite("                     {0: .8e}  {1: .8e}  {2: .8e}".format(*rot[1, :]))
    vc.rwrite("                     {0: .8e}  {1: .8e}  {2: .8e}".format(*rot[2, :]))
    vc.rwrite("")
    vc.rwrite("Translation vector = {0: .8e}  {1: .8e}  {2: .8e}".format(*trans))
    vc.rwrite("")
    vc.rwrite("Energy requirement:")
    vc.rwrite("")
    vc.rwrite(
        "V(Q*r_1+c,...,Q*r_N+c) = V(r_1,...,r_N), "
        "where r_i is the position of atom i, V is the potential energy, "
    )
    vc.rwrite("Q is a rotation, and c is a translation vector.")
    vc.rwrite("")
    vc.rwrite("V(Q*r_1+c,...,Q*r_N+c) = {0}".format(energy_rot_trans))
    vc.rwrite("V(Q*r_1,...,Q*r_N)     = {0}".format(energy_rot))
    vc.rwrite("V(r_1,...,r_N)         = {0}".format(energy_orig))
    vc.rwrite("")

    # check forces for objectivity
    vc.rwrite("Forces requirement:")
    vc.rwrite("")
    vc.rwrite(
        "f_i(Q*r_1+c,...,Q*r_N+c) = Q*f_i(r_1,...,r_N), "
        "where r_i is the position of atom i, f_i is the force "
    )
    vc.rwrite("on atom i, Q is a rotation matrix, and c is a translation vector.")
    vc.rwrite("")

    hfmt = "{:>3}" + " " * 17 + "{}" + " " * 36 + "{}"
    fmt = "{:>3}   " + "{: .8e}   " * 3 + "|  " + "{: .8e}   " * 3 + "{}"
    vc.rwrite(hfmt.format("i", "f_i(Q*r_1+c,...,Q*r_N+c)", "Q*f_i(r_1,...,r_N)"))
    vc.rwrite("-" * dashwidth)
    passed_forces = True
    for i in range(0, len(atoms)):
        f_lhs = forces_rot_trans[i]
        f_rhs = np.dot(rot, forces_orig[i])
        den = np.maximum(0.5 * (np.absolute(f_lhs) + np.absolute(f_rhs)), eps_prec)
        force_ok = np.all(np.absolute(f_lhs - f_rhs) / den < tole)
        stat = ""
        if not force_ok:
            passed_forces = False
            stat = "ERR"
        vc.rwrite(
            fmt.format(
                i, f_lhs[0], f_lhs[1], f_lhs[2], f_rhs[0], f_rhs[1], f_rhs[2], stat
            )
        )
    vc.rwrite("-" * dashwidth)

    # determine overall result
    passed = passed_energy and passed_forces

    if passed:
        vc.rwrite(
            "PASS: Energies and forces are the same to within a "
            "relative error of {0}".format(tole)
        )
    else:
        vc.rwrite(
            "FAIL: Energies and/or forces differ by more than "
            "a relative error of {0}".format(tole)
        )
        if trans_passed:
            vc.rwrite("")
            vc.rwrite(
                "      NOTE: The model IS translationally invariant but "
                "not rotationally invariant"
            )
    vc.rwrite("-" * dashwidth)

    return passed


################################################################################
def do_vc(model, vc):
    """
    Do Objectivity Translation VC
    """
    # Max iterations allowed for some of the while loops below
    max_iters = 2000

    # Get supported species
    species = get_model_supported_species(model)
    species = kim_ase_utils.remove_species_not_supported_by_ASE(list(species))
    species.sort()

    # Basic cell parameters
    lattice_constant_orig = 3.0
    pert_amp_orig = 0.1 * lattice_constant_orig
    ncells_per_side = 2
    seed = 13
    random.seed(seed)

    # Finite domain in which to embed the finite cluster of atoms we'll translate and
    # invert
    large_cell_len = 7 * lattice_constant_orig * ncells_per_side

    # Print Vc info
    dashwidth = 120
    vc.rwrite("")
    vc.rwrite("-" * dashwidth)
    vc.rwrite("Results for KIM Model      : %s" % model.strip())
    vc.rwrite("Supported species          : %s" % " ".join(species))
    vc.rwrite("")
    vc.rwrite("random seed                = %d" % seed)
    vc.rwrite("lattice constant (orig)    = %0.3f" % lattice_constant_orig)
    vc.rwrite("perturbation amplitude     = %0.3f" % pert_amp_orig)
    vc.rwrite("number unit cells per side = %d" % ncells_per_side)
    vc.rwrite("-" * dashwidth)
    vc.rwrite("")

    # Initialize variables
    got_atleast_one = False
    passed_all = True

    # Perform objectvitity check for monotatomic systems
    for spec in species:
        # Check if this species has non-trivial force and energy interactions
        atoms_interacting_energy, atoms_interacting_force = kim_ase_utils.check_if_atoms_interacting(
            model, symbols=[spec, spec]
        )
        if not atoms_interacting_energy:
            vc.rwrite("")
            vc.rwrite(
                "WARNING: The model provided, {}, does not possess a non-trivial energy "
                "interaction for species {} as required by this Verification "
                "Check. Skipping...".format(model, spec)
            )
            vc.rwrite("")
            continue

        if not atoms_interacting_force:
            vc.rwrite("")
            vc.rwrite(
                "WARNING: The model provided, {}, does not possess a non-trivial force "
                "interaction for species {} as required by this Verification Check.  "
                "Skipping...".format(model, spec)
            )
            vc.rwrite("")
            continue

        calc = KIM(model)
        lattice_constant = lattice_constant_orig
        got_initial_config = False
        while not got_initial_config:
            atoms = BodyCenteredCubic(
                size=(ncells_per_side, ncells_per_side, ncells_per_side),
                latticeconstant=lattice_constant,
                symbol=spec,
                pbc=False,
            )

            # Move our finite cluster of atoms to the center of our large cell
            atoms.set_cell([large_cell_len, large_cell_len, large_cell_len])
            trans = [0.5 * large_cell_len] * 3
            atoms.translate(trans)

            atoms.set_calculator(calc)
            try:
                kim_ase_utils.rescale_to_get_nonzero_forces(atoms, 0.01)
                got_initial_config = True
            except kim_ase_utils.KIMASEError:
                # Routine failed in on recoverable manner
                raise  # re-raise same exception
            except Exception:
                # Initial config failed. This most likely due to an evaluation
                # outside the legal model range. Increase lattice constant and
                # try again.
                lattice_constant += 0.25
                if lattice_constant > 10.0:
                    raise RuntimeError(
                        "Cannot find a working configuration within a reasonable lattice constant range."
                    )

        # Randomize positions
        save_positions = atoms.get_positions()
        pert_amp = pert_amp_orig
        got_randomized_config = False
        iters = 0
        while not got_randomized_config:
            try:
                kim_ase_utils.randomize_positions(atoms, pert_amp)
                atoms.get_forces()  # make sure forces can be computed
                got_randomized_config = True
            except:  # noqa: E722
                # Failed to compute forces; reset to original posns and retry
                atoms.set_positions(save_positions)
                pert_amp *= 0.5  # cut perturbation amplitude by half
                iters += 1
                if iters >= max_iters:
                    raise RuntimeError(
                        "Iteration limit exceeded when randomizing positions "
                        "during check for species {}".format(spec)
                    )

        # Move atoms around until all forces are sizeable
        kim_ase_utils.perturb_until_all_forces_sizeable(atoms, pert_amp)
        aux_file = "config-" + spec + ".xyz"
        vc.vc_files.append(aux_file)
        vc.write_aux_ase_atoms(aux_file, atoms, "xyz")
        heading = (
            "MONOATOMIC STRUCTURE -- Species = "
            + spec
            + '   (Configuration in file "'
            + aux_file
            + '")'
        )
        try:
            passed = perform_objectivity_check(vc, atoms, heading, dashwidth)
            passed_all = passed_all and passed
            got_atleast_one = True
        except:  # noqa: E722
            pass
        finally:
            # Explicitly close calculator to ensure any allocated memory is freed
            # (relevant for SMs)
            if hasattr(calc, "__del__"):
                calc.__del__()

    # Perform numerical derivative check for mixed system
    if len(species) > 1:
        lattice_constant = lattice_constant_orig
        while True:
            atoms = BodyCenteredCubic(
                size=(ncells_per_side, ncells_per_side, ncells_per_side),
                latticeconstant=lattice_constant,
                symbol="H",
                pbc=False,
            )

            if len(atoms) < len(species):
                ncells_per_side += 1
            else:
                break

        # Move our finite cluster of atoms to the center of our large cell
        atoms.set_cell([large_cell_len, large_cell_len, large_cell_len])
        trans = [0.5 * large_cell_len] * 3
        atoms.translate(trans)

        kim_ase_utils.randomize_species(atoms, species)
        calc = KIM(model)
        atoms.set_calculator(calc)
        got_initial_config = False
        while not got_initial_config:
            try:
                kim_ase_utils.rescale_to_get_nonzero_forces(atoms, 0.01)
                got_initial_config = True
            except kim_ase_utils.KIMASEError:
                # Routine failed in on recoverable manner
                raise  # re-raise same exception
            except Exception:
                # Initial config failed. This most likely due to an evaluation
                # outside the legal model range. Increase lattice constant and
                # try again.
                lattice_constant += 0.25
                if lattice_constant > 10.0:
                    raise RuntimeError(
                        "Cannot find a working configuration within # a reasonable lattice constant range."
                    )
                acell = lattice_constant * ncells_per_side
                atoms.set_cell([acell, acell, acell], scale_atoms=True)

        # Randomize positions
        save_positions = atoms.get_positions()
        pert_amp = pert_amp_orig
        got_randomized_config = False
        iters = 0
        while not got_randomized_config:
            try:
                kim_ase_utils.randomize_positions(atoms, pert_amp)
                atoms.get_forces()  # make sure forces can be computed
                got_randomized_config = True
            except:  # noqa: E722
                # Failed to compute forces; reset to original posns and retry
                atoms.set_positions(save_positions)
                pert_amp *= 0.5  # cut perturbation amplitude by half
                iters += 1
                if iters >= max_iters:
                    raise RuntimeError(
                        "Iteration limit exceeded when randomizing positions "
                        "during check for species {}".format(spec)
                    )

        kim_ase_utils.perturb_until_all_forces_sizeable(atoms, pert_amp)
        aux_file = "config-" + "".join(species) + ".xyz"
        vc.vc_files.append(aux_file)
        vc.write_aux_ase_atoms(aux_file, atoms, "xyz")
        heading = (
            "MIXED STRUCTURE -- Species = "
            + " ".join(species)
            + '   (Configuration in file "'
            + aux_file
            + '")'
        )
        try:
            passed = perform_objectivity_check(vc, atoms, heading, dashwidth)
            passed_all = passed_all and passed
            got_atleast_one = True
        except:  # noqa: E722
            pass

    if got_atleast_one:

        # Compute grade
        vc.rwrite("")
        vc.rwrite("=" * dashwidth)
        vc.rwrite(
            "To pass this verification check the model must be invariant "
            "with respect to"
        )
        vc.rwrite("rigid-body motion (translation and rotation) for all configurations")
        vc.rwrite("it was able to compute.")
        vc.rwrite("")

        if passed_all:
            vc_grade = "P"
            vc_comment = (
                "Model energy and forces are invariant with respect "
                "to rigid-body motion (translation and rotation) for "
                "all configurations the model was able to compute."
            )
        else:
            vc_grade = "F"
            vc_comment = (
                "Model energy and/or forces are NOT invariant with respect "
                "to rigid-body translation and/or rotation for at "
                "least one configuration that the model was "
                "able to compute. This could be valid if the model "
                "includes an external field. Otherwise this is an "
                "error in the model implementation."
            )

        return vc_grade, vc_comment

    else:
        msg = (
            "ERROR: Failed to compute all configuration for the rigid-body translation "
            "verification check."
        )
        vc.rwrite("")
        vc.rwrite(msg)
        vc.rwrite("")
        raise RuntimeError(msg)


################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == "__main__":

    vcargs = {
        "vc_name": "vc-objectivity",
        "vc_author": __author__,
        "vc_description": kim_vc_utils.vc_stripall(__doc__),
        "vc_category": "informational",
        "vc_grade_basis": "passfail",
        "vc_files": [],
        "vc_debug": False,  # Set to True to get exception traceback info
    }

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Execute VC
    kim_vc_utils.setup_and_run_vc(do_vc, model, **vcargs)