###############################################################################
# Copyright 2014 2017 Intel Corporation.
#
# The source code, information and material ("Material") contained herein is
# owned by Intel Corporation or its suppliers or licensors, and title to such
# Material remains with Intel Corporation or its suppliers or licensors. The
# Material contains proprietary information of Intel or its suppliers and
# licensors. The Material is protected by worldwide copyright laws and treaty
# provisions. No part of the Material may be used, copied, reproduced, modified,
# published, uploaded, posted, transmitted, distributed or disclosed in any way
# without Intel's prior express written permission.
#
# No license under any patent, copyright or other intellectual property rights
# in the Material is granted to or conferred upon you, either expressly, by
# implication, inducement, estoppel or otherwise. Any license under such
#
# intellectual property rights must be express and approved by Intel in writing.
# Unless otherwise agreed by Intel in writing, you may not remove or alter
# this notice or any other notice embedded in Materials by Intel or Intel's
# suppliers or licensors in any way.
###############################################################################


import os
import time
import errno
import shutil
import datetime
import random

from . import stdiolog
from ._py2to3 import *
from .ipc_env.ipc_commands import ReportResults

try:
    import psutil
except ImportError:
    psutil = None



class PrintColors:
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'


def _human_readable_size(bytes):
    if bytes > 1024 ** 4:
        return "{:.1f} TB".format(bytes / 1024 ** 4)
    if bytes > 1024 ** 3:
        return "{:.1f} GB".format(bytes / 1024 ** 3)
    if bytes > 1024 ** 2:
        return "{:.1f} MB".format(bytes / 1024 ** 2)
    if bytes > 1024:
        return "{:.1f} KB".format(bytes / 1024)
    return"{} B".format(bytes)


def _human_readable_frequency(hertz):
    if hertz > 1000 ** 4:
        return "{:.3f} THz".format(hertz / 1000 ** 4)
    if hertz > 1000 ** 3:
        return "{:.3f} GHz".format(hertz / 1000 ** 3)
    if hertz > 1000 ** 2:
        return "{:.3f} MHz".format(hertz / 1000 ** 2)
    if hertz > 1000:
        return "{:.3f} KHz".format(hertz / 1000)
    return "{} Hz".format(hertz)


class DecoratedResult():
    """
    Can be stored in DiagnosticResult to control how it is printed.
    """
    def __init__(self, value, rank=0, value_format="{}"):
        self.rank = rank
        self.value = value
        self._value_format = value_format

    def format_value(self):
        if not self._value_format:
            return self.value
        if callable(self._value_format):
            return self._value_format(self.value)
        return self._value_format.format(self.value)

    def __str__(self):
        return str(self.format_value())

    def __lt__(self, other):
        if isinstance(other, DecoratedResult):
            return (self.rank, self.value) < (other.rank, other.value)
        return True  # Decorated results come before non-decorated results

    __repr__ = __str__


class DiagnosticResults(dict):
    """
    Represents a table hierarchy of results intended for pretty-printing and querying.
    """

    def __init__(self, *args, **kwargs):
        dict.__init__(self, *args, **kwargs)

    def __getitem__(self, key):
        v = dict.__getitem__(self, key)
        if isinstance(v, DecoratedResult):
            return v.value
        return v

    def __str__(self):
        return self._to_string()

    def _to_string(self, depth=0):
        string = ""

        def _item_order(item):
            k = item[0]
            v = item[1]
            if isinstance(v, DecoratedResult):  # First, DecoratedResult by rank
                return (0, v.rank, k)
            elif isinstance(v, DiagnosticResults):  # DiagnosticResults last
                return (2, 0, k)
            else:  # Everything else by key
                return(1, 0, k)

        for k, v in sorted(self.items(), key=_item_order):
            if isinstance(v, DiagnosticResults):
                value = v._to_string(depth + 1)
                string += "{}{}:\n{}".format(" " * depth, k, value).expandtabs(20)
            elif isinstance(v, DecoratedResult):
                string += "{}{}\t: {}\n".format(" " * depth, k, v).expandtabs(20)
            else:
                string += "{}{}\t: {}\n".format(" " * depth, k, v).expandtabs(20)
        return string

    __repr__ = __str__


