#!/usr/bin/env python
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2017, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#
################################################################################

# The docstring below is vc_description
'''Comparison of the analytical forces obtained from the model with forces
computed by numerical differentiation using Richardson extrapolation for a
randomly distorted non-periodic face-centered cubic (FCC) cube base structure.
Separate configurations are tested for each species supported by the model, as
well as one containing a random distribution of all species. Configurations
used for testing are provided as auxiliary files.'''

# TODO:
#
# This program does not deal with the issue of discontinuities in the
# model energy.  Such discontinuities could lead to incorrect numerical
# derivatives and to incorrect conclusions about the accuracy of the
# analytical forces.  A useful extension would be to check whether
# energy discontinuities are responsible for the component with the
# largest error, and if yes, ignore it and move on to the next largest,
# and so on.  Discontinuities could be found using the "detectedge"
# algorithm developed for the "Chebfun" Matlab package. See the
# discussion in "Piecewise Smooth Chebfuns", Pachon, Platte and Trefethen,
# https://www.cs.ox.ac.uk/files/717/NA-08-07.pdf
# and the code here:
# https://github.com/chebfun/chebfun/blob/development/%40fun/detectEdge.m

# Python 2-3 compatible code issues
from __future__ import print_function
try:
    input = raw_input
except NameError:
    pass

from ase.lattice.cubic import FaceCenteredCubic
from kimcalculator import KIMCalculator, KIM_get_supported_species_list
import kimvc
import sys
import random
import numpy as np
import numdifftools as nd
import math

__version__ = "000"
__author__ = "Ellad Tadmor"
vc_name = "vc-forces-numerical-derivative"
vc_description = kimvc.vc_stripall(__doc__)
vc_category = "consistency"
vc_grade_basis = "graded"
vc_files = []

################################################################################
#
#   FUNCTIONS
#
################################################################################

################################################################################
def negpot(p, at=0, dof=0, atoms=None):
    '''
    Function that takes the value 'p' of degree of freedom 'dof' of atom 'at'
    and returns the negative of the total potential energy of full system of
    atoms. Used by the numerical derivative method.
    '''
    if atoms==None:
       return 0
    sve = (atoms[at].position)[dof]
    (atoms[at].position)[dof] = p
    pot = atoms.get_potential_energy()
    (atoms[at].position)[dof] = sve
    return -pot

