from __future__ import print_function
import unittest
import sys
import os
import filecmp
import modpipe.test
from modpipe.scripts import MakeChains

class MakeChainsTests(modpipe.test.TestCase):

    pdb_codes = [('1apx', 'A', 'B', 'C', 'D'),
                 ('1bwy', 'A' ) , ('1cbi', 'A' ) ,
                 ('1ecs', 'A', 'B' ) , ('1f9z', 'A' ) ,
                 ('1fdq', 'A' ) , ('1flj', 'A' ) ,
                 ('1ftp', 'A' ) , ('1g7n', 'A' ) ,
                 ('1ggl', 'A' ) , ('1h0z', 'A' ) ,
                 ('1hcb', 'A' ) , ('1itk', 'A' ) ,
                 ('1iyn', 'A' ) , ('1j4w', 'A' ) ,
                 ('1jd0', 'A' ) , ('1keq', 'A' ) ,
                 ('1kll', 'A' ) , ('1kop', 'A' ) ,
                 ('1kqw', 'A' ) , ('1llp', 'A' ) ,
                 ('1lpj', 'A' ) , ('1lug', 'A' ) ,
                 ('1mwv', 'A' ) , ('1n8y', 'C' ) ,
                 ('1o8v', 'A' ) , ('1oaf', 'A' ) ,
                 ('1opa', 'A' ) , ('1pmp', 'A' ) ,
                 ('1qip', 'A' ) , ('1rj5', 'A' ) ,
                 ('1tou', 'A' ) , ('3dh4', 'A' ) ,
                 ('3agz', 'D' ) ,
                 ('1u2k', 'A' ) , ('1ub2', 'A' ) ,
                 ('1v9e', 'A' ) , ('1vyf', 'A' ) ,
                 ('1zby', 'A' ) , ('1znc', 'A' ) ,
                 ('1b56', 'A' ) , ('1mdc', 'A' ) ,
                 ('1hmr', 'A' ) , ('1cbs', 'A' ) ,
                 ('1crb', 'A' ) ,
                 ('1sj2', 'A', 'B')]

    def _make_pdb_list(self, prefix='pdb', suffix='.ent'):
        tmp = 'pdb-list'
        with open(tmp, 'w') as f:
            for pdb in self.pdb_codes:
                print(prefix + pdb[0] + suffix, file=f)
        return tmp

    def test_pdblist_exists(self):
        """Non-existing PDB list file should result in a MakeChains error"""
        self.assertRaises(IOError, self.run_script, 'scripts', MakeChains,
                          ['/does/not/exist'])

    def test_required_arguments(self):
        """Check required MakeChains arguments"""
        self.assertRaises(SystemExit, self.run_script, 'scripts',
                          MakeChains, [])

    def test_make_individual_chains(self):
        """Check production of individual chains files"""
        pdblist = self._make_pdb_list()
        self.run_script('scripts', MakeChains,
                        ['-p', '../db/pdb', pdblist])
        for pdb in self.pdb_codes:
            for chain in pdb[1:]:
                # 3agz should have been filtered out -
                # doesn't fulfill minimum criteria
                if pdb[0] != '3agz':
                    os.unlink(pdb[0] + chain + '.chn')
        os.unlink(pdblist)

    def test_make_one_file(self):
        """Check production of one aggregated chains file"""
        pdblist = self._make_pdb_list()
        out = 'pdb-list.pir'
        if os.path.exists(out):
            os.unlink(out)
        self.run_script('scripts', MakeChains,
                        ['-o', out, '-p', '../db/pdb', pdblist])
        self.assertTrue(filecmp.cmp(out, '../db/test-pdb.pir'))
        os.unlink(out)
        os.unlink(pdblist)

    def test_invalid_format_pdb_list(self):
        """Check for invalid format PDB list files for MakeChains"""
        pdblist = self._make_pdb_list(prefix='')
        out = 'pdb-list.pir'
        self.assertRaises(modpipe.FileFormatError, self.run_script, 'scripts',
                          MakeChains,
                          ['-o', out, '-p', '../db/pdb', pdblist])
        os.unlink(out)
        os.unlink(pdblist)

if __name__ == '__main__':
    unittest.main()
