#!/usr/bin/env python3

import os,re,sys
import netCDF4 as nc4
import time
import numpy as np

class norcpm_data():
    ## provide a easy method to access ensemble data
    def __init__(self,path,casename='',memtag='mem',ocngridfile=''):
        ## path: to directory contains all members
        ## something like:
        ##   path/case_mem01
        ##   path/case_mem02
        ##   path/case_mem03
        ##   ...
        self.path = path.rstrip('/')

        self.casename = casename or self.path.split('/')[-1]
        self.memberCasenames = []
        self.membersPath = []

        subdirs = os.listdir(self.path)
        subdirs.sort()
        self.memberCasenames = [ i for i in subdirs if re.match(f'{self.casename}_{memtag}[\d]+',i) ]
        self.membersPath = [ f'{self.path}/{i}' for i in self.memberCasenames ]
        self.nmember = len(self.memberCasenames)
        self.components = ['atm','ocn','ice','lnd'] 
        self.modeltag = {'atm': 'cam2'
                        ,'lnd': 'clm2'
                        ,'ice': 'cice'
                        ,'ocn': 'micom'
                        }
        self.modeltag = {'atm': 'cam'
                        ,'lnd': 'clm2'
                        ,'ice': 'cice'
                        ,'ocn': 'blom'
                        }
        self.varnames = {}
        self.startDate = ''
        startDate = re.match('.*[^\d](\d\d\d\d\d\d\d\d)$',self.casename)
        if startDate: self.startDate = startDate.group(1)

        ## ocn grid file, needed when read ocn data 
        self.ocngridfile = ocngridfile

    def _getVarnames(self,fn):
        with nc4.Dataset(fn,'r') as f:
            #print([i.name for i in f.variables])
            vns = list(f.variables.keys())

        return vns
    def _collect_varnames(self,component='atm',mem=1):
        ## collect varnames from data files of 1 component
        if self.varnames.get(component): return
        path = f'{self.membersPath[mem]}/{component}/hist'
        hist = ['h0','h1','h2','h3','h4','h5']
        if component in ['ocn']: hist = ['hm','hy']
        if component in ['ice']: hist = ['h']
        fns = os.listdir(path)
        self.varnames[component] = {}
        for i in hist:
            for j in fns:
                if f'.{i}.' in j:
                    self.varnames[component][i] = self._getVarnames(f'{path}/{j}')
                    break

    def _fn(self,component,member,hist,year,month,day='',hr=''):
        if type(member) == type(0): member = self.memberCasenames[member-1]
        path = f'{self.path}/{member}/{component}/hist/{member}.{self.modeltag[component]}.{hist}.{year:04}-{month:02}'
        if day: 
            path += f'-{day:02}'
            if hr:
                sec = hr*3600
                path += f'-{sec:05}'
        path += '.nc'
        if os.path.isfile(path): return path
        print(f'{path} not found')
            
    def _fn_variable_(self,vn,component='',year='',month='',day='',hr=''):
        ## find file name with given variable and time 
        pass

    def _histtag_monthly(self,component):
        hist = 'h0'
        if component == 'ocn': hist = 'hm'
        if component == 'ice': hist = 'h'
        return hist
        
    def _get_atm_monthly_variable(self,vn,component,member=0,year=0,month=0,latbe=[-90,90],lonbe=[0,360],dims=False):
        hist = self._histtag_monthly(component)
        with nc4.Dataset(self._fn(component,member,hist,year,month)) as f:
            var = f.variables[vn][:]
            if vn.lower() == 'sst': var = np.ma.masked_equal(var,0.0)
            var.attrs = dict()
            #var.dims = f.variables[vn].get_dims()
            for i in f.variables[vn].ncattrs():
                var.attrs[i] = f.variables[vn].getncattr(i)
            if dims:  ## 1d coordinate
                ##var.dims = [ {'name':i.name,'size':i.size,'value':f.variables[i.name][:]} for i in f.variables[vn].get_dims() ]
                var.dims = []
                for i in f.variables[vn].get_dims() :
                    dimvalue = f.variables[i.name][:] if i.name in f.variables else np.ma.array(range(i.size))
                    dimattr = {'name':i.name,'size':i.size,'value':dimvalue}
                    for j in f.variables[i.name].ncattrs():
                        dimattr[j] = f.variables[i.name].getncattr(j)
                    
                    var.dims.append(dimattr)
        return var
    def _get_ocn_monthly_variable(self,vn,component,member=0,year=0,month=0,latbe=[-90,90],lonbe=[0,360],dims=False):
        hist = self._histtag_monthly(component)
        with nc4.Dataset(self._fn(component,member,hist,year,month)) as f:
            var = f.variables[vn][:]
            if vn.lower() == 'sst': var = np.ma.masked_equal(var,0.0)
            var.attrs = dict()
            #var.dims = f.variables[vn].get_dims()
            for i in f.variables[vn].ncattrs():
                var.attrs[i] = f.variables[vn].getncattr(i)
        if dims:  ## 2d coordinate, not done yet
            var.dims = []
            coords = var.attrs['coordinates'].split()
            with nc4.Dataset(self.ocngridfile) as f:
                for i in coords:
                    d = f.variables[i]
                    dimvalue = d[:]
                    dimattr = {'name':i,'size':dimvalue.shape,'value':dimvalue}
                    for j in d.ncattrs():
                        dimattr[j] = d.getncattr(j)
                    var.dims.append(dimattr)
        return var

    def get_monthly_variable(self,vn,component,member=0,year=0,month=0,latbe=[-90,90],lonbe=[0,360],dims=False):
        ## return Variable class 
        ## need cap. to read region data
        ## need cap. to 2d coordinate
        ## in atm, sst = 0. means missing value
        ## not done yet
        if latbe != [-90,90] or lonbe != [0,360]: 
            print('get_monthly_variable(): latbe/lonbe are not imply yet. Be carful.')

        if component == 'atm': return self._get_atm_monthly_variable(vn,component,member,year,month,latbe,lonbe,dims)
        if component == 'ocn': return self._get_ocn_monthly_variable(vn,component,member,year,month,latbe,lonbe,dims)

        print(f'component {component} is not applied yet')
        return False

    def get_monthly_variable_ens(self,vn,component,members=[],year=0,month=0,latbe=[-90,90],lonbe=[0,360],mean=False,dims=False):
        ## add a axis for return variable of selected members
        ## if members is empty or 0, return all

        if members == [0] or members == []: members = self.memberCasenames
        if type(members) == type(0) : return self.get_monthly_variable(vn,component,members,year,month,latbe,lonbe,dims)
        nmem = len(members)
        if nmem == 1: return self.get_monthly_variable(vn,component,members,year,month,latbe,lonbe,dims)
        v1 = self.get_monthly_variable(vn,component,members[0],year,month,latbe,lonbe,dims=True)
        varall = np.ma.empty([nmem]+list(v1.shape),order='F')
        varall[0,] = v1
        dims = v1.dims
        attrs = v1.attrs
        for i in range(1,nmem):
            varall[i,] = self.get_monthly_variable(vn,component,members[i],year,month,latbe,lonbe,dims=False)
        if mean:
            varall = np.ma.mean(varall,axis=0)
        varall.dims = dims
        varall.attrs = attrs
        return varall

    def get_ensmean_clm12(self,vn,component,years=[],tag=''):
        ny = len(years)
        if not ny:
            print('get_ensmean_clm12(): need years range as list()')
            return False
        for m in range(1,12+1):
            cachefn = f'data/{vn}_{component}_{tag}_mon_{m:02}.npy'
            if tag and os.path.exists(cachefn):
                with open(cachefn,'rb') as f:
                    v = np.load(f)
            else:
                v = self.get_monthly_variable_ens(vn,component,year=years[0],month=m,mean=True)/ny
                for y in years[1:]:
                    v += self.get_monthly_variable_ens(vn,component,year=y,month=m,mean=True)/ny
                if tag:
                    with open(cachefn,'wb') as f:
                        np.save(f,np.array(v))
            if 'vall' in locals():
                vall = np.ma.append(vall,v,axis=0)
            else:
                vall = v
        return vall
