#! /usr/bin/env python

# Copyright (C) 2024- The University of Notre Dame
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.

import sys
import matplotlib.pyplot as plt
import ast
import os
import argparse

MICROSECONDS = 1000000
GIGABYTES = 1000000000

def parse_debug(log):
    lines = open(log, 'r').read().splitlines()
    manager_transfers = 0
    transfer_info = []
    file_sizes = {}
    worker_file_info = {}
    transfer_addresses = {}
    # FORMAT [FROM_ADDR, TO_ADDR, CACHENAME, SIZE(ADDED LATER)]
    for line in lines:
        if 'tx to' in line and ' file ' in line:
            date, time, manager, vine, tx, to, machine, addr, opt, cachename, size, mode = line.split(maxsplit=11)
            size = float(size)
            addr = addr.strip('):')
            addr = addr.strip('(')
            transfer_info.append(['manager', addr, cachename])
            if cachename not in file_sizes:
                file_sizes[cachename] = []
            file_sizes[cachename].append(size/GIGABYTES)
        if 'rx from' in line and ' file ' in line:
            try:
                date, time, manager, vine, rx, frm, machine, addr, opt, cachename, size, mode = line.split(maxsplit=11)
            # WQ fallback
            except:
                date, time, manager, vine, rx, frm, machine, addr, opt, cachename, size = line.split(maxsplit=10)
            size = float(size)
            addr = addr.strip('):')
            addr = addr.strip('(')
            transfer_info.append([addr, 'manager', cachename])
            if cachename not in file_sizes:
                file_sizes[cachename] = []
            file_sizes[cachename].append(size/GIGABYTES)
        if 'will get' in line:
            date, time, manager, vine, machine, addr, will, get, filename, fr, url, path = line.split(maxsplit=12)
            path = path.strip('workerip://')
            from_addr, cachenmame = path.split('/') 
            addr = addr.strip('):')
            addr = addr.strip('(')
            from_addr = transfer_addresses[from_addr]
            transfer_info.append([from_addr, addr, cachename])
        if "cache-update" in line:
            date, time, manager, vine, rx, frm, machine, addr, op, cachename, size, x, y, z = line.split(maxsplit=13)
            size = float(size)
            if cachename not in file_sizes:
                file_sizes[cachename] = []
            addr = addr.strip('):')
            addr = addr.strip('(')
            if addr not in worker_file_info:
                worker_file_info[addr] = {'files':{}}
            worker_file_info[addr]['files'][cachename] = 0
            file_sizes[cachename].append(size/GIGABYTES)
        if "transfer-port" in line:
            date, time, manager, vine, rx, frm, machine, addr, transfer_address, tr_addr = line.split(maxsplit=9)
            addr = addr.strip('):')
            addr = addr.strip('(')
            tr_addr = addr.split(':')[0] + ':' + tr_addr
            transfer_addresses[tr_addr] = addr

    info = []
    for transfer in transfer_info:
        transfer.append(max(file_sizes[transfer[2]]))
        info.append(transfer)
    transfer_info = info

    for worker in worker_file_info:
        for cachename in worker_file_info[worker]['files']:
            worker_file_info[worker]['files'][cachename] = max(file_sizes[cachename])
    ids = {}
    # assign ids to workers + manager
    count = 0 
    ids["manager"] = 0
    for transfer in transfer_info:
        if transfer[0] not in ids:
            count +=1
            ids[transfer[0]] = count
        if transfer[1] not in ids:
            count +=1
            ids[transfer[1]] = count
    for worker in worker_file_info:
        if worker not in ids:
            count +=1
            ids[worker] = count
    
    log_info = {}
    log_info['ids'] = ids
    log_info['manager_info'] = {'log':log, 'type':'debug'}
    log_info['transfer_info'] = transfer_info
    log_info['worker_file_info'] = worker_file_info
    return log_info

