# INTEL CONFIDENTIAL
# Copyright 2019Intel Corporation
#
# The source code  contained or  described herein and  all documents related to
# the source code  ("Material") are owned by Intel Corporation or its suppliers
# or licensors.  Title to the  Material  remains with  Intel Corporation or its
# suppliers  and licensors. The Material contains trade secrets and proprietary
# and  confidential  information  of  Intel  or  its  suppliers  and  licensors.
# The Material  is protected  by worldwide  copyright and trade secret laws and
# treaty provisions.  No Spart of the Material  may  be used, copied, reproduced,
# modified, published, uploaded, posted, transmitted, distributed, or disclosed
# in any way without Intel's prior express written permission. No license under
# any  patent,  copyright, trade secret or other intellectual property right is
# granted  to  or conferred upon you by disclosure or delivery of the Materials,
# either expressly, by implication, inducement, estoppel or otherwise.
# Any license under such intellectual property rights must be express and
# approved by Intel in writing.
from copy import deepcopy, copy
from ...utils.ordereddict import odict
from ...logging import getLogger
from ...comp import (
    ComponentGroup,
    GetComponentFromFile,
    NamedComponent,
    NamedComponentDefinition,
    )
from ...nodes import (
    NamedNodeValue,
    NamedNodeArrayValue,
    NamedNodeDefinition,
)
from ...discoveries import static
# to make sure all the accesses are registered..
from ...accesses.general import (
    AccessOffline,
    _StoredValuesFormat
)
from ...accesses import (
    AccessStoredValues,
    AccessRegisterStoredValues,
    AccessRegister,
    AccessField,
    StateNodeAccess,
    StateNodeOfflineAccess,
    StateNodeAccessNoLock,
    ScanFieldAccess,
    AccessSnapshot,
    AccessSnapshotFrozen,
    AccessSnapshotCapture,
    )
from ...accesses.general import _StoredValuesFormat
from ...precondition import MultipleAccesses
from ...plugins.nn_logregisters import _read_if_available
from ...plugins import nn_logregisters
from collections import OrderedDict
from ...nodes import NodeTypes

_LOG = getLogger("namednodes")


# this is an exception:
from ...access import AvailabilityOptions
from ...errors import SnapshotOverride
import six
from ...utils._py2to3 import *
import json
import functools

from ...discoveries.snapshot import SnapshotDiscovery

# mapping of known accesses that should be converted to other accesses
# when we are in offline mode
_ACCESSES_SNAPSHOT_MAP = {
    # # FROM , TO
    AccessField: AccessField, # leave this one pointing to accessfield

}
_ACCESSES_FROZEN_MAP = {
    # # FROM , TO
    AccessRegister: AccessSnapshotFrozen,
    # state node accesses should use typical frozen class
    # to make sure we don't remap safe ones
    AccessField: AccessField, # leave this one pointing to accessfield
    ScanFieldAccess: ScanFieldAccess,
    AccessRegisterStoredValues: AccessSnapshotFrozen,
}


# in the future, I expect we will load components in to this thing maybe...
class _dummyDiscoveryManager(object):
    def __init__(self):
        self._discovered = odict()

    def _add_component(self, component):
        self._discovered[component.name] = component

    def __getattr__(self, compname):
        if compname in self._discovered:
            return self._discovered[compname]
        else:
            object.__getattribute__(self, compname)

    def get_all(self, *args, **kwargs):
        # this should not actually do anything, but needs
        # to exist to provide compatibility
        return

    def __dir__(self):
        attrs = sorted(self._discovered.keys())
        return attrs


class _snapshotSvOverride(object):
    original_discoveries = None
    manager = None

    def __init__(self, snapsv):
        self._snapsv = snapsv
        self.original_discoveries = odict()

    def __enter__(self):
        from ...discovery import DiscoveryManager
        self.manager = DiscoveryManager.get_manager()
        # if the manager is not initialized, the next_discoveries won't become
        # available:
        self.manager.initialize()
        next_discoveries = odict()
        for component in self._snapsv._discovered.values():
            item_name = component.target_info.get("_discovery_name", None)
            if item_name is None:
                raise RuntimeError("_discovery_name is missing, this occurs if someone uses find_all instead of get_all")
            if item_name not in next_discoveries:
                next_discoveries[item_name] = SnapshotDiscovery(item_name)
            next_discoveries[item_name]._add_component(component)
        # do the first get_all call to make sure they are all there
        for disc in next_discoveries.values():
            disc.get_all()
        # save the old discoveries
        self.original_discoveries = copy(self.manager._all_discoveries)
        self.manager._all_discoveries = next_discoveries
        return

    def __exit__(self, *args):
        self.manager._all_discoveries = copy(self.original_discoveries)
        return


