#!/usr/bin/env python

###############################################################################
#
# Test driver to compute the linear thermal expansion of a cubic crystal at
# finite temperature and pressure.
#
###############################################################################

import os
import multiprocessing
import numpy as np

###############################################################################
#
# read_params
#
# read parameters from stdin
#
# Argument:
# x: input class instance
#
###############################################################################

def read_params(x):
  x.modelname = raw_input('Please enter a KIM Model extended-ID:\n')
  print x.modelname

  x.species = raw_input('Please enter the species symbol (e.g. Si, Au, Al, etc.):\n')
  print x.species

  msg = 'Please enter the atomic mass of the species (g/mol):\n'
  x.mass = read_num(msg, 'mass')
  is_positive(x.mass, 'mass')

  x.latticetype = raw_input('Please enter the lattice type (bcc, fcc, sc, or diamond):\n')
  print x.latticetype

  msg = 'Please specify the lattice constant (meters):\n'
  x.latticeconst = read_num(msg, 'lattice constant')
  is_positive(x.latticeconst, 'lattice constant')

  msg = 'Please enter the temperature (Kelvin):\n'
  x.temperature = read_num(msg, 'temperature')
  is_positive(x.temperature, 'temperature')

  msg = 'Please enter the hydrostatic pressure (Mpa):\n'
  x.press = read_num(msg, 'stress')
  # convert unit of pressure from Mpa to Bar
  x.press = x.press*10


###############################################################################
#
# read_num
#
# Read numerical parameter; raise error if the input is not a numerical value.
#
# Arguments:
# asking: information string about the parameter
# name: parameter name
#
###############################################################################

def read_num(asking, name):
  try:
    value_str = raw_input(asking)
    print value_str
    value = float(value_str)
  except ValueError:
    raise ValueError("Incorrect input `%s = %s'; `%s' should be a numerical value."
                    %(name, value_str, name))
  return value


###############################################################################
#
# is_positive
#
# Whether a numerical parameter is positive or not. If not, raise error.
#
# Arguments:
# var: parameter to be tested
# name: name string of the parameter
#
###############################################################################

def is_positive(var, name):
  if  var < 0.0:
    raise ValueError("Incorrect input `%s = %s'; `%s' should be positive."
                     %(name, str(var), name))


###############################################################################
#
# get_input_file
#
# Generate LAMMPS input file from the template input file.
#
# Arguments:
# template: name string of the template lammps input file
# lmpinput: name string of the generated lammps input file
# T: temperature
# vol_file: name string of a file storing the volume of the system
#
###############################################################################

def get_input_file(template, lmpinput, x, T, vol_file):
  replacements = {'rpls_latticetype':x.latticetype.lower(),
                  'rpls_latticeconst':str(x.latticeconst*1e10), #1e10 to transfer units to Ang
                  'rpls_modelname':x.modelname,
                  'rpls_species':x.species,
                  'rpls_mass':str(x.mass),
                  'rpls_temp':str(T),
                  'rpls_press':str(x.press),
                  'rpls_vol_file':vol_file}

  with open(template, 'r') as readfile, open(lmpinput, 'w') as writefile:
    for line in readfile:
      for src, target in replacements.iteritems():
        line = line.replace(src, target)
      writefile.write(line)


###############################################################################
#
# run_lammps
#
# Run lammps with `lmpinput' as the input.
#
# Arguments:
# lmpinput: name string of lammps input file
#
###############################################################################

def run_lammps(lmpinput):
  cmd = 'lammps -in '+lmpinput +' > output/screen.out'
  os.system(cmd)


###############################################################################
#
# compute_alpha
#
# compute the linear thermal expansion coefficient
#
# Arguments:
# volfile: list of name strings of the volume files
# Temp: temperature list
#
###############################################################################

def compute_alpha(volfile, Temp):
  volume = []
  for fin in volfile:
    with open(fin, 'r') as f:
      tmp = f.read().rstrip()
      if tmp == 'not_converged':
        return None
      else:
        volume.append(float(tmp))

  # compute delta_a/delta_T (linear least mean square fit)
  cubic_root_volume = [i**(1/3.0) for i in volume]
  (slope, ycut) = np.polyfit(Temp, cubic_root_volume, 1)

  # linear thermal expansion coeff
  return slope/float(cubic_root_volume[len(cubic_root_volume)/2])



###############################################################################
#
# Write result in `edn' format
#
# Arguments:
# template: name string of template result file
# result: name string of result file
# x: input class instance
#
###############################################################################

