#!/usr/bin/env python
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2017, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#
################################################################################

# The docstring below is vc_description
'''Determines whether a model has a continuous energy and first
derivative, i.e. belongs to the C^1 continuity class, for all possible dimers.
For a model supporting N species, there are N + N!/(2(N-2)! distinct dimers for
all possible species combinations.  (For example if N=3, there are 3+3!/2=6
dimers.  If the species are A,B,C, the 6 dimers are AA, BB, CC, AB, AC, BC.)
For each dimer, the equilibrium separation and cutoff are determined.
The continuity across the cutoff is assessed.  Then an analysis is performed
to detect any discontinuities from half the equilibrium distance to the cutoff.
Although the verification check only requires C^1 continuity to pass, continuity
up to 3rd order is checked and reported.'''

# Python 2-3 compatible code issues
from __future__ import print_function
try:
    input = raw_input
except NameError:
    pass

from kimcalculator import KIMCalculator, KIM_get_supported_species_list
import kimvc
import itertools
from ase import Atoms
import scipy.optimize as opt
import numpy as np
import sys
import math
import numdifftools as nd

__version__ = "000"
__author__ = "Ellad Tadmor"
vc_name = "vc-dimer-continuity-c1"
vc_description = kimvc.vc_stripall(__doc__)
vc_category = "informational"
vc_grade_basis = "passfail"
vc_files = []

################################################################################
#
#   FUNCTIONS
#
################################################################################
def create_species_pairs_list(species):
    '''
    Given a list of species, create all the monoatomic and biatomic dimer pairs.
    '''
    species_pairs = []
    for sp in species:
        species_pairs.append((sp,sp))
    if len(species)>1:
        species_pairs += list(itertools.combinations(species,2))
    return species_pairs

################################################################################
def energy(a, dimer, positions):
    '''
    Returns the energy of a dimer with a separation 'a'.
    '''
    dimer[1].x = a
    return dimer.get_potential_energy()

################################################################################
def energy_cheat(a, dimer, positions, offset):
    '''
    Returns the energy of a dimer with a separation 'a' plus a small 'offset'
    to ensure that the root bisection has a positive number past the cutoff.
    '''
    dimer[1].x = a
    return dimer.get_potential_energy() + offset

################################################################################
def get_equilibrium_separation(dimer):
    '''
    Compute the equilibrium separation of a dimer (2-atom chain).
    '''
    positions = dimer.get_positions()
    init_guess = positions[1][0]
    aopt_arr, eopt, iterations, funccalls, warnflag = \
        opt.fmin(energy, init_guess, args=(dimer,positions), \
        full_output=True, xtol=1e-8, disp=False, maxiter=1000)
    info = {"iterations": iterations,
            "func_calls": funccalls,
            "warnflag": warnflag}
    return aopt_arr[0], eopt, info

################################################################################
def get_cutoff_radius(dimer,a,b,offset):
    '''
    Compute the cutoff radius (distance at which energy of dimer is zero)
    '''
    positions = dimer.get_positions()
    x0, results = opt.bisect(energy_cheat, a, b, args=(dimer,positions,offset), \
        full_output=True, xtol=1e-8, maxiter=1000)
    return x0, results

