import sys
from modeller import modfile
from modeller.util.logger import log
from modeller.top_interpreter import util
from modeller.top_interpreter.topcmds import topcmds
import _modeller

class commands(topcmds):
    def __init__(self):
        if hasattr(sys, 'dllhandle'):
            # Only true on Windows systems
            self.__dirsep = '\\'
        else:
            self.__dirsep = '/'

    def default_file(self, varname, deffile_id, exts=None):
        file = self.vars[varname]
        if file.upper().find('DEFAULT') >= 0:
            root_name = self.vars['root_name']
            file_id = self.vars['file_id']
            id1 = self.vars['id1']
            id2 = self.vars['id2']
            file_ext = self.vars['file_ext']
            if file_id.upper().find('DEFAULT') >= 0:
                file_id = deffile_id
            file = modfile.default(root_name, file_id, id1, id2, file_ext)
            self.vars[varname] = file
        if exts is not None:
            directory = self.vars['directory']
            file = _modeller.fullfn(file, directory, exts)
            self.vars[varname] = file

    def add_outdir(self, file):
        """Prepend |output_directory| to |file|."""
        file = self.vars[file]
        outdir = self.vars['output_directory']
        if len(outdir) == 0 or outdir.endswith(self.__dirsep) \
           or file.startswith(self.__dirsep):
            return outdir + file
        else:
            return outdir + self.__dirsep + file

    def __pickall(self):
        """Create the implicit 'all atoms' selection (set 1) for MODEL"""
        mdl = self.get_mdl(1)
        self.vars['sel1'] = _modeller.select_all(mdl)

    def __get_sched(self):
        mdl = self.get_mdl(1)
        return _modeller.model_sched_get(mdl)

    def __check_schedule(self):
        _modeller.check_schedule(self.__get_sched())

    def __get_optimization_method(self):
        ometh = self.vars['optimization_method']
        if ometh == -999:
            ometh = _modeller.get_schedule_optimizer(self.__get_sched())
        return ometh

    def __update_schedule(self):
        return _modeller.update_schedule(self.__get_sched(),
                                         self.vars['residue_span_range'])

    def __process_seg(self, res):
        """Look for a given code in the alignment"""
        if ":" not in res and len(res) > 0 and res[0] != '!':
            aln = self.get_aln()
            iseq = _modeller.find_alignment_code(aln, res)
            if iseq < 0:
                log.error('getchnrng',
                          "There is no such CODE in the alignment: %s" % res)
            return (aln, iseq)
        return (None, None)

    def __seg_from_aln(self, model_segment):
        """Possibly update file and model_segment with alignment information"""

        file = self.vars['file']
        model_segment = self.vars[model_segment]
        # Get PDB name from alignment?
        (aln, iseq) = self.__process_seg(model_segment[1])
        if aln:
            file = _modeller.alignment_atom_files_get(aln, iseq)
        # Get segment range from alignment?
        (aln, iseq) = self.__process_seg(model_segment[0])
        if aln:
            seq = _modeller.alignment_sequence_get(aln, iseq)
            model_segment = [_modeller.sequence_rng_get(seq, i) for i in (0,1)]

        return file, model_segment

    def __parse_residue_ids(self, mdl, resids):
        """Given string residue IDs, return integers"""
        num_resids = [_modeller.iresind(r, mdl) for r in resids]
        for (n, id) in enumerate(num_resids):
            if id <= 0:
                log.error("iresind",
                          "Residue identifier not found: %s" % resids[n])
        return num_resids

    def __make_sstruc_restraints(self, modfunc):
        """Make alpha or strand restraints"""
        mdl = self.get_mdl()
        libs = self.get_libs()
        resids = self.__parse_residue_ids(mdl, self.vars['residue_ids'])
        modfunc(mdl, resids, libs)

    def __make_sheet_restraints(self):
        """Make sheet restraints"""
        mdl = self.get_mdl()
        atids = self.vars['atom_ids']
        num_atids = [_modeller.indxatm2(a, mdl) for a in atids]
        for (n, id) in enumerate(num_atids):
            if id <= 0:
                log.error("indxatm2",
                          "Atom identifier not found: %s" % atids[n])
        _modeller.make_sheet_restraints(mdl, num_atids,
                                        self.vars['sheet_h_bonds'])

    def write_model(self):
        self.default_file('file', '.B')
        topcmds.write_model(self, file=self.add_outdir('file'))
    def write_model2(self):
        self.default_file('file', '.B')
        topcmds.write_model2(self, file=self.add_outdir('file'))
    def build_model(self):
        topcmds.build_model(self)
        self.__pickall()
    def read_restraints(self):
        if not self.vars['add_restraints']:
            _modeller.clear_restraints(self.get_mdl())
        self.default_file('file', '.C', ':.rsr')
        topcmds.read_restraints(self)
    def write_restraints(self):
        self.default_file('file', '.C')
        topcmds.write_restraints(self, file=self.add_outdir('file'))
    def read_model(self):
        self.default_file('file', 'B')
        (file, model_segment) = self.__seg_from_aln('model_segment')
        topcmds.read_model(self, file=file, model_segment=model_segment)
        self.__pickall()
    def read_model2(self):
        self.default_file('file', '.B')
        (file, model2_segment) = self.__seg_from_aln('model2_segment')
        topcmds.read_model2(self, file=file, model2_segment=model2_segment)
    def segment_matching(self):
        topcmds.segment_matching(self, root_name=self.outdir('root_name'))
    def switch_trace(self):
        self.default_file('file', '.D')
        topcmds.switch_trace(self)
    def read_schedule(self):
        self.default_file('file', 'S', ':.sch')
        topcmds.read_schedule(self)
    def write_schedule(self):
        self.default_file('file', '.S')
        topcmds.write_schedule(self)
    def id_table(self):
        self.default_file('matrix_file', '.M')
        topcmds.id_table(self, matrix_file=self.add_outdir('matrix_file'))
    def write_topology_model(self):
        self.default_file('file', '.toplib')
        topcmds.write_topology_model(self, file=self.add_outdir('file'))
    def write_alignment(self):
        self.default_file('file', '.A')
        topcmds.write_alignment(self, file=self.add_outdir('file'))
    def sequence_comparison(self):
        self.default_file('rr_file', '.A', ':.mat')
        self.default_file('matrix_file', '.M')
        self.default_file('variability_file', '.V')
        topcmds.sequence_comparison(self,
            matrix_file=self.add_outdir('matrix_file'),
            variability_file=self.add_outdir('variability_file'))
    def sequence_to_ali(self):
        aln = self.get_aln()
        if not self.vars['add_sequence']:
            _modeller.delete_alignment(aln)
        nseq = _modeller.alignment_nseq_get(aln)
        try:
            align_codes = self.vars['align_codes'][nseq]
        except IndexError:
            align_codes = ''
        try:
            atom_files = self.vars['atom_files'][nseq]
        except IndexError:
            atom_files = ''
        topcmds.sequence_to_ali(self, align_codes=align_codes,
                                atom_files=atom_files)
    def read_topology(self):
        if not self.vars['add_topology']:
            _modeller.clear_topology(self.get_tpl(), self.get_libs())
        self.default_file('file', '.T', ':.lib')
        topcmds.read_topology(self)
    def generate_topology(self):
        seq = self.vars['sequence']
        iseq = _modeller.find_alignment_code(self.get_aln(), seq)
        if iseq < 0:
            log.error('generate_topology',
                      "Sequence '%s' not found in alignment" % seq)
        if not self.vars['add_segment']:
            _modeller.clear_model_topology(self.get_mdl())
        topcmds.generate_topology(self, iseq=iseq)
    def energy(self):
        self.default_file('file', 'P')
        span = self.__update_schedule()
        topcmds.energy(self, residue_span_range=span,
                       file=self.add_outdir('file'))
    def optimize(self):
        span = self.__update_schedule()
        self.__check_schedule()
        ometh = self.__get_optimization_method()
        topcmds.optimize(self, residue_span_range=span,
                         optimization_method=ometh)
    def debug_function(self):
        span = self.__update_schedule()
        topcmds.debug_function(self, residue_span_range=span)
    def spline_restraints(self):
        span = self.__update_schedule()
        topcmds.spline_restraints(self, residue_span_range=span)
    def pick_hot_atoms(self):
        span = self.__update_schedule()
        topcmds.pick_hot_atoms(self, residue_span_range=span)
    def make_schedule(self):
        span = self.__update_schedule()
        topcmds.make_schedule(self, residue_span_range=span)
    def make_restraints(self):
        if not self.vars['add_restraints']:
            _modeller.clear_restraints(self.get_mdl())
        span = self.__update_schedule()
        rsrtype = self.vars['restraint_type'].lower()
        if rsrtype == 'alpha':
            self.__make_sstruc_restraints(_modeller.make_alpha_restraints)
        elif rsrtype == 'strand':
            self.__make_sstruc_restraints(_modeller.make_strand_restraints)
        elif rsrtype == 'sheet':
            self.__make_sheet_restraints()
        else:
            topcmds.make_restraints(self, residue_span_range=span)
    def pick_restraints(self):
        if not self.vars['add_restraints']:
            mdl = self.get_mdl()
            rsr = _modeller.model_rsr_get(mdl)
            _modeller.unpick_all_restraints(rsr)
        span = self.__update_schedule()
        topcmds.pick_restraints(self, residue_span_range=span)
    def align(self):
        self.default_file('rr_file', '.A', ':.mat')
        topcmds.align(self)
    def sequence_search(self):
        self.default_file('rr_file', '.mat', ':.mat')
        topcmds.sequence_search(self)
    def malign(self):
        self.default_file('rr_file', '.A', ':.mat')
        topcmds.malign(self)
    def principal_components(self):
        self.default_file('matrix_file', '.G')
        self.default_file('file', '.dat')
        topcmds.principal_components(self)
    def align2d(self):
        self.default_file('rr_file', '.A', ':.mat')
        topcmds.align2d(self)
    def write_pdb_xref(self):
        self.default_file('file', '.X')
        (file, model_segment) = self.__seg_from_aln('model_segment')
        topcmds.write_pdb_xref(self, model_segment=model_segment,
                               file=self.add_outdir('file'))
    def build_profile(self):
        self.default_file('rr_file', '.mat', ':.mat')
        topcmds.build_profile(self)
    def read_sequence_db(self):
        chains = util.get_topvars('chains_list', self.vars)
        if chains.upper() != 'ALL':
            self.default_file('chains_list', 'H', ':.list')
        self.default_file('seq_database_file', 'H', ':.seq')
        topcmds.read_sequence_db(self)
    def write_sequence_db(self):
        self.default_file('chains_list', 'H')
        self.default_file('seq_database_file', 'H')
        topcmds.write_sequence_db(self)
    def seqfilter(self):
        self.default_file('rr_file', '.mat', ':.mat')
        topcmds.seqfilter(self)
    def open(self):
        self.default_file('objects_file', 'TOP')
        topcmds.open(self)
    def inquire(self):
        self.default_file('file', '.B')
        topcmds.inquire(self)
    def read_alignment(self):
        if not self.vars['add_sequence']:
            _modeller.delete_alignment(self.get_aln())
        self.default_file('file', '.A', ':.ali')
        topcmds.read_alignment(self)
    def read_alignment2(self):
        if not self.vars['add_sequence']:
            _modeller.delete_alignment(self.get_aln(2))
        self.default_file('file', '.A', ':.ali')
        topcmds.read_alignment2(self)
    def prof_to_aln(self):
        if not self.vars['append_aln']:
            _modeller.delete_alignment(self.get_aln())
        topcmds.prof_to_aln(self)
    def read_parameters(self):
        if not self.vars['add_parameters']:
            # Clear both CHARMM parameters and group restraint parameters
            _modeller.clear_parameters(self.get_prm())
            _modeller.clear_group_restraints_param(self.get_gprsr())
        self.default_file('file', '.toplib', ':.lib')
        topcmds.read_parameters(self)
    def delete_restraint(self):
        atom_names = util.get_topvars('atom_ids', self.vars)
        mdl = self.get_mdl(1)
        atom_ids = [_modeller.indxatm2(a, mdl) for a in atom_names]
        if 0 in atom_ids:
            log.warning("delete_restraint",
                        "One or more atoms absent from MODEL:  " \
                        + " ".join(atom_names))
        else:
            topcmds.unpick_restraints(self, atom_ids=atom_ids)
    def patch(self):
        mdl = self.get_mdl()
        resids = self.__parse_residue_ids(mdl, self.vars['residue_ids'])
        topcmds.patch(self, residue_ids=resids)
    def superpose(self):
        topcmds.superpose(self, rms_cutoff=self.vars['rms_cutoffs'][0])
    def rotate_model(self):
        mdl = self.get_mdl(1)
        inds = _modeller.select_all(mdl)
        translation = util.get_topvars('translation', self.vars)
        _modeller.translate_selection(mdl, inds, translation)
        matrix = util.get_topvars('rotation_matrix', self.vars)
        _modeller.transform_selection(mdl, inds, matrix)
        axis = util.get_topvars('rotation_axis', self.vars)
        angle = util.get_topvars('rotation_angle', self.vars)
        _modeller.rotate_axis_selection(mdl, inds, axis, angle)
