#!/usr/bin/env python
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2017, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#
################################################################################

# The docstring below is vc_description
'''Check whether a model is invariant with respect to the inversion
operation where each atom is moved along a straight line through the inversion
center to a point of equal distance on the other side. This is satisfied for
all inversion centers if the model is invariant to rigid-body translation, and
has inversion symmetry about the origin. Invariance symmetry is expected from
the properties of the quantum mechanics Hamiltonian operator. It should be
true for any model that does not depend on an external field. The check
is performed for a randomly distorted non-periodic body-centered cubic (BCC)
cube base structure. Separate configurations are tested for each species
supported by the model, as well as one containing a random distribution of all
species.  The energy and forces of each configuration are compared with those
of the same configuration translated in a random direction by an irrational
distance and then inverted through the origin. The energies must be the
same and the forces must change sign.  The verification check will pass if
the energy of all configurations that the model is able to compute pass both
tests. Configurations used for testing are provided as auxiliary files.'''

# Python 2-3 compatible code issues
from __future__ import print_function
try:
   input = raw_input
except NameError:
   pass

from ase.lattice.cubic import BodyCenteredCubic
from kimcalculator import KIMCalculator, KIM_get_supported_species_list
import kimvc
import sys
import random
import numpy as np
import math

__version__ = "000"
__author__ = "Ellad Tadmor"
vc_name = "vc-inversion-symmetry"
vc_description = kimvc.vc_stripall(__doc__)
vc_category = "informational"
vc_grade_basis = "passfail"
vc_files = []

################################################################################
#
#   FUNCTIONS
#
################################################################################

################################################################################
def get_random_unit_vector():
    """
    Generates a random 3D unit vector (direction) with a uniform spherical
    distribution
    stackoverflow.com/questions/5408276/python-uniform-spherical-distribution
    """
    phi = random.uniform(0,2*math.pi)
    costheta = random.uniform(-1.,1.)

    theta = np.arccos( costheta )
    x = np.sin( theta) * np.cos( phi )
    y = np.sin( theta) * np.sin( phi )
    z = np.cos( theta )
    return (x,y,z)

################################################################################
def perform_inversion_symmetry_check(vc, atoms, heading, dashwidth):
    '''
    Perform inversion symmetry check for the ASE atoms object in 'atoms'
    '''
    # set comparison tolerance
    tole = 1e-8
    eps_prec = np.finfo(float).eps
    # compute  the energy in the original location
    energy_orig = atoms.get_potential_energy()
    forces_orig = atoms.get_forces()

    # compute a random translation vector and apply translation
    trans = np.multiply(get_random_unit_vector(), math.pi)
    for at in range(0,len(atoms)):
        atoms[at].position += trans
    energy_trans = atoms.get_potential_energy()

    # check if energy is the same up to a numerical tolerance
    den = max(0.5*(abs(energy_trans) + abs(energy_orig)), eps_prec)
    passed_trans = abs(energy_trans-energy_orig)/den < tole

    # apply inversion about origin
    for at in range(0,len(atoms)):
        atoms[at].position = np.multiply(atoms[at].position, -1.)
    energy_trans_oinv = atoms.get_potential_energy()
    forces_trans_oinv = atoms.get_forces()

    # check if energy is the same up to a numerical tolerance
    den = max(0.5*(abs(energy_trans_oinv) + abs(energy_trans)), eps_prec)
    passed_oinv = abs(energy_trans_oinv-energy_trans)/den < tole

    # need to pass translation and origin inversion symmetry to pass check
    passed_energy = passed_trans and passed_oinv

    # report results and return
    vc.rwrite('')
    vc.rwrite(heading)
    vc.rwrite('-'*dashwidth)
    vc.rwrite('Translation vector = {0: .8e}  {1: .8e}  {2: .8e}'. \
              format(*trans))
    vc.rwrite('')
    vc.rwrite('Energy requirement:')
    vc.rwrite('')
    vc.rwrite( \
        'V(-(r_1+c),...,-(r_N+c)) = V(r_1,...,r_N), '
        'where r_i is the position of atom i, V is the potential energy, ')
    vc.rwrite('and c is a translation vector.')
    vc.rwrite('')
    vc.rwrite('V(-(r_1+c),...,-(r_N+c)) = {0}'.format(energy_trans_oinv))
    vc.rwrite('V(r_1+c,...,r_N+c)       = {0}'.format(energy_trans))
    vc.rwrite('V(r_1,...,r_N)           = {0}'.format(energy_orig))
    vc.rwrite('')

    # check forces for inversion symmetry
    vc.rwrite('Forces requirement:')
    vc.rwrite('')
    vc.rwrite( \
        'f_i(-(r_1+c),...,-(r_N+c)) = -f_i(r_1,...,r_N), '
        'where r_i is the position of atom i, f_i is the force ')
    vc.rwrite('on atom i, and c is a translation vector.')
    vc.rwrite('')
    hfmt = '{:>3}' + ' '*16 + '{}' + ' '*35 + '{}'
    fmt  = '{:>3}   ' + '{: .8e}   '*3 + '|   ' + '{: .8e}   '*3 + '{}'
    vc.rwrite(hfmt.format('i','f_i(-(r_1+c),...,-(r_N+c))', \
                              '-f_i(r_1,...,r_N)'))
    vc.rwrite('-'*dashwidth)
    passed_forces = True
    for i in range(0,len(atoms)):
        f_lhs = forces_trans_oinv[i]
        f_rhs = np.multiply(forces_orig[i], -1.)
        den=np.maximum(0.5*(np.absolute(f_lhs)+np.absolute(f_rhs)),eps_prec)
        force_ok = np.all(np.absolute(f_lhs - f_rhs)/den<tole)
        stat = ''
        if not force_ok:
            passed_forces = False
            stat = 'ERR'
        vc.rwrite(fmt.format(i, f_lhs[0], f_lhs[1], f_lhs[2], \
                                f_rhs[0], f_rhs[1], f_rhs[2], stat))
    vc.rwrite('-'*dashwidth)

    # determine overall result
    passed = passed_energy and passed_forces

    if passed:
        vc.rwrite('PASS: Energies and forces are the same to within a '
                   'relative error of {0}'.format(tole))
    else:
        vc.rwrite('FAIL: Energies and/or forces differ by more than '
                  'a relative error of {0}'.format(tole))
    vc.rwrite('-'*dashwidth)

    return passed