class norcpm_forecasts:
    ## access multiple forecasts or hindcasts
    def __init__(self,path,caseprefix=''):
        ## path: contain forcasts
        ## ex:
        ##  path/caseprefix_19800115
        ##  path/caseprefix_19800415
        ##  path/caseprefix_19800715
        ##  ...
        self.path = path.rstrip('/')
        self.caseprefix = caseprefix or self.path.split('/')[-1]
        self.hindcastsName = [ i for i in os.listdir(path) if re.match(f'^{caseprefix}_........$',i)]
        self.hindcastsName.sort()

        self.hindcasts = [ norcpm_data(f'{path}/{i}') for i in self.hindcastsName ]
        self.nhindcasts = len(self.hindcasts)
    
    def _get_year_month_from_leadmonth(self,startDate:str,leadmonth:int):
        ## startDate: string, YYYYMM or YYYYMMDD or YYYYMMDDSSSSS
        ## leadmonth: 1 integer
        ## return year,month
        yyyy = int(startDate[0:4])
        mm   = int(startDate[4:6])

        mm += leadmonth
        while mm > 12 or mm < 1:
            if mm > 12:
                mm -= 12
                yyyy += 1
            if mm < 1:
                mm += 12
                yyyy -= 1

        return yyyy,mm

    def _get_init_yyyymm_list(self):
        return [ i.startDate for i in self.hindcasts ]
        
    def get_monthly_variable_ensmean_leadmonth(self,vn,component,leadmonth:int,years=[],months=[],msg=False):
        ## put all hindcasts ensemble mean into time coordinate
        ## if cache present
        years,months = set(years), set(months)
        initDates  = self._get_init_yyyymm_list() ## string in list
        initYears  = [int(i[0:4]) for i in initDates]
        initMonths = [int(i[4:6]) for i in initDates]
        initNeed = [ i.isdigit() for i in initDates ]
        if years : initNeed = [ (i in years)  and j for i,j in zip(initYears, initNeed)]
        if months: initNeed = [ (i in months) and j for i,j in zip(initMonths,initNeed)]

        ineed = [ i for i,j in enumerate(initNeed) if j ]
        nhindcasts = len(ineed)
        if not nhindcasts:
            print('No suitable hindcast')
            print('years:  '+str(years))
            print('months: '+str(months))
            sys.exit()

        year,month = self._get_year_month_from_leadmonth(self.hindcasts[ineed[0]].startDate,leadmonth)
        timeaxis = [f'{year}{month:02}']
        v1 = self.hindcasts[ineed[0]].get_monthly_variable_ens(vn,component,[0],year,month,dims=True,mean=True)
        varallshape = list(v1.shape)
        varallshape[0] = nhindcasts
        varall = np.ma.empty(varallshape,order='F')
        maskall = np.ma.getmaskarray(varall)
        varall[0,] = v1
        maskall[0,] = np.ma.getmaskarray(v1)
        varall.dims = v1.dims
        varall.attrs = v1.attrs
        dimtime = v1.dims[0]

        t0 = time.perf_counter()
        for i in range(1,nhindcasts):
            j = ineed[i]
            print(f'\r Reading {i+1}/{nhindcasts}',end='')
            year,month = self._get_year_month_from_leadmonth(self.hindcasts[j].startDate,leadmonth)
            v1 = self.hindcasts[j].get_monthly_variable_ens(vn,component,[0],year,month,mean=True)
            varall[i,] = v1
            maskall[i,] = np.ma.getmaskarray(v1)

            ## append time dimension, should be wrong need rewrite
            dimtime['value'] = np.append(dimtime['value'],v1.dims[0]['value'])
            timeaxis.append(f'{year}{month:02}')
        dt = time.perf_counter() -t0
        if msg: print(f'\rRead {nhindcasts} hindcasts done in {dt:.2f} sec.')
        varall.time = np.ma.array(timeaxis)
        varall.dims[0] = dimtime
        #print(np.count_nonzero(np.ma.getmaskarray(varall) == True))
        #np.ma.masked_equal(varall,0.0)
        return  varall


