#!/usr/bin/env python3
"""
This Test Driver computes the total potential energy and forces acting on a
finite cluster of atoms in a rectangular box, supplied in the form of an xyz
file which lies in each of its Tests' directories.  It then relaxes this config-
uration using the Polak-Ribiere version of conjugate gradient minimization (see
LAMMPS 'minimize' documentation for more details).

The xyz file specifying the initial structure must be of the following form:

Natoms
Comment line
Element positionx positiony positionz
Element positionx positiony positionz
.
.
.

where Element is an integer which specifies the atomic species code used by
LAMMPS.  The numbering of these species codes should start from 1 and
corresponds to the chemical species according to the order of elements
specified in pipeline.stdin.tpl:  the first species given in pipeline.stdin.tpl
should correspond to species 1, the second to species 2, etc (and note that
multiple species are separated by spaces).

Author: Daniel S. Karls, University of Minnesota (karl0100 |AT| umn DOT edu)
"""

import json, subprocess, os, shutil, re
from collections import OrderedDict
from itertools import islice

import numpy

import xyz_utils


def stripquotes(matchobj):
    return matchobj.group(1)


def run_lammps(infile, outfile):
    """Run LAMMPS with given input file and write the output to outfile"""
    with open(outfile, "w") as outfl:
        try:
            lammps_process = subprocess.check_call(
                ["lammps", '-in', infile], shell=False, stdout=outfl
            )
        except subprocess.CalledProcessError:
            extrainfo = ""
            try:
                with open("log.lammps") as f:
                    extrainfo = f.read()
            except IOError:
                extrainfo = "no log file"
            raise Exception("LAMMPS did not exit properly:\n" + extrainfo)


def get_isolated_atom_energy(model, element, mass):
    ISOLATED_ATOM_LAMMPS_INPUT_TEMPLATE = open(
        os.path.join(THIS_DIR, "isolated_atom.lammps.tpl")
    ).read()
    templated_input = "isolated_atom.lammps." + element + ".in"
    lammps_output = "isolated_atom.lammps." + element + ".out"
    with open(templated_input, "w") as in_file:
        in_file.write(
            ISOLATED_ATOM_LAMMPS_INPUT_TEMPLATE.format(
                modelname=model, symbol=element, mass=mass
            )
        )

    run_lammps(templated_input, lammps_output)

    # Get potential energy
    with open(lammps_output) as outfl:
        output = outfl.read()
    try:
        energy = re.search("Isolated atom energy: ([0-9e.\-]+) eV", output).group(1)
        energy = float(energy)
    except AttributeError:
        raise Exception("Failed to find the potential energy")

    return energy


# Read elements, masses, and model from stdin
elementinput = input("Elements: ")
print(elementinput)
elements = elementinput.split(" ")
numelements = len(elements)
if numelements < 1:
    raise RuntimeError("ERROR: No species were listed! Exiting...")
massesinput = input("Masses in g/mol (must match order of elements above): ")
print(massesinput)
masses = massesinput.split(" ")
model = input("Model: ")
print(model)
xyz_src = input("XYZ file: ")
print(xyz_src)
print("")

# Some directories we need
THIS_DIR = os.path.dirname(__file__)
MAIN_LAMMPS_INPUT_TEMPLATE = open(os.path.join(THIS_DIR, "main.lammps.tpl")).read()
INDIR = os.path.join("output", "lammps_inputs")
OUTDIR = os.path.join("output", "lammps_output_log")
DUMPDIR = os.path.join("output", "lammps_dump")

# Ensure the directories we need are created
try:
    os.makedirs(INDIR)
except OSError:
    pass
try:
    os.makedirs(OUTDIR)
except OSError:
    pass
try:
    os.makedirs(DUMPDIR)
except OSError:
    pass

# Files LAMMPS uses or produces
infile = os.path.join(INDIR, "main.lammps.tpl")
logfile = os.path.join(OUTDIR, "lammps.log")
dumpfile = os.path.join(DUMPDIR, "lammps.dump")

# Copy the xyz file over so it's preserved after we're done running
xyzfile = os.path.join(INDIR, xyz_src)
shutil.copy(xyz_src, xyzfile)

# Create results dictionary
results = OrderedDict()
results[
    "property-id"
] = "tag:staff@noreply.openkim.org,2014-04-15:property/configuration-cluster-relaxed"
results["instance-id"] = 1
results["species"] = {}
results["unrelaxed-configuration-positions"] = {}
results["unrelaxed-configuration-forces"] = {}
results["unrelaxed-potential-energy"] = {}
results["relaxed-configuration-positions"] = {}
results["relaxed-configuration-forces"] = {}
results["relaxed-potential-energy"] = {}

