#!/usr/bin/env python
"""
LAMMPS TriclinicPBCEnergyAndForces test:
This test computes the total potential energy and forces acting on a triclinic periodic 3D box of atoms, supplied by the user in the form of an extended xyz file which lies in the Test's directory.
The extended xyz file must be of the following form:
Natoms
Lattice=" Ax Ay Az Bx By Bz Cx Cy Cz "
Element positionx positiony positionz
Element positionx positiony positionz
.
.
.
where A, B, and C are the simulation box edge vectors (periodic boundary condition).
(Note that all entries in the xyz file are separated by single spaces)
The pipeline.stdin.tpl file which accompanies the Test which uses this Test Driver should
contain a list of elements in the form 'Element1 Element2 ...' (separated by single spaces)
and a list of the masses (ordered the same as the elements list) separated by single spaces.

Author: Daniel S. Karls, University of Minnesota (karl0100 |AT| umn DOT edu)
"""

import json, fileinput, simplejson, subprocess, os, itertools, shutil, re, numpy
from collections import namedtuple, OrderedDict

def stripquotes(matchobj):
    return matchobj.group(1)

# Read elements, masses, and model from stdin
elementinput = raw_input("elements=")
elements = elementinput.split(' ')
numelements = len(elements)
if numelements < 1:
    raise Exception("ERROR: No species were listed! Exiting...")
    exit(1)
massesinput = raw_input("masses (must match order of elements above)=")
masses = massesinput.split(' ')
model = raw_input("modelname=")
print ' '

# Some directories we need
THIS_DIR = os.path.dirname(__file__)
LAMMPS_INPUT_FILE = open(os.path.join(THIS_DIR, 'in.lammps')).read()
INDIR = os.path.join('output','lammps_inputs')
OUTDIR = os.path.join('output','lammps_output_log')
DUMPDIR = os.path.join('output','lammps_dump')

# Ensure the directories we need are created
try:
    os.makedirs(INDIR)
except OSError:
    pass
try:
    os.makedirs(OUTDIR)
except OSError:
    pass
try:
    os.makedirs(DUMPDIR)
except OSError:
    pass

# Files LAMMPS uses or produces
infile = os.path.join(INDIR,"in.lammps")
outfile = os.path.join(OUTDIR,"lammps.log")
dumpfile = os.path.join(DUMPDIR,"lammps.dump")
datafile = os.path.join(INDIR,"TriclinicPBC.data")

# Create results dictionary
results = OrderedDict()
results['property-id'] = 'tag:staff@noreply.openkim.org,2014-04-15:property/configuration-nonorthogonal-periodic-3d-cell-fixed-particles-fixed'
results['instance-id'] = 1

### REGULAR EXPRESSIONS FOR MATCHING LAMMPS OUTPUT

# Finds potential energy
POTENG_MATCH = re.compile(r"""
        PotEng  #MAGIC Word
        .*$\n   #till end of line
        \s*      #possible leading whitespace
        ([0-9e.\-]+)   #out float, grouped
        """,flags=re.VERBOSE|re.MULTILINE)

# Finds number of atoms
NATOM_MATCH = re.compile(r"""
        ITEM:\ NUMBER\ OF\ ATOMS  #MAGIC WORDS
        .*$\n                     #till end of line
        \s*                        #possible leading whitespace
        (\d+)                     #our count, grouped
        """,flags=re.VERBOSE|re.MULTILINE)

# Finds the BOX BOUNDS lines in the LAMMPS dumpfile
#BOXBOUNDS_LINES = re.compile(r"""
#        ITEM:\ BOX\ BOUNDS\        #magic words
#        .*?\n               #till end of line, nongreedy dot
#        (.*)                #rest of file
#        """,flags=re.VERBOSE|re.DOTALL)

# Finds the ATOMS lines in the LAMMPS dumpfile
ATOMS_LINES = re.compile(r"""
        ITEM:\ ATOMS        #magic words
        .*?\n               #till end of line, nongreedy dot
        (.*)                #rest of file
        """,flags=re.VERBOSE|re.DOTALL)

LATTICE_FIELD = re.compile(r"Lattice=\"([0-9e\.\-\s]*)\"")

