"""Classes for handling protein structure coordinates"""

import _modeller
import modeller.selection as selection
from modeller import sequence
from modeller.atom_type import AtomType
import modeller.util.modutil as modutil
from modeller.util.logger import log

__docformat__ = "epytext en"

class coordinates(sequence.sequence):
    """Holds protein coordinates"""

    def __init__(self):
        sequence.sequence.__init__(self)

    def get_list_atom_indices(self, objlist, num):
        inds = []
        nat = len(self.atoms)
        if not isinstance(objlist, (list, tuple)):
            objlist = (objlist,)
        for obj in objlist:
            if isinstance(obj, (list, tuple)):
                inds.extend(self.get_list_atom_indices(obj, None))
            elif isinstance(obj, int):
                if obj < 1 or obj > nat:
                    raise IndexError(("Invalid atom index %d: should be " +
                                      "between %d and %d") % (obj, 1, nat))
                inds.append(obj)
            else:
                (objinds, mdl) = obj.get_atom_indices()
                if mdl != self:
                    raise TypeError("Incorrect model %s: expected %s"
                                    % (str(mdl), str(self)))
                inds.extend(objinds)
        if num is not None and len(inds) != num:
            raise ValueError("Expecting %d atom indices - got %d"
                             % (num, len(inds)))
        return inds

    def _indxatm(self, offset, length, suffix, indx):
        if isinstance(indx, str):
            newindx = _modeller.mod_model_find_atom(indx+suffix,
                                                    self.modpt) - 1 - offset
            if newindx < 0 or (length is not None and newindx >= length):
                raise KeyError("No such atom: %s" % indx)
            return newindx
        raise TypeError("Atom IDs must be numbers or strings")

    def _indxres(self, offset, length, suffix, indx):
        if isinstance(indx, str):
            newindx = _modeller.mod_model_find_residue(indx+suffix,
                                                       self.modpt) - 1 - offset
            if newindx < 0 or (length is not None and newindx >= length):
                raise KeyError("No such residue: %s" % indx)
            return newindx
        raise TypeError("Residue IDs must be numbers or strings")

    def _indxseg(self, offset, length, indx):
        if isinstance(indx, str):
            newindx = _modeller.mod_sequence_find_chain(indx,
                                                        self.seqpt) - 1 - offset
            if newindx < 0 or (length is not None and newindx >= length):
                raise KeyError("No such chain: %s" % indx)
            return newindx
        raise TypeError("Chain IDs must be numbers or strings")

    def residue_range(self, start, end):
        """Return a list of residues, running from start to end inclusive"""
        start = self.residues[start]._num
        end = self.residues[end]._num
        if end < start:
            raise ValueError("End residue is before start residue")
        return ResidueList(self, start, end - start + 1)

    def atom_range(self, start, end):
        """Return a list of atoms, running from start to end inclusive"""
        start = self.atoms[start]._num
        end = self.atoms[end]._num
        if end < start:
            raise ValueError("End atom is before start atom")
        return AtomList(self, start, end - start + 1)

    def point(self, x, y, z):
        """Return a point in the Cartesian space of this model"""
        return point(self, x, y, z)

    def make_chains(self, file, structure_types='structure',
                    minimal_resolution=99.0, minimal_chain_length=30,
                    max_nonstdres=10, chop_nonstd_termini=True,
                    minimal_stdres=30, alignment_format='PIR'):
        """Write out matching chains to separate files"""
        for chn in self.chains:
            if chn.filter(structure_types, minimal_resolution,
                          minimal_chain_length, max_nonstdres,
                          chop_nonstd_termini, minimal_stdres):
                (atom_file, code) = chn.atom_file_and_code(file)
                log.message('make_chains', "Wrote chain %s.chn" % code)
                chn.write(code + '.chn', atom_file, code,
                          'C; Produced by MODELLER', alignment_format,
                          chop_nonstd_termini)

    def __get_atoms(self):
        return AtomList(self)
    def __get_residues(self):
        return ResidueList(self)
    def __get_chains(self):
        return ChainList(self)
    def __get_natm(self):
        return _modeller.mod_coordinates_natm_get(self.cdpt)

    natm = property(__get_natm, doc="Number of atoms in this structure")
    atoms = property(__get_atoms, doc="List of all atoms in this structure")
    residues = property(__get_residues,
                        doc="List of all residues in this structure")
    chains = property(__get_chains,
                      doc="List of all chains/segments in this structure")
    segments = chains


