import _modeller
import util.modutil as modutil
from modeller.excluded_pair import excluded_pair
from modeller.rigid_body import rigid_body
from modeller.pseudo_atom import pseudo_atom
from modeller.symmetry import symmetry
from modeller.pseudo_atom_list import pseudo_atom_list
from modeller.rigid_body_list import rigid_body_list
from modeller.excluded_pair_list import excluded_pair_list
from modeller.symmetry_list import symmetry_list
import modeller.physical as physical
import modeller.alignment as alignment

class restraints(object):
    """Holds all restraints which can act on a model"""

    def __init__(self, mdl):
        self.__mdl = mdl
        self.top = mdl.top

    def __get_modpt(self):
        return _modeller.model_rsr_get(self.__mdl.modpt)

    def __len__(self):
        return _modeller.restraints_ncsr_get(self.modpt)

    def __str__(self):
        return "<List of %d restraints>" % len(self)

    def add(self, *args, **vars):
        """Add a single specified restraint"""
        if len(args) != 0:
            for rsr in args:
                if isinstance(rsr, excluded_pair):
                    raise TypeError, \
                          "use restraints.excluded_pairs.append() instead"
                elif isinstance(rsr, rigid_body):
                    raise TypeError, \
                          "use restraints.rigid_bodies.append() instead"
                elif isinstance(rsr, pseudo_atom):
                    raise TypeError, \
                          "use restraints.pseudo_atoms.append() instead"
                elif isinstance(rsr, symmetry):
                    raise TypeError, \
                          "use restraints.symmetry.append() instead"
                else:
                    return rsr._add_restraint(self, self.__mdl)
        else:
            return self.top.add_restraint('restraints.add',
                                          mdl=self.__mdl.modpt, **vars)

    def condense(self):
        """Delete unselected restraints"""
        return _modeller.condense_restraints(self.__mdl.modpt)

    def unpick(self, *atom_ids):
        """Unselect restraints acting on specified atoms"""
        inds = self.__mdl.get_list_atom_indices(atom_ids, None)
        return self.top.unpick_restraints('restraints.unpick',
                                          mdl=self.__mdl.modpt, atom_ids=inds)

    def unpick_all(self):
        """Unselect all restraints"""
        return _modeller.unpick_all_restraints(self.modpt)

    def clear(self):
        """Delete all restraints"""
        return _modeller.clear_restraints(self.__mdl.modpt)

    def append(self, file):
        """Append restraints from a file, and select them"""
        return self.top.read_restraints('restraints.append',
                                        mdl=self.__mdl.modpt, file=file)

    def read(self, file):
        """Read restraints from a file, and select them"""
        self.clear()
        return self.append(file)

    def write(self, file):
        """Write currently selected restraints to a file"""
        return self.top.write_restraints('restraints.write',
                                         mdl=self.__mdl.modpt, file=file)

    def pick(self, atmsel, **vars):
        """Select specified restraints"""
        (inds, mdl) = atmsel.get_atom_indices()
        if mdl is not self.__mdl:
            raise ValueError, "selection refers to a different model"

        return self.top.pick_restraints('restraints.pick', mdl=self.__mdl.modpt,
                                        sel1=inds, **vars)

    def spline(self, edat=None, **vars):
        """Approximate selected restraints by splines"""
        if edat is None:
            edat = self.__mdl.env.edat
        return self.top.spline_restraints('restraints.spline',
                                          mdl=self.__mdl.modpt, edat=edat.modpt,
                                          libs=self.__mdl.env.libs.modpt,
                                          **vars)

    def reindex(self, mdl):
        """Renumber restraints for a new model"""
        return _modeller.reindex_restraints(self.__mdl.modpt, mdl.modpt)

    def make(self, atmsel, aln=None, edat=None, io=None, **vars):
        """Calculates and selects new restraints of a specified type"""
        if not hasattr(atmsel, "get_atom_indices"):
            raise TypeError, "First argument needs to be an atom selection"
        logname = 'restraints.make'
        restyp = self.top.get_argument('restraint_type', vars)
        if type(restyp) is not str:
            raise TypeError, "restraint_type must be a string"
        restyp = restyp.upper()
        group = self.top.get_argument('restraint_group', vars)
        if not isinstance(group, physical.physical_type):
            raise TypeError, "restraint_group should be a physical_type object"

        (inds, mdl) = atmsel.get_atom_indices()
        if mdl is None:
            raise ValueError, "selection is empty"
        if mdl is not self.__mdl:
            raise ValueError, "selection refers to a different model"

        if edat is None:
            edat = self.__mdl.env.edat
        if io is None:
            io = self.__mdl.env.io
        if aln is None:
            if restyp in \
                ('CHI1_DIHEDRAL', 'CHI2_DIHEDRAL', 'CHI3_DIHEDRAL',
                 'CHI4_DIHEDRAL', 'PHI_DIHEDRAL', 'PSI_DIHEDRAL',
                 'OMEGA_DIHEDRAL', 'PHI-PSI_BINORMAL'):
                modutil.require_argument('aln', logname)
            else:
                aln = alignment.alignment(self.__mdl.env)
        return self.top.make_restraints(logname, mdl=self.__mdl.modpt,
                                        edat=edat.modpt, aln=aln.modpt,
                                        io=io.modpt,
                                        libs=self.__mdl.env.libs.modpt,
                                        sel1=inds, sel2=(), sel3=(), **vars)

    def make_distance(self, atmsel1, atmsel2, aln=None, edat=None, io=None,
                      **vars):
        """Calculates and selects new distance restraints of a specified type"""
        logname = 'restraints.make_distance'
        group = self.top.get_argument('restraint_group', vars)
        if not isinstance(group, physical.physical_type):
            raise TypeError, "restraint_group should be a physical_type object"

        (inds1, mdl) = atmsel1.get_atom_indices()
        if mdl is not self.__mdl:
            raise ValueError, "selection refers to a different model"
        (inds2, mdl) = atmsel2.get_atom_indices()
        if mdl is not self.__mdl:
            raise ValueError, "selection refers to a different model"

        if edat is None:
            edat = self.__mdl.env.edat
        if io is None:
            io = self.__mdl.env.io
        if aln is None:
            modutil.require_argument('aln', logname)
        return self.top.make_restraints(logname, mdl=self.__mdl.modpt,
                                        edat=edat.modpt, aln=aln.modpt,
                                        io=io.modpt,
                                        libs=self.__mdl.env.libs.modpt,
                                        restraint_type='distance',
                                        sel1=(), sel2=inds1, sel3=inds2, **vars)

    def __get_pseudo_atoms(self):
        return pseudo_atom_list(self.__mdl)
    def __get_rigid_bodies(self):
        return rigid_body_list(self.__mdl)
    def __set_rigid_bodies(self, obj):
        modutil.set_varlist(self.rigid_bodies, obj)
    def __del_rigid_bodies(self):
        modutil.del_varlist(self.rigid_bodies)
    def __get_excluded_pairs(self):
        return excluded_pair_list(self.__mdl)
    def __set_excluded_pairs(self, obj):
        modutil.set_varlist(self.excluded_pairs, obj)
    def __del_excluded_pairs(self):
        modutil.del_varlist(self.excluded_pairs)
    def __get_symmetry(self):
        return symmetry_list(self.__mdl)

    modpt = property(__get_modpt)
    pseudo_atoms = property(__get_pseudo_atoms, doc="Pseudo atoms")
    rigid_bodies = property(__get_rigid_bodies, __set_rigid_bodies,
                            __del_rigid_bodies, doc="Rigid bodies")
    excluded_pairs = property(__get_excluded_pairs, __set_excluded_pairs,
                              __del_excluded_pairs, doc="Excluded pairs")
    symmetry = property(__get_symmetry, doc="Symmetry restraints")