# Read xyz file and write LAMMPS data file
atom_count = 0
original_pos = [] # Vector for storing original positions
with open('triclinicpbc.xyz','r') as xyzfile, open(datafile,'w') as datafl:
    numatoms = int(xyzfile.readline()) # get number of atoms
    vecstring = LATTICE_FIELD.search(xyzfile.readline()).groups()[0]
    cellvecs = [a for a in vecstring.strip().split(' ')] # read box vectors from comment line
    Ax=float(cellvecs[0]) # Remember that cellvecs[0] should be equal to the string 'Lattice="'
    Ay=float(cellvecs[1])
    Az=float(cellvecs[2])
    Bx=float(cellvecs[3])
    By=float(cellvecs[4])
    Bz=float(cellvecs[5])
    Cx=float(cellvecs[6])
    Cy=float(cellvecs[7])
    Cz=float(cellvecs[8])
    # Compute xlo, xhi, ylo, yhi, zlo, zhi, xy, xz, yz from A, B, and C
    # [ Note 1 ]: In LAMMPS, arbitrarily oriented triclinic boxes are not permitted.  Instead, any triclinic boxes (with vectors A,B,C) must be rotated to conform to
    # a conventional orientation where the new box has vectors a,b,c.  Vector 'a' must point along the positive x-axis, vector 'b' must lie in the xy plane
    # and have a strictly positive y-component, and vector 'c' must have a strictly positive z-component.  The below definitions of ax,ay,az,bx,by,bz,cx,cy,cz
    # correspond to performing this rotation.
    # ( See http://lammps.sandia.gov/doc/Section_howto.html#howto_12 )
    ax=numpy.sqrt(Ax*Ax+Ay*Ay+Az*Az)
    ay=0.0
    az=0.0
    bx=(Bx*Ax+By*Ay+Bz*Az)/ax
    by=numpy.sqrt( (Bx*Bx+By*By+Bz*Bz)-bx*bx )
    bz=0.0
    cx=(Cx*Ax+Cy*Ay+Cz*Az)/ax
    cy=((Bx*Cx+By*Cy+Bz*Cz)-bx*cx)/by
    cz=numpy.sqrt((Cx*Cx+Cy*Cy+Cz*Cz)-cx*cx-cy*cy)
    xlo=0.0
    xhi=ax
    ylo=0.0
    yhi=by
    zlo=0.0
    zhi=cz
    xy=bx
    xz=cx
    yz=cy
    V=numpy.fabs(Ax*(By*Cz-Cy*Bz) + Ay*(Bz*Cx-Bx*Cz) + Az*(Bx*Cy-Cx*By))
    # Check if volume of box is 0.  If it is 0, divide by 0 error is thrown and test driver will halt
    check = 1/V

    # [ Note 2 ]: If we think of the rotation referred to in Note 1 above as a matrix 'R', then the below quantities are the components of the matrix inverse of 'R'.
    # This is necessary because the test driver wants to return the atomic forces in the same basis that the user used in constructing the extended xyz file, not in the rotated basis that LAMMPS has to use.
    Rinv11=((-az*by*Cx + ay*bz*Cx + az*Bx*cy - Ax*bz*cy - ay*Bx*cz + Ax*by*cz)*V)/((az*by*cx - ay*bz*cx - az*bx*cy + ax*bz*cy + ay*bx*cz - ax*by*cz)*(Az*By*Cx - Ay*Bz*Cx - Az*Bx*Cy + Ax*Bz*Cy + Ay*Bx*Cz - Ax*By*Cz))
    Rinv12=((-az*Bx*cx + Ax*bz*cx + az*bx*Cx - ax*bz*Cx - Ax*bx*cz + ax*Bx*cz)*V)/((az*by*cx - ay*bz*cx - az*bx*cy + ax*bz*cy + ay*bx*cz - ax*by*cz)*(Az*By*Cx - Ay*Bz*Cx - Az*Bx*Cy + Ax*Bz*Cy + Ay*Bx*Cz - Ax*By*Cz))
    Rinv13=((az*Bx*cx - Ax*by*cx-ay*bx*Cx + ax*by*Cx + Ax*bx*cy - ax*Bx*cy)*V)/((az*by*cx - ay*bz*cx - az*bx*cy + ax*bz*cy + ay*bx*cz - ax*by*cz)*(Az*By*Cx - Ay*Bz*Cx - Az*Bx*Cy + Ax*Bz*Cy + Ay*Bx*Cz - Ax*By*Cz))
    Rinv21=((az*By*cy - Ay*bz*cy - az*by*Cy + ay*bz*Cy + Ay*by*cz - ay*By*cz)*V)/((az*by*cx - ay*bz*cx - az*bx*cy + ax*bz*cy + ay*bx*cz - ax*by*cz)*(Az*By*Cx - Ay*Bz*Cx - Az*Bx*Cy + Ax*Bz*Cy + Ay*Bx*Cz - Ax*By*Cz))
    Rinv22=((-az*By*cx + Ay*bz*cx + az*bx*Cy - ax*bz*Cy - Ay*bx*cz + ax*By*cz)*V)/((az*by*cx - ay*bz*cx - az*bx*cy + ax*bz*cy + ay*bx*cz - ax*by*cz)*(Az*By*Cx - Ay*Bz*Cx - Az*Bx*Cy + Ax*Bz*Cy + Ay*Bx*Cz - Ax*By*Cz))
    Rinv23=((-Ay*by*cx + ay*By*cx + Ay*bx*cy - ax*By*cy - ay*bx*Cy + ax*by*Cy)*V)/((az*by*cx - ay*bz*cx - az*bx*cy + ax*bz*cy + ay*bx*cz - ax*by*cz)*(Az*By*Cx - Ay*Bz*Cx - Az*Bx*Cy + Ax*Bz*Cy + Ay*Bx*Cz - Ax*By*Cz))
    Rinv31=((-Az*bz*cy + az*Bz*cy + Az*by*cz - ay*Bz*cz - az*by*Cz + ay*bz*Cz)*V)/((az*by*cx - ay*bz*cx - az*bx*cy + ax*bz*cy + ay*bx*cz - ax*by*cz)*(Az*By*Cx - Ay*Bz*Cx - Az*Bx*Cy + Ax*Bz*Cy + Ay*Bx*Cz - Ax*By*Cz))
    Rinv32=((Az*bz*cx - az*Bz*cx - Az*bx*cz + ax*Bz*cz + az*bx*Cz - ax*bz*Cz)*V)/((az*by*cx - ay*bz*cx - az*bx*cy + ax*bz*cy + ay*bx*cz - ax*by*cz)*(Az*By*Cx - Ay*Bz*Cx - Az*Bx*Cy + Ax*Bz*Cy + Ay*Bx*Cz - Ax*By*Cz))
    Rinv33=((-Az*by*cx + ay*Bz*cx + Az*bx*cy - ax*Bz*cy - ay*bx*Cz + ax*by*Cz)*V)/((az*by*cx - ay*bz*cx - az*bx*cy + ax*bz*cy + ay*bx*cz - ax*by*cz)*(Az*By*Cx - Ay*Bz*Cx - Az*Bx*Cy + Ax*Bz*Cy + Ay*Bx*Cz - Ax*By*Cz))
    datafl.write("Triclinic periodic calculation\n\n")
    datafl.write("%s atoms\n" % str(numatoms))
    datafl.write("%s %s xlo xhi\n" % (xlo,xhi))
    datafl.write("%s %s ylo yhi\n" % (ylo,yhi))
    datafl.write("%s %s zlo zhi\n" % (zlo,zhi))
    datafl.write("%s %s %s xy xz yz\n" % (xy,xz,yz))
    datafl.write("%s atom types\n" % str(numelements))
    datafl.write(' \n')
    datafl.write('Atoms\n')
    datafl.write(' \n')
    for atom_count in xrange(numatoms):
        atomln=xyzfile.readline().strip().split(' ')
        atomtype=elements.index(atomln[0])+1
        posx=float(atomln[-3])
        posy=float(atomln[-2])
        posz=float(atomln[-1])
        original_pos.append([atomln[-3],atomln[-2],atomln[-1]]) # Store positions as strings
        # Transform atomic positions to the new basis
        intermediate_vecx=(By*Cz-Cy*Bz)*posx+(Bz*Cx-Bx*Cz)*posy+(Bx*Cy-Cx*By)*posz
        intermediate_vecy=(Cy*Az-Ay*Cz)*posx+(Cz*Ax-Cx*Az)*posy+(Cx*Ay-Ax*Cy)*posz
        intermediate_vecz=(Ay*Bz-By*Az)*posx+(Az*Bx-Ax*Bz)*posy+(Ax*By-Bx*Ay)*posz
        newposx=(ax*intermediate_vecx+bx*intermediate_vecy+cx*intermediate_vecz)/V
        newposy=(by*intermediate_vecy+cy*intermediate_vecz)/V
        newposz=(cz*intermediate_vecz)/V
        datafl.write("%s %s %s %s %s\n" % (str(atom_count+1),str(atomtype),str(newposx),str(newposy),str(newposz)))
    datafl.write(' \n')
    datafl.write('Masses\n')
    datafl.write(' \n')
    type_count=0
    for type_count in xrange(numelements):
        datafl.write("%s %s\n" % (str(type_count+1),masses[type_count]))
    datafl.close()

