# Note that this code needs Python 2.6 or later, unlike much of Modeller
# which works as far back as 2.3, so imports need to be protected by a
# version check

import modeller.features


class _CifAtom(object):
    """Map a Modeller Atom to mmCIF ids"""
    def __init__(self, atom, entity_map):
        self._atom = atom
        chain = atom.residue.chain
        # Modeller chain index is 1-based
        self.entity_id = entity_map[chain.index - 1]
        # Get seq_id offset (index is from start of model, but we need it to
        # be from the start of the chain)
        self.chain_offset = chain.residues[0].index - 1
    # Evaluate only when needed
    atom_id = property(lambda self: self._atom.name)
    comp_id = property(lambda self: self._atom.residue.name)
    asym_id = property(lambda self: self._atom.residue.chain.name)
    seq_id = property(
        lambda self: self._atom.residue.index - self.chain_offset)


class TemplateChain(object):
    """A single chain of a template"""
    pass


class TargetChain(object):
    """A single chain in the target"""
    pass


class Alignment(object):
    """A multiple alignment between one TargetChain and one or
       more TemplateChains"""
    pass


class Data(object):
    """Some data used in the modeling."""
    pass


class DataGroup(list):
    """A group of data."""
    pass


class TemplateData(Data):
    """Template data used in the modeling."""
    content_type = 'template structure'


class AlignmentData(Data):
    """Alignment data used in the modeling."""
    content_type = 'target-template alignment'


class RestraintData(Data):
    """Restraint data used in the modeling."""
    content_type = 'restraint'


class TargetData(Data):
    """Target sequence"""
    content_type = 'target'


class ModelData(Data):
    """Final model output."""
    content_type = 'model coordinates'


class Restraints(object):
    """A set of restraints used in the modeling."""
    pass


class DistanceRestraints(Restraints, list):
    """A list of distance restraints used in the modeling."""
    restraint_type = 'Distance restraints'


