from modeller.model import model
from modeller.alignment import alignment
from modeller.topology import topology
from modeller.error import ModellerError
from modeller.energy_data import energy_data
import modeller.modfile as modfile
import modeller.util.top as top
from modeller.scripts import align_strs_seq
import refine
import generate
import randomize

class automodel(model):
    """Automatically build complete model(s) using template information"""

    # Default variables
    max_ca_ca_distance = 14.0
    max_n_o_distance   = 11.0
    max_sc_mc_distance =  5.5
    max_sc_sc_distance =  5.0
    create_restraints = True
    deviation = 4.0
    toplib = '${LIB}/top_heav.lib'
    parlib = '${LIB}/par.lib'
    topology_model = None
    spline_on_site = True
    initial_malign3d = False
    final_malign3d = False
    starting_model = 1
    ending_model = 1
    write_intermediates = False
    library_schedule = 4
    pdb_ext = '.pdb'
    repeat_optimization = 1
    fit_in_refine = 'NO_FIT'
    refine_hot_only = False
    rstrs_refined = 1
    max_molpdf = 100e3
    optimize_output = 'NO_REPORT'
    max_var_iterations = 200
    trace_output = 10
    accelrys = False
    alnfile = ''
    knowns = None
    sequence = None
    inifile = ''
    csrfile = ''
    schfile = ''
    generate_method = None
    rand_method = None
    md_level = None
    assess_methods = None
    outputs = None

    def __init__(self, env, alnfile, knowns, sequence, 
                 deviation=None, library_schedule=None, toplib=None,
                 parlib=None, topology_model=None, csrfile=None, inifile=None,
                 assess_methods=None):
        model.__init__(self, env, None, None)
        self.alnfile = alnfile
        if type(knowns) is tuple:
            knowns = list(knowns)
        if type(knowns) is not list:
            knowns = [ knowns ]
        self.knowns = knowns
        self.sequence = sequence
        self.assess_methods = assess_methods
        self.set_defaults()
        libs = self.env.libs
        if libs.topology.in_memory or libs.parameters.in_memory:
          libs.topology.clear()
          libs.parameters.clear()
          print "automodel__W> Topology and/or parameter libraries already", \
                "in memory. These will\n", \
                "              be overwritten. To specify the", \
                "libraries to use for automodel,\n", \
                "              set the toplib", \
                "and/or parlib arguments in the constructor."
        if deviation:
            self.deviation = deviation
        if library_schedule:
            self.library_schedule = library_schedule
        if toplib:
            self.toplib = toplib
        if parlib:
            self.parlib = parlib
        if topology_model:
            self.topology_model = topology_model
        if csrfile:
            self.csrfile = csrfile
            self.create_restraints = False
        if inifile:
            self.inifile = inifile
            self.generate_method = generate.read_xyz

    def make(self, exit_stage=0):
        """Build all models"""

        self.outputs = []
        self.homcsr(exit_stage)
        # Exit early?
        if exit_stage >= 1:
            return
        # Read all restraints once for the whole job (except when loops are
        # done when restraints are read for each *.B9999???? model):
        self.rd_restraints()

        # getting model(s) (topology library must be in memory; ensured
        # now by one of the three GENERATE_METHOD routines):
        self.multiple_models()

        if self.final_malign3d:
            self.fit_models_on_template()

        self.write_summary(self.outputs, 'models')

    def set_defaults(self):
        """Set most default variables"""
        self.inifile = self.sequence + '.ini'
        self.csrfile = self.sequence + '.rsr'
        self.schfile = self.sequence + '.sch'
        self.generate_method = generate.transfer_xyz
        self.rand_method = randomize.xyz
        self.md_level = refine.very_fast

    def auto_align(self):
        """Create an initial alignment for fully automated comparative
           modeling. Use only when you have high template-model sequence
           identity."""
        segfile = self.alnfile
        self.alnfile = segfile + '.ali'
        align_strs_seq(self.env, segfile, self.alnfile, self.knowns,
                       self.sequence)

    def very_fast(self):
        """Call this routine before calling 'make()' if you want really fast
           optimization"""
        self.max_ca_ca_distance = 10.0
        self.max_n_o_distance   =  6.0
        self.max_sc_mc_distance =  5.0
        self.max_sc_sc_distance =  4.5
        # Note that all models will be the same if you do not change rand_method
        self.rand_method = None
        self.max_var_iterations = 50
        self.library_schedule   =  7
        self.md_level = None

    def write_summary(self, outputs, modeltyp):
        """Print out a summary of all generated models"""
        ok = []
        failed = []
        for mdl in outputs:
            if mdl['failure']:
                failed.append(mdl)
            else:
                ok.append(mdl)
        if ok:
            self.write_ok_summary(ok, modeltyp)
        if failed:
            self.write_failure_summary(failed, modeltyp)

    def write_ok_summary(self, all, modeltyp):
        """Print out a summary of all successfully generated models"""
        print
        print ">> Summary of successfully produced %s:" % modeltyp
        fields = filter(lambda a: a.endswith(' score'), all[0].keys())
        fields.sort()
        fields = ['molpdf'] + fields
        header = reduce(lambda a,b: a + ' %14s' % b, fields,
                        '%-25s' % 'Filename')
        print header
        print '-' * len(header)
        for mdl in all:
            text = '%-25s' % mdl['name']
            for field in fields:
                text = text + ' %14.5f' % mdl[field]
            print text
        print

    def write_failure_summary(self, all, modeltyp):
        """Print out a summary of all failed models"""
        print
        print ">> Summary of failed %s:" % modeltyp
        for mdl in all:
            print "%-25s %s" % (mdl['name'], mdl['failure'])
        print
      
    def rd_restraints(self):
        """Read all restraints. You can override this in subclasses to read
           additional restraints."""
        self.restraints.clear()
        self.restraints.append(file=self.csrfile)

    def multiple_models(self):
        """Build all models, given all the previously generated restraints"""
        for num in range(self.starting_model, self.ending_model + 1):
            self.single_model(num)

    def single_model(self, num):
        """Build a single optimized model from the initial model"""
        self.switch_trace(file=modfile.default(file_ext='', file_id='.D',
                                               root_name=self.sequence,
                                               id1=0, id2=num))
        # Vary the initial structure
        # Note that you are counting on some MODEL arrays not being deleted by
        # read() (ie the charge(1:natm) array generated by generate_topology())
        self.read(file=self.inifile)
        self.select_atoms()
        if self.rand_method:
            self.rand_method(self)

        self.schedule.make(library_schedule=self.library_schedule)
        self.schedule.write(file=self.schfile)
        self.write_int(0, num)
        filename = modfile.default(file_id='.B', file_ext=self.pdb_ext,
                                   root_name=self.sequence, id1=9999, id2=num)
        out = {'name':filename, 'failure':None}
        try:
            for irepeat in range(0, self.repeat_optimization):
                self.single_model_pass(num)
            self.to_iupac()
        except ModellerError, detail:
            if len(str(detail)) > 0:
                out['failure'] = detail
            else:
                out['failure'] = 'Optimization failed'
        else:
            self.model_analysis(filename, out, num)
        self.outputs.append(out)

    def model_analysis(self, filename, out, num):
        """Energy evaluation and assessment, and write out the model"""
        if self.accelrys:
            for (id, norm) in (('.E', False), ('.NE', True)):
                self.energy(output='LONG ENERGY_PROFILE',
                            normalize_profile=norm,
                            file=modfile.default(file_id=id, file_ext='',
                                                 root_name=self.sequence,
                                                 id1=9999, id2=num))
            # The new request from Lisa/Azat to print out only
            # stereochemical restraint violations (6/24/03):
            # select only stereochemical restraints (maybe add dihedral
            # angles?):
            for (id, norm) in (('.ES', False), ('.NES', True)):
                e = self.energy(output='ENERGY_PROFILE',
                                normalize_profile=norm,
                                schedule_scale=(1, 1, 1, 1, 1, 0, 0, 0, 0,
                                                0, 0, 0, 0, 0, 0, 0, 0, 1,
                                                1, 1, 0, 0, 0, 0, 0, 0, 0,
                                                0, 0, 0, 0, 0, 0),
                                file=modfile.default(file_id=id,
                                                     file_ext='',
                                                     root_name=self.sequence,
                                                     id1=9999, id2=num))
        else:
            e = self.energy(output='LONG VIOLATIONS_PROFILE',
                            file=modfile.default(file_id='.V', file_ext='',
                                                 root_name=self.sequence,
                                                 id1=9999, id2=num))

        out['molpdf'] = e
        self.user_after_single_model()

        # Write the final model
        self.write(file=filename)

        # Do model assessment if requested
        self.assess(self.assess_methods, out)

    def assess(self, methods, out=None):
        """Assess the model using all given methods"""
        assess_list = methods
        if assess_list:
            if not isinstance(assess_list, (tuple, list)):
                assess_list = [ assess_list ]
            for method in assess_list:
                (key,value) = method(self)
                if out:
                    out[key] = value

    def single_model_pass(self, num):
        """Perform a single pass of model optimization"""
        for step in range(0, len(self.schedule)):
            self.schedule.step = step + 1
            self.restraints.unpick_all()
            self.restraints.pick(residue_span_range=(-999, -999))
            molpdf = self.optimize(output=self.optimize_output,
                                   max_iterations=self.max_var_iterations,
                                   residue_span_range=(-999, -999),
                                   trace_output=self.trace_output)
            self.write_int(step + 1, num)
            if molpdf > self.max_molpdf:
                raise ModellerError, \
                      "Obj. func. (%.3f) exceeded max_molpdf (%.3f) " \
                                   % (molpdf, self.max_molpdf)
        self.schedule.step = len(self.schedule)
        self.refine()


    def write_int(self, id1, id2):
        """Write intermediate model file during optimization, if so requested"""
        if self.write_intermediates:
            self.write(file=modfile.default(file_ext=self.pdb_ext, file_id='.B',
                                            root_name=self.sequence, id1=id1,
                                            id2=id2))

    def read_alignment(self, aln=None):
        """Read the template-sequence alignment needed for modeling"""
        if aln is None:
            aln = alignment(self.env)
        aln.clear()
        aln.append(file=self.alnfile, align_codes=self.knowns+[self.sequence])
        return aln

    def homcsr(self, exit_stage):
        """Construct the initial model and restraints"""
        # Check the alignment
        aln = self.read_alignment()

        # Since in general we do not want to loose the original alignment file
        # (which is usually not a temporary scratch file):
        if self.accelrys:
            # Accelrys code here (Azat, you may want to add the .tmp.ali part
            # so that you do not change the input alignment file, unless you
            # want to have it changed here for some other use elsewhere):
            aln.write(file=self.alnfile) # file='.tmp.ali'
            codes = [seq.code for seq in aln]
            aln.read(file=self.alnfile, align_codes=codes)
            # modfile.delete(file='.tmp.ali')

        aln.check()

        # make topology and build/read the atom coordinates:
        self.generate_method(self, aln)

        # exit early?
        if exit_stage == 2:
            return

        # make and write the stereochemical, homology, and special restraints?
        if self.create_restraints:
            self.mkhomcsr(aln)
            self.restraints.condense()
            self.restraints.write(file=self.csrfile,
                                  restraints_format='MODELLER')


    def mkhomcsr(self, aln):
        """Construct typical comparative modeling restraints"""
        rsr = self.restraints
        rsr.clear()
        rsr.make(aln=aln, restraint_type='stereo',
                 spline_on_site=self.spline_on_site,
                 residue_span_range=(0, 99999),
                 distance_rsr_model=5, restraint_group=26)

        rsr.make(aln=aln, restraint_type='phi-psi_binormal',
                 spline_on_site=self.spline_on_site,
                 residue_span_range=(0, 99999), distance_rsr_model=5,
                 restraint_group=26)

        for type in ['omega', 'chi1', 'chi2', 'chi3', 'chi4']:
            rsr.make(aln=aln, restraint_type=type+'_dihedral', spline_range=4.0,
                     spline_dx=0.3, spline_min_points=5,
                     spline_on_site=self.spline_on_site,
                     residue_span_range=(0, 99999), distance_rsr_model=5,
                     restraint_group=26)

        
        # Only do the standard residue types for CA, N, O, MNCH, SDCH dst rsrs
        # (no HET or BLK residue types):
        for (dmodel, maxdis, rsrrng, rsrsgn, rsrgrp, typ2, typ3, stdev) in \
            ((5, self.max_ca_ca_distance, (2, 99999), True, 9, 'CA', 'CA',
              (0, 1.0)),
             (6, self.max_n_o_distance,   (2, 99999), False, 10, 'N', 'O',
              (0, 1.0)),
             (6, self.max_sc_mc_distance, (1, 2), False, 23, 'SDCH', 'MNCH',
              (0.5, 1.5)),
             (6, self.max_sc_sc_distance, (2, 99999), True, 26, 'SDCH', 'SDCH',
              (0.5, 2.0))):
            for (set, attyp) in ((2, typ2), (3, typ3)):
                self.pick_atoms(aln, pick_atoms_set=set, atom_types=attyp,
                                res_types='STD')
            rsr.make(aln=aln, restraint_type='distance',
                     spline_on_site=self.spline_on_site,
                     distance_rsr_model=dmodel, restraint_group=rsrgrp,
                     maximal_distance=maxdis, residue_span_range=rsrrng,
                     residue_span_sign=rsrsgn, restraint_stdev=stdev,
                     spline_range=4.0, spline_dx=0.7, spline_min_points=5) 
        
        # Generate intra-HETATM and HETATM-protein restraints:
        self.hetatm_restraints(aln)

        # Generate intra-BLK and BLK-protein restraints:
        self.blk_restraints(aln)

        # Special restraints have to be called last so that possible cis-proline
        # changes are reflected in the current restraints:
        self.special_restraints(aln)


    def hetatm_restraints(self, aln):
      """Generate intra-HETATM and HETATM-protein restraints"""
      # Note: there are going to be duplicated HETATM-HETATM restraints:
      for (set, rest) in ((2, 'ALL'), (3, 'HET')):
          self.pick_atoms(aln, pick_atoms_set=set, atom_types='ALL',
                          res_types=rest)

      # Select MODEL=7 where you can specify stand. dev. explicitly
      self.restraints.make(aln=aln, restraint_type='distance',
                           distance_rsr_model=7, maximal_distance=7.0,
                           restraint_group=27,
                           spline_on_site=self.spline_on_site,
                           residue_span_range=(0, 99999),
                           residue_span_sign=False,
                           # Inter- and intra- residue:
                           restraint_stdev=(0.2, 0.0))


    def blk_restraints(self, aln):
        """Re-define the auxillary restraints routine to include the restraints
           between the protein and the BLK residues. The BLK atoms will be
           restrained by their distances to the protein CA atoms that are within
           MAXIMAL_DISTANCE angstroms of the selected BLK atoms in the
           templates. Note: this only works because the BLK atoms have unique
           atom names."""
        # To derive restraints from all (one) templates, comment out
        # (uncomment):
        # READ_ALIGNMENT FILE = ALNFILE, ALIGN_CODES = '3b5c' SEQUENCE
        rsr = self.restraints

        # Select MODEL=7 where you can specify stand. dev. explicitly
        rsrtyp = 'distance'
        drmodel = 7
        maxdis = 10.0
        rsrgrp = 27

        # Intra-residue:
        for set in (2, 3):
            self.pick_atoms(aln, pick_atoms_set=set, atom_types='ALL',
                            res_types='BLK')
        rsr.make(aln=aln, restraint_type=rsrtyp, distance_rsr_model=drmodel,
                 maximal_distance=maxdis, spline_on_site=self.spline_on_site,
                 restraint_group=rsrgrp, restraint_stdev=(0.05, 0.0),
                 residue_span_range=(0, 0), residue_span_sign=True)

        # Inter-residue:
        # There may be some duplicated CA BLK - CA BLK restraints:
        for (set, atype, rtype) in ((2, 'CA', 'ALL'), (3, 'ALL', 'BLK')):
            self.pick_atoms(aln, pick_atoms_set=set, atom_types=atype,
                            res_types=rtype)
        rsr.make(aln=aln, restraint_type=rsrtyp, distance_rsr_model=drmodel,
                 maximal_distance=maxdis, spline_on_site=self.spline_on_site,
                 restraint_group=rsrgrp, restraint_stdev=(0.2, 0.0),
                 residue_span_range=(1, 99999), residue_span_sign=False)

    def special_restraints(self, aln):
         """This can be redefined by the user to add special restraints.
            In this class, it does nothing."""
         pass

    def special_patches(self, aln):
         """This can be redefined by the user to add additional patches
            (for example, user-defined disulfides). In this class, it
            does nothing."""
         pass

    def read_top_par(self):
        """Read in the topology and parameter libraries"""
        libs = self.env.libs
        if not (libs.topology.in_memory and libs.parameters.in_memory):
            libs.topology.read(file=self.toplib)
            if self.topology_model is not None:
                libs.topology.submodel = self.topology_model
            libs.parameters.read(file=self.parlib)

    def create_topology(self, aln, sequence=None):
        """Build the topology for this model"""
        if sequence is None:
            sequence = self.sequence
        self.generate_topology(aln, sequence=sequence, add_segment=False)
        self.default_patches(aln)
        self.special_patches(aln)

    def default_patches(self, aln):
        """Derive the presence of disulfides from template structures (you can
           still define additional disulfides in the special_patches routine)"""
        self.patch_ss_templates(aln)

    def select_atoms(self):
        """Select atoms to be optimized in the model building procedure. By
           default, this selects all atoms, but you can redefine this routine
           to select a subset instead."""
        self.pick_atoms(selection_segment=('@:@', 'X:X'),
                        selection_search='segment', pick_atoms_set=1,
                        res_types='all', atom_types='all', selection_from='all',
                        selection_status='initialize')

    def initial_refine_hot(self):
        """Do some initial refinement of hotspots in the model"""
        viol_rc = [999] * 31
        if self.rstrs_defined == 0:
            # Refine only hotspots that have badly violated stereochemical
            # restraints:
            viol_rc = [4, 4, 4, 4, 4, 999, 999, 999, 999, 999, 999, 999,
                       999, 999, 999, 999, 999, 4, 4, 4, 999, 999, 999,
                       999, 999, 999, 999, 999, 999, 999, 999]
        elif self.rstrs_defined == 1:
            # Refine hotspots that have badly violated stereochemical
            # restraints and the important homology-derived restraints:
            viol_rc = [4, 4, 4, 4, 4, 999, 999, 999, 4, 4, 999, 999, 4,
                       999, 999, 999, 999, 4, 4, 4, 999, 999, 4, 999,
                       4, 4, 999, 999, 999, 999]
        elif self.rstrs_defined == 2:
            # Refine hotspots that have badly violated any kind of
            # restraints
            viol_rc = [4] * 31

        # Pick hot atoms (must use RESIDUE mode because of sidechains):
        self.pick_hot_atoms(pick_hot_cutoff=4.5,
                            selection_mode='RESIDUE',
                            extend_hot_spot=0, viol_report_cut=viol_rc)
        # Pick all corresponding (violated and others) restraints:
        self.restraints.unpick_all()
        self.restraints.pick(residue_span_range=(-999, -999))

        # Local optimization to prevent MD explosions:
        self.optimize(optimization_method=1,
                      max_iterations=100, output=self.optimize_output,
                      residue_span_range=(-999, -999),
                      trace_output=self.trace_output)

    def final_refine_hot(self):
        """Do some final refinement of hotspots in the model"""
        # Get conjugate gradients refined hot spots:
        self.optimize(optimization_method=1,
                      max_iterations=200, output=self.optimize_output,
                      residue_span_range=(-999, -999),
                      trace_output=self.trace_output)

        # Get all static restraints again and select all atoms
        self.select_atoms()
        self.restraints.unpick_all()
        self.restraints.pick(residue_span_range=(-999, -999))

    def refine(self):
        """Refine the optimized model with MD and CG"""
        # Save the current model:
        if self.fit_in_refine != 'NO_FIT':
            self.write(file='TO_BE_REFINED.TMP')

        # Possibly skip selecting hot atoms only and optimize all atoms:
        if self.refine_hot_only:
            self.initial_refine_hot()

        # Do simulated annealing MD:
        if self.md_level:
            self.md_level(self)

        # Possibly skip 'HOT CG' after MD:
        if self.refine_hot_only:
            self.final_refine_hot()

        # Get a final conjugate gradients refined structure:
        self.optimize(optimization_method=1,
                      max_iterations=200, output=self.optimize_output,
                      residue_span_range=(-999, -999),
                      trace_output=self.trace_output)

        # Evaluate gross changes between the initial and final refined model:
        if 'NO_FIT' not in self.fit_in_refine:
            aln = alignment(self.env)
            mdl2 = read_model(file='TO_BE_REFINED.TMP')
            self.pick_atoms(atom_types='CA', pick_atoms_set=1)
            self.superpose(mdl2, aln)
            self.pick_atoms(atom_types='ALL', pick_atoms_set=1)
            self.superpose(mdl2, aln)
            self.select_atoms()
            modfile.delete(self.env.output_directory + '/TO_BE_REFINED.TMP')

    def user_after_single_model(self):
        """Used for any user analysis after building each model. Redefine as you
           see fit."""
        pass

    def cluster(self, cluster_cut=1.5):
        """Cluster all output models, and output an optimized cluster average"""
        self.read(file=self.inifile)
        aln = alignment(self.env)
        aln.append_model(mdl=self, align_codes=self.sequence)
        aln.expand(root_name=self.sequence, file_id='.B', file_ext='.pdb',
                   expand_control=(9999, 9999, self.starting_model,
                                   self.ending_model, 0))
        aln.malign3d(gap_penalties_3d=(0, 3), fit=False)
        aln.append_model(mdl=self, align_codes='cluster',
                         atom_files='cluster.opt')
        self.transfer_xyz(aln, cluster_cut=cluster_cut)
        self.write(file='cluster.ini')
        self.read_top_par()
	self.rd_restraints()
        self.create_topology(aln, sequence='cluster')
        self.select_atoms()
        self.restraints.unpick_all()
        self.restraints.pick()
        self.restraints.condense()
        edat = energy_data(copy=self.env.edat)
        edat.nonbonded_sel_atoms = 1
        self.energy(output='LONG', edat=edat)
        self.switch_trace(file='cluster.deb')
        self.optimize(trace_output=5, max_iterations=self.max_var_iterations)
        self.energy()
        self.write(file='cluster.opt')
        aln.compare_structures(fit=True)

    def fit_models_on_template(self):
        """Superpose each of the generated models on the templates"""
        aln = self.read_alignment()
        aln.expand(expand_control=(9999, 9999, self.starting_model,
                                   self.ending_model, 0),
                   root_name=self.sequence, file_ext=self.pdb_ext, file_id='.B')
        # To take care of the '.' in segment specs:
        aln.write(file='.tmp.ali', alignment_format='PIR')
        codes = [seq.code for seq in aln]
        aln.read(file='.tmp.ali', alignment_format='PIR', align_codes=codes)
        modfile.delete('.tmp.ali')

        aln.compare_structures(fit=True, output='SHORT', fit_atoms='CA')
        aln.malign3d(gap_penalties_3d=(0, 3), write_whole_pdb=False,
                     write_fit=True, fit=False, fit_atoms='CA',
                     current_directory=True,
                     edit_file_ext=(self.pdb_ext, '_fit.pdb'),
                     output='NO_REPORT')
