# This file is part of ModPipe, Copyright 1997-2020 Andrej Sali
#
# ModPipe is free software: you can redistribute it and/or modify
# it under the terms of version 2 of the GNU General Public License
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ModPipe.  If not, see <http://www.gnu.org/licenses/>.


import modpipe.version
import modpipe
import sys
import os

# Get Python-version specific directory for yaml
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__),
    '..', '..', 'python%d' % sys.version_info[0])))
import yaml

# Use faster C implementations if available
try:
    from yaml import CLoader as Loader
    from yaml import CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper

# YAMLObjects must use the same loader/dumper as calls to dump/load_all,
# otherwise we get 'could not determine a constructor' yaml ConstructorErrors
class _YAMLObject(yaml.YAMLObject):
    yaml_loader = Loader
    yaml_dumper = Dumper

class _ModPipeVersion(_YAMLObject):
    yaml_tag = u'ModPipeVersion'
    def __init__(self):
        self.modpipe = modpipe.version.get()
        self.file = 1

class Hit(_YAMLObject):
    """Representation of a single hit found by ModPipe."""
    yaml_tag = u'Hit'
    def __init__(self, id, sequence, alignment, region,
                 fold_assignment_method, highest_sequence_identity,
                 templates):
        (self.id, self.sequence, self.alignment, self.region,
         self.fold_assignment_method, self.highest_sequence_identity,
         self.templates) = \
           (id, sequence, alignment, region, fold_assignment_method,
            highest_sequence_identity, templates)

class Sequence(_YAMLObject):
    yaml_tag = u'Sequence'
    def __init__(self, id, length):
        (self.id, self.length) = (id, length)

class Alignment(_YAMLObject):
    yaml_tag = u'Alignment'
    def __init__(self, id, evalue, gap_percentage, score_chi_squared, score_ks):
        (self.id, self.evalue, self.gap_percentage, self.score_chi_squared,
         self.score_ks) = \
        (id, evalue, gap_percentage, score_chi_squared, score_ks)

class Model(_YAMLObject):
    """Representation of a single model generated by ModPipe."""
    yaml_tag = u'Model'
    def __init__(self, id, sequence, alignment, region, fold_assignment_method,
                 hetatms, waters, score, highest_sequence_identity, rating,
                 templates):
        (self.id, self.sequence, self.alignment, self.region,
         self.fold_assignment_method, self.hetatms, self.waters, self.score,
         self.highest_sequence_identity, self.rating, self.templates) = \
        (id, sequence, alignment, region, fold_assignment_method,
         hetatms, waters, score, highest_sequence_identity, rating, templates)

class TSVMod(_YAMLObject):
    yaml_tag = u'TSVMod'
    def __init__(self, type, predicted_rmsd, predicted_no35,
                 features, relax_count, set_size):
        (self.type, self.predicted_rmsd, self.predicted_no35,
         self.features, self.relax_count, self.set_size) = \
        (type, predicted_rmsd, predicted_no35, features, relax_count, set_size)

class GA341(_YAMLObject):
    yaml_tag = u'GA341'
    def __init__(self, total, compactness, distance, surface_area,
                 combined, z_distance, z_surface_area, z_combined):
        (self.total, self.compactness, self.distance,
         self.surface_area, self.combined, self.z_distance,
         self.z_surface_area, self.z_combined) = \
        (total, compactness, distance, surface_area, combined, z_distance,
         z_surface_area, z_combined)

class Score(_YAMLObject):
    yaml_tag = u'Score'
    def __init__(self, objfunc, dope, dope_hr, normalized_dope, quality,
                 ga341):
        (self.objfunc, self.dope, self.dope_hr, self.normalized_dope,
         self.quality, self.ga341) = \
        (objfunc, dope, dope_hr, normalized_dope, quality, ga341)

class NativeBenchmark(_YAMLObject):
    yaml_tag = u'NativeBenchmark'
    def __init__(self,chain,code,cutoff_rms,global_num_equiv_pos,global_rms,
                 length,mean_cutoff_rms,mean_num_equiv_pos,num_equiv_pos_35,
                 region,cutoff_rms_35):
        (self.chain,self.code,self.cutoff_rms,self.global_num_equiv_pos,
         self.global_rms,self.length,self.mean_cutoff_rms,self.mean_num_equiv_pos,
         self.num_equiv_pos_35,self.region,cutoff_rms_35) = \
        (chain,code,cutoff_rms,global_num_equiv_pos,global_rms,length,
         mean_cutoff_rms,mean_num_equiv_pos, num_equiv_pos_35,region,cutoff_rms_35)


