import _modeller
from modeller import physical, group_restraints, energy_data
from modeller.energy_profile import energy_profile
from modeller.util.matrix import matrix_to_list

class selection_iterator(object):
    def __init__(self, mdl, seldict):
        self.mdl = mdl
        self.seliter = seldict.__iter__()
    def __iter__(self):
        return self
    def next(self):
        obj = self.seliter.next()
        return self.mdl.atoms[obj - 1]

class selection(object):
    """An arbitrary set of atoms. 'atom' or 'residue' objects can be
       added to or removed from the selection, and selections can be
       manipulated in the same way as Python-standard sets - e.g. union,
       intersection, difference, or the equivalent |, & and - operators.
       Note that 'obj in sel' is only True if ALL atoms in 'obj'
       are in the selection sel. If you want to check for partial selection,
       use len(sel.intersection([obj])) > 0."""

    # Global reference to the DOPE restraints files, so that they're loaded
    # only once
    dope_restraints = None
    dopehr_restraints = None

    def __init__(self, *atoms):
        self.__selection = {}
        self.__mdl = None
        self.union_update(atoms)

    def get_atom_indices(self):
        """Return the integer indices of all atoms in this selection"""
        keys = self.__selection.keys()
        keys.sort()
        return (keys, self.__mdl)

    def get_model(self):
        """Return the model object which all selected atoms belong to"""
        return self.__mdl

    def __repr__(self):
        suffix = "s"
        num = len(self.__selection)
        if num == 1:
            suffix = ""
        return "Selection of %d atom%s" % (num, suffix)

    def __str__(self):
        return "<%s>" % repr(self)

    def __len__(self):
        return len(self.__selection)

    def __check_object(self, obj):
        if not hasattr(obj, 'get_atom_indices'):
            raise TypeError, \
                  ("Invalid object %s for selection - try atom, residue, "+ \
                   "model, selection objects") % repr(obj)

    def __contains__(self, obj):
        self.__check_object(obj)
        (inds, mdl) = obj.get_atom_indices()
        if mdl is not self.__mdl:
            return False
        for ind in inds:
            if ind not in self.__selection:
                return False
        return True

    def __iter__(self):
        return selection_iterator(self.__mdl, self.__selection)

    def __typecheck(self, obj):
        if not isinstance(obj, selection):
            raise TypeError, "Must use selection objects here"

    def __typecoerce(self, obj):
        if isinstance(obj, selection):
            return obj
        else:
            return selection(*obj)

    def issubset(self, t):
        """Tests whether every element is also in t"""
        t = self.__typecoerce(t)
        for obj in self:
            if obj not in t:
                return False
        return True

    def __le__(self, t):
        self.__typecheck(t)
        return self.issubset(t)

    def issuperset(self, t):
        """Tests whether every element in t is also in this set"""
        t = self.__typecoerce(t)
        for obj in t:
            if obj not in self:
                return False
        return True

    def __ge__(self, t):
        self.__typecheck(t)
        return self.issuperset(t)

    def __eq__(self, t):
        self.__typecheck(t)
        return self.issubset(t) and t.issubset(self)

    def __lt__(self, t):
        self.__typecheck(t)
        return self.issubset(t) and not t.issubset(self)

    def __gt__(self, t):
        self.__typecheck(t)
        return t.issubset(self) and not self.issubset(t)

    def add(self, obj):
        """Adds a new object (e.g. atom, residue) to the selection"""
        if not hasattr(obj, "get_atom_indices") \
           and isinstance(obj, (list, tuple)):
            for x in obj:
                self.add(x)
        else:
            self.__check_object(obj)
            (inds, mdl) = obj.get_atom_indices()
            if self.__mdl is None:
                self.__mdl = mdl
            elif self.__mdl is not mdl:
                raise ValueError, "All atoms must be in the same model!"
            for ind in inds:
                self.__selection[ind] = None

    def remove(self, obj):
        """Removes an object (e.g. atom, residue) from the selection;
           raises KeyError if not present"""
        if not hasattr(obj, "get_atom_indices") \
           and isinstance(obj, (list, tuple)):
            for x in obj:
                self.remove(x)
        else:
            if obj not in self:
                raise KeyError, obj
            (inds, mdl) = obj.get_atom_indices()
            if self.__mdl is not mdl:
                raise ValueError, "All atoms must be in the same model!"
            for ind in inds:
                del self.__selection[ind]

    def discard(self, obj):
        """Removes an object (e.g. atom, residue) from the selection,
           if present"""
        try:
            self.remove(obj)
        except KeyError:
            pass

    def union(self, t):
        """Returns a new selection containing objects from both selections"""
        newobj = selection()
        for obj in self:
            newobj.add(obj)
        for obj in t:
            newobj.add(obj)
        return newobj

    def __or__(self, t):
        self.__typecheck(t)
        return self.union(t)

    def intersection(self, t):
        """Returns a new selection containing objects common to both
           selections"""
        t = self.__typecoerce(t)
        newobj = selection()
        for obj in self:
            if obj in t:
                newobj.add(obj)
        return newobj

    def __and__(self, t):
        self.__typecheck(t)
        return self.intersection(t)

    def difference(self, t):
        """Returns a new selection containing objects in this selection
           but not in t"""
        t = self.__typecoerce(t)
        newobj = selection()
        for obj in self:
            if obj not in t:
                newobj.add(obj)
        return newobj

    def __sub__(self, t):
        self.__typecheck(t)
        return self.difference(t)

    def symmetric_difference(self, t):
        """Returns a new selection containing objects in either this selection
           or t, but not both"""
        t = self.__typecoerce(t)
        newobj = selection()
        for obj in self:
            if obj not in t:
                newobj.add(obj)
        for obj in t:
            if obj not in self:
                newobj.add(obj)
        return newobj

    def __xor__(self, t):
        self.__typecheck(t)
        return self.symmetric_difference(t)

    def copy(self):
        """Returns a copy of this selection"""
        newobj = selection()
        newobj.__selection = self.__selection.copy()
        newobj.__mdl = self.__mdl
        return newobj

    def union_update(self, t):
        for obj in t:
            self.add(obj)
        return self

    def __ior__(self, t):
        self.__typecheck(t)
        return self.union_update(t)

    def intersection_update(self, t):
        t = self.__typecoerce(t)
        c = self.copy()
        for obj in c:
            if obj not in t:
                self.remove(obj)
        return self

    def __iand__(self, t):
        self.__typecheck(t)
        return self.intersection_update(t)

    def difference_update(self, t):
        t = self.__typecoerce(t)
        c = self.copy()
        for obj in c:
            if obj in t:
                self.remove(obj)
        return self

    def __isub__(self, t):
        self.__typecheck(t)
        return self.difference_update(t)

    def symmetric_difference_update(self, t):
        t = self.__typecoerce(t)
        c = self.copy()
        for obj in t:
            if obj in c:
                self.remove(obj)
            else:
                self.add(obj)
        return self

    def __ixor__(self, t):
        self.__typecheck(t)
        return self.symmetric_difference_update(t)

    def clear(self):
        """Remove all atoms from the selection"""
        self.__mdl = None
        self.__selection = {}

    def __require_indices(self):
        """Get atom indices, and fail if there are none"""
        (inds, mdl) = self.get_atom_indices()
        if mdl is None or len(inds) == 0:
            raise ValueError, "Selection contains no atoms"
        return (inds, mdl)

    def translate(self, vector):
        """Translate the selection by the given vector"""
        (inds, mdl) = self.__require_indices()
        _modeller.translate_selection(mdl.modpt, inds, vector)

    def rotate_origin(self, axis, angle):
        """Rotate the selection about the given axis through the origin,
           by the given angle (in degrees)"""
        (inds, mdl) = self.__require_indices()
        _modeller.rotate_axis_selection(mdl.modpt, inds, axis, angle)

    def rotate_mass_center(self, axis, angle):
        """Rotate the selection about the given axis through the mass center,
           by the given angle (in degrees)"""
        com = self.mass_center
        self.translate([-a for a in com])
        self.rotate_origin(axis, angle)
        self.translate(com)

    def transform(self, matrix):
        """Transform the selection coordinates with the given matrix"""
        (inds, mdl) = self.__require_indices()
        lst = matrix_to_list(matrix)
        _modeller.transform_selection(mdl.modpt, inds, lst)

    def saxs_intens(self, saxsd, filename, fitflag=False):
        """Calculate SAXS intensity from model
           sel.saxs_intens(saxsd, filename, fitflag)"""
        (inds, mdl) = self.__require_indices()
        return _modeller.saxs_intens(mdl.modpt, saxsd.modpt, inds,
                                     filename, fitflag)

    def find_atoms(self, restyp, atom_names, min_selected):
        mdl = self.get_model()
        if mdl is None:
            raise ValueError, "Selection contains no atoms"
        import model_topology
        return model_topology.find_atoms(mdl, restyp, atom_names,
                                         self.__check_selected, min_selected)

    def find_chi1_dihedrals(self, restyp, min_selected):
        return self.__find_dihedrals(restyp, 5, min_selected)
    def find_chi2_dihedrals(self, restyp, min_selected):
        return self.__find_dihedrals(restyp, 6, min_selected)
    def find_chi3_dihedrals(self, restyp, min_selected):
        return self.__find_dihedrals(restyp, 7, min_selected)
    def find_chi4_dihedrals(self, restyp, min_selected):
        return self.__find_dihedrals(restyp, 8, min_selected)

    def __find_dihedrals(self, restyp, dihedral_type, min_selected):
        mdl = self.get_model()
        if mdl is None:
            raise ValueError, "Selection contains no atoms"
        import model_topology
        return model_topology.find_dihedrals(mdl, restyp, dihedral_type,
                                             self.__check_selected,
                                             min_selected)

    def __check_selected(self, atoms, min_selected):
        num_selected = 0
        min_selected = min(min_selected, len(atoms))
        for a in atoms:
            if a in self:
                num_selected += 1
                if num_selected >= min_selected:
                    return True
        return False

    def write(self, file, **vars):
        """Write selection coordinates to a file"""
        (inds, mdl) = self.__require_indices()

        return mdl.top.write_model('selection.write', mdl=mdl.modpt,
                                   libs=mdl.env.libs.modpt, sel1=inds,
                                   file=file, write_all_atoms=False, **vars)

    def by_residue(self):
        """Return a new selection, in which any residues in the existing
           selection, that have at least one selected atom, are now entirely
           selected"""
        return self.extend_by_residue(0)

    def extend_by_residue(self, extension):
        """Return a new selection, in which any residues with at least one
           selected atom in the existing selection are now entirely selected.
           Additionally, |extension| residues around each selected residue are
           selected."""
        newobj = self.copy()
        (inds, mdl) = self.get_atom_indices()
        if mdl is not None:
            for (n, res) in enumerate(mdl.residues):
                if len(self.intersection([res])) > 0:
                    for i in range(max(0, n-extension),
                                   min(len(mdl.residues), n+extension+1)):
                        newobj.add(mdl.residues[i])
        return newobj

    def only_sidechain(self):
        """Returns a new selection, containing only sidechain atoms from the
           current selection"""
        return self - self.only_mainchain()

    def only_mainchain(self):
        """Returns a new selection, containing only mainchain atoms from the
           current selection"""
        return self.only_atom_types('O OT1 OT2 C CA N')

    def only_atom_types(self, atom_types):
        """Returns a new selection, containing only atoms from the current
           selection of the given space-separated type(s) (e.g. 'CA CB')"""
        if isinstance(atom_types, (list, tuple)):
            atom_types = " ".join(atom_types)
        return self.__filter(_modeller.selection_atom_types, atom_types)

    def only_residue_types(self, residue_types):
        """Returns a new selection, containing only atoms from the current
           selection in residues of the given space-separated type(s)
           (e.g. 'ALA ASP')"""
        if isinstance(residue_types, (list, tuple)):
            residue_types = " ".join(residue_types)
        libs = self.__get_libs()
        return self.__filter(_modeller.selection_residue_types, residue_types,
                             libs)

    def only_std_residues(self):
        """Returns a new selection, containing only atoms in standard residue
           types (i.e. everything but HETATM)"""
        libs = self.__get_libs()
        return self.__filter(_modeller.selection_std_residues, libs)

    def only_no_topology(self):
        """Returns a new selection, containing only atoms in residues that
           have no defined topology"""
        libs = self.__get_libs()
        return self.__filter(_modeller.selection_no_topology, libs)

    def only_het_residues(self):
        """Returns a new selection, containing only atoms in HETATM residues"""
        libs = self.__get_libs()
        return self.__filter(_modeller.selection_het_residues, libs)

    def select_sphere(self, radius):
        """Returns a new selection, containing all atoms within the given
           distance from any atom in the current selection"""
        s = selection()
        for x in self:
            s.add(x.select_sphere(radius))
        return s

    def mutate(self, residue_type):
        """Mutate selected residues"""
        (inds, mdl) = self.__require_indices()
        return _modeller.mutate_model(mdl.modpt, mdl.env.libs.modpt, inds,
                                      residue_type)

    def randomize_xyz(self, deviation):
        """Randomize coordinates"""
        (inds, mdl) = self.__require_indices()
        return _modeller.randomize_xyz(mdl.modpt, mdl.env.libs.modpt, inds,
                                       deviation)

    def superpose(self, mdl2, aln, rms_cutoff=3.5, **vars):
        """Superpose the input model on this selection, given an alignment of
           the models"""
        import superpose
        (inds, mdl) = self.__require_indices()

        retval = mdl.top.superpose('selection.superpose', mdl=mdl.modpt,
                                   mdl2=mdl2.modpt, aln=aln.modpt, sel1=inds,
                                   libs=mdl.env.libs.modpt,
                                   rms_cutoff=rms_cutoff, **vars)
        return superpose.superpose_data(*retval)

    def rotate_dihedrals(self, deviation, change,
                         dihedrals=('PHI', 'PSI', 'CHI1', 'CHI2',
                                    'CHI3', 'CHI4')):
        """Optimize or randomize dihedral angles"""
        (inds, mdl) = self.__require_indices()
        return _modeller.rotate_dihedrals(mdl.modpt, mdl.env.libs.modpt, inds,
                                          deviation, change, dihedrals)

    def hot_atoms(self, edat=None, **vars):
        """Returns a new selection containing all atoms violating restraints"""
        (inds, mdl) = self.__require_indices()

        if edat is None:
            edat = mdl.env.edat
        newinds = mdl.top.pick_hot_atoms('selection.hot_atoms', mdl=mdl.modpt,
                                         edat=edat.modpt,
                                         libs=mdl.env.libs.modpt, sel1=inds,
                                         **vars)
        newobj = selection()
        newobj.__mdl = mdl
        newobj.__selection = dict.fromkeys(newinds)
        return newobj

    def objfunc(self, edat=None, residue_span_range=(0,99999),
                schedule_scale=physical.values(default=1.0)):
        """Get just the objective function value, without derivatives"""
        (inds, mdl) = self.__require_indices()
        if edat is None:
            edat = mdl.env.edat
        return _modeller.get_objfunc(mdl.modpt, edat.modpt, mdl.env.libs.modpt,
                                     inds, residue_span_range, schedule_scale)
        

    def energy(self, edat=None, output='LONG', file='default', **vars):
        """Evaluate the objective function given restraints"""
        (inds, mdl) = self.__require_indices()

        if edat is None:
            edat = mdl.env.edat
        (molpdf, terms) = mdl.top.energy('selection.energy', mdl=mdl.modpt,
                                         edat=edat.modpt,
                                         libs=mdl.env.libs.modpt, sel1=inds,
                                         file=file, output=output, **vars)
        terms = physical.from_list(terms)
        return (molpdf, terms)

    def get_dope_potential(self):
        """Get the spline data for the DOPE statistical potential"""
        (inds, mdl) = self.__require_indices()
        if not selection.dope_restraints:
            selection.dope_restraints = \
                group_restraints(mdl.env, classes='${LIB}/atmcls-mf.lib',
                                 parameters='${LIB}/dist-mf.lib')
        return selection.dope_restraints

    def get_dopehr_potential(self):
        """Get the spline data for the DOPE-HR statistical potential"""
        (inds, mdl) = self.__require_indices()
        if not selection.dopehr_restraints:
            selection.dopehr_restraints = \
                group_restraints(mdl.env, classes='${LIB}/atmcls-mf.lib',
                                 parameters='${LIB}/dist-mfhr.lib')
        return selection.dopehr_restraints

    def get_dope_energy_data(self):
        """Get ideal energy_data terms for DOPE potential evaluations"""
        return energy_data(contact_shell=15.0, dynamic_modeller=True,
                           dynamic_lennard=False, dynamic_sphere=False,
                           excl_local=(False, False, False, False))

    def assess_dope(self, **vars):
        """Assess the selection with the DOPE potential"""
        return self._dope_energy(self.get_dope_potential(), "DOPE", **vars)

    def assess_dopehr(self, **vars):
        """Assess the selection with the DOPE-HR potential"""
        return self._dope_energy(self.get_dopehr_potential(), "DOPE-HR", **vars)

    def get_energy_profile(self, edat, physical_type):
        """Get a per-residue energy profile, plus the number of restraints on
           each residue, and the RMS minimum and heavy violations"""
        (inds, mdl) = self.__require_indices()
        scaln = physical.values(default=0.)
        scaln[physical_type] = 1.
        prof = _modeller.rms_profile(mdl.modpt, edat.modpt, mdl.env.libs.modpt,
                                     inds, (1, 9999), True, False,
                                     physical_type.get_num(), scaln)
        return energy_profile(*prof)

    def _dope_energy(self, gprsr, name, output='SHORT NO_REPORT',
                     residue_span_range=(1, 9999),
                     schedule_scale=physical.values(default=0.,
                                                    nonbond_spline=1.), **vars):
        """Internal function to do DOPE or DOPE-HR assessment"""
        mdl = self.__mdl
        print ">> Model assessment by %s potential" % name
        edat = self.get_dope_energy_data()
        old_gprsr = mdl.group_restraints
        mdl.group_restraints = gprsr
        try:
            (molpdf, terms) = \
                self.energy(edat=edat, residue_span_range=residue_span_range,
                            output=output, schedule_scale=schedule_scale,
                            **vars)
        finally:
            mdl.group_restraints = old_gprsr
        print "%s score               : %12.6f" % (name, molpdf)
        return molpdf

    def debug_function(self, edat=None, **vars):
        """Test code self-consistency"""
        (inds, mdl) = self.__require_indices()

        if edat is None:
            edat = mdl.env.edat
        return mdl.top.debug_function('selection.debug_function', mdl=mdl.modpt,
                                      edat=edat.modpt, libs=mdl.env.libs.modpt,
                                      sel1=inds, **vars)

    def unbuild(self):
        """Undefine all coordinates"""
        (inds, mdl) = self.__require_indices()
        return _modeller.unbuild_model(mdl=mdl.modpt, sel1=inds)

    def __get_libs(self):
        if self.__mdl:
            return self.__mdl.env.libs.modpt
        else:
            return None

    def __filter(self, func, *args):
        newobj = selection()
        (inds, mdl) = self.get_atom_indices()
        if mdl is not None:
            newinds = func(mdl.modpt, inds, *args)
            newobj.__mdl = mdl
            newobj.__selection = dict.fromkeys(newinds)
        return newobj

    def __get_com(self):
        (inds, mdl) = self.__require_indices()
        return _modeller.get_selection_com(mdl.modpt, inds)
    def __set_com(self, val):
        com = self.mass_center
        self.translate([val[0] - com[0], val[1] - com[1], val[2] - com[2]])
    def __get_x(self):
        return self.__get_com()[0]
    def __get_y(self):
        return self.__get_com()[1]
    def __get_z(self):
        return self.__get_com()[2]
    def __set_x(self, val):
        self.translate([val - self.x, 0, 0])
    def __set_y(self, val):
        self.translate([0, val - self.y, 0])
    def __set_z(self, val):
        self.translate([0, 0, val - self.z])

    mass_center = property(__get_com, __set_com,
                           doc="Coordinates of mass center")
    x = property(__get_x, __set_x, doc="x coordinate of mass center")
    y = property(__get_y, __set_y, doc="y coordinate of mass center")
    z = property(__get_z, __set_z, doc="z coordinate of mass center")
