#!/usr/bin/env python3

import csv
import json
import sys
import re

import requests

if len(sys.argv) < 3:
    sys.exit("needs two arguments: uniprot mapping file and PDBe data file")

taxon_dict = dict([(284812, 4896), (4895, 4896), (1264690, 4896), (402676, 4897)])

uniprot_mapping_filename = sys.argv[1]
pdbe_data_filename = sys.argv[2]

uniprot_pombe_map = {}

output_lines = []

with open(uniprot_mapping_filename, mode='r') as tsvfile:
    reader = csv.reader(tsvfile, delimiter='\t')
    for row in reader:
        uniprot_pombe_map[row[1]] = row[0]

seen_output_lines = {}

print("gene_uniquename\tpdb_id\ttax_id\tuniprot_accession\ttitle\tentry_authors\tentry_authors_abbrev\treference\texperimental_method\tresolution\tchain\tposition")

ex_method_dict = {
    'Electron Microscopy': 'EM',
    'X-ray diffraction': 'X-ray',
    'Solution NMR': 'NMR',
}

pdb_mappings_url = 'https://www.ebi.ac.uk/pdbe/api/mappings/best_structures'

def fetch_uniprot_mappings(uniprot_accessions):
    headers = {'Accept': 'application/json'}
    uniprot_accessions_with_commas = ','.join(uniprot_accessions)

    req = requests.get(f'{pdb_mappings_url}/{uniprot_accessions_with_commas}', headers=headers)

    ret = req.json()

    with open("mapping_temp.json", "w") as f:
        f.write(json.dumps(ret, indent=2))

    return ret

#def fetch_uniprot_mappings(uniprot_accessions):
#    with open("mapping_temp.json", "r") as f:
#        return json.load(f)


def mapping_lookup(output_lines):
    uniprot_accessions = set()

    for line in output_lines:
        uniprot_accessions.add(line['uniprot_accession'])

    return fetch_uniprot_mappings(uniprot_accessions)

def process_row(row):
    global output_lines

    pdb_id = row['pdb_id']
    tax_ids = row['tax_id']
    title = row['title']
    entry_authors = row['entry_authors']

    entry_authors_abbrev = re.sub(pattern=r'^((?:v[ao]n\s)?(?:[A-Z]\w+(?:-[A-Z]\w+)?\s+)+[A-Z]+(?:-[A-Z]+)?),.*',
           repl=r"\1 et al.",
           string=entry_authors)

    if re.search(r'^\d+$', row['pubmed_id']):
        reference = 'PMID:' + row['pubmed_id']
    else:
        reference = ''

    for tax_id in [int(id) for id in tax_ids.split(',')]:
        if tax_id in taxon_dict:
            tax_id = taxon_dict[tax_id]
        if tax_id not in taxon_dict.values():
            continue
        experimental_method = row['experimental_method']
        if experimental_method in ex_method_dict:
            experimental_method = ex_method_dict[experimental_method]
        molecule_type = row['molecule_type']
        if molecule_type == 'RNA':
            continue

        resolution = row['resolution']
        uniprot_accessions = row['entry_uniprot_accession']

        for uniprot_accession in uniprot_accessions.split(','):
            if uniprot_accession in uniprot_pombe_map:
                pombe_id = uniprot_pombe_map[uniprot_accession]
                output_lines.append({
                    'pdb_id': pdb_id,
                    'tax_id': tax_id,
                    'experimental_method': experimental_method,
                    'uniprot_accession': uniprot_accession,
                    'resolution': resolution,
                    'gene_uniquename': pombe_id,
                    'title': title,
                    'entry_authors': entry_authors,
                    'entry_authors_abbrev': entry_authors_abbrev,
                    'reference': reference,
                })


def order_output_by_mapping(output_lines, uniprot_mappings):
    mapping_data_ids = []

    def line_sort_key(line):
        uniprot_accession = line['uniprot_accession']
        mapping = uniprot_mappings[uniprot_accession]
        pos = 0
        for idx, mapping_item in enumerate(mapping):
            if mapping_item['pdb_id'] == line['pdb_id']:
                pos = idx
                break

        return (uniprot_accession, pos)

    output_lines.sort(key=line_sort_key)

def write_output(output_lines, uniprot_mappings):
    global seen_output_lines

    for line_data in output_lines:
        pdb_id = line_data['pdb_id']
        tax_id = line_data['tax_id']
        experimental_method = line_data['experimental_method']
        uniprot_accession = line_data['uniprot_accession']
        resolution = line_data['resolution']
        pombe_id = line_data['gene_uniquename']

        if uniprot_accession not in uniprot_mappings:
            print("not found in uniprot_mappings: " + uniprot_accession, file=sys.stderr)
            continue

        mapping_data = uniprot_mappings[uniprot_accession]

        chain_list = []
        pos_start = None
        pos_end = None

        title = line_data['title']
        entry_authors = line_data['entry_authors']
        entry_authors_abbrev = line_data['entry_authors_abbrev']
        reference = line_data['reference']

        for data in mapping_data:
            if data['pdb_id'] == pdb_id:
                chain_list.append(data['chain_id'])
                pos_start = data['unp_start']
                pos_end = data['unp_end']

        chains = '/'.join(chain_list)

        output_line = f"{pombe_id}\t{pdb_id}\t{tax_id}\t{uniprot_accession}\t{title}\t{entry_authors}\t{entry_authors_abbrev}\t{reference}\t{experimental_method}\t{resolution}\t{chains}\t{pos_start}-{pos_end}"

        if not output_line in seen_output_lines:
            seen_output_lines[output_line] = True
            print (output_line)

with open(pdbe_data_filename, mode='r') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        process_row(row)

uniprot_mappings = mapping_lookup(output_lines)

order_output_by_mapping(output_lines, uniprot_mappings)

write_output(output_lines, uniprot_mappings)