# Create the LAMMPS input file
with open(infile,'w') as in_file:
    in_file.write(LAMMPS_INPUT_FILE.format(modelname=model,
        symbol=elementinput,datafile=datafile,dumpfile=dumpfile))

# Run LAMMPS with given infile and write the output to outfile
with open(infile) as infl, open(outfile,'w') as outfl:
    try:
        lammps_process = subprocess.check_call('lammps', shell=False, stdout=outfl, stdin=infl)
    except subprocess.CalledProcessError:
        extrainfo = ""
        try:
            with open("log.lammps") as f:
                extrainfo = f.read()
        except IOError:
            extrainfo = "no log file"
        raise Exception("LAMMPS did not exit propertly:\n"+extrainfo)

### Now process the output and dumpfile for relevant information
# Get potential energy
with open(outfile) as outfl:
    output = outfl.read()
try:
    poteng = {}
    poteng['source-value'] = POTENG_MATCH.search(output).group(1)
    poteng['source-unit'] = "eV"
except AttributeError:
    raise Exception("Failed to find the potential energy")

# Get configuration
with open(dumpfile) as dumpfl:
    dump = dumpfl.read()
# Find the number of atoms
try:
    natoms = int(NATOM_MATCH.search(dump).group(1))
    #results.setdefault(triclinic_file,{})['natoms'] = natoms
