"""Tools for reading/writing extended xyz files, rotating coordinate systems"""
import os

import numpy as np


class UnsupportedFormat(Exception):
    """Raised if the xyz file being read has an unexpected format"""


def read_xyz(xyz_src):
    """
    Read an extended xyz file that lists the cell vectors and atomic
    positions in an arbitrary right-handed basis.

    Note that this function will loop over 'numatoms' atoms lines, where
    'numatoms' is the integer given on the first line of the file. If there
    are additional atoms lines, these will silently go unread.

    Parameters
    ----------
    xyz_src: string
        The name of the extended xyz file to read. This file must have a
        specific format which is a subset of the extended xyz format. Each atom
        line must contain only 4 elements: the first element specifying the
        species type of the atom and the remaining three elements specifying
        the positions of the atom. If the file includes any additional elements
        on the line, e.g. charge, an UnsupportedFormat exception will be
        raised.

    Returns
    -------
    numatoms: integer
        The integer written at the top of the xyz file indicating the number of
        atoms present.

    cell: tuple of strings
        The cell vectors read in, specified as a tuple of length 9 of the form
        (Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz). We store the actual strings read
        from the file so as to defer handling of precision to the calling
        function.

    species: tuple of strings
        The species read from each atom line, expressed as a
        numatoms-dimensional tuple. Whether it's written as an integer or an
        elemental string in the file, this tuple will contain strings.

    pos: tuple of tuples
        A two-dimensional tuple of dimension [numatoms, 3] containing the x, y,
        and z coordinates of each atom.  Again, we return the actual strings
        read from the file so as to defer handling of precision to the calling
        function.
    """
    # Pattern used to read cell parameters
    import re

    LATTICE_FIELD = re.compile(r"Lattice=\"([0-9e\.\-\s]*)\"")

    with open(xyz_src, "r") as src:
        numatoms = int(src.readline())  # get number of atoms

        # Read box vectors from comment line
        vecstring = LATTICE_FIELD.search(src.readline()).groups()[0]
        cell = [a for a in vecstring.strip().split(" ")]

        species = []
        pos = []
        for atom_count in range(numatoms):
            atomln = src.readline().strip().split(" ")
            if len(atomln) > 4:
                raise UnsupportedFormat(
                    "Extended xyz file {} has unsupported "
                    "format. Each atom line must contain only the species "
                    "and position of the atom."
                )
            species.append(atomln[0])
            pos.append([atomln[1], atomln[2], atomln[3]])

    return numatoms, cell, species, pos


def write_lammps_xyz_dump(numatoms, cell, species, pos, xyz_dest, overwrite=False):
    """
    Given a simulation cell and a set of atomic species and positions, write a
    LAMMPS xyz dump file. The part of the format that's specific to LAMMPS is
    that the species given on each atom line must be an integer rather than an
    alphabetical string. If the file at path xyz_dest already exists, an
    exception is raised unless overwrite is set to True.

    Parameters
    ----------
    numatoms: integer
        The number of atoms contained in the species and positions parameters.
        Obviously this isn't necessary, but we use it to perform a consistency
        check in the hopes that this could uncover bugs in the calling function.

    cell: tuple of strings
        The cell vectors read in, specified as a tuple of length 9 of the form
        (Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz). We require that these quantities
        be numeric strings so as to force the handling of precision by the
        calling function.

    species: tuple of integers
        The species read from each atom line, expressed as a
        numatoms-dimensional tuple of integers. Atomic symbols are not allowed
        because this LAMMPS expects the internal integer codes for the species
        to be given directly.

    pos: tuple of tuples
        A two-dimensional tuple of dimension [numatoms, 3] containing the x, y,
        and z coordinates of each atom.  Again, the coordinates are to be
        strings so as to force the calling function to be careful about
        handling precision.

    overwrite: bool
        Overwrite the file at path xyz_dest if it already exists.

    Returns
    ----------
    None
    """
    LATTICE_LINE_TPL = 'Lattice=" {} {} {} {} {} {} {} {} {} "'

    # First, check if species and positions actually contain 'numatoms'
    # elements
    for arr in ["species", "pos"]:
        if len(locals()[arr]) != numatoms:
            raise RuntimeError("array {} has length != numatoms".format(arr))

    if not all(len(i) == 3 for i in pos):
        raise RuntimeError(
            "detected that not all atoms tuples in positions array are of length 3"
        )

    if not all(isinstance(el, int) for el in species):
        raise RuntimeError("species array must contain only integers")

    if not overwrite:
        if os.path.isfile(xyz_dest):
            raise RuntimeError("file {} already exists")

    with open(xyz_dest, mode="w") as f:
        f.write(str(numatoms) + "\n")
        f.write(LATTICE_LINE_TPL.format(*cell) + "\n")
        for i in range(numatoms):
            line = "{:>6} {:>20} {:>20} {:>20}\n".format(
                species[i], pos[i][0], pos[i][1], pos[i][2]
            )
            f.write(line)


