#!/usr/bin/env python
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2017, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#
################################################################################

# The docstring below is vc_description
'''
Check that the model supports periodic boundary conditions correctly.
If the simulation box is increased by an integer factor along a periodic
direction, the total energy must multiply by that factor and the forces
on atoms that are periodic copies of each other must be the same.
The check is performed for a randomly distorted non-periodic
face-centered cubic (FCC) cube base structure. Separate configurations
are tested for each species supported by the model, as well as one containing
a random distribution of all species. For each configuration, all
possible combinations of periodic boundary conditions are tested:
TFF, FTF, FFT, TTF, TFT, TTF, TTT (where 'T' indicates periodicity
along a direction, and 'F' indicates no periodicity). The verification
check passes if the energy of all configurations that the model is able
to compute support all periodic boundary conditions correctly.
Configurations used for testing are provided as auxiliary files.'''

# Python 2-3 compatible code issues
from __future__ import print_function
try:
   input = raw_input
except NameError:
   pass

import sys
import kimvc
from kimcalculator import KIMCalculator, KIM_get_supported_species_list
from ase import Atoms
from ase.lattice.cubic import FaceCenteredCubic
import random
import itertools
import numpy as np

__version__ = "000"
__author__ = "Ellad Tadmor"
vc_name = "vc-periodicity-support"
vc_description = kimvc.vc_stripall(__doc__)
vc_category = "mandatory"
vc_grade_basis = "passfail"
vc_files = []

################################################################################
#
#   FUNCTIONS
#
################################################################################

################################################################################
def perform_pbc_check(model, vc, atoms, heading, dashwidth):
    '''
    Perform periodic boundary conditions check for the ASE atoms
    object in 'atoms'
    '''
    # set comparison tolerance
    tole = 1e-8
    eps_prec = np.finfo(float).eps

    # compute  the energy and forces in the original location
    energy = atoms.get_potential_energy()
    forces = atoms.get_forces()

    # extract original system information
    pbc       = atoms.get_pbc()
    species   = atoms.get_chemical_symbols()
    positions = atoms.get_positions()
    cell      = atoms.get_cell()
    num_atoms = len(atoms)

    # construct a system that is extended in all periodic directions
    num_periodic = pbc.tolist().count(True)
    num_copies = 2**num_periodic
    species2   = species*num_copies
    positions2 = positions.copy()
    cell2      = cell.copy()
    for j in range(3):
        if pbc[j]:
            cell2[j] = np.multiply(cell2[j], 2.)
            natoms = len(positions2)
            positions2 = np.tile(positions2,(2,1))
            for i in range(natoms,2*natoms):
                positions2[i,:] += cell[j]
    atoms2 = Atoms(''.join(species2), positions=positions2, cell=cell2, pbc=pbc)
    calc2 = KIMCalculator(model)
    atoms2.set_calculator(calc2)

    # compute the energy and forces in the original location
    energy2 = atoms2.get_potential_energy()
    forces2 = atoms2.get_forces()

    # check if energy scales in the expected way
    den = max(0.5*(abs(energy2) + abs(num_copies*energy)), eps_prec)
    passed_energy = abs(energy2-num_copies*energy)/den < tole

    # report energy results
    vc.rwrite('')
    vc.rwrite(heading)
    vc.rwrite('-'*dashwidth)
    vc.rwrite('')
    vc.rwrite( \
        'The system is doubled in p={0} periodic directions, which means an '
        'increase by a factor n=2^{0}={1}'.format(num_periodic, num_copies))
    vc.rwrite('in the number of atoms and in the energy.')
    vc.rwrite('')
    vc.rwrite('Energy requirement:')
    vc.rwrite('')
    vc.rwrite( \
        'V(DBL_p(r_1,...,r_N)) = (2^p) V(r_1,...,r_N), '
        'where r_i is the position of atom i, V is the potential energy, ')
    vc.rwrite('and DBL_p is an operator that doubles the configuration in '
              'p periodic directions.')
    vc.rwrite('')
    vc.rwrite('V(DBL_p(r_1,...,r_N)) = {0}'.format(energy2))
    vc.rwrite('2^p V(r_1,...,r_N)    = {0}'.format(num_copies*energy))
    vc.rwrite('')

    # report force results and check if they forces in the doubled
    # periodic cell map back to the original forces as expected.
    vc.rwrite('Forces requirement:')
    vc.rwrite('')
    vc.rwrite( \
        'f_k(DBL_p(r_1,...,r_N)) = f_(k % N)(r_1,...,r_N), where r_i '
        'is the position of atom i, f_k is the force on atom k ')
    vc.rwrite( \
        '(where k runs from 1 to the number of atoms in the doubled '
        'configuration), DBL_p doubles the configuration ')
    vc.rwrite( \
        'in p periodic directions, N is the number of atoms in the original '
        'configuration, and % is the modulo operator.')
    vc.rwrite('')

    hfmt = '{:>3}' + ' '*17 + '{}' + ' '*34 + '{}'
    fmt  = '{:>3}   ' + '{: .8e}   '*3 + '|  ' + '{: .8e}   '*3 + '{}'
    vc.rwrite(hfmt.format('k','f_k(DBL_p(r_1,...,r_N))', \
                              'f_(k % N)(r_1,...,r_N)'))
    vc.rwrite('-'*dashwidth)
    passed_forces = True
    for i in range(0,len(atoms2)):
        f_lhs = forces2[i]
        j = i % num_atoms
        f_rhs = forces[j]
        den=np.maximum(0.5*(np.absolute(f_lhs)+np.absolute(f_rhs)),eps_prec)
        force_ok = np.all(np.absolute(f_lhs - f_rhs)/den<tole)
        stat = ''
        if not force_ok:
            passed_forces = False
            stat = 'ERR'
        vc.rwrite(fmt.format(i, f_lhs[0], f_lhs[1], f_lhs[2], \
                                f_rhs[0], f_rhs[1], f_rhs[2], stat))
    vc.rwrite('-'*dashwidth)

    # determine overall result
    passed = passed_energy and passed_forces

    if passed:
        vc.rwrite('PASS: Energies and forces are the same to within a '
                   'relative error of {0}'.format(tole))
    else:
        vc.rwrite('FAIL: Energies and/or forces differ by more than '
                  'a relative error of {0}'.format(tole))
    vc.rwrite('-'*dashwidth)

    return passed