### REGULAR EXPRESSIONS FOR MATCHING LAMMPS OUTPUT
# Finds potential energy
POTENG_MATCH = re.compile(
    r"""
        v_pe_metal      # MAGIC WORD
        .*$\n           # until end of line
        \s*             # possible leading whitespace
        ([0-9e.\-]+)    # potential energy float, grouped
        """,
    flags=re.VERBOSE | re.MULTILINE,
)

# Finds number of atoms
NATOM_MATCH = re.compile(
    r"""
        ITEM:\ NUMBER\ OF\ ATOMS  # MAGIC WORDS
        .*$\n                     # until end of line
        \s*                       # possible leading whitespace
        (\d+)                     # Natoms, grouped
        """,
    flags=re.VERBOSE | re.MULTILINE,
)

# Finds the ATOMS lines in the LAMMPS dumpfile
UNRELAXED_ATOMS_LINES = re.compile(
    r"""
        ITEM:\ ATOMS         #magic words
        .*?\n                #to end of line, nongreedy dot
        (.*?)                #unrelaxed atoms positions & forces, grouped
        ITEM:\ TIMESTEP
        """,
    flags=re.VERBOSE | re.DOTALL,
)

RELAXED_ATOMS_LINES = re.compile(
    r"""
        ITEM:\ ATOMS         #magic words
        .*?\n                #to end of line, nongreedy dot
        (.*?)                #relaxed atoms positions & forces, grouped
        """,
    flags=re.VERBOSE | re.DOTALL,
)

# Read xyz file, convert the cell to LAMMPS' convention and create a
# dump file to read in
# NOTE: using a dump file rather than a data file means that it's not
#       sensitive to the specific atom_style in the event that we're
#       running with an SM, e.g. if the atom_style of the SM is
#       'charge', charges of zero will simply be assigned to each atom
#       since they're not specified in the dumpfile
numatoms, orig_cell, species, orig_pos = xyz_utils.read_xyz(xyzfile)

# Curate species from xyz file
species_codes = []
for el in species:
    try:
        species_codes.append(int(el))
    except ValueError:
        # Try to map the element string given to an integer code based on
        # the elements the Test inputted
        if el in elements:
            species_codes.append(elements.index(el) + 1)
        else:
            raise RuntimeError(
                "Found species in xyz file {} that was not "
                "given in the elements inputted".format(xyz_src)
            )

# Calculate isolated atomic energies for each species
isolated_atom_energies = {}
for ind, symbol in enumerate(elements):
    isolated_atom_energies[symbol] = get_isolated_atom_energy(
        model, symbol, masses[ind]
    )
    print(
        "Isolated atomic energy for species {}: {} eV".format(
            symbol, isolated_atom_energies[symbol]
        )
    )
print("")

# Sum together isolated atomic energies of each atom in the xyz file
isolated_atom_energies_sum = sum(
    [isolated_atom_energies[elements[code - 1]] for code in species_codes]
)
print("Sum of isolated atomic energies: {} eV".format(isolated_atom_energies_sum))
print("")

# Set atomic masses
mass_string = ""
for type_count in range(numelements):
    mass_string += "variable mass{}_converted equal {}*${{_u_mass}}\n".format(
        type_count + 1, masses[type_count]
    )
    mass_string += "mass {} ${{mass{}_converted}}\n".format(
        type_count + 1, type_count + 1
    )

# Create the LAMMPS input file
with open(infile, "w") as in_file:
    in_file.write(
        MAIN_LAMMPS_INPUT_TEMPLATE.format(
            modelname=model,
            symbol=elementinput,
            numelements=numelements,
            xyzfile=xyzfile,
            set_masses=mass_string,
            dumpfile=dumpfile,
        )
    )
run_lammps(infile, logfile)

### Now process the output and dumpfile for relevant information
# Get unrelaxed potential energy
with open(logfile) as outfl:
    output = outfl.read()
try:
    unrelaxedpoteng = POTENG_MATCH.search(output).group(1)
except AttributeError:
    raise Exception(
        "Error: Failed to find the initial potential energy in the LAMMPS output log"
    )

unrelaxedpoteng = float(unrelaxedpoteng)
print("Raw unrelaxed total potential energy: {} eV".format(unrelaxedpoteng))
unrelaxedpoteng = unrelaxedpoteng - isolated_atom_energies_sum
print(
    "Unrelaxed energy with isolated atomic energies subtracted: {} eV "
    "".format(unrelaxedpoteng)
)
print("")
print("")

# Get relaxed potential energy
try:
    text_before_Loop_time = (
        re.search(r"""(.*)(Loop time)""", output, re.DOTALL).group().split("\n")
    )
    relaxedpoteng = (
        text_before_Loop_time[len(text_before_Loop_time) - 2].lstrip().rstrip()
    )
