import numpy as np
from netCDF4 import Dataset
import matplotlib.pyplot as plt

year1  = 1700 
yearn  = 2000 

# read CAM forcing
offset = 500 # forcing data shifted by 500 years from -499-2000 to 1-2500 
nc = Dataset(f'CCSM4_volcanic_forcing_0001-2500.nc','r')
colmass = nc['colmass'][:]
MMRVOLC = nc['MMRVOLC'][:]
lev = nc['lev'][:]
lat = nc['lat'][:]
date = nc['date'][:]
nc.close()
year_ccsm = np.int32(date/10000) + np.int32((date - np.int32(date/10000)*10000)/100)/12 - 1/24 - offset

# read eruption data
file = open(f'CCSM4_volcanic_forcing_0001-2500.txt','r')
year = []
hemi = [] 
load = []
for line in file:
    year.append(int(line[0:6])+0.5-offset) 
    hemi.append(int(line[7:8]))
    load.append(float(line[8:15]))
file.close()

# compute global load from colmass
area_earth = 4*np.pi*6371000**2
weights = np.cos(lat/180.*np.pi)
weights = weights/sum(weights)*area_earth
colmass_global = (colmass*weights).sum(axis=1)*1e-9

# plot
fig, ax = plt.subplots(1, 1)
h=[]
h1, = ax.plot(year_ccsm, colmass_global,'r',linewidth=2,label='CAM forcing file')
h.append(h1)
for ind, y in enumerate(year):
    h2, = ax.plot([y,y], [load[ind],load[ind]],'k+',label='eruption file')
ax.legend(handles=(h1,h2))
ax.set_xlim((year1,yearn))
ax.set_ylim((0,90))
ax.set_ylabel('Load (Tg)')
plt.tight_layout()
plt.savefig('verify_colmass.png')