import numpy as np
import netCDF4 as nc
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable ## for colorbar
import sys,glob

###spectral operation

def generate_fft_index(n):
    nup = int(np.ceil((n+1)/2))
    if n%2 == 0:
        wn = np.concatenate((np.arange(0, nup), np.arange(2-nup, 0)))
    else:
        wn = np.concatenate((np.arange(0, nup), np.arange(1-nup, 0)))
    return wn

def pwrspec2d(field):  ##horizontal 2D spectrum p(k), k=sqrt(kx^2+ky^2)
    nx, ny, nz = field.shape
    FT = np.zeros((nx, ny, nz), dtype=complex)
    nupx = int(np.ceil((nx+1)/2))
    nupy = int(np.ceil((ny+1)/2))
    nup = max(nupx, nupy)
    wnx = generate_fft_index(nx)
    wny = generate_fft_index(ny)
    #wired#ky, kx = np.meshgrid(wnx, wny)
    ky, kx = np.meshgrid(wnx, wny,indexing='ij')
    k2d = np.sqrt((kx*(nup/nupx))**2 + (ky*(nup/nupy))**2)
    for z in range(nz):
        FT[:, :, z] = np.fft.fft2(field[:, :, z])
    P = (np.abs(FT)/nx/ny)**2
    wn = np.arange(0.0, nup)
    pwr = np.zeros((nup, nz))
    for z in range(nz):
        Pz = P[:, :, z]
        for w in range(nup):
            pwr[w, z] = np.sum(Pz[np.where(np.ceil(k2d)==w)])
    return wn, pwr


def vertical_intpolate_hybridsigma2plev(var1d,PS,P0,A,B,plev):
    ## p(k) = A(k)*P0 + B(k)*PS
    ## plev units in hPa
    plevPa = [p*100 for p in plev]
    origLevP = A*P0 + B*PS ## units in Pa
    return np.interp(plevPa,origLevP,var1d)

def hybridsigma2plev_tzyx2xyzt(var,PS,P0,A,B,plev):
    ## p(i,j,k) = A(k)*P0 + B(k)*PS(i,j)
    nt,nz,ny,nx = var.shape ## t,k,j,i
    nplev = len(plev)
    varplev = np.empty(shape=(nx,ny,nplev,nt),dtype=type(var),order="F")
    for t in range(nt):
        for j in range(ny):
            for i in range(nx):
                varplev[i,j,:,t] = vertical_intpolate_hybridsigma2plev(var[t,:,j,i],PS[t,j,i],P0,A,B,plev)
                
    return varplev

def cal_noresm_pwrspec2d(archiveFile):
    dstPlev = [1000,925,900,850,800,700,600,500,400,350,300,250,200,150,100,50,10]
    dset = nc.Dataset(archiveFile)
    lev = dset.variables['lev'][:]
    hyam = dset.variables['hyam'][:]
    hybm = dset.variables['hybm'][:]
    P0 = dset.variables['P0'][:]
    PS = dset.variables['PS'][:][:][:]
    U = dset.variables['U'][:]
    V = dset.variables['V'][:]
    Uplev = hybridsigma2plev_tzyx2xyzt(var=U,PS=PS,P0=P0,A=hyam,B=hybm,plev=dstPlev)
    Vplev = hybridsigma2plev_tzyx2xyzt(var=V,PS=PS,P0=P0,A=hyam,B=hybm,plev=dstPlev)
    wn,pwrU = pwrspec2d(Uplev[:,:,:,0])
    wn,pwrV = pwrspec2d(Vplev[:,:,:,0])
    pwr = 0.5*(pwrU+pwrV)
    return wn,dstPlev,pwr

def plot_pwrspec2d(wn,plev,pwr,figname='',title=''):
    fig,ax = plt.subplots()
    im = ax.contourf(wn,plev,np.transpose(pwr),levels=[float(i) for i in range(50)],cmap='PuRd',extend='both')
    ax.set_ylim(ax.get_ylim()[::-1])
    ax.set_xlim([1,20])
    if title: plt.title(title,loc='left')
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right",size="5%",pad=0.05)
    plt.colorbar(im,cax)
    if figname: 
        plt.savefig(figname)
    else:
        plt.show()
    plt.close()

def plot_noresm_pwrspec2d_1f(archiveFile,figname='',title=''):
    wn,plev,pwr = cal_noresm_pwrspec2d(archiveFile)
    plot_pwrspec2d(wn,plev,pwr,figname,title)
def plot_noresm_pwrspec2d(archiveDir):
    ## plot mean daily pwr
    case = archiveDir.split('/')[-1]
    casemem = [ i.split('/')[-1] for i in glob.glob(f'{archiveDir}/{case}_mem01*')] ## only mem01
    files = list()
    for m in casemem: 
        files.extend(glob.glob(f'{archiveDir}/{m}/atm/hist/{m}.cam2.h1.1981-*.nc'))

    files.sort()
    #files = files[:10]  ## only 10 days

    nf = len(files)
    wn,plev,pwr = cal_noresm_pwrspec2d(files[0])
    counter = 1
    print(f"{case} mem01 cal.")
    for i in files[1:]:
        counter += 1
        print(f'    reading {counter}/{nf} ',end='\r')
        wn,plev,pwr1 = cal_noresm_pwrspec2d(i)
        pwr += pwr1
    print()
    pwr = pwr/nf
    title = f'2DSpec. {case} mem01 1981 mean'
    figname = f'2DSpec_{case}_mem01_1981.png'
    plot_pwrspec2d(wn,plev,pwr,figname,title)
    

if __name__ == '__main__':
    #archivedir= '/projects/NS9039K/shared/norcpm/cases/NorCPM/Lilian/ana_19800115_nud_6h_me/ana_19800115_nud_6h_me_mem03'
    basepath = '/projects/NS9039K/shared/norcpm/cases/NorCPM/Lilian'
    caselist =  ['ana_19800115_anud_10d_me','ana_19800115_anud_7d_me','ana_19800115_nud_1d_me','ana_19800115_nud_7d_me'
                ,'noa_19800115_anud_7d_me', 'ana_19800115_anud_1d_me','ana_19800115_nud_00_me','ana_19800115_nud_6h_me']
    #caselist = [caselist[0]]
    for case in caselist:
        plot_noresm_pwrspec2d(basepath+'/'+case)
    sys.exit()

    if False:
        for i in range(1,30):
            case = 'ana_19800115_nud_6h_me'
            mem = 3
            date = f'1980-07-{i:02}'
            casemem = f'{case}_mem{mem:02}'
            archiveFile= f'{basepath}/{case}/{casemem}/atm/hist/{casemem}.cam2.h1.{date}-00000.nc'
            title = f'2D spatial spectrum {case} {date}'
            figname = title.replace(" ","_")
            plot_noresm_pwrspec2d_1f(archiveFile,figname,title)
