#!/usr/bin/env python3
"""
Lattice constant Test Driver for Hexagonal Structure

Computes the zero-temperature equilibrium lattice geometry for any element in a
hexagonal close-packed, simple hexagonal, or graphite geometry by using simplex
minimization.

Date: 2015/08/25
Author: Junhao Li <streaver91@gmail.com>

Last Update: 2016/02/26 Junhao Li

             2017/08/28 Daniel S. Karls

             2019/08/27 Daniel S. Karls
                       + Convert to python3
                       + Use model cutoff to determine geometries considered
                       + Add maxiter constraint on simplex minimizations
                       + Place lower limits on lattice parameters to prevent possible
                         neighbor list overflow
                       + Extend to work with toy models.  Specifically, if the initial
                         geometry guesses still place all atoms outside of the model
                         cutoff, the lattice is shrunk before performing the search over
                         lattice constants
"""
# Python 2-3 compatible code issues
from __future__ import print_function

try:
    input = raw_input
except NameError:
    pass

import os
import json
import math
from collections import OrderedDict

import numpy as np
from scipy.optimize import fmin
from ase import Atoms
from ase.calculators.kim.kim import KIM
import kim_python_utils.ase as kim_ase_utils

# Constants for result.edn
SPACE_GROUP = "P63/mmc"
WCYKOFF_CODE = "2a"
KEY_SOURCE_VALUE = "source-value"
KEY_SOURCE_UNIT = "source-unit"
KEY_SOURCE_UNCERT = "source-std-uncert-value"
UNIT_LENGTH = "angstrom"
UNIT_TEMPERATURE = "K"
UNIT_PRESSURE = "GPa"
UNIT_ENERGY = "eV"
STRUCTURE_PROP_ID = (
    "tag:staff@noreply.openkim.org,2014-04-15:property/structure-hexagonal-crystal-npt"
)
ENERGY_PROP_ID = "tag:staff@noreply.openkim.org,2014-04-15:property/cohesive-potential-energy-hexagonal-crystal"

# Constants frequently used in calculation
SQRT3 = math.sqrt(3.0)
PERFECT_CA = math.sqrt(8.0 / 3.0)

# Set up a 4-atom orthorhombic unit cell
HCP_CUBIC_CELL = np.array([1.0, SQRT3, 1.0])
HCP_CUBIC_POSITIONS = np.array(
    [
        [0.0, 0.0, 0.0],
        [0.5, 0.5 * SQRT3, 0.0],
        [0.5, (0.5 / 3.0) * SQRT3, 0.5],
        [0.0, (0.5 + 0.5 / 3.0) * SQRT3, 0.5],
    ]
)
HCP_CUBIC_NATOMS = 4

# Minimization Convergence Criteria
FMIN_FTOL = 1e-10
FMIN_XTOL = 1e-10


def check_neigh_per_atom(atoms, a, c, model_cutoff):
    """
    Calculate number of neighbors that each atom has (all have identical environments).
    If a KIM Portable Model is being used, then the neighbor lists will be built using
    1.2 times the model's influence distance since we did not specify the
    neigh_skin_ratio parameter when initializing the calculator.  If a LAMMPS-based
    Simulator Model is being used, then neighbor list skin will be something different
    depending on the units of the Simulator Model.  Altogether, here we just use 1.25
    times the model cutoff that we calculated ourselves as our factor of safety.
    """
    if a <= 0.1 * model_cutoff or c <= 0.1 * model_cutoff:
        raise RuntimeError(
            "Attempted to evaluate energy at a = {} Angstroms and c = {} "
            "Angstroms; the model cutoff was estimated to be {} Angstroms.  This "
            "may mean that the model does not possess close-range repulsive forces "
            "to prevent system collapse.  In order to prevent possible neighbor list "
            "overflow, both a and c must be greater than 0.1 times the model cutoff..."
            "".format(a, c, model_cutoff)
        )


def energy(cellVector, atoms, isolated_energy_per_atom, model_cutoff):
    """
    Given the lattice constants a and c in 'cellVector', compute the total energy of a
    lattice supercell relative to the isolated energy of the same number of atoms.
    """
    # Determine a and c based on the length of the cellVector
    a = cellVector[0]
    if len(cellVector) == 2:
        c = cellVector[1]
    else:
        c = a * PERFECT_CA  # Default to perfect c/a ratio

    # Scale atoms with the scaled new_cell
    new_cell = HCP_CUBIC_CELL * [a, a, c]
    atoms.set_cell(new_cell, scale_atoms=True)

    # Check to see if this combination of lattice constants leads to more than 5000
    # neighbors per atom for this model cutoff
    check_neigh_per_atom(atoms, a, c, model_cutoff)

    # Evaluate energy and subtract isolated energy
    energy = atoms.get_potential_energy() - len(atoms) * isolated_energy_per_atom
    return energy


