#!/usr/bin/env python3
"""
    This seems to work for fcc, bcc, sc and diamond

    Things to figure out:
        * why multiply strain on right
        * why we can't use changing volumes

Last Update: 2019/04/13 Ellad Tadmor (added ability to do diamond)

"""
# Python 2-3 compatible code issues
from __future__ import print_function

import os
import sys
import json

import numpy as np
from ase.build import bulk
from ase.units import GPa
import numdifftools as ndt
from ase.calculators.kim.kim import KIM
from scipy.optimize import fmin
from scipy.optimize import minimize
import jinja2

try:
    input = raw_input
except NameError:
    pass


class ElasticConstants(object):
    """Determine the cubic elastic constants by numerically determining the Hessian"""

    def __init__(self, calc, element, potentialname, crystalstructure, latticeconst):
        self.calculator = calc
        self.element = element
        self.potentialname = potentialname
        self.crystalstructure = crystalstructure
        self.latticeconst = latticeconst
        self.slab = self.create_slab()
        self.o_cell = self.slab.get_cell()
        self.slab.set_calculator(self.calculator)
        self.o_volume = self.slab.get_volume()

    def create_slab(self):
        slab = bulk(
            self.element,
            a=self.latticeconst,
            crystalstructure=self.crystalstructure,
            cubic=True,
        )
        return slab

    def voigt_to_matrix(self, voigt_vec):
        """Convert a voigt notation vector to a matrix """
        matrix = np.zeros((3, 3))
        matrix[0, 0] = voigt_vec[0]
        matrix[1, 1] = voigt_vec[1]
        matrix[2, 2] = voigt_vec[2]
        matrix[tuple([[1, 2], [2, 1]])] = voigt_vec[3]
        matrix[tuple([[0, 2], [2, 0]])] = voigt_vec[4]
        matrix[tuple([[0, 1], [1, 0]])] = voigt_vec[5]

        return matrix

    def get_energy_from_positions(self, pos):
        natoms = round(len(pos) / 3)
        self.slab.set_positions(np.reshape(pos, (natoms, 3)))
        energy = self.slab.get_potential_energy()
        return energy

    def get_gradient_from_positions(self, pos):
        natoms = round(len(pos) / 3)
        self.slab.set_positions(np.reshape(pos, (natoms, 3)))
        forces = self.slab.get_forces()
        return -forces.flatten()

    def energy_from_strain(self, strain_vec):
        """ Apply a strain according to the strain_vec """
        # self.slab = self.o_slab.copy()
        # print strain_vec
        strain_mat = self.voigt_to_matrix(strain_vec)
        old_cell = self.o_cell
        new_cell = old_cell + np.dot(old_cell, strain_mat)
        # new_cell = old_cell + np.einsum('ij,aj->ai',strain_mat,old_cell)
        self.slab.set_cell(new_cell, scale_atoms=True)
        if self.crystalstructure == "diamond":
            pos0 = self.slab.get_positions().flatten()
            res = minimize(
                self.get_energy_from_positions,
                pos0,
                jac=self.get_gradient_from_positions,
            )
            energy = res["fun"]
        else:
            energy = self.slab.get_potential_energy()
        # volume = self.slab.get_volume()
        # n_of_atoms = self.slab.get_number_of_atoms()

        return (energy / self.o_volume) / GPa

    def energy_from_scale(self, scale):
        # strain_mat = np.eye(3) * (scale)
        old_cell = self.o_cell
        # new_cell = old_cell + np.dot( strain_mat, old_cell)
        new_cell = old_cell * (1 + scale)
        self.slab.set_cell(new_cell, scale_atoms=True)
        energy = self.slab.get_potential_energy()
        # volume = self.slab.get_volume()
        # n_of_atoms = self.slab.get_number_of_atoms()
        return energy

    def results(self):
        """ Return the cubic elastic constants """
        # get the minimum
        self.minscale = float(
            fmin(self.energy_from_scale, 0, xtol=0, ftol=1e-7, maxiter=2000)
        )
        self.oo_cell = self.o_cell.copy()
        self.o_cell = self.o_cell * (1 + self.minscale)

        func = self.energy_from_strain
        hess = ndt.Hessian(func, step=0.001, full_output=True)

        elastic_constants, info = hess(np.zeros(6, dtype=float))
        error_estimate = info.error_estimate

        inds11 = tuple([[0, 1, 2], [0, 1, 2]])
        C11 = elastic_constants[inds11].mean()
        C11sig = np.sqrt(((1.0 / 3 * error_estimate[inds11]) ** 2).sum())

        inds12 = tuple([[1, 2, 2, 0, 0, 1], [0, 0, 1, 1, 2, 2]])
        C12 = elastic_constants[inds12].mean()
        C12sig = np.sqrt(((1.0 / 6 * error_estimate[inds12]) ** 2).sum())

        inds44 = tuple([[3, 4, 5], [3, 4, 5]])
        C44 = 1.0 / 4 * elastic_constants[inds44].mean()
        C44sig = 1.0 / 4 * np.sqrt(((1.0 / 3 * error_estimate[inds44]) ** 2).sum())

        B = 1.0 / 3 * (C11 + 2 * C12)
        Bsig = np.sqrt((1.0 / 3 * C11sig) ** 2 + (2.0 / 3 * C12sig) ** 2)

        excessinds = tuple(
            [
                [
                    3,
                    4,
                    5,
                    3,
                    4,
                    5,
                    3,
                    4,
                    5,
                    0,
                    1,
                    2,
                    4,
                    5,
                    0,
                    1,
                    2,
                    3,
                    5,
                    0,
                    1,
                    2,
                    3,
                    4,
                ],
                [
                    0,
                    0,
                    0,
                    1,
                    1,
                    1,
                    2,
                    2,
                    2,
                    3,
                    3,
                    3,
                    3,
                    3,
                    4,
                    4,
                    4,
                    4,
                    4,
                    5,
                    5,
                    5,
                    5,
                    5,
                ],
            ]
        )
        excess = np.abs(elastic_constants[excessinds]).mean()
        excess_sig = np.sqrt(((1.0 / 24 * error_estimate[excessinds]) ** 2).sum())

        results_dict = {
            "C11": C11,
            "C11_sig": C11sig,
            "C12": C12,
            "C12_sig": C12sig,
            "C44": C44,
            "C44_sig": C44sig,
            "B": B,
            "B_sig": Bsig,
            "excess": excess,
            "excess_sig": excess_sig,
            "units": "GPa",
            "element": symbol,
            "crystal_structure": lattice,
            "space_group": space_groups[lattice],
            "wyckoff_code": wyckoff_codes[lattice],
            "lattice_constant": self.latticeconst,
            "scale_discrepency": self.minscale,
        }

        return results_dict