except AttributeError:
    raise Exception("Failed to find the number of atoms")

# Process the rest of the dump file, the atom positions, etc
try:
    atomslines = ATOMS_LINES.search(dump).group(1)
except AttributeError:
    raise Exception("Failed to find ATOMS block")


# Set lattice vectors to original orientation
unrelaxedcellvec1={}
unrelaxedcellvec2={}
unrelaxedcellvec3={}
unrelaxedcellvec1['source-unit']='angstrom'
unrelaxedcellvec2['source-unit']='angstrom'
unrelaxedcellvec3['source-unit']='angstrom'
unrelaxedcellvec1['source-value']=[cellvecs[0], cellvecs[1], cellvecs[2]]
unrelaxedcellvec2['source-value']=[cellvecs[3], cellvecs[4], cellvecs[5]]
unrelaxedcellvec3['source-value']=[cellvecs[6], cellvecs[7], cellvecs[8]]

nanflag=0
# Check poteng for NaN
if numpy.isnan(float(poteng['source-value'])):
    nanflag=1

atom_count=0
configspecies={}
configspecies['source-value']=[]
unrelaxedconfigpos={}
unrelaxedconfigpos['source-value']=[]
unrelaxedconfigpos['source-unit']='angstrom'
unrelaxedconfigforce={}
unrelaxedconfigforce['source-value']=[]
unrelaxedconfigforce['source-unit']='eV/angstrom'
for line in atomslines.split('\n'):
    if line:
        fields = [float(l) for l in line.split()[1:]]
        # Convert LAMMPS 'type' to element string
        fields[0] = elements[int(fields[0])-1]
        # Set positions to original values (we do this instead of applying the inverse rotation to the LAMMPS positions because LAMMPS always wraps atoms that are outside of a periodic box.  We want to preserve the original positions in case the user provides an extended xyz which features atoms which are outside of the triclinic box.
        fields[1]=original_pos[atom_count][0]
        fields[2]=original_pos[atom_count][1]
        fields[3]=original_pos[atom_count][2]
        if numpy.isnan(float(fields[1])):
            nanflag=1
        if numpy.isnan(float(fields[2])):
            nanflag=1
        if numpy.isnan(float(fields[3])):
            nanflag=1
        # Rotate forces back to original orientation
        fx=fields[4]
        fy=fields[5]
        fz=fields[6]
        fields[4] = Rinv11*fx+Rinv12*fy+Rinv13*fz
        fields[5] = Rinv21*fx+Rinv22*fy+Rinv23*fz
        fields[6] = Rinv31*fx+Rinv32*fy+Rinv33*fz
        if numpy.isnan(fields[4]):
            nanflag=1
        if numpy.isnan(fields[5]):
            nanflag=1
        if numpy.isnan(fields[6]):
            nanflag=1
        atom_count=atom_count+1
        configspecies['source-value'].append(fields[0])
        unrelaxedconfigpos['source-value'].append([fields[1], fields[2], fields[3]])
        unrelaxedconfigforce['source-value'].append([fields[4], fields[5], fields[6]])


results['species']=configspecies
results['unrelaxed-periodic-cell-vector-1']=unrelaxedcellvec1
results['unrelaxed-periodic-cell-vector-2']=unrelaxedcellvec2
results['unrelaxed-periodic-cell-vector-3']=unrelaxedcellvec3
results['unrelaxed-configuration-positions']=unrelaxedconfigpos
results['unrelaxed-configuration-forces']=unrelaxedconfigforce
results['unrelaxed-potential-energy']=poteng

# If none of the reported quantities was NaN, print a results.edn file
if nanflag==0:
    resultsedn=open('output/results.edn','w')
    resultsedn.write(re.sub('"([0-9\.e\+\-]+)"',stripquotes,json.dumps(results,separators=(" "," "),indent=2,sort_keys=False)))
    resultsedn.close()