import numpy as np
import scipy.io as sp
import subprocess
import glob
import sys, os
sys.path.append(os.path.expanduser('~/code/python/pyRELACS'))
from pyRELACS.DataLoader import iload
from IPython import embed;

def read_info_file(file_name):
    """
    Reads the info file and returns the stored metadata in a dictionary. The dictionary may be nested.
    @param file_name:  The name of the info file.
    @return: dictionary, the stored information.
    """
    information = []
    root = {}
    with open(file_name, 'r') as f:
        lines = f.readlines()
        for l in lines:
            if not l.startswith("#"):
                continue
            l = l.strip("#").strip()
            if len(l) == 0:
                continue
            if not ": " in l:
                sec = {}
                root[l] = sec
            else:
                parts = l.split(': ')
                sec[parts[0].strip()] = parts[1].strip()
    information.append(root)
    return information


def get_zero_crossings(trace, threshold=0.0, rising_flank=True):
    if rising_flank:
        shifted_trace = np.hstack((0, np.squeeze(trace[0:-1])))
        crossings = np.all(np.vstack((np.squeeze(trace) > threshold, shifted_trace <= threshold)), 0)
    else:
        shifted_trace = np.hstack((0, np.squeeze(trace[0:-1])))
        crossings = np.all(np.vstack((np.squeeze(trace) < threshold, shifted_trace >= threshold)), 0)
    positions = np.nonzero(crossings)[0]
    return positions


def has_cell_type(cell, cell_type):
    if not os.path.exists(cell + '/info.dat'):
        return False
    else:
        info = read_info_file(cell + '/info.dat')
        return info[0]['Cell']['CellType'].lower() == cell_type.lower()


def has_baseline(cell):
    return os.path.exists(cell + '/basespikes1.dat')


def has_raw_traces(cell):
    if os.path.exists(cell + '/trace-2.raw'):
        return True
    if os.path.exists(cell + '/trace-2.raw.gz'):
        if subprocess.check_call(['gzip', '-d' , cell+'/trace-2.raw']) == 0:
            return True
    return False


def load_eod_data(cell, max_time, sample_rate):
    eod = np.fromfile(cell + "/trace-2.raw", np.float32)
    eod = eod[:max_time*sample_rate]
    return eod


def load_spike_data(cell, max_time):
    data = list(iload(cell+'/basespikes1.dat'))
    if data[0][2][-1]/1000 < max_time:
        return None
    else:
        return data[0][2][data[0][2] <= max_time*1000]/1000


if __name__ == '__main__':
    cell_type = 'p-unit'
    directory = os.path.expanduser('~/data/apteronotus')
    cells = glob.glob(directory + '/2012-*')
    
    found = 0
    index = 0
    max_cell_count = 25
    while found < max_cell_count and index < len(cells):
        cell = cells[index]
        print cell
        if not has_cell_type(cell, cell_type):
            print 'wrong type' 
            index += 1
            continue
        if not has_baseline(cell):
            print 'no baseline'
            index += 1
            continue
        if not has_raw_traces(cell):
            print 'no raw trace'
            index += 1
            continue
        spikes = load_spike_data(cell, 20)
        if spikes is None:
            print 'not enough baseline spikes'
            index += 1
            continue
        eod_data = load_eod_data(cell, 20, 20000)
        name = cell.split('/')[-1] + '_baseline.mat'
        sp.savemat(name, {'spike_times':spikes, 'eod':eod_data})
        found += 1
        index += 1