class AtomList(object):
    """A list of L{atom} objects."""

    def __init__(self, mdl, offset=0, length=None, suffix=""):
        self.mdl = mdl
        self.offset = offset
        self.length = length
        self.suffix = suffix

    def __repr__(self):
        ln = len(self)
        s = "%d atom" % ln
        if ln != 1:
            s += 's'
        return s

    def __str__(self):
        return "<List of " + repr(self) + ">"

    def __len__(self):
        if self.length is not None:
            return self.length
        else:
            return self.mdl.natm

    def get_atom_indices(self):
        return range(self.offset + 1, self.offset + len(self) + 1), self.mdl

    def __getitem__(self, indx):
        ret = modutil.handle_seq_indx(self, indx, self.mdl._indxatm,
                                      (self.offset, self.length, self.suffix))
        if isinstance(ret, int):
            return atom(self.mdl, ret + self.offset)
        else:
            return [self[ind] for ind in ret]

    def __contains__(self, indx):
        try:
            ret = self[indx]
            return True
        except KeyError:
            return False


class ResidueList(object):
    """A list of L{residue} objects."""

    def __init__(self, mdl, offset=0, length=None, suffix=""):
        self.mdl = mdl
        self.offset = offset
        self.length = length
        self.suffix = suffix

    def __repr__(self):
        ln = len(self)
        s = "%d residue" % ln
        if ln != 1:
            s += 's'
        return s

    def __str__(self):
        return "<List of " + repr(self) + ">"

    def __len__(self):
        if self.length is not None:
            return self.length
        else:
            return self.mdl.nres

    def get_atom_indices(self):
        (startind, endind) = get_residue_atom_indices(self.mdl, self.offset,
                                                      self.offset + len(self))
        return range(startind + 1, endind + 1), self.mdl

    def __getitem__(self, indx):
        ret = modutil.handle_seq_indx(self, indx, self.mdl._indxres,
                                      (self.offset, self.length, self.suffix))
        if isinstance(ret, int):
            return residue(self.mdl, ret + self.offset)
        else:
            return [self[ind] for ind in ret]

class AtomIndices:
    """Placeholder class to pass 'raw' atom indices directly to selections"""
    def __init__(self, inds, mdl):
        self.inds = inds
        self.mdl = mdl

    def get_atom_indices(self):
        return self.inds, self.mdl


class point(object):
    """An arbitary point in the Cartesian space of a model"""
    def __init__(self, mdl, x, y, z):
        self.mdl = mdl
        self.x = x
        self.y = y
        self.z = z

    def __repr__(self):
        return "<Point (%.2f, %.2f, %.2f)>" % (self.x, self.y, self.z)

    def select_sphere(self, radius):
        """Returns a selection of all atoms within the given distance"""
        inds = _modeller.mod_selection_sphere(self.mdl.modpt, self.x, self.y,
                                              self.z, radius)
        return selection.selection(AtomIndices(inds, self.mdl))