class CifData(object):
    def __init__(self):
        self.data = []
        self.data_groups = []
        self.modeling_input_data = DataGroup()
        self._add_data_group(self.modeling_input_data)
        self.modeling_output_data = DataGroup()
        self._add_data_group(self.modeling_output_data)

    def add_alignment_info(self, aln, knowns, seq, mdl):
        self.mdl = mdl
        self._add_target_chains(aln[seq], mdl)
        self._add_template_chains(aln, knowns, aln[seq])
        self._add_alignments(aln)

    def add_restraints(self, user_restraints):
        self.restraints = []
        # Only handle single-feature restraints for now
        self._add_distance_restraints(
            r for r in user_restraints
            if isinstance(r._features, modeller.features.Distance))

    def _add_distance_restraints(self, rs):
        self.distance_restraints = DistanceRestraints(rs)
        if len(self.distance_restraints) > 0:
            self._add_restraint(self.distance_restraints)

    def _add_data(self, data, data_group):
        """Add some data used in the modeling."""
        self.data.append(data)
        data.id = len(self.data)
        data_group.append(data)

    def _add_data_group(self, data_group):
        """Add some a group of data"""
        self.data_groups.append(data_group)
        data_group.id = len(self.data_groups)

    def _add_restraint(self, rs):
        """Add a set of restraints."""
        self.restraints.append(rs)
        rs.id = len(self.restraints)

        # add Data for this set of restraints
        d = RestraintData()
        d.name = 'User-provided restraints'
        rs.data = d
        self._add_data(d, self.modeling_input_data)

    def _get_gapped_sequence(self, chain):
        """Return a sequence including gaps for the given chain"""
        # todo: this is not quite right because leading gaps for the N terminus
        # will include trailing gaps of the C terminus of the previous chain
        def rescode_with_leading_gap(residue):
            return '-' * residue.get_leading_gaps() + residue.code
        return ''.join(rescode_with_leading_gap(r) for r in chain.residues)

    def _add_alignments(self, aln):
        """Add an Alignment for each TargetChain."""
        # add Data for this alignment
        d = AlignmentData()
        d.name = 'Target Template Alignment'
        self._add_data(d, self.modeling_input_data)

        target_to_alignment = {}
        self.alignments = []
        for nchain, target in enumerate(self.target_chains):
            a = Alignment()
            target_to_alignment[target] = a
            a.data = d
            a.target = target
            a.templates = []
            a.id = nchain + 1
            self.alignments.append(a)
        # Add template information
        for template in self.template_chains:
            a = target_to_alignment[template.target_chain]
            a.templates.append(template)

    def _get_target_chain(self, chain, target):
        """Get the TargetChain object that aligns with this chain"""
        # We just return the first match. This will miss cases where a template
        # chain aligns with multiple target chains, but this isn't handled
        # by the dictionary anyway (and is usually a modeling error)
        for r in chain.residues:
            target_r = r.get_aligned_residue(target)
            if target_r:
                # Modeller chain index is 1-based
                return self.target_chains[target_r.chain.index - 1]

    def _add_template_chains(self, aln, knowns, target):
        ordinal = 1
        self.template_chains = []
        for k in knowns:
            # Add Data for this template
            d = TemplateData()
            d.name = 'Template Structure'
            self._add_data(d, self.modeling_input_data)

            seq = aln[k]
            for chain in seq.chains:
                t = TemplateChain()
                # Assume PDB code is first 4 characters of align code
                t.pdb_code = k[:4].upper()
                t.id = ordinal
                t.asym_id = chain.name
                t.template_data = d
                t.seq_range = (1, len(chain.residues))
                # todo: handle non-standard residues
                t.sequence = ''.join(r.code for r in chain.residues)
                t.gapped_sequence = self._get_gapped_sequence(chain)
                t.sequence_can = ''.join(r.code for r in chain.residues)
                t.target_chain = self._get_target_chain(chain, target)
                self.template_chains.append(t)
                ordinal += 1

    def _get_target_entities(self, seq):
        # Mapping from chain # to entity #
        self._entity_map = {}
        seen_seqs = {}
        for nchain, chain in enumerate(seq.chains):
            s = tuple(r.type for r in chain.residues)
            seen_seqs[s] = None
            self._entity_map[nchain] = len(seen_seqs)

    def _add_target_chains(self, seq, mdl):
        self._get_target_entities(seq)
        ordinal = 1
        self.target_chains = []
        for nchain, chain in enumerate(seq.chains):
            model_chain = mdl.chains[nchain]
            t = TargetChain()
            t.id = ordinal
            # Use output model chain IDs, not the original sequence (chain
            # IDs may have been changed)
            t.asym_id = model_chain.name
            t.entity_id = self._entity_map[nchain]
            t.sequence = ''.join(r.code for r in chain.residues)
            t.gapped_sequence = self._get_gapped_sequence(chain)
            t.seq_range = (1, len(chain.residues))
            self.target_chains.append(t)
            ordinal += 1

    def write_mmcif(self, writer):
        target_seq = TargetData()
        target_seq.name = 'Target Sequence'
        self._add_data(target_seq, self.modeling_input_data)
        model_data = ModelData()
        model_data.name = 'Target Structure'
        self._add_data(model_data, self.modeling_output_data)
        self._write_software_groups(writer)
        self._write_template_details(writer)
        self._write_template_segments(writer)
        self._write_target_entity(writer, target_seq)
        self._write_poly_mapping(writer)
        self._write_alignment(writer)
        self._write_data(writer)
        self._write_data_groups(writer)
        self._write_restraints(writer)
        self._write_distance_restraints(writer)
        self._write_protocols(writer)
        self._write_model_list(writer, model_data)

    def _write_software_groups(self, writer):
        # Write a single group that includes software #1 (Modeller, already
        # written out by Model.write())
        with writer.loop("_ma_software_group",
                         ["ordinal_id", "group_id", "software_id",
                          "parameter_group_id"]) as lp:
            lp.write(ordinal_id=1, group_id=1, software_id=1)

    def _write_template_details(self, writer):
        with writer.loop(
                "_ma_template_trans_matrix",
                ["id",
                 "rot_matrix[1][1]", "rot_matrix[2][1]", "rot_matrix[3][1]",
                 "rot_matrix[1][2]", "rot_matrix[2][2]", "rot_matrix[3][2]",
                 "rot_matrix[1][3]", "rot_matrix[2][3]", "rot_matrix[3][3]",
                 "tr_vector[1]", "tr_vector[2]", "tr_vector[3]"]) as lp:
            lp.write(id=1, rot_matrix11=1.0, rot_matrix21=0.0,
                     rot_matrix31=0.0, rot_matrix12=0.0,
                     rot_matrix22=1.0, rot_matrix32=0.0,
                     rot_matrix13=0.0, rot_matrix23=0.0,
                     rot_matrix33=1.0, tr_vector1=0.0,
                     tr_vector2=0.0, tr_vector3=0.0)

        with writer.loop('_ma_template_details',
                         ['ordinal_id', 'template_id',
                          'template_origin',
                          'template_entity_type', 'template_trans_matrix_id',
                          'template_data_id', 'target_asym_id',
                          'template_model_num',
                          'template_auth_asym_id']) as lp:
            for t in self.template_chains:
                lp.write(ordinal_id=t.id, template_id=t.id,
                         template_origin='?',
                         template_entity_type='polymer',
                         template_trans_matrix_id=1,
                         template_data_id=t.template_data.id,
                         target_asym_id=t.target_chain.asym_id,
                         template_model_num=1, template_auth_asym_id=t.asym_id)

        with writer.loop('_ma_template_poly',
                         ['template_id', 'seq_one_letter_code',
                          'seq_one_letter_code_can']) as lp:
            for t in self.template_chains:
                lp.write(template_id=t.id,
                         seq_one_letter_code=t.sequence,
                         seq_one_letter_code_can=t.sequence_can)


    def _write_template_segments(self, writer):
        ordinal = 1
        with writer.loop('_ma_template_poly_segment',
                         ['id', 'template_id', 'residue_number_begin',
                          'residue_number_end']) as lp:
            for t in self.template_chains:
                lp.write(id=ordinal, template_id=t.id,
                         residue_number_begin=t.seq_range[0],
                         residue_number_end=t.seq_range[1])
                t.segment_id = ordinal
                ordinal += 1

    def _write_target_entity(self, writer, target_seq):
        all_entities = set(t.entity_id for t in self.target_chains)
        with writer.loop('_ma_target_entity',
                         ['entity_id', 'data_id', 'origin']) as lp:
            for entity_id in sorted(all_entities):
                lp.write(entity_id=entity_id,
                         data_id=target_seq.id, origin='?')

        with writer.loop('_ma_target_entity_instance',
                         ['asym_id', 'entity_id', 'details']) as lp:
            for t in self.target_chains:
                lp.write(asym_id=t.asym_id, entity_id=t.entity_id,
                         details="Model subunit")

    def _write_poly_mapping(self, writer):
        ordinal = 1
        with writer.loop('_ma_target_template_poly_mapping',
                         ['id', 'template_segment_id', 'target_asym_id',
                          'target_seq_id_begin', 'target_seq_id_end']) as lp:
            for a in self.alignments:
                for t in a.templates:
                    lp.write(id=ordinal, template_segment_id=t.segment_id,
                             target_asym_id=a.target.asym_id,
                             target_seq_id_begin=a.target.seq_range[0],
                             target_seq_id_end=a.target.seq_range[1])
                    ordinal += 1

    def _write_alignment(self, writer):
        # todo: populate with info on how the alignment was made
        with writer.loop('_ma_alignment_info',
                         ['alignment_id', 'data_id',
                          'alignment_length', 'alignment_type']) as lp:
            for a in self.alignments:
                lp.write(alignment_id=a.id, data_id=a.data.id,
                         alignment_length=len(a.target.gapped_sequence),
                         alignment_type='target-template pairwise alignment')
        with writer.loop('_ma_alignment_details',
                         ['ordinal_id', 'alignment_id', 'template_segment_id',
                          'target_asym_id']) as lp:
            ordinal = 1
            for a in self.alignments:
                for template in a.templates:
                    lp.write(ordinal_id=ordinal, alignment_id=a.id,
                             template_segment_id=template.segment_id,
                             target_asym_id=a.target.asym_id)

        with writer.loop('_ma_alignment',
                         ['ordinal_id', 'alignment_id', 'target_template_flag',
                          'sequence']) as lp:
            ordinal = 1
            for a in self.alignments:
                for template in a.templates:
                    lp.write(ordinal_id=ordinal, alignment_id=a.id,
                             target_template_flag=2,  # Template
                             sequence=template.gapped_sequence)
                    ordinal += 1
                lp.write(ordinal_id=ordinal, alignment_id=a.id,
                         target_template_flag=1,  # Target
                         sequence=a.target.gapped_sequence)

    def _write_data(self, writer):
        with writer.loop('_ma_data',
                         ['id', 'name', 'content_type']) as lp:
            for d in self.data:
                lp.write(id=d.id, name=d.name, content_type=d.content_type)

    def _write_data_groups(self, writer):
        ordinal = 1
        with writer.loop('_ma_data_group',
                         ['ordinal_id', 'group_id', 'data_id']) as lp:
            for g in self.data_groups:
                for d in g:
                    lp.write(ordinal_id=ordinal, group_id=g.id, data_id=d.id)
                    ordinal += 1

    def _write_restraints(self, writer):
        with writer.loop('_ma_restraints',
                         ['ordinal_id', 'restraint_id', 'data_id', 'name',
                          'restraint_type', 'details']) as lp:
            ordinal = 1
            for r in self.restraints:
                lp.write(ordinal_id=ordinal, restraint_id=r.id,
                         data_id=r.data.id, restraint_type=r.restraint_type)
                ordinal += 1

    def _write_distance_restraints(self, writer):
        with writer.loop('_ma_distance_restraints',
                         ['ordinal_id', 'restraint_id', 'group_id',
                          'entity_id_1', 'asym_id_1', 'seq_id_1', 'comp_id_1',
                          'atom_id_1', 'entity_id_2', 'asym_id_2', 'seq_id_2',
                          'comp_id_2', 'atom_id_2',
                          'distance_threshold', 'uncertainty']) as lp:
            ordinal = 1
            for r in self.distance_restraints:
                # Assume all such restraints are Gaussian for now
                atoms = r._features.indices_to_atoms(
                    self.mdl, r._features.get_atom_indices()[0])
                atoms = [_CifAtom(a, self._entity_map) for a in atoms]
                lp.write(ordinal_id=ordinal,
                         restraint_id=self.distance_restraints.id,
                         group_id=1,
                         atom_id_1=atoms[0].atom_id,
                         comp_id_1=atoms[0].comp_id,
                         asym_id_1=atoms[0].asym_id,
                         seq_id_1=atoms[0].seq_id,
                         entity_id_1=atoms[0].entity_id,
                         atom_id_2=atoms[1].atom_id,
                         comp_id_2=atoms[1].comp_id,
                         asym_id_2=atoms[1].asym_id,
                         seq_id_2=atoms[1].seq_id,
                         entity_id_2=atoms[1].entity_id,
                         distance_threshold=r._parameters[0],
                         uncertainty=r._parameters[1])
                ordinal += 1

    def _write_protocols(self, writer):
        # We write a single 'modeling' step using software group 1 (Modeller)
        with writer.loop('_ma_protocol_step',
                         ['ordinal_id', 'protocol_id', 'step_id',
                          'method_type', 'step_name', 'details',
                          'software_group_id', 'input_data_group_id',
                          'output_data_group_id']) as lp:
            lp.write(ordinal_id=1, protocol_id=1, step_id=1,
                     method_type='modeling', software_group_id=1,
                     input_data_group_id=self.modeling_input_data.id,
                     output_data_group_id=self.modeling_output_data.id)

    def _write_model_list(self, writer, model_data):
        with writer.loop('_ma_model_list',
                         ['ordinal_id', 'model_id', 'model_group_id',
                          'model_name', 'model_group_name', 'data_id',
                          'model_type', 'model_type_other_details']) as lp:
            lp.write(ordinal_id=1, model_id=1, model_group_id=1,
                     model_name="Target Structure",
                     model_group_name="All models",
                     data_id=model_data.id, model_type="Homology model")
