#!/usr/bin/env python
################################################################################
#
#  CDDL HEADER START
#
#  The contents of this file are subject to the terms of the Common Development
#  and Distribution License Version 1.0 (the "License").
#
#  You can obtain a copy of the license at
#  http:# www.opensource.org/licenses/CDDL-1.0.  See the License for the
#  specific language governing permissions and limitations under the License.
#
#  When distributing Covered Code, include this CDDL HEADER in each file and
#  include the License file in a prominent location with the name LICENSE.CDDL.
#  If applicable, add the following below this CDDL HEADER, with the fields
#  enclosed by brackets "[]" replaced with your own identifying information:
#
#  Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
#
#  CDDL HEADER END
#
#  Copyright (c) 2017, Regents of the University of Minnesota.
#  All rights reserved.
#
#  Contributor(s):
#     Ellad B. Tadmor
#
################################################################################

# The docstring below is vc_description
'''This verification check examines whether a model called in parallel by
multiple threads gives the same results as when called sequentially. A number
`num_configs` (preset in the code) of configurations is generated each
containing a different number of atoms based on a randomly distorted,
periodic, face-centered cubic (fcc) structure containing a random distribution
of all atoms supported by the model. Configurations used for testing are
provided as auxiliary files. The energy and forces for each configuration are
computed in sequence. Then the calculations are repeated in using Python
multithreading with `num_configs` threads, and with the Global Interpreter
Lock (GIL) released to ensure true parallelism.  Each thread possesses its
own copy of memory related to the calculation. The threaded calculations
are repeated `num_cycles` times (preset in the code) with the configurations
randomly distributed to the threads to create many varying opportunities
for collisions.  To pass the test, the total energy and the forces on all atoms
for a given configuration obtained in all cycles must be be identical to those
obtained in the sequential calculation.  Failure of this verification check implies
that the model is inappropriately storing information in persistent (static) memory
during or between calls. Passing this verification check provides some assurance
the the model is thread safe, but it is NOT a guarantee due to the inherent randomness
of race conditions in unsafe code. To be certain, the model code would need to
be studied and analyzed carefully.'''

# Python 2-3 compatible code issues
from __future__ import print_function
try:
    input = raw_input
except NameError:
    pass

from ase.lattice.cubic import FaceCenteredCubic
from kimcalculator import KIMCalculator, KIM_get_supported_species_list
import kimvc
import kimsm
import sys
import random
import numpy as np
import scipy.optimize
import threading
from multiprocessing import cpu_count

# import kimcluster
# from multiprocessing import Pool
# from multiprocessing.dummy import Pool as ThreadPool
# import time

__version__ = "000"
__author__ = "Ellad Tadmor"
vc_name = "vc-thread-safe"
vc_description = kimvc.vc_stripall(__doc__)
vc_category = "mandatory"
vc_grade_basis = "passfail"
vc_files = []

################################################################################
#
#   FUNCTIONS
#
################################################################################
def cubic_cell_energy(alat, atoms, ncells_per_side):
    '''
    Calculate the energy of the passed 'atoms' structure containing a
    cubic structure with 'ncells_per_side'. Scale to lattice constant
    'alat' (passed as a nd array of length 1) and return the energy.
    '''
    acell = alat[0]*ncells_per_side
    atoms.set_cell([acell, acell, acell],scale_atoms=True)
    return atoms.get_potential_energy()

################################################################################
def compute_config(atoms):
    '''
    Calculate the energy and forces for the passed 'atoms' structure.
    '''
    dum = atoms.get_potential_energy() # This forces a calculation of the model
                                       # if anything has changed and the results
                                       # energy, forces, etc. are stored in atoms.
    return None

