#!/usr/bin/env python3 
from dicttoxml import dicttoxml
from xml.dom.minidom import parseString
from lxml import etree as ET
import pandas as pd
import os
import glob
import shlex
import subprocess
import sys
from Bio.Seq import Seq
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
import numpy as np
from yaml import load, Loader

def initialize_config(configfile):
    with open(configfile,'r') as cfile:
        info = load(cfile, Loader)
    return info

def initialize_submission(submissionfile):
    with open(submissionfile,'r') as cfile:
        subinfo = load(cfile, Loader)

    center_name = subinfo['center_name']
    stage = subinfo['stage']
    release_now = subinfo['release_now']
    project_accession = subinfo['project_accession']
    alias = subinfo['alias']
    title = subinfo['title']
    description = subinfo['description']
    batch_folder = subinfo['batch_folder']
    metadata_file = subinfo['metadata_file']
    fasta_file = subinfo['fasta_file']
    multi_fasta = subinfo['multi_fasta']
    TAXID = subinfo['TAXID']
    ORGANISM = subinfo['ORGANISM']
    HOST = subinfo['HOST']

    return center_name,stage,release_now,project_accession,alias,title,description,batch_folder,metadata_file,fasta_file,multi_fasta,TAXID,ORGANISM,HOST

def load_metadata(batch_folder,metadata_file,data_type='GISAID',taxID='',species_name='',host=''):
    """
    Params: taxID, species_name and host_name are required for data_type=='GISAID'
    ENA: https://www.ebi.ac.uk/ena/browser/view/ERC000033
    """
    
    if data_type == 'GISAID':
        df = pd.read_excel(os.path.join(batch_folder,metadata_file),sheet_name='Submissions')
        df.drop([0],inplace=True) # delete first row (header labels)
        df['taxon id'] = taxID 
        df['species'] = species_name
        df['host scientific name'] = host
        locations = df['covv_location'].tolist()
        countries = [loc.split('/')[1].strip() for loc in locations]
        regions = [loc.split('/')[2].strip() for loc in locations]
        df['geographic location (country and/or sea)'] = countries
        df['geographic location (region and locality)'] = regions
        df['collecting institution'] = df[['covv_orig_lab', 'covv_orig_lab_addr']].agg(','.join, axis=1)    
        df['host subject id'] = 'unknown'
        df.rename(columns = {
            'covv_seq_technology':'sequencing method',
            'covv_collection_date':'collection date',
            'covv_host':'host common name',
            'covv_authors':'collector name',     
            'covv_virus_name':'virus name',
            'covv_specimen':'isolation source host-associated',
            'covv_assembly_method':'library_construction protocol',
            'covv_coverage':'coverage',
            'covv_gender':'host sex',
            'covv_patient_age':'host age',
            'covv_patient_status':'host health state',
            'fn':'fasta file'
            },inplace=True) 
    elif data_type == 'SPSP': # if data was extracted from SPSP, rename columns from SPSP to ENA
        df = pd.read_json(metadata_file)
        df['host common name'] = df['host_species'].split('|')[2].strip()
        df['host scientific name'] = df['host_species'].split('|')[1].strip()
        df.rename(columns = { # !not tested!
            'species_ID':'taxon id',
            'species_name':'species',
            'strain_name':'virus name',
            'sequencing_platform':'sequencing method',
            'isolation_date':'collection date',
            'isolation_country':'geographic location (country and/or sea)',
            'isolation_canton':'geographic location (region and locality)',
            'depositor_name':'collector name',
            'depositor_institution':'collecting institution',
            'assembly_file':'fasta filename' # TODO: this doesn't exist yet in SPSP, define in db
            },inplace=True)      
    else:
        raise Exception('Unknown data_type "{0}". Supported values: GISAID, SPSP.'.format(data_type))  
          
    return df