################################################################################
def perform_numerical_derivative_check(vc, atoms, heading, dashwidth):
    '''
    Perform a numerical derivative check for the ASE atoms object in 'atoms'.
    '''
    # compute analytical forces (negative gradient of cohesive energy)
    energy = atoms.get_potential_energy()
    forces = atoms.get_forces()

    # Loop over atoms and compute numerical derivative check
    Dnegpot = nd.Derivative(negpot, full_output=True)
    forces_num    = np.zeros(shape=(len(atoms),3), dtype=float, order='C')
    forces_uncert = np.zeros(shape=(len(atoms),3), dtype=float, order='C')
    for at in range(0,len(atoms)):
        for dof in range(0,3):
            p = (atoms[at].position)[dof]
            val, info = Dnegpot(p,at=at,dof=dof,atoms=atoms)
            forces_num[at,dof] = val
            forces_uncert[at,dof] = info.error_estimate

    # Identify outliers using a box plot construction with fences
    # (See http://www.itl.nist.gov/div898/handbook/prc/section1/prc16.htm)
    # We'll take all results above the upper outer fence to be outliers.
    uncert_lower_quartile  = np.percentile(forces_uncert.flatten(), 25)
    uncert_upper_quartile  = np.percentile(forces_uncert.flatten(), 75)
    uncert_interquartile_range = uncert_upper_quartile - uncert_lower_quartile
    uncert_upper_fence = uncert_upper_quartile + 3*uncert_interquartile_range

    # Initialize for printing
    frmt_head = \
                "{0:>6}  {1:>4} {2:>3} {3:>25} {4:>25} {5:>15} {6:>15}"
    frmt_line = \
        {True:  "{0: 6d}  {1:4} {2: 3d} {3: 25.15e} {4: 25.15e} {5: 15.5e} {6: 15.5e} {7:1}", \
         False: "             {2: 3d} {3: 25.15e} {4: 25.15e} {5: 15.5e} {6: 15.5e} {7:1}"}
    vc.rwrite('Comparison of analytical forces obtained from the model, '
              'the force computed as a numerical derivative')
    vc.rwrite('of the energy, the difference between them, and the uncertainty '
              'in the numerical estimate of the force:')
    vc.rwrite('')
    vc.rwrite(heading)
    vc.rwrite('-'*dashwidth)
    args = ('Part', 'Spec', 'Dir', 'Force_model', 'Force_numer', \
            'Force diff', 'uncertainty')
    vc.rwrite(frmt_head.format(*args))
    vc.rwrite('-'*dashwidth)

    # Identify max error and print numerical derivative results
    eps_prec = np.finfo(float).eps
    errmax = 0.
    at_least_one_result_discarded = False
    for at in range(0,len(atoms)):
        for dof in range(0,3):
            forcediff = abs(forces[at,dof]-forces_num[at,dof])
            den = max(0.5*(abs(forces[at,dof]) + \
                           abs(forces_num[at,dof])), eps_prec)
            if forces_uncert[at,dof] < uncert_upper_fence:
                # Result is not an outlier. Include it in determining max error
                lowacc_mark = ' '
                if forcediff/den > errmax:
                    errmax = forcediff/den
                    at_max  = at
                    dof_max = dof
            else:
                lowacc_mark = '*'
                at_least_one_result_discarded = True
            # Print results line
            args = (at+1, atoms[at].symbol, dof+1, \
                    forces[at,dof], forces_num[at,dof], \
                    forcediff, forces_uncert[at,dof], lowacc_mark)
            vc.rwrite(frmt_line[dof==0].format(*args))
            if (dof==2):
               vc.rwrite('-'*dashwidth)
    if at_least_one_result_discarded:
        vc.rwrite('* Starred lines are suspected outliers and are not '
                  'included when determining the error.')
        vc.rwrite('  A calculation is considered an outlier if it has an '
                  'uncertainty that lies at an abnormal')
        vc.rwrite('  distance from the other uncertainties in this set of '
                  'calculations.  Outliers are determined')
        vc.rwrite('  using the box plot construction with fences. '
                  'An outlier could indicate a problem with the')
        vc.rwrite('  the numerical differentiation or problems with the '
                  'potential energy, such as dicontinuities.')

   # Print summary
    vc.rwrite('')
    vc.rwrite( \
        'Maximum error obtained for particle = {0:d}, direction = {1:d}:'. \
        format(at_max+1, dof_max+1))
    vc.rwrite('')
    vc.rwrite('                   |F_model - F_numer|')
    vc.rwrite('    error = --------------------------------- = {0:.5e}'. \
              format(errmax))
    vc.rwrite('            max{(|F_model|+|F_numer|)/2, eps}')
    vc.rwrite('')
    vc.rwrite('')

    return errmax