def get_cell_volume(cell):
    """
    Calculate the volume of a parallelepiped given its edge vectors

    Parameters
    ----------
    cell: tuple (or list)
        The cell vectors, specified as a tuple or list of the form
        Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz

    Returns
    -------
    The cell volume (even if it's zero)
    """
    # Unpack original cell
    Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz = (float(x) for x in cell)

    return np.fabs(
        Ax * (By * Cz - Cy * Bz) + Ay * (Bz * Cx - Bx * Cz) + Az * (Bx * Cy - Cx * By)
    )


def convert_xyz_to_lammps_convention(numatoms, orig_cell, orig_pos):
    """
    Takes a triclinic simulation cell and a set of atomic positions written in
    an arbitrary right-handed basis and rotates them to conform to
    LAMMPS' required orientation for triclinic cells.

    In LAMMPS, arbitrarily oriented triclinic boxes are not permitted.
    Instead, any triclinic boxes (with vectors A,B,C) must be rotated to
    conform to a conventional orientation where the new box has vectors a,b,c.
    Vector 'a' must point along the positive x-axis, vector 'b' must lie in the
    xy plane and have a strictly positive y-component, and vector 'c' must have
    a strictly positive z-component.  The below definitions of
    ax,ay,az,bx,by,bz,cx,cy,cz correspond to performing this rotation.
    See https://lammps.sandia.gov/doc/Howto_triclinic.html

    Parameters
    ----------
    orig_cell: tuple (or list)
        The original cell vectors, specified as a tuple or list of the form
        Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz

    orig_pos: tuple of tuples (or list of lists)
        The original atomic positions.  Each subtuple or sublist has dimension
        3.

    Returns
    -------
    new_cell: tuple
        The rotated cell vectors, specified as a tuple of the form
        Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz

    new_pos: tuple of tuples (or list of lists)
        The rotated atomic positions.  Each subtuple or sublist has dimension
        3.
    """
    # Unpack original cell
    Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz = (float(x) for x in orig_cell)

    # Check that original cell isn't flat w/ scalar triple product
    V = get_cell_volume(orig_cell)
    if abs(V) < 1e-8:
        raise RuntimeError("original cell with parameters {} is flat".format(orig_cell))

    # Rotate original cell to LAMMPS' required orientation
    ax = np.sqrt(Ax * Ax + Ay * Ay + Az * Az)
    ay = 0.0
    az = 0.0
    bx = (Bx * Ax + By * Ay + Bz * Az) / ax
    by = np.sqrt((Bx * Bx + By * By + Bz * Bz) - bx * bx)
    bz = 0.0
    cx = (Cx * Ax + Cy * Ay + Cz * Az) / ax
    cy = ((Bx * Cx + By * Cy + Bz * Cz) - bx * cx) / by
    cz = np.sqrt((Cx * Cx + Cy * Cy + Cz * Cz) - cx * cx - cy * cy)
    new_cell = (ax, ay, az, bx, by, bz, cx, cy, cz)

    new_pos = []
    for atom_count in range(numatoms):
        posx, posy, posz = (float(el) for el in orig_pos[atom_count])

        # Transform atomic positions to the new basis
        intermediate_vecx = (
            (By * Cz - Cy * Bz) * posx
            + (Bz * Cx - Bx * Cz) * posy
            + (Bx * Cy - Cx * By) * posz
        )
        intermediate_vecy = (
            (Cy * Az - Ay * Cz) * posx
            + (Cz * Ax - Cx * Az) * posy
            + (Cx * Ay - Ax * Cy) * posz
        )
        intermediate_vecz = (
            (Ay * Bz - By * Az) * posx
            + (Az * Bx - Ax * Bz) * posy
            + (Ax * By - Bx * Ay) * posz
        )
        newposx = (
            ax * intermediate_vecx + bx * intermediate_vecy + cx * intermediate_vecz
        ) / V
        newposy = (by * intermediate_vecy + cy * intermediate_vecz) / V
        newposz = (cz * intermediate_vecz) / V
        new_pos.append((newposx, newposy, newposz))

    return new_cell, new_pos