class TransportDiagnostics():
    """
    Provides diagnostic operations evaluating the health of the connected probes.
    """
    def __init__(self, cli):
        self._parent_cli = cli

    def info(self):
        """
        Gathers probe related settings and information.
        """
        results = {}
        debugports = self._parent_cli.devs.group_by(nodetype="^Debugport$")
        for debugport in debugports:
            debugport_results = {
                "Plugin Type": debugport.device.plugintype,
                "Probe Type": debugport.device.probetype,
                "Probe Info:": debugport.device.probediagnosticinfo,
            }
            for chain in [child for child in debugport.children if child.nodetype == "JTAGScanChain"]:
                chain_results = {}
                jtag_optional_configs = [
                    ("Voltage", lambda x: x),
                    ("TimeoutClkSource", lambda x: x),
                    ("InternalReferenceClock", lambda x: DecoratedResult(int(x), 0, _human_readable_frequency)),
                    ("MuxSel", lambda x: int(x)),
                    ("TclkSource", lambda x: x)
                    ]
                for jtag_config_name, value_constructor in jtag_optional_configs:
                    if "Jtag.{}".format(jtag_config_name) in dir(chain.config):
                        chain_results[jtag_config_name] = value_constructor(chain.config["Jtag.{}".format(jtag_config_name)])
                if chain_results:
                    debugport_results["Chain {}".format(chain.device.instanceid)] = DiagnosticResults(chain_results)
            results["Debug Port {}".format(debugport.device.instanceid)] = DiagnosticResults(debugport_results)
        return DiagnosticResults(results)


