import _modeller
import util.top as top
from modeller.restraints import restraints
import model_topology
import alignment
from modeller.util.modobject import modobject
from modeller.coordinates import coordinates, get_residue_atom_indices, \
                                 atomlist, residuelist, residue

class model(coordinates):
    """Holds a model of a protein"""
    __modpt = None
    env = None
    top = None
    __gprsr = None
    
    def __new__(cls, *args, **vars):
        obj = modobject.__new__(cls)
        obj.__modpt = _modeller.new_model(obj)
        return obj

    def __init__(self, env, **vars):
        coordinates.__init__(self)
        self.env = env.copy()
        self.top = top.top(self.env)
        self.group_restraints = self.env.group_restraints
        if len(vars) > 0:
          self.read(**vars)

    def __getstate__(self):
        d = coordinates.__getstate__(self)
        for key in ('__gprsr', 'dope_restraints', 'dopehr_restraints'):
            if d.has_key(key):
                del d[key]
        return d

    def __del__(self):
        _modeller.free_model(self.modpt)

    def __repr__(self):
        return "Model containing %s, %s, and %s" \
               % (repr(self.chains), repr(self.residues), repr(self.atoms))

    def __str__(self):
        return "<%s>" % repr(self)

    def __get_modpt(self):
        return self.__modpt
    def __get_seqpt(self):
        return _modeller.model_seq_get(self.__modpt)
    def __get_cdpt(self):
        return _modeller.model_cd_get(self.__modpt)

    def get_atom_indices(self):
        """Get the indices of all atoms in this model"""
        return (range(1, self.natm+1), self)

    def read(self, file, io=None, **vars):
        """Read coordinates from a file"""
        if io is None:
            io = self.env.io
        return self.top.read_model('model.read', mdl=self.modpt, io=io.modpt,
                                   libs=self.env.libs.modpt, file=file, **vars)

    def write(self, file, **vars):
        """Write coordinates to a file"""
        return self.top.write_model('model.write', mdl=self.modpt,
                                    libs=self.env.libs.modpt, sel1=(),
                                    file=file, write_all_atoms=True, **vars)

    def build_sequence(self, sequence):
        """Build an extended chain from a string of one-letter residue codes"""
        a = alignment.alignment(self.env)
        a.append_sequence(sequence)
        self.clear_topology()
        self.generate_topology(a[0])
        self.build(initialize_xyz=True, build_method='INTERNAL_COORDINATES')

    def use_lennard_jones(self):
        """Set up to use Lennard Jones rather than soft sphere (the default) for
           van der Waals interactions"""
        edat = self.env.edat
        edat.contact_shell = 8.00
        edat.dynamic_sphere = False
        edat.dynamic_lennard = True

    def assess_ga341(self):
        """Assess with the GA341 method"""
        return _modeller.assess_ga341(self.modpt, self.env.libs.modpt)

    def fast_rmsd(self, mdl):
        """Calculate the RMSD between this model and the input."""
        return _modeller.fast_rmsd(self.modpt, mdl.modpt)

    def orient(self):
        """Center and orient the model"""
        import orient
        retval = _modeller.orient_model(self.modpt)
        return orient.orient_data(*retval)

    def build(self, **vars):
        """Build coordinates from topology"""
        return self.top.build_model('model.build', mdl=self.modpt,
                                    libs=self.env.libs.modpt, **vars)

    def transfer_xyz(self, aln, io=None, **vars):
        """Copy coordinates from template structures"""
        if io is None:
            io = self.env.io
        return self.top.transfer_xyz('model.transfer_xyz', mdl=self.modpt,
                                     aln=aln.modpt, io=io.modpt,
                                     libs=self.env.libs.modpt, **vars)

    def res_num_from(self, mdl, aln):
        """Copy residue numbers from the given model"""
        return _modeller.transfer_res_numb(self.modpt, mdl.modpt, aln.modpt,
                                           self.env.libs.modpt)

    def reorder_atoms(self):
        """Standardize atom order to match the current topology library"""
        return _modeller.reorder_atoms(self.modpt, self.env.libs.modpt)

    def rename_segments(self, **vars):
        """Relabel residue numbers in each chain/segment"""
        return self.top.rename_segments('model.rename_segments', mdl=self.modpt,
                                        **vars)

    def to_iupac(self):
        """Make dihedral angles satisfy the IUPAC convention"""
        return _modeller.iupac_model(self.modpt, self.env.libs.modpt)

    def clear_topology(self):
        """Clear all covalent topology (atomic connectivitiy) and sequence"""
        return _modeller.clear_model_topology(self.modpt)

    def generate_topology(self, alnseq, io=None, **vars):
        """Generate covalent topology (atomic connectivity) and sequence"""
        if not isinstance(alnseq, alignment.alnsequence):
            raise TypeError, """Must use an 'alnsequence' object here.
For example, replace 'generate_topology(aln, sequence="foo")' with
'generate_topology(aln["foo"])'"""
        if io is None:
            io = self.env.io
        return self.top.generate_topology('model.generate_topology',
                                          mdl=self.modpt, aln=alnseq.aln.modpt,
                                          io=io.modpt, libs=self.env.libs.modpt,
                                          iseq=alnseq._num, **vars)

    def patch(self, residue_type, residues):
        """Patch the model topology"""
        if not isinstance(residues, (list, tuple)):
            residues = [residues]
        for r in residues:
            if not isinstance(r, residue) or r.mdl is not self:
                raise TypeError, "expecting one or more 'residue' objects"""
        return _modeller.patch(mdl=self.modpt, libs=self.env.libs.modpt,
                               residue_type=residue_type,
                               residue_ids=[r.index for r in residues])

    def patch_ss(self):
        """Guess disulfides from the current structure"""
        return _modeller.patch_ss_model(self.modpt, self.env.libs.modpt)

    def patch_ss_templates(self, aln, io=None):
        """Guess disulfides from templates"""
        if io is None:
            io = self.env.io
        return _modeller.patch_ss_templates(mdl=self.modpt, aln=aln.modpt,
                                            io=io.modpt,
                                            libs=self.env.libs.modpt)

    def write_data(self, edat=None, **vars):
        """Write derivative model data"""
        if edat is None:
            edat = self.env.edat
        return self.top.write_data('model.write_data', mdl=self.modpt,
                                   edat=edat.modpt, libs=self.env.libs.modpt,
                                   **vars)

    def make_region(self, atom_accessibility=1.0, region_size=20):
        """Define a random surface patch of atoms"""
        return _modeller.make_region(mdl=self.modpt, libs=self.env.libs.modpt,
                                     atom_accessibility=atom_accessibility,
                                     region_size=region_size)

    def make_chains(self, **vars):
        """Write out matching chains to separate files"""
        return self.top.make_chains('model.make_chains', mdl=self.modpt,
                                    libs=self.env.libs.modpt, **vars)

    def color(self, aln):
        """Color according to the alignment"""
        return _modeller.color_aln_model(self.modpt, aln.modpt)

    def loops(self, aln, minlength, maxlength, insertion_ext, deletion_ext):
        """Returns a list of all loops (insertions or deletions) in the model,
           as defined by the alignment"""
        return self.get_insertions(aln, minlength, maxlength, insertion_ext) \
               + self.get_deletions(aln, deletion_ext)

    def get_insertions(self, aln, minlength, maxlength, extension):
        """Returns a list of all insertions in the model, as defined by the
           alignment."""
        return self.__get_insdel(aln, _modeller.get_next_insertion, minlength,
                                 maxlength, extension)

    def get_deletions(self, aln, extension):
        """Returns a list of all deletions in the model, as defined by the
           alignment."""
        return self.__get_insdel(aln, _modeller.get_next_deletion, extension)

    def assess_normalized_dope(self):
        """Assess the model, and return a normalized DOPE score (z score)"""
        from modeller.selection import selection
        import normalized_dope
        sel = selection(self)
        dope_score = sel.assess_dope()
        scorer = normalized_dope.dope_scorer(self)
        z_score = scorer.get_z_score(dope_score)
        print ">> Normalized DOPE z score: %.3f" % z_score
        return z_score

    def assess_new_normalized_dope(self):
        """Assess the model, and return a normalized DOPE score (z score)"""
        from modeller.selection import selection
        import normalized_dope
        sel = selection(self)
        dope_score = sel.assess_dope()
        scorer = normalized_dope.new_dope_scorer(self)
        z_score = scorer.get_z_score(dope_score)
        print ">> New normalized DOPE z score: %.3f" % z_score
        return z_score

    def assess_normalized_dopehr(self):
        """Assess the model, and return a normalized DOPE-HR score (z score)"""
        from modeller.selection import selection
        import normalized_dope
        sel = selection(self)
        dope_score = sel.assess_dopehr()
        scorer = normalized_dope.dopehr_scorer(self)
        z_score = scorer.get_z_score(dope_score)
        print ">> Normalized DOPE-HR z score: %.3f" % z_score
        return z_score

    def get_normalized_dope_profile(self):
        """Return a normalized DOPE per-residue profile"""
        from modeller.selection import selection
        import normalized_dope
        import physical
        sel = selection(self)
        edat = sel.get_dope_energy_data()
        oldgprsr = self.group_restraints
        self.group_restraints = sel.get_dope_potential()
        try:
            profile = sel.get_energy_profile(edat, physical.nonbond_spline)
        finally:
            self.group_restraints = oldgprsr
        scorer = normalized_dope.dope_scorer(self)
        return scorer.get_profile(profile)

    def write_psf(self, file, xplor=True):
        """Write the molecular topology to a PSF file, in either X-PLOR format
           (the default) or CHARMM format."""
        model_topology.write_psf(file, self, xplor)

    def find_atoms(self, residue_type, atom_names):
        return model_topology.find_atoms(self, residue_type, atom_names)

    def find_chi1_dihedrals(self, residue_type):
        return model_topology.find_dihedrals(self, residue_type, 5)
    def find_chi2_dihedrals(self, residue_type):
        return model_topology.find_dihedrals(self, residue_type, 6)
    def find_chi3_dihedrals(self, residue_type):
        return model_topology.find_dihedrals(self, residue_type, 7)
    def find_chi4_dihedrals(self, residue_type):
        return model_topology.find_dihedrals(self, residue_type, 8)

    def __get_insdel(self, aln, func, *args):
        _modeller.chk_aln_model(aln.modpt, len(aln), self.modpt,
                                self.env.libs.modpt)
        l = []
        pos = 0
        while pos >= 0:
            (pos, start, end) = func(aln.modpt, pos, *args)
            if start > 0 and end > 0:
                l.append(residuelist(self, start - 1, end - start + 1))
        return l

    def __get_seq_id(self):
        return _modeller.model_seq_id_get(self.modpt)
    def __set_seq_id(self, val):
        return _modeller.model_seq_id_set(self.modpt, val)
    def __get_remark(self):
        return _modeller.model_remark_get(self.modpt)
    def __set_remark(self, val):
        return _modeller.model_remark_set(self.modpt, val)
    def __get_header(self):
        return _modeller.model_header_get(self.modpt)
    def __set_header(self, val):
        return _modeller.model_header_set(self.modpt, val)
    def __get_last_energy(self):
        return _modeller.model_last_energy_get(self.modpt)
    def __set_last_energy(self, val):
        return _modeller.model_last_energy_set(self.modpt, val)
    def __get_restraints(self):
        return restraints(self)
    def __set_group_restraints(self, val):
        self.__gprsr = val
        if val:
            _modeller.model_group_restraints_set(self.modpt,
                                                 self.env.libs.modpt, val.modpt)
        else:
            _modeller.model_group_restraints_unset(self.modpt,
                                                   self.env.libs.modpt)
    def __get_group_restraints(self):
        return self.__gprsr
    def __get_bonds(self):
        return model_topology.bondlist(self, _modeller.model_topology_nbnd_get,
                                       _modeller.model_topology_iatb_get, 2)
    def __get_angles(self):
        return model_topology.bondlist(self, _modeller.model_topology_nang_get,
                                       _modeller.model_topology_iata_get, 3)
    def __get_dihedrals(self):
        return model_topology.bondlist(self, _modeller.model_topology_ndih_get,
                                       _modeller.model_topology_iatd_get, 4)
    def __get_impropers(self):
        return model_topology.bondlist(self, _modeller.model_topology_nimp_get,
                                       _modeller.model_topology_iati_get, 4)


    modpt = property(__get_modpt, doc="Internal Modeller object")
    cdpt = property(__get_cdpt, doc="Internal Modeller coordinates object")
    seqpt = property(__get_seqpt, doc="Internal Modeller sequence object")
    seq_id = property(__get_seq_id, __set_seq_id,
                      doc="Sequence identity between model and best template")
    header = property(__get_header, __set_header, doc="PDB header")
    last_energy = property(__get_last_energy, __set_last_energy,
                           doc="Energy from last energy or optimize")
    remark = property(__get_remark, __set_remark, doc="PDB REMARK line(s)")
    restraints = property(__get_restraints,
                          doc="All restraints acting on this model")
    group_restraints = property(__get_group_restraints, __set_group_restraints,
                                doc="Group restraints active for this model")
    bonds = property(__get_bonds, doc="All defined bonds")
    angles = property(__get_angles, doc="All defined angles")
    dihedrals = property(__get_dihedrals, doc="All defined dihedrals")
    impropers = property(__get_impropers, doc="All defined impropers")
