#!/usr/bin/env python
"""
LAMMPS cluster test:
This test computes the total potential energy and forces acting on a finite cluster of atoms in a rectangular box,
supplied by the user in the form of an xyz file which lies in the Test's directory.  It then relaxes this config-
uration using the Polak-Ribiere version of conjugate gradient minimization (see LAMMPS 'minimize' documentation
for more details).

The xyz file specifying the initial structure must be of the following form:
Natoms
xlo xhi ylo yhi zlo zhi
Element positionx positiony positionz
Element positionx positiony positionz
.
.
.
where xlo, xhi, etc. define a rectangular box (fixed boundary condition) in which the atoms are placed.
(Note that all entries in the xyz file are separated by single spaces)
The pipeline.stdin.tpl file which accompanies the Test which uses this Test Driver should
contain a list of elements in the form 'Element1 Element2 ...' (separated by single spaces)
and a list of the masses (ordered the same as the elements list) separated by single spaces.

Author: Daniel S. Karls, University of Minnesota (karl0100 |AT| umn DOT edu)
"""

import fileinput, json, subprocess, os, itertools, shutil, re, numpy, pprint
from collections import namedtuple, OrderedDict
from itertools import islice

def stripquotes(matchobj):
    return matchobj.group(1)

# Read elements, masses, and model from stdin
elementinput = raw_input("elements=")
elements = elementinput.split(' ')
numelements = len(elements)
massesinput = raw_input("masses (must match order of elements above)=")
masses = massesinput.split(' ')
model = raw_input("modelname=")
print(' ')

# Some directories we need
THIS_DIR = os.path.dirname(__file__)
LAMMPS_INPUT_FILE = open(os.path.join(THIS_DIR,'in.lammps')).read()
INDIR = os.path.join('output','lammps_inputs')
OUTDIR = os.path.join('output','lammps_output_log')
DUMPDIR = os.path.join('output','lammps_dump')

# Ensure the directories we need are created
try:
    os.makedirs(INDIR)
except OSError:
    pass
try:
    os.makedirs(OUTDIR)
except OSError:
    pass
try:
    os.makedirs(DUMPDIR)
except OSError:
    pass

# Files LAMMPS uses or produces
infile = os.path.join(INDIR,"in.lammps")
outfile = os.path.join(OUTDIR,"lammps.log")
dumpfile = os.path.join(DUMPDIR,"lammps.dump")
datafile = os.path.join(INDIR,"cluster.data")

# Create results dictionary
results = OrderedDict()
results['property-id'] = "tag:staff@noreply.openkim.org,2014-04-15:property/configuration-cluster-relaxed"
results['instance-id'] = 1
results['species']={}
results['unrelaxed-configuration-positions']={}
results['unrelaxed-configuration-forces']={}
results['unrelaxed-potential-energy']={}
results['relaxed-configuration-positions']={}
results['relaxed-configuration-forces']={}
results['relaxed-potential-energy']={}

### REGULAR EXPRESSIONS FOR MATCHING LAMMPS OUTPUT
# Finds potential energy
POTENG_MATCH = re.compile(r"""
        PotEng        #MAGIC Word
        .*$\n         #to end of line
        \s*           #possible leading whitespace
        ([0-9e.\-]+)   #potential energy float, grouped
        """,flags=re.VERBOSE|re.MULTILINE)

# Finds number of atoms
NATOM_MATCH = re.compile(r"""
        ITEM:\ NUMBER\ OF\ ATOMS  #MAGIC WORDS
        .*$\n                     #to end of line
        \s*                       #possible leading whitespace
        (\d+)                     #Natoms, grouped
        """,flags=re.VERBOSE|re.MULTILINE)

# Finds the ATOMS lines in the LAMMPS dumpfile
UNRELAXED_ATOMS_LINES = re.compile(r"""
        ITEM:\ ATOMS         #magic words
        .*?\n                #to end of line, nongreedy dot
        (.*?)                #unrelaxed atoms positions & forces, grouped
        ITEM:\ TIMESTEP
        """,flags=re.VERBOSE|re.DOTALL)