################################################################################
def do_vc(model, vc):
    '''
    Do periodicity support check
    '''
    # Get supported species
    species = KIM_get_supported_species_list(model)
    species = kimvc.remove_species_not_supported_by_ASE(species)
    species.sort()

    # Basic cell parameters
    lattice_constant_orig = 3.0
    pert_amp_orig = 0.1*lattice_constant_orig
    ncells_per_side = 1
    seed = 13
    random.seed(seed)

    # Print VC defining information
    dashwidth = 120
    vc.rwrite('')
    vc.rwrite('-'*dashwidth)
    vc.rwrite('Results for KIM Model      : %s' % model.strip())
    vc.rwrite('Supported species          : %s' % ' '.join(species))
    vc.rwrite('')
    vc.rwrite('random seed                = %d'    % seed)
    vc.rwrite('lattice constant (orig)    = %0.3f' % lattice_constant_orig)
    vc.rwrite('perturbation amplitude     = %0.3f' % pert_amp_orig)
    vc.rwrite('number unit cells per side = %d'    % ncells_per_side)
    vc.rwrite('-'*dashwidth)
    vc.rwrite('')

    # Initialize variables
    got_atleast_one = False
    passed_all = True

    # Creae list of all possible periodic boundary conditions
    onezero = (1,0)
    TF = {1:'T', 0:'F'}
    pbc_options = list(itertools.product(onezero,repeat=3))
    pbc_options.remove((0,0,0))

    # Perform periodicty check for monotatomic systems
    for spec in species:

        for pbc in pbc_options:

            calc = KIMCalculator(model)
            lattice_constant = lattice_constant_orig
            got_initial_config = False
            while not got_initial_config:
                atoms = FaceCenteredCubic(
                        size=(ncells_per_side, ncells_per_side, ncells_per_side),
                        latticeconstant=lattice_constant, symbol=spec, pbc=pbc)
                atoms.set_calculator(calc)
                try:
                    kimvc.rescale_to_get_nonzero_forces(atoms, 0.01)
                    got_initial_config = True
                except:
                    # Initial config failed. This most likely due to an evaluation
                    # outside the legal model range. Increase lattice constant and
                    # try again.
                    lattice_constant += 0.25
                    if lattice_constant > 10.0:
                        sys.exit(1)  # Cannot find a working configuration within a
                                     # a reasonable lattice constant range.
            # Randomize positions
            save_positions = atoms.get_positions()
            pert_amp = pert_amp_orig
            got_randomized_config = False
            while not got_randomized_config:
                try:
                    kimvc.randomize_positions(atoms, pert_amp)
                    forces = atoms.get_forces() # make sure forces can be computed
                    got_randomized_config = True
                except:
                    # Failed to compute forces; reset to original posns and retry
                    atoms.set_positions(save_positions)
                    pert_amp *= 0.5  # cut perturbation amplitude by half
            # Move atoms around until all forces are sizeable
            kimvc.perturb_until_all_forces_sizeable(atoms, pert_amp)
            pbcstr = ''.join(TF[x] for x in pbc)
            aux_file = 'config-'+spec+'-'+pbcstr+'.xyz'
            vc_files.append(aux_file)
            vc.write_aux_ase_atoms(aux_file, atoms, 'xyz')
            heading = 'MONOATOMIC STRUCTURE -- Species = ' + spec + ', ' \
                      'PBC = ' + pbcstr + \
                      '   (Configuration in file "' + aux_file + '")'
            try:
                passed = perform_pbc_check(model, vc, atoms, heading, dashwidth)
                passed_all = passed_all and passed
                got_atleast_one = True
            except:
                pass

    # Perform periodicity check for mixed system
    if len(species)>1:

        for pbc in pbc_options:

            lattice_constant = lattice_constant_orig
            while True:
                atoms = FaceCenteredCubic(
                    size=(ncells_per_side, ncells_per_side, ncells_per_side),
                    latticeconstant=lattice_constant, symbol="H", pbc=pbc)
                if len(atoms) < len(species):
                    ncells_per_side += 1
                else:
                    break
            kimvc.randomize_species(atoms, species)
            calc = KIMCalculator(model)
            atoms.set_calculator(calc)
            got_initial_config = False
            while not got_initial_config:
                try:
                    kimvc.rescale_to_get_nonzero_forces(atoms, 0.01)
                    got_initial_config = True
                except:
                    # Initial config failed. This most likely due to an evaluation
                    # outside the legal model range. Increase lattice constant and
                    # try again.
                    lattice_constant += 0.25
                    if lattice_constant > 10.0:
                        sys.exit(1)  # Cannot find a working configuration within a
                                    # a reasonable lattice constant range.
                    acell = lattice_constant*ncells_per_side
                    atoms.set_cell([acell, acell, acell],scale_atoms=True)
            # Randomize positions
            save_positions = atoms.get_positions()
            pert_amp = pert_amp_orig
            got_randomized_config = False
            while not got_randomized_config:
                try:
                    kimvc.randomize_positions(atoms, pert_amp)
                    forces = atoms.get_forces() # make sure forces can be computed
                    got_randomized_config = True
                except:
                    # Failed to compute forces; reset to original posns and retry
                    atoms.set_positions(save_positions)
                    pert_amp *= 0.5  # cut perturbation amplitude by half
            kimvc.perturb_until_all_forces_sizeable(atoms, pert_amp)
            pbcstr = ''.join(TF[x] for x in pbc)
            aux_file = 'config-'+''.join(species)+'-'+pbcstr+'.xyz'
            vc_files.append(aux_file)
            vc.write_aux_ase_atoms(aux_file, atoms, 'xyz')
            heading = 'MIXED STRUCTURE -- Species = '+' '.join(species)+', ' \
                      'PBC = ' + pbcstr + \
                      '   (Configuration in file "' + aux_file + '")'
            try:
                passed = perform_pbc_check(model, vc, atoms, heading, dashwidth)
                passed_all = passed_all and passed
                got_atleast_one = True
            except:
                pass

    if got_atleast_one:

        # Compute grade
        vc.rwrite('='*dashwidth)
        # Compute grade
        vc.rwrite('')
        vc.rwrite('='*dashwidth)
        vc.rwrite('To pass this verification check the model must correctly '
                  'support periodic boundary conditions ')
        vc.rwrite('for all configurations it was able to compute.')
        vc.rwrite('')

        if passed_all:
            vc_grade = 'P'
            vc_comment = \
               'Periodic boundary conditions were correctly supported ' + \
               'for all configurations that the model was able to compute.'
        else:
            vc_grade = 'F'
            vc_comment = \
               'Periodic boundary conditions were NOT supported correctly ' + \
               'for at least one configuration that the model was able to ' + \
               'This is an error in the implementation of the model.'

        vc.rwrite('Grade: {}'.format(vc_grade))
        vc.rwrite('')
        vc.rwrite('Comment: '+vc_comment)
        vc.rwrite('')

        return vc_grade, vc_comment

    else:
        raise RuntimeError('Failed to compute periodic boundary conditions '
                           'for all tests configurations.')
        return None, None

    return vc_grade, vc_comment

################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == '__main__':

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Define VC object and do verification check
    vc = kimvc.vc_object(vc_name, vc_description, __author__)
    with vc:
        # Perform verification check and get grade
        try:
            vc_grade, vc_comment = do_vc(model, vc)
        except:
            vc_grade = "N/A"
            vc_comment = "Unable to perform verification check due to an error."
            sys.stderr.write('ERROR: Unable to perform verification check.\n')

        # Pack results in a dictionary and write VC property instance
        results = {"vc_name"        : vc_name,
                   "vc_description" : vc_description,
                   "vc_category"    : vc_category,
                   "vc_grade_basis" : vc_grade_basis,
                   "vc_grade"       : vc_grade,
                   "vc_comment"     : vc_comment,
                   "vc_files"       : vc_files}
        vc.write_results(results)