def get_interlayer_atom_dist(latticeConstants):
    # Calculate the distance between atom 0 and 1 (refer to HCP_CUBIC_POSITIONS)
    # This represents the closest distance that exists between an atom in one layer and
    # another atom in the adjacent layer.
    hrDist = latticeConstants[0] / SQRT3
    vDist = latticeConstants[1] / 2
    return math.sqrt(hrDist * hrDist + vDist * vDist)


def collapse_handler(latticeConstants):
    raise Exception(
        "Calculation converged to lattice constants a = {} Angstroms and c "
        "= {} Angstroms, for which the atoms have the same energy as in isolation. "
        "This may indicate the lattice is unstable.".format(*latticeConstants)
    )


def get_lattice_constant(symbol, model, maxiter):

    # Check if atoms have an energy interaction for this model
    atoms_interacting = kim_ase_utils.check_if_atoms_interacting(
        model, symbols=[symbol, symbol], check_force=False, etol=1e-6
    )
    if not atoms_interacting:
        raise RuntimeError(
            "The model provided, {}, does not possess a non-trivial energy interaction "
            "for species {} as required by this Test.  Aborting.".format(model, symbol)
        )

    # Get model cutoff
    model_cutoff = kim_ase_utils.get_model_energy_cutoff(
        model, symbols=[symbol, symbol]
    )

    print("Computed model cutoff of {} Angstroms".format(model_cutoff))

    # Get isolated energy per atom
    isolated_energy_per_atom = {}
    isolated_energy_per_atom[symbol] = kim_ase_utils.get_isolated_energy_per_atom(
        model, symbol
    )

    print(
        "Computed isolated atomic energy of {} eV for species {}"
        "".format(isolated_energy_per_atom[symbol], symbol)
    )

    # Create lattice and calculator
    atoms = Atoms(
        symbol + "4", positions=HCP_CUBIC_POSITIONS, cell=HCP_CUBIC_CELL, pbc=True
    )
    num_atoms = len(atoms)
    calc = KIM(model)
    atoms.set_calculator(calc)

    # Used in searching over 'a' with c/a fixed to sqrt(8/3); all in angstroms
    min_init_guess_a = 1.0
    max_init_guess_a = max(3.5, model_cutoff)
    init_guess_a_incr = 0.125

    # Check if the initial guess would be outside of the cutoff.  This is unlikely, but
    # can occur for Toy Models.  On the other hand, many non-toy Models may produce an
    # error because init_guess is too small for them to handle and, in this case, we
    # simply proceed to the loop below, which will catch the associated exceptions as it
    # tries larger initial guesses.
    init_guess_a_too_large = min_init_guess_a >= model_cutoff

    # First, optimize 'a' while keeping c/a fixed to sqrt(8/3).  This means that the
    # smallest distance between atoms in two adjacent layers is exactly equal to 'a'.
    # Therefore, the lattice being unstable under this constraint means that the energy
    # will become equal to the sum of the corresponding isolated atomic energies.
    print("\n" + "=" * 78 + "\n" + "=" * 78)
    print(
        " Starting optimization to find optimum lattice constant 'a' with "
        "c/a fixed to\n sqrt(8/3)"
    )
    print("=" * 78 + "\n" + "=" * 78 + "\n")

    # Set initial value for minEnergy, which keeps track of the lowest energy we
    # encounter across all initial guesses of both a and c
    minEnergy = 0
    tmpLatticeConstants = [None, None]

    if not init_guess_a_too_large:
        # Start at init_guess_a and increment up to max_init_guess_a, performing a
        # simplex minimization for each initial guess
        for init_guess_a in np.arange(
            min_init_guess_a, max_init_guess_a + init_guess_a_incr, init_guess_a_incr
        ):
            print("")
            print(
                "Attempting to perform relaxation using initial lattice constant guess "
                "of a = {} Angstroms (c/a fixed at sqrt(8/3))".format(init_guess_a)
            )
            print("")
            try:
                tmpLatticeConstants, tmpEnergy, iterations, funcalls, warnflag = fmin(
                    energy,
                    [init_guess_a],
                    args=(atoms, isolated_energy_per_atom[symbol], model_cutoff),
                    maxiter=maxiter,
                    full_output=True,
                    ftol=FMIN_FTOL,
                    xtol=FMIN_XTOL,
                )

                print("Results for this minimization:")
                print(
                    "  a = {}, total energy = {}".format(
                        tmpLatticeConstants[0], tmpEnergy
                    )
                )

                if tmpEnergy < minEnergy:
                    minLatticeConstants = (
                        tmpLatticeConstants,
                        tmpLatticeConstants * PERFECT_CA,
                    )
                    minEnergy = tmpEnergy

            except Exception as e:
                print(
                    "Failed to perform relaxation w.r.t. 'a' with c/a fixed at "
                    "sqrt(8/3) and an initial guess of a = {} Angstroms".format(
                        init_guess_a
                    )
                )
                print("Exception message:\n{}".format(str(e)))
                print("Continuing...\n")

    else:
        print("")
        print(
            "Detected that energy at a lattice constant of a = {} Angstroms (with "
            "c/a fixed to sqrt(8/3)) was equal to the isolated energy (i.e. no "
            "interactions). Shrinking box until energy becomes non-trivial...\n"
            "".format(min_init_guess_a, ".2f")
        )
        kim_ase_utils.rescale_to_get_nonzero_energy(
            atoms, isolated_energy_per_atom, etol=1e-6
        )
        rescaled_a = atoms.get_cell()[0, 0]
        min_init_guess_a = 0.5 * rescaled_a
        max_init_guess_a = 1.2 * rescaled_a
        init_guess_a_incr = 0.1 * rescaled_a

        for init_guess_a in np.arange(
            min_init_guess_a, max_init_guess_a + init_guess_a_incr, init_guess_a_incr
        ):
            print("")
            print(
                "Attempting to perform relaxation using initial lattice constant guess "
                "of a = {} Angstroms (c/a fixed at sqrt(8/3))".format(init_guess_a)
            )
            print("")
            try:
                # Simplex searching lattice constants that minimize potential energy of atoms
                tmpLatticeConstants, tmpEnergy, iterations, funcalls, warnflag = fmin(
                    energy,
                    [init_guess_a],
                    args=(atoms, isolated_energy_per_atom[symbol], model_cutoff),
                    maxiter=maxiter,
                    full_output=True,
                    ftol=FMIN_FTOL,
                    xtol=FMIN_XTOL,
                )

                print("Results for this minimization:")
                print(
                    "  a = {}, total energy = {}".format(
                        tmpLatticeConstants[0], tmpEnergy
                    )
                )

                if tmpEnergy < minEnergy:
                    minLatticeConstants = (
                        tmpLatticeConstants,
                        tmpLatticeConstants * PERFECT_CA,
                    )
                    minEnergy = tmpEnergy

            except Exception as e:
                print(
                    "Failed to perform relaxation w.r.t. 'a' with c/a fixed at "
                    "sqrt(8/3) and an initial guess of a = {} Angstroms".format(
                        init_guess_a
                    )
                )
                print("Exception message:\n{}".format(str(e)))
                print("Continuing...\n")

    if abs(minEnergy) < 1e-8:
        collapse_handler(minLatticeConstants)

    # Relaxation with both a and c/a relaxed.  Initial guess for 'a' is the optimal
    # value reached from the minimization above and the initial guess for 'c' is
    # sqrt(8/3) times that value.
    tmpA = minLatticeConstants[0]
    print("\n" + "=" * 78 + "\n" + "=" * 78)
    print(" Starting optimization with both a and c/a relaxed")
    print("=" * 78 + "\n" + "=" * 78 + "\n")

    min_init_guess_c = 0.8 * PERFECT_CA * tmpA
    max_init_guess_c = 1.2 * PERFECT_CA * tmpA
    init_guess_c_incr = 0.05 * PERFECT_CA * tmpA

    for init_guess_c in np.arange(
        min_init_guess_c, max_init_guess_c + init_guess_c_incr, init_guess_c_incr
    ):
        print("")
        print(
            "Attempting to perform relaxation using initial lattice constant guess "
            "of a = {} Angstroms and c = {} = ({} * a) Angstroms".format(
                tmpA, init_guess_c, init_guess_c / tmpA
            )
        )
        print("")
        try:
            tmpLatticeConstants, tmpEnergy, iterations, funcalls, warnflag = fmin(
                energy,
                [tmpA, init_guess_c],
                args=(atoms, isolated_energy_per_atom[symbol], model_cutoff),
                maxiter=maxiter,
                full_output=True,
                ftol=FMIN_FTOL,
                xtol=FMIN_XTOL,
            )

            print("Results for this minimization:")
            print(
                "  a = {}, c = {}, total energy = {}".format(
                    tmpLatticeConstants[0], tmpLatticeConstants[1], tmpEnergy
                )
            )

            if tmpEnergy < minEnergy:
                minLatticeConstants = tmpLatticeConstants
                minEnergy = tmpEnergy

        except Exception as e:
            print(
                "     Failed to perform relaxation w.r.t. 'a' with c/a fixed at "
                "sqrt(8/3) and an initial guess of a = {} Angstroms".format(
                    init_guess_a
                )
            )
            print("     Exception message:\n{}".format(str(e)))
            print("     Continuing...\n")

    if abs(minEnergy) < 1e-8:
        collapse_handler(minLatticeConstants)

    # Now check to see if there are still interlayer and intraylayer interactions. If
    # either is missing, raise an exception
    final_a, final_c = minLatticeConstants

    if final_a > model_cutoff:
        raise RuntimeError(
            "Lattice converged to a = {} Angstroms and c = {} Angstroms "
            "(c/a = {:.6f}). The model cutoff computed was {} Angstroms.  Since "
            "a > model_cutoff, the minimization has converged to a state where "
            "there are no longer intralayer forces, indicating the lattice is "
            "unstable.".format(final_a, final_c, float(final_c / final_a), model_cutoff)
        )

    final_min_interlayer_atom_dist = get_interlayer_atom_dist(minLatticeConstants)

    if final_min_interlayer_atom_dist > model_cutoff:
        raise RuntimeError(
            "Lattice converged to a = {} Angstroms and c = {} "
            "Angstroms (c/a = {:.6f}), which means the distance between an atom "
            "and its nearest neighbor in an adjacent layer is {} Angstroms.  The "
            "model cutoff computed was {} Angstroms.  Since the smallest "
            "interlayer atom distance is greater than the model_cutoff, the "
            "minimization has converged to a state where there are no longer "
            "interlayer forces, indicating the lattice is unstable.".format(
                final_a,
                final_c,
                float(final_c / final_a),
                final_min_interlayer_atom_dist,
                model_cutoff,
            )
        )

    print("\n" + "=" * 78 + "\n" + "=" * 78)
    print(" Finished with optimization")
    print("=" * 78 + "\n" + "=" * 78 + "\n")
    print(" Final lattice Constants:")
    print(
        "   a = {} Angstroms, c = {} Angstroms, c/a = {:.6f}".format(
            minLatticeConstants[0],
            minLatticeConstants[1],
            float(minLatticeConstants[1] / minLatticeConstants[0]),
        )
    )
    print("\n Total energy: {} eV".format(minEnergy))
    print("\n Cohesive energy: {} eV/atom".format(minEnergy / num_atoms))
    print("\n" + "=" * 78 + "\n" + "=" * 78 + "\n\n\n")

    # Return Results
    return (
        float(minLatticeConstants[0]),
        float(minLatticeConstants[1]),
        float(minEnergy / num_atoms),
    )