class atom(point):
    """A single atom in a protein structure"""

    def __init__(self, mdl, num):
        self.mdl = mdl
        self._num = num

    def get_atom_indices(self):
        return [self._num + 1], self.mdl

    def __repr__(self):
        chainid = self.residue.get_chain_suffix()
        return "<Atom %s:%s%s>" % (self.name, self.residue.num, chainid)

    def __get_x(self):
        x = _modeller.mod_coordinates_x_get(self.mdl.cdpt)
        return _modeller.mod_float1_get(x, self._num)
    def __set_x(self, val):
        x = _modeller.mod_coordinates_x_get(self.mdl.cdpt)
        _modeller.mod_float1_set(x, self._num, val)
    def __get_y(self):
        y = _modeller.mod_coordinates_y_get(self.mdl.cdpt)
        return _modeller.mod_float1_get(y, self._num)
    def __set_y(self, val):
        y = _modeller.mod_coordinates_y_get(self.mdl.cdpt)
        _modeller.mod_float1_set(y, self._num, val)
    def __get_z(self):
        z = _modeller.mod_coordinates_z_get(self.mdl.cdpt)
        return _modeller.mod_float1_get(z, self._num)
    def __set_z(self, val):
        z = _modeller.mod_coordinates_z_get(self.mdl.cdpt)
        _modeller.mod_float1_set(z, self._num, val)
    def __get_dvx(self):
        dvx = _modeller.mod_model_dvx_get(self.mdl.modpt)
        return _modeller.mod_float1_get(dvx, self._num)
    def __get_dvy(self):
        dvy = _modeller.mod_model_dvy_get(self.mdl.modpt)
        return _modeller.mod_float1_get(dvy, self._num)
    def __get_dvz(self):
        dvz = _modeller.mod_model_dvz_get(self.mdl.modpt)
        return _modeller.mod_float1_get(dvz, self._num)
    def __get_vx(self):
        vx = _modeller.mod_model_vx_get(self.mdl.modpt)
        return _modeller.mod_float1_get(vx, self._num)
    def __get_vy(self):
        vy = _modeller.mod_model_vy_get(self.mdl.modpt)
        return _modeller.mod_float1_get(vy, self._num)
    def __get_vz(self):
        vz = _modeller.mod_model_vz_get(self.mdl.modpt)
        return _modeller.mod_float1_get(vz, self._num)
    def __get_mass(self):
        return self.type.mass
    def __get_biso(self):
        biso = _modeller.mod_coordinates_biso_get(self.mdl.cdpt)
        return _modeller.mod_float1_get(biso, self._num)
    def __set_biso(self, val):
        biso = _modeller.mod_coordinates_biso_get(self.mdl.cdpt)
        _modeller.mod_float1_set(biso, self._num, val)
    def __get_occ(self):
        occ = _modeller.mod_coordinates_occ_get(self.mdl.cdpt)
        return _modeller.mod_float1_get(occ, self._num)
    def __set_occ(self, val):
        occ = _modeller.mod_coordinates_occ_get(self.mdl.cdpt)
        _modeller.mod_float1_set(occ, self._num, val)
    def __get_charge(self):
        charge = _modeller.mod_model_charge_get(self.mdl.modpt)
        return _modeller.mod_float1_get(charge, self._num)
    def __set_charge(self, val):
        charge = _modeller.mod_model_charge_get(self.mdl.modpt)
        _modeller.mod_float1_set(charge, self._num, val)
    def __get_name(self):
        return _modeller.mod_coordinates_atmnam_get(self.mdl.cdpt, self._num)
    def __get_type(self):
        iattyp = _modeller.mod_model_iattyp_get(self.mdl.modpt)
        return AtomType(self.mdl, _modeller.mod_int1_get(iattyp, self._num)-1)
    def __get_gprsr_class(self):
        iatta = _modeller.mod_model_iatta_get(self.mdl.modpt)
        return _modeller.mod_int1_get(iatta, self._num)
    def __get_residue(self):
        iresatm = _modeller.mod_coordinates_iresatm_get(self.mdl.cdpt)
        resind = _modeller.mod_int1_get(iresatm, self._num) - 1
        return residue(self.mdl, resind)
    def __get_index(self):
        return self._num + 1

    x = property(__get_x, __set_x, doc="x coordinate")
    y = property(__get_y, __set_y, doc="y coordinate")
    z = property(__get_z, __set_z, doc="z coordinate")
    dvx = property(__get_dvx, doc="Objective function derivative, dF/dx")
    dvy = property(__get_dvy, doc="Objective function derivative, dF/dy")
    dvz = property(__get_dvz, doc="Objective function derivative, dF/dz")
    vx = property(__get_vx, doc="x component of velocity")
    vy = property(__get_vy, doc="y component of velocity")
    vz = property(__get_vz, doc="z component of velocity")
    mass = property(__get_mass, doc="Atomic mass")
    biso = property(__get_biso, __set_biso, doc="Isotropic temperature factor")
    occ = property(__get_occ, __set_occ, doc="Occupancy")
    charge = property(__get_charge, __set_charge, doc="Electrostatic charge")
    name = property(__get_name, doc="PDB name")
    type = property(__get_type, doc="Atom type")
    gprsr_class = property(__get_gprsr_class, doc="group_restraints class")
    residue = property(__get_residue, doc="Residue object containing this atom")
    index = property(__get_index, doc="Internal atom index")