def split_multi_fasta(multi_fasta_name,batch_folder,inputdir):
    """
    Extract multiple sequence fasta file and write each sequence in separate file.
    Modified from: https://www.biostars.org/p/340937/

    ! IMPORTANT ! 
    - We assume that we have complete genomes in multi-FASTA.
    - Since these are assumed to be complete genomes, we submit as chromosome => need to make it more general
    - Need to standardize also virus names (cf. seq_rec.id below)
    """
    
    with open(os.path.join(batch_folder,multi_fasta_name)) as FH:
        record = SeqIO.parse(FH, "fasta")
        file_count = 0
        for seq_rec in record:
            file_count = file_count + 1
            virus_name = (seq_rec.id).replace('/','_')
            filename = os.path.join(batch_folder,inputdir,virus_name)+'.fasta'
            chromosome_fn = os.path.join(batch_folder,inputdir,virus_name)+'_chr.txt' # TODO: currently assumes there's only 1 segment, i.e. complete genome
            seq_rec.description = ''
            seq_rec.id = seq_rec.id.replace('/','_').split('_')[2] # TODO: make it more general 
            with open(filename, 'w') as FHO:
                SeqIO.write(seq_rec, FHO, 'fasta') 
            gunzip = subprocess.Popen(['gzip',filename]) # gunzip once fasta is written
            gunzip.communicate()
            with open(chromosome_fn,'w') as CFN:
                CFN.write('{0}\t1\tLinear-Chromosome'.format(seq_rec.id))
            gunzip2 = subprocess.Popen(['gzip',chromosome_fn]) # gunzip once fasta is written
            gunzip2.communicate()
    if file_count == 0:
        raise Exception('No valid sequence in fasta file "{0}"'.format(multi_fasta_name))
    return 'Done splitting multi-FASTA file and writing associated chromosome files.'

def create_submission(info,alias,center_name,batch_folder,submission_type='PROJECT'): 
    # update SUBMISSION key to add alias and center_name
    mytree = ET.parse(os.path.join(info['TEMPLATE'],info['SUBMISSION'])) 
    myroot = mytree.getroot()
    if submission_type == 'PROJECT':
        for s in myroot.iter('SUBMISSION'): 
            s.set('alias', alias) 
            s.set('center_name', center_name)  

        for s in myroot.iter('ADD'): 
            s.set('source', info['PROJECT']) 
            s.set('schema', 'project')  
            
    elif submission_type == 'SAMPLE':
        for s in myroot.iter('SUBMISSION'): 
            #s.set('alias', alias) 
            s.set('center_name', center_name)  

        for s in myroot.iter('ADD'): 
            s.set('source', info['SAMPLE']) 
            s.set('schema', 'sample')  
        
    else:
        raise Exception('Unknown submission_type "{0}" in create_submission. Supported values: PROJECT or SAMPLE.'.format(submission_type))
        
    # write final xml
    print(os.path.join(batch_folder,info['SUBMISSION']))
    mytree.write(os.path.join(batch_folder,info['SUBMISSION']), encoding='utf-8', xml_declaration=True, pretty_print=True) 

def create_project(info,alias,center_name,title,description,batch_folder):
    # update PROJECT key to add alias and center_name
    mytree = ET.parse(os.path.join(info['TEMPLATE'],info['PROJECT'])) 
    myroot = mytree.getroot() 
    for s in myroot.iter('PROJECT'): 
        s.set('alias',alias) 
        s.set('center_name', center_name)  
        
    for s in myroot.iter('TITLE'): 
        s.text = title
        
    for s in myroot.iter('DESCRIPTION'): 
        s.text = description
        
    # write final xml
    print(os.path.join(batch_folder,info['PROJECT']))
    mytree.write(os.path.join(batch_folder,info['PROJECT']), encoding='utf-8', xml_declaration=True, pretty_print=True)   

    
def run_cmd(cmd_line):
    """Execute command line.
    :param cmd_line: the string of command line
    :return output: the string of output from execution
    From: https://github.com/usegalaxy-eu/ena-upload-cli/blob/master/ena_upload/ena_upload.py
    """

    args = shlex.split(cmd_line)
    process = subprocess.Popen(args,
                     stdout=subprocess.PIPE, 
                     stderr=subprocess.PIPE)
    output, stderr = process.communicate()
    return output

def submit_metadata(info,batch_folder,outputdir,stage='DEV',submission_type='PROJECT'):
    """
    Params: submission_type can be either PROJECT, RELEASE or SAMPLE
    """
    
    if submission_type == 'PROJECT':
        cmd = 'curl -u {0}:{1} -F "SUBMISSION=@{2}" -F "PROJECT=@{3}" {4}'.format(
            info['USER'],
            info['PWD'],
            os.path.join(batch_folder,info['SUBMISSION']),
            os.path.join(batch_folder,info['PROJECT']),
            info['URL'][stage])
        print(cmd)
    elif submission_type == 'RELEASE':
        cmd = 'curl -u {0}:{1} -F "SUBMISSION=@{2}" -F "PROJECT=@{3}" {4}'.format(
            info['USER'],
            info['PWD'],
            os.path.join(batch_folder,info['RELEASE']),
            os.path.join(batch_folder,info['PROJECT']),
            info['URL'][stage])
        print(cmd)
    elif submission_type == 'SAMPLE':
        cmd = 'curl -u {0}:{1} -F "SUBMISSION=@{2}" -F "SAMPLE=@{3}" {4}'.format(
            info['USER'],
            info['PWD'],
            os.path.join(batch_folder,info['SUBMISSION']),
            os.path.join(batch_folder,info['SAMPLE']),
            info['URL'][stage])
        print(cmd)
    else:
        raise Exception('Unknown submission_type "{0}" in submit_metadata. Supported values: PROJECT, RELEASE, SAMPLE.'.format(submission_type))  
    
    # run command and retrieve accession
    receipt = run_cmd(cmd)
    
    # write log in OUTPUT dir
    logfilename = os.path.join(outputdir,'ENA_{0}_receipt.log'.format(submission_type))
    with open(logfilename,'a+') as ofile:
        ofile.write(receipt.decode('utf-8'))
        ofile.write('\n\n\n')
    return receipt