RELAXED_ATOMS_LINES = re.compile(r"""
        ITEM:\ ATOMS         #magic words
        .*?\n                #to end of line, nongreedy dot
        (.*?)                #relaxed atoms positions & forces, grouped
        """,flags=re.VERBOSE|re.DOTALL)

# Read xyz file and write LAMMPS data file
atom_count = 0
positions=[]
with open('cluster.xyz','r') as xyzfile, open(datafile,'w') as datafl:
    numatoms = int(xyzfile.readline()) # get number of atoms
    boxdims = xyzfile.readline().splitlines()[0].split(' ') # read box dimensions from comment line
    datafl.write("Cluster test driver\n")
    datafl.write("%s atoms\n" % str(numatoms))
    datafl.write("%s %s xlo xhi\n" % (boxdims[0],boxdims[1]))
    datafl.write("%s %s ylo yhi\n" % (boxdims[2],boxdims[3]))
    datafl.write("%s %s zlo zhi\n" % (boxdims[4],boxdims[5]))
    datafl.write(' \n')
    datafl.write("%s atom types\n" % str(numelements))
    datafl.write(' \n')
    datafl.write('Atoms\n')
    datafl.write(' \n')
    for atom_count in xrange(numatoms):
        atomln=xyzfile.readline().splitlines()[0].split(' ')
        atomtype=elements.index(atomln[0])+1
        positions.append([atomln[1],atomln[2],atomln[3]]) # Store positions as strings
        datafl.write("%s %s %s %s %s\n" % (str(atom_count+1),str(atomtype),atomln[1],atomln[2],atomln[3]))
    datafl.write(' \n')
    datafl.write('Masses\n')
    datafl.write(' \n')
    type_count=0
    for type_count in xrange(numelements):
        datafl.write("%s %s\n" % (str(type_count+1),masses[type_count]))
    datafl.close()

# Create the LAMMPS input file
with open(infile,'w') as in_file:
    in_file.write(LAMMPS_INPUT_FILE.format(modelname=model,
        symbol=elementinput,datafile=datafile,dumpfile=dumpfile))

# Run LAMMPS with given infile and write the output to outfile
with open(infile) as infl, open(outfile,'w') as outfl:
    try:
        lammps_process = subprocess.check_call('lammps', shell=False, stdout=outfl, stdin=infl)
    except subprocess.CalledProcessError:
        extrainfo = ""
        try:
            with open("log.lammps") as f:
                extrainfo = f.read()
        except IOError:
            extrainfo = "No LAMMPS log file (log.lammps) found"
        raise Exception("LAMMPS did not exit propertly:\n"+extrainfo)

### Now process the output and dumpfile for relevant information
# Get potential energy
with open(outfile) as outfl:
    output = outfl.read()
try:
    unrelaxedpoteng = POTENG_MATCH.search(output).group(1)
except AttributeError:
    raise Exception("Error: Failed to find the initial potential energy in the LAMMPS output log")
try:
    text_before_Loop_time = re.search(r"""(.*)(Loop time)""",output,re.DOTALL).group().split('\n')
    relaxedpoteng = text_before_Loop_time[len(text_before_Loop_time)-2].lstrip().rstrip()
except AttributeError:
    raise Exception("Failed to find the relaxed potential energy in the LAMMPS output log")

# Get configuration
with open(dumpfile) as dumpfl:
    dump = dumpfl.read()

# Find the number of atoms
try:
    natoms = int(NATOM_MATCH.search(dump).group(1))
    #results['natoms'] = natoms
except AttributeError:
    raise Exception("Failed to find the number of atoms")

# Process the rest of the dump file, the atom positions, etc
itercount=0
try:
    unrelaxedatomslines = UNRELAXED_ATOMS_LINES.search(dump).group(1)
except AttributeError:
    raise Exception("Failed to find the unrelaxed ATOMS block in the LAMMPS dump file")
