#!/usr/bin/env python
import matplotlib
matplotlib.use('Agg')
import ase
from ase.build import bulk
import numpy as np
from chooseSurfaces import expandList
from surface import *
from ase.calculators.kim.kim import KIM
from analysis import *
import simplejson
import pickle
import scipy.optimize as opt
import sys
import signal
import time
import os
import jinja2
import json
from functools import partial

driverpath = os.path.dirname(os.path.abspath(raw_input()))
symbol = raw_input()
lattice = raw_input()
model = raw_input()
latticeconstant = raw_input()
latticeconstant_global = 0

# Convert latticeconstant from query into Angstroms
latticeconstant = float(latticeconstant)*1e10

TIME_CUTOFF = 60*60
# in seconds
# if a surface calculation exceeds TIME_CUTOFF skips that surface

calc = KIM(model)
print "calculator established"

def sweepSurfaces(calc, latticeconstant):
    global latticeconstant_global
    surfaceEnergyDict = {}
    surfaceEnergyUnrelaxedDict = {}
    surfaceLatticeVects = {}
    position0, position1 = {}, {}
    energies=[]
    indices_calculated=[]
    file = open(os.path.join(driverpath,'IndexList.pkl'), 'r')
    list_of_indices = pickle.load(file)[:100]
    file.close()
    # list of indices for testing
    list_of_indices = [[1,1,1],[1,0,0],[1,2,1],[1,1,0]]#,[1,8,9],[2,5,7]]
    latticeconstant_global = latticeconstant

    atoms = bulk(symbol,lattice,a=latticeconstant)
    atoms.set_calculator(calc)
    unit_e_bulk = atoms.get_potential_energy()/atoms.get_number_of_atoms()

    # testing time for the calculation of 1 surface
    # start an alarm
    signal.signal(signal.SIGALRM, handler1)
    signal.alarm(TIME_CUTOFF)
    start_time = time.time()
    if lattice == 'fcc':
        miller = [1,1,1]
    elif lattice == 'bcc':
        miller = [1,0,0]
    else:
        miller = [1,0,0]
    E_unrelaxed, E_relaxed, surf_lattice_vect, p0, p1 = getSurfaceEnergy(miller, calc, unit_e_bulk, latticeconstant)
    signal.alarm(0)
    end_time = time.time()
    calcTime = end_time - start_time

    signal.signal(signal.SIGALRM, handler2)
    for miller in list_of_indices:
        signal.alarm(TIME_CUTOFF)
        try:
            print miller
            E_unrelaxed, E_relaxed, surf_lattice_vect, p0, p1 = getSurfaceEnergy(miller, calc, unit_e_bulk, latticeconstant)
            surfaceEnergyUnrelaxedDict[tuple(miller)] = E_unrelaxed
            surfaceEnergyDict[tuple(miller)] = E_relaxed
            surfaceLatticeVects[tuple(miller)] = surf_lattice_vect
            position0[tuple(miller)] = p0
            position1[tuple(miller)] = p1
            energies.append(E_relaxed)
            indices_calculated.append(miller)
        except TimeoutException:
            print "surface took too long, skipping", miller
        except:
            raise
        signal.alarm(0)
    return indices_calculated, np.array(energies), surfaceEnergyDict, calcTime, surfaceLatticeVects, surfaceEnergyUnrelaxedDict, position0, position1

def getSurfaceEnergy(miller, calc, unit_e_bulk, latticeconstant):
    from ase.io import write

    surf = makeSurface(symbol,lattice,miller,size = (3, 3, 10),lattice_const=latticeconstant)
    # let's save the configuration as xyz file here
    write("output/SurfaceConfigurationUnrelaxed_%s_%s_%s_%s_.xyz" % (symbol, lattice, model, ".".join([str(l) for l in miller])), surf)

    e_unrelaxed, e_relaxed, pos_unrelaxed, pos_relaxed = surface_energy(surf, calc)
    e_bulk = unit_e_bulk*surf.get_number_of_atoms()
    surface_vector = np.cross(surf.cell[0],surf.cell[1])
    surface_area = np.sqrt(np.dot(surface_vector,surface_vector))
    E_unrelaxed = (e_unrelaxed-e_bulk)/(2*surface_area)
    E_relaxed = (e_relaxed-e_bulk)/(2*surface_area)
    surfvector = getSurfaceVector(surf)

    # let's save the configuration as xyz file here
    write("output/SurfaceConfiguration_%s_%s_%s_%s_.xyz" % (symbol, lattice, model, ".".join([str(l) for l in miller])), surf)

    return E_unrelaxed, E_relaxed, surfvector, pos_unrelaxed, pos_relaxed

def fitBrokenBond(indices, energies, structure, n=3, p0=[0.1,0.1,0.01,0.0],correction=1):

    indices = np.array(indices)

    bfparams, cov_x, cost, range_error, max_error = fitSurfaceEnergies(indices,energies,structure,n=n,p0=p0,correction=correction)

    return bfparams, cost, range_error, max_error

def fitSubSetBrokenBond(sample_indices,indices,energies,structure,n=3,p0=[0.1,0.1,0.01,0.0],correction=1):

    expandedIndices, expandedEnergies = expandList(indices, energies)

    sample_energies=[]

    for ind in sample_indices:
        curr_index = expandedIndices.index(ind)
        sample_energies.append(expandedEnergies[curr_index])

    bfparams, cost, range_error, max_error = fitBrokenBond(sample_indices, sample_energies, structure, n=n, p0 = p0,correction=correction)

    return bfparams

