import nixio as nix
import numpy as np

def read_baseline(block):
    spikes = []
    if "baseline" not in block.name:
        print("Block %s does not appear to be a baseline block!" % block.name )
        return spikes
    spikes = block.data_arrays[0][:]
    return spikes


def sort_blocks(nix_file):
    block_map = {}
    contrasts = []
    deltafs = []
    chirp_sizes = []
    conditions = []
    for b in nix_file.blocks:
        if "baseline" not in b.name.lower(): 
            name_parts = b.name.split("_")
            cntrst = float(name_parts[1])
            if cntrst not in contrasts:
                contrasts.append(cntrst)
            cndtn = name_parts[3]
            if cndtn not in conditions:
                conditions.append(cndtn)
            dltf = float(name_parts[5])
            if dltf not in deltafs:
                deltafs.append(dltf)
            chirpsize = int(name_parts[7])
            if chirpsize not in chirp_sizes:
                chirp_sizes.append(chirpsize)
            
            block_map[(cntrst, dltf, chirpsize, cndtn)] = b
        else:
            block_map["baseline"] = b
    return block_map, contrasts, deltafs, chirp_sizes, conditions


def get_spikes(block):
    """Get the spike trains.

    Args:
        block ([type]): [description]

    Returns:
        list of np.ndarray: the spike trains.
    """
    response_map = {}
    spikes = []

    for da in block.data_arrays:
        if "spike_times" in da.type and "response" in da.name:
            resp_id = int(da.name.split("_")[-1])
            response_map[resp_id] = da
    for k in sorted(response_map.keys()):
        spikes.append(response_map[k][:])
   
    return spikes


def get_signals(block):
    """Read the fish signals from block.

    Args:
        block ([type]): the block containing the data for a given df, contrast and condition

    Raises:
        ValueError: when the  complete stimulus data is not found
        ValueError: when the no-other animal data is not found

    Returns:
        np.ndarray: the complete signal
        np.ndarray: the frequency profile of the recorded fish
        np.ndarray: the frequency profile of the other fish
        np.ndarray: the time axis
    """
    self_freq = None
    other_freq = None
    signal = None
    time = None
    if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
        raise ValueError("Signals not stored in block!")
    if "no-other" not in block.name and "other frequency" not in block.data_arrays:
        raise ValueError("Signals not stored in block!")
    
    signal = block.data_arrays["complete stimulus"][:]
    time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
    self_freq = block.data_arrays["self frequency"][:]
    if "no-other" not in block.name:
        other_freq = block.data_arrays["other frequency"][:]
    return signal, self_freq, other_freq, time


def get_chirp_metadata(block):
    trial_duration = float(block.metadata["stimulus parameter"]["duration"])
    dt = float(block.metadata["stimulus parameter"]["dt"])
    chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"]
    chirp_size = block.metadata["stimulus parameter"]["chirp_size"]
    chirp_times = block.metadata["stimulus parameter"]["chirp_times"]
        
    return trial_duration, dt, chirp_size, chirp_duration, chirp_times