symbol = input()
lattice = input()
model = input()
latticeconst_result = input()

space_groups = {"fcc": "Fm-3m", "bcc": "Im-3m", "sc": "Pm-3m", "diamond": "Fd-3m"}
wyckoff_codes = {"fcc": "4a", "bcc": "2a", "sc": "1a", "diamond": "8a"}
normed_basis = {
    lattice: json.dumps(
        bulk(symbol, lattice, a=1, cubic=True).positions.tolist(), separators=(" ", " ")
    )
    for lattice in space_groups.keys()
}
count = bulk(symbol, lattice, a=1, cubic=True).positions.shape[0]  # For species
latticeconst = float(latticeconst_result) * 1e10

print(symbol, lattice, model, latticeconst)

calc = KIM(model)
bulkmodulus = ElasticConstants(calc, symbol, model, lattice, latticeconst)
results = bulkmodulus.results()
results.update({"basis_coordinates": normed_basis[lattice]})

# Repeat symbols to match normed basis
results.update({"species": '" "'.join([symbol] * count)})

# Echo results
print()
print("C11 = {:12.6f} +/- {:12.6f} GPa".format(results["C11"], results["C11_sig"]))
print("C12 = {:12.6f} +/- {:12.6f} GPa".format(results["C12"], results["C12_sig"]))
print("C44 = {:12.6f} +/- {:12.6f} GPa".format(results["C44"], results["C44_sig"]))
print("B   = {:12.6f} +/- {:12.6f} GPa".format(results["B"], results["B_sig"]))
print()

template_environment = jinja2.Environment(
    loader=jinja2.FileSystemLoader("/"),
    block_start_string="@[",
    block_end_string="]@",
    variable_start_string="@<",
    variable_end_string=">@",
    comment_start_string="@#",
    comment_end_string="#@",
    undefined=jinja2.StrictUndefined,
)

# template the EDN output
with open(os.path.abspath("output/results.edn"), "w") as f:
    template = template_environment.get_template(os.path.abspath("results.edn.tpl"))
    f.write(template.render(**results))