class JtagDiagnostics():
    """
    Provides diagnostic operations evaluating the health of the JTAG scan chains.
    """
    def __init__(self, cli):
        self._parent_cli = cli
        self._key_pattern = cli.BitData(32, 0xD5D2600D)
        self._key_pattern_range = min(self._key_pattern.PopCount(), self._key_pattern.BitSize - self._key_pattern.PopCount()) - 1
        self._ir_overscan_size = 128
        self._dr_overscan_size = 256

    def _fuzzy_find_key_once(self, data, key):
        """
        Finds the best match for key in data, returning a tuple of the offset
        and number of bits that matched.
        """
        tries = data.BitSize - key.BitSize
        if tries < 0:
            return 0, 0
        best_offset = 0
        best_popcount = key.BitSize
        for x in range(tries + 1):
            test = data.Extract(x, key.BitSize)
            bit_dif = (test ^ key).PopCount()
            if bit_dif < best_popcount:
                best_popcount = bit_dif
                best_offset = x
                if best_popcount == 0:
                    return best_offset, best_popcount
        return best_offset, best_popcount

    def _fuzzy_find_key(self, data, key):
        """
        Finds the best match for key in data with inversion check, returning 
        a tuple of the offset and number of bits that matched, and if it was
        inverted
        """
        tries = data.BitSize - key.BitSize
        if tries < 0:
            return 0, 0, False
        best_offset, best_popcount = self._fuzzy_find_key_once(data, key)
        if best_popcount == 0:
            return best_offset, best_popcount, False
        inv_best_offset, inv_best_popcount = self._fuzzy_find_key_once(data, ~key)
        if inv_best_popcount < best_popcount:
            return inv_best_offset, inv_best_popcount, True

        return best_offset, best_popcount, False

    def _get_continuity_result(self, missmatch_bit_count, inverted, key_offset):
        """
        Gets the DecoratedResult for the continuity found
        """
        if missmatch_bit_count == 0:
            if inverted:
                if key_offset > 0:
                    continuity = DecoratedResult("Inverted", 0, "{}{{}}{}".format(PrintColors.FAIL, PrintColors.ENDC))
                else:
                    continuity = DecoratedResult("Inverted, Short", 0, "{}{{}}{}".format(PrintColors.FAIL, PrintColors.ENDC))
            else:
                if key_offset > 0:
                    continuity = DecoratedResult("Found", 0, "{}{{}}{}".format(PrintColors.OKGREEN, PrintColors.ENDC))
                else:
                    continuity = DecoratedResult("Short", 0, "{}{{}}{}".format(PrintColors.FAIL, PrintColors.ENDC))
        elif missmatch_bit_count < self._key_pattern_range:
            if inverted:
                continuity = DecoratedResult("Inverted, Corrupt", 0, "{}{{}}{}".format(PrintColors.FAIL, PrintColors.ENDC))
            else:
                continuity = DecoratedResult("Corrupt", 0, "{}{{}}{}".format(PrintColors.FAIL, PrintColors.ENDC))
        else:
            continuity = DecoratedResult("Not Found", 0, "{}{{}}{}".format(PrintColors.WARNING, PrintColors.ENDC))
        return continuity

    def _get_key_confidence(self, missmatch_bit_count):
        """
        Calculates a confidence level based on how many bits of the key
        """
        key_confidence_value = 1 - missmatch_bit_count*2 / self._key_pattern.BitSize
        if key_confidence_value == 1:
            key_confidence_color = PrintColors.OKGREEN
        elif key_confidence_value >= .8:
            key_confidence_color = PrintColors.WARNING
        else:
            key_confidence_color = PrintColors.FAIL
        return DecoratedResult(key_confidence_value, 1, "{}{{:.0%}}{}".format(key_confidence_color, PrintColors.ENDC))

    def _chain_ir_continuity(self, chain):
        """
        Performs an IR overscan to check for IR continuity
        """
        with self._parent_cli.device_locker() as _:
            key_bitcount = self._key_pattern.BitSize
            non_key_bitcount = self._ir_overscan_size
            tdi = self._parent_cli.BitData(non_key_bitcount + key_bitcount, -1)
            tdi.Deposit(0, key_bitcount, self._key_pattern)
            tdo = self._parent_cli.irscan(chain, tdi, tdi.BitSize)

        key_offset, missmatch_bit_count, inverted = self._fuzzy_find_key(tdo, self._key_pattern)

        continuity = self._get_continuity_result(missmatch_bit_count, inverted, key_offset)
        key_confidence = self._get_key_confidence(missmatch_bit_count)

        ir_capture = tdo[0:key_offset - 1] if key_offset > 0 else self._parent_cli.BitData(0, 0)
        possible_taps = 0
        for i in range(1, ir_capture.BitSize):
            if ir_capture[i] == 0 and ir_capture[i - 1] == 1:  # CapIR must end with 0b01 per IEEE spec
                possible_taps += 1

        results = {
            "Continuity": continuity,
            "Key Confidence": key_confidence,
            "IR Length": key_offset,
            "IR Capture": ir_capture,
            "Max Tap Count": possible_taps,
            "IR TDI": tdi,
            "IR TDO": tdo 
        }
        return DiagnosticResults(results)

    def _chain_bypass_continuity(self, chain):
        """
        Writes the bypass IR and does a DR overscan to check for continuity
        """
        with self._parent_cli.device_locker() as _:
            # First write IR to all 1s to ensure we are bypassing all taps
            tdi = self._parent_cli.BitData(self._ir_overscan_size, -1)
            self._parent_cli.irscan(chain, tdi, tdi.BitSize)

            key_bitcount = self._key_pattern.BitSize
            non_key_bitcount = self._dr_overscan_size  # TODO: Maybe grow size if we get back no key, but get data like values (not all 1s or 0s)
            tdi = self._parent_cli.BitData(non_key_bitcount + key_bitcount, -1)
            tdi.Deposit(0, key_bitcount, self._key_pattern)
            tdo = self._parent_cli.drscan(chain, tdi.BitSize, data=tdi)

        key_offset, missmatch_bit_count, inverted = self._fuzzy_find_key(tdo, self._key_pattern)

        continuity = self._get_continuity_result(missmatch_bit_count, inverted, key_offset)
        key_confidence = self._get_key_confidence(missmatch_bit_count)

        results = {
            "Continuity": continuity,
            "Key Confidence": key_confidence,
            "DR Length": key_offset,
            "DR Capture": tdo[0:key_offset - 1] if key_offset > 0 else self._parent_cli.BitData(0, 0),
            "Max Tap Count": key_offset,
            "DR TDI": tdi,
            "DR TDO": tdo
        }
        return DiagnosticResults(results)

    def _chain_manual_tapreset_idcodes(self, chain, use_pin):
        with self._parent_cli.device_locker() as _:
            # First write IR to all 1s to ensure we aren't already on the idcode IR
            tdi = self._parent_cli.BitData(self._ir_overscan_size, -1)
            tdo = self._parent_cli.irscan(chain, tdi, tdi.BitSize)
            int(tdo)  # flush
            if use_pin:
                try:
                    import py2ipc
                    pinService = py2ipc.IPC_GetService("InterfacePins")
                    pulseWidth = 1200  # in nano-seconds
                    pinService.PulsePinState(
                            chain.parent[0].did,
                            py2ipc.IPC_Types.IPC_InterfacePins_Signals.TRSTn,
                            py2ipc.IPC_Types.IPC_InterfacePins_State.Assert,
                            pulseWidth)
                except py2ipc.IPC_Error as err:
                    if (err.code != py2ipc.IPC_Error_Codes.InterfacePins_Pin_Not_Supported):
                        raise
                    return DiagnosticResults(
                        {
                            "Warning": DecoratedResult("TRST pin not supported", 0, "{}{{}}{}".format(PrintColors.WARNING, PrintColors.ENDC))
                        }
                    )
            else:
                self._parent_cli.jtag_goto_state(chain, "TLR", 6, "COUNT")

            key_bitcount = self._key_pattern.BitSize
            non_key_bitcount = self._dr_overscan_size
            tdi = self._parent_cli.BitData(non_key_bitcount + key_bitcount, -1)
            tdi.Deposit(0, key_bitcount, self._key_pattern)
            tdo = self._parent_cli.drscan(chain, tdi.BitSize, data=tdi)

        key_offset, missmatch_bit_count, inverted = self._fuzzy_find_key(tdo, self._key_pattern)
        continuity = self._get_continuity_result(missmatch_bit_count, inverted, key_offset)
        key_confidence = self._get_key_confidence(missmatch_bit_count)
        results = {
            "Continuity": continuity,
            "Key Confidence": key_confidence,
            "DR Length": key_offset,
            "DR Capture": tdo[0:key_offset - 1] if key_offset > 0 else self._parent_cli.BitData(0, 0),
            "Max Tap Count": int(key_offset / 32),
            "DR TDI": tdi,
            "DR TDO": tdo
        }
        return DiagnosticResults(results)

    def _chain_tapreset_idcodes(self, chain):
        """
        Performs the TapResetGetIdcodes function and returns the results.
        """
        import py2ipc
        jtagcfg = py2ipc.IPC_GetService("JtagDiagnostic")
        try:
            idcodes, status = jtagcfg.TapResetGetIdcodes(chain.did, 100)
        except py2ipc.IPC_Error as err:
            idcodes = None
            status = str(err)
        status = status.title()

        status_color = PrintColors.OKGREEN if status == "Good" else PrintColors.FAIL
        results = {
            'Status': DecoratedResult(status, 0, "{}{{}}{}".format(status_color, PrintColors.ENDC)),
            'Idcodes': DecoratedResult(idcodes, 1, lambda x: ", ".join(["0x{:08X}".format(y) for y in x]))
        }

        results["TLR Idcodes"] = self._chain_manual_tapreset_idcodes(chain, False)

        return DiagnosticResults(results)

    def _tap_stress_dr_readonly(self, tap, ir, dr_size, expected_dr, count_per_bundle, bundle_count):
        """
        Will repeatidly scan a specific DR after setting the IR and ensure
        it always returns the expected value.
        """
        with self._parent_cli.device_locker() as _:
            tdo = self._parent_cli.irscan(tap, ir)  # TODO: Should we have an option to write the IR every time?
            int(tdo)  # force flush
            bad_count = 0
            for _ in range(bundle_count):
                bundle_tdos = []
                for _ in range(count_per_bundle):
                    tdo = self._parent_cli.drscan(tap, dr_size)
                    bundle_tdos.append(tdo)
                int(bundle_tdos[-1])  # force flush
                for tdo in bundle_tdos:
                    if expected_dr is None:
                        expected_dr = tdo
                    if tdo != expected_dr:
                        bad_count += 1

        if bad_count == 0:
            result = DecoratedResult("Stable", 0, "{}{{}}{}".format(PrintColors.OKGREEN, PrintColors.ENDC))
        else:
            result = DecoratedResult("Unstable", 0, "{}{{}}{}".format(PrintColors.FAIL, PrintColors.ENDC))

        results = {
            "IR": DecoratedResult(ir, 1, "0x{:X}"),
            "Result": result,
            "Expected Value": expected_dr,
            "Misscompares": bad_count,
            "Read Count": bundle_count * count_per_bundle
        }
        return DiagnosticResults(results)

    def _tap_stress_dr_readwrite(self, tap, ir, dr_size, count_per_bundle, bundle_count, return_traffic=False):
        """
        Will repeatidly scan to a specific DR after setting the IR and ensure
        it written values can be read back.
        """
        with self._parent_cli.device_locker() as _:
            self._parent_cli.irscan(tap, ir)
            first_tdi = self._parent_cli.BitData(dr_size, random.getrandbits(dr_size))
            tdo = self._parent_cli.drscan(tap, dr_size, data=first_tdi)
            int(tdo)  # force flush

            bad_count = 0
            if return_traffic:
                every_tdi = [first_tdi]
                every_tdo = []
            for _ in range(bundle_count):
                bundle_tdos = []
                bundle_tdis = [first_tdi]
                for _ in range(count_per_bundle):
                    tdi = self._parent_cli.BitData(dr_size, random.getrandbits(dr_size))
                    tdo = self._parent_cli.drscan(tap, dr_size, data=tdi)
                    bundle_tdis.append(tdi)
                    bundle_tdos.append(tdo)
                    if return_traffic:
                        every_tdi.append(tdi)
                        every_tdo.append(tdo)
                int(bundle_tdos[-1])  # force flush
                for i in range(len(bundle_tdos)):
                    if bundle_tdos[i] != bundle_tdis[i]:
                        bad_count += 1
                first_tdi = bundle_tdis[-1]

        if bad_count == 0:
            result = DecoratedResult("Stable", 0, "{}{{}}{}".format(PrintColors.OKGREEN, PrintColors.ENDC))
        else:
            result = DecoratedResult("Unstable", 0, "{}{{}}{}".format(PrintColors.FAIL, PrintColors.ENDC))

        results = {
            "Result": result,
            "IR": DecoratedResult(ir, 1, "0x{:X}"),
            "Misscompares": bad_count,
            "Scan Count": bundle_count * count_per_bundle
        }
        if return_traffic:
            results["TDI"] = every_tdi
            results["TDO"] = every_tdo
        return DiagnosticResults(results)

    def info(self):
        """
        Gathers Jtag related settings and information.
        """
        results = {}
        debugports = self._parent_cli.devs.group_by(nodetype="^Debugport$")
        for debugport in debugports:
            debugport_results = {}
            for chain in [child for child in debugport.children if child.nodetype == "JTAGScanChain"]:
                chain_results = {
                    "TclkRate": DecoratedResult(int(chain.config["Jtag.TclkRate"]), 0, _human_readable_frequency),
                    "LeadingTCKs": int(chain.config["Jtag.LeadingTCKs"]),
                    "TrailingTCKs": int(chain.config["Jtag.TrailingTCKs"]),
                    "Taps":  [x.device.alias for x in chain.children]
                }
                debugport_results["Chain {}".format(chain.device.instanceid)] = DiagnosticResults(chain_results)
            results["Debug Port {}".format(debugport.device.instanceid)] = DiagnosticResults(debugport_results)
        return DiagnosticResults(results)

    def _update_aggregate_result(self, result, key, pass_value, value):
        """
        Increments the Pass Count entry for the result, and ensures we don't
        overwrite a failure result with a pass result.
        """
        decorated = isinstance(value, DecoratedResult)
        if decorated:
            error_occurred = value.value != pass_value
        else:
            error_occurred = value != pass_value
        if key not in result or error_occurred:
            result[key] = value
        pass_count_key = key + " Pass Count"
        if pass_count_key not in result:
            if decorated:
                result[pass_count_key] = DecoratedResult(0, value.rank)
            else:
                result[pass_count_key] = 0
        if not error_occurred:
            if isinstance(result.get(pass_count_key), DecoratedResult):
                result.get(pass_count_key).value = result[pass_count_key] + 1
            else:
                result[pass_count_key] = result[pass_count_key] + 1
        return error_occurred

    def _stress_inner_loop(self, results):
        """
        Performs a single iteration of the jtag stress tests
        """
        error_occurred = False
        debugports = self._parent_cli.devs.group_by(nodetype="^Debugport$")
        for debugport in debugports:
            key = "Debug Port {}".format(debugport.device.instanceid)
            if key not in results:
                results[key] = DiagnosticResults({})
            debugport_results = results[key]
            for chain in [child for child in debugport.children if child.nodetype == "JTAGScanChain"]:
                key = "Chain {}".format(chain.device.instanceid)
                if key not in debugport_results:
                    debugport_results[key] = DiagnosticResults({
                        "TclkRate": DecoratedResult(int(chain.config["Jtag.TclkRate"]), 0, _human_readable_frequency)
                    })
                chain_results = debugport_results[key]
                error_occurred |= self._update_aggregate_result(chain_results, "IR Continuity", "Found", self._chain_ir_continuity(chain).get("Continuity"))
                error_occurred |= self._update_aggregate_result(chain_results, "DR Continuity", "Found", self._chain_bypass_continuity(chain).get("Continuity"))
                error_occurred |= self._update_aggregate_result(chain_results, "Tapreset Idcodes", "Good", self._chain_tapreset_idcodes(chain).get("Status"))
                for tap in chain.children:
                    key = "Tap {}".format(tap.device.alias)
                    if key not in chain_results:
                        chain_results[key] = DiagnosticResults({})
                    tap_results = chain_results[key]
                    idcode_tdi = self._parent_cli.BitData(tap.device.irlength, tap.device.idcodeir)
                    error_occurred |= self._update_aggregate_result(tap_results, "IDCODE Stress", "Stable",  self._tap_stress_dr_readonly(tap, idcode_tdi, 32, None, 10, 10).get("Result"))
        return error_occurred

    def stress(self, seconds, stop_on_error):
        """
        Performs JTAG operations for the given amount of time to establish the
        health of the available JTAG scan chains.
        """
        results = {}
        start = time.time()
        elapsed_time = time.time() - start
        loop_count = 0
        loop_time = 0
        error_occurred = False
        while (elapsed_time < (seconds - loop_time)) and not (stop_on_error and error_occurred):
            loop_count += 1
            error_occurred |= self._stress_inner_loop(results)
            elapsed_time = time.time() - start
            loop_time = elapsed_time / loop_count
        results["Iterations"] = DecoratedResult(loop_count, 1)
        error_occurred_color = PrintColors.FAIL if error_occurred else PrintColors.OKGREEN
        results["Error Occurred"] = DecoratedResult(error_occurred, 0, "{}{{}}{}".format(error_occurred_color, PrintColors.ENDC))
        return DiagnosticResults(results)

    def run(self):
        """
        Performs JTAG operations to establish the health of the available JTAG
        scan chains.
        """
        results = {}
        debugports = self._parent_cli.devs.group_by(nodetype="^Debugport$")
        for debugport in debugports:
            debugport_results = {}
            for chain in [child for child in debugport.children if child.nodetype == "JTAGScanChain"]:
                chain_results = {
                    "TclkRate": DecoratedResult(int(chain.config["Jtag.TclkRate"]), 0, _human_readable_frequency)
                }
                chain_results["IR Continuity"] = self._chain_ir_continuity(chain)
                chain_results["DR Continuity"] = self._chain_bypass_continuity(chain)
                chain_results["Tapreset Idcodes"] = self._chain_tapreset_idcodes(chain)
                for tap in chain.children:
                    tap_results = {}
                    idcode_tdi = self._parent_cli.BitData(tap.device.irlength, tap.device.idcodeir)
                    tap_results["IDCODE Stress"] = self._tap_stress_dr_readonly(tap, idcode_tdi, 32, None, 10, 10)
                    chain_results["Tap {}".format(tap.device.alias)] = DiagnosticResults(tap_results)
                debugport_results["Chain {}".format(chain.device.instanceid)] = DiagnosticResults(chain_results)
            results["Debug Port {}".format(debugport.device.instanceid)] = DiagnosticResults(debugport_results)

        return DiagnosticResults(results)


