import numpy as np
from netCDF4 import Dataset
import matplotlib.pyplot as plt

# read old CAM forcing
nc = Dataset('CCSM4_volcanic_forcing_0001-0015_jan.nc','r')
colmass_jan = nc['colmass'][:]
mmrvolc_jan = nc['MMRVOLC'][:]
lev = nc['lev'][:]
lat = nc['lat'][:]
date = nc['date'][:]
nc.close()
#
nc = Dataset('CCSM4_volcanic_forcing_0001-0015_jul.nc','r')
colmass_jul = nc['colmass'][:]
mmrvolc_jul = nc['MMRVOLC'][:]
nc.close()
#
year_ccsm = np.int32(date/10000) + np.int32((date - np.int32(date/10000)*10000)/100)/12 - 1/24 

# read new CAM forcing
nc = Dataset('CCSM4_volcanic_forcing_0001-0015_jan_new.nc','r')
mmrvolc_jan_new = nc['MMRVOLC'][:]
nc.close()
#
nc = Dataset('CCSM4_volcanic_forcing_0001-0015_jul_new.nc','r')
mmrvolc_jul_new = nc['MMRVOLC'][:]
nc.close()


# read CAM diagnostic output 
nc = Dataset('B1850_T31_g37_500Tg_jan.cam2.h0.0001-0015.nc','r')
colmass_jan_out = nc['VOLC_MASS_C'][:]
lat_out = nc['lat'][:]
date_out = nc['date'][:]
nc.close()
nc = Dataset('B1850_T31_g37_500Tg_jul.cam2.h0.0001-0015.nc','r')
colmass_jul_out = nc['VOLC_MASS_C'][:]
nc.close()
#
year_ccsm_out = np.int32(date_out/10000) + np.int32((date_out - np.int32(date_out/10000)*10000)/100)/12 - 1/24 


# compute global load from colmass
area_earth = 4*np.pi*6371000**2
weights = np.cos(lat/180.*np.pi)
weights = weights/sum(weights)*area_earth
weights_out = np.cos(lat_out/180.*np.pi)
weights_out = weights_out/sum(weights_out)*area_earth
colmass_global_jan = (colmass_jan*weights).sum(axis=1)*1e-9
colmass_global_jul = (colmass_jul*weights).sum(axis=1)*1e-9
colmass_global_jan_out = (np.mean(colmass_jan_out,axis=-1)*weights_out).sum(axis=1)*1e-9
colmass_global_jul_out = (np.mean(colmass_jul_out,axis=-1)*weights_out).sum(axis=1)*1e-9

# compute global load from mmrvolc 
def vint_mmrvolc(mmrvolc,plev,lat):
    plevi = np.zeros((len(plev),2))
    plevi[0,0] = 1 # stratopause set to 1 hPa 
    plevi[0,1] = 0.5 * (plev[0] + plev[1])
    for i in range(1,len(plev)-1):
        plevi[i,0] = 0.5 * (plev[i-1] + plev[i])
        plevi[i,1] = 0.5 * (plev[i] + plev[i+1])        
    plevi[-1,0] = 0.5 * (plev[-2] + plev[-1])   
    plevi[-1,1] = plev[-1] + 0.5 * (plev[-1] - plev[-2]) # linear extrapolation 
    dp = (plevi[:,1] - plevi[:,0]) * 100 / 9.81 # hPa -> Pa -> kg/m2  
    dp = np.reshape(dp,(1,len(dp),1))
    #
    area_earth = 4*np.pi*6371000**2
    area = np.cos(lat/180.*np.pi)
    area = area/sum(area)*area_earth  
    area = np.reshape(area,(1,1,len(area)))
    #
    mmrvolc_int = np.sum(dp*area*mmrvolc,axis=(-1,-2)) * 1e-9 # kg -> Tg
    return mmrvolc_int
mmrvolc_global_jan = vint_mmrvolc(mmrvolc_jan,lev,lat)
mmrvolc_global_jul = vint_mmrvolc(mmrvolc_jul,lev,lat)
mmrvolc_global_jan_new = vint_mmrvolc(mmrvolc_jan_new,lev,lat)
mmrvolc_global_jul_new = vint_mmrvolc(mmrvolc_jul_new,lev,lat)

# plot
fig, ax = plt.subplots(2, 1)
h1, = ax[0].plot(year_ccsm_out, colmass_global_jan_out,'b',linewidth=1,label='JAN start, online')
h2, = ax[0].plot(year_ccsm, colmass_global_jan,'r',linewidth=1,label='JAN start, forcing file colmass')
h3, = ax[0].plot(year_ccsm, mmrvolc_global_jan,'r:',linewidth=2,label='JAN start, forcing file mmrvolc')
h4, = ax[0].plot(year_ccsm, mmrvolc_global_jan_new,'g:',linewidth=2,label='JAN start, forcing file mmrvolc new')
ax[0].legend(handles=(h1,h2,h3,h4))
ax[0].set_xlim((0,16))
ax[0].set_ylim((0,500))
ax[0].set_ylabel('Load (Tg)')
#
h1, = ax[1].plot(year_ccsm_out, colmass_global_jul_out,'b',linewidth=1,label='JUL start, online')
h2, = ax[1].plot(year_ccsm, colmass_global_jul,'r',linewidth=1,label='JUL start, forcing file colmass')
h3, = ax[1].plot(year_ccsm, mmrvolc_global_jul,'r:',linewidth=2,label='JUL start, forcing file mmrvolc')
h4, = ax[1].plot(year_ccsm, mmrvolc_global_jul_new,'g:',linewidth=2,label='JUL start, forcing file mmrvolc new')
ax[1].legend(handles=(h1,h2,h3,h4))
ax[1].set_xlim((0,16))
ax[1].set_ylim((0,500))
ax[1].set_ylabel('Load (Tg)')
#
plt.tight_layout()
plt.savefig('verify_colmass.png')