# -*- coding: utf-8 -*-
"""
Created on Wed Jul 23 09:46:27 2025

@author: u230638
"""

import sys
import os
import pandas as pd
import numpy as np
import xarray as xr
import glob
import re
import matplotlib.pyplot as plt
import seaborn as sns

# =============================================================================
# %% Read and transform into xarray dataset
# =============================================================================
# =============================================================================
# User choice
# =============================================================================
input_folder =  os.getcwd() # specify other folder if needed 
time_resolution = '30min' # '1min' or '30min' or '35min'
parameter = 'VOC' # 'VOC' or 'ozone' or 'turbulence'

# =============================================================================
# Start reading
# =============================================================================
message = f"Reading {time_resolution} {parameter} CSV file(s)..."
border = "*" * (len(message) + 4)
print(border)
print(f"* {message} *")
print(border)

# =============================================================================
# Find file names
# =============================================================================
if time_resolution == '1min':
    if parameter == 'VOC':
        pattern = 'BE-Vie_1min_concentrations_VOC_*.csv'
    elif parameter == 'ozone':
        pattern = 'BE-Vie_1min_concentrations_ozone_*.csv'
    else:
        print(f"Error: Unknown parameter '{parameter}'. Please specify 'VOC' or 'ozone'.")
        sys.exit(1)
elif time_resolution == '30min':
    if parameter == 'VOC':
        pattern = 'BE-Vie_30min_concentrations_fluxes_VOC_*.csv'
    elif parameter == 'ozone':
        pattern = 'BE-Vie_30min_concentrations_fluxes_ozone_*.csv'
    elif parameter == 'turbulence':
        pattern = 'BE-Vie_30min_profile_turbulence_*.csv'
    else:
        print(f"Error: Unknown parameter '{parameter}'. Please specify 'VOC' or 'ozone' or 'turbulence'.")
        sys.exit(1)
elif time_resolution == '35min':
    if parameter == 'VOC':
        pattern = 'BE-Vie_35min_profile_VOC_*.csv'
    elif parameter == 'ozone':
        pattern = 'BE-Vie_35min_profile_ozone_*.csv'
    else:
        print(f"Error: Unknown parameter '{parameter}'. Please specify 'VOC' or 'ozone'.")
        sys.exit(1)
else:
    print(f"Error: Unknown time_resolution '{time_resolution}'. Please specify '1min' or '30min' or '35min'.")
    sys.exit(1)

# Extract the path(s) and file name(s)
full_paths = glob.glob(os.path.join(input_folder, pattern))
matching_files = [os.path.basename(f) for f in full_paths]
if not full_paths:
    print(f"[ERROR] No files found matching the pattern: '{pattern}'")
    print("Please check the following:")
    print("  - The combination of 'time_resolution' and 'parameter' is correct.")
    print("  - The current script is located in the same folder as the data files (if applicable).")
    print(f"  - The 'input_folder' is correctly specified: {input_folder}")
    sys.exit(1)
    
