# This file is part of ModPipe, Copyright 1997-2020 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

import sys
from modeller import *
import operator

class ClusterModels(object):

    def __init__(self, file_list, models_dir, sort_models=True, rms_cutoff=3.5,
                 min_equiv_pos=30, max_nonequiv_pos=30, so_cutoff=35.0):
        self.file_list = file_list
        self.modenv = Environ()
        self.modenv.io.atom_files_directory = models_dir
        self.models = []
        self.modlen = {}
        self.sort_models = sort_models
        self.rms_cutoff = rms_cutoff
        self.min_equiv_pos = min_equiv_pos
        self.max_nonequiv_pos = max_nonequiv_pos
        self.so_cutoff = so_cutoff
        self.clusters = {}
        self.comparisons = []
        self.compare_elements = self.compare_models

    def read_model_list(self):
        f = open(self.file_list, 'r')
        for line in f:
            line = line.rstrip()
            self.models.append(line)

    def sort_models_by_length(self):
        def iteritems(d):
            return d.iteritems() if sys.version_info[0] == 2 else d.items()
        for mod in self.models:
            mdl = Model(self.modenv, file=mod)
            self.modlen[mod] = mdl.nres

        self.models = [ mod for mod, len in sorted(iteritems(self.modlen),
                              key=operator.itemgetter(1), reverse=True) ]

    def get_model(self, mod):
        mdl = Model(self.modenv, file=mod)
        mdl.name = mod
        return mdl


    def get_pairwise_alignment(self, model1, model2):
        aln = Alignment(self.modenv)
        aln.append_model(model1, align_codes=model1.name, atom_files=model1.name)
        aln.append_model(model2, align_codes=model2.name, atom_files=model2.name)
        aln.align(gap_penalties_1d=(-3000, -1000), matrix_offset=0,
                  local_alignment=True, rr_file='$(LIB)/id.sim.mat')
        return aln

    def get_structure_overlap(self, aln, model1, model2):
        sel = Selection(model1).only_atom_types('CA')
        sup = sel.superpose(model2, aln, rms_cutoff=self.rms_cutoff)
        return sup.num_equiv_pos, sup.num_equiv_cutoff_pos

    def check_similarity(self, num_equiv_pos, num_equiv_cutoff_pos,
                         num_nonequiv_pos):
        sim = False
        if num_equiv_pos >= self.min_equiv_pos and \
           num_nonequiv_pos <= self.max_nonequiv_pos and \
           100*num_equiv_cutoff_pos/num_equiv_pos >= self.so_cutoff:
            sim = True
        return sim

    def compare_models(self, mod1, mod2):
        mdl1 = self.get_model(mod1)
        mdl2 = self.get_model(mod2)

        aln = self.get_pairwise_alignment(mdl1, mdl2)
        num_equiv_pos, num_equiv_cutoff_pos = \
             self.get_structure_overlap(aln, mdl1, mdl2)
        num_nonequiv_pos = abs(self.modlen[mod2] - num_equiv_pos)
        similarity = self.check_similarity(num_equiv_pos, num_equiv_cutoff_pos,
                                           num_nonequiv_pos)
        self.comparisons.append((mod1, self.modlen[mod1], mod2,
                                 self.modlen[mod2], aln[0].get_num_equiv(aln[1]),
                                 num_equiv_pos, num_equiv_cutoff_pos,
                                 num_nonequiv_pos, similarity))
        return similarity

    def cluster(self):

        rep_stack = []
        org_stack = self.models[:]
        while len(org_stack) > 0:
            # Pop the first element from the org_stack
            # as a representative
            rep_stack.append(org_stack[0])

            # Initialize the auxillary stack
            aux_stack = []

            # Initialize the current representative
            rep = rep_stack[-1]
            self.clusters[rep] = []

            # Compare all elements of the original stack
            # against the current representative. Store
            # the ones that do not match in the aux stack
            for org in org_stack:
                if self.compare_elements(rep, org):
                    self.clusters[rep].append(org)
                else:
                    aux_stack.append(org)

            # Replace the org_stack with elements that did
            # not cluster in this iteration
            org_stack = aux_stack[:]

    def make(self):
        self.read_model_list()
        self.sort_models_by_length()
        self.compare_models(self.models[0], self.models[-1])
        self.cluster()

    def write_clusters(self, file=None):
        f = open(file, 'w')
        for rep, mem in self.clusters.items():
            f.write('%s :(%8d): %s\n' % (rep, len(mem), ' '.join(mem)))
        f.close()

    def write_representatives(self, file=None):
        f = open(file, 'w')
        for rep in self.clusters.keys():
            f.write('%s\n' % rep)
        f.close()

    def write_largest_cluster(self, file=None):
        f = open(file, 'w')
        L = sorted([(rep, len(mem)) for rep, mem in self.clusters.items()],
                   key=operator.itemgetter(1), reverse=True)
        rep_largest_cluster = L[0][0]
        for mem in self.clusters[rep_largest_cluster]:
            f.write('%s\n' % mem)
        f.close()

    def write_comparisons(self, file=None):
        f = open(file, 'w')
        for comp in self.comparisons:
            f.write('%s %5d %s %5d %5d %5d %5d %5d %s\n' %
                (comp[0], comp[1], comp[2], comp[3], comp[4],
                 comp[5], comp[6], comp[7], comp[8]))
        f.close()
