#!/usr/bin/env python
"""
    This is adapted from the cubic version
    This seems to work for hcp

    Things to figure out:
        * why multiply strain on right
        * why we can't use changing volumes
        * why diamond doesn't agree between cubic True and False

"""

import ase
try:
    from ase.lattice import bulk
    print 'Imported bulk from ase.lattice' # For ASE version 3.9
except ImportError:
    from ase.structure import bulk
    print 'Imported bulk from ase.structure' # For ASE version <= 3.8
import numpy as np
import scipy.optimize as opt
import numdifftools as ndt
from kimcalculator import *
from scipy.optimize import fmin
import simplejson
import sys
import json
import jinja2

class ElasticConstants(object):
    """Determine the cubic elastic constants by numerically determining the Hessian"""

    def __init__(self, calc, element, potentialname, crystalstructure, latticeconstA, latticeconstC):
        self.calculator = calc
        self.element = element
        self.potentialname = potentialname
        self.crystalstructure = crystalstructure
        # ORIGIN: self.latticeconst = latticeconst
        self.latticeconstA = latticeconstA
        self.latticeconstC = latticeconstC
        self.slab = self.create_slab()
        self.o_cell = self.slab.get_cell()
        self.slab.set_calculator(self.calculator)
        self.o_volume = self.slab.get_volume()

    def create_slab(self):
        # ORIGIN: slab = bulk(self.element,a=self.latticeconst,crystalstructure=self.crystalstructure,cubic=True)
        slab = bulk(self.element,a=self.latticeconstA,c=self.latticeconstC, crystalstructure=self.crystalstructure,cubic=True)
        return slab

    def voigt_to_matrix(self,voigt_vec):
        """Convert a voigt notation vector to a matrix """
        matrix = np.zeros((3,3))
        matrix[0,0] = voigt_vec[0]
        matrix[1,1] = voigt_vec[1]
        matrix[2,2] = voigt_vec[2]
        matrix[ [ [1,2], [2,1] ] ] = voigt_vec[3]
        matrix[ [ [0,2], [2,0] ] ] = voigt_vec[4]
        matrix[ [ [0,1], [1,0] ] ] = voigt_vec[5]

        return matrix

    def energy_from_strain(self,strain_vec):
        """ Apply a strain according to the strain_vec """
        # self.slab = self.o_slab.copy()
        # print strain_vec
        strain_mat = self.voigt_to_matrix(strain_vec)
        old_cell = self.o_cell
        new_cell = old_cell + np.dot(old_cell, strain_mat)
        # new_cell = old_cell + np.einsum('ij,aj->ai',strain_mat,old_cell)
        self.slab.set_cell(new_cell,scale_atoms=True)
        energy = self.slab.get_potential_energy()
        # volume = self.slab.get_volume()
        # n_of_atoms = self.slab.get_number_of_atoms()

        return (energy/self.o_volume)/ase.units.GPa

    def energy_from_scale(self,scale):
        # strain_mat = np.eye(3) * (scale)
        old_cell = self.o_cell
        # new_cell = old_cell + np.dot( strain_mat, old_cell)
        new_cell = old_cell * (1 + scale)
        self.slab.set_cell(new_cell,scale_atoms=True)
        energy = self.slab.get_potential_energy()
        # volume = self.slab.get_volume()
        # n_of_atoms = self.slab.get_number_of_atoms()
        return energy

    def results(self):
        """ Return the cubic elastic constants """
        #get the minimum
        self.minscale = float(fmin(self.energy_from_scale,0,xtol=0,ftol=1e-7))
        self.oo_cell = self.o_cell.copy()
        self.o_cell = self.o_cell * (1+self.minscale)

        func = self.energy_from_strain
        hess = ndt.Hessian(func)

        elastic_constants = hess(np.zeros(6))
        error_estimate = hess.error_estimate

        inds11 = [[0,1,2],[0,1,2]]
        C11 = elastic_constants[ inds11 ].mean()
        C11sig = np.sqrt( ((1./3*error_estimate[ inds11 ])**2).sum() )

        inds12 = [[1,2,2,0,0,1],[0,0,1,1,2,2]]
        C12 = ( elastic_constants[inds12].mean() )
        C12sig = np.sqrt( ((1./6*error_estimate[inds12])**2).sum() )
        
        # ORIGIN
        # inds44 = [[3,4,5],[3,4,5]]
        # C44 = 1./4 * elastic_constants[inds44].mean()
        # C44sig = 1./4 * np.sqrt( ((1./3*error_estimate[inds44])**2).sum() )

        B = 1./3 * ( C11 + 2 * C12 )
        Bsig = np.sqrt( ( 1./3 * C11sig )**2 + ( 2./3 * C12sig )**2 )
        
        # ORIGIN: 
        # excessinds = [[3,4,5,3,4,5,3,4,5,0,1,2,4,5,0,1,2,3,5,0,1,2,3,4],
                # [0,0,0,1,1,1,2,2,2,3,3,3,3,3,4,4,4,4,4,5,5,5,5,5]]
        # excess =(np.abs(elastic_constants[excessinds]).mean() )
        # excess_sig = np.sqrt( ((1./24*error_estimate[excessinds])**2).sum() )
        
        results_dict = { 'C11': C11,
                         'C11_sig' : C11sig,
                         'C12' : C12,
                         'C12_sig' : C12sig,
                         # 'C44': C44,
                         # 'C44_sig' : C44sig,
                         'B' : B,
                         'B_sig' : Bsig,
                         # 'excess': excess,
                         # 'excess_sig' : excess_sig,
                         'units' : 'GPa',
                         "element": symbol,
                         "crystal_structure": lattice,
                         "space_group": space_groups[lattice],
                         "wyckoff_code": wyckoff_codes[lattice],
                         # "lattice_constant": self.latticeconst,
                         "lattice_constant_a": self.latticeconstA,
                         "lattice_constant_c": self.latticeconstC,
                         "scale_discrepency": self.minscale,
                         }

        return results_dict

