#!/usr/bin/env python3
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2019, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#
################################################################################

# The docstring below is vc_description
"""This verification check ensures that a model supports all of the species that it
claims to. The verification check attempts to compute the energy for all clusters up to
a preset size for all combinations of the species listed by the model.  Several random
samples are generated for each cluster.  The check is successful if all elements appear
in at least one successful calculation, i.e. the model does not abort during the
calculation.  This algorithm accounts for the fact that some models may only support
species in certain combinations.  For example a model supporting species A and B  may be
for an AB alloy and may only support A-B interactions and not A-A or B-B. The output
provides full information on which interactions are supported and which are not.  Note
that this definition of 'support' does not imply that the model has non-trivial energy
and/or force interactions for any of the species."""

# Python 2-3 compatible code issues
from __future__ import print_function

import random
import itertools

import numpy as np
from ase.calculators.kim import KIM, get_model_supported_species
import kim_python_utils.ase as kim_ase_utils
import kim_python_utils.vc as kim_vc_utils
from ase import Atoms

__version__ = "002"
__author__ = "Ellad Tadmor"

try:
    input = raw_input
except NameError:
    pass


################################################################################
#
#   FUNCTIONS
#
################################################################################
def pos_has_distance_less_than_dmin(pos, dmin):
    for i in range(len(pos) - 1):
        for j in range(i + 1, len(pos)):
            dist = np.linalg.norm(pos[j] - pos[i])
            if dist < dmin:
                return True
    return False


################################################################################
def get_cluster_positions(N, rball, dmin):
    """
    Select 'N' positions within a sphere of radius 'rball' where no two points
    are closer than 'rmin'.
    """
    pos = np.zeros((N, 3))
    points_too_close = True
    while points_too_close:
        # create N points uniformly distributed within a sphere of radius 'rball'
        for i in range(N):
            phi = random.uniform(0, 2 * np.pi)
            costheta = random.uniform(-1.0, 1.0)
            theta = np.arccos(costheta)
            x = np.sin(theta) * np.cos(phi)
            y = np.sin(theta) * np.sin(phi)
            z = np.cos(theta)
            rad = random.uniform(0, rball)
            pos[i] = (x * rad, y * rad, z * rad)
        points_too_close = pos_has_distance_less_than_dmin(pos, dmin)
    return pos


