#!/usr/bin/env python3

###############################################################################
#
# Test driver to compute the linear thermal expansion of a cubic crystal at
# finite temperature and pressure.
#
###############################################################################
import os
import multiprocessing

import numpy as np

###############################################################################
#
# read_params
#
# read parameters from stdin
#
# Argument:
# x: input class instance
#
###############################################################################


def read_params(x):
    x.modelname = input("Please enter a KIM Model extended-ID:\n")
    print("Modelname = {}".format(x.modelname))

    x.species = input("Please enter the species symbol (e.g. Si, Au, Al, etc.):\n")
    print("Species = {}".format(x.species))

    msg = "Please enter the atomic mass of the species (g/mol):\n"
    x.mass = read_num(msg, "mass")
    is_positive(x.mass, "mass")
    print("Mass = {} g/mol".format(x.mass))

    x.latticetype = input("Please enter the lattice type (bcc, fcc, sc, or diamond):\n")
    print("Lattice type = {}".format(x.latticetype))

    msg = "Please specify the lattice constant (meters):\n"
    x.latticeconst = read_num(msg, "lattice constant")
    is_positive(x.latticeconst, "lattice constant")
    print("Lattice constant = {} m".format(x.latticeconst))

    msg = "Please enter the temperature (Kelvin):\n"
    x.temperature = read_num(msg, "temperature")
    is_positive(x.temperature, "temperature")
    print("Temperature = {} Kelvin".format(x.temperature))

    msg = "Please enter the hydrostatic pressure (MPa):\n"
    x.press = read_num(msg, "stress")
    print("Pressure = {} MPa".format(x.press))

    # convert unit of pressure from MPa to Bar
    x.press = x.press * 10


###############################################################################
#
# read_num
#
# Read numerical parameter; raise error if the input is not a numerical value.
#
# Arguments:
# asking: information string about the parameter
# name: parameter name
#
###############################################################################


def read_num(asking, name):
    try:
        value_str = eval(input(asking))
        print(value_str)
        value = float(value_str)
    except ValueError:
        raise ValueError(
            "Incorrect input `%s = %s'; `%s' should be a numerical value."
            % (name, value_str, name)
        )
    return value


###############################################################################
#
# is_positive
#
# Whether a numerical parameter is positive or not. If not, raise error.
#
# Arguments:
# var: parameter to be tested
# name: name string of the parameter
#
###############################################################################


def is_positive(var, name):
    if var < 0.0:
        raise ValueError(
            "Incorrect input `%s = %s'; `%s' should be positive."
            % (name, str(var), name)
        )


###############################################################################
#
# get_input_file
#
# Generate LAMMPS input file from the template input file.
#
# Arguments:
# template: name string of the template lammps input file
# lmpinput: name string of the generated lammps input file
# T: temperature
# vol_file: name string of a file storing the volume of the system
#
###############################################################################


def get_input_file(template, lmpinput, x, T, vol_file):
    replacements = {
        "rpls_latticetype": x.latticetype.lower(),
        "rpls_latticeconst": str(
            x.latticeconst * 1e10
        ),  # 1e10 to transfer units to Ang
        "rpls_modelname": x.modelname,
        "rpls_species": x.species,
        "rpls_mass": str(x.mass),
        "rpls_temp": str(T),
        "rpls_press": str(x.press),
        "rpls_vol_file": vol_file,
    }

    with open(template, "r") as readfile, open(lmpinput, "w") as writefile:
        writefile.write(readfile.read().format(**replacements))


###############################################################################
#
# run_lammps
#
# Run lammps with `lmpinput' as the input.
#
# Arguments:
# lmpinput: name string of lammps input file
#
###############################################################################


def run_lammps(lmpinput):
    cmd = "lammps -in " + lmpinput + " > output/screen.out"
    os.system(cmd)


###############################################################################
#
# compute_alpha
#
# compute the linear thermal expansion coefficient
#
# Arguments:
# volfile: list of name strings of the volume files
# Temp: temperature list
#
###############################################################################


def compute_alpha(volfile, Temp):
    volume = []
    for fin in volfile:
        with open(fin, "r") as f:
            tmp = f.read().strip()
            if tmp == "not_converged":
                return "not_converged"
            elif "Volume keeps increasing" in tmp:
                return "unstable"
            else:
                volume.append(float(tmp))

    # compute delta_a/delta_T (linear least mean square fit)
    cubic_root_volume = [i ** (1 / 3.0) for i in volume]
    (slope, ycut) = np.polyfit(Temp, cubic_root_volume, 1)

    # linear thermal expansion coeff
    return slope / float(cubic_root_volume[round(len(cubic_root_volume) / 2)])


###############################################################################
#
# Write result in `edn' format
#
# Arguments:
# template: name string of template result file
# result: name string of result file
# x: input class instance
#
###############################################################################


