#!/usr/bin/env python
from ase.build import bulk
from ase.dft.kpoints import ibz_points, get_bandpath
from ase.phonons import Phonons
from ase.calculators.kim.kim import KIM
import jinja2
import json
from functools import partial

import matplotlib
matplotlib.use('Agg')
import pylab as plt
import numpy as np
import os
import hashlib
import scipy.optimize as opt

symbol = raw_input()
lattice = raw_input()
model = raw_input()
latticeconstant = raw_input()

# Convert lattice constant from query to angstroms
latticeconstant = float(latticeconstant)*1e10

def energy(a, slab, cell):
    """ Compute the energy of a lattice given the lattice constant """
    slab.set_cell( cell * a )
    try:
        energy = slab.get_potential_energy()
    except:
        energy = 1e10
        raise
    return energy

def getFileInfo(filename):
    dct = {}
    abspath = os.path.abspath(filename)
    dct['filename'] = filename
    dct['path'] = os.path.dirname(abspath)
    dct['extension'] = os.path.splitext(filename)[1]
    dct['size'] = os.path.getsize(abspath)
    dct['created'] = os.path.getctime(abspath)
    dct['hash'] = hashlib.md5(open(abspath, 'rb').read()).hexdigest()
    dct['desc'] = "Plot of the broken bond fit"
    return dct

# Setup crystal and calculator
calc = KIM(model)
atoms = bulk(symbol, lattice, a=latticeconstant)

# Phonon calculator
N = 7
ph = Phonons(atoms, calc, supercell=(N, N, N), delta=0.05)
ph.run()

# Read forces and assemble the dynamical matrix
ph.read(acoustic=True)

# High-symmetry points in the Brillouin zone
points = ibz_points['fcc']
G = points['Gamma']
X = points['X']
W = points['W']
K = points['K']
L = points['L']
U = points['U']

point_names = ['$\Gamma$', 'X', 'U', 'L', '$\Gamma$', 'K']
path = [G, X, U, L, G, K]

# Band structure in meV
path_kc, q, Q = get_bandpath(path, atoms.cell, 100)
omega_kn = 1000 * ph.band_structure(path_kc)

# Calculate phonon DOS
omega_e, dos_e = ph.dos(kpts=(50, 50, 50), npts=5000, delta=5e-4)
omega_e *= 1000

# Plot the band structure and DOS
plt.figure(1, (8, 6))
plt.axes([.1, .07, .67, .85])
for n in range(len(omega_kn[0])):
    omega_n = omega_kn[:, n]
    plt.plot(q, omega_n, 'k-', lw=2)

plt.xticks(Q, point_names, fontsize=18)
plt.yticks(fontsize=18)
plt.xlim(q[0], q[-1])
plt.ylabel("Frequency ($\mathrm{meV}$)", fontsize=22)
plt.grid(True)

plt.axes([.8, .07, .17, .85])
plt.fill_between(dos_e, omega_e, y2=0, color='lightgrey', edgecolor='k', lw=1)
plt.ylim(0, 35)
plt.xticks([], [])
plt.yticks([], [])
plt.xlabel("DOS", fontsize=18)

filename = 'output/phonons-%s-%s-%s.png' % (model, symbol, lattice)
plt.savefig(filename)

kpoints = []
DOS = []
branches = ['TA', 'TA', 'LA']

space_groups = {"fcc": "Fm-3m", "bcc": "Im-3m", "sc": "Pm-3m", "diamond": "Fd-3m"}
wyckoff_codes = {"fcc": "4a", "bcc": "2a", "sc": "1a", "diamond": "8a"}
normed_basis = {
    lattice: json.dumps(bulk(symbol, lattice, a=1, cubic=True).positions.tolist(), separators=(' ', ' '))
    for lattice in space_groups.keys()
}

curr = 0
for j in xrange(len(omega_kn[0])):
    kpoints.append({
        "wavenumber": list(q),
        "frequency": list(omega_kn[:,j]),
        "branch_label": branches[j],
        "iter": j+1,
    })
    curr = j+1

DOS.append({"energy": list(omega_e), "density": list(dos_e), "iter": curr+1})

results = {
    "wavevector": np.array(path_kc).tolist(),
    "lattice_constant": latticeconstant,
    "crystal_structure": lattice,
    "element": symbol,
    "kpoints": kpoints,
    "DOS": DOS,
    "basis_atoms": normed_basis[lattice],
}

template_environment = jinja2.Environment(
    loader=jinja2.FileSystemLoader('/'),
    block_start_string='@[',
    block_end_string=']@',
    variable_start_string='@<',
    variable_end_string='>@',
    comment_start_string='@#',
    comment_end_string='#@',
    undefined=jinja2.StrictUndefined,
    )

jsondump = partial(json.dumps, separators=(' ', ' '), indent=4)
jsonlinedump = partial(json.dumps, separators=(' ', ' '))
template_environment.filters.update({"json": jsondump, "jsonl": jsonlinedump})

#template the EDN output
with open(os.path.abspath("output/results.edn"), "w") as f:
    template = template_environment.get_template(os.path.abspath("results.edn.tpl"))
    f.write(template.render(**results))