def make_update(update,action,receiptDate,status):
    update = [(element.get('alias'), element.get('accession'),
               receiptDate, status[action]) for element in update]
    # used for labelling dataframe
    labels = ['alias', 'accession', 'submission_date', 'status']
    df = pd.DataFrame.from_records(update, columns=labels)
    return df
    
def process_receipt(receipt, action='ADD'):
    '''Process submission reciept from ENA.
    :param receipt: a string of XML
    :return schema_update: a dictionary - {schema:update}
                           schema: a string - 'study', 'sample',
                                              'run', 'experiment'
                           update: a dataframe with columns - 'alias',
                                   'accession', 'submission_date'
    Modified from: https://github.com/usegalaxy-eu/ena-upload-cli/blob/master/ena_upload/ena_upload.py
    '''
    receipt_root = ET.fromstring(receipt)

    success = receipt_root.get('success')

    if success != 'true':
        errors = []
        for element in receipt_root.findall('MESSAGES/ERROR'): # should be receipt_root, not receipt (corrected)
            error = element.text
            errors.append(error)
        errors = '\nOops:\n' + '\n'.join(errors)
        sys.exit(errors)

    # define expected status based on action
    status = {'ADD': 'added', 'MODIFY': 'modified',
              'CANCEL': 'cancelled', 'RELEASE': 'released'}

    receiptDate = receipt_root.get('receiptDate')

    study_update = receipt_root.findall('PROJECT') # replaced STUDY by PROJECT
    sample_update = receipt_root.findall('SAMPLE')
    experiment_update = receipt_root.findall('EXPERIMENT')
    run_update = receipt_root.findall('RUN')

    schema_update = {}  # schema as key, dataframe as value

    if study_update:
        schema_update['study'] = make_update(study_update,action,receiptDate,status)

    if sample_update:
        schema_update['sample'] = make_update(sample_update,action,receiptDate,status)

    if experiment_update:
        schema_update['experiment'] = make_update(experiment_update,action,receiptDate,status)

    if run_update:
        schema_update['run'] = make_update(run_update,action,receiptDate,status)

    return schema_update
    
def create_release(info,alias,center_name,batch_folder,accession):
    # update SUBMISSION key to add alias and center_name
    mytree = ET.parse(os.path.join(info['TEMPLATE'],info['RELEASE'])) 
    myroot = mytree.getroot() 
    for s in myroot.iter('SUBMISSION'): 
        s.set('alias', alias) 
        s.set('center_name', center_name)  
        
    for s in myroot.iter('RELEASE'): 
        s.set('target', accession)  
        
    # write final xml
    print(os.path.join(batch_folder,info['RELEASE']))
    mytree.write(os.path.join(batch_folder,info['RELEASE']), encoding='utf-8', xml_declaration=True, pretty_print=True) 

def concatenate_xml(tmp_folder,fname,m,center_name):
    xml_files = glob.glob(os.path.join(tmp_folder,'*.xml'))
    node = []
    txt = ['<SAMPLE_SET>']
    for i,xmlFile in enumerate(xml_files):  
        tree = ET.parse(xmlFile)
        root = tree.getroot()
        for child in root:
            node = str(ET.tostring(child))[2:].strip("'")
            txt.append(node)
    
    txt.append('</SAMPLE_SET>')
    xml = ''.join(txt)
    
    # write final xml
    mytree = ET.ElementTree(ET.fromstring(xml))
    ET.tostring(mytree, pretty_print=True)
    mytree.write(fname, encoding='utf-8', xml_declaration=True, pretty_print=True)
    
    # delete files used to concatenate
    subprocess.run(["rm", "-rf", tmp_folder])
    