try:
    for relaxedatoms in RELAXED_ATOMS_LINES.finditer(dump):
        itercount=itercount+1
        pass
    linecount = (itercount-1)*(9+natoms)
    #relaxedatoms_re=r""".{"""+str(relaxedatoms.start()+36)+r"""}(.*)"""
    with open(dumpfile) as dumpfl:
        relaxedatomslines_pre=list(islice(dumpfl,linecount+9,linecount+9+natoms))
    relaxedatomslines = list([ el.rstrip() for el in relaxedatomslines_pre ])
except AttributeError:
    raise Exception("Failed to find the relaxed ATOMS blocks in the LAMMPS dump file")

nanflag=0
# Check poteng for NaN
if numpy.isnan(float(relaxedpoteng)) or numpy.isnan(float(unrelaxedpoteng)):
    nanflag=1

# Create Python dictionaries to pack results into and convert to EDN
atom_count=0
configspecies={}
configspecies['source-value']=[]
unrelaxedconfigpos={}
unrelaxedconfigpos['source-value']=[]
unrelaxedconfigpos['source-unit']="angstrom"
unrelaxedconfigforce={}
unrelaxedconfigforce['source-value']=[]
unrelaxedconfigforce['source-unit'] = "eV/angstrom"
for line in unrelaxedatomslines.split('\n'):
    if line:
        fields=[float(l) for l in line.split()[1:]]
        fields[0] = elements[int(fields[0])-1]
        if numpy.isnan(fields[1]):
            nanflag=1
        if numpy.isnan(fields[2]):
            nanflag=1
        if numpy.isnan(fields[3]):
            nanflag=1
        if numpy.isnan(fields[4]):
            nanflag=1
        if numpy.isnan(fields[5]):
            nanflag=1
        if numpy.isnan(fields[6]):
            nanflag=1
        fields[1]=positions[atom_count][0] # Output the xyz positions as strings
        fields[2]=positions[atom_count][1]
        fields[3]=positions[atom_count][2]
        atom_count=atom_count+1
        configspecies['source-value'].append(fields[0])
        #configspecies['source-value'].append(''.join('"'+fields[0]+'"'))
        unrelaxedconfigpos['source-value'].append([fields[1],fields[2],fields[3]])
        unrelaxedconfigforce['source-value'].append([fields[4],fields[5],fields[6]])

results['species']=configspecies
results['unrelaxed-configuration-positions']=unrelaxedconfigpos
results['unrelaxed-configuration-forces']=unrelaxedconfigforce
results['unrelaxed-potential-energy']['source-value'] = float(unrelaxedpoteng)
results['unrelaxed-potential-energy']['source-unit']="eV"

atom_count=0
relaxedconfigpos={}
relaxedconfigpos['source-value']=[]
relaxedconfigpos['source-unit'] = "angstrom"
relaxedconfigforce={}
relaxedconfigforce['source-value']=[]
relaxedconfigforce['source-unit'] = "eV/angstrom"
for line in relaxedatomslines:
        fields=[float(l) for l in line.split()[1:]]
        fields[0] = elements[int(fields[0])-1]
        if numpy.isnan(fields[1]):
            nanflag=1
        if numpy.isnan(fields[2]):
            nanflag=1
        if numpy.isnan(fields[3]):
            nanflag=1
        if numpy.isnan(fields[4]):
            nanflag=1
        if numpy.isnan(fields[5]):
            nanflag=1
        if numpy.isnan(fields[6]):
            nanflag=1
        atom_count=atom_count+1
        relaxedconfigpos['source-value'].append([fields[1],fields[2],fields[3]])
        relaxedconfigforce['source-value'].append([fields[4],fields[5],fields[6]])

results['relaxed-configuration-positions']=relaxedconfigpos
results['relaxed-configuration-forces']=relaxedconfigforce
results['relaxed-potential-energy']['source-value']=float(relaxedpoteng)
results['relaxed-potential-energy']['source-unit']="eV"

# If none of the reported quantities was NaN, print a results.edn file
if nanflag==0:
    resultsedn=open('output/results.edn','w')
    resultsedn.write(re.sub('"([0-9e\-\+\.]+)"',stripquotes,json.dumps(results,separators=(" "," "),indent=2,sort_keys=False)))
    resultsedn.close()