#!/usr/bin/env python
"""
Lattice constant Test Driver

Computes the lattice constant for any material and any cubic crystal structure
by minimizing the energy using a simplex minimization

Date: 2012/09/03
Author: Alex Alemi & Matt Bierbaum

"""
import fileinput
from ase.structure import bulk
import scipy.optimize as opt
from kimcalculator import KIMCalculator
from kimservice import KIM_API_get_data_double
import simplejson
import numpy as np
import jinja2
import os
import json
import sys

def get_cutoff(model, q=None):
    """ Attempt to read the cutoff of the model and NBL type """
    calc = KIMCalculator(model)
    slab = bulk(symbol, lattice, a=1)
    slab.set_calculator(calc)
    cutoff = KIM_API_get_data_double(calc.pkim, "cutoff")[0]
    nbcname = calc.get_NBC_method()
    q.put((cutoff, nbcname))

def run_get_cutoff(model):
    from multiprocessing import Process, Queue
    q = Queue()
    p = Process(target=get_cutoff, args=(model,q))
    p.start()
    p.join()
    try:
        return q.get(block=False)
    except Exception as e:
        raise AttributeError("Cutoff could not be read from model %r" % model)

def energy(a, slab, cell, positions):
    """ Compute the energy of a lattice given the lattice constant.
     Multiply the positions for every lattice type since in this
     test only diamond has multiple sites, all other positions are zero.
    """
    slab.set_positions( positions * a )
    slab.set_cell( cell * a )
    energy = slab.get_potential_energy()
    return energy

def get_lattice_constant(model, symbol, lattice):
    cutoff, nbcname = run_get_cutoff(model)
    
    calc = KIMCalculator(model)
    repeat = 0

    if nbcname == "NEIGH_RVEC_F":
        slab = bulk(symbol, lattice, a=1)
        cell = slab.get_cell()
        positions = slab.get_positions()
        slab.set_calculator(calc)
   
        particles = len(slab)

    if nbcname == "NEIGH_PURE_F":
        # now that we have the cutoff, actually do some calcs
        smallslab = bulk(symbol, lattice, a=1, cubic=True)
        repeat = int(cutoff+0.5)+1
        slab = smallslab.repeat((repeat,)*3)
        cell = slab.get_cell()
        positions = slab.get_positions()
    
        # find the only atom that matters
        center = positions[-1]/2
        atom = np.argmin(np.sum((positions-center)**2, axis=-1))
        ghosts = np.ones(positions.shape[0], dtype='int8')
        ghosts[atom] = 0
    
        # set up the ghosts with the calculator
        slab.set_calculator(calc)
        slab.set_pbc([0,0,0])
        calc.set_ghosts(ghosts)
   
        particles = 1

    if nbcname == "MI_OPBC_F":
        # now that we have the cutoff, actually do some calcs
        smallslab = bulk(symbol, lattice, a=1, cubic=True)
        repeat = int(cutoff+0.5)+1
        slab = smallslab.repeat((repeat,)*3)
        cell = slab.get_cell()
        positions = slab.get_positions()
        
        # work with slab from now on
        slab.set_calculator(calc)
   
        particles = len(slab)

    aopt_arr, eopt, iterations, funccalls, warnflag = opt.fmin(energy, cutoff/2.0, args=(slab,cell,positions), full_output=True, xtol=1e-8)
    info = {"iterations": iterations,
            "func_calls": funccalls,
            "warnflag": warnflag,
            "cutoff": cutoff,
            "nbctype": nbcname,
            "repeat": repeat}

    return aopt_arr[0], -eopt/particles, info

if __name__ == "__main__":
    #grab from stdin (or a file)
    symbol  = raw_input("element=")
    lattice = raw_input("lattice type=")
    model   = raw_input("modelname=")
    print symbol, lattice, model

    aopt, eopt, info = get_lattice_constant(model=model, symbol=symbol, lattice=lattice)

    space_groups = {"fcc": "Fm-3m", "bcc": "Im-3m", "sc": "Pm-3m", "diamond": "Fd-3m"}
    wyckoff_codes = {"fcc": "4a", "bcc": "2a", "sc": "1a", "diamond": "8a"}
    normed_basis = { 
          lattice: json.dumps(bulk(symbol, lattice, a=1, cubic=True).positions.tolist(), separators=(' ', ' '))
          for lattice in space_groups.keys() 
        }
    
    # pack the results in a dictionary
    results = {"lattice_constant": aopt,
                "cohesive_energy": eopt,
                "element": symbol,
                "crystal_structure": lattice,
                "space_group": space_groups[lattice],
                "wyckoff_code": wyckoff_codes[lattice],
                "basis_atoms": normed_basis[lattice]}
    results.update(info)
   
    template_environment = jinja2.Environment(
        loader=jinja2.FileSystemLoader('/'),
        block_start_string='@[',
        block_end_string=']@',
        variable_start_string='@<',
        variable_end_string='>@',
        comment_start_string='@#',
        comment_end_string='#@',
        undefined=jinja2.StrictUndefined,
        )
 
    #template the EDN output
    with open(os.path.abspath("output/results.edn"), "w") as f:
        template = template_environment.get_template(os.path.abspath("results.edn.tpl"))
        f.write(template.render(**results))