if __name__ == '__main__':
    #path = '/projects/NS9039K/shared/norcpm/cases/NorCPM/NorCPM_V1/ana_19800115_me_hindcasts/ana_19800115_me_19890115/'
    #a = norcpm_data(path)
    #for i in ['atm','ocn','ice','lnd']:
    #    a._collect_varnames(i)
    #print(a.varnames)
    #print(a._fn('atm',2,'h0',1989,3))
    #v = a.get_monthly_variable('sst','ocn',1,1989,3)
    #print(v.shape)

    #path = '/projects/NS9039K/shared/norcpm/cases/NorCPM/NorCPM_V1/ana_19800115_me_hindcasts'
    #path = './test_set/ana_19800115_me_hindcasts'
    #prefix = 'ana_19800115_me'
    #b = norcpm_forecasts(path,prefix)
    #c1 = b.hindcasts[1].get_monthly_variable('templvl','ocn',1,1985,10)
    #c2 = b.hindcasts[1].get_monthly_variable('templvl','ocn',2,1985,10)
    #c12 = np.empty([2]+list(c1.shape))
    #c12[0,] = c1
    #c12[1,] = c2
    #c12 = b.hindcasts[1].get_monthly_variable_ens('T','atm',[0],1985,10)
    #c12 = b.hindcasts[1].get_monthly_variable_ens('sst','ocn',[1,2],1985,10)
    #print(c12.dims)
    #v = b.get_monthly_variable_ensmean_leadmonth('sst','ocn',leadmonth=2)
    #v = b.get_monthly_variable_ensmean_leadmonth('SST','atm',leadmonth=3)
    #print(v)
    #print(v.shape)
    #print(np.count_nonzero(np.ma.getmaskarray(v) == True))
    #print(v.dims[0])

    ## free run/analysis (Betzy)
    path = '/cluster/shared/NS9039K/archive/noresm_ctl_f09_tn14_19700101'
    ocngridfile = '/cluster/shared/noresm/inputdata/ocn/blom/grid/grid_tnx1v4_20170622.nc'
    prefix = 'noresm_ctl_f09_tn14_19700101'
    a = norcpm_data(path,ocngridfile=ocngridfile)
    #v = a.get_monthly_variable('TS','atm',1,1989,3)
    v = a.get_ensmean_clm12('temp','ocn',years=list(range(1985,2010+1)),tag='hist_clm12_1990-2010')
    print(v)
    print(v.shape)
    print(v.max())
    print(v.min())
    print(type(v))
    print(dir(v))