class Template(_YAMLObject):
    yaml_tag = u'Template'
    def __init__(self, code, chain, region, sequence_identity):
        (self.code, self.chain, self.region, self.sequence_identity) = \
        (code, chain, region, sequence_identity)

def write_hits_file(hits, stream, append=False):
    """Write a list of :class:`Hit` objects to the given stream.
       If *append* is set to *True*, no version header is written to the
       stream (it is assumed that the header is already there)."""
    _check_stream(stream)
    if not append:
        yaml.dump([_ModPipeVersion()], stream, Dumper=Dumper,
                  default_flow_style=None)
    if len(hits) > 0:
        yaml.dump(hits, stream, Dumper=Dumper, default_flow_style=None)

def write_models_file(models, stream, append=False, separator=False):
    """Write a list of :class:`Model` objects to the given stream.
       If *append* is set to *True*, no version header is written to the
       stream (it is assumed that the header is already there).
       If *separator* is set to *True*, then "---" will be written before
       the appended models."""
    _check_stream(stream)
    if not append:
        yaml.dump([_ModPipeVersion()], stream, Dumper=Dumper,
                  default_flow_style=None)
    if len(models) > 0:
        if append:
            if separator:
                stream.write("---\n")

        yaml.dump(models, stream, Dumper=Dumper, default_flow_style=None)

def _check_stream(stream, desc='stream'):
    if not hasattr(stream, 'read'):
        raise ValueError("Expecting a file-like object for '%s' " % desc + \
                         "parameter - got %s" % str(stream))

def _check_list_of(models, typ, desc, first=True):
    if not isinstance(models, list):
        raise modpipe.FileFormatError(desc + " should be a YAML list; " \
                                      + "got %s" % str(models))
    if first:
        if len(models) < 1 or not isinstance(models[0], _ModPipeVersion):
            raise modpipe.FileFormatError(desc + " should start with version")
    for m in models[1:]:
        if not isinstance(m, typ):
            raise modpipe.FileFormatError(desc + " contains something" \
                                          + " not a " + str(typ) + ": " \
                                          + str(m))

def read_models_file(stream):
    """Read and return a list of :class:`Model` objects from the given stream,
       as a generator.
       Some checking is done to make sure the file is valid YAML format
       (e.g. that the file starts with a version header, and that each
       object is really a :class:`Model`.)"""
    _check_stream(stream)
    first = True
    docs = yaml.load_all(stream,Loader=Loader)
    for models in docs:
        _check_list_of(models, Model, "Models file", first)
        if first:
            version = models.pop(0)
        first = False
        for model in models:
            yield model
    # Handle empty files
    if first:
        _check_list_of([], Model, "Models file", first)

def read_hits_file(stream):
    """Read and return a list of :class:`Hit` objects from the given stream,
       as a generator.
       Some checking is done to make sure the file is valid YAML format
       (e.g. that the file starts with a version header, and that each
       object is really a :class:`Hit`.)"""
    _check_stream(stream)
    first = True
    docs = yaml.load_all(stream,Loader=Loader)
    for hits in docs:
        _check_list_of(hits, Hit, "Hits file", first)
        if first:
            version = hits.pop(0)
        first = False
        for hit in hits:
            yield hit
    # Handle empty files
    if first:
        _check_list_of([], Hit, "Hits file", first)

def _count_yaml_file(stream, count_line):
    _check_stream(stream)
    count = 0
    for line in stream:
        if line.rstrip('\r\n') == count_line:
            count += 1
    return count

def count_hits_file(stream):
    """Count the number of :class:`Hit` objects from the given stream.
       This is faster than `len(list(read_hits_file(...))` since no
       checking for validity, or object construction, is done."""
    return _count_yaml_file(stream, '- !<Hit>')

def count_models_file(stream):
    """Count the number of :class:`Model` objects from the given stream.
       This is faster than `len(list(read_models_file(...))` since no
       checking for validity, or object construction, is done."""
    return _count_yaml_file(stream, '- !<Model>')
