from __future__ import print_function
import unittest
import modpipe
import modpipe.test
import modpipe.sequence
import sys
import io
import os

class SequenceTests(modpipe.test.TestCase):

    def test_id(self):
        """Check sequence ID"""
        for (seq, id) in \
            (('MLGIIM', 'f754177b1db86fba508552fc536485f1MLGIGIIM'),
             ('MLCVGI', '5a6451f4b89e9e9938d30b4d3f765397MLCVCVGI')):
            s = modpipe.sequence.Sequence()
            s.primary = seq
            self.assertEqual(s.get_id(), id)

    def test_clean(self):
        """Check Sequence.clean()"""
        # All of these sequences should be treated as identical: clean()
        # removes line breaks, spaces, non-word characters, underscores,
        # numbers and *, converts to uppercase, maps B(ASX) to N(ASN)
        # and Z(GLX) to Q(GLN), and maps everything else to G(GLY).
        for aaseq in ('MLC\nVQV\nNIG', 'M\rLCV\rQVNIG', 'MLC\tVQV\tNIG',
                      'MLC_VQV_NIG', 'M1LCVQ4VNIG', 'MLCV*QVNIG*',
                      'M LCVQ  VNIG', 'MLcVQVnIG', 'MLCVQVBIG', 'MLCVZVNIG',
                      'MLCVQVNIJ', 'MLC\nv_QV1NIj*', 'MLCVQVNI.'):
            s = modpipe.sequence.Sequence()
            s.primary = aaseq
            s.clean()
            self.assertEqual(s.primary, 'MLCVQVNIG')

    def test_read_fasta(self):
        """Check reading of FASTA files"""
        p = modpipe.sequence.FASTAFile()
        # Blank lines before header, or within sequence, should be ignored
        for prefix in ('', '\n'):
            for seq in ('AAP', '\nAA\n\nP\n\n'):
                s = io.StringIO() if sys.version_info[0] >= 3 else io.BytesIO()
                print(prefix + ">foo\n" + seq, file=s)
                s.seek(0)
                seqs = list(p.read(s))
                self.assertEqual(len(seqs), 1)
                self.assertEqual(seqs[0].code, 'foo')
                self.assertEqual(seqs[0].primary, 'AAP')

    def test_read_pir(self):
        """Check reading of PIR files"""
        p = modpipe.sequence.PIRFile()
        s = io.StringIO() if sys.version_info[0] >= 3 else io.BytesIO()
        print(">P1;seq1", file=s)
        print("structureX:1abc:1:A:10:B:Name:Source:2.0:1.0\nAAP*", file=s)
        s.seek(0)
        seqs = list(p.read(s))
        self.assertEqual(len(seqs), 1)
        self.assertEqual(seqs[0].code, 'seq1')
        self.assertEqual(seqs[0].prottyp, 'structureX')
        self.assertEqual(seqs[0].atom_file, '1abc')
        self.assertEqual(seqs[0].range, [['1', 'A'], ['10', 'B']])
        self.assertEqual(seqs[0].primary, 'AAP')
        self.assertEqual(seqs[0].name, 'Name')
        self.assertEqual(seqs[0].source, 'Source')
        self.assertEqual(seqs[0].resolution, '2.0')
        self.assertEqual(seqs[0].rfactor, '1.0')
        # Make sure that empty prottyp defaults to 'sequence'
        p = modpipe.sequence.PIRFile()
        s = io.StringIO() if sys.version_info[0] >= 3 else io.BytesIO()
        print(">P1;seq1", file=s)
        print(":::::::::\nAAP*", file=s)
        s.seek(0)
        seqs = list(p.read(s))
        self.assertEqual(len(seqs), 1)
        self.assertEqual(seqs[0].prottyp, 'sequence')

    def test_read_sptr(self):
        """Check reading of SPTR files"""
        p = modpipe.sequence.SPTRFile()
        s = io.StringIO() if sys.version_info[0] >= 3 else io.BytesIO()
        print("AC   Q4U9M9;\nSQ   SEQUENCE\n     AAP\n//", file=s)
        s.seek(0)
        seqs = list(p.read(s))
        self.assertEqual(len(seqs), 1)
        self.assertEqual(seqs[0].code, 'Q4U9M9')
        self.assertEqual(seqs[0].primary, 'AAP')

    def test_bad_fasta(self):
        """Check handling of invalid FASTA files"""
        p = modpipe.sequence.FASTAFile()
        s = io.StringIO() if sys.version_info[0] >= 3 else io.BytesIO()
        # Sequence without a preceding header
        print("AAP", file=s)
        s.seek(0)
        self.assertRaises(modpipe.FileFormatError, list, p.read(s))

    def test_bad_pir(self):
        """Check handling of invalid PIR files"""
        for invalid in (
               # New sequence without termination of the previous one
               ">P1;seq1\nsequence:::::::::\nAAP\n>P1;seq2",
               # Sequence without a preceding header
               "AAP*",
               # Sequence not terminated at end-of-file
               ">P1;seq1\nsequence:::::::::\nAAP",
               # Invalid PIR header (wrong number of colons)
               ">P1;seq1\nsequence::::::::\nAAP*"):
            p = modpipe.sequence.PIRFile()
            s = io.StringIO() if sys.version_info[0] >= 3 else io.BytesIO()
            print(invalid, file=s)
            s.seek(0)
            self.assertRaises(modpipe.FileFormatError, list, p.read(s))

    def test_bad_sptr(self):
        """Check handling of invalid SPTR files"""
        p = modpipe.sequence.SPTRFile()
        s = io.StringIO() if sys.version_info[0] >= 3 else io.BytesIO()
        # AC record before end of previous sequence
        print("AC   Q4U9M9;\nSQ   SEQUENCE\n     AAP\nAC   P15711;", file=s)
        s.seek(0)
        self.assertRaises(modpipe.FileFormatError, list, p.read(s))

    def test_extract_sequences(self):
        """Check ExtractSequences"""
        self.run_perl_code("""use PLLib::Sequence;
ExtractSequences('../db/test-pdb.fsa', 'test.fsa', 'FASTA', ['1ecsA']);""")
        out = open('test.fsa').read()
        self.assertEqual(out, """>1ecsA
TDQATPNLPSRDFDSTAAFYERLGFGIVFRDAGWMILQRGDLMLEFFAHPGLDPLASWFSCCLRLDDLAEFYRQC
KSVGIQETSSGYPRIHAPELQGWGGTMAALVDPDGTLLRLIQNEL\n""")
        os.unlink('test.fsa')

if __name__ == '__main__':
    unittest.main()