def plotBrokenBondFit(indices,energies,bfparams,structure,correction=1):

    # modify to include three cyrstallographic zone cutoffs
    plotindices, plotenergies = expandList(indices,energies)
    plotSubSet(plotindices,plotenergies,[1,-1,0],[1,1,0],bfparams,structure = structure, correction=correction)
    pylab.savefig('output/BrokenBondFit1-10.png')
    pylab.clf()
    plotSubSet(plotindices,plotenergies,[1,-1,2],[1,1,0],bfparams,structure = structure, correction=correction)
    pylab.savefig('output/BrokenBondFit1-12.png')
    pylab.clf()
    plotSubSet(plotindices,plotenergies,[1,1,-1],[0,1,1],bfparams,structure = structure,correction=correction)
    pylab.savefig('output/BrokenBondFit11-1.png')
    pylab.clf()

def calcEnergy(a, calc):

    atoms = bulk(symbol,lattice,a=a)
    atoms.set_calculator(calc)
    try:
        energy = atoms.get_potential_energy()
    except:
        energy = 1e10

    return energy

def handler1(signum,frame):
    raise Exception("first calculation took too long, aborting model-test pair")

def handler2(signum,frame):
    raise TimeoutException()

class TimeoutException(Exception):
    pass

def getFileInfo(filename):
    import os, hashlib
    dct = {}
    abspath = os.path.abspath(filename)
    dct['filename'] = filename
    dct['path'] = os.path.dirname(abspath)
    dct['extension'] = os.path.splitext(filename)[1]
    dct['size'] = os.path.getsize(abspath)
    dct['created'] = os.path.getctime(abspath)
    dct['hash'] = hashlib.md5(open(abspath, 'rb').read()).hexdigest()
    dct['desc'] = "Plot of the broken bond fit"
    return dct


indices, energies, surfaceEnergyDict, calcTime, surfaceVectorDict, surfaceEnergyUnrelaxed, surfaceP0, surfaceP1 = sweepSurfaces(calc, latticeconstant)
print "surfaces swept"
plotfiles = []
if len(indices)>=4:
    bfparams, cost, range_error, max_error = fitBrokenBond(indices, energies, lattice)
    plotBrokenBondFit(indices,energies,bfparams,lattice)
    for name in ["output/BrokenBondFit1-10.png", "output/BrokenBondFit1-12.png", "output/BrokenBondFit11-1.png"]:
        plotfiles.append(getFileInfo(name))

    # fit the minimum subset and see how well it does
    samplelist = [[1,1,1],[1,0,0],[1,1,2],[1,0,1]]
    #subbfparams = fitSubSetBrokenBond(samplelist,indices,energies,lattice)
    #subfitcost = np.sum(abs(residual(subbfparams,numpy.array(indices),energies,1,lattice)/energies))/len(indices)/cost
else:
    bfparams = None
    range_error = None
    max_error = None
    subbfparams = None
    #subfitcost = None


#=======================================================
# formatting the output for pipeline integration
#=======================================================
energies = []
lastindex = 0
for item, (key,val) in enumerate(surfaceEnergyDict.iteritems()):
    energies.append({
        "index": item+1,
        "miller_index": key,
        "surface_energy": val,
        "surface": surfaceVectorDict[key],
        "positions": surfaceP0[key].tolist(),
    })
    lastindex = item+1

unrelaxedenergies = []
for item, (key,val) in enumerate(surfaceEnergyUnrelaxed.iteritems()):
    unrelaxedenergies.append({
        "index": item+lastindex+1,
        "miller_index": key,
        "surface_energy": val,
        "surface": surfaceVectorDict[key],
        "positions": surfaceP1[key].tolist(),
    })


space_groups = {"fcc": "Fm-3m", "bcc": "Im-3m", "sc": "Pm-3m", "diamond": "Fd-3m"}
wyckoff_codes = {"fcc": "4a", "bcc": "2a", "sc": "1a", "diamond": "8a"}
normed_basis = {
    lattice: json.dumps(bulk(symbol, lattice, a=1, cubic=True).positions.tolist(), separators=(' ', ' '))
    for lattice in space_groups.keys()
}

# dump all the stuff we calculated into one large dictionary
if bfparams is not None:
    results = {'BrokenBond_P1': bfparams[0], \
        'BrokenBond_P2':bfparams[1], \
        'BrokenBond_P3':bfparams[2], \
        'CorrectionParameter':bfparams[3],
        'ErrorRange':range_error,
        'MaxResidual':max_error,
        #'SubsetPredictionQuality': subfitcost,

        "calculationTimeForTestSurface" : calcTime,
        "crystal_structure": lattice,
        "element": symbol,
        "lattice_constant": latticeconstant_global,
        "basis_atoms": normed_basis[lattice],
        "space_group": space_groups[lattice],
        "wyckoff_code": wyckoff_codes[lattice],

        "plotfiles": plotfiles,
        "energies": energies,
        "unrelaxedenergies": unrelaxedenergies}
else:
    results = {"calculationTimeForTestSurface":calcTime,\
         "message from test":"not enough surface energy results for fit",
         "energies": energies}

template_environment = jinja2.Environment(
    loader=jinja2.FileSystemLoader('/'),
    block_start_string='@[',
    block_end_string=']@',
    variable_start_string='@<',
    variable_end_string='>@',
    comment_start_string='@#',
    comment_end_string='#@',
    undefined=jinja2.StrictUndefined,
    )

jsondump = partial(json.dumps, separators=(' ', ' '), indent=4)
jsonlinedump = partial(json.dumps, separators=(' ', ' '))
template_environment.filters.update({"json": jsondump, "jsonl": jsonlinedump})

#template the EDN output
with open(os.path.abspath("output/results.edn"), "w") as f:
    template = template_environment.get_template(os.path.abspath("results.edn.tpl"))
    f.write(template.render(**results))