################################################################################
def do_vc(model, vc):
    '''
    Perform Thread Safety VC
    '''
    # Get supported species
    species = KIM_get_supported_species_list(model)
    species = kimvc.remove_species_not_supported_by_ASE(species)
    species.sort()

    # Basic VC parameters
    num_cores = cpu_count()
    num_configs = 10
    num_cycles = 10
    lattice_constant = 2.5
    max_lattice_constant = 10.0
    mincells_per_side = 2
    maxcells_per_side = 10
    seed = 13
    random.seed(seed)

    # Print VC info
    agree_text = {True:"OK", False:"Failed"}
    dashwidth = 80
    vc.rwrite('')
    vc.rwrite('-'*dashwidth)
    vc.rwrite('Results for KIM Model       : %s' % model.strip())
    vc.rwrite('Supported species           : %s' % ' '.join(species))
    vc.rwrite('')
    vc.rwrite('Number of cores             = %d' % num_cores)
    vc.rwrite('number of configurations    = %d' % num_configs)
    vc.rwrite('number of cycles            = %d' % num_cycles)
    vc.rwrite('random seed                 = %d' % seed)
    vc.rwrite('minimum unit cells per side = %d' % mincells_per_side)
    vc.rwrite('maximum unit cells per side = %d' % maxcells_per_side)
    vc.rwrite('-'*dashwidth)
    vc.rwrite('')

    # Check the model type. VC can only be performed for KIM Models
    if kimsm.is_simulator_model(model):
        vc.rwrite(model.strip()+" is not a KIM Model.")
        vc.rwrite('')
        vc.rwrite('Thread safety can only be checked for models compatible with the KIM API.')
        vc.rwrite('Check not performed.')
        vc_grade = "N/A"
        vc_comment = "Thread safety can only be checked for KIM Models."
        return vc_grade, vc_comment

    # Make sure there are more than one processor available
    if (num_cores==1):
        vc.rwrite('Only one core available on system. Unable to perform thread safety verification check.')
        raise RuntimeError('Only one core available on system. Unable to perform thread safety verification check.')

    # Create num_config randomized configurations containing all
    # supported species
    vc.rwrite('SINGLE THREAD REFERENCE CALCULATIONS')
    vc.rwrite('')
    vc.rwrite('{0:>6}   {1:>6}   {2:>13}   {3:>13}'. \
              format('Config', '#Atoms', 'Energy', 'Ave Norm'))
    vc.rwrite('-'*47)
    alat = lattice_constant
    config_atoms = []
    config_energy = []
    config_forces = []
    for n in range(num_configs):
        # Randomly set system size based on allowed range
        ncells_per_side = random.randint(mincells_per_side, maxcells_per_side)
        # Compute equilibrium lattice constant
        not_done = True
        got_equil = False
        while not_done:
            while True:
                atoms = FaceCenteredCubic(
                        size=(ncells_per_side, ncells_per_side,
                        ncells_per_side),
                        latticeconstant=alat, symbol="H", pbc=False)
                # Change to the cell size in the following block should
                # be run at most once when a selected size is too small
                if len(atoms) < len(species):
                    ncells_per_side += 1
                    min_cells_per_side += 1
                else:
                    break
            kimvc.randomize_species(atoms, species)
            calc = KIMCalculator(model)
            atoms.set_calculator(calc)
            # Find equilibrium lattice constant
            try:
                res = scipy.optimize.minimize(cubic_cell_energy, alat,
                      args=(atoms,ncells_per_side), method='Nelder-Mead',
                      tol=1e-6)
                alat = res.x[0]
                not_done = False
                got_equil = True
            except:
                # failed for some reason (assume it's because of KIM error)
                alat += 0.5
                if alat > max_lattice_constant:
                    not_done = False
        if not got_equil:
            raise RuntimeError('Unable to compute equilibrium lattice constant for one of the configurations.')
        # Change periodicity to false
        atoms.set_pbc([False, False, False])
        # Rescale and perturb crystal
        kimvc.rescale_to_get_nonzero_forces(atoms, 0.01)
        pert_amp = 0.01*alat
        kimvc.randomize_positions(atoms, pert_amp)
        kimvc.perturb_until_all_forces_sizeable(atoms, pert_amp)
        # Write configuration files
        aux_file = 'config-'+str(n).strip()+'-config-'+''.join(species)+'.xyz'
        vc_files.append(aux_file)
        vc.write_aux_ase_atoms(aux_file, atoms, 'xyz')
        # Store configuration, energy and forces
        config_atoms.append(atoms.copy())
        config_atoms[n].info["GIL"]="off"  # Turn off GIL to allow multithreading
        calc = KIMCalculator(model)             # copy does not retain the calculator,
        config_atoms[n].set_calculator(calc)    # so attached a new one.
        config_energy.append(atoms.get_potential_energy())
        config_forces.append(atoms.get_forces())
        # Report to user
        natoms = len(config_atoms[n])
        energy = config_energy[n]
        avenorm = np.linalg.norm(config_forces[n])/natoms
        vc.rwrite('{0:6d}   {1:6d}   {2:13.6e}   {3:13.6e}'. \
                  format(n, natoms, energy, avenorm))
    # close table
    vc.rwrite('-'*47)
    vc.rwrite('')

    # Cycle through configuration calculations num_cycle times
    # using num_config threads to tom compute all configurations.
    vc.rwrite('THREADED CALCULATIONS')
    vc.rwrite('')
    for cycle in range(num_cycles):
        vc.rwrite('### Cycle {} ###'.format(cycle+1))
        vc.rwrite('')
        # Randomize order of configurations for thread calculations.
        # This way the configurations are visited in different order
        # during each cycle.
        order = range(num_configs)
        random.shuffle(order)
        # Change positions, request energy which will issue a calculate
        # operation, then restore positions. This will forcce the threaded
        # calculations to recalculate and not use the energy and forces
        # stored in the config_atoms[] objects from previous cycle
        for n in range(num_configs):
            posns = config_atoms[n].get_positions()
            config_atoms[n].positions += 1.0
            config_atoms[n].get_potential_energy()
            config_atoms[n].set_positions(posns)
        # Create threads. Each thread will compute one configuration.
        threads = []
        num_threads = num_configs
        for n in range(num_threads):
            thr = threading.Thread(name="thread-"+str(n),
                                   target=compute_config,
                                   args=(config_atoms[order[n]],))
            threads.append(thr)
            thr.start()
        # Start all threads and inform the calling thread (running do_vc)
        # to wait until they are done.
        for thr in threads:
            thr.join()
        # All threads are complete at this points. Each configuration
        # contains the energy and forces computed by a thread. Compare
        # these with the original value computed in the earlier reference
        # sequential calculation and make sure there are no discrepancies.
        vc.rwrite('{0:>6}   {1:>6}   {2:>13}   {3:>13}   {4:>6}   {5:>6}'. \
                  format('Config', '#Atoms', 'Energy', 'Ave Norm', 'Thread', 'Status'))
        vc.rwrite('-'*65)
        all_agree = True
        for n in range(num_configs):
            natoms = len(config_atoms[n])
            energy = config_atoms[n].get_potential_energy()
            forces = config_atoms[n].get_forces()
            avenorm = np.linalg.norm(forces)/natoms
            thrnum = order.index(n) # number of thread that processed config n
            agree = np.allclose(energy, config_energy[n]) and \
                    np.allclose(forces, config_forces[n])
            agree_s = agree_text[agree]
            vc.rwrite('{0:6d}   {1:6d}   {2:13.6e}   {3:13.6e}   {4:6d}   {5:<6}'. \
                    format(n, natoms, energy, avenorm, thrnum, agree_s))
            all_agree = all_agree and agree
        # Close table
        vc.rwrite('-'*65)
        vc.rwrite('')
        # Stop cycling if an error is detected
        if not all_agree:
            break

    # Report grade
    vc.rwrite('='*dashwidth)
    vc.rwrite('To pass this verification check all threads must give identical results,')
    vc.rwrite('i.e. the same total energy and the same average norm (force norm divided')
    vc.rwrite('by the number of atoms). This is indicated by the "Status" column in the')
    vc.rwrite('above table.')
    vc.rwrite('')
    if all_agree:
        vc_grade = 'P'
        vc_comment = 'All threads give identical results for tested case. Model appears to be thread safe.'
    else:
        vc_grade = 'F'
        vc_comment = 'One or more threads gave different results than a single thread calculation. The model is not thread safe.'
    vc.rwrite('Grade: {}'.format(vc_grade))
    vc.rwrite('')
    vc.rwrite('Comment: '+vc_comment)
    vc.rwrite('')

    return vc_grade, vc_comment

################################################################################
#
#   MAIN PROGRAM
#
###############################################################################
if __name__ == '__main__':

    # Get the model extended KIM ID:
    model = input("Model Extended KIM ID = ")

    # Define VC object and do verification check
    vc = kimvc.vc_object(vc_name, vc_description, __author__)
    with vc:
        # Perform verification check and get grade
        try:
            vc_grade, vc_comment = do_vc(model, vc)
        except:
            vc_grade = "N/A"
            vc_comment = "Unable to perform verification check due to an error."
            sys.stderr.write('ERROR: Unable to perform verification check.\n')

        # Pack results in a dictionary and write VC property instance
        results = {"vc_name"        : vc_name,
                   "vc_description" : vc_description,
                   "vc_category"    : vc_category,
                   "vc_grade_basis" : vc_grade_basis,
                   "vc_grade"       : vc_grade,
                   "vc_comment"     : vc_comment,
                   "vc_files"       : vc_files}
        vc.write_results(results)