#!/usr/bin/python
# This file is part of ModPipe, Copyright 1997-2020 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

from optparse import OptionParser
import modpipe.version
from modpipe.cdhit import CDHit
import sys, os, re

def pick_chain_with_highest_resolution(representative, members, resol):
    """Given a list of codes and a hash-table connecting codes
    with resolutions, this will return the code with the highest
    resolution."""
    refcode = representative
    refresol = resol[representative]
    lowres = 100
    for code in members:
        if resol[code] > 0 and resol[code] < lowres:
            (refcode, refresol) = (code, resol[code])
            lowres = refresol
    return refcode, refresol

def clusterpdb(seqallpir, outbase, sequence_identity,
               throw_away_seq_length, length_difference):
    """Take a file containing sequences extracted from PDB files (done
    by Modeller) and cluster them using CD-HIT and return representatives
    and groups, where the representatives are chosen as the structure with
    the highest resolution."""
    def iteritems(d):
        return d.iteritems() if sys.version_info[0] == 2 else d.items()

    # Initialize Modeller
    import modeller
    modeller.log.verbose()
    env = modeller.Environ()

    sdb = modeller.SequenceDB(env)
    sdb.read(seq_database_file=seqallpir, seq_database_format='pir',
              chains_list='all')

    # Create fasta filename
    (dirname, filename) = os.path.split(seqallpir)
    (basename, extension) = os.path.splitext(filename)
    seqallfsa = basename + '.fsa'
    reprcdhit = basename + str(sequence_identity)

    # Create a few output filenames
    if outbase == '':
        grpfile = reprcdhit + '.grp'
        codfile = reprcdhit + '.cod'
        seqfile = reprcdhit + '.pir'
    else:
        grpfile = outbase + '.grp'
        codfile = outbase + '.cod'
        seqfile = outbase + '.pir'

    # Write fasta file
    sdb.write(seq_database_file=seqallfsa, seq_database_format='fasta',
              chains_list='all')

    # Create a hash-lookup of resolutions
    resol = {}
    for s in range(0, len(sdb)):
        resol[sdb[s].code] = sdb[s].resol
    del sdb

    # Cluster at specified sequence identity using CD-HIT
    cdh = CDHit(seqallfsa)
    cdh.cluster(reprcdhit, sequence_identity, throw_away_seq_length,
              length_difference)
    clusters = cdh.parse_clusters()

    new_clusters = {}
    for r in clusters.keys():
        x, y = pick_chain_with_highest_resolution(r, clusters[r], resol)
        new_clusters[x] = clusters[r]

    # Write out clusters
    grpf = open(grpfile, 'w')
    for rep, members in iteritems(new_clusters):
        grpf.write("%s : %s\n" % (rep, ' '.join(members)))
    grpf.close()

    # Write out codes
    codf = open(codfile, 'w')
    for rep in new_clusters.keys():
        codf.write(rep + '\n')
    codf.close()

    # Write out representative sequences
    input = modeller.modfile.File(seqallpir, 'r')
    output = modeller.modfile.File(seqfile, 'w')
    aln = modeller.Alignment(env)
    while aln.read_one(input, alignment_format='PIR'):
        if aln[0].code in new_clusters:
            aln.write(output, alignment_format='PIR')


def main():
    parser = OptionParser(version=modpipe.version.message())
    parser.set_usage("""This script takes a file with sequences, clusters
 them at a specified sequence identity threshold and returns the
 clusters and representatives. The latter is chosen as the structure
 with the best resolution. This should normally be used only for sequences
 extracted from the PDB, but could be used for other sequences as well with
 some modifications.

 Usage: %prog [options] seqfile

 seqfile is a PIR format file containing sequences to be clustered.

 Run `%prog -h` for help information

 """)

    parser.set_defaults(outbase='',
                        seqidthresh=95,
                        minseqlength=30,
                        lendiff=10,
                       )

    parser.add_option("-f", "--base_filename",
                 dest="outbase",
                 type='string',
                 help="""Basename for output files""",
                 metavar="FILE")
    parser.add_option("-t", "--sequence_identity_threshold",
                 dest="seqidthresh",
                 type='int',
                 help="""Threshold sequence identity for clustering""",
                 metavar="VALUE")
    parser.add_option("-l", "--min_sequence_length",
                 dest="minseqlength",
                 type='int',
                 help="""Minimum length of sequences to include during
                      clustering. Sequences shorter than this limit will
                      not be considered""",
                 metavar="VALUE")
    parser.add_option("-d", "--max_length_difference",
                 dest="lendiff",
                 type='int',
                 help="""Maximum difference in length to be considered
                      when clustering into the same group. A difference
                      in length greater than this value will trigger a
                      new cluster""",
                 metavar="VALUE")

    opts, args = parser.parse_args()

    if len(args) != 1:
        parser.error("You must specify a PIR file containing sequences")
    seqallpir = args[0]

    clusterpdb(seqallpir, opts.outbase, opts.seqidthresh,
               opts.minseqlength, opts.lendiff)


if __name__ == "__main__":
    main()