def parse_txn(log):

    task_info = {}
    worker_info = {}
    library_info = {}
    file_info = {}
    file_sizes = {}
    manager_info = {'type':'txn', 'log':log, 'start':0, 'stop':0, 'first_task':float('inf'), 'last_task_done':float('-inf')}

    # parse relevant info
    lines = open(log, 'r').read().splitlines()
    for line in lines:
        try:
            (time, m_pid, category, obj, status, info) = line.split(maxsplit=5)
            time = float(time)/MICROSECONDS
        except:
            continue
        if category == 'TASK':
            if obj not in task_info:
                task_info[obj] = {}
            if status == 'READY': 
                task_info[obj]['function'] = info.split()[0]
                task_info[obj]['ready'] = time
                task_info[obj]['id'] = obj
            if status == 'RUNNING':
                task_info[obj]['running'] = time
                task_info[obj]['worker'] = info.split()[0]
                if time < manager_info['first_task']:
                    manager_info['first_task'] = time
            if status == 'WAITING_RETRIEVAL':
                task_info[obj]['waiting_retrieval'] = time
            if status == 'RETRIEVED':
                result, exit_code, resource, info = info.split(maxsplit=4) 
                info = ast.literal_eval(info)
                if 'time_worker_start' in info:
                    task_info[obj]['ex_start'] = float(info['time_worker_start'][0])
                    task_info[obj]['ex_end'] = float(info['time_worker_end'][0])
                task_info[obj]['retrieved'] = time
            if status == 'DONE':
                task_info[obj]['done'] = time
                if time > manager_info['last_task_done']:
                    manager_info['last_task_done'] = time

        if category == 'WORKER':
            if obj not in worker_info:
                worker_info[obj] = {'tasks':[], 'libraries':[], 'cache':{}}
            if status == 'CONNECTION':
                worker_info[obj]['start'] = time
            if status == 'DISCONNECTION':
                if 'stop' not in worker_info[obj]:
                    worker_info[obj]['stop'] = []
                worker_info[obj]['stop'].append(time)
            if status == 'CACHE_UPDATE':
                (filename, size, wall_time, start_time) = info.split()
                size = float(size)
                if filename not in file_sizes:
                    file_sizes[filename] = []
                file_sizes[filename].append(size/GIGABYTES)
                if filename not in file_info:
                    file_info[filename] = {"workers":{}}
                if obj not in file_info[filename]["workers"]:
                    file_info[filename]['workers'][obj] = [time, size/GIGABYTES]
                worker_info[obj]['cache'][filename] = ([time, size/GIGABYTES])

        if category == 'LIBRARY':
            if obj not in library_info:
                library_info[obj] = {}
            if status == 'STARTED':
                library_info[obj]['start'] = time
                library_info[obj]['worker'] = info

        if category == 'MANAGER':
            if status == 'START':
                manager_info['start'] = time
            if status == 'END':
                manager_info['stop'] = time
    
    # match tasks and libraries to workers

    for worker in worker_info:
        for filename in worker_info[worker]['cache']:
            worker_info[worker]['cache'][filename][-1] = max(file_sizes[filename])

    for task in task_info:
        if 'worker' in task_info[task] and task_info[task]['worker'] in worker_info and 'ready' in task_info[task]:
            worker_info[task_info[task]['worker']]['tasks'].append(task_info[task])

    for library in library_info:
            if 'worker' in library_info[library] and library_info[library]['worker'] in worker_info:
                worker_info[library_info[library]['worker']]['libraries'].append(library_info[library])

    log_info = {}
    log_info['worker_info'] = worker_info
    log_info['manager_info'] = manager_info
    log_info['task_info'] = task_info
    log_info['library_info'] = library_info
    log_info['file_info'] = file_info

    return log_info

def plot_cache(log_info, axs, args):

    worker_info = log_info['worker_info']
    print(len(worker_info))
    origin = 0
    first_task =  log_info['manager_info']['first_task'] 
    manager_start = log_info['manager_info']['start']
    if args.origin == 'first-task':
        origin = first_task
    elif args.origin == 'manager-start':
        origin = manager_start
    dx = []
    dy = []
    for worker in worker_info:
        if 'tasks' not in worker_info[worker]:
            continue
        xs = []
        ys = []
        times = []
        for filename in worker_info[worker]['cache']:
            time = worker_info[worker]['cache'][filename][0]
            size = worker_info[worker]['cache'][filename][1]
            times.append([time,size,'c'])

        times = sorted(times, key=lambda x:x[0])
        gb_count = 0
        for time in times:
            if time[-1] == 'c':
                gb_count += time[1]
                xs.append(time[0] - origin)
                ys.append(gb_count)
        axs.plot(xs, ys, label=worker)
        if 'stop' in worker_info[worker]:
            for time in worker_info[worker]['stop']:
                dx.append(time - origin)
                dy.append(gb_count)
    axs.plot(dx,dy, 'rx', label='Worker Stop')
    if args.sublabels:
        axs.set_xlabel('Manager Runtime (s)')
        axs.set_ylabel('Disk (GB)')
    if args.subtitles:
        axs.set_title(log_info['manager_info']['log'])
    if args.r_xlim:
        axs.set_xlim(right=args.r_xlim)
    if args.l_xlim:
        axs.set_xlim(left=args.l_xlim)
    if args.sublegend:
        axs.legend()

