#!/usr/bin/env python
"""
Lattice constant Test Driver for Hexagonal Structure

Computes the lattice constant for:
1. Any element,
2. HCP crystal,
3. At 0 K, 0 GPa,
by simplex minimization.

Structure Overview:
1. main() deals with I/O and get information from getLatticeConstant()
2. getLatticeConstant() uses an initial guess at the lattice constant 'a' and calls searchLatticeConstants()
3. searchLatticeConstants() minimizes getEnergy() with simplex algorithm
4. When lower energy is obtained with separation > init_upper_guess, collapseHandler() is called

Date: 2015/08/25
Author: Junhao Li <streaver91@gmail.com>

Last Update: 2016/02/26 Junhao Li

             2017/08/28 Daniel S. Karls
"""
# Import External Modules
# ASE Modules
from ase.build import bulk
from ase import Atoms

# KIM Modules
from kimcalculator import KIMCalculator

# Other Python Modules
import sys
import json
import math
import jinja2
import os
import copy
import numpy as np
from scipy.optimize import fmin
from collections import OrderedDict

# Constants for result.edn
SPACE_GROUP = 'P63/mmc'
WCYKOFF_CODE = '2a'
KEY_SOURCE_VALUE = 'source-value'
KEY_SOURCE_UNIT = 'source-unit'
KEY_SOURCE_UNCERT = 'source-std-uncert-value'
UNIT_LENGTH = 'angstrom'
UNIT_TEMPERATURE = 'K'
UNIT_PRESSURE = 'GPa'
UNIT_ENERGY = 'eV'
STRUCTURE_PROP_ID = 'tag:staff@noreply.openkim.org,2014-04-15:property/structure-hexagonal-crystal-npt'
ENERGY_PROP_ID = 'tag:staff@noreply.openkim.org,2014-04-15:property/cohesive-free-energy-hexagonal-crystal'

# Constants frequently used in calculation
SQRT3 = math.sqrt(3.0)
PERFECT_CA = math.sqrt(8.0 / 3.0)
HCP_CUBIC_CELL = np.array([1.0, SQRT3, 1.0])
HCP_CUBIC_POSITIONS = np.array([
    [0.0, 0.0, 0.0],
    [0.5, 0.5 * SQRT3, 0.0],
    [0.5, 0.25 * SQRT3, 0.5],
    [0.0, 0.75 * SQRT3, 0.5]
])
HCP_CUBIC_NATOMS = 4

# Minimization Convergence Criteria
FMIN_FTOL = 1e-10
FMIN_XTOL = 1e-10

def getEnergy(cellVector, meta, cache):
    # cellVector: 1d numpy array of size 1(2), containing a (and c)
    # meta: various information...
    # cache: cache previously created atoms

    # Determine a and c based on the length of the cellVector
    a = cellVector[0]
    if len(cellVector) == 2:
        c = cellVector[1]
    else:
        c = a * PERFECT_CA # Default to perfect c/a ratio

    cellSize = HCP_CUBIC_CELL * [a, a, c]

    repeat = (1, 1, 1)
    nAtoms = HCP_CUBIC_NATOMS

    # Get atoms from cache or create a new one if haven't been created
    sizes = cache['sizes']
    if repeat in sizes:
        sizeId = sizes.index(repeat)
        atoms = cache['atoms'][sizeId]
    else:
        # create a new atoms and save to cache
        print 'Creating new atoms:', repeat
        sizes.append(repeat)
        atoms = meta['prototype'].copy()
        atoms *= repeat
        atoms.set_calculator(KIMCalculator(meta['model']))

        cache['atoms'].append(atoms)

    # Scale atoms with the scaled cellSize
    atoms.set_cell(cellSize, scale_atoms = True)
    energy = atoms.get_potential_energy() / nAtoms
    return energy

def searchLatticeConstants(latticeConstantsGuess, meta, cache):
    # Simplex searching lattice constants that minimize potential energy of atoms
    # Searching starts from latticeConstantsGuess
    print 'Simplex search starting from:', latticeConstantsGuess
    tmpLatticeConstants, tmpEnergy = fmin(
        getEnergy,
        latticeConstantsGuess,
        args = (meta, cache),
        full_output = True,
        ftol = FMIN_FTOL,
        xtol = FMIN_XTOL,
    )[:2]
    return tmpLatticeConstants, tmpEnergy

def getInterlayerDist(latticeConstants):
    # Calculate the distance between atom 0 and 1 (refer to HCP_CUBIC_POSITIONS)
    hrDist = latticeConstants[0] / SQRT3
    vDist = latticeConstants[1] / 2
    return math.sqrt(hrDist * hrDist + vDist * vDist)

def collapseHandler():
    # Called when lower energy is obtained by separating atoms beyond cutoff
    print 'ERR: Lower energy is obtained by separating atoms beyond cutoff.'
    print 'ERR: System Collapsed. Structure Maybe Unstable.'
    sys.exit(0)