################################################################################
def do_vc(model, vc):
    '''
    Perform Numerical Derivative Check VC
    '''
    # Get supported species
    species = KIM_get_supported_species_list(model)
    species = kimvc.remove_species_not_supported_by_ASE(species)
    species.sort()

    # Basic cell parameters
    lattice_constant = 3.0
    pert_amp = 0.1*lattice_constant
    ncells_per_side = 2
    seed = 13
    random.seed(seed)

    # Print Vc info
    dashwidth = 101
    vc.rwrite('')
    vc.rwrite('-'*dashwidth)
    vc.rwrite('Results for KIM Model      : %s' % model.strip())
    vc.rwrite('Supported species          : %s' % ' '.join(species))
    vc.rwrite('')
    vc.rwrite('random seed                = %d'    % seed)
    vc.rwrite('lattice constant           = %0.3f' % lattice_constant)
    vc.rwrite('perturbation amplitude     = %0.3f' % pert_amp)
    vc.rwrite('number unit cells per side = %d'    % ncells_per_side)
    vc.rwrite('-'*dashwidth)
    vc.rwrite('')

    # Initialize variables
    errmaxmax = 0.
    got_atleast_one = False

    # Perform numerical derivative check for monotatomic systems
    for spec in species:
        atoms = FaceCenteredCubic(
                size=(ncells_per_side, ncells_per_side, ncells_per_side),
                latticeconstant=lattice_constant, symbol=spec, pbc=False)
        calc = KIMCalculator(model)
        atoms.set_calculator(calc)
        kimvc.rescale_to_get_nonzero_forces(atoms, 0.01)
        kimvc.randomize_positions(atoms, pert_amp)
        kimvc.perturb_until_all_forces_sizeable(atoms, pert_amp)
        aux_file = 'config-'+spec+'.xyz'
        vc_files.append(aux_file)
        vc.write_aux_ase_atoms(aux_file, atoms, 'xyz')
        heading = 'MONOATOMIC STRUCTURE -- Species = ' + spec + \
                '   (Configuration in file "' + aux_file + '")'
        try:
            errmax = perform_numerical_derivative_check(vc, atoms, \
                                                        heading, dashwidth)
            errmaxmax = max(errmax, errmaxmax)
            got_atleast_one = True
        except:
            pass

    # Perform numerical derivative check for mixed system
    if len(species)>1:
        while True:
            atoms = FaceCenteredCubic(
                    size=(ncells_per_side, ncells_per_side, ncells_per_side),
                    latticeconstant=lattice_constant, symbol="H", pbc=False)
            if len(atoms) < len(species):
                ncells_per_side += 1
            else:
                break
        kimvc.randomize_species(atoms, species)
        calc = KIMCalculator(model)
        atoms.set_calculator(calc)
        kimvc.rescale_to_get_nonzero_forces(atoms, 0.01)
        kimvc.randomize_positions(atoms, pert_amp)
        kimvc.perturb_until_all_forces_sizeable(atoms, pert_amp)
        aux_file = 'config-'+''.join(species)+'.xyz'
        vc_files.append(aux_file)
        vc.write_aux_ase_atoms(aux_file, atoms, 'xyz')
        heading = 'MIXED STRUCTURE -- Species = ' + ' '.join(species) + \
                  '   (Configuration in file "' + aux_file + '")'
        try:
            errmax = perform_numerical_derivative_check(vc, atoms, \
                                                        heading, dashwidth)
            errmaxmax = max(errmax, errmaxmax)
            got_atleast_one = True
        except:
            pass

    if got_atleast_one:

        # Compute grade
        vc.rwrite('='*dashwidth)
        vc.rwrite('Grade is based on maximum error across all systems '
                  '(for a model supporting multiple species):')
        vc.rwrite('')
        vc.rwrite('Maximum error     = {:15.5e}'.format(errmaxmax))
        vc.rwrite('Machine precision = {:15.5e}'.format(np.finfo(float).eps))
        vc.rwrite('')
        vc_grade, vc_comment \
            = kimvc.vc_letter_grade_machine_precision(errmaxmax)
        vc.rwrite('Grade: {}'.format(vc_grade))
        vc.rwrite('')
        vc.rwrite('Comment: '+vc_comment)
        vc.rwrite('')

        return vc_grade, vc_comment

    else:
        raise RuntimeError('Failed to compute a single numerical derivative.')
        return None, None

################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == '__main__':

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Define VC object and do verification check
    vc = kimvc.vc_object(vc_name, vc_description, __author__)
    with vc:
        # Perform verification check and get grade
        try:
            vc_grade, vc_comment = do_vc(model, vc)
        except:
            vc_grade = "N/A"
            vc_comment = "Unable to perform verification check due to an error."
            sys.stderr.write('ERROR: Unable to perform verification check.\n')

        # Pack results in a dictionary and write VC property instance
        results = {"vc_name"        : vc_name,
                   "vc_description" : vc_description,
                   "vc_category"    : vc_category,
                   "vc_grade_basis" : vc_grade_basis,
                   "vc_grade"       : vc_grade,
                   "vc_comment"     : vc_comment,
                   "vc_files"       : vc_files}
        vc.write_results(results)