# =============================================================================
# Read csv files
# =============================================================================
# Detect the number of lines to remove (header)
def count_header_lines(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        n_skip = 0
        for line in f:
            if line.startswith('#'):
                n_skip += 1
            else:
                break
    return n_skip

# Read all files in a dictionnary
dict_dfs = {}
for f in full_paths:
    skip = count_header_lines(f)
    df = pd.read_csv(f, skiprows=skip, sep=',')
    
    # Time index will correspond to the end of a given period, in UTC+1
    df.index = pd.to_datetime(df['TIMESTAMP_END']) + pd.Timedelta(hours=1)
    df.index.name = 'time'
    
    # Add to dictionnary
    dict_dfs[os.path.basename(f)] = df

# =============================================================================
# Convert to xarray
# =============================================================================
message = f"Converting {time_resolution} {parameter} to xarray..."
border = "*" * (len(message) + 4)
print(border)
print(f"* {message} *")
print(border)

if time_resolution == '1min':
    if parameter == 'VOC':
        # =====================================================================
        # 1 min concentrations of VOC        
        # =====================================================================
        # Concatenate all DataFrames vertically (over years)
        df = pd.concat(dict_dfs.values(), axis=0)
        # Drop unwanted columns if present
        drop_cols = ['TIMESTAMP_START', 'TIMESTAMP_END', 'DOY_START', 'DOY_END']
        df = df.drop(columns=[col for col in drop_cols if col in df.columns])
        # Reset index to get time as column
        df = df.reset_index()
        # Remove rows where height is NaN to avoid NaN in dimension
        df_clean = df.dropna(subset=['height'])
        # Identify variable columns (all except 'time' and 'height')
        var_cols = [col for col in df_clean.columns if col not in ['time', 'height']]
        # Unique sorted coordinates without NaN
        times = pd.to_datetime(df_clean['time'].unique())
        heights = sorted(df_clean['height'].unique())
        # Extract base variable names and mz values from columns
        # Pattern: <basevar>_<mz>
        pattern = re.compile(r'^(?P<basevar>.+)_(?P<mz>\d+\.\d+)$')
        # Create a dict to hold variables structured as:
        # {basevar: DataFrame pivoted with [time, height*mz combined]}
        variables_dict = {}
        # Collect all unique mz values to build the coordinate later
        mz_values = set()
        # Build a mapping: basevar -> list of (mz, column_name)
        basevar_mz_cols = {}
        for col in var_cols:
            m = pattern.match(col)
            if m:
                basevar = m.group('basevar')
                mz = float(m.group('mz'))
                mz_values.add(mz)
                basevar_mz_cols.setdefault(basevar, []).append((mz, col))
            else:
                # If column doesn't fit pattern, you can choose to ignore or handle separately
                pass
        mz_values = sorted(mz_values)
        # Build data_vars: for each basevar, create a 3D array [time, height, mz]
        data_vars = {}
        for basevar, mz_col_list in basevar_mz_cols.items():
            # For each mz, pivot data for that column [time x height]
            pivot_arrays = []
            mz_sorted_cols = sorted(mz_col_list, key=lambda x: x[0])  # sort by mz
            for mz, col_name in mz_sorted_cols:
                pivot = df_clean.pivot(index='time', columns='height', values=col_name)
                pivot_arrays.append(pivot.values)
            # Stack along new mz axis
            arr_3d = np.stack(pivot_arrays, axis=2)  # shape (time, height, mz)
            data_vars[basevar] = (['time', 'height', 'mz'], arr_3d)
        # Build xarray Dataset with coords time, height, mz
        ds = xr.Dataset(
            data_vars=data_vars,
            coords={
                'time': times,
                'height': heights,
                'mz': mz_values})
    elif parameter == 'ozone':
        # =====================================================================
        # 1 min concentrations of ozone        
        # =====================================================================
        # Concatenate all DataFrames vertically (over years)
        df = pd.concat(dict_dfs.values(), axis=0)
        # Drop unwanted columns if present
        drop_cols = ['TIMESTAMP_START', 'TIMESTAMP_END', 'DOY_START', 'DOY_END']
        df = df.drop(columns=[col for col in drop_cols if col in df.columns])
        # Reset index to get time as column
        df = df.reset_index()
        # Remove rows where height is NaN to avoid NaN in dimension
        df_clean = df.dropna(subset=['height'])
        # Identify variable columns (all except 'time' and 'height')
        var_cols = [col for col in df_clean.columns if col not in ['time', 'height']]
        # Unique sorted coordinates without NaN
        times = pd.to_datetime(df_clean['time'].unique())
        heights = sorted(df_clean['height'].unique())
        # Build data_vars dictionary with 2D arrays [time, height], removing suffix after last underscore
        data_vars = {}
        for var in var_cols:
            base_var = re.sub(r'_[^_]+$', '', var)  # remove suffix after last underscore
            pivot = df_clean.pivot(index='time', columns='height', values=var)
            data_vars[base_var] = (['time', 'height'], pivot.values)
        # Build xarray Dataset
        ds = xr.Dataset(
            data_vars=data_vars,
            coords={
                'time': times,
                'height': heights})
elif time_resolution == '30min' and parameter != 'turbulence':
    if parameter == 'VOC':
        # =====================================================================
        # 30 min concentrations and fluxes of VOC        
        # =====================================================================
        # Define systems: order matters (first = TOP, second = TRUNK)
        systems = ['TOP', 'TRUNK']
        # Extract the DataFrames in order from the dict
        dfs = list(dict_dfs.values())
        if len(dfs) != 2:
            raise ValueError("dict_dfs must contain exactly two DataFrames: TOP and TRUNK")
        # Use time index from the first DataFrame (assume both are identical)
        time = dfs[0].index
        data_vars = {}
        # Regex pattern to extract mz values and variable name
        pattern = r'^(?P<var>\w+)_(' + '|'.join(systems) + r')_(?P<mz>\d+\.\d+)$'
        # Extract unique variable types (normally just one, e.g. 'FluxUg') and mz values
        columns = dfs[0].columns
        matches = [re.match(pattern, col) for col in columns]
        variables = sorted(set(m.group('var') for m in matches if m))
        mz_values = sorted(set(float(m.group('mz')) for m in matches if m))
        # For each variable (e.g., 'FluxUg'), create a 3D array: [time, mz, system]
        for var in variables:
            data_array = np.stack([
                np.stack([
                    df[f"{var}_{sys}_{mz:.3f}"].values
                    for mz in mz_values], axis=1)  # axis=1 → mz
                for df, sys in zip(dfs, systems)], axis=2)  # axis=2 → system
            data_vars[var] = (['time', 'mz', 'system'], data_array)

        # Create the xarray Dataset
        ds = xr.Dataset(
            data_vars=data_vars,
            coords={'time': time, 'mz': mz_values, 'system': systems})
        
    elif parameter == 'ozone':
        # =====================================================================
        # 30 min concentrations and fluxes of ozone        
        # =====================================================================
        # Define systems: order matters (first = TOP, second = TRUNK)
        systems = ['TOP', 'TRUNK']
        # Extract the DataFrames in order from the dict
        dfs = list(dict_dfs.values())
        if len(dfs) != 2:
            raise ValueError("dict_dfs must contain exactly two DataFrames: TOP and TRUNK")
        # Use time index from the first DataFrame (assume both are identical)
        time = dfs[0].index
        data_vars = {}
        # Regex pattern to extract base variable names
        pattern = r'^(?P<var>\w+)_(' + '|'.join(systems) + r')_ozone$'
        columns = dfs[0].columns
        variables = sorted(set(
            re.match(pattern, col).group('var')
            for col in columns if re.match(pattern, col)))
        # Create data variables as 2D arrays with dimensions [time, system]
        for var in variables:
            arr = np.column_stack([
                df[f"{var}_{sys}_ozone"].values for df, sys in zip(dfs, systems)])
            data_vars[var] = (['time', 'system'], arr)
        # Create the xarray Dataset
        ds = xr.Dataset(
            data_vars=data_vars,
            coords={'time': time, 'system': systems})
        
elif (time_resolution == '35min' and parameter == 'ozone')\
    or (time_resolution == '30min' and parameter == 'turbulence'):
        # =====================================================================
        # 30/35 min concentration profile of ozone/turbulence      
        # =====================================================================
        df = dict_dfs[matching_files[0]]
        if parameter == 'ozone':
            # Pattern to match variable names with format: <var>_ozone_<height>m
            pattern = r'^(?P<var>\w+)_ozone_(?P<height>\d+)m$'
        elif parameter == 'turbulence':
            # Pattern to match variable names with format: <var>_ozone_<height>m
            pattern = r'^(?P<var>\w+?)_(?P<height>\d+)m$'
        # Identify matching columns and extract (var, height)
        matches = [
        (col, re.match(pattern, col))
        for col in df.columns
        if re.match(pattern, col)]
        # Filter valid matches only
        valid = [(col, m.group('var'), int(m.group('height'))) for col, m in matches if m]
        # Extract unique variable names and heights
        variables = sorted(set(v for _, v, _ in valid))
        heights = sorted(set(h for _, _, h in valid))
        # Build data_vars dict for the xarray Dataset
        data_vars = {}
        for var in variables:
            arr = np.full((len(df), len(heights)), np.nan)  # Initialize with NaNs
            for col, v, h in valid:
                if v == var:
                    arr[:, heights.index(h)] = df[col].values
            data_vars[var] = (['time', 'height'], arr)
        # Create the Dataset
        ds = xr.Dataset(
            data_vars=data_vars,
            coords={
                'time': df.index,
                'height': heights})

elif time_resolution == '35min' and parameter == 'VOC':
    # =====================================================================
    # 35 min concentration profile of VOC    
    # =====================================================================
    df = dict_dfs[matching_files[0]]
    # Regex pattern to match variable names with format: <var>_<mz>_<height>m
    # Example: ConcMean_69.07_11m
    pattern = r'^(?P<var>\w+?)_(?P<mz>\d+\.\d+?)_(?P<height>\d+)m$'
    # Find columns matching the pattern and capture their groups
    matches = [
        (col, re.match(pattern, col))
        for col in df.columns
        if re.match(pattern, col)]
    # Filter valid matches and extract variable name, mz (float), and height (int)
    valid = [
        (col, m.group('var'), float(m.group('mz')), int(m.group('height')))
        for col, m in matches if m]
    # Extract unique variable names, mz values, and heights, all sorted
    variables = sorted(set(v for _, v, _, _ in valid))
    mz_values = sorted(set(mz for _, _, mz, _ in valid))
    heights = sorted(set(h for _, _, _, h in valid))
    data_vars = {}
    # For each variable, build a 3D numpy array with dimensions [time, height, mz]
    for var in variables:
        # Initialize array filled with NaNs
        arr = np.full((len(df), len(heights), len(mz_values)), np.nan)
        # Fill array slices with corresponding column data from the DataFrame
        for col, v, mz, h in valid:
            if v == var:
                i_h = heights.index(h)   # index of the height coordinate
                i_mz = mz_values.index(mz)  # index of the mz coordinate
                arr[:, i_h, i_mz] = df[col].values
        # Add variable data to dictionary with coordinates ['time', 'height', 'mz']
        data_vars[var] = (['time', 'height', 'mz'], arr)
    # Create the xarray Dataset with data variables and coordinates
    ds = xr.Dataset(
        data_vars=data_vars,
        coords={
            'time': df.index,
            'height': heights,
            'mz': mz_values})

# Visualize the structure of the xarray dataset
print(ds)

# =============================================================================
# %% Plotting example
# =============================================================================
# Chose plot style
sns.set_theme(style="ticks")

if time_resolution == '1min':
    if parameter == 'VOC':
        mz_plot = 69.070
        height = 51.0
        fig, ax= plt.subplots(figsize=(14, 8))
        ax.scatter(ds.time, ds.sel(height=height)\
                   .sel(mz=mz_plot, method='nearest')['Conc'])
        ax.grid()
        ax.set_ylabel(f'm/z {mz_plot} concentration at height {height} m [ppbv]')
        ax.set_xlabel('Time [UTC+1]')
        plt.show()
        
    elif parameter == 'ozone':
        height = 51.0
        fig, ax= plt.subplots(figsize=(14, 8))
        ax.scatter(ds.time, ds.sel(height=height)['ConcMean'])
        ax.grid()
        ax.set_ylabel(f'Ozone concentration at height {height} m [ppbv]')
        ax.set_xlabel('Time [UTC+1]')
        plt.show()
        
elif time_resolution == '30min':
    if parameter == 'VOC':
        mz_plot = 69.070
        fig, ax= plt.subplots(figsize=(14, 8))
        for system in ['TOP', 'TRUNK']:
            ax.scatter(ds.time, ds.sel(system=system)\
                       .sel(mz=mz_plot, method='nearest')['FluxUg'], label=system)
        ax.grid()
        ax.legend()
        ax.set_ylabel(f'm/z {mz_plot} flux [\u03bcg m$^{{-2}}$ s$^{{-1}}$]')
        ax.set_xlabel('Time [UTC+1]')
        plt.show()
        
    elif parameter == 'ozone':
        fig, ax= plt.subplots(figsize=(14, 8))
        for system in ['TOP', 'TRUNK']:
            ax.scatter(ds.time, ds.sel(system=system)['FluxNet'], label=system)
        ax.grid()
        ax.legend()
        ax.set_ylabel('Ozone flux [nmol m$^{-2}$ s$^{-1}$]')
        ax.set_xlabel('Time [UTC+1]')
        plt.show()
        
    elif parameter == 'turbulence':
        fig, ax= plt.subplots(figsize=(14, 8))
        for height in reversed(ds.height):
            ax.scatter(ds.time, ds.sel(height=height)['Sigmaw'], label=height.values)
        ax.grid()
        ax.legend()
        ax.set_ylabel('Sigmaw [m s$^{-1}$]')
        ax.set_xlabel('Time [UTC+1]')
        plt.show()

elif time_resolution == '35min':
    if parameter == 'VOC':
        mz_plot = 69.070
        fig, ax= plt.subplots(figsize=(14, 8))
        for height in reversed(ds.height):
            ax.scatter(ds.time, ds.sel(height=height)\
                       .sel(mz=mz_plot, method='nearest')['ConcMedian'], label=height.values)
        ax.grid()
        ax.legend()
        ax.set_ylabel(f'm/z {mz_plot} concentration [ppbv]')
        ax.set_xlabel('Time [UTC+1]')
        plt.show()
        
    elif parameter == 'ozone':
        fig, ax= plt.subplots(figsize=(14, 8))
        for height in reversed(ds.height):
            ax.scatter(ds.time, ds.sel(height=height)['ConcMean'], label=height.values)
        ax.grid()
        ax.legend()
        ax.set_ylabel('Ozone concentration [ppbv]')
        ax.set_xlabel('Time [UTC+1]')
        plt.show()
