from automodel import automodel
from modeller.util.modobject import modobject
from modeller import *
import refine

class loopmodel(automodel):
    """Optionally build comparative models, and then refine the loops"""
    loop = None
    inimodel = None

    def __init__(self, env, sequence, alnfile=None, knowns=None, inimodel=None,
                 deviation=None, library_schedule=None, toplib=None,
                 parlib=None, topology_model=None, csrfile=None, inifile=None,
                 assess_methods=None, loop_assess_methods=None):
        automodel.__init__(self, env, alnfile, knowns, sequence, deviation,
                           library_schedule, toplib, parlib, topology_model,
                           csrfile, inifile, assess_methods)
        self.inimodel = inimodel
        self.loop = loop_data(env)
        self.loop.assess_methods = loop_assess_methods

    def make(self, exit_stage=0):
        if self.inimodel:
            self.env = self.loop.env
            self.build_seq(self.inimodel, 1)
        else:
            automodel.make(self, exit_stage)
        self.write_summary(self.loop.outputs, 'loop models')

    def multiple_models(self):
        automodel.multiple_models(self)
        envcopy = self.env
        self.env = self.loop.env
        for num in range(self.starting_model, self.ending_model + 1):
            filename = modfile.default(root_name=self.sequence, file_id='.B',
                                       id1=9999, id2=num, file_ext=self.pdb_ext)
            self.build_seq(filename, num)
        self.env = envcopy

    def fit_models_on_template(self):
        automodel.fit_models_on_template(self)
        for num in range(self.starting_model, self.ending_model + 1):
            filename = modfile.default(root_name=self.sequence, file_id='.B',
                                       id1=9999, id2=num, file_ext='_fit.pdb')
            mdl = model(self.env, file=filename)
            aln = alignment(self.env)
            aln.append_model(mdl, align_codes=self.sequence)
            aln.append_model(mdl, align_codes=self.sequence)
            for id1 in range(self.loop.starting_model,
                             self.loop.ending_model + 1):
                filename = modfile.default(root_name=self.sequence,
                                           file_id='.BL', id1=id1, id2=num,
                                           file_ext='')
                mdl2 = model(self.env, file=filename)
                mdl.pick_atoms(pick_atoms_set=1, atom_types='CA')
                mdl.superpose(mdl2, aln, fit=True)
                filename = modfile.default(root_name=self.sequence,
                                           file_id='.BL', id1=id1, id2=num,
                                           file_ext='_fit.pdb')
                mdl2.write(file=filename)
        
    def build_seq(self, filename, num):
        self.csrfile = self.sequence + '.lrsr'
        self.read_top_par()
        gprsr = group_restraints(classes='$(LIB)/atmcls-melo.lib',
                                 parameters='$(LIB)/melo1-dist.lib')
        oldgprsr = self.group_restraints
        self.group_restraints = None
        self.read(file=filename)
        self.group_restraints = gprsr

        aln = alignment(self.env)
        for i in range(2):
            aln.append_model(self, align_codes=self.sequence,
                             atom_files=filename)
        self.generate_topology(aln, sequence=self.sequence)

        # Save self.seq_id, since otherwise transfer_xyz will set it to 100%
        seq_id = self.seq_id
        self.transfer_xyz(aln)
        self.seq_id = seq_id

        self.res_num_from(model(self.env, file=filename), aln)
        self.special_patches(aln)
        self.select_loop_atoms()
        self.loop_restraints(aln)

        # Select corresponding restraints only:
        # only necessary to eliminate inefficiencies in 'special_restraints'
        # because MAKE_RSRS works with selected atoms now:
        self.restraints.unpick_all()
        self.restraints.pick()
        self.restraints.condense()
        self.restraints.write(file=self.csrfile)

        # Calculate energy for the original (raw) loop:
        self.env.edat.nonbonded_sel_atoms = 1
        self.energy()

        # Prepare the starting structure (comment it out if
        # the input PDB file is a better initial structure):
        self.build_ini_loop()

        ini_model = "%s.IL%04d%04d%s" % (self.sequence, 0, num, self.pdb_ext)
        self.write(file=ini_model)

        self.schedule.make(library_schedule=6)
        for id1 in range(self.loop.starting_model, self.loop.ending_model + 1):
            self.read(file=ini_model)
            self.select_loop_atoms()
            self.randomize_xyz(deviation=5.0)
            swfile = modfile.default(root_name=self.sequence, file_id='.DL',
                                     id1=id1, id2=num)
            self.switch_trace(file=swfile)

            filename = modfile.default(root_name=self.sequence, file_id='.BL',
                                       id1=id1, id2=num, file_ext=self.pdb_ext)
            out = {'name':filename, 'failure':None}

            # Refine without the rest of the protein:
            self.env.edat.nonbonded_sel_atoms = 2
            self.optimize_loop()
            # Refine in the context of the rest of the protein:
            self.env.edat.nonbonded_sel_atoms = 1
            self.optimize_loop()

            out['molpdf'] = self.energy()

            self.to_iupac()
            self.write(file=filename)
            # Do model assessment if requested
            self.assess(self.loop.assess_methods, out)

            self.loop.outputs.append(out)
        self.group_restraints = oldgprsr

    def build_ini_loop(self):
        self.unbuild()
        self.build(build_method='3D_INTERPOLATION', initialize_xyz=False)

    def loop_restraints(self, aln):
        dih_lib_only = True
        mnch_lib = 1
        res_at = 1
        self.restraints.clear()
        for typ in ('bond', 'angle', 'improper', 'dihedral',
                    'phi-psi_binormal'):
            self.restraints.make(aln=aln, restraint_type=typ,
                                 spline_on_site=self.spline_on_site,
                                 dih_lib_only=dih_lib_only,
                                 mnch_lib=mnch_lib, restraint_sel_atoms=res_at)
        for typ in ('omega', 'chi1', 'chi2', 'chi3', 'chi4'):
            self.restraints.make(aln=aln, restraint_type=typ+'_dihedral',
                                 spline_on_site=self.spline_on_site,
                                 dih_lib_only=dih_lib_only, mnch_lib=mnch_lib,
                                 restraint_sel_atoms=res_at, spline_range=4.0,
                                 spline_dx=0.3, spline_min_points=5)
        self.special_restraints(aln)

    def optimize_loop(self):
        for step in range(1,6):
            self.schedule.step = step
            self.optimize(optimization_method=1, max_iterations=200,
                          output=self.optimize_output,
                          min_atom_shift=0.001, trace_output=self.trace_output)
        if self.loop.md_level:
            self.loop.md_level(self)
        self.schedule.step = 5
        self.optimize(optimization_method=1, max_iterations=1000,
                      output=self.optimize_output,
                      min_atom_shift=0.00001, trace_output=self.trace_output)

    def select_loop_atoms(self):
        """The default loop atom selection routine. This selects all atoms near
           gaps in the alignment. You can redefine this routine to select a
           different region, and in fact this is necessary if you are refining a
           PDB file, as no alignment is available in this case."""
        if self.knowns is None:
            raise ModellerError, \
                  "No alignment: you must redefine select_loop_atoms"
        aln = self.read_alignment()
        self.pick_atoms(aln, selection_segment=('LOOPS',''), atom_types='ALL',
                        res_types='STD', selection_status='INITIALIZE',
                        selection_mode='ATOM', selection_search='SEGMENT',
                        selection_from='ALL')


class loop_data(modobject):
    starting_model = 1
    ending_model = 1
    md_level = None
    env = None
    outputs = None
    assess_methods = None

    def __init__(self, env):
        self.md_level = refine.slow
        self.env = env.copy()
        self.env.edat.contact_shell = 7.0
        self.env.edat.dynamic_modeller = True
        self.env.edat.dynamic_sphere = True
        self.outputs = []
