import unittest
import modpipe
import modpipe.test
import modpipe.serialize
import sys
if sys.version_info[0] == 2:
    from io import BytesIO as TextIO
else:
    from io import StringIO as TextIO

_seq = modpipe.serialize.Sequence(id='seqid', length=100)
_aln = modpipe.serialize.Alignment(id='alnid', evalue=0., gap_percentage=5,
                                   score_chi_squared=0.1, score_ks=0.2)
_tmpl = modpipe.serialize.Template(code='1abc', chain='A', region=[1,20],
                                   sequence_identity=60)
_ga341 = modpipe.serialize.GA341(total=1., compactness=2., distance=3.,
                                 surface_area=4., combined=5., z_distance=6.,
                                 z_surface_area=7., z_combined=8.)
_tsvmod = modpipe.serialize.TSVMod(type=1., predicted_rmsd=2., predicted_no35=3.,
                                 features=4., relax_count=5., set_size=6.)
_score = modpipe.serialize.Score(objfunc=10., dope=20., dope_hr=30., normalized_dope=40.,
                                 quality=50., ga341=_ga341)
hit = modpipe.serialize.Hit(id='foo', sequence=_seq, alignment=_aln,
                            region=[1,2], fold_assignment_method='fold',
                            highest_sequence_identity=100, templates=[_tmpl])
mdl = modpipe.serialize.Model(id='foo', sequence=_seq, alignment=_aln,
                              region=[1,2], fold_assignment_method='fold',
                              hetatms=1, waters=4, score=_score,
                              highest_sequence_identity=100, rating='11001',
                              templates=[_tmpl])

class SerializeTests(modpipe.test.TestCase):
    """Test the modpipe.serialize module"""

    def test_write_file(self):
        """Test the modpipe.serialize.write_*_file functions"""
        for (meth, data, expected) in \
             [(modpipe.serialize.write_hits_file, [hit], '- !<Hit>'),
              (modpipe.serialize.write_models_file, [mdl], '- !<Model>')]:
            out = TextIO()
            # Make sure we get an error with invalid parameters
            self.assertRaises(ValueError, meth, out, data)
            meth(data, out)
            out.seek(0)
            lines = out.readlines()
            self.assert_(lines[0].startswith('- !<ModPipeVersion>'))
            # Some versions of YAML put entire header on one line (with
            # version info within {}), others split it over 3 lines
            if '{' in lines[0]:
                self.assert_(lines[1].startswith(expected))
            else:
                self.assert_(lines[3].startswith(expected))

    def test_empty_write_file(self):
        """Test modpipe.serialize.write_*_file with empty lists"""
        for meth in (modpipe.serialize.write_hits_file,
                     modpipe.serialize.write_models_file):
            out = TextIO()
            meth([], out)
            out.seek(0)
            lines = out.readlines()
            # Should get only the version header:
            self.assert_(lines[0].startswith('- !<ModPipeVersion>'))
            # Some versions of YAML output all one line, some split over 3 lines
            self.assertTrue(len(lines) in (1,3))
            out = TextIO()
            meth([], out, append=True)
            out.seek(0)
            lines = out.readlines()
            # Should get nothing at all:
            self.assertEqual(len(lines), 0)

    def test_read_file(self):
        """Test the modpipe.serialize.read_*_file functions"""
        for (methout, methin, methcount, data) in \
             [(modpipe.serialize.write_hits_file,
               modpipe.serialize.read_hits_file,
               modpipe.serialize.count_hits_file, [hit]),
              (modpipe.serialize.write_models_file,
               modpipe.serialize.read_models_file,
               modpipe.serialize.count_models_file, [mdl])]:
            out = TextIO()
            # Empty files should raise an error:
            self.assertRaises(modpipe.FileFormatError, list, methin(out))
            methout(data, out, append=True)
            out.seek(0)
            # Files lacking a version header should raise an error:
            out.seek(0)
            self.assertRaises(modpipe.FileFormatError, list, methin(out))
            out = TextIO()
            methout(data, out)
            out.seek(0)
            # Make sure we get an error with invalid parameters
            self.assertRaises(ValueError, list, methin('foo'))
            newdata = list(methin(out))
            self.assertEqual(len(newdata), len(data))
            out.seek(0)
            self.assertEqual(methcount(out), 1)
            for (a, b) in zip(newdata, data):
                self.assertEqual(type(a), type(b))
                self.assertEqual(a.id, b.id)
                self.assertEqual(a.region, b.region)

if __name__ == '__main__':
    unittest.main()