def write_result(template, result, x):
  if x.latticetype.lower() == 'bcc':
    spacegroup = '"Im-3m"'
    wyckoffmultletter = '"2a"'
    wyckoffcoords = '[ 0 0 0 ]'
    wyckoffspecies = '"%s"' %(x.species)
    basisatomcoords = '''[   0    0    0 ]
                         [ 0.5  0.5  0.5 ]'''
    specieslist = ('"'+x.species+'" ')*2
  elif x.latticetype.lower() == 'fcc':
    wyckoffmultletter = '"4a"'
    wyckoffcoords = '[ 0 0 0 ]'
    wyckoffspecies = '"%s"' %(x.species)
    spacegroup = '"Fm-3m"'
    basisatomcoords = '''[   0    0    0 ]
                         [   0  0.5  0.5 ]
                         [ 0.5    0  0.5 ]
                         [ 0.5  0.5    0 ]'''
    specieslist = ('"'+x.species+'" ')*4
  elif x.latticetype.lower() == 'sc':
    wyckoffmultletter = '"1a"'
    wyckoffcoords = '[ 0 0 0 ]'
    wyckoffspecies = '"%s"' %(x.species)
    spacegroup = '"Pm-3m"'
    basisatomcoords = '[ 0 0 0 ]'
    specieslist = '"%s"' %(x.species)
  elif x.latticetype.lower() == 'diamond':
    wyckoffmultletter = '"8a"'
    wyckoffcoords = '[ 0 0 0 ]'
    wyckoffspecies = '"%s"' %(x.species)
    spacegroup = '"Fd-3m"'
    basisatomcoords = '''[ 0 0 0 ]
                       [ 0 0.5 0.5 ]
                       [ 0.5 0.5 0 ]
                       [ 0.5 0 0.5 ]
                       [ 0.75 0.25 0.75 ]
                       [ 0.25 0.25 0.25 ]
                       [ 0.25 0.75 0.75 ]
                       [ 0.75 0.75 0.25 ]'''
    specieslist = ('"'+x.species+'" ')*8
  else:
    raise Exception("input lattice type `%s' not supported." %(x.latticetype))

  latticetype = '"'+x.latticetype.lower()+'"'
  stressarray = [x.press for i in range(3)]
  stressarray.extend([0.0, 0.0, 0.0])
  stressarray= ' '.join(map(str, stressarray))

  # replace placeholder strings in results file
  replacements = {'_LATTICETYPE_':latticetype,
                  '_LATTICECONST_':str(x.latticeconst*1e10), #1e10 to transfer units to Ang
                  '_SPECIES_':specieslist,
                  '_WYCKOFFCOORDS_':wyckoffcoords,
                  '_WYCKOFFMULTLETTER_':wyckoffmultletter,
                  '_WYCKOFFSPECIES_':wyckoffspecies,
                  '_BASISATOMCOORDS_':str(basisatomcoords),
                  '_SPACEGROUP_':str(spacegroup),
                  '_TEMPERATURE_':str(x.temperature),
                  '_THERMEXPANCOEFF_':str(alpha),
                  '_CAUCHYSTRESS_':str(stressarray)}

  with open(template, 'r') as readfile, open(result, 'w') as writefile:
    for line in readfile:
      for src, target in replacements.iteritems():
        line = line.replace(src, target)
      writefile.write(line)


###############################################################################
#
# create_jobs
# create N = len(Temp) LAMMPS input(s) and run them
#
# Arguments:
# Temp: temperature list
# volfle: list of name strings of volume file
#
###############################################################################

def create_jobs(Temp, volfile):

  # lammps input template
  TDdirectory = os.path.dirname(os.path.realpath(__file__))
  tempfile = TDdirectory+'/lammps.in.template'

  # create N = len(Temp) jobs at N = len(Temp) different temperatures
  jobs = []
  infile = []
  for i in range(len(Temp)):
    # LAMMPS input file names
    infile.append('output/lmpinfile_T'+str(T[i])+'.in')
    # get input file from
    get_input_file(tempfile, infile[i], param, Temp[i], volfile[i])
    # record jobs
    p = multiprocessing.Process(target=run_lammps, args=(infile[i],))
    jobs.append(p)

  # submit the jobs
  for p in jobs:
    p.start()
  # wait for the jobs to complete
  for p in jobs:
    p.join()

  return jobs


###############################################################################
#
# main function
#
###############################################################################

# read parameters from stdin
class Input:
  pass
param = Input()
read_params(param)

# create temperature list and LAMMPS volume output file list
dT = 20       # temperautre interval
if param.temperature - 2*dT < 0.0:
  T_lowest = 0.0
else:
  T_lowest = param.temperature - 2*dT
T = [T_lowest + i*dT for i in range(5)]
lmpvolfile = ['output/vol_T'+str(T[i])+'.out' for i in range(5)]

# run LAMMPS at N = len(T) temperautres simutaneously
jobs = create_jobs(T, lmpvolfile)

# compute linear thermal expansion coefficient
alpha = compute_alpha(lmpvolfile, T)

# write result
if alpha is not None:
  TDdirectory = os.path.dirname(os.path.realpath(__file__))
  tempfile = TDdirectory+'/results.edn.tpl'
  resultfile = 'output/results.edn'
  write_result(tempfile, resultfile, param)
  print "ALL DONE"
else:
  print ('Error: the temperature or pessure has not converged within the simulation '
         'steps specified in the Test Driver. Linear thermal expansion coefficient '
         'cannot be obtained.')