# generate test input for perturbing solar flux applied in POP

# import libraries
import numpy as np 
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from mpl_toolkits.basemap import Basemap, shiftgrid

# define anomaly and domain where flux is pertubed 
flux_anomaly = 30 
lonRange=[-40,0]
latRange=[0,30] 

# read coordinates and land mask 
nc = Dataset('grid_gx1.nc','r')
lon = nc['plon'][:]
lat = nc['plat'][:]
msk = nc['pmask'][:]
nc.close()
lon1 = lon.flatten()
lat1 = lat.flatten()
msk1 = msk.flatten()
for i in range(len(msk1)): 
    if not (msk1[i] == 1 and lon1[i] >= lonRange[0] and lon1[i] <= lonRange[1] and lat1[i] >= latRange[0] and lat1[i] <= latRange[1]): msk1[i] = 0 

# generate test input for Jan 2023 
for day in range(1,32): 
    file_name = f'popforcing_2023-01-{day:0>2d}-00000.nc' 
    nc = Dataset(file_name, 'w', format='NETCDF4_CLASSIC')
    nc.createDimension('time',0)
    nc.createDimension('x',msk.shape[-1])
    nc.createDimension('y',msk.shape[-2])
    nc.createVariable('SHF_QSW','f4',['time','y','x'])[0,:,:] = msk1.reshape(msk.shape)*flux_anomaly 
    nc.close()

# plot grid coordinates for domain
fig, ax = plt.subplots(figsize=(8, 8))    
m = Basemap(projection='ortho',lon_0=-10,lat_0=20,resolution='c')
m.fillcontinents(color='gray',lake_color='gray')
x, y = m(lon1,lat1)
xvec = []
yvec = []
for i in range(len(msk1)):
    if msk1[i] == 1: 
        xvec.append(x[i])
        yvec.append(y[i])
        
m.scatter(xvec,yvec,0.1,marker='o',color='r')
plt.savefig(f'grid_gx1.png',format='png',dpi=600,bbox_inches='tight')
plt.close()