def plot_worker(log_info, axs, args):
    
    worker_info = log_info['worker_info']

    y_count = 0
    y_counts = []
    core_count = 0
    core_counts = []

    task_running_info = {"ys":[], "widths":[], "lefts":[]}
    task_execution_info = {"ys":[], "widths":[], "lefts":[]}
    task_waiting_info = {"ys":[], "widths":[], "lefts":[]}
    lib_plot_info = {"ys":[], "widths":[], "lefts":[]}
    worker_plot_info = {"ys":[], "widths":[], "heights":[], "lefts":[]}

    origin = 0
    first_task =  log_info['manager_info']['first_task'] 
    manager_start = log_info['manager_info']['start']
    if args.origin == 'first-task':
        origin = first_task
    elif args.origin == 'manager-start':
        origin = manager_start

    for worker in worker_info:
        # assign tasks to cores on the worker
        if 'tasks' not in worker_info[worker] or 'start' not in worker_info[worker]:
            continue

        # info relevant to the worker
        y_count += 1
        start_y = y_count
        start_x = worker_info[worker]['start']

        slots = {}
        
        # sort tasks by time started running
        tasks = []
        for task in worker_info[worker]['tasks']:
            try:
                tasks.append([task['running'], task['waiting_retrieval'], task['retrieved'], task['ex_start'], task['ex_end']])
            except:
                print("TASK NOT RETRIEVED", worker, task['id'], task['function'])
        #tasks = [[task['running'], task['waiting_retrieval'], task['retrieved']] for task in worker_info[worker]['tasks']]
        tasks = sorted(tasks, key=lambda x: x[0])
        
        # way to count the number of cores/slots on a worker
        for task in tasks:
            # intitate slots
            if not slots:
                slots[1] = [task]
            # go through slots to see next available slot
            else:
                fits = 0
                for slot in slots:
                    if task[0] > slots[slot][-1][2]:
                        # task fits in slot
                        slots[slot].append(task)
                        fits += 1
                        break
                # create new slot if task does not fit
                if not fits:
                    slots[len(slots) + 1] = [task]

        # accum results for tasks
        for slot in slots:
            core_count += 1
            y_count += 1
            for task in slots[slot]:
                task_running_info['ys'].append(y_count)
                task_waiting_info['ys'].append(y_count)
                task_execution_info['ys'].append(y_count)
                task_running_info['widths'].append(task[1] - task[0])
                task_waiting_info['widths'].append(task[2] - task[1])
                task_execution_info['widths'].append(task[4] - task[3])
                task_running_info['lefts'].append(task[0] - origin)
                task_waiting_info['lefts'].append(task[1] - origin)
                task_execution_info['lefts'].append(task[0] - origin)

        if 'stop' not in worker_info[worker]:
            worker_info[worker]['stop'] = [log_info['manager_info']['last_task_done']]

        # accum results for libraries
        for lib in worker_info[worker]['libraries']:
            y_count += 1
            lib_plot_info['ys'].append(y_count)
            lib_plot_info['widths'].append(worker_info[worker]['stop'][0] - lib['start'])
            lib_plot_info['lefts'].append(lib['start'] - origin)
        
        # For y scales
        core_counts.append(core_count)
        y_counts.append(y_count)
        
        # accum result for worker
        stop_y = y_count
        stop_x = worker_info[worker]['stop'][0]

        worker_plot_info['ys'].append(start_y)
        worker_plot_info['widths'].append(stop_x - start_x)
        worker_plot_info['heights'].append(stop_y - start_y)
        worker_plot_info['lefts'].append(start_x - origin)
    
    if(worker_plot_info['ys']):
        axs.barh(worker_plot_info['ys'], worker_plot_info['widths'], height=worker_plot_info['heights'], left=worker_plot_info['lefts'], label='worker', color='grey', align='edge')
    if(task_running_info['ys']):
        axs.barh(task_running_info['ys'], task_running_info['widths'], left=task_running_info['lefts'], label='task running', height=-.8, align='edge')
    if(task_waiting_info['ys']):
        axs.barh(task_waiting_info['ys'], task_waiting_info['widths'], left=task_waiting_info['lefts'], label='task waiting retrieval', height=-.8, align='edge')
    if(task_execution_info['ys']):
        axs.barh(task_execution_info['ys'], task_execution_info['widths'], left=task_execution_info['lefts'], color='lightblue', label='task execution', height=-.8, align='edge')
    if(lib_plot_info['ys']):
        axs.barh(lib_plot_info['ys'], lib_plot_info['widths'], left=lib_plot_info['lefts'], label='library tasks', color='green', height=-.8, align='edge')
    
    # trickery for y axis ticks
    tick_count = args.worker_ticks
    steps = int(len(y_counts)/tick_count) 
    y_labels = []
    y_ticks = [y_counts[x] for x in range(steps - 1, len(y_counts), steps) if y_counts[steps - 1]]
    y_axis = args.worker_y
    if y_axis == 'workers':
        y_labels = [x for x in range(steps, len(y_counts) + 1, steps)]
    elif y_axis == 'cores':
        y_labels = [core_counts[x] for x in range(steps - 1, len(core_counts), steps) if core_counts[steps - 1]]
    
    if args.sublabels:
        axs.set_xlabel('Time (s)')
        axs.set_ylabel(y_axis)
    axs.set_yticks(y_ticks)
    axs.set_yticklabels(y_labels)
    axs.set_xlim(0)
    axs.legend()