################################################################################
def do_vc(model, vc):
    '''
    Do Symmetry Inversion VC
    '''
    # Get supported species
    species = KIM_get_supported_species_list(model)
    species = kimvc.remove_species_not_supported_by_ASE(species)
    species.sort()

    # Basic cell parameters
    lattice_constant_orig = 3.0
    pert_amp_orig = 0.1*lattice_constant_orig
    ncells_per_side = 2
    seed = 13
    random.seed(seed)

    # Print VC info
    dashwidth = 121
    vc.rwrite('')
    vc.rwrite('-'*dashwidth)
    vc.rwrite('Results for KIM Model      : %s' % model.strip())
    vc.rwrite('Supported species          : %s' % ' '.join(species))
    vc.rwrite('')
    vc.rwrite('random seed                = %d'    % seed)
    vc.rwrite('lattice constant (orig)    = %0.3f' % lattice_constant_orig)
    vc.rwrite('perturbation amplitude     = %0.3f' % pert_amp_orig)
    vc.rwrite('number unit cells per side = %d'    % ncells_per_side)
    vc.rwrite('-'*dashwidth)
    vc.rwrite('')

    # Initialize variables
    got_atleast_one = False
    passed_all = True

    # Perform inversion check for monotatomic systems
    for spec in species:
        calc = KIMCalculator(model)
        lattice_constant = lattice_constant_orig
        got_initial_config = False
        while not got_initial_config:
            atoms = BodyCenteredCubic(
                    size=(ncells_per_side, ncells_per_side, ncells_per_side),
                    latticeconstant=lattice_constant, symbol=spec, pbc=False)
            atoms.set_calculator(calc)
            try:
                kimvc.rescale_to_get_nonzero_forces(atoms, 0.01)
                got_initial_config = True
            except:
                # Initial config failed. This most likely due to an evaluation
                # outside the legal model range. Increase lattice constant and
                # try again.
                lattice_constant += 0.25
                if lattice_constant > 10.0:
                    sys.exit(1)  # Cannot find a working configuration within a
                                 # a reasonable lattice constant range.
        # Randomize positions
        save_positions = atoms.get_positions()
        pert_amp = pert_amp_orig
        got_randomized_config = False
        while not got_randomized_config:
            try:
                kimvc.randomize_positions(atoms, pert_amp)
                forces = atoms.get_forces() # make sure forces can be computed
                got_randomized_config = True
            except:
                # Failed to compute forces; reset to original posns and retry
                atoms.set_positions(save_positions)
                pert_amp *= 0.5  # cut perturbation amplitude by half
        # Move atoms around until all forces are sizeable
        kimvc.perturb_until_all_forces_sizeable(atoms, pert_amp)
        aux_file = 'config-'+spec+'.xyz'
        vc_files.append(aux_file)
        vc.write_aux_ase_atoms(aux_file, atoms, 'xyz')
        heading = 'MONOATOMIC STRUCTURE -- Species = ' + spec + \
                '   (Configuration in file "' + aux_file + '")'
        try:
            passed = perform_inversion_symmetry_check(vc, atoms, \
                                                      heading, dashwidth)
            passed_all = passed_all and passed
            got_atleast_one = True
        except:
            pass

    # Perform inversion check for mixed system
    if len(species)>1:
        lattice_constant = lattice_constant_orig
        while True:
            atoms = BodyCenteredCubic(
                    size=(ncells_per_side, ncells_per_side, ncells_per_side),
                    latticeconstant=lattice_constant, symbol="H", pbc=False)
            if len(atoms) < len(species):
                ncells_per_side += 1
            else:
                break
        kimvc.randomize_species(atoms, species)
        calc = KIMCalculator(model)
        atoms.set_calculator(calc)
        got_initial_config = False
        while not got_initial_config:
            try:
                kimvc.rescale_to_get_nonzero_forces(atoms, 0.01)
                got_initial_config = True
            except:
                # Initial config failed. This most likely due to an evaluation
                # outside the legal model range. Increase lattice constant and
                # try again.
                lattice_constant += 0.25
                if lattice_constant > 10.0:
                    sys.exit(1)  # Cannot find a working configuration within a
                                # a reasonable lattice constant range.
                acell = lattice_constant*ncells_per_side
                atoms.set_cell([acell, acell, acell],scale_atoms=True)
        # Randomize positions
        save_positions = atoms.get_positions()
        pert_amp = pert_amp_orig
        got_randomized_config = False
        while not got_randomized_config:
            try:
                kimvc.randomize_positions(atoms, pert_amp)
                forces = atoms.get_forces() # make sure forces can be computed
                got_randomized_config = True
            except:
                # Failed to compute forces; reset to original posns and retry
                atoms.set_positions(save_positions)
                pert_amp *= 0.5  # cut perturbation amplitude by half
        kimvc.perturb_until_all_forces_sizeable(atoms, pert_amp)
        aux_file = 'config-'+''.join(species)+'.xyz'
        vc_files.append(aux_file)
        vc.write_aux_ase_atoms(aux_file, atoms, 'xyz')
        heading = 'MIXED STRUCTURE -- Species = ' + ' '.join(species) + \
                  '   (Configuration in file "' + aux_file + '")'
        try:
            passed = perform_inversion_symmetry_check(vc, atoms, \
                                                      heading, dashwidth)
            passed_all = passed_all and passed
            got_atleast_one = True
        except:
            pass

    if got_atleast_one:

        # Compute grade
        vc.rwrite('')
        vc.rwrite('='*dashwidth)
        vc.rwrite('To pass this verification check the model must be invariant '
                  'with respect to')
        vc.rwrite('translation and have inversion symmetry about the origin '
                  'for all configurations ')
        vc.rwrite('it was able to compute.')
        vc.rwrite('')

        if passed_all:
            vc_grade = 'P'
            vc_comment = 'Model energy has inversion symmetry for '     + \
                         'all configurations the model was able to compute.'
        else:
            vc_grade = 'F'
            vc_comment = 'Model energy does NOT have inversion symmetry '  + \
                         'for at least one configuration that the model '    + \
                         'was able to compute. This could be valid if the '  + \
                         'model includes an external field or represents '   + \
                         'a material with unusual quantum properties. '      + \
                         'Otherwise this is a error in the model '           + \
                         'implementation.'

        vc.rwrite('Grade: {}'.format(vc_grade))
        vc.rwrite('')
        vc.rwrite('Comment: '+vc_comment)
        vc.rwrite('')

        return vc_grade, vc_comment

    else:
        raise RuntimeError('Failed to compute all configuration for the '
                           'inversion symmetry verification check.')
        return None, None

################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == '__main__':

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Define VC object and do verification check
    vc = kimvc.vc_object(vc_name, vc_description, __author__)
    with vc:
        # Perform verification check and get grade
        try:
            vc_grade, vc_comment = do_vc(model, vc)
        except:
            vc_grade = "N/A"
            vc_comment = "Unable to perform verification check due to an error."
            sys.stderr.write('ERROR: Unable to perform verification check.\n')

        # Pack results in a dictionary and write VC property instance
        results = {"vc_name"        : vc_name,
                   "vc_description" : vc_description,
                   "vc_category"    : vc_category,
                   "vc_grade_basis" : vc_grade_basis,
                   "vc_grade"       : vc_grade,
                   "vc_comment"     : vc_comment,
                   "vc_files"       : vc_files}
        vc.write_results(results)