class residue(sequence.SequenceResidue):
    """A single residue in a protein structure"""

    def get_atom_indices(self):
        (startind, endind) = get_residue_atom_indices(self.mdl, self._num,
                                                      self._num + 1)
        return range(startind + 1, endind + 1), self.mdl

    def __repr__(self):
        chainid = self.get_chain_suffix()
        return "Residue %s%s (type %s)" % (self.num, chainid, self.name)

    def __str__(self):
        return "<%s>" % repr(self)

    def get_chain_suffix(self):
        """Returns a suffix - e.g. ':A' - to identify the chain this residue
           is in. If the chain has no ID, returns a blank string."""
        chainid = self.chain.name
        if chainid == ' ':
            return ''
        else:
            return ':' + chainid

    def __get_num(self):
        return _modeller.mod_coordinates_resnum_get(self.mdl.cdpt,
                                                    self._num).strip()
    def __set_num(self, val):
        _modeller.mod_coordinates_resnum_set(self.mdl.cdpt, self._num, val)
    def __get_atoms(self):
        (startind, endind) = get_residue_atom_indices(self.mdl, self._num,
                                                      self._num + 1)
        suffix = ":%s%s" % (self.num, self.get_chain_suffix())
        return AtomList(self.mdl, startind, endind - startind, suffix)

    def __get_alpha(self):
        return get_dihedral("alpha", 1, self._num, self.mdl)
    def __get_phi(self):
        return get_dihedral("phi", 2, self._num, self.mdl)
    def __get_psi(self):
        return get_dihedral("psi", 3, self._num, self.mdl)
    def __get_omega(self):
        return get_dihedral("omega", 4, self._num, self.mdl)
    def __get_chi1(self):
        return get_dihedral("chi1", 5, self._num, self.mdl)
    def __get_chi2(self):
        return get_dihedral("chi2", 6, self._num, self.mdl)
    def __get_chi3(self):
        return get_dihedral("chi3", 7, self._num, self.mdl)
    def __get_chi4(self):
        return get_dihedral("chi4", 8, self._num, self.mdl)
    def __get_chi5(self):
        return get_dihedral("chi5", 9, self._num, self.mdl)
    def __get_index(self):
        return self._num + 1

    num = property(__get_num, __set_num, doc="PDB-style residue number")
    atoms = property(__get_atoms, doc="All atoms in this residue")
    alpha = property(__get_alpha, doc="Alpha dihedral angle")
    phi = property(__get_phi, doc="Phi dihedral angle")
    psi = property(__get_psi, doc="Psi dihedral angle")
    omega = property(__get_omega, doc="Omega dihedral angle")
    chi1 = property(__get_chi1, doc="Chi1 dihedral angle")
    chi2 = property(__get_chi2, doc="Chi2 dihedral angle")
    chi3 = property(__get_chi3, doc="Chi3 dihedral angle")
    chi4 = property(__get_chi4, doc="Chi4 dihedral angle")
    chi5 = property(__get_chi5, doc="Chi5 dihedral angle")
    index = property(__get_index, doc="Internal residue index")


def get_residue_atom_indices(mdl, start, end):
    iatmr1 = _modeller.mod_coordinates_iatmr1_get(mdl.cdpt)
    startind = _modeller.mod_int1_get(iatmr1, start) - 1
    if end < mdl.nres:
        endind = _modeller.mod_int1_get(iatmr1, end) - 1
    else:
        endind = mdl.natm
    return (startind, endind)


class ChainList(object):
    """A list of L{chain} objects."""

    def __init__(self, mdl):
        self.mdl = mdl

    def __len__(self):
        return self.mdl.nseg

    def __repr__(self):
        ln = len(self)
        s = "%d chain" % ln
        if ln != 1:
            s += 's'
        return s

    def __str__(self):
        return "<List of " + repr(self) + ">"

    def __getitem__(self, indx):
        ret = modutil.handle_seq_indx(self, indx, self.mdl._indxseg, (0, None))
        if isinstance(ret, int):
            return chain(self.mdl, ret)
        else:
            return [self[ind] for ind in ret]