def compute_time(log_info, axs, args):
    worker_info = log_info['worker_info']
    total_compute_time = 0
    for worker in worker_info:
        if 'tasks' not in worker_info[worker]:
            continue
        for task in worker_info[worker]['tasks']:
            total_compute_time += task['waiting_retrieval'] - task['running']
    print('Compute Time:', total_compute_time)
    print('Application Time:', log_info['manager_info']['last_task_done'] - log_info['manager_info']['first_task'])
        

def plot_task(log_info, axs, args):

    worker_info = log_info['worker_info']

    task_running_info = {'ys':[], 'lefts':[], 'widths':[]}
    task_waiting_info = {'ys':[], 'lefts':[], 'widths':[]}
    task_retrieved_info = {'ys':[], 'lefts':[], 'widths':[]}

    origin = 0
    first_task =  log_info['manager_info']['first_task'] 
    manager_start = log_info['manager_info']['start']
    if args.origin == 'first-task':
        origin = first_task
    elif args.origin == 'manager-start':
        origin = manager_start

    times = []
    for worker in worker_info:
        if 'tasks' in worker_info[worker]:
            for task in worker_info[worker]['tasks']:
                if 'waiting_retrieval' in task and 'retrieved' in task and 'done' in task:
                    times.append([task['running'], task['waiting_retrieval'], task['retrieved'], task['done']])
    
    times = sorted(times, key=lambda x: x[0])
    count = 0
    for time in times:
        count += 1
        task_running_info['ys'].append(count)
        task_waiting_info['ys'].append(count)
        task_retrieved_info['ys'].append(count)

        task_running_info['widths'].append(time[1] - time[0])
        task_waiting_info['widths'].append(time[2] - time[1])
        task_retrieved_info['widths'].append(time[3] - time[2])

        task_running_info['lefts'].append(time[0] - origin)
        task_waiting_info['lefts'].append(time[1] - origin)
        task_retrieved_info['lefts'].append(time[2] - origin)
            
    axs.barh(task_running_info['ys'], task_running_info['widths'], left=task_running_info['lefts'], label='tasks running')
    axs.barh(task_waiting_info['ys'], task_waiting_info['widths'], left=task_waiting_info['lefts'], label='tasks waiting retrieval')
    axs.barh(task_retrieved_info['ys'], task_retrieved_info['widths'], left=task_retrieved_info['lefts'], label='tasks retrieved')

    if args.sublabels:
        axs.set_xlabel('Time (s)')
        axs.set_ylabel('Tasks by Start Time')
    if args.r_xlim:
        axs.set_xlim(right=args.r_xlim)
    axs.legend()
    