class DiagnosticsManager():
    """
    Provides diagnostic operations evaluating the health of the Transport and
    JTAG scan chains.
    """

    def __init__(self, cli):
        self._parent_cli = cli  # allows access to ipc commands on the ipc enclosing object
        self._previous_preset = None
        self._experiment_started = False
        self._enable_color_console()
        self._jtag = JtagDiagnostics(cli)
        self._transport = TransportDiagnostics(cli)

    def _enable_color_console(self):
        if not os.name == 'nt':
            return
        import ctypes
        import ctypes.wintypes
        kernel32 = ctypes.windll.kernel32
        mode = ctypes.wintypes.DWORD()
        STD_OUTPUT_HANDLE = kernel32.GetStdHandle(-11)  # https://docs.microsoft.com/en-us/windows/console/getstdhandle        
        kernel32.GetConsoleMode(STD_OUTPUT_HANDLE, ctypes.byref(mode))
        ENABLE_VIRTUAL_TERMINAL_PROCESSING = 4  # https://docs.microsoft.com/en-us/windows/console/setconsolemode
        if not mode.value & ENABLE_VIRTUAL_TERMINAL_PROCESSING:
            mode.value |= ENABLE_VIRTUAL_TERMINAL_PROCESSING
            kernel32.SetConsoleMode(STD_OUTPUT_HANDLE, mode)

    def _filename_sanitizer(self, unsafestring):
        """Sanitizes a filename string"""
        result = "".join( x for x in unsafestring if (x.isalnum() or x in "._- "))
        return result

    def mark_experiment(self, tag='anonymous', logging_preset="All"):
        """Loads preset, flushes logs, and writes start message

        Args:
          tag (str)           : tag/title for the experiment that appears int he logs (defaults to 'anonymous')
          logging_preset (str) : the name of the logging preset to load (defaults to 'All')

        Returns:
          None.
        """
        ### Starts ipccli-level logging
        if self._previous_preset is None:
            ### start writing stdio logging to file
            self._parent_cli.log(os.path.join(os.path.expanduser('~'), '.OpenIPC','stdio.txt'), 'w')
            ### Stash previous logging preset
            self._previous_preset = self._parent_cli.logger.openipc_getloadedpreset()

        ### Sets the logging level
        self._parent_cli.logger.openipc_loadpreset(logging_preset)
        ### Flush pending log messages to disk
        self._parent_cli.logger.openipc_flush()
        ### Add a starting user comments to the log
        self._parent_cli.logger.openipc_writemessage("Beginning experiment: {}".format(tag))

    def collect(self, tag='payload', uid=None):
        """Resets log levels and compiles payload

        Complies an archive with logging and diagnostic info and
        prints path.  Can be used independent of markExperiment.

        Args:
          uid (str) : upload ID (defaults to None)
          tag (str) : tag string to append to archive name (defaults to 'payload')

        Returns:
          None.  Prints archive location.
        """

        archive_path_result = None
        now = datetime.datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
        clean_tag = self._filename_sanitizer(tag)
        payload_name = 'diagnostic_{}_{}'.format(clean_tag, now)

        print("...Collecting Log Results")
        ### Turn off stardard IO logging
        stdiolog.nolog()
        ### Add a ending user comment to the log
        self._parent_cli.logger.openipc_writemessage("experiment complete!")
        ### if we changed logging preset in startexperiment then restore origional level else leave untouched
        if self._previous_preset:
            print("...Logging Preset Restored to {}".format(self._previous_preset))
            self._parent_cli.logger.openipc_loadpreset(self._previous_preset)
            self._previous_preset = None

        ### Flush pending log messages to disk, and start new file
        self._parent_cli.logger.openipc_flush()
        ### Creates a .zip file of the captured log files
        time.sleep(3)
        self._parent_cli.logger.openipc_archive()
        ### Set all logging back to defaults
        self._parent_cli.logger.reset()


        print("...Compiling Payload")
        ROOT_PATH = os.path.join(os.path.expanduser('~'), '.OpenIPC', '')
        IPC_PATH = os.environ['IPC_PATH']

        # create temporary payload folder
        if not os.path.exists(os.path.join(ROOT_PATH, payload_name, '')):
            try:
                os.makedirs(os.path.join(ROOT_PATH, payload_name, ''))
            except OSError as exc: # Guard against race condition
                if exc.errno != errno.EEXIST:
                    raise

        # Retrieve Buildinfo
        build_info_check = os.path.exists(os.path.join(os.path.dirname(IPC_PATH), "Buildinfo.txt"))
        if build_info_check:
            s_loc = os.path.join(os.path.dirname(IPC_PATH), "Buildinfo.txt")
            d_loc = os.path.join(ROOT_PATH, payload_name, '')
            shutil.copy(s_loc, d_loc)
        else:
            print("...Buildinfo.txt not found, continuing...")

        # collect standard io logs
        stdio_log_check = os.path.exists(os.path.join(ROOT_PATH, 'stdio.txt'))
        if stdio_log_check:
            shutil.move(os.path.join(ROOT_PATH, 'stdio.txt'),
                        os.path.join(ROOT_PATH, payload_name, ''))
        else:
            print("...stdio.txt not found, continuing...")

        # collect ipython buffer history in txt file
        try:
            ipython = get_ipython()
            if ipython is not None:
                ipython.magic("history -f {}".format(os.path.join(ROOT_PATH, payload_name, 'buffer_history.txt')))
            else:
                print("...no ipython buffer, continuing...")
        except NameError:
            print("...no ipython instance, continuing...")
            pass

        # misc metrics
        if psutil is not None:
            mem_stats = psutil.virtual_memory()
            mem_str =  ReportResults({'Total Memory'     : mem_stats.total,
                                      'Available Memory' : mem_stats.available,
                                      'Percent Usage'    : str(mem_stats.percent)+'%'})

        info_lst = [str(self._parent_cli.perfreport(categories=["Host"])),
                    str(self._parent_cli.perfreport(categories=["Version"])),
                    'Debugports:\n'              + str(self._parent_cli.debugports),
                    '\nClients:\n'               + str(self._parent_cli.clients()),
                    '\nStalls:\n'                + str(self._parent_cli.stalls.stalls),
                    '\nPlatform Type:\n'         + str(self._parent_cli.debugports.device.platformtype),
                    '\nPlugin Type:\n'           + str(self._parent_cli.debugports.device.plugintype),
                    '\nProbe Type:\n'            + str(self._parent_cli.debugports.device.probetype),
                    '\nProbe Diagnostic Info:\n' + str(self._parent_cli.debugports.device.probediagnosticinfo),
                    '\nDevicelist:\n'            + str(self._parent_cli.devicelist)
                    ]
        if psutil is not None:
            info_lst.insert(2, 'Memory (bytes):\n'   + str(mem_str))
        info_str = '\n'.join(info_lst)

        # write to file
        with open(os.path.join(os.path.dirname(ROOT_PATH), payload_name, 'info.txt'), 'w') as f:
            f.write(info_str)

        # move archive(s) into payload folder
        # Note: if user calls logger archive independently, those will also get incedentally collected
        # Could grab latest archive only (if this behavior is undesired)
        files = os.listdir(ROOT_PATH)
        for f in files:
            if 'Archive' in f:
                shutil.move(os.path.join(ROOT_PATH, f), os.path.join(ROOT_PATH, payload_name, ''))

        # create payload zip and and rm old folder
        shutil.make_archive(os.path.join(ROOT_PATH, payload_name), 'zip', os.path.join(ROOT_PATH, payload_name, ''))
        shutil.rmtree(os.path.join(ROOT_PATH, payload_name,''))

        archive_path_result = os.path.join(ROOT_PATH, '{}.{}'.format(payload_name,'zip'))
        print("...Diagnostic Payload can be found at: {}".format(archive_path_result))


    def _find_procs_by_name(self, name):
        """ Return a list of processes matching 'name'. """
        import fnmatch
        ls = []
        for p in psutil.process_iter():
            this_name = ""
            exe = ""
            cmdline = []
            try:
                this_name = p.name()
                cmdline = p.cmdline()
                exe = p.exe()
            except (psutil.AccessDenied, psutil.ZombieProcess):
                pass
            except psutil.NoSuchProcess:
                continue
            if fnmatch.fnmatch(this_name, name) or (cmdline and fnmatch.fnmatch(cmdline[0], name)) or fnmatch.fnmatch(os.path.basename(exe), name):
                ls.append(p)
        return ls

    def _host_info(self):
        """
        Collects information about the running host.
        """
        results = {
            "Clients": len(self._parent_cli.clients())
        }
        if psutil is not None:
            process = psutil.Process(os.getpid())
            results['Client Memory'] = DecoratedResult(process.memory_info().rss, 6, _human_readable_size)

            openipc_procs = self._find_procs_by_name("OpenIPC_*")
            if openipc_procs:
                results['OpenIPC Memory'] = DecoratedResult(openipc_procs[0].memory_info().rss, 5, _human_readable_size)

            mem_stats = psutil.virtual_memory()
            results['Total Memory'] = DecoratedResult(mem_stats.total, 1, _human_readable_size)
            results['Available Memory'] = DecoratedResult(mem_stats.available, 2, _human_readable_size)
            results['Memory Usage'] = DecoratedResult(mem_stats.percent / 100, 3, "{:.0%}")
        return DiagnosticResults(results)

    def run(self, categories=["JTAG"]):
        """Runs the available jtag diagnostic tests.

        Run a number of jtag diagnostic test and returns a structure containing
        the results of the tests.

        Args:
          categories (list) : a list of the categories to include in the
                report. (default is ["JTAG"]).

        Returns:
          DiagnosticResults
        """
        results = {}
        if "JTAG" in categories:
            results["JTAG"] = self._jtag.run()
        return DiagnosticResults(results)

    def info(self, categories=["Host", "JTAG", "Transport"]):
        """Gathers the settings

        Collects the settings that affect JTAG and returns them as a report.

        Args:
          categories (list) : a list of the categories to include in the
                report. (default is ["Host", "JTAG", "Transport"]).

        Returns:
          DiagnosticResults
        """
        results = {}
        if "Host" in categories:
            results["Host"] = self._host_info()
        if "JTAG" in categories:
            results["JTAG"] = self._jtag.info()
        if "Transport" in categories:
            results["Transport"] = self._transport.info()
        return DiagnosticResults(results)

    def stress(self, categories=["JTAG"], seconds=60, stop_on_error=True):
        """Runs the available stress tests.

        Run stress tests for the specified amount of time. The aggregate of the
        results will be returned. If stop_on_error is True and an error occurs
        then details about the error will be returned instead.

        Args:
          categories (list) : a list of the categories to include in the
                report. (default is ["JTAG"]).
          seconds (int) : the number of seconds to run for. This is divided
                evenly among the selected categories. (default is 60)
          stop_on_error (bool) : if true the stress execution will end once
                an error is detected even if the specified time hasn't elapsed
                yet. (default is True)

        Returns:
            DiagnosticResults
        """
        results = {}
        if "JTAG" in categories:
            results["JTAG"] = self._jtag.stress(seconds / len(categories), stop_on_error)
        return DiagnosticResults(results)