symbol  = raw_input()
lattice = raw_input()
model   = raw_input()
latticeconst_result_a = raw_input() # ORIGIN: latticeconst_result = raw_input())
latticeconst_result_c = raw_input()

# symbol  = 'Fe'
# lattice = 'diamond'
# model   = 'EAM_Dynamo_Ackland_Bacon_Fe__MO_142799717516_000'
# latticeconst_result = 2.86652799316e-10

space_groups = {"hcp": "P63/mmc"}
wyckoff_codes = {"hcp": "2d"}
# ORIGIN: 
# space_groups = {"fcc": "Fm-3m", "bcc": "Im-3m", "sc": "Pm-3m", "diamond": "Fd-3m"}
# wyckoff_codes = {"fcc": "4a", "bcc": "2a", "sc": "1a", "diamond": "8a"}
normed_basis = { 
    # ORIGIN: lattice: json.dumps(bulk(symbol, lattice, a=1, cubic=True).positions.tolist(), separators=(' ', ' '))
    lattice: json.dumps(bulk(symbol, lattice, a=1, c=1, cubic=True).positions.tolist(), separators=(' ', ' '))
    for lattice in space_groups.keys() 
}

# ORIGIN: latticeconst = float(latticeconst_result) * 1e10
latticeconstA = float(latticeconst_result_a) * 1e10
latticeconstC = float(latticeconst_result_c) * 1e10

print symbol, lattice, model, latticeconstA, latticeconstC

calc = KIMCalculator(model)
# import ase.calculators
# calc = ase.calculators.emt.EMT()
bulkmodulus = ElasticConstants(calc, symbol, model, lattice, latticeconstA, latticeconstC)
results = bulkmodulus.results()
results.update({"basis_coordinates": [
  [0, 0, 0],
  [2.0/3, 1.0/3, 0.5]
  ]})

# Repeat species to match normed basis
results.update({"species": '" "'.join([symbol, symbol])})

template_environment = jinja2.Environment(
    loader=jinja2.FileSystemLoader('/'),
    block_start_string='@[',
    block_end_string=']@',
    variable_start_string='@<',
    variable_end_string='>@',
    comment_start_string='@#',
    comment_end_string='#@',
    undefined=jinja2.StrictUndefined,
    )

#template the EDN output
with open(os.path.abspath("output/results.edn"), "w") as f:
    template = template_environment.get_template(os.path.abspath("results.edn.tpl"))
    f.write(template.render(**results))