def plot_state(log_info, axs, args):

    worker_info = log_info['worker_info']
    origin = 0
    first_task =  log_info['manager_info']['first_task'] 
    manager_start = log_info['manager_info']['start']
    if args.origin == 'first-task':
        origin = first_task
    elif args.origin == 'manager-start':
        origin = manager_start

    times = []
    for worker in worker_info:
        if 'tasks' in worker_info[worker]:
            for task in worker_info[worker]['tasks']:
                times.append([task['ready'], 'ready', task['id']])
                times.append([task['running'], 'running', task['id']])
                times.append([task['waiting_retrieval'], 'waiting_retrieval', task['id']])
                times.append([task['retrieved'], 'retrieved', task['id']])
                times.append([task['done'], 'done', task['id']])
    times = sorted(times, key=lambda x: x[0])

    states = ['ready', 'running', 'waiting_retrieval', 'retrieved', 'done']
    state_info = {}
    task_last_state = {}
    for state in states:
        state_info[state] = {'count':0, 'x':[], 'y':[]}

    for event in times:
        time = event[0]
        state = event[1]
        task_id = event[2]

        state_info[state]['count'] += 1
        state_info[state]['x'].append(time - origin)
        state_info[state]['y'].append(state_info[state]['count'])
        if task_id  in task_last_state:
            last_state = task_last_state[task_id]
            state_info[last_state]['count'] -= 1
            state_info[last_state]['x'].append(time - origin)
            state_info[last_state]['y'].append(state_info[last_state]['count'])
        task_last_state[task_id] = state

    for state in state_info:
        axs.plot(state_info[state]['x'], state_info[state]['y'], label=state)

    if args.sublabels:
        axs.set_xlabel('Runtime (s)')
        axs.set_ylabel('Number of Tasks in State')    
    axs.set_xlim(0)
    axs.legend()

def plot_runtime(log_info, axs, args):    
    
    worker_info = log_info['worker_info']

    times = []
    for worker in worker_info:
        if 'tasks' in worker_info[worker]:
            for task in worker_info[worker]['tasks']:
                times.append(task['waiting_retrieval'] - task['running'])
    
    p_complete = []
    runtime = []
    times = sorted(times, key=lambda x: x)
    total_tasks = len(times)
    for x in range(total_tasks):
        p_complete.append((x+1)/total_tasks)
        runtime.append(times[x])
    
    axs.plot(runtime, p_complete, label='task runtimes')
    axs.set_xscale('log')
    if args.sublabels:
        axs.set_xlabel('Task Runtime (s)')
        axs.set_ylabel('Percent of Tasks')
    if args.subtitles:
        axs.set_title(log_info['manager_info']['log'])
    if args.r_xlim:
        axs.set_xlim(right=args.r_xlim)
    if args.l_xlim:
        axs.set_xlim(left=args.l_xlim)

    axs.legend()
    
def plot_completion(log_info, axs, args):

    origin = 0
    first_task =  log_info['manager_info']['first_task'] 
    manager_start = log_info['manager_info']['start']
    if args.origin == 'first-task':
        origin = first_task
    elif args.origin == 'manager-start':
        origin = manager_start

    worker_info = log_info['worker_info']

    times = []
    for worker in worker_info:
        if 'tasks' in worker_info[worker]:
            for task in worker_info[worker]['tasks']:
                times.append([task['running'], task['waiting_retrieval']])
    
    p_complete = []
    time_completed = []
    times = sorted(times, key=lambda x: x[1])
    total_tasks = len(times)
    for x in range(total_tasks):
        p_complete.append((x+1)/total_tasks)
        time_completed.append(times[x][1] - origin)
    
    axs.plot(time_completed, p_complete, label='tasks completed')
    if args.sublabels:
        axs.set_xlabel('Time (s)')
        axs.set_ylabel('Percent of Tasks Completed')
    axs.legend()

def plot_file_accum(log_info, axs, args):

    origin = 0
    first_task =  log_info['manager_info']['first_task'] 
    manager_start = log_info['manager_info']['start']
    if args.origin == 'first-task':
        origin = first_task
    elif args.origin == 'manager-start':
        origin = manager_start

    file_info = log_info['file_info']
    for filename in file_info:
        file_accum = []
        for worker in file_info[filename]["workers"]:
            time = file_info[filename]["workers"][worker][0]
            file_accum.append([worker, time])

        file_accum = sorted(file_accum, key=lambda x:x[1])
        xs = []
        ys = []
        count = 0
        for accum in file_accum:
            count += 1
            xs.append(accum[1] - origin)
            ys.append(int(count))
        axs.plot(xs, ys, label=filename)

    if args.sublabels:
        axs.set_xlabel('Manager Lifetime (s)')
        axs.set_ylabel('Individual File Repilication')

