import _modeller
import util.top as top
import util.modutil as modutil
from energy_data import energy_data
from group_restraints import group_restraints
import alignment
from modeller.util.modobject import modobject

class model(modobject):
    """Holds a model of a protein"""
    
    # Global reference to the DOPE restraints file, so that it's loaded
    # only once
    dope_restraints = None

    def __init__(self, env, io=None, aln=None, libs=None, **vars):
        self.add_members(('_model__modpt', 'env', 'top', '_model__sched',
                          '_model__gprsr'))
        self.__modpt = _modeller.new_model()
        self.env = env.copy()
        self.top = top.top(self.env)
        self.__sched = schedule(self)
        self.group_restraints = self.env.group_restraints
        if len(vars) > 0:
          self.read(io, aln, libs, **vars)

    def __del__(self):
        _modeller.free_model(self.modpt)

    def __get_modpt(self):
        return self.__modpt

    def read(self, io=None, aln=None, libs=None, **vars):
        """Read coordinates from a file"""
        if io is None:
            io = self.env.io
        if libs is None:
            libs = self.env.libs
        if aln is None:
            aln = alignment.alignment(self.env)
        return self.top.read_model('model.read', mdl=self.modpt, aln=aln.modpt,
                                   io=io.modpt, libs=libs.modpt, **vars)

    def write(self, file, libs=None, **vars):
        """Write coordinates to a file"""
        vars['file'] = file
        if libs is None:
            libs = self.env.libs
        return self.top.write_model('model.write', mdl=self.modpt,
                                    libs=libs.modpt, **vars)

    def energy(self, edat=None, libs=None, output='LONG', file='default',
               **vars):
        """Evaluate the objective function given restraints"""
        if edat is None:
            edat = self.env.edat
        if libs is None:
            libs = self.env.libs
        vars['output'] = output
        vars['file'] = file
        return self.top.energy('model.energy', mdl=self.modpt, edat=edat.modpt,
                               libs=libs.modpt, **vars)

    def assess_dope(self, output='SHORT NO_REPORT',
                    residue_span_range=(1, 9999),
                    schedule_scale=[0.]*30+[1.]+[0.]*4, **vars):
        print ">> Model assessment by DOPE potential"
        edat = energy_data(contact_shell=15.0, dynamic_modeller=True,
                           dynamic_lennard=False, dynamic_sphere=False,
                           excl_local=(False, False, False, False))
        old_gprsr = self.group_restraints
        if not model.dope_restraints:
            model.dope_restraints = \
                group_restraints(self.env, classes='${LIB}/atmcls-mf.lib',
                                 parameters='${LIB}/dist-mf.lib')
        self.group_restraints = model.dope_restraints
        molpdf = self.energy(edat=edat, residue_span_range=residue_span_range,
                             output=output, schedule_scale=schedule_scale,
                             **vars)
        self.group_restraints = old_gprsr
        print "DOPE score               : %12.6f" % molpdf
        return molpdf

    def assess_ga341(self):
        return self.top.assess_ga341('model.assess_ga341', mdl=self.modpt,
                                     libs=self.env.libs.modpt)

    def optimize(self, edat=None, libs=None, **vars):
        """Optimize the objective function given restraints"""
        if edat is None:
            edat = self.env.edat
        if libs is None:
            libs = self.env.libs
        return self.top.optimize('model.optimize', mdl=self.modpt,
                                 edat=edat.modpt, libs=libs.modpt, **vars)

    def switch_trace(self, **vars):
        return self.top.switch_trace('model.switch_trace', mdl=self.modpt,
                                     **vars)

    def debug_function(self, edat=None, libs=None, **vars):
        """Test code self-consistency"""
        if edat is None:
            edat = self.env.edat
        if libs is None:
            libs = self.env.libs
        return self.top.debug_function('model.debug_function', mdl=self.modpt,
                                       edat=edat.modpt, libs=libs.modpt, **vars)

    def pick_atoms(self, aln=None, libs=None, **vars):
        """Select atoms"""
        logname = 'model.pick_atoms'
        if libs is None:
            libs = self.env.libs
        if aln is None:
            selseg = self.top.get_argument('selection_segment', vars)
            if not isinstance(selseg, (list, tuple)):
                selseg = [selseg]
            if len(selseg) > 0 and selseg[0].upper() == 'LOOPS':
                modutil.require_argument('aln', logname)
            else:
                aln = alignment.alignment(self.env)
        return self.top.pick_atoms(logname, mdl=self.modpt,
                                   aln=aln.modpt, libs=libs.modpt, **vars)

    def superpose(self, mdl, aln, libs=None, **vars):
        """Superpose the input model on this one, given an alignment of the
           two"""
        if libs is None:
            libs = self.env.libs
        return self.top.superpose('model.superpose', mdl=self.modpt,
                                  mdl2=mdl.modpt, aln=aln.modpt,
                                  libs=libs.modpt, **vars)

    def orient(self, **vars):
        """Center and orient the model"""
        return self.top.orient_model('model.orient', mdl=self.modpt, **vars)

    def mutate(self, libs=None, **vars):
        """Mutate selected residues"""
        if libs is None:
            libs = self.env.libs
        return self.top.mutate_model('model.mutate', mdl=self.modpt,
                                     libs=libs.modpt, **vars)

    def build(self, libs=None, **vars):
        """Build coordinates from topology"""
        if libs is None:
            libs = self.env.libs
        return self.top.build_model('model.build', mdl=self.modpt,
                                    libs=libs.modpt, **vars)

    def unbuild(self, **vars):
        """Undefine all coordinates"""
        return self.top.unbuild_model('model.unbuild', mdl=self.modpt, **vars)

    def transfer_xyz(self, aln, io=None, libs=None, **vars):
        """Copy coordinates from template structures"""
        if io is None:
            io = self.env.io
        if libs is None:
            libs = self.env.libs
        return self.top.transfer_xyz('model.transfer_xyz', mdl=self.modpt,
                                     aln=aln.modpt, io=io.modpt,
                                     libs=libs.modpt, **vars)

    def res_num_from(self, mdl, aln, libs=None, **vars):
        """Copy residue numbers from the given model"""
        if libs is None:
            libs = self.env.libs
        return self.top.transfer_res_numb('model.res_num_from', mdl=self.modpt,
                                          mdl2=mdl.modpt, aln=aln.modpt,
                                          libs=libs.modpt, **vars)

    def reorder_atoms(self, libs=None, **vars):
        """Standardize atom order to match the current topology library"""
        if libs is None:
            libs = self.env.libs
        return self.top.reorder_atoms('model.reorder_atoms', mdl=self.modpt,
                                      libs=libs.modpt, **vars)

    def rename_segments(self, **vars):
        """Relabel residue numbers in each chain/segment"""
        return self.top.rename_segments('model.rename_segments', mdl=self.modpt,
                                        **vars)

    def pick_hot_atoms(self, edat=None, libs=None, **vars):
        """Pick atoms violating restraints"""
        if edat is None:
            edat = self.env.edat
        if libs is None:
            libs = self.env.libs
        return self.top.pick_hot_atoms('model.pick_hot_atoms', mdl=self.modpt,
                                       edat=edat.modpt, libs=libs.modpt, **vars)

    def randomize_xyz(self, libs=None, **vars):
        """Randomize coordinates"""
        if libs is None:
            libs = self.env.libs
        return self.top.randomize_xyz('model.randomize_xyz', mdl=self.modpt,
                                      libs=libs.modpt, **vars)

    def to_iupac(self, libs=None, **vars):
        """Make dihedral angles satisfy the IUPAC convention"""
        if libs is None:
            libs = self.env.libs
        return self.top.iupac_model('model.to_iupac', mdl=self.modpt,
                                    libs=libs.modpt, **vars)

    def rotate(self, **vars):
        """Rotate and translate the coordinates"""
        return self.top.rotate_model('model.rotate', mdl=self.modpt, **vars)

    def generate_topology(self, aln, io=None, libs=None, **vars):
        """Generate covalent topology (atomic connectivity)"""
        if io is None:
            io = self.env.io
        if libs is None:
            libs = self.env.libs
        return self.top.generate_topology('model.generate_topology',
                                          mdl=self.modpt, aln=aln.modpt,
                                          io=io.modpt, libs=libs.modpt, **vars)

    def patch(self, libs=None, **vars):
        """Patch the model topology"""
        if libs is None:
            libs = self.env.libs
        return self.top.patch('model.patch', mdl=self.modpt, libs=libs.modpt,
                              **vars)

    def patch_ss(self, libs=None, **vars):
        """Guess disulfides from the current structure"""
        if libs is None:
            libs = self.env.libs
        return self.top.patch_ss_model('model.patch_ss', mdl=self.modpt,
                                       libs=libs.modpt, **vars)

    def patch_ss_templates(self, aln, io=None, libs=None, **vars):
        """Guess disulfides from templates"""
        if io is None:
            io = self.env.io
        if libs is None:
            libs = self.env.libs
        return self.top.patch_ss_templates('model.patch_ss_templates',
                                           mdl=self.modpt, aln=aln.modpt,
                                           io=io.modpt, libs=libs.modpt, **vars)

    def write_data(self, edat=None, libs=None, **vars):
        """Write derivative model data"""
        if edat is None:
            edat = self.env.edat
        if libs is None:
            libs = self.env.libs
        return self.top.write_data('model.write_data', mdl=self.modpt,
                                   edat=edat.modpt, libs=libs.modpt, **vars)

    def write_pdb_xref(self, aln, libs=None, **vars):
        """Write residue number/index correspondence"""
        if libs is None:
            libs = self.env.libs
        return self.top.write_pdb_xref('model.write_pdb_xref', mdl=self.modpt,
                                       aln=aln.modpt, libs=libs.modpt, **vars)

    def make_region(self, libs=None, **vars):
        """Define a random surface patch of atoms"""
        if libs is None:
            libs = self.env.libs
        return self.top.make_region('model.make_region', mdl=self.modpt,
                                    libs=libs.modpt, **vars)

    def make_chains(self, libs=None, **vars):
        """Write out matching chains to separate files"""
        if libs is None:
            libs = self.env.libs
        return self.top.make_chains('model.make_chains', mdl=self.modpt,
                                    libs=libs.modpt, **vars)

    def color(self, aln, **vars):
        """Color according to the alignment"""
        return self.top.color_aln_model('model.color', mdl=self.modpt,
                                        aln=aln.modpt, **vars)

    def rotate_dihedrals(self, libs=None, **vars):
        """Optimize or randomize dihedral angles"""
        if libs is None:
            libs = self.env.libs
        return self.top.rotate_dihedrals('model.rotate_dihedrals',
                                         mdl=self.modpt, libs=libs.modpt,
                                         **vars)

    def __get_resol(self):
        return _modeller.get_model_resol(self.modpt)
    def __set_resol(self, val):
        return _modeller.set_model_resol(self.modpt, val)
    def __get_rfactr(self):
        return _modeller.get_model_rfactr(self.modpt)
    def __set_rfactr(self, val):
        return _modeller.set_model_rfactr(self.modpt, val)
    def __get_seq_id(self):
        return _modeller.get_model_seq_id(self.modpt)
    def __set_seq_id(self, val):
        return _modeller.set_model_seq_id(self.modpt, val)
    def __get_natm(self):
        return _modeller.get_model_natm(self.modpt)
    def __get_nres(self):
        return _modeller.get_model_nres(self.modpt)
    def __get_nsegm(self):
        return _modeller.get_model_nsegm(self.modpt)
    def __get_header(self):
        return _modeller.get_model_header(self.modpt)
    def __set_header(self, val):
        return _modeller.set_model_header(self.modpt, val)
    def __get_remark(self):
        return _modeller.get_model_remark(self.modpt)
    def __set_remark(self, val):
        return _modeller.set_model_remark(self.modpt, val)
    def __get_atoms(self):
        return atomlist(self)
    def __get_residues(self):
        return residuelist(self)
    def __get_segments(self):
        return segmentlist(self)
    def __get_schedule(self):
        return self.__sched
    def __get_restraints(self):
        return restraints(self)
    def __get_symmetry(self):
        return symmetry(self)
    def __set_group_restraints(self, val):
        self.__gprsr = val
        if val:
            _modeller.set_model_group_restraints(self.modpt,
                                                 self.env.libs.modpt, val.modpt)
        else:
            _modeller.unset_model_group_restraints(self.modpt,
                                                   self.env.libs.modpt)
    def __get_group_restraints(self):
        return self.__gprsr

    modpt = property(__get_modpt)
    resol = property(__get_resol, __set_resol)
    rfactr = property(__get_rfactr, __set_rfactr)
    seq_id = property(__get_seq_id, __set_seq_id)
    natm = property(__get_natm)
    nres = property(__get_nres)
    nsegm = property(__get_nsegm)
    header = property(__get_header, __set_header)
    remark = property(__get_remark, __set_remark)
    atoms = property(__get_atoms)
    residues = property(__get_residues)
    segments = property(__get_segments)
    schedule = property(__get_schedule)
    restraints = property(__get_restraints)
    symmetry = property(__get_symmetry)
    group_restraints = property(__get_group_restraints, __set_group_restraints)