def V(value, unit="", uncert=""):
    # Generate OrderedDict for KIM JSON Output
    res = OrderedDict([(KEY_SOURCE_VALUE, value)])
    if unit != "":
        res.update(OrderedDict([(KEY_SOURCE_UNIT, unit)]))
    if uncert != "":
        res.update(OrderedDict([(KEY_SOURCE_UNCERT, uncert)]))
    return res


if __name__ == "__main__":

    # Input Parameters
    symbol = input("Element = ")
    lattice = input("Lattice = ")
    model = input("Model = ")

    # Print Inputs
    print("Element: {}".format(symbol))
    print("Lattice: {}".format(lattice))  # Not used here
    print("Model: {}".format(model))

    # Obtain Lattice Constants and Cohesive Energy
    a, c, cohesive_energy = get_lattice_constant(symbol, model, maxiter=500)

    # Output Results
    structureResults = OrderedDict(
        [
            ("property-id", STRUCTURE_PROP_ID),
            ("instance-id", 1),
            ("cauchy-stress", V([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UNIT_PRESSURE)),
            ("temperature", V(0, UNIT_TEMPERATURE)),
        ]
    )
    cohesiveEnergyResults = OrderedDict(
        [
            ("property-id", ENERGY_PROP_ID),
            ("instance-id", 2),
            ("cohesive-potential-energy", V(-cohesive_energy, UNIT_ENERGY)),
        ]
    )
    commonResults = OrderedDict(
        [
            ("short-name", V(["hcp"])),
            ("species", V([symbol, symbol])),
            ("a", V(a, UNIT_LENGTH)),
            ("c", V(c, UNIT_LENGTH)),
            ("basis-atom-coordinates", V([[0.0, 0.0, 0.0], [2.0 / 3, 1.0 / 3, 0.5]])),
            ("space-group", V(SPACE_GROUP)),
        ]
    )
    structureResults.update(commonResults)
    cohesiveEnergyResults.update(commonResults)
    results = [structureResults, cohesiveEnergyResults]
    resultsString = json.dumps(results, separators=(" ", " "), indent=4)
    # print(resultsString)
    with open(os.path.abspath("output/results.edn"), "w") as f:
        f.write(resultsString)