################################################################################
def do_vc(model, vc):
    """
    Perform Species Supported as Stated VC
    """
    # Get supported species
    species = get_model_supported_species(model)
    species = kim_ase_utils.remove_species_not_supported_by_ASE(list(species))
    species.sort()

    # Initialize
    Nmax = 4  # clusters ranging from 2 to Nmax atoms will be
    # tested (as long as 'max_combinations' is not
    # exeeded.
    max_combinations = 10000  # maximum number of allowed combinations
    rball = 1.5  # radius of ball containing clusters
    dmin = 1.0  # minimum allowed distance between atoms
    Nsamp = 10  # number of samples of each cluster
    seed = 13
    random.seed(seed)
    species_is_supported = {}
    for spec in species:
        species_is_supported[spec] = False  # Assume none of the species
        # are supported and now test
    yesno = {True: "yes", False: "** NO **"}

    # Finite domain in which to embed the finite cluster of atoms
    large_cell_len = 7 * rball

    # Print VC Header
    dashwidth = 80
    vc.rwrite("")
    vc.rwrite("-" * dashwidth)
    vc.rwrite("Results for KIM Model      : %s" % model.strip())
    vc.rwrite("Supported species          : %s" % " ".join(species))
    vc.rwrite("")
    vc.rwrite("Maximum cluster size considered (#atoms)  = %d" % Nmax)
    vc.rwrite("Maximum number of allowed combinations    = %d" % max_combinations)
    vc.rwrite("Radius of ball containing clusters:       = %0.3f" % rball)
    vc.rwrite("Minimum allowed distance between atoms:   = %0.3f" % dmin)
    vc.rwrite("Number of random samples for each cluster = %d" % Nsamp)
    vc.rwrite("Random number seed:                       = %d" % seed)
    vc.rwrite("-" * dashwidth)
    vc.rwrite("")

    for spec in species:
        # Check if this species has non-trivial force and energy interactions
        atoms_interacting_energy, atoms_interacting_force = kim_ase_utils.check_if_atoms_interacting(
            model, symbols=[spec, spec]
        )
        if not atoms_interacting_energy:
            vc.rwrite("")
            vc.rwrite(
                "INFO: The model provided, {}, does not possess a non-trivial energy "
                "interaction for species {}, i.e. the energy computed by the model for "
                "a given set of atoms of this species is always equal to the energy "
                "it computes for an isolated atom of this species multiplied by the "
                "number of atoms present. This message is strictly informational and "
                "will not affect the results of this Verification Check.".format(
                    model, spec
                )
            )
            vc.rwrite("")

        if not atoms_interacting_force:
            vc.rwrite("")
            vc.rwrite(
                "INFO: The model provided, {}, does not possess a non-trivial force "
                "interaction for species {} i.e. the forces computed by the model for "
                "a given set of atoms of this species will always be zero. This "
                "message is strictly informational, and will not affect the results of "
                "this Verification Check.".format(model, spec)
            )
            vc.rwrite("")

    # Loop over all clusters from 2 to Nmax considering all possible
    # species combinations without considering order.
    # For example for a model supporting 3 species (A,B,C) the loop would
    # be over 2-atom and 3-atom clusters with the following combinations:
    # A-A, B-B, C-C, A-B, A-C, B-C
    # A-A-A, B-B-B, C-C-C, A-A-B, A-A-C, B-B-A, B-B-C, C-C-A, C-C-B, A-B-C
    # Test whether the model supports the configuration. The model passes
    # the check if all species appear in at least one supported configuration.

    # Report this combination
    vc.rwrite("%-*s   %7s     %s" % (2 * Nmax, "Cluster", "Success", "Supported"))
    vc.rwrite("-" * (2 * Nmax + 24))

    for N in range(2, Nmax + 1):

        # Create all combintations of species for a cluster of size N
        species_cluster = list(itertools.combinations_with_replacement(species, N))

        # Compute number of combinations and skip if too many
        num_combinations = len(species_cluster)
        if num_combinations > max_combinations:
            vc.rwrite(
                "...... skipping %d-cluster (%d combinations)" % (N, num_combinations)
            )
            continue

        # Loop over combinations
        for i in range(len(species_cluster)):

            # Loop over samples
            num_success = 0
            for n in range(Nsamp):

                # Assign positions to atoms in a cluster of N atoms in a ball of
                # radius 'rball' with no atoms closer than 'dmin'.
                pos = get_cluster_positions(N, rball, dmin)

                # Define Atoms object for cluster
                atoms = Atoms(
                    "".join(species_cluster[i]),
                    positions=pos,
                    cell=(2 * rball, 2 * rball, 2 * rball),
                    pbc=(False, False, False),
                )
                calc = KIM(model)

                # Move our finite cluster of atoms to the center of our large cell
                atoms.set_cell([large_cell_len, large_cell_len, large_cell_len])
                trans = [0.5 * large_cell_len] * 3
                atoms.translate(trans)

                # Try to compute energy
                try:
                    atoms.set_calculator(calc)
                    atoms.get_potential_energy()
                    num_success += 1
                    for spec in species_cluster[i]:
                        species_is_supported[spec] = True
                except:  # noqa: E722
                    pass
                finally:
                    # Explicitly close calculator to ensure any allocated memory is freed
                    # (relevant for SMs)
                    if hasattr(calc, "__del__"):
                        calc.__del__()

            # Report this combination
            vc.rwrite(
                "%-*s   %2d / %2d     %s"
                % (
                    2 * Nmax,
                    "".join(species_cluster[i]),
                    num_success,
                    Nsamp,
                    yesno[num_success > 0],
                )
            )
    vc.rwrite("-" * (2 * Nmax + 24))

    vc.rwrite("")
    vc.rwrite("=" * dashwidth)
    vc.rwrite("To pass this verification check each model element must be supported")
    vc.rwrite("in at least one cluster combination.")
    vc.rwrite("")
    if all(value for value in species_is_supported.values()):
        vc_grade = "P"
        vc_comment = "All elements claimed by model are supported in at least one of the tested configurations."
    else:
        vc_grade = "F"
        vc_comment = "At least one of the elements claimed by the model is not supported in any tested configuration."

    return vc_grade, vc_comment


################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == "__main__":

    vcargs = {
        "vc_name": "vc-species-supported-as-stated",
        "vc_author": __author__,
        "vc_description": kim_vc_utils.vc_stripall(__doc__),
        "vc_category": "mandatory",
        "vc_grade_basis": "passfail",
        "vc_files": [],
        "vc_debug": False,  # Set to True to get exception traceback info
    }

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Execute VC
    kim_vc_utils.setup_and_run_vc(do_vc, model, **vcargs)