class schedule(object):
    """Holds all information for the optimization schedule"""

    def __init__(self, mdl):
        self.__mdl = mdl
        self.__modpt = _modeller.get_model_schedule(self.__mdl.modpt)
        self.top = mdl.top

    def __get_modpt(self):
        return self.__modpt

    def __len__(self):
        return _modeller.get_schedule_nsch(self.modpt)

    def read(self, **vars):
        """Read optimization schedule from a file"""
        return self.top.read_schedule('schedule.read', mdl=self.__mdl.modpt,
                                      **vars)

    def make(self, **vars):
        """Create optimization schedule"""
        return self.top.make_schedule('schedule.make', mdl=self.__mdl.modpt,
                                      **vars)

    def write(self, **vars):
        """Write optimization schedule to a file"""
        return self.top.write_schedule('schedule.write', mdl=self.__mdl.modpt,
                                       **vars)

    def __get_step(self):
        return _modeller.get_schedule_step(self.modpt)
    def __set_step(self, val):
        return _modeller.set_schedule_step(self.modpt, val)

    modpt = property(__get_modpt)
    step = property(__get_step, __set_step)


class restraints(object):
    """Holds all restraints which can act on a model"""

    def __init__(self, mdl):
        self.__mdl = mdl
        self.__modpt = _modeller.get_model_restraints(self.__mdl.modpt)
        self.top = mdl.top

    def __get_modpt(self):
        return self.__modpt

    def __len__(self):
        return _modeller.get_restraints_ncsr(self.modpt)

    def add(self, **vars):
        """Add a single specified restraint"""
        return self.top.add_restraint('restraints.add', mdl=self.__mdl.modpt,
                                      **vars)

    def condense(self, **vars):
        """Delete unselected restraints"""
        return self.top.condense_restraints('restraints.condense',
                                            mdl=self.__mdl.modpt, **vars)

    def unpick(self, **vars):
        """Unselect specified restraints"""
        return self.top.delete_restraint('restraints.unpick',
                                         mdl=self.__mdl.modpt, **vars)

    def unpick_all(self):
        """Unselect all restraints"""
        return _modeller.unpick_all_restraints(self.modpt)

    def clear(self):
        """Delete all restraints"""
        return _modeller.clear_restraints(self.modpt)

    def append(self, file, **vars):
        """Append restraints from a file, and select them"""
        vars['file'] = file
        vars['add_restraints'] = True
        return self.top.read_restraints('restraints.append',
                                        mdl=self.__mdl.modpt, **vars)

    def read(self, file, **vars):
        """Read restraints from a file, and select them"""
        self.clear()
        return self.append(file, **vars)

    def write(self, **vars):
        """Write currently selected restraints to a file"""
        return self.top.write_restraints('restraints.write',
                                         mdl=self.__mdl.modpt, **vars)

    def pick(self, **vars):
        """Select specified restraints"""
        vars['add_restraints'] = True
        return self.top.pick_restraints('restraints.pick', mdl=self.__mdl.modpt,
                                        **vars)

    def spline(self, edat=None, libs=None, **vars):
        """Approximate selected restraints by splines"""
        if edat is None:
            edat = self.__mdl.env.edat
        if libs is None:
            libs = self.__mdl.env.libs
        return self.top.spline_restraints('restraints.spline',
                                          mdl=self.__mdl.modpt, edat=edat.modpt,
                                          libs=libs.modpt, **vars)

    def reindex(self, mdl, **vars):
        """Renumber restraints for a new model"""
        return self.top.reindex_restraints('restraints.reindex',
                                           mdl=self.__mdl.modpt, mdl2=mdl.modpt,
                                           **vars)

    def make(self, aln=None, edat=None, io=None, libs=None, **vars):
        """Calculates and selects new restraints of a specified type"""
        logname = 'restraints.make'
        if edat is None:
            edat = self.__mdl.env.edat
        if io is None:
            io = self.__mdl.env.io
        if libs is None:
            libs = self.__mdl.env.libs
        if aln is None:
            restyp = self.top.get_argument('restraint_type', vars)
            if type(selseg) is str and selseg.upper() in \
                ('DISTANCE', 'CHI1_DIHEDRAL', 'CHI2_DIHEDRAL', 'CHI3_DIHEDRAL',
                 'CHI4_DIHEDRAL', 'PHI_DIHEDRAL', 'PSI_DIHEDRAL',
                 'OMEGA_DIHEDRAL', 'PHI-PSI_BINORMAL'):
                modutil.require_argument('aln', logname)
            else:
                aln = alignment.alignment(self.env)
        vars['add_restraints'] = True
        return self.top.make_restraints(logname, mdl=self.__mdl.modpt,
                                        edat=edat.modpt, aln=aln.modpt,
                                        io=io.modpt, libs=libs.modpt, **vars)

    modpt = property(__get_modpt)