def plot_file_hist(log_info, axs, args):

    origin = 0
    first_task =  log_info['manager_info']['first_task'] 
    manager_start = log_info['manager_info']['start']
    if args.origin == 'first-task':
        origin = first_task
    elif args.origin == 'manager-start':
        origin = manager_start
    
    hist = {}
    file_info = log_info['file_info']
    for filename in file_info:
        count = 0
        max_size = float('-inf')
        min_size = float('inf')
        for worker in file_info[filename]['workers']:
            time = file_info[filename]["workers"][worker][0]
            size = file_info[filename]["workers"][worker][1]
            count +=1
            if size > max_size:
                max_size = size
            if size < min_size:
                min_size = size
        if count not in hist:
            hist[count] = [1, max_size*count, [filename]]
        else:
            hist[count][0] += 1
            hist[count][1] += max_size*count
            hist[count][2].append(filename)
    xs = []
    ys = []
    ys2 = []
    axs2 = axs.twinx()
    for count in hist:
        xs.append(count)
        ys.append(hist[count][0])
        ys2.append(hist[count][1])

    axs.bar(xs, ys, label='file replication count', width=-.2, align='edge')
    axs2.bar(xs, ys2, label='distributed disk usage', width=.2, align='edge', color='orange')
    if args.sublabels:
        axs.set_xlabel('File Replication Count')
        axs.set_ylabel('Number of Files')
        axs2.set_ylabel('Distributed Disk Usage (GB)')
    if args.r_xlim is not None:
        axs.set_xlim(right=args.r_xlim)
        axs2.set_xlim(right=args.r_xlim)
    if args.l_xlim is not None:
        axs.set_xlim(left=args.l_xlim)
        axs2.set_xlim(left=args.l_xlim)
    if args.sublegend: 
        axs.legend(loc='upper left')
        axs2.legend(loc='upper right')

def plot_data_exg(log_info, axs, args):

    transfer_info = log_info['transfer_info']
    
    ids = log_info['ids']
    count = len(ids)

    C = []
    for x in range(count):
        C_in = []
        for y in range(count):
            C_in.append(0)
        C.append(C_in)
    
    for transfer in transfer_info:
        from_id = ids[transfer[0]]
        to_id = ids[transfer[1]]
        size = transfer[3]
        C[from_id][to_id] += size

    im = axs.pcolormesh(C)
    fig = axs.get_figure()
    cbar = fig.colorbar(im, ax=axs)
    if args.sublabels:
        axs.set_ylabel('Source(ID)')
        axs.set_xlabel('Destination(ID)')
    cbar.ax.set_ylabel("Exchanged Data (GB)")
    if args.r_xlim:
        axs.set_xlim(right=args.r_xlim)
    if args.l_xlim:
        axs.set_xlim(left=args.l_xlim)
    if args.t_ylim:
        axs.set_ylim(top=args.t_ylim)
    if args.b_ylim:
        axs.set_ylim(bottom=args.b_ylim)

def plot_trans_accum(log_info, axs, args):
    transfer_info = log_info['transfer_info']

    ids = log_info['ids']
    count = len(ids)

    Accum = [[]]
    for x in range(count):
        Accum[0].append(0)
    for transfer in transfer_info:
        from_id = ids[transfer[0]]
        to_id = ids[transfer[1]]
        size = transfer[3]
        Accum[0][to_id] += size

    im = axs.pcolormesh(Accum)
    axs.yaxis.set_tick_params(labelleft=False)
    axs.set_yticks([])
    fig = axs.get_figure()
    cbar = fig.colorbar(im, ax=axs, location='top')
    if args.sublabels:
        axs.set_xlabel('Node(ID)')
    cbar.ax.set_xlabel("Transfer Accumulated Data (GB)")
    if args.r_xlim:
        axs.set_xlim(right=args.r_xlim)
    if args.l_xlim:
        axs.set_xlim(left=args.l_xlim)

def plot_sent_accum(log_info, axs, args):
    transfer_info = log_info['transfer_info']

    ids = log_info['ids']
    count = len(ids)

    Accum = [[]]
    for x in range(count):
        Accum[0].append(0)
    for transfer in transfer_info:
        from_id = ids[transfer[0]]
        to_id = ids[transfer[1]]
        size = transfer[3]
        Accum[0][from_id] += size

    im = axs.pcolormesh(Accum)
    axs.yaxis.set_tick_params(labelleft=False)
    axs.set_yticks([])
    fig = axs.get_figure()
    cbar = fig.colorbar(im, ax=axs, location='top')
    if args.sublabels:
        axs.set_xlabel('Node(ID)')
    cbar.ax.set_xlabel("Sent Accumulated Data (GB)")
    if args.r_xlim:
        axs.set_xlim(right=args.r_xlim)
    if args.l_xlim:
        axs.set_xlim(left=args.l_xlim)