def getLatticeConstant(elem, model):
    # Initialize meta and cache
    meta = {
        'model': model
    }
    cache = {
        'sizes': [],
        'atoms': []
    }

    # Create prototype atoms
    atoms = Atoms(elem + '4',
        positions = HCP_CUBIC_POSITIONS,
        cell = HCP_CUBIC_CELL,
        pbc = [1, 1, 1]
    )

    meta['prototype'] = atoms

    # Relaxation with c/a Ratio Fixed to PERFECT_CA
    print 'Relaxation with c/a ratio fixed'
    init_upper_guess = 5.0
    minEnergy = 0
    for i in [x * 0.05 + 0.5 for x in range(-2, 3)]:
        print 'Simplex search starting from: %r * %i' % (init_upper_guess, i)
        tmpLatticeConstants, tmpEnergy = searchLatticeConstants(
            [init_upper_guess * i],
            meta,
            cache
        )
        print 'Tmp Lattice Constants:', tmpLatticeConstants
        print 'Tmp Energy:', tmpEnergy
        if tmpEnergy < minEnergy and tmpLatticeConstants[0] < init_upper_guess:
            # Make sure there are interatomic forces holding the structure
            minLatticeConstants = tmpLatticeConstants
            minEnergy = tmpEnergy
        print '--------'

    if minEnergy == 0:
        collapseHandler()

    # Relaxation With c/a Ratio Relaxed
    print 'Relaxation with c/a ratio relaxed'
    tmpA = minLatticeConstants[0]
    minEnergy = 0
    for i in [x * 0.05 + 1 for x in range(-4, 5)]:
        print 'Simplex search starting from c/a ratio: 1.633 *', i
        tmpLatticeConstants, tmpEnergy = searchLatticeConstants(
            [tmpA, tmpA * PERFECT_CA * i],
            meta,
            cache
        )
        print 'Tmp Lattice Constants:', tmpLatticeConstants
        print 'Tmp Energy:', tmpEnergy
        if tmpEnergy < minEnergy and \
            tmpLatticeConstants[0] < init_upper_guess and \
            getInterlayerDist(tmpLatticeConstants) < init_upper_guess:
            # Make sure there is both intralayer and interlayer forces
            minLatticeConstants = tmpLatticeConstants
            minEnergy = tmpEnergy
        print '--------'

    if minEnergy == 0:
        collapseHandler()

    print 'Lattice Constants:', minLatticeConstants
    print 'Energy:', minEnergy

    # Return Results
    return minLatticeConstants[0], minLatticeConstants[1], minEnergy

def V(value, unit = '', uncert = ''):
    # Generate OrderedDict for KIM JSON Output
    res = OrderedDict([
        (KEY_SOURCE_VALUE, value),
    ])
    if unit != '':
        res.update(OrderedDict([
            (KEY_SOURCE_UNIT, unit),
        ]))
    if uncert != '':
        res.update(OrderedDict([
            (KEY_SOURCE_UNCERT, uncert)
        ]))
    return res

def main():
    # Input Parameters
    elem = raw_input('Element = ')
    lattice = raw_input('Lattice = ')
    model = raw_input('Model = ')

    # Parameters for Debugging
    # ========
    # elem = 'Co'
    # lattice = 'hcp'
    # model = 'EAM_Dynamo_PurjaPun_Mishin_Co__MO_885079680379_001'

    # elem = 'Si'
    # lattice = 'hcp'
    # model = 'EDIP_BOP_Bazant_Kaxiras_Si__MO_958932894036_001'

    # elem = 'Al'
    # lattice = 'hcp'
    # model = 'EAM_Dynamo_Mishin_Farkas_Al__MO_651801486679_001'

    # elem = 'C'
    # lattice = 'hcp'
    # model = 'EAM_Dynamo_Hepburn_Ackland_FeC__MO_143977152728_001'
    # model = 'MEAM_2NN_Fe_to_Ga__MO_145522277939_001'
    # model = 'model_ArCHHeXe_BOP_AIREBO__MO_154399806462_001'
    # model = 'Tersoff_LAMMPS_Erhart_Albe_CSi__MO_903987585848_001'
    # ========

    # Print Inputs
    print 'Element:', elem
    print 'Lattice:', lattice # Not used here
    print 'Model:', model

    # Obtain Lattice Constants and Cohesive Energy
    a, c, energy = getLatticeConstant(elem, model)
    print 'Lattice Constants:', a, c

    # Output Results
    structureResults = OrderedDict([
        ('property-id', STRUCTURE_PROP_ID),
        ('instance-id', 1),
        ('cauchy-stress', V([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UNIT_PRESSURE)),
    ])
    cohesiveEnergyResults = OrderedDict([
        ('property-id', ENERGY_PROP_ID),
        ('instance-id', 2),
        ('cohesive-free-energy', V(-energy, UNIT_ENERGY)),
    ])
    commonResults = OrderedDict([
        ('short-name', V(['hcp'])),
        ('species', V([elem, elem])),
        ('a', V(a, UNIT_LENGTH)),
        ('c', V(c, UNIT_LENGTH)),
        ('basis-atom-coordinates', V([[0.0, 0.0, 0.0], [2.0/3, 1.0/3, 0.5]])),
        ('space-group', V(SPACE_GROUP)),
        ('temperature', V(0, UNIT_TEMPERATURE)),
    ])
    structureResults.update(commonResults)
    cohesiveEnergyResults.update(commonResults)
    results = [
        structureResults,
        cohesiveEnergyResults,
    ]
    resultsString = json.dumps(results, separators = (' ', ' '), indent = 4)
    print resultsString
    with open(os.path.abspath('output/results.edn'), 'w') as f:
        f.write(resultsString)

if __name__ == '__main__':
    main()