class symmetry(object):
    """Holds all symmetry restraints which can act on a model"""

    def __init__(self, mdl):
        self.__mdl = mdl
        self.top = mdl.top

    def define(self, **vars):
        """Define similar segments"""
        return self.top.define_symmetry('symmetry.define', mdl=self.__mdl.modpt,
                                        **vars)


class atomlist(object):

    def __init__(self, mdl, offset=0, length=None):
        self.mdl = mdl
        self.offset = offset
        self.length = length

    def __len__(self):
        if self.length is not None:
            return self.length
        else:
            return self.mdl.natm

    def __getitem__(self, indx):
        ret = modutil.handle_seq_indx(self, indx)
        if type(ret) is int:
            return atom(self.mdl, ret + self.offset)
        else:
            return ret


class residuelist(object):

    def __init__(self, mdl, offset=0, length=None):
        self.mdl = mdl
        self.offset = offset
        self.length = length

    def __len__(self):
        if self.length is not None:
            return self.length
        else:
            return self.mdl.nres

    def __getitem__(self, indx):
        ret = modutil.handle_seq_indx(self, indx)
        if type(ret) is int:
            return residue(self.mdl, ret + self.offset)
        else:
            return ret


class segmentlist(object):

    def __init__(self, mdl):
        self.mdl = mdl

    def __len__(self):
        return self.mdl.nsegm

    def __getitem__(self, indx):
        ret = modutil.handle_seq_indx(self, indx)
        if type(ret) is int:
            return segment(self.mdl, ret)
        else:
            return ret

