#!/usr/bin/env python
from ase.structure import bulk
from ase.calculators.emt import EMT
from ase.dft.kpoints import ibz_points, get_bandpath
from ase.phonons import Phonons
from kimcalculator import *
import simplejson
import jinja2
import json
from functools import partial

import matplotlib
matplotlib.use('Agg')
import pylab as plt
import numpy as np
import scipy.optimize as opt

symbol = raw_input()
lattice = raw_input()
model = raw_input()
latticeconstant = 0

def energy(a, slab, cell):
    """ Compute the energy of a lattice given the lattice constant """
    slab.set_cell( cell * a )
    try:
        energy = slab.get_potential_energy()
    except:
        energy = 1e10
        raise
    return energy


def findLatticeConstant(calc):
    """
    temporary copy of Alex's LatticeConstantCubicEnergy Test
    in the future we want to look up the result of this as input to the test
    """
   # Attempt to read the cutoff of the model
    slab = bulk(symbol, lattice, a=1)
    cell = slab.get_cell()
    slab.set_calculator(calc)
    cutoff = ks.KIM_API_get_data_double(calc.pkim, "cutoff")[0]
    
    aopt_arr, eopt, iterations, funccalls, warnflag = opt.fmin(energy, cutoff/2.0, args=(slab,cell), full_output=True)
    aopt = aopt_arr[0]
 
    if warnflag:
        raise Exception("Lattice constant computation hit bounds")    

    return aopt

def getFileInfo(filename):
    import os, hashlib
    dct = {}
    abspath = os.path.abspath(filename)
    dct['filename'] = filename
    dct['path'] = os.path.dirname(abspath)
    dct['extension'] = os.path.splitext(filename)[1]
    dct['size'] = os.path.getsize(abspath)
    dct['created'] = os.path.getctime(abspath)
    dct['hash'] = hashlib.md5(open(abspath, 'rb').read()).hexdigest()
    dct['desc'] = "Plot of the broken bond fit"
    return dct

# Setup crystal and EMT calculator
calc = KIMCalculator(model)
latticeconstant = findLatticeConstant(calc)
atoms = bulk(symbol, lattice, a=latticeconstant)

# Phonon calculator
N = 7
ph = Phonons(atoms, calc, supercell=(N, N, N), delta=0.05)
ph.run()

# Read forces and assemble the dynamical matrix
ph.read(acoustic=True)

# High-symmetry points in the Brillouin zone
points = ibz_points['fcc']
G = points['Gamma']
X = points['X']
W = points['W']
K = points['K']
L = points['L']
U = points['U']

point_names = ['$\Gamma$', 'X', 'U', 'L', '$\Gamma$', 'K']
path = [G, X, U, L, G, K]

# Band structure in meV
path_kc, q, Q = get_bandpath(path, atoms.cell, 100)
omega_kn = 1000 * ph.band_structure(path_kc)

# Calculate phonon DOS
omega_e, dos_e = ph.dos(kpts=(50, 50, 50), npts=5000, delta=5e-4)
omega_e *= 1000

# Plot the band structure and DOS
plt.figure(1, (8, 6))
plt.axes([.1, .07, .67, .85])
for n in range(len(omega_kn[0])):
    omega_n = omega_kn[:, n]
    plt.plot(q, omega_n, 'k-', lw=2)

plt.xticks(Q, point_names, fontsize=18)
plt.yticks(fontsize=18)
plt.xlim(q[0], q[-1])
plt.ylabel("Frequency ($\mathrm{meV}$)", fontsize=22)
plt.grid('on')

plt.axes([.8, .07, .17, .85])
plt.fill_between(dos_e, omega_e, y2=0, color='lightgrey', edgecolor='k', lw=1)
plt.ylim(0, 35)
plt.xticks([], [])
plt.yticks([], [])
plt.xlabel("DOS", fontsize=18)

filename = 'output/phonons-%s-%s-%s.png' % (model, symbol, lattice)
plt.savefig(filename)

kpoints = []
DOS = []
branches = ['TA', 'TA', 'LA']

space_groups = {"fcc": "Fm-3m", "bcc": "Im-3m", "sc": "Pm-3m", "diamond": "Fd-3m"}
wyckoff_codes = {"fcc": "4a", "bcc": "2a", "sc": "1a", "diamond": "8a"}
normed_basis = { 
    lattice: json.dumps(bulk(symbol, lattice, a=1, cubic=True).positions.tolist(), separators=(' ', ' '))
    for lattice in space_groups.keys() 
}

curr = 0
for j in xrange(len(omega_kn[0])):
    kpoints.append({
        "wavenumber": list(q), 
        "frequency": list(omega_kn[:,j]), 
        "branch_label": branches[j], 
        "iter": j+1,
    })
    curr = j+1

DOS.append({"energy": list(omega_e), "density": list(dos_e), "iter": curr+1})

results = {
    "wavevector": np.array(path_kc).tolist(),
    "lattice_constant": latticeconstant,
    "crystal_structure": lattice,
    "element": symbol,
    "kpoints": kpoints,
    "DOS": DOS,
    "basis_atoms": normed_basis[lattice],
}

template_environment = jinja2.Environment(
    loader=jinja2.FileSystemLoader('/'),
    block_start_string='@[',
    block_end_string=']@',
    variable_start_string='@<',
    variable_end_string='>@',
    comment_start_string='@#',
    comment_end_string='#@',
    undefined=jinja2.StrictUndefined,
    )

jsondump = partial(json.dumps, separators=(' ', ' '), indent=4)
jsonlinedump = partial(json.dumps, separators=(' ', ' '))
template_environment.filters.update({"json": jsondump, "jsonl": jsonlinedump})

#template the EDN output
with open(os.path.abspath("output/results.edn"), "w") as f:
    template = template_environment.get_template(os.path.abspath("results.edn.tpl"))
    f.write(template.render(**results))