#!/usr/bin/env python3
"""
    This is adapted from the cubic version
    This seems to work for hcp

    Things to figure out:
        * why multiply strain on right
        * why we can't use changing volumes
        * why diamond doesn't agree between cubic True and False

"""
import os
import sys
import json
import ast

import ase
from ase.build import bulk
from ase.calculators.kim.kim import KIM
import numpy as np
import numdifftools as ndt
from scipy.optimize import fmin
import jinja2

class ElasticConstants(object):
    """Determine the cubic elastic constants by numerically determining the Hessian"""

    def __init__(
        self,
        calc,
        element,
        potentialname,
        crystalstructure,
        latticeconstA,
        latticeconstC,
    ):
        self.calculator = calc
        self.element = element
        self.potentialname = potentialname
        self.crystalstructure = crystalstructure
        # ORIGIN: self.latticeconst = latticeconst
        self.latticeconstA = latticeconstA
        self.latticeconstC = latticeconstC
        self.slab = self.create_slab()
        self.o_cell = self.slab.get_cell()
        self.slab.set_calculator(self.calculator)
        self.o_volume = self.slab.get_volume()

    def create_slab(self):
        # ORIGIN: slab = bulk(self.element,a=self.latticeconst,crystalstructure=self.crystalstructure,cubic=True)
        slab = bulk(
            self.element,
            a=self.latticeconstA,
            c=self.latticeconstC,
            crystalstructure=self.crystalstructure,
            cubic=True,
        )
        return slab

    def voigt_to_matrix(self, voigt_vec):
        """Convert a voigt notation vector to a matrix """
        matrix = np.zeros((3, 3))
        matrix[0, 0] = voigt_vec[0]
        matrix[1, 1] = voigt_vec[1]
        matrix[2, 2] = voigt_vec[2]
        matrix[tuple([[1, 2], [2, 1]])] = voigt_vec[3]
        matrix[tuple([[0, 2], [2, 0]])] = voigt_vec[4]
        matrix[tuple([[0, 1], [1, 0]])] = voigt_vec[5]

        return matrix

    def energy_from_strain(self, strain_vec):
        """ Apply a strain according to the strain_vec """
        # self.slab = self.o_slab.copy()
        # print strain_vec
        strain_mat = self.voigt_to_matrix(strain_vec)
        old_cell = self.o_cell
        new_cell = old_cell + np.dot(old_cell, strain_mat)
        # new_cell = old_cell + np.einsum('ij,aj->ai',strain_mat,old_cell)
        self.slab.set_cell(new_cell, scale_atoms=True)
        energy = self.slab.get_potential_energy()
        # volume = self.slab.get_volume()
        # n_of_atoms = self.slab.get_number_of_atoms()

        return (energy / self.o_volume) / ase.units.GPa

    def energy_from_scale(self, scale):
        # strain_mat = np.eye(3) * (scale)
        old_cell = self.o_cell
        # new_cell = old_cell + np.dot( strain_mat, old_cell)
        new_cell = old_cell * (1 + scale)
        self.slab.set_cell(new_cell, scale_atoms=True)
        energy = self.slab.get_potential_energy()
        # volume = self.slab.get_volume()
        # n_of_atoms = self.slab.get_number_of_atoms()
        return energy

    def results(self):
        """ Return the cubic elastic constants """
        # get the minimum
        self.minscale = float(fmin(self.energy_from_scale, 0, xtol=0, ftol=1e-7))
        self.oo_cell = self.o_cell.copy()
        self.o_cell = self.o_cell * (1 + self.minscale)

        func = self.energy_from_strain
        hess = ndt.Hessian(func, step=0.001, full_output=True)

        elastic_constants, info = hess(np.zeros(6, dtype=float))
        error_estimate = info.error_estimate

        inds11 = tuple([[0, 1, 2], [0, 1, 2]])
        C11 = elastic_constants[inds11].mean()
        C11sig = np.sqrt(((1.0 / 3 * error_estimate[inds11]) ** 2).sum())

        inds12 = tuple([[1, 2, 2, 0, 0, 1], [0, 0, 1, 1, 2, 2]])
        C12 = elastic_constants[inds12].mean()
        C12sig = np.sqrt(((1.0 / 6 * error_estimate[inds12]) ** 2).sum())

        # ORIGIN
        # inds44 = [[3,4,5],[3,4,5]]
        # C44 = 1./4 * elastic_constants[inds44].mean()
        # C44sig = 1./4 * np.sqrt( ((1./3*error_estimate[inds44])**2).sum() )

        B = 1.0 / 3 * (C11 + 2 * C12)
        Bsig = np.sqrt((1.0 / 3 * C11sig) ** 2 + (2.0 / 3 * C12sig) ** 2)

        # ORIGIN:
        # excessinds = tuple([[3,4,5,3,4,5,3,4,5,0,1,2,4,5,0,1,2,3,5,0,1,2,3,4],
        # [0,0,0,1,1,1,2,2,2,3,3,3,3,3,4,4,4,4,4,5,5,5,5,5]])
        # excess =(np.abs(elastic_constants[excessinds]).mean() )
        # excess_sig = np.sqrt( ((1./24*error_estimate[excessinds])**2).sum() )

        results_dict = {
            "C11": C11,
            "C11_sig": C11sig,
            "C12": C12,
            "C12_sig": C12sig,
            # 'C44': C44,
            # 'C44_sig' : C44sig,
            "B": B,
            "B_sig": Bsig,
            # 'excess': excess,
            # 'excess_sig' : excess_sig,
            "units": "GPa",
            "element": symbol,
            "crystal_structure": lattice,
            "space_group": space_groups[lattice],
            "wyckoff_code": wyckoff_codes[lattice],
            "lattice_constant_a": self.latticeconstA,
            "lattice_constant_c": self.latticeconstC,
            "scale_discrepency": self.minscale,
        }

        return results_dict