class atom(object):

    def __init__(self, mdl, num):
        self.mdl = mdl
        self.num = num

    def __get_x(self):
        return _modeller.get_model_x(self.mdl.modpt, self.num)
    def __set_x(self, val):
        _modeller.set_model_x(self.mdl.modpt, self.num, val)
    def __get_y(self):
        return _modeller.get_model_y(self.mdl.modpt, self.num)
    def __set_y(self, val):
        _modeller.set_model_y(self.mdl.modpt, self.num, val)
    def __get_z(self):
        return _modeller.get_model_z(self.mdl.modpt, self.num)
    def __set_z(self, val):
        _modeller.set_model_z(self.mdl.modpt, self.num, val)
    def __get_biso(self):
        return _modeller.get_model_biso(self.mdl.modpt, self.num)
    def __set_biso(self, val):
        _modeller.set_model_biso(self.mdl.modpt, self.num, val)
    def __get_occ(self):
        return _modeller.get_model_occ(self.mdl.modpt, self.num)
    def __set_occ(self, val):
        _modeller.set_model_occ(self.mdl.modpt, self.num, val)
    def __get_charge(self):
        return _modeller.get_model_charge(self.mdl.modpt, self.num)
    def __set_charge(self, val):
        _modeller.set_model_charge(self.mdl.modpt, self.num, val)
    def __get_radii(self):
        return _modeller.get_model_radii(self.mdl.modpt, self.num)
    def __set_radii(self, val):
        _modeller.set_model_radii(self.mdl.modpt, self.num, val)
    def __get_name(self):
        return _modeller.get_model_atmnam(self.mdl.modpt, self.num)
    def __get_hetatm(self):
        return _modeller.get_model_is_hetatm(self.mdl.modpt,
                                             self.mdl.env.libs.modpt, self.num)
    def __get_residue(self):
        resind =  _modeller.get_model_iresatm(self.mdl.modpt, self.num) - 1
        return residue(self.mdl, resind)

    x = property(__get_x, __set_x)
    y = property(__get_y, __set_y)
    z = property(__get_z, __set_z)
    biso = property(__get_biso, __set_biso)
    occ = property(__get_occ, __set_occ)
    radii = property(__get_radii, __set_radii)
    charge = property(__get_charge, __set_charge)
    name = property(__get_name)
    hetatm = property(__get_hetatm)
    residue = property(__get_residue)