except AttributeError:
    raise Exception(
        "Failed to find the relaxed potential energy in the LAMMPS output log"
    )

relaxedpoteng = float(relaxedpoteng)
print("Raw relaxed total potential energy: {} eV".format(relaxedpoteng))
relaxedpoteng = relaxedpoteng - isolated_atom_energies_sum
print(
    "Relaxed energy with isolated atomic energies subtracted: {} eV "
    "".format(relaxedpoteng)
)
print("")
print("")

# Get configuration
with open(dumpfile) as dumpfl:
    dump = dumpfl.read()

# Find the number of atoms
try:
    natoms = int(NATOM_MATCH.search(dump).group(1))
    # results['natoms'] = natoms
except AttributeError:
    raise Exception("Failed to find the number of atoms")

# Process the rest of the dump file, the atom positions, etc
itercount = 0
try:
    unrelaxedatomslines = UNRELAXED_ATOMS_LINES.search(dump).group(1)
except AttributeError:
    raise Exception("Failed to find the unrelaxed ATOMS block in the LAMMPS dump file")
try:
    for relaxedatoms in RELAXED_ATOMS_LINES.finditer(dump):
        itercount = itercount + 1
        pass
    linecount = (itercount - 1) * (9 + natoms)
    # relaxedatoms_re=r""".{"""+str(relaxedatoms.start()+36)+r"""}(.*)"""
    with open(dumpfile) as dumpfl:
        relaxedatomslines_pre = list(
            islice(dumpfl, linecount + 9, linecount + 9 + natoms)
        )
    relaxedatomslines = list([el.rstrip() for el in relaxedatomslines_pre])
except AttributeError:
    raise Exception("Failed to find the relaxed ATOMS blocks in the LAMMPS dump file")

nanflag = 0
# Check poteng for NaN
if numpy.isnan(float(relaxedpoteng)) or numpy.isnan(float(unrelaxedpoteng)):
    nanflag = 1

# Create Python dictionaries to pack results into and convert to EDN
configspecies = {}
configspecies["source-value"] = []
unrelaxedconfigpos = {}
unrelaxedconfigpos["source-value"] = []
unrelaxedconfigpos["source-unit"] = "angstrom"
unrelaxedconfigforce = {}
unrelaxedconfigforce["source-value"] = []
unrelaxedconfigforce["source-unit"] = "eV/angstrom"
for line in unrelaxedatomslines.split("\n"):
    if line:
        fields = [float(l) for l in line.split()[1:]]
        fields[0] = elements[int(fields[0]) - 1]
        for f in fields[1 : 6 + 1]:
            if numpy.isnan(f):
                nanflag = 1
        configspecies["source-value"].append(fields[0])
        unrelaxedconfigpos["source-value"].append([fields[1], fields[2], fields[3]])
        unrelaxedconfigforce["source-value"].append([fields[4], fields[5], fields[6]])

results["species"] = configspecies
results["unrelaxed-configuration-positions"] = unrelaxedconfigpos
results["unrelaxed-configuration-forces"] = unrelaxedconfigforce
results["unrelaxed-potential-energy"]["source-value"] = float(unrelaxedpoteng)
results["unrelaxed-potential-energy"]["source-unit"] = "eV"

relaxedconfigpos = {}
relaxedconfigpos["source-value"] = []
relaxedconfigpos["source-unit"] = "angstrom"
relaxedconfigforce = {}
relaxedconfigforce["source-value"] = []
relaxedconfigforce["source-unit"] = "eV/angstrom"
for line in relaxedatomslines:
    fields = [float(l) for l in line.split()[1:]]
    fields[0] = elements[int(fields[0]) - 1]
    for f in fields[1 : 6 + 1]:
        if numpy.isnan(f):
            nanflag = 1
    relaxedconfigpos["source-value"].append([fields[1], fields[2], fields[3]])
    relaxedconfigforce["source-value"].append([fields[4], fields[5], fields[6]])

results["relaxed-configuration-positions"] = relaxedconfigpos
results["relaxed-configuration-forces"] = relaxedconfigforce
results["relaxed-potential-energy"]["source-value"] = float(relaxedpoteng)
results["relaxed-potential-energy"]["source-unit"] = "eV"

# If none of the reported quantities was NaN, print a results.edn file
if nanflag == 0:
    resultsedn = open("output/results.edn", "w")
    resultsedn.write(
        re.sub(
            '"([0-9e\-\+\.]+)"',
            stripquotes,
            json.dumps(results, separators=(" ", " "), indent=2, sort_keys=False),
        )
    )
    resultsedn.close()