def plot_total_accum(log_info, axs, args):
    worker_file_info = log_info['worker_file_info']

    ids = log_info['ids']
    count = len(ids)

    Accum = [[]]
    for x in range(count):
        Accum[0].append(0)
    for worker in worker_file_info:
        for cachename in worker_file_info[worker]['files']:
            worker_id = ids[worker]
            size = worker_file_info[worker]['files'][cachename]
            Accum[0][worker_id] += size

    im = axs.pcolormesh(Accum)
    axs.yaxis.set_tick_params(labelleft=False)
    axs.set_yticks([])
    fig = axs.get_figure()
    cbar = fig.colorbar(im, ax=axs, location='top')
    if args.sublabels:
        axs.set_xlabel('Node(ID)')
    cbar.ax.set_xlabel("Total Accumulated Data (GB)")
    if args.r_xlim:
        axs.set_xlim(right=args.r_xlim)
    if args.l_xlim:
        axs.set_xlim(left=args.l_xlim)

def plot_internal_accum(log_info, axs, args):
    transfer_info = log_info['transfer_info']
    worker_file_info = log_info['worker_file_info']

    ids = log_info['ids']
    count = len(ids)

    Accum = [[]]
    for x in range(count):
        Accum[0].append(0)
    for worker in worker_file_info:
        for cachename in worker_file_info[worker]['files']:
            worker_id = ids[worker]
            size = worker_file_info[worker]['files'][cachename]
            Accum[0][worker_id] += size
    for transfer in transfer_info:
        from_id = ids[transfer[0]]
        to_id = ids[transfer[1]]
        size = transfer[3]
        Accum[0][to_id] -= size
    for x in range(count):
        Accum[0][x] = max(0, Accum[0][x])
    im = axs.pcolormesh(Accum)
    axs.yaxis.set_tick_params(labelleft=False)
    axs.set_yticks([])
    fig = axs.get_figure()
    cbar = fig.colorbar(im, ax=axs, location='top')
    if args.sublabels:
        axs.set_xlabel('Node(ID)')
    cbar.ax.set_xlabel("Internally Generated Data (GB)")
    if args.r_xlim:
        axs.set_xlim(right=args.r_xlim)
    if args.l_xlim:
        axs.set_xlim(left=args.l_xlim)

def check_debug(log_info):
    log = log_info['manager_info']['log']
    head =os.path.dirname(log)
    tail =os.path.basename(log)
    log = os.path.join(head, 'debug')
    print('Parsing debug log for {}'.format(log))
    return parse_debug(log)
    