def create_sample(info,metadata,center_name,batch_folder,stage):
    """
    https://www.ebi.ac.uk/ena/browser/view/ERC000033
    """
    
    tmp_samples_folder = os.path.join(batch_folder,'tmp_samples')
    if os.path.isdir(tmp_samples_folder):
        print('Deleting folder: {0}'.format(tmp_samples_folder))
        subprocess.run(["rm", "-rf", tmp_samples_folder])
    os.mkdir(tmp_samples_folder)
    
    for i,m in metadata.iterrows():
        
        txt = dict([])
        txt['SAMPLE'] = {}
        txt['SAMPLE']['TITLE'] = m['virus name']
        txt['SAMPLE']['SAMPLE_NAME'] = {}
        txt['SAMPLE']['SAMPLE_NAME']['TAXON_ID'] = m['taxon id']
        txt['SAMPLE']['SAMPLE_NAME']['SCIENTIFIC_NAME'] = m['species']
        txt['SAMPLE']['SAMPLE_NAME']['COMMON_NAME']  = ''
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'] = []

        attr = {}
        attr['TAG'] = 'isolate'
        attr['VALUE'] = m['virus name'].replace('/','_')+'_isolate'
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)
        
        attr = {}
        attr['TAG'] = 'sequencing method'
        a = m['sequencing method']  
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]  
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'collection date'
        a = m['collection date']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'geographic location (country and/or sea)'
        a = m['geographic location (country and/or sea)']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'geographic location (region and locality)'
        a = m['geographic location (region and locality)']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'host common name'
        a = m['host common name']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'host subject id'
        a = m['host subject id']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'host health state'
        a = m['host health state']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'host sex'
        a = m['host sex']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'host scientific name'
        a = m['host scientific name']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'collector name'
        a = m['collector name']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        attr = {}
        attr['TAG'] = 'collecting institution'
        a = m['collecting institution']
        attr['VALUE'] = [a if ((a!='unknown') and (a!='NaN')) else 'not provided'][0]
        txt['SAMPLE']['SAMPLE_ATTRIBUTES'].append(attr)

        my_item_func = lambda x: x[:-1] # rename <item> with parent name without "s", i.e. SAMPLE_ATTRIBUTE    
        xml = dicttoxml(txt, attr_type=False, custom_root='SAMPLE_SET', item_func=my_item_func)
        
        # print to file
        dom = parseString(xml)
        fname = os.path.join(tmp_samples_folder,m['virus name'].replace('/','_')+'.xml')
        ofile = open(fname, 'w')
        dom.writexml(ofile)
        ofile.close()

        # update SAMPLE key to add alias and center_name
        mytree = ET.parse(fname) 
        myroot = mytree.getroot() 
        for s in myroot.iter('SAMPLE'): 
            if stage == 'DEV':
                s.set('alias', m['virus name']+str(np.random.randint(0,1000000,1))) # add a random number, otherwise if alias already exists it will be refused
            elif stage == 'PROD':
                s.set('alias', m['virus name']) 
            s.set('center_name', center_name)  

        # write final xml
        mytree.write(fname, encoding='utf-8', xml_declaration=True, pretty_print=False)
        
    concatenate_xml(tmp_samples_folder,os.path.join(batch_folder,info['SAMPLE']),m,center_name)

def write_manifest(inputdir,project_accession,sample_acc,m,stage):
    df = pd.DataFrame()
    df['tags'] = ['STUDY',
                  'SAMPLE',
                  'ASSEMBLYNAME',
                  'ASSEMBLY_TYPE',
                  'COVERAGE',
                  'PROGRAM',
                  'PLATFORM',
                  'MOLECULETYPE',
                  'FASTA',
                  'CHROMOSOME_LIST']
    if stage == 'DEV':
        assembly_name = m.iloc[0]['virus name'].replace('/','_')+str(np.random.randint(0,1000000,1)[0])
    else:
        assembly_name = m.iloc[0]['virus name'].replace('/','_')

    df['attributes'] = [project_accession['study'].iloc[0]['alias'],
                        sample_acc,
                        assembly_name,
                        'COVID-19 outbreak',
                        m.iloc[0]['coverage'],
                        m.iloc[0]['library_construction protocol'],
                        m.iloc[0]['sequencing method'],
                        'genomic RNA',
                        m.iloc[0]['virus name'].replace('/','_')+'.fasta.gz',
                        m.iloc[0]['virus name'].replace('/','_')+'_chr.txt.gz',
    ]
    
    # save file
    manifestFileName = os.path.join(inputdir,m.iloc[0]['virus name'].replace('/','_'))+'_MANIFEST.txt'
    df.to_csv(manifestFileName,sep='\t',header=False,index=False)
    
    return manifestFileName