class chain(object):
    """A single chain/segment in a protein structure."""

    def __init__(self, mdl, num):
        self.mdl = mdl
        self.num = num

    def __repr__(self):
        return "<Chain %s>" % repr(self.name)

    def filter(self, structure_types='structure', minimal_resolution=99.0,
               minimal_chain_length=30, max_nonstdres=10,
               chop_nonstd_termini=True, minimal_stdres=30):
        """Does this chain pass all filter criteria?"""
        f = _modeller.mod_chain_filter
        return f(self.mdl.seqpt, self.num, structure_types, minimal_resolution,
                 minimal_chain_length, max_nonstdres, chop_nonstd_termini,
                 minimal_stdres)

    def write(self, file, atom_file, align_code, comment='', format='PIR',
              chop_nonstd_termini=True):
        """Write this chain out to an alignment file"""
        f = _modeller.mod_chain_write
        return f(self.mdl.seqpt, self.mdl.cdpt, self.num, file, atom_file,
                 align_code, comment, format, chop_nonstd_termini,
                 self.mdl.env.libs.modpt)

    def atom_file_and_code(self, filename):
        """Return suitable atom_file and align_codes for this chain, given
           a model filename."""
        return _modeller.mod_chain_atom_file_and_code(self.mdl.seqpt, self.num,
                                                      filename)

    def __get_resind(self):
        iress1 = _modeller.mod_sequence_iress1_get(self.mdl.seqpt)
        iress2 = _modeller.mod_sequence_iress2_get(self.mdl.seqpt)
        return (_modeller.mod_int1_get(iress1, self.num) - 1,
                _modeller.mod_int1_get(iress2, self.num))

    def __get_residues(self):
        (startres, endres) = self.__get_resind()
        suffix = ":%s" % self.name
        return ResidueList(self.mdl, startres, endres - startres, suffix)

    def __get_atoms(self):
        (startres, endres) = self.__get_resind()
        (startatm, endatm) = get_residue_atom_indices(self.mdl, startres,
                                                      endres)
        suffix = ":%s" % self.name
        return AtomList(self.mdl, startatm, endatm - startatm, suffix)

    def get_atom_indices(self):
        (startres, endres) = self.__get_resind()
        (startatm, endatm) = get_residue_atom_indices(self.mdl, startres,
                                                      endres)
        return (range(startatm+1, endatm+1), self.mdl)

    def __get_name(self):
        return _modeller.mod_sequence_segid_get(self.mdl.seqpt, self.num)
    def __set_name(self, val):
        return _modeller.mod_sequence_segid_set(self.mdl.seqpt, self.num, val)

    residues = property(__get_residues,
                        doc="List of all residues in this chain")
    atoms = property(__get_atoms, doc="List of all atoms in this chain")
    name = property(__get_name, __set_name, doc="Chain ID")


def get_dihedral(type, idihtyp, num, mdl):
    """Get a residue dihedral angle, or None if not defined for this residue"""
    if _modeller.mod_coordinates_has_dihedral(mdl.cdpt, mdl.seqpt,
                                              mdl.env.libs.modpt, idihtyp, num):
        return dihedral(type, idihtyp, num, mdl)
    else:
        return None

class dihedral(object):
    """A residue dihedral angle (e.g. alpha, omega, phi, psi, chi1)"""

    def __init__(self, type, idihtyp, num, mdl):
        self.type = type
        self.idihtyp = idihtyp
        self._num = num
        self.mdl = mdl

    def __repr__(self):
        return "%s dihedral" % self.type
    def __str__(self):
        return "<%s>" % repr(self)

    def __get_value(self):
        return _modeller.mod_coordinates_dihedral_get(self.mdl.cdpt,
                                                      self.mdl.seqpt,
                                                      self.mdl.env.libs.modpt,
                                                      self.idihtyp, self._num)
    def __get_dihclass(self):
        dih = self.value
        return _modeller.mod_sequence_dihclass_get(self.mdl.seqpt,
                                                   self.mdl.env.libs.modpt,
                                                   self.idihtyp, dih, self._num)
    def __get_atoms(self):
        ats = _modeller.mod_coordinates_dihatoms_get(self.mdl.cdpt,
                                                     self.mdl.seqpt,
                                                     self.mdl.env.libs.modpt,
                                                     self.idihtyp, self._num)
        return [self.mdl.atoms[i-1] for i in ats]

    value = property(__get_value, doc="Current value, in degrees")
    dihclass = property(__get_dihclass, doc="Current dihedral class")
    atoms = property(__get_atoms, doc="Atoms defining this dihedral angle")