def plot_any(log_info, plot, axs, args):

    # txn based plots
    if plot == 'worker-view':
        plot_worker(log_info, axs, args)
    elif plot == 'worker-cache':
        plot_cache(log_info, axs, args)
    elif plot == 'task-view':
        plot_task(log_info, axs, args)
    elif plot == 'task-runtime':
        plot_runtime(log_info, axs, args)
    elif plot == 'task-completion':
        plot_completion(log_info, axs, args)
    elif plot == 'task-state':
        plot_state(log_info, axs, args)
    elif plot == 'file-accum':
        plot_file_accum(log_info, axs, args)
    elif plot == 'file-hist':
        plot_file_hist(log_info, axs, args)
    elif plot == 'compute-time':
        compute_time(log_info, axs, args)

    # debug based plots
    elif plot == 'data-exg':
        if log_info['manager_info']['type'] == 'txn':
            debug_info = check_debug(log_info)
            plot_data_exg(debug_info, axs, args)
        elif log_info['manager_info']['type'] == 'debug':
            plot_data_exg(log_info, axs, args)
    elif plot == 'trans-accum':
        if log_info['manager_info']['type'] == 'txn':
            debug_info = check_debug(log_info)
            plot_trans_accum(debug_info, axs, args)
        elif log_info['manager_info']['type'] == 'debug':
            plot_trans_accum(log_info, axs, args)
    elif plot == 'sent-accum':
        if log_info['manager_info']['type'] == 'txn':
            debug_info = check_debug(log_info)
            plot_sent_accum(debug_info, axs, args)
        elif log_info['manager_info']['type'] == 'debug':
            plot_sent_accum(log_info, axs, args)
    elif plot == 'total-accum':
        if log_info['manager_info']['type'] == 'txn':
            debug_info = check_debug(log_info)
            plot_total_accum(debug_info, axs, args)
        elif log_info['manager_info']['type'] == 'debug':
            plot_total_accum(log_info, axs, args)
    elif plot == 'internal-accum':
        if log_info['manager_info']['type'] == 'txn':
            debug_info = check_debug(log_info)
            plot_internal_accum(debug_info, axs, args)
        elif log_info['manager_info']['type'] == 'debug':
            plot_internal_accum(log_info, axs, args)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Consolidated plotting tool with a variety of options')
    parser.add_argument('logs', nargs='+')

    # Plotting views
    parser.add_argument('--worker-view', dest='plots', action='append_const', const='worker-view')
    parser.add_argument('--worker-cache', dest='plots', action='append_const', const='worker-cache')
    parser.add_argument('--task-view', dest='plots', action='append_const', const='task-view')
    parser.add_argument('--task-runtime', dest='plots', action='append_const', const='task-runtime')
    parser.add_argument('--task-completion', dest='plots', action='append_const', const='task-completion')
    parser.add_argument('--task-state', dest='plots', action='append_const', const='task-state')
    parser.add_argument('--file-accum', dest='plots', action='append_const', const='file-accum')
    parser.add_argument('--file-hist', dest='plots', action='append_const', const='file-hist')
    parser.add_argument('--data-exg', dest='plots', action='append_const', const= 'data-exg')
    parser.add_argument('--trans-accum', dest='plots', action='append_const', const= 'trans-accum')
    parser.add_argument('--sent-accum', dest='plots', action='append_const', const= 'sent-accum')
    parser.add_argument('--internal-accum', dest='plots', action='append_const', const= 'internal-accum')
    parser.add_argument('--total-accum', dest='plots', action='append_const', const= 'total-accum')
    parser.add_argument('--compute-time', dest='plots', action='append_const', const= 'compute-time')


    # Subplot options
    parser.add_argument('--sublabels', action='store_true', default=False)
    parser.add_argument('--subtitles', action='store_true', default=False)
    parser.add_argument('--sublegend', action='store_true', default=False)
    parser.add_argument('--worker-y', nargs='?', default='workers', choices=['workers', 'cores'], type=str)
    parser.add_argument('--origin', nargs='?', default='first-task', choices=['first-task', 'manager-start'], type=str)
    parser.add_argument('--r-xlim', nargs='?', default=None, type=float)
    parser.add_argument('--l-xlim', nargs='?', default=None, type=float)
    parser.add_argument('--t-ylim', nargs='?', default=None, type=float)
    parser.add_argument('--b-ylim', nargs='?', default=None, type=float)
    parser.add_argument('--worker-ticks', nargs='?', default=5, type=int)

    # Figure Options
    parser.add_argument('--title', nargs='?', default='TaskVine Plot Composition')
    parser.add_argument('--width', nargs='?', default=6.4, type=float)
    parser.add_argument('--height', nargs='?', default=4.8, type=float)
    parser.add_argument('--scale',  action='store_true', default=False)
    parser.add_argument('--show',  action='store_true', default=False)
    parser.add_argument('--out', nargs='?', default=None)

    # Other Options
    parser.add_argument('mode', nargs='?', default='tv', choices=['tv', 'wq'], type=str)
    parser.add_argument('--parse-debug', action='store_true', default=False, help='Parse as debug log')

    args = parser.parse_args()
    
    ncols = len(args.plots)
    nrows = len(args.logs)

    width = args.width
    height = args.height
    if args.scale:
        width = 6.4 * ncols
        height = 4.8 * nrows
    
    fig, axs = plt.subplots(nrows, ncols, squeeze=False)

    row = 0
    for log in args.logs:
        print('Parsing log for {}'.format(log))
        if args.parse_debug:
            log_info = parse_debug(log)
        else:
            log_info = parse_txn(log)
        print('Ploting plot(s) for {}'.format(log))
        col = 0
        for plot in args.plots:
            plot_any(log_info, plot, axs[row][col], args)
            col += 1
        row += 1
    
    fig.suptitle(args.title)
    fig.set_size_inches(w=width, h=height)
    if args.out:
        fig.savefig(args.out)
    if args.show:
        plt.show()
