#!/usr/bin/env python

import os
import sys
import shutil
import numpy as np

# Auxillary functions.
# --------------------

def checkArgs():
    args = sys.argv
    if len(args) <= 6:
        print('Usage: python reformat_NorCAM_volc_forcing.py infile frequency year1 month1 yearn monthn')
        print()
        exit()
    ifile = args[1]
    frequency = args[2]
    year1 = int(args[3])
    month1 = int(args[4])
    yearn = int(args[5])
    monthn = int(args[6])
    return ifile, frequency, year1, month1, yearn, monthn

def isLeapYear(y):
    return y%4 == 0 and y%100 != 0 or y%400 == 0

def daysInMonth(y,m):
    daysInMonthNoLeap = [31,28,31,30,31,30,31,31,30,31,30,31]
    daysInMonthLeap = [31,29,31,30,31,30,31,31,30,31,30,31]
    if isLeapYear(int(y)):
        days = daysInMonthLeap[int(m)-1]
    else:
        days = daysInMonthNoLeap[int(m)-1]
    return days

def reformat_forcing(ifile, ofile, frequency, year1, month1, yearn, monthn):
    from netCDF4 import Dataset

    # prepare date and datesec variables
    ndays = 0
    nmons = 0
    for y in range(year1,yearn+1):
        m1 = 1 if y > year1 else month1
        m2 = 12 if y < yearn else monthn
        for m in range(m1,m2+1):
            ndays += daysInMonth(y,m)
            nmons += 1
    if frequency == '6hourly' or frequency == 'sixhourly':
        date = np.zeros(ndays*4,dtype=int)
        datesec = np.zeros(ndays*4,dtype=int)
        dtot = 0
        for y in range(year1,yearn+1):
            m1 = 1 if y > year1 else month1
            m2 = 12 if y < yearn else monthn
            for m in range(m1,m2+1):
                for d in range(daysInMonth(y,m)):
                    for t in range(4):
                        date[dtot*4+t] = y * 10000 + m * 100 + d + 1
                        datesec[dtot*4+t] = 3600 * 24 / 4 * t
                    dtot += 1
    elif frequency == 'monthly':
        date = np.zeros(nmons,dtype=int)
        datesec = np.zeros(nmons,dtype=int)
        mtot = 0
        for y in range(year1,yearn+1):
            m1 = 1 if y > year1 else month1
            m2 = 12 if y < yearn else monthn
            for m in range(m1,m2+1):
                date[mtot] = y * 10000 + m * 100 + int(daysInMonth(y,m)/2)
                datesec[mtot] = 0 if daysInMonth(y,m)%2 == 0 else 3600 * 24 / 2
                mtot += 1
   
    # open input and output file 
    ncin = Dataset(ifile,'r')
    ncout = Dataset(ofile,'w',format='NETCDF4_CLASSIC')
    
    # check dimension order in input file
    dimOrder = 'EVA' if ncin.variables['ext_sun'].dimensions[0] == 'month' else 'DEFAULT'
    
    # create dimensions in output file 
    ncout.createDimension('altitude',70)
    ncout.createDimension('altitude_int',71)
    ncout.createDimension('lat',36)
    ncout.createDimension('time',len(date))

    # create variables in output file 
    ncout.createVariable('date','i4',('time',))
    ncout.variables['date'].long_name = 'current date (YYYYMMDD)'
    #
    ncout.createVariable('datesec','i4',('time',))
    ncout.variables['datesec'].long_name = 'current seconds of current date'
    #
    ncout.createVariable('altitude','f4',('altitude',))
    ncout.variables['altitude'].units = 'km'
    #
    ncout.createVariable('altitude_int','f4',('altitude_int',))
    ncout.variables['altitude_int'].units = 'km'
    #
    ncout.createVariable('lat','f4',('lat',))
    ncout.variables['lat'].units = 'degrees_north'
    #
    nshort = 14
    for band in range(1,nshort+1):
        vname = 'ext_sun' + str(band) 
        ncout.createVariable(vname,'f4',('time','altitude','lat'))
        ncout.variables[vname].units = 'extinction coefficient of solar bands in 1/km'
    for band in range(1,nshort+1):
        vname = 'omega_sun' + str(band) 
        ncout.createVariable(vname,'f4',('time','altitude','lat'))
        ncout.variables[vname].units = 'single scattering albedo of solar bands'
    for band in range(1,nshort+1):
        vname = 'g_sun' + str(band) 
        ncout.createVariable(vname,'f4',('time','altitude','lat'))
        ncout.variables[vname].units = 'asymmetry factor of solar bands'
    #
    nlong = 16
    for band in range(1,nlong+1):
        vname = 'ext_earth' + str(band) 
        ncout.createVariable(vname,'f4',('time','altitude','lat'))
        ncout.variables[vname].units = 'extinction coefficient of terrestrial bands in 1/km'
    for band in range(1,nlong+1):
        vname = 'omega_earth' + str(band) 
        ncout.createVariable(vname,'f4',('time','altitude','lat'))
        ncout.variables[vname].units = 'single scattering albedo of terrestrial bands'
    for band in range(1,nlong+1):
        vname = 'g_earth' + str(band) 
        ncout.createVariable(vname,'f4',('time','altitude','lat'))
        ncout.variables[vname].units = 'asymmetry factor of terrestrial bands'

    
    # read and write data 
    ncout.variables['date'][:] = date
    #
    ncout.variables['datesec'][:] = datesec
    # 
    altitude = ncin.variables['altitude'][:] 
    altitude_int = np.zeros(71)
    altitude_int[0] = altitude[0] - (altitude[1] - altitude[0])/2 
    for k in range(1,70):
        altitude_int[k] = (altitude[k-1] + altitude[k])/2 
    altitude_int[70] = altitude[69] + (altitude[69] - altitude[68])/2
    ncout.variables['altitude'][:] = altitude
    ncout.variables['altitude_int'][:] = altitude_int
    # 
    ncout.variables['lat'][:] = ncin.variables['latitude'][:]
    # 
    for rec in range(len(date)):
        for param in ['ext', 'omega', 'g']:
            for band in range(1,nshort+1):
                if dimOrder == 'EVA':
                    ncout.variables[param + '_sun' + str(band)][rec,:,:] = ncin.variables[param + '_sun'][rec,band-1,:,:].transpose()
                else: 
                    ncout.variables[param + '_sun' + str(band)][rec,:,:] = ncin.variables[param + '_sun'][band-1,:,:,rec].transpose()
            for band in range(1,nlong+1):
                if dimOrder == 'EVA':
                    ncout.variables[param + '_earth' + str(band)][rec,:,:] = ncin.variables[param + '_earth'][rec,band-1,:,:].transpose()
                else: 
                    ncout.variables[param + '_earth' + str(band)][rec,:,:] = ncin.variables[param + '_earth'][band-1,:,:,rec].transpose()

    # close files
    ncin.close()
    ncout.close()

# Main section.
# -----------------------------

if __name__ == '__main__':

    # Check command line arguments
    ifile, frequency, year1, month1, yearn, monthn = checkArgs()
    ofile = ifile[:-3] + '_reformatted.nc'

    print('Reformat forcing file ' + ifile + ' and write output to ' + ofile)
    reformat_forcing(ifile, ofile, frequency, year1, month1, yearn, monthn)

    print('SUCCESS')

