# This file is part of ModPipe, Copyright 1997-2020 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.

"""Handling of amino acid sequences and common file formats"""

from __future__ import print_function
import modpipe
import re
import os
import sys
# Use hashlib module if available (Python 2.5 or later) since md5 is deprecated
try:
    import hashlib
except ImportError:
    import md5 as hashlib

class Sequence(object):
    """Representation of a single amino acid sequence"""

    def __init__(self):
        self.prottyp = 'sequence'
        self.range = [['', ''], ['', '']]
        self.atom_file = self.name = self.source = ''
        self.resolution = self.rfactor = ''

    def get_id(self):
        """Return the ModPipe sequence identifier"""
        m = hashlib.md5()
        if sys.version_info[0] == 2:
            m.update(self.primary)
        else:
            m.update(self.primary.encode('utf-8'))
        return m.hexdigest() + self.primary[:4] + self.primary[-4:]

    def clean(self, keep_char = ''):
        """Clean up the primary sequence"""
        s = self.primary
        s = s.replace('.', 'G')     # replace BLK with GLY
        s = re.sub(r'\W+', '', s)    # remove non-word char
        s = re.sub('_+', '', s)     # remove underscore, since not covered by \W
        s = re.sub(r'\d+', '', s)    # remove numbers
        s = s.replace('*', '')      # remove asterisks
        s = re.sub(r'\s+', '', s)    # remove spaces
        s = s.upper()               # convert to uppercase
        s = s.replace('B', 'N')     # convert ASX to ASN
        s = s.replace('Z', 'Q')     # convert GLX to GLN

        # convert everything else to GLY

        if keep_char is None:
            keep= '[^ACDEFGHIKLMNPQRSTVWY]'
        else:
            keep= '[^ACDEFGHIKLMNPQRSTVWY' + keep_char + ']'

        s = re.sub(keep, 'G', s)
        self.primary = s


class FASTAFile(object):
    """Representation of a FASTA-format file"""

    def read(self, fh):
        """Read sequences from the given stream in FASTA format. A list of
           the sequences is returned, as :class:`Sequence` objects."""
        seq = None
        for (num, line) in enumerate(fh):
            if line.startswith('>'):
                if seq:
                    yield seq
                seq = Sequence()
                seq.primary = ''
                seq.code = line.rstrip()[1:]
            else:
                line = line.rstrip()
                if line: # Skip blank lines
                    if seq is None:
                        raise modpipe.FileFormatError( \
"Found FASTA sequence before first header at line %d: %s" % (num + 1, line))
                    seq.primary += line
        if seq:
            yield seq

    def write(self, fh, seq, width=70):
        """Write a single :class:`Sequence` object to the given stream in
           FASTA format."""
        print(">" + seq.code, file=fh)
        for pos in range(0, len(seq.primary), width):
            print(seq.primary[pos:pos+width], file=fh)


class PIRFile(object):
    """Representation of a PIR-format file"""

    def _parse_pir_header(self, num, line, seq):
        seq.primary = ''
        spl = line.rstrip().split(':')
        if len(spl) != 10:
            raise modpipe.FileFormatError( \
"Invalid PIR header at line %d (expecting 10 fields split by colons): %s" \
% (num + 1, line))
        (seq.prottyp, seq.atom_file, seq.range[0][0], seq.range[0][1],
         seq.range[1][0], seq.range[1][1], seq.name, seq.source,
         seq.resolution, seq.rfactor) = spl
        if seq.prottyp == '':
            seq.prottyp = 'sequence'

    def read(self, fh):
        """Read sequences from the given stream in PIR format. A list of
           the sequences is returned, as :class:`Sequence` objects."""
        seq = None
        terminator = re.compile(r'\*\s*$')
        for (num, line) in enumerate(fh):
            if line.startswith('C;') or line.startswith('R;'):
                # Skip comment lines
                continue
            elif line.startswith('>P1;'):
                if seq:
                    raise modpipe.FileFormatError( \
"PIR sequence without terminating * at line %d: %s" % (num + 1, line))
                seq = Sequence()
                seq.primary = None
                seq.code = line.rstrip()[4:]
            elif seq and seq.primary is None:
                self._parse_pir_header(num, line, seq)
            else:
                line = line.rstrip()
                if line:
                    if seq is None:
                        raise modpipe.FileFormatError( \
"PIR sequence found without a preceding header at line %d: %s" \
% (num + 1, line))
                    (line, count) = terminator.subn("", line)
                    seq.primary += line
                    # See if this was the last line in the sequence
                    if count == 1:
                        yield seq
                        seq = None
        if seq:
            raise modpipe.FileFormatError( \
                     "PIR sequence without terminating * at end of file")

    def write(self, fh, seq, width=70):
        """Write a single :class:`Sequence` object to the given stream in
           PIR format."""
        print(">P1;" + seq.code, file=fh)
        start, end = seq.range
        print(":".join(str(x) for x in [seq.prottyp, seq.atom_file,
                                        start[0], start[1], end[0],
                                        end[1], seq.name, seq.source,
                                        seq.resolution, seq.rfactor]), file=fh)
        for pos in range(0, len(seq.primary), width):
            print(seq.primary[pos:pos+width], file=fh)
        print('*', file=fh)


class SPTRFile(object):
    """Representation of a file containing UniProtKB/SwissProt or TrEMBL
       database entries"""

    def read(self, fh):
        """Read sequences from the given stream in SPTR format. A list of
           the sequences is returned, as :class:`Sequence` objects."""
        AC = re.compile(r'AC   (\w+);')
        seq = None
        for (num, line) in enumerate(fh):
            m = AC.match(line)
            if m:
                if seq:
                    raise modpipe.FileFormatError( \
"SPTR file contains AC record before end of previous sequence at line %d: %s" \
% (num + 1, line))
                seq = Sequence()
                seq.code = m.group(1)
                seq.primary = None
            elif line.startswith('SQ   SEQUENCE') and seq:
                seq.primary = ''
            elif line.startswith('//') and seq:
                yield seq
                seq = None
            elif seq and seq.primary is not None:
                seq.primary += line.rstrip().replace(' ', '')
        if seq and seq.primary is not None:
            yield seq


class UniqueFile(object):
    """Mapping file from alignment codes to ModPipe IDs"""

    def __init__(self):
        self._unqseq = {}

    def add_sequence(self, modpipe_id, align_code):
        """Add a single mapping from an align code to a ModPipe ID"""
        if modpipe_id in self._unqseq:
            self._unqseq[modpipe_id].append(align_code)
        else:
            self._unqseq[modpipe_id] = [align_code]

    def file_name_from_seqfile(self, seqfile):
        """Given an input sequence file, return a suitable name for the
           unique file"""
        return os.path.splitext(os.path.basename(seqfile))[0] + '.unq'

    def write(self, fh):
        """Write the code-ID mapping to a stream"""
        def iteritems(d):
            return d.iteritems() if sys.version_info[0] == 2 else d.items()
        for (key, value) in iteritems(self._unqseq):
            print("%s : %s" % (key, " ".join(value)), file=fh)