class Snapshot(object):
    _copies = None
    _active = False # track whether an override is already in place
    # tracks data for which paths were accessed
    _snapshot_data = None
    # keeps a list of which nodes failed to read during a freeze
    _failed_reads = None
    # whether freeze has been called
    _frozen = None
    # copy of discovery manager OR data components from file
    _sv = None
    # set if we have loaded a set of static discoveries from the file
    sv = None
    # used to hold context override
    _sv_override = None
    # this key used in snapshot_data to track available flags per component
    # this dictionary will be key=value of comp.path=old_available flag
    _avail = "available"

    def __init__(self,  sv=None, **kwargs):
        """
        Args:
             all_reads_delayed : make all reads go through delayed read flow
        """
        # experimenting with modify/copy idea...
        self._all_reads_delayed = kwargs.pop("all_reads_delayed", False)
        if len(kwargs):
            raise ValueError("Unexpected kwargs: %s"%str(list(kwargs.keys())))
        if sv is None:
            from ... import discovery
            self._sv = discovery.DiscoveryManager.get_manager()
        else:
            self._sv = sv
        self._frozen = False
        # this will save "component name" : "save data dict"
        # save data dict is of the form
        self._saved_data = odict()
        # very similar to saved data, but this is inside the snapshot
        # specific target info
        self._snapshot_data = odict()
        self._failed_reads = odict()

    @property
    def all_reads_delayed(self):
        """Whether reads are currently being delayed till freeze"""
        return self._all_reads_delayed

    @all_reads_delayed.setter
    def all_reads_delayed(self, value):
        """Set reads to be delayed till freeze"""
        self._all_reads_delayed = value

    def freeze(self, skip_errors=True, silent=False):
        """
        Args:
             skip_errors (bool) : whether to skip errors or stop on error (default=True, skip)
             silent (bool) : whether to send status to std out as we collect register values
        """
        if self.sv is not None:
            raise RuntimeError("Cant use freeze on a shapshot that was loaded from disk")
        if self._frozen:
            raise RuntimeError("Snapshot already frozen")
        self._switch_all_to_snapshot_capture()
        try:
            self._do_delayed_reads(skip_errors, silent=silent)
        finally:
            self._restore_all()
        self._frozen = True  # NOTE: this is set BEFORE we do all the delayed reads

    def jsonlogregisters(self, file, **kwargs):
        """Supports only a subset of the normal logregisters parameters
        """
        if not self.is_frozen():
            raise RuntimeError("Can only run this on a frozen or loaded snapshot")

        binary_output = kwargs.pop('binary_output', False)
        if len(kwargs) > 0:
            raise ValueError("Unknown kwargs provided: %s"%str(list(kwargs.keys())))
        output_dict = OrderedDict()

        if isinstance(file, basestring):
            fout = open(file, 'w')
        else:
            # if "file" argument is not a string then assume
            # it's a file handle
            fout = file

        if self._frozen:
            sv = self._sv
            saved_data = self._saved_data
        else:
            sv = self.sv
            saved_data = {s.name: s.target_info.get(AccessOffline._cache_name, {}) for s in sv._discovered.values()}

        for disc_item_name, disc_item_data in saved_data.items():
            disc_item = getattr(sv, disc_item_name)

            for comp_path, node_dict in disc_item_data.items():
                last_comp = disc_item.get_by_path(comp_path)
                for node_path, node_info in node_dict.items():
                    node = last_comp.get_by_path(node_path)
                    node.value = node_info['value']
                    node_entries = None
                    if node.type == NodeTypes.Register:
                        if not node.nodenames:
                            node_entries = nn_logregisters._register_entry(node, "", binary_output)
                    elif node.type == NodeTypes.Field:
                        node_entries = nn_logregisters._field_entry(node, "", binary_output)
                    if node_entries:
                        output_dict.update(node_entries)


        json_output = {'namednodes_simple': output_dict}
        json.dump(json_output, fout, indent=4)
        if isinstance(file, basestring):
            fout.close()

    @classmethod
    def is_active(cls):
        """Reports whether snapshot is active"""
        return cls._active

    def is_frozen(self):
        """Reports whether snapshot accesses will access cached data (True) or potentially the system (False)"""
        return (self._frozen or self.sv is not None)

    def is_empty(self):
        """Reports whether a frozen snapshot does not contain any captured
        data"""
        if not self.is_frozen():
            raise RuntimeError("Can only check frozen snapshots")
        has_data = False
        if self.sv:
            for d in self.sv._discovered.values():
                cache = d.target_info.get(AccessSnapshot._cache_name, {})
                comps = list(cache.keys())
                has_data |= len(comps)
        else:
            for v in self._saved_data.values():
                values = list(v.keys())
                has_data |= len(values)
        return not has_data

    @property
    def freeze_read_failures(self):
        """return list of paths that failed during freeze"""
        return self._failed_reads.keys()


    def _switch_all_to_snapshot_capture(self):
        """switches sv object so that read data is captured"""

        access_map = copy(_ACCESSES_SNAPSHOT_MAP)

        def create_get_all(old_get_all):
            def new_get_all(*args, **kwargs):
                # cant survive args, and drop refresh since we are forcing that...
                kwargs.pop("refresh", None)
                return old_get_all(refresh=False, **kwargs)
            return new_get_all
        
        for discovery in self._sv.discoveries.values():
            if not hasattr(discovery, "_backup_get_all"):
                discovery._backup_get_all = discovery.get_all
            discovery.get_all = create_get_all(discovery._backup_get_all)

            for component in discovery.discovered:
                # create a copy of the access map, and make sure we dont override any state accesses
                # other than the field ones
                # this is not needed, and I think the 100% proper way here
                # is to build derived classes that call their parent for the read and
                # then save the value
                for cname, cobj in component.access_classes.items():
                    if cobj not in access_map:
                        newc = type(cname+"_Snapshot", (AccessSnapshotCapture,),
                                    {
                                        'original_class': cobj,
                                        # prevent telemetry from logging Bundled Accesses
                                        '_logged_read': True,
                                        '_logged_write': True,
                                        '_logged_read_write': True,
                                        '_logged_store': True,
                                        '_logged_flush': True,
                                    }
                        )
                        access_map[cobj] = newc
                saved_data = self._saved_data.setdefault(component.name, {})
                snapshot_data = self._snapshot_data.setdefault(component.name, {})
                # set setting for all reads delayed or not...but only if we are not about to freeze
                if self._frozen:
                    snapshot_data[AccessSnapshot._all_reads_delayed] = False
                else:
                    snapshot_data[AccessSnapshot._all_reads_delayed] = self._all_reads_delayed
                static._switch_offline(component,
                                       accesses_map=access_map,
                                       default_access=AccessSnapshotCapture)
                # keep in mind we may be in offline mode and this may not do anything
                AccessStoredValues._init_cache(component, saved_data)
                # else: cache exist do nothinger, we're in offline and will share it...
                component.target_info[AccessSnapshot._snapshot_key] = snapshot_data
                component.target_info["_in_snapshot"] = True

    def _switch_all_to_snapshot_frozen(self):
        """switch sv object so that we read only captured data and throw exceptions on missing data"""
        for discovery in self._sv.discoveries.values():
            for component in discovery.discovered:
               self._switch_one_to_snapshot_frozen(component)

    def _switch_one_to_snapshot_frozen(self, component, loading_from_file=False):
        """
        Switch a component to frozen accesses...update available flags as well
        """
        static._switch_offline(component,
                               accesses_map=_ACCESSES_FROZEN_MAP,
                               default_access=AccessSnapshotFrozen)
        if not loading_from_file:
            save_data = self._saved_data.setdefault(component.name, {})
            snapshot_data = self._snapshot_data.setdefault(component.name, {})
            snapshot_data.setdefault(self._avail, {})
            self._switch_available_flags_frozen(component)
            component.target_info[AccessSnapshot._cache_name] = save_data
            component.target_info[AccessSnapshot._snapshot_key] = snapshot_data

        component.target_info["_in_snapshot"] = True

    def _switch_available_flags_frozen(self, component):
        """collect existing available flags from target_info before we modify them"""
        # we save old available flags in a subset of the snapshot data
        snapshot_data = self._snapshot_data.setdefault(component.name, {})
        avail = snapshot_data.setdefault(self._avail, {})
        # use the keys from saved_info for our check to know about available or not...
        save_data = self._saved_data.setdefault(component.name, {})
        # hopefully not, but could have a mix
        # currently not sure how to "get" format at this spot in the code
        comp_paths = [
            c[0] if isinstance(c, tuple) else c for c in save_data.keys()
        ]
        for comp in component.walk_components():
            avail[comp.path] = comp.target_info.get("available", LookupError)
            comp.target_info['available'] = AvailabilityOptions.Available \
                if comp.path in comp_paths else AvailabilityOptions.ChildrenOnly

    def _restore_available_flags(self, component):
        """collect existing available flags from target_info before we modify them"""
        # we save old available flags in a subset of the snapshot data
        snapshot_data = self._snapshot_data.setdefault(component.name, {})
        avail = snapshot_data.setdefault(self._avail, {})
        for comp in component.walk_components():
            avail_flag = avail.get(comp.path, None)
            # avail flag can be missing if there was an exception earlier during go offline flow...
            if avail_flag is not None:
                if avail[comp.path] is LookupError:
                    del comp.target_info['available']
                else:
                    comp.target_info['available'] = avail[comp.path]

    def _restore_all(self):
        """make sure sv object is in its restored state"""
        for discovery in self._sv.discoveries.values():
            if hasattr(discovery, "_backup_get_all"):
                discovery.get_all = discovery._backup_get_all
            for component in discovery.discovered:
                # this one was never offline, nothing to do...
                if component.target_info.get("_in_snapshot", False) == False:
                    continue
                static._switch_online(component)
                # restore available flag
                if self._frozen:
                    self._restore_available_flags(component)
                # remove the saved data, it may or may not be our own dictionary depending
                # on whether we started in offline mode or not
                cache_data = component.target_info.pop(AccessSnapshot._cache_name)
                self._saved_data[component.name].update(cache_data)
                component.target_info.pop(AccessSnapshot._snapshot_key, None)
                component.target_info.pop("_in_snapshot")

    def _do_delayed_reads(self, skip_errors, silent=False):
        """capture all the reads that were requested to be delated"""
        AccessSnapshotCapture._capture_time = True
        for discovery in self._sv.discoveries.values():
            for top_component in discovery.discovered:
                multiple = MultipleAccesses()
                with multiple:
                    # really needs to be moved up and/or global....
                    delayed_reads = top_component.target_info[AccessSnapshot._snapshot_key].get(
                        AccessSnapshot._delayed_reads, [])
                    lastcomp = None
                    lastcomp_path = None
                    for comppath, nodepath in delayed_reads:
                        try:
                            if lastcomp_path != comppath:
                                if not silent:
                                    sys.stdout.write("\nCapturing for: %s..." % comppath)
                                    sys.stdout.flush()
                                if lastcomp is not None:
                                    lastcomp.definition.tolmdb.clear_cache(recursive=False, silent=True)
                                lastcomp = top_component.get_by_path(comppath)
                                lastcomp_path = comppath
                                lastcomp.target_info['is_available_cache'] = {}
                            # now get node, still need by path due to arrays and such...
                            r = lastcomp.get_by_path(nodepath)
                            # now do "read" if we can
                            # make sure exceptions bubble back to this code
                            available = _read_if_available(r, raise_exceptions=True)
                            if not available:
                                self._failed_reads[r.path] = 1
                        except Exception as e:
                            if skip_errors:
                                import traceback
                                stack = traceback.format_exc()
                                _LOG.debug(stack)
                                _LOG.result("Error reading %s.%s for snapshot"%(
                                    comppath,
                                    nodepath))
                                if r is not None:
                                    self._failed_reads[r.path] = 1
                            else:
                                AccessSnapshotCapture._capture_time = False
                                raise
        # capturing is complete
        AccessSnapshotCapture._capture_time = False

    def __enter__(self):
        """enter for context manager that is based on whether freeze has been called"""
        if self.__class__._active:
            raise SnapshotOverride("Only one snapshot override can be active at a time")
        self.__class__._active = True
        # now use loaded or current sv objects...
        if self.sv is None:
            if self._frozen:
                self._switch_all_to_snapshot_frozen()
            else:
                self._switch_all_to_snapshot_capture()
        else:
            self._sv_override = _snapshotSvOverride(self.sv)
            return self._sv_override.__enter__()

    def __exit__(self, *args):
        """restores sv object to previous working state"""
        self.__class__._active = False
        if self._sv_override is not None:
            return self._sv_override.__exit__(*args)
        else:
            self._restore_all()

    def _get_deep_copies(self, sv):
        copies = []
        for name in sv.discoveries:
            discovered = sv.discoveries[name]
            for d in discovered.discovered:
                comp_copy = deepcopy(d)
                copies.append(comp_copy)
        return copies

    def save(self, filepath):
        """Load from the specified path"""
        # TODO: ...how to make this look...
        if self._frozen is False:
            raise RuntimeError("Can only save off a frozen snapshot")
        if self._active:
            raise RuntimeError("Cant save while sv has been overridden")
        components = []
        try:
            for discovery in self._sv.discoveries.values():
                for component in discovery.discovered:
                    # dont save snapshots without data..but let the switch happen so that restore all works...
                    if self._saved_data.get(component.name, {}) == {}:
                        continue
                    self._switch_one_to_snapshot_frozen(component)
                    components.append(component)
            static.write(components, filepath)
        except:
            raise
        finally:
            access_exception = sys.exc_info()
            try:
                self._restore_all()
            finally:
                if access_exception != (None, None, None):
                    six.reraise(*access_exception)

    def load(self, filepath):
        """Load from the specified path"""
        # make sure user only user doesn't load a snapshot that we currently are using to
        # override existing SV objects
        if len(self._saved_data) > 0:
            raise RuntimeError("Cannot load an offline snapshot in to one where we have already started collecting data")
        if self._active:
            raise RuntimeError("Cant load while sv has been overridden")
        # TODO: ...how to make this look...
        self.sv = _dummyDiscoveryManager()
        #setattr(self.sv, '_snap_discovered', [])
        # temp...
        temp = static.load(filepath)
        for comp in temp:
            self._switch_one_to_snapshot_frozen(comp, loading_from_file=True)
            self.sv._add_component(comp)
            self._load_nested_dict_format(comp)

    def _load_nested_dict_format(self, comp):
        for cpath, node_dict in comp.target_info.get(AccessSnapshot._cache_name, {}).items():
            next_comp = comp.get_by_path(cpath, missing=True)  # dont failure return None if missing
            # create the necessary components
            if isinstance(next_comp, list):
                next_comp, missing = next_comp
                # if no sub component found, use the current one
                next_comp = comp if next_comp is None else next_comp
                for next_comp_name in missing.split("."):
                    next_def = NamedComponentDefinition(next_comp_name)
                    next_comp = next_comp.add_component(next_def.create(next_comp_name))

            for node_path, node_value in node_dict.items():
                next_node = next_comp.get_by_path(node_path, missing=True)
                if isinstance(next_node, list):
                    next_node, missing = next_node
                    next_node = next_comp.definition if next_node is None else next_node.definition
                    if isinstance(next_node, NamedNodeArrayValue):
                        _LOG.error("Error loading snapshot, array %s missing element: %s, skipping it..."%
                                   (next_node.path, str(missing)))
                    else:
                        for n, next_node_name in enumerate(missing.split(".")):
                            next_next_node = NamedNodeDefinition(
                                    next_node_name,
                                    accesses={"constant": "AccessStoredValues"},
                                )
                            # if no nodes found, add to the component
                            next_node = next_node.add_node(next_next_node)


### add to discovery a 'temporary_override({discovery_mapping}' capability
### create a "TemporarySnapshotDiscovery" -- create one for each type of thing
### use temporary_override to replace the current discoveries temporarily
###
### our "sv" object...inherit from discovery manager and just get rid of initialize ?