def get_inverse_rotation_matrix(orig_cell, new_cell):
    """
    Given one simulation cell, orig_cell, and its rotated counterpart,
    new_cell, calculate the inverse of the rotation matrix that mapped the
    former to the latter. This is necessary because the test driver wants to
    return the atomic forces in the same basis that the user used in
    constructing the extended xyz file, not in the rotated basis that LAMMPS
    has to use.

    Parameters
    ----------
    orig_cell: tuple (or list)
        The original cell vectors, specified as a tuple or list of the form
        Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz

    new_cell: tuple (or list)
        The new cell vectors, specified as a tuple or list of the form
        Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz

    Returns
    -------
        Rinv: tuple
        The inverse of the rotation that mapped orig_cell to new_cell,
        given as a tuple of floats which proceed column-wise through
        the matrix:

        Rinv11, Rinv21, Rinv31, Rinv12, Rinv22, Rinv32, Rinv13, Rinv23, Rinv33
    """
    # Unpack cells
    Ax, Ay, Az, Bx, By, Bz, Cx, Cy, Cz = (float(x) for x in orig_cell)
    ax, ay, az, bx, by, bz, cx, cy, cz = (float(x) for x in new_cell)

    # Check that original cell isn't flat w/ scalar triple product
    orig_V = get_cell_volume(orig_cell)
    if abs(orig_V) < 1e-8:
        raise RuntimeError("original cell with parameters {} is flat".format(orig_cell))

    # Check to make sure that the volume was preserved
    new_V = get_cell_volume(new_cell)
    if abs(new_V - orig_V) > 1e-10:
        raise RuntimeError("Volume of old cell and new cell differs by more than 1e-10")

    Rinv11 = (
        (
            -az * by * Cx
            + ay * bz * Cx
            + az * Bx * cy
            - Ax * bz * cy
            - ay * Bx * cz
            + Ax * by * cz
        )
        * orig_V
    ) / (
        (
            az * by * cx
            - ay * bz * cx
            - az * bx * cy
            + ax * bz * cy
            + ay * bx * cz
            - ax * by * cz
        )
        * (
            Az * By * Cx
            - Ay * Bz * Cx
            - Az * Bx * Cy
            + Ax * Bz * Cy
            + Ay * Bx * Cz
            - Ax * By * Cz
        )
    )
    Rinv12 = (
        (
            -az * Bx * cx
            + Ax * bz * cx
            + az * bx * Cx
            - ax * bz * Cx
            - Ax * bx * cz
            + ax * Bx * cz
        )
        * orig_V
    ) / (
        (
            az * by * cx
            - ay * bz * cx
            - az * bx * cy
            + ax * bz * cy
            + ay * bx * cz
            - ax * by * cz
        )
        * (
            Az * By * Cx
            - Ay * Bz * Cx
            - Az * Bx * Cy
            + Ax * Bz * Cy
            + Ay * Bx * Cz
            - Ax * By * Cz
        )
    )
    Rinv13 = (
        (
            az * Bx * cx
            - Ax * by * cx
            - ay * bx * Cx
            + ax * by * Cx
            + Ax * bx * cy
            - ax * Bx * cy
        )
        * orig_V
    ) / (
        (
            az * by * cx
            - ay * bz * cx
            - az * bx * cy
            + ax * bz * cy
            + ay * bx * cz
            - ax * by * cz
        )
        * (
            Az * By * Cx
            - Ay * Bz * Cx
            - Az * Bx * Cy
            + Ax * Bz * Cy
            + Ay * Bx * Cz
            - Ax * By * Cz
        )
    )
    Rinv21 = (
        (
            az * By * cy
            - Ay * bz * cy
            - az * by * Cy
            + ay * bz * Cy
            + Ay * by * cz
            - ay * By * cz
        )
        * orig_V
    ) / (
        (
            az * by * cx
            - ay * bz * cx
            - az * bx * cy
            + ax * bz * cy
            + ay * bx * cz
            - ax * by * cz
        )
        * (
            Az * By * Cx
            - Ay * Bz * Cx
            - Az * Bx * Cy
            + Ax * Bz * Cy
            + Ay * Bx * Cz
            - Ax * By * Cz
        )
    )
    Rinv22 = (
        (
            -az * By * cx
            + Ay * bz * cx
            + az * bx * Cy
            - ax * bz * Cy
            - Ay * bx * cz
            + ax * By * cz
        )
        * orig_V
    ) / (
        (
            az * by * cx
            - ay * bz * cx
            - az * bx * cy
            + ax * bz * cy
            + ay * bx * cz
            - ax * by * cz
        )
        * (
            Az * By * Cx
            - Ay * Bz * Cx
            - Az * Bx * Cy
            + Ax * Bz * Cy
            + Ay * Bx * Cz
            - Ax * By * Cz
        )
    )
    Rinv23 = (
        (
            -Ay * by * cx
            + ay * By * cx
            + Ay * bx * cy
            - ax * By * cy
            - ay * bx * Cy
            + ax * by * Cy
        )
        * orig_V
    ) / (
        (
            az * by * cx
            - ay * bz * cx
            - az * bx * cy
            + ax * bz * cy
            + ay * bx * cz
            - ax * by * cz
        )
        * (
            Az * By * Cx
            - Ay * Bz * Cx
            - Az * Bx * Cy
            + Ax * Bz * Cy
            + Ay * Bx * Cz
            - Ax * By * Cz
        )
    )
    Rinv31 = (
        (
            -Az * bz * cy
            + az * Bz * cy
            + Az * by * cz
            - ay * Bz * cz
            - az * by * Cz
            + ay * bz * Cz
        )
        * orig_V
    ) / (
        (
            az * by * cx
            - ay * bz * cx
            - az * bx * cy
            + ax * bz * cy
            + ay * bx * cz
            - ax * by * cz
        )
        * (
            Az * By * Cx
            - Ay * Bz * Cx
            - Az * Bx * Cy
            + Ax * Bz * Cy
            + Ay * Bx * Cz
            - Ax * By * Cz
        )
    )
    Rinv32 = (
        (
            Az * bz * cx
            - az * Bz * cx
            - Az * bx * cz
            + ax * Bz * cz
            + az * bx * Cz
            - ax * bz * Cz
        )
        * orig_V
    ) / (
        (
            az * by * cx
            - ay * bz * cx
            - az * bx * cy
            + ax * bz * cy
            + ay * bx * cz
            - ax * by * cz
        )
        * (
            Az * By * Cx
            - Ay * Bz * Cx
            - Az * Bx * Cy
            + Ax * Bz * Cy
            + Ay * Bx * Cz
            - Ax * By * Cz
        )
    )
    Rinv33 = (
        (
            -Az * by * cx
            + ay * Bz * cx
            + Az * bx * cy
            - ax * Bz * cy
            - ay * bx * Cz
            + ax * by * Cz
        )
        * orig_V
    ) / (
        (
            az * by * cx
            - ay * bz * cx
            - az * bx * cy
            + ax * bz * cy
            + ay * bx * cz
            - ax * by * cz
        )
        * (
            Az * By * Cx
            - Ay * Bz * Cx
            - Az * Bx * Cy
            + Ax * Bz * Cy
            + Ay * Bx * Cz
            - Ax * By * Cz
        )
    )

    Rinv = (Rinv11, Rinv21, Rinv31, Rinv12, Rinv22, Rinv32, Rinv13, Rinv23, Rinv33)

    return Rinv