from __future__ import print_function
import unittest
import modpipe.test
import modpipe.serialize
import subprocess
from modpipe.main import AddSeqMP
import os
import shutil

# Note that seq1 and seq2 are redundant sequences, after cleaning.
seq1 = {'align_code': 'test1',
        'id': 'f754177b1db86fba508552fc536485f1MLGIGIIM',
        'aaseq': 'MLGiI\tM', 'final_aaseq': 'MLGIIM'}
seq2 = {'align_code': 'test2',
        'id': 'f754177b1db86fba508552fc536485f1MLGIGIIM',
        'aaseq': 'M1LGIIM*', 'final_aaseq': 'MLGIIM'}
seq3 = {'align_code': 'test3',
        'id': '5a6451f4b89e9e9938d30b4d3f765397MLCVCVGI',
        'aaseq': 'MLCVGI', 'final_aaseq': 'MLCVGI'}
seq4 = {'align_code': 'test4',
        'id': '5a6451f4b89e9e9938d30b4d3f765397MLCVCVGI',
        'idX': 'f25547f23d5c4428632823e8255411a4MLCVCVXI',
        'aaseq': 'MLCVXI', 'final_aaseq': 'MLCVGI',
        'final_aaseq_X': 'MLCVXI' }

def _make_seq_file(seqfile, seqs):
    f = open(seqfile, 'w')
    for s in seqs:
        print(">" + s['align_code'], file=f)
        print(s['aaseq'], file=f)

class AddSeqMPTests(modpipe.test.TestCase):

    def test_add_redundant(self):
        """Check that AddSeqMP.py handles redundant sequences correctly"""
        seqfile = 'test.fsa'
        unqfile = 'test.unq'
        _make_seq_file(seqfile, (seq1, seq2))
        self._run_add_seq(seqfile)
        lines = open(unqfile).readlines()
        self.assertEqual(len(lines), 1)
        self.assertEqual(lines[0].strip(),
                         '%s : %s %s' % (seq1['id'], seq1['align_code'],
                                         seq2['align_code']))
        os.unlink(unqfile)
        self._check_seq_file(seq1)

    def test_add_single(self):
        """Check that AddSeqMP.py adds single sequences correctly"""
        seqfile = 'foo.fsa'
        unqfile = 'foo.unq'
        _make_seq_file(seqfile, (seq3,))
        self._run_add_seq(seqfile)
        lines = open(unqfile).readlines()
        self.assertEqual(len(lines), 1)
        self.assertEqual(lines[0].strip(),
                         '%s : %s' % (seq3['id'], seq3['align_code']))
        os.unlink(unqfile)
        self._check_seq_file(seq3)

    def test_add_pair(self):
        """Check that AddSeqMP.py adds pairs of sequences correctly"""
        seqfile = 'pair.fsa'
        unqfile = 'pair.unq'
        _make_seq_file(seqfile, (seq2, seq3))
        self._run_add_seq(seqfile)
        lines = open(unqfile).readlines()
        self.assertEqual(len(lines), 2)
        lines.sort() # unq file lines are not in guaranteed input order
        self.assertEqual(lines[1].strip(),
                         '%s : %s' % (seq2['id'], seq2['align_code']))
        self.assertEqual(lines[0].strip(),
                         '%s : %s' % (seq3['id'], seq3['align_code']))
        os.unlink(unqfile)
        self._check_seq_file(seq2)
        self._check_seq_file(seq3)

    def _check_seq_file(self, seq):
        id = seq['id']
        f = '%s/%s/sequence/%s.fsa' % (id[:3], id, id)
        lines = open(f).readlines()
        self.assertEqual(len(lines), 2)
        self.assertEqual(lines[0].strip(), '>%s' % id)
        self.assertEqual(lines[1].strip(), seq['final_aaseq'])
        os.unlink(f)
        os.rmdir('%s/%s/sequence' % (id[:3], id))
        os.rmdir('%s/%s' % (id[:3], id))
        os.rmdir(id[:3])

    def _run_add_seq(self, seqfile):
        conf = 'modpipe.conf'
        with open(conf, 'w') as fh:
            print("DATDIR  %s" % os.getcwd(), file=fh)
        self.run_script('main', AddSeqMP,
                        ['--conf_file', conf, '--sequence_file', seqfile])
        os.unlink(conf)
        os.unlink(seqfile)


class AddSeqTests(modpipe.test.TestCase):
    def _check_output_pir(self, seqfile, id, idtype, outdir, unqfile, keep = ''):
        import sys
        binary = self.get_modpipe_binary_path('src/AddSeq.py')
        for (dirtype, dirs) in (
                ('SIMPLE', (id,)),
                ('PDB', (id[1:3], id)),
                ('MODPIPE', (id[:3], id))):
            p = subprocess.Popen([binary, '-f', 'FASTA', '-g', 'PIR',
                                  '-c', idtype, '-s', dirtype, '-k', keep, seqfile, outdir],
                                  stdout=subprocess.PIPE,
                                  stderr=subprocess.PIPE)
            lines = p.stdout.readlines()
            print(p.stderr.read())
            self.require_clean_exit(p)
            lines = open(unqfile).readlines()
            seqlines = open(seqfile).readlines()
            self.assertEqual(len(lines), 1)
            self.assertEqual(lines[0].strip(),
                             '%s : %s' % (id, seq4['align_code']))
            os.unlink(unqfile)
            pirname = os.path.join(outdir, os.path.join(*dirs), '%s.pir' % id)
            pir = open(pirname).readlines()
            self.assertEqual(pir[0].rstrip(), ">P1;" + id)
            os.unlink(pirname)
            for n in range(len(dirs), 0, -1):
                os.rmdir(os.path.join(outdir, *dirs[:n]))

    def test_add_single_convert(self):
        """Check that AddSeq.py adds single sequences with unknown residues correctly"""
        import sys
        seqfile = 'foo.fsa'
        unqfile = 'foo.unq'
        outdir = 'tmpout'
        if os.path.exists(outdir):
            shutil.rmtree(outdir)
        _make_seq_file(seqfile, (seq4,))
        keep = ''
        for (id, idtype) in ((seq4['id'], 'MD5'), ('test4', 'CODE')):
            self._check_output_pir(seqfile, id, idtype, outdir, unqfile, '')
        os.unlink(seqfile)
        os.rmdir(outdir)

    def test_add_single_retain(self):
        """Check that AddSeq.py adds single sequences retaining unknown residues correctly"""
        import sys
        seqfile = 'foo.fsa'
        unqfile = 'foo.unq'
        outdir = 'tmpout'
        if os.path.exists(outdir):
            shutil.rmtree(outdir)
        _make_seq_file(seqfile, (seq4,))
        keep = 'X'
        for (id, idtype) in ((seq4['idX'], 'MD5'), ('test4', 'CODE')):
            self._check_output_pir(seqfile, id, idtype, outdir, unqfile, keep)
        os.unlink(seqfile)
        os.rmdir(outdir)

if __name__ == '__main__':
    unittest.main()