class residue(object):

    def __init__(self, mdl, num):
        self.mdl = mdl
        self.num = num

    def __get_name(self):
        return _modeller.get_model_resnam(self.mdl.modpt, self.num)
    def __get_chain(self):
        return _modeller.get_model_chain(self.mdl.modpt, self.num)
    def __get_atoms(self):
        (startind, endind) = get_residue_atom_indices(self.mdl, self.num,
                                                      self.num + 1)
        return atomlist(self.mdl, startind, endind - startind)

    name = property(__get_name)
    chain = property(__get_chain)
    atoms = property(__get_atoms)


class segment(object):

    def __init__(self, mdl, num):
        self.mdl = mdl
        self.num = num

    def __get_resind(self):
        iress1 = _modeller.get_model_iress1(self.mdl.modpt, self.num) - 1
        iress2 = _modeller.get_model_iress2(self.mdl.modpt, self.num)
        return (iress1, iress2)

    def __get_residues(self):
        (startres, endres) = self.__get_resind()
        return residuelist(self.mdl, startres, endres - startres)

    def __get_atoms(self):
        (startres, endres) = self.__get_resind()
        (startatm, endatm) = get_residue_atom_indices(self.mdl, startres,
                                                      endres)
        return atomlist(self.mdl, startatm, endatm - startatm)

    residues = property(__get_residues)
    atoms = property(__get_atoms)


def get_residue_atom_indices(mdl, start, end):
    startind = _modeller.get_model_iatmr1(mdl.modpt, start) - 1
    if end < mdl.nres:
        endind = _modeller.get_model_iatmr1(mdl.modpt, end) - 1
    else:
        endind = mdl.natm
    return (startind, endind)