def submit_data_webin_cli(info,metadata,batch_folder,center_name,stage,project_accession,samples_accession,inputdir,outputdir):
    # create OUTPUT dir
    os.system('mkdir -p {0}'.format(outputdir))
    
    receipt = []
    # write manifest files and run webin-cli
    for sample_alias,sample_acc in zip(samples_accession['sample']['alias'],samples_accession['sample']['accession']):
        virus_name = sample_alias.split('[')[0] # randint is added with square brackets for dev stage
        metadata_virus = metadata.where(metadata['virus name']==virus_name).copy()
        metadata_virus.dropna(how='all',inplace=True)
        manifestFileName = write_manifest(inputdir,project_accession,sample_acc,metadata_virus,stage)

        # run webin-cli
        stage_value = {'PROD':'','DEV':'-test'}
        cmd = 'java -jar {0} -context genome -userName {1} -password {2} -centerName "{3}" -manifest {4} -inputDir {5} -outputDir {6} -submit {7}'.format(
            info['WEBIN_CLI_PATH'],info['USER'],info['PWD'],center_name,manifestFileName,inputdir,outputdir,stage_value[stage])
        print(cmd)
        receipt.append(run_cmd(cmd))

    return receipt

def main():
    """
    Usage: ./gisaid2ena.py <configfile> <submissionfile>
    """
    if len(sys.argv) < 3:
        raise Exception("Usage: ./gisaid2ena.py <config file> <submission file>")

    configfile = sys.argv[1] # config yaml
    submissionfile = sys.argv[2] # submission yaml

    # (0) Initialize basic info like credentials and template folders
    info = initialize_config(configfile) # load yaml
    center_name,stage,release_now,study_accession,alias,title,description,batch_folder,metadata_file,fasta_file,multi_fasta,TAXID,ORGANISM,HOST = initialize_submission(submissionfile) # load yaml

    inputdir = os.path.join(batch_folder,'INPUT')
    outputdir = os.path.join(batch_folder,'OUTPUT')
    if os.path.isdir(inputdir):
        print('Deleting folder: {0}'.format(inputdir))
        subprocess.run(["rm", "-rf", inputdir])
    if os.path.isdir(outputdir):
        print('Deleting folder: {0}'.format(outputdir))
        subprocess.run(["rm", "-rf", outputdir])
    os.system('mkdir -p {0}'.format(inputdir))
    os.system('mkdir -p {0}'.format(outputdir))
        
    metadata = load_metadata(batch_folder,metadata_file,'GISAID',TAXID,ORGANISM,HOST)

    if multi_fasta:
        split_multi_fasta(fasta_file,batch_folder,inputdir)

    # (1a) Register a study (project)
    # TODO: if study alias already exists, process error, retrieve accession and continue
    if len(study_accession) == 0: # if project does not exist yet, register it
        create_submission(info,alias,center_name,batch_folder)
        create_project(info,alias,center_name,title,description,batch_folder)
        project_receipt = submit_metadata(info,batch_folder,outputdir,stage,'PROJECT')
        project_accession = process_receipt(project_receipt, 'ADD')
    else:
        project_accession = dict([])
        project_accession['study'] = pd.DataFrame()
        project_accession['study']['accession'] = [study_accession]
        project_accession['study']['alias'] = [alias]

    # (1b) Release studies immediately (i.e. make public)
    if release_now:
        for acc in project_accession['study']['accession']:
            create_release(info,alias,center_name,batch_folder,acc)
            submit_metadata(info,batch_folder,outputdir,stage,'RELEASE')

    # (2a) Register samples
    create_submission(info,alias,center_name,batch_folder,'SAMPLE')
    create_sample(info,metadata,center_name,batch_folder,stage)
    sample_receipt = submit_metadata(info,batch_folder,outputdir,stage,'SAMPLE')
    samples_accession = process_receipt(sample_receipt, 'ADD')

    # (2b) Release samples immediately
    if release_now:
        for acc in samples_accession['sample']['accession']:
            create_release(info,alias,center_name,batch_folder,acc)
            submit_metadata(info,batch_folder,outputdir,stage,'RELEASE')

    # (3) Submit assemblies
    submission_receipt = submit_data_webin_cli(info,metadata,batch_folder,center_name,stage,project_accession,samples_accession,inputdir,outputdir)
    print(submission_receipt)

    print('\n\nDONE. Please refer to the OUTPUT/ folder in batch_folder for ENA receipts and other logs.')

if __name__ == "__main__":
    main()