#!/usr/bin/env python3
import subprocess as sp
import yaml
import argparse
import os, re, sys, time
from pathlib import Path
#import multiprocessing as mp ## for Pool and Event
import threading,queue
import random ## for test queue part

##  update keyword dict() from sequence of yml files
##      Default.yml
##      Machine_HOST.yml, HOST should be detected from host
##          module load or env settings
##      plotCase.yml, which located at work directory
##          Contain case name, path
##          Contain Recipes: a list of which Recipes to run
##      Recipes.yml, predefined data, variables, plots and layout.
##          Recipes.yml is from plotCase.yml, but keywords should be use plotCase.yml
##          Important variable: Scripts
##              A list of template code (plotScript or Script) and keywords (ALL CAP) to replace in template
##
##          
##  create script/ncl/py from template and modify it with keyword dict()
##  
##  Initial:
##      Create plotCase.yml in work directory with profile
##      profile: What kind of data/analysis to plot
##          ex. NorCPM seasonal hindcast, decadal hindcast, or MPIESM IE analysis
##          Contain recipes and ask(predefine) necessary variables(keywords)
##          
## Plainize, read all into a big dict()
class diag_norcpm():
    def __init__(self,workdir='.',diag_root=''):
        ## get script path
        if not diag_root: diag_root = os.path.dirname(os.path.realpath(__file__))
        self.diag_root = diag_root
        ## get workdir with abspath
        self.workdir = os.path.abspath(workdir)
        self.host = self._get_hostname()
    
        self.mustBeList = ['Recipes','plotRecipes','Scripts','plotScripts','Depends','Tags']
        self.progress = {}
        self.event = threading.Event()
        self.q = queue.Queue()
        self.dryrun = False ## true mean not really run scripts, should be set at plotCase.yml
        self.termwidth = 80
        self.include = False  ## prior directory for yml, recipe and code templates
    def _get_hostname(self):
        ## need rewrite
        return 'betzy'
    def _readYML(self,filepath,updatedic={}):
        ## read config.yml
        typelist = type([])
        dd = dict()
        if os.path.isfile(filepath): 
            with open(filepath,'r') as f:
                dd = yaml.load(f,Loader=yaml.BaseLoader)
                dd.update(updatedic)
                ## break non-KEYWORD string into array
                for i in dd.keys():
                    if i in self.mustBeList and (type(dd[i]) != typelist) :
                        dd[i] = list(filter(None,re.split(',|\s',dd[i])))
                        
                ## backward compilble
                ## combine plotScripts to Scipts, plotRecipes to Recipes
                if dd.get('plotRecipes'):
                    if dd.get('Recipes'): 
                        dd['Recipes'].extend(dd.pop('plotRecipes'))
                    else:
                        dd['Recipes'] = dd.pop('plotRecipes')
                if dd.get('plotScripts'):
                    if dd.get('Scripts'): 
                        dd['Scripts'].extend(dd.pop('plotScripts'))
                    else:
                        dd['Scripts'] = dd.pop('plotScripts')
                if dd.get('Depends'):
                    dd['Depends'] = [ i.replace('.yml','') for i in dd['Depends'] ]
                    
        return dd
    def _readYMLIfInc(self,filepath,updatedic={}):
        if not self.include: return self._readYML(filepath,updatedic)
        fn = os.path.basename(filepath)
        if os.path.isfile(f'{self.include}/{fn}'): return self._readYML(f'{self.include}/{fn}',updatedic)
        return self._readYML(filepath,updatedic)

    def _KEYWORDs(self,dd):
        ## return only KEYWORDs dict
        a = {}
        for k,v in dd.items():
            if k.isupper(): a.update({k:v})
        return a

    def _gen_script(self,template:str,ofn:str,keywords:dict):
        OutputReplaces = False ## debug
        ok = True

        diagcmd  = ';DIAG_NORCPM;'  ## default setting of template, KEYWORD should set after space
            ## ex.  ;DIAG_NORCPM; VARONE: this is var 1
            ## ex.  ##;;DIAG_NORCPM;; VARONE: this is var 1
            ## ex.  ##;;DIAG_NORCPM;SomeCommentWithoutSpace; VARONE: this is var 1
        diagNeed = ';NeedBeReplaced' ## those KEYWORDS need be replaced, stop if not in keywords
            ## ex.  ;; ;NeedBeReplaced: VARONE, VARTWO, VARTHREE
            ## ex.  ##;NeedBeReplaced  VARONE, VARTWO, VARTHREE
            ## ex.  ##;NeedBeReplaced##AlsoCommentWithoutSpace  VARONE, VARTWO, VARTHREE

        ## read template
        with open(template,'r') as f:
            tmpstr = f.read()

        ## get necessary and default KEYWORDs in list() and dict()
        ### necessary KEYWORDs
        necKeys = []
        for line in tmpstr.split('\n'):
            if diagNeed in line:
                k = re.sub(f'.*{diagNeed}[^ ]*  *(.*)',r'\1',line)
                if k: necKeys.extend(list(filter(None,re.split(',|\s',k))))
        ### check necessary KEYWORDs in keywords dict()
        if necKeys:
            for i in necKeys:
                if not i in keywords.keys(): 
                    ok = False
                    print(f'WARNING: {i} in necessary in {template}')

        ### default KEYWORDs
        kws = {}
        for line in tmpstr.split('\n'):
            cmdstr = re.search(diagcmd+r'[^ ]*  *(.*)',line)
            if cmdstr:
                cmd = re.search("^([^ :]*): *(.*)",cmdstr.group(1))
                if cmd:
                    kws.update({cmd.group(1):cmd.group(2)})
                else:
                    print('diag.py._gen_script(): cannot parse: '+cmdstr.group(0))
        ### update default KEYWORDs with keywords
        kws.update(keywords)

        ## replace KEYWORDs in tmpstr from long to short, watch out indent
        keys = list(kws.keys())
        keys.sort(key=len,reverse=True)
        ### print replacing strings
        if OutputReplaces:
            print('======== keys to replace with order')
            for i in keys:
                print('    '+i+': '+kws[i])
            print('===================================')

        out = ""
        for j in tmpstr.split("\n"):
            if diagcmd in j:
                out += j+'\n'
                continue
            elif diagNeed in j:
                out += j+'\n'
                continue
            else:
                afterstr = j
                for i in keys:
                    ## need to rewrite for keep indent in template
                    dest = kws[i].replace('\\','\\\\')  ## for escape
                    if '\n' in dest:   ## contain new line, need indent
                        ## need find base position
                        ##     1. 1st line is align to 2nd line
                        ##     2. 2nd line is more indent than 1st line
                        ## (v) 3. Only add indent if KEYWORD is leading by a indent(space or tab)
                        indent = re.search('^([\s]+)'+i,afterstr)
                        if indent: dest = dest.replace('\n','\n'+indent.group(1))
                    afterstr = re.sub(i,dest,afterstr)
                out += afterstr+'\n'
        ## write to ofn
        Path(os.path.dirname(ofn)).mkdir(parents=True,exist_ok=True)
        with open(ofn,'w') as f:
            f.write(out)
        return ok

    def _gen_recipes(self,evthing):
        ## evthing['Recipes']['Scripts']     ['Script']
        ##     or            ['plotScripts'] ['plotScript']
        ##  path = f'{self.workdir}/{Recipe}/{num}_{script}'

        ok = True
        recipScripts = {} ## {Recipes: [Scripts]}
        keywords = evthing.copy()
        for i in evthing.get('Recipes'):
            r = keywords.copy()
            r.update(i)
            recipScripts[r['Recipe']] = []
            Scripts = i.get('Scripts') or i.get('plotScripts')
            numwidth = len(str(len(Scripts)))
            for num,j in enumerate(Scripts,start=1):
                r1 = r.copy()
                ## varPack: a list of KEYWORD, apply from predefined setting
                if j.get('varPack'):
                    for k in re.split(',| |\n',j.get('varPack')):
                        if k: r1.update(r1.get(k))
                r1.update(j)
                script = r1.get('Script') or r1.get('plotScript')

                scriptFN = r1.get('scriptRename')
                if not scriptFN: scriptFN = re.sub('.*/','',script)  ## only file name

                template = f'{self.include}/{script}'
                if not os.path.isfile(template): template = f'{self.diag_root}/Codes/{script}'
                RecipeBaseName = os.path.basename(r1['Recipe'])
                ofn = f'{self.workdir}/{RecipeBaseName}/{num:0{numwidth}}-{scriptFN}'
                recipScripts[r1['Recipe']].append(ofn)
                ok = self._gen_script(template=template,ofn=ofn,keywords=self._KEYWORDs(r1)) and ok

        if ok: return recipScripts
        return ok

    def _expend_Recipes(self,rNames:list):
        recipes = list()
        for i in rNames:
            iiyml = i if '.yml' in i else i+'.yml'
            ii    = iiyml.replace('.yml','')
            ymlfile = f'{self.include}/{iiyml}'
            if not os.path.isfile(ymlfile): ymlfile = f'{self.diag_root}/Recipes/{iiyml}'
            recipes.append(self._readYMLIfInc(ymlfile,updatedic={'Recipe':ii}))
        return recipes

    def main(self,argv):
        args = argv[1:] or ''
        if os.path.isfile(f'{self.workdir}/plotCase.yml'): 
            return self.doplot(args)
        else:
            return self.init(args)
            
    def init(self,argv=[]):
        ## init workdir of diag_norcpm
        ## actually create plotCase.yml from Profiles
        ## and Makefile point to diag.py
        profile = ''
        if argv: profile = argv[0]
        profileDir = f'{self.diag_root}/Profiles'
        availables = [ i.replace('.yml','') for i in os.listdir(profileDir) if os.path.isfile(f'{profileDir}/{i}')]
        availables.sort()
        profileName = profile.replace('.yml','')
        profileFN = profileName+'.yml'

        if not os.path.isfile(f'{self.workdir}/Makefile'): 
            text  = ''
            text += 'DIR := $(shell basename $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))))'
            text +=  'doplot:\n'
            text += f'	@{self.diag_root}/diag.py\n'
            text +=  'edit:\n'
            text += f'	@vim plotCase.yml\n'
            text +=  'upload:\n'
            text += f'	ture ## for uploading figures, try use $(DIR)\n'
            with open(f'{self.workdir}/Makefile','w') as f:
                f.write(text)

        if profile:
            dst = f'{self.workdir}/plotCase.yml'
            if profileName in availables: 
                with open(f'{profileDir}/{profileFN}','r') as f:
                    a = f.read()
                with open(dst,'w') as f:
                    f.write(a)
                print('plotCase.yml created, please look in to for detail settings.')
                return
            else:
                print(f'{profileName} is not exists, available profiles are:')
                print('  '+', '.join(availables))
                return
        else:
            print('Available profiles are:')
            print('  '+', '.join(availables))

    def doplot(self,argv=[]):
        evthing = dict()
        ## Get Include first
        plotCaseyml = self._readYML(f'{self.workdir}/plotCase.yml')
        self.include = plotCaseyml.get('Include')

        ## read Default.yml
        evthing.update(self._readYMLIfInc(f'{self.diag_root}/Defaults.yml'))
        ## read Machine_HOST.yml
        host = self._get_hostname()
        evthing.update(self._readYMLIfInc(f'{self.diag_root}/Machine_{host}.yml'))
        ## read plotCase.yml
        ##      read Recipes
        evthing.update(plotCaseyml)
        del plotCaseyml

        ## if just testing use dryrun
        self.dryrun = evthing.get('Dryrun') or evthing.get('DryRun')
        if self.dryrun == 'False': self.dryrun = False

        ## Some predifined KEYWORDs
        evthing['CODEDIR'] = f'{self.diag_root}/Codes'

        ##      expend Recipes
        evthing['Recipes'] = self._expend_Recipes(evthing.get('Recipes'))
        for n in range(10): # max 10 layers of dependency
            nowhave = [ i['Recipe'] for i in evthing['Recipes'] ]
            depends = []
            for i in evthing['Recipes']:
                if i.get('Depends'): depends.extend(i['Depends'])
            depends = [ i for i in depends ]
            need = set(depends).difference(nowhave)
            if not need: break
            evthing['Recipes'].extend(self._expend_Recipes(need))
                
        ## gen_scripts for all scipts at each Recipe directory
        recipScripts = self._gen_recipes(evthing)
        if not recipScripts: return 

        ## Queue system and run scripts
        maxproc = int(evthing.get('MaxProc') or evthing.get('Maxprocess'))
        depends = { i['Recipe']: i.get('Depends') for i in evthing['Recipes']  }
        self._queue_run(maxproc=maxproc,depends=depends,recipScripts=recipScripts)

        self._mkindex(evthing)

    def _mkindex_entry(self):
        ## recipes info to html
        entry_html = ''
        dirs = next(os.walk(self.workdir), (None, None, []))[1]
        dirs = [ i for i in dirs if os.path.isfile(f'{self.workdir}/{i}/index.html')]
        for i in dirs:
            if os.path.isfile(f'{self.workdir}/{i}/entry.html'):  ## entry html generated by mk_index_page.sh
                with open(f'{self.workdir}/{i}/entry.html','r') as f:
                    entry_html += f'<!--- {i}/entry.html  --->\n'
                    entry_html += f.read()
                    entry_html += '\n'
            else:  ## if empty
                entry_html += f'<!--- {i}/entry.html is not exist --->\n'

        return entry_html

    def _mkindex(self,evthing):
        ## make index.html at self.workdir
        ## not done yet
        fn = f'{self.workdir}/index.html'

        #### html header
        html = ''
        html += '<!DOCTYPE html> \n'
        html += '<html> \n'
        html += '<body> \n'

        #### title
        #html += '<h3>diag_norcpm:</h3><h1>'+plotCase+'</h1>\n'
        html += '<hr>\n'
        #### recipe items
        entries = self._mkindex_entry()
        html += entries

        #### html footer
        html += '<p align="right"><small>contact: pgchiu (Ping-Gin.Chiu_at_uib.no)</small></p> \n'
        html += '</body> \n'
        html += '</html> \n'

        #### write index html
        with open(f"{self.workdir}/index.html",'w') as f:
            f.write(html)
        pass

    def _update_progress(self,recipe,status):
        self.progress[recipe] = status
        self.event.set()
        pass

    def _sp_run(self,workdir,script,logf,dryrun=False):
        ## run script and output log file
        ## logf should be a file

        fn = script

        os.chdir(workdir)
        rootfn,suffix = os.path.splitext(fn)
        if suffix == '.ncl': 
            cmd = 'ncl -Q'
        elif suffix == '.m': 
            cmd = 'matlab'
        elif suffix == '.py': 
            cmd = 'python'
        elif suffix == '.sh': 
            cmd = 'sh -e'
        else: 
            cmd = 'sh -e'  # bad idea
        logf.write('>>>>>>>>>>>>>>>> running '+cmd+' '+fn+'\n')
        logf.flush()
        if dryrun:
            time.sleep(random.uniform(0,1))
            logf.write('>>>>>>>>>>>>>>>> '+cmd+' '+fn+' DryRun \n')
        else:
            start = time.perf_counter()
            sp.run(' '.join([cmd,fn]),shell=True,stdout=logf,stderr=logf)
            end = time.perf_counter()
            logf.write('>>>>>>>>>>>>>>>> '+cmd+' '+fn+' [done with %2.2f secs.]\n'%(end-start))
            logf.flush()

        pass
    def _job_worker(self):
        ## get scripts from queue: self.q
        ## scripts are a list of scripts of abspath
        ## they should be at same recipe, then that work dir are same
        ## chdir to script path first
        ## 
        ## use self._update_progress() to update progress
        ##      self._update_progress(recipe,'(3/10)')
        ## or 
        ##      self._update_progress(recipe,'DONE')
        ## 
        while 1:
            recipe,scripts = self.q.get()
            ns = len(scripts)
            recipeBaseName = os.path.basename(recipe)
            logfn = f'{self.workdir}/{recipeBaseName}.log'
            logf = open(logfn,'w')
            for n,i in enumerate(scripts):
                workdir = os.path.dirname(i)
                scriptFN = os.path.basename(i)
                self._update_progress(recipe,f'({n+1}/{ns}) {scriptFN}')
                self._sp_run(workdir,scriptFN,logf,dryrun=self.dryrun)
            logf.close()
            logf = open(logfn,'r')
            log = logf.read().lower() ## for case insentive
            logf.close()
            iserr = ''
            if any( i in log for i in ['fault','fatal','error','not found'] ): iserr = ' , with error?'
            self._update_progress(recipe,'DONE - '+f'{recipeBaseName}.log'+iserr)
            self.q.task_done()
            
    def _queue_run(self,maxproc:int=1,recipScripts={},depends={}):
        ## Run Recipes
        ## fancy progress for running multiple recipes (use curse?)
        ### ex:
            ### Recipe1 (10/10), DONE in 300 sec, log: Recipe1.log
            ### Recipe2 (4/10), Running 04_a.py
            ### Recipe3 (3/10), Running 03_a.ncl
            ### Recipe4 (2/10), Running 02_b.sh
            ### Recipe5 (0/10), Waiting for Recipe2
            ### Recipe6 (0/10), Waiting...
            ### Recipe7 (0/10), Waiting...

        ## threading + queue, self.q
        ## start workers
        for i in range(maxproc):
            threading.Thread(target=self._job_worker, daemon=True).start()

        ## set init status
        for i in recipScripts.keys():
            if depends.get(i):
                self.progress[i] = 'Pending for '+','.join(depends[i])
            else:
                self.progress[i] = 'Pending'

        self._refresh_monitor(first=True) ## 
        while True:
            if all([ i[:4] == 'DONE' for i in self.progress.values() ]): break
            done = [ i for i in self.progress.keys() if self.progress[i][:4] == 'DONE' ]
            pend = [ i for i in self.progress.keys() if self.progress[i][:4] == 'Pend' ]
            torun = []
            pend.sort()
            for i in pend:
                if not depends.get(i): 
                    torun.append(i)
                elif all([ j in done for j in depends[i] ]): 
                    torun.append(i)
                else:
                    notdone = list(set(depends[i]).difference(done))
                    notdone.sort()
                    self.progress[i] = 'Pending for '+','.join(notdone)
            for i in torun:
                self.progress[i] = 'Waiting for worker'

            if torun:
                for i in torun:
                    self.q.put((i,recipScripts[i]))

            self._refresh_monitor()
            self.event.wait()
            self.event.clear()
        self._refresh_monitor() ## 
        
    def _refresh_monitor(self,first=False):
        progress = self.progress.copy()
        keys = list(progress.keys())
        keys.sort()
        nline = len(keys)+2
        termwidth = self.termwidth

        klen = max([len(i) for i in keys])
        vlen = termwidth - klen  -3

        ## output
        if not first: print(f'\033[{nline}A',end='\033[K')
        print('='*(termwidth-1))
        for k in keys:
            l = min([len(progress[k]),vlen])
            print(f'{k:{klen}}: {progress[k][:l]}',end='\033[K\n')
        print('='*(termwidth-1))

if __name__ == '__main__':
    p = diag_norcpm()
    p.main(sys.argv)