# Read inputs from pipeline.stdin
driverpath = os.path.dirname(os.path.abspath(input()))
symbol = input("Enter a species  (e.g. Al, Fe): ")
print(symbol)
lattice = input("Enter a hexagonal lattice type (e.g. hcp): ")
print(lattice)
model = input("Enter a model name: ")
print(model)
latticeconst_input = ast.literal_eval(input("Enter hexagonal lattice constants 'a' and "
    "'c' (m): "))
latticeconst_result_a, latticeconst_result_c = latticeconst_input
print("{}, {}".format(latticeconst_result_a, latticeconst_result_c))

space_groups = {"hcp": "P63/mmc"}
wyckoff_codes = {"hcp": "2d"}
normed_basis = {
    lattice: json.dumps(
        bulk(symbol, lattice, a=1, c=1, cubic=True).positions.tolist(),
        separators=(" ", " "),
    )
    for lattice in list(space_groups.keys())
}

latticeconstA = float(latticeconst_result_a) * 1e10
latticeconstC = float(latticeconst_result_c) * 1e10

print()
print("Calculating isothermal bulk modulus\n")

calc = KIM(model)
bulkmodulus = ElasticConstants(
    calc, symbol, model, lattice, latticeconstA, latticeconstC
)
results = bulkmodulus.results()
results.update({"basis_coordinates": [[0, 0, 0], [2.0 / 3, 1.0 / 3, 0.5]]})

print()
print("Finished calculation")
print()
print("C11 = {} GPa".format(results["C11"]))
print("C12 = {} GPa".format(results["C12"]))
print("  B = {} GPa".format(results["B"]))
print()

# Repeat species to match normed basis
results.update({"species": '" "'.join([symbol, symbol])})

template_environment = jinja2.Environment(
    loader=jinja2.FileSystemLoader("/"),
    block_start_string="@[",
    block_end_string="]@",
    variable_start_string="@<",
    variable_end_string=">@",
    comment_start_string="@#",
    comment_end_string="#@",
    undefined=jinja2.StrictUndefined,
)

# template the EDN output
with open(os.path.abspath("output/results.edn"), "w") as f:
    template = template_environment.get_template(
        os.path.join(driverpath, "results.edn.tpl")
    )
    f.write(template.render(**results))