def write_result(template, result, x):
    if x.latticetype.lower() == "bcc":
        spacegroup = '"Im-3m"'
        wyckoffmultletter = '"2a"'
        wyckoffcoords = "[ 0 0 0 ]"
        wyckoffspecies = '"%s"' % (x.species)
        basisatomcoords = """[   0    0    0 ]
                         [ 0.5  0.5  0.5 ]"""
        specieslist = ('"' + x.species + '" ') * 2
    elif x.latticetype.lower() == "fcc":
        wyckoffmultletter = '"4a"'
        wyckoffcoords = "[ 0 0 0 ]"
        wyckoffspecies = '"%s"' % (x.species)
        spacegroup = '"Fm-3m"'
        basisatomcoords = """[   0    0    0 ]
                         [   0  0.5  0.5 ]
                         [ 0.5    0  0.5 ]
                         [ 0.5  0.5    0 ]"""
        specieslist = ('"' + x.species + '" ') * 4
    elif x.latticetype.lower() == "sc":
        wyckoffmultletter = '"1a"'
        wyckoffcoords = "[ 0 0 0 ]"
        wyckoffspecies = '"%s"' % (x.species)
        spacegroup = '"Pm-3m"'
        basisatomcoords = "[ 0 0 0 ]"
        specieslist = '"%s"' % (x.species)
    elif x.latticetype.lower() == "diamond":
        wyckoffmultletter = '"8a"'
        wyckoffcoords = "[ 0 0 0 ]"
        wyckoffspecies = '"%s"' % (x.species)
        spacegroup = '"Fd-3m"'
        basisatomcoords = """[ 0 0 0 ]
                       [ 0 0.5 0.5 ]
                       [ 0.5 0.5 0 ]
                       [ 0.5 0 0.5 ]
                       [ 0.75 0.25 0.75 ]
                       [ 0.25 0.25 0.25 ]
                       [ 0.25 0.75 0.75 ]
                       [ 0.75 0.75 0.25 ]"""
        specieslist = ('"' + x.species + '" ') * 8
    else:
        raise Exception("input lattice type `%s' not supported." % (x.latticetype))

    latticetype = '"' + x.latticetype.lower() + '"'
    stressarray = [x.press for i in range(3)]
    stressarray.extend([0.0, 0.0, 0.0])
    stressarray = " ".join(map(str, stressarray))

    # replace placeholder strings in results file
    replacements = {
        "_LATTICETYPE_": latticetype,
        "_LATTICECONST_": str(x.latticeconst * 1e10),  # 1e10 to transfer units to Ang
        "_SPECIES_": specieslist,
        "_WYCKOFFCOORDS_": wyckoffcoords,
        "_WYCKOFFMULTLETTER_": wyckoffmultletter,
        "_WYCKOFFSPECIES_": wyckoffspecies,
        "_BASISATOMCOORDS_": str(basisatomcoords),
        "_SPACEGROUP_": str(spacegroup),
        "_TEMPERATURE_": str(x.temperature),
        "_THERMEXPANCOEFF_": str(alpha),
        "_CAUCHYSTRESS_": str(stressarray),
    }

    with open(template, "r") as readfile, open(result, "w") as writefile:
        writefile.write(readfile.read().format(**replacements))


###############################################################################
#
# create_jobs
# create N = len(Temp) LAMMPS input(s) and run them
#
# Arguments:
# Temp: temperature list
# volfle: list of name strings of volume file
#
###############################################################################


def create_jobs(volfile, Temp):

    # lammps input template
    TDdirectory = os.path.dirname(os.path.realpath(__file__))
    tempfile = TDdirectory + "/lammps.in.template"

    # create N = len(Temp) jobs at N = len(Temp) different temperatures
    jobs = []
    infile = []
    for i in range(len(Temp)):
        # LAMMPS input file names
        infile.append("output/lmpinfile_T" + str(Temp[i]) + ".in")
        # get input file from
        get_input_file(tempfile, infile[i], param, Temp[i], volfile[i])
        # record jobs
        p = multiprocessing.Process(target=run_lammps, args=(infile[i],))
        jobs.append(p)

    # submit the jobs
    for p in jobs:
        p.start()
    # wait for the jobs to complete
    for p in jobs:
        p.join()

    return jobs


###############################################################################
#
# main function
#
###############################################################################

# read parameters from stdin
class Input:
    pass


param = Input()
read_params(param)

# create temperature list and LAMMPS volume output file list
dT = 20  # temperature interval
if param.temperature - 2 * dT < 0.0:
    T_lowest = 0.0
else:
    T_lowest = param.temperature - 2 * dT
T = [round(T_lowest + i * dT, 2) for i in range(5)]
lmpvolfile = ["output/vol_T" + str(T[i]) + ".out" for i in range(5)]

# run LAMMPS at N = len(T) temperatures simutaneously
jobs = create_jobs(lmpvolfile, T)

# compute linear thermal expansion coefficient
alpha = compute_alpha(lmpvolfile, T)

# write result
if alpha == "not_converged":
    print(
        "Error: the temperature or pessure has not converged within the simulation "
        "steps specified in the Test Driver. Linear thermal expansion coefficient "
        "cannot be obtained."
    )
elif alpha == "unstable":
    print(
        "Error: the system may be unstable since the volume keeps increasing. "
        "Linear thermal expansion coefficient cannot be obtained."
    )
else:
    TDdirectory = os.path.dirname(os.path.realpath(__file__))
    tempfile = TDdirectory + "/results.edn.tpl"
    resultfile = "output/results.edn"
    write_result(tempfile, resultfile, param)
    print("Calculation completed")
    print("")