################################################################################
def do_vc(model, vc):
    '''
    Perform Dimer Continuity C1 Verification Check
    '''
    # Initialize
    max_deriv = 3  # largest derivative to be investigated
                   # (NOTE: If this is increased, increase nth below)
    continuous = [True]*(1+max_deriv)  # assume function and all
                                       # derivs are continuous
    cont_tolerance = 1e-6  # Values in absolute value greater than this
                           # of the energy or its derivatives at the
                           # cutoff are considered discontinuous.

    led_tolerance = 1.0  # value considered an indicator of a
                         # discontinuity for the local edge detection
                         # algorithm
    nth = {
        0: "Energy",
        1: "1st-Deriv",
        2: "2nd-Deriv",
        3: "3rd-Deriv"}
    yesno = {
        True: "yes",
        False: "** NO **"}

    # Get supported species
    species = KIM_get_supported_species_list(model)
    species = kimvc.remove_species_not_supported_by_ASE(species)
    species.sort()

    # Print information specific to this VC
    dashwidth = 101
    vc.rwrite('')
    vc.rwrite('-'*dashwidth)
    vc.rwrite('Results for KIM Model      : %s' % model.strip())
    vc.rwrite('Supported species          : %s' % ' '.join(species))

    # Create list of dimers species
    species_pairs = create_species_pairs_list(species)

    # Loop over all dimers and determine the smoothness of each
    for i in range(0,len(species_pairs)):

        # Print header for this dimer
        vc.rwrite('')
        vc.rwrite('-'*dashwidth)
        vc.rwrite('DIMER {0}--{1}'.format(*species_pairs[i]))
        vc.rwrite('-'*dashwidth)

        # Build ASE Atoms object for the dimer
        a = 1.0
        dimer = Atoms(''.join(species_pairs[i]), positions=[(0,0,0),(a,0,0)],
                                   cell=(1000*a, 1000*a, 1000*a),pbc=(1,1,1))
        calc = KIMCalculator(model)
        dimer.set_calculator(calc)

        vc.rwrite('')
        vc.rwrite('E n e r g y  E x t r e m a')

        # Find equilibrium separation for this dimer
        aopt, eopt, info = get_equilibrium_separation(dimer)
        warn = 'none'
        if info['warnflag']:
            warn = '** DID NOT CONVERGE **'
        vc.rwrite('{0:25}   {1:>15}   {2:>15}   {3:>9}   {4:>9}    {5}'.
              format('', 'distance', 'energy', '#iter', '#fn-calls', 'warnings'))
        vc.rwrite('{0:25}   {1: 11.8e}   {2: 11.8e}   {3:9d}   {4:9d}    {5}'.
              format('Equilibrium separation:', aopt, eopt,
              info['iterations'],info['func_calls'],warn))
        if info['warnflag']:
            vc.rwrite('')
            vc.rwrite('WARNING: NOT CHECKING CONTINUITY FOR '+'-'.join(species_pairs[i])+' DIMER.')
            vc.rwrite('')
            continue   # failed to converge, so skip this dimer

        # Find cutoff radius
        positions = dimer.get_positions()
        a = aopt
        b = a
        eblast = eopt
        energy_changing = True
        while energy_changing:
            b = 2*b
            eb = energy(b, dimer, positions)
            if abs(eb-eblast)<1e-9:
                energy_changing = False
            else:
                eblast = eb
        # compute offset to ensure that energy past cutoff is
        # a bit positive to satisfy bisection algorithm
        if eb<=0:
            offset = -eb + np.finfo(float).eps
        rcut, results = get_cutoff_radius(dimer,a,b,offset)
        warn = 'none'
        if not results.converged:
            warn = '** DID NOT CONVERGE **'
        vc.rwrite('{0:25}   {1: 11.8e}   {2: 11.8e}   {3:9d}   {4:9d}    {5}'.
              format('Cutoff separation:', rcut, energy(rcut, dimer, positions),
              results.iterations,results.function_calls,warn))
        if not results.converged:
            vc.rwrite('')
            vc.rwrite('WARNING: NOT CHECKING CONTINUITY FOR '+'-'.join(species_pairs[i])+' DIMER.')
            vc.rwrite('')
            continue   # failed to converge, so skip this dimer

        vc.rwrite('')
        vc.rwrite('C u t o f f  S m o o t h n e s s')
        vc.rwrite('')

        vc.rwrite('{0:10}   {1:>15}   {2:10}'.
              format('', 'value', 'continuous'))

        # Check smoothness at cutoff
        for n in range(0,max_deriv+1):
            Denergy = nd.Derivative(energy, full_output=True, method='backward', n=n)
            val, info = Denergy(rcut, dimer=dimer, positions=positions)
            is_continuous = abs(val) <= cont_tolerance
            if not is_continuous:
                continuous[n] = False
            vc.rwrite('{0:10}   {1: 11.8e}   {2:>10}'.
                format(nth[n], np.asscalar(val), yesno[is_continuous]))

        vc.rwrite('')
        vc.rwrite('C o n t i n u i t y')
        vc.rwrite('')

        # Set the range and increment for exploring internal discontinuities
        amin = 0.5*aopt
        amax = rcut
        del_a = 0.01
        na = int(math.ceil((amax-amin)/del_a))
        dimer[1].x = amin
        refposns = dimer.get_positions()
        vc.rwrite('Checking continuity for r = [{0:.5f},{1:.5f}] at {2:d} points (Delta r = {3:.5f})'. \
              format(amin,amax,na,del_a))

        vc.rwrite('')
        vc.rwrite('Local edge detection based on a normalized 5th-order local difference formula T^5')
        vc.rwrite('is used to determine the presence of discontinuities. The tolernace is |T^5|>{0:.5f}.'.format(led_tolerance))
        vc.rwrite('(For details see Anne Gelb and Eitan Tadmor, J. Sci. Comp., 28:279-306, 2006.)')
        vc.rwrite('')

        for n in range(0,max_deriv+1):
            # set numerical derivative to constant step to prevent the algorithm
            # going haywire in some cases.
            Denergy = nd.Derivative(energy, full_output=True, n=n,step=0.1*del_a)
            if n==0:
                aux_file = 'dimer-energy-' + ''.join(species_pairs[i]) + '.dat'
            else:
                aux_file = 'dimer-energy-deriv-' + str(n) + '-'+ ''.join(species_pairs[i]) + '.dat'
            vc_files.append(aux_file)
            vc.rwrite('')
            vc.rwrite('Checking {0} ({1} vs. distance in file "{2}")'. \
                  format(nth[n], nth[n].lower(), aux_file))

            # Generate energy curve
            r = []
            e = []
            for j in range(0,na+1):
                a = amin + j*del_a
                val, info = Denergy(a, dimer=dimer, positions=refposns)
                r.append(a)
                e.append(np.asscalar(val))
            vc.write_aux_x_y(aux_file, r, e)

            # Apply local edge detection algorithm to identify discontinuities
            # Based on: A. Gelb and E. Tadmor, J. Sci. Comp., 28:279-306, 2006.
            fact = 1.0/6.0
            is_continuous = True
            for j in range(2,na-3):
                # use 5-th order local difference formula
                led = fact*(-e[j-2] + 5*e[j-1] - 10*e[j] + 10*e[j+1] - 5*e[j+2] + e[j+3])
                if abs(led) > led_tolerance:
                    continuous[n] = False
                    is_continuous = False
                    vc.rwrite('==> Suspected discontinuity encountered at r={:11.8e} (|T^5| = {:11.8e})'.format(r[j],abs(led)))
            if is_continuous:
                    vc.rwrite('... No discontinuities found.')


    # Summary of results
    vc.rwrite('')
    vc.rwrite('='*dashwidth)
    vc.rwrite('')
    vc.rwrite('SUMMARY of Model Continuity Results Across All Dimers:')
    vc.rwrite('')

    vc.rwrite('{0:10}   {1:10}'.format('', 'continuous'))
    vc.rwrite('-'*29)
    for n in range(0,max_deriv+1):
        vc.rwrite('{0:10}   {1:>10}'.format(nth[n], yesno[continuous[n]]))

    # Determine continuity class and write out properties
    k=-1
    n=0
    while n<=max_deriv and continuous[n]:
        k = n
        n += 1
    vc_comment = 'The model is C^{0:d} continuous. '.format(k)
    if k==-1:
        vc_comment += 'This means that the model has discontinuous energy.'
    if k==0:
        vc_comment += 'This means that the model has continuous energy, but a discontinuous first derivative.'
    if k==1:
        vc_comment += 'This means that the model has continuous energy and continuous first derivative.'
    if k==2:
        vc_comment += 'This means that the model has continuous energy and continuous derivatives up to order 2.'
    if k==3:
        vc_comment += 'This means that the model has continuous energy and continuous derivatives at least up to order 3. (Derivatives beyond this order were not tested.)'

    vc.rwrite('')
    vc.rwrite('='*dashwidth)
    vc.rwrite('Continuity must be C^1 or higher to pass this verification check.')
    vc.rwrite('')
    if k>=1:
        vc_grade = 'P'
    else:
        vc_grade = 'F'
    vc.rwrite('Grade: {}'.format(vc_grade))
    vc.rwrite('')
    vc.rwrite('Comment: '+vc_comment)
    vc.rwrite('')

    return vc_grade, vc_comment


################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == '__main__':

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Define VC object and do verification check
    vc = kimvc.vc_object(vc_name, vc_description, __author__)
    with vc:
        # Perform verification check and get grade
        try:
            vc_grade, vc_comment = do_vc(model, vc)
        except:
            vc_grade = "N/A"
            vc_comment = "Unable to perform verification check due to an error."
            sys.stderr.write('ERROR: Unable to perform verification check.\n')

        # Pack results in a dictionary and write VC property instance
        results = {"vc_name"        : vc_name,
                   "vc_description" : vc_description,
                   "vc_category"    : vc_category,
                   "vc_grade_basis" : vc_grade_basis,
                   "vc_grade"       : vc_grade,
                   "vc_comment"     : vc_comment,
                   "vc_files"       : vc_files}
        vc.write_results(results)