"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.


This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
See the GNU General Public License for more details.


You should have received a copy of the GNU General Public License
 along with this program.  If not, see <https://www.gnu.org/licenses/>.

Copyright © 2019 Cloud Linux Software Inc.

This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import os
from collections import defaultdict
from logging import getLogger

from defence360agent.contracts.config import Malware as Config
from defence360agent.contracts.messages import MessageType
from defence360agent.contracts.plugins import (
    MessageSink,
    MessageSource,
    expect,
)
from imav.malwarelib.config import MalwareScanType
from imav.malwarelib.model import MalwareIgnorePath
from imav.malwarelib.plugins.detached_scan import DetachedScanPlugin
from imav.malwarelib.scan.scanner import MalwareScanner
from defence360agent.utils import recurring_check

logger = getLogger(__name__)


class Scanner(MessageSink, MessageSource):
    _loop, _sink = None, None
    _targets = None
    _scan_task = None

    async def create_source(self, loop, sink):
        self._loop = loop
        self._sink = sink

        self._scan_task = self._loop.create_task(self._recurring_scan())

    async def create_sink(self, loop):
        self._targets = defaultdict(set)

    async def shutdown(self):
        self._scan_task.cancel()
        await self._scan_task

    def _process_scan_task(self, message):
        scan_type = message.get("scan_type", MalwareScanType.REALTIME)
        bucket = self._targets[scan_type]
        # Snapshot FromConfig descriptors once per batch: each read goes
        # through config_to_dict() which deepcopies the merged config, and
        # with cap=100_000 we would otherwise pay ~2 deepcopies per path.
        max_path_len = Config.MAX_PATH_LEN
        max_targets = Config.MAX_TARGETS_PER_SCAN_TYPE
        dropped_long = 0
        for path in message["filelist"]:
            if not isinstance(path, str):
                t = type(path)
                path = os.fsdecode(path)
                logger.error(
                    "Received path %s as %s instead of %s. Message: %s",
                    path,
                    t,
                    type(str),
                    message,
                )
            if len(path) > max_path_len:
                dropped_long += 1
                continue
            if len(bucket) >= max_targets:
                logger.warning(
                    "MAX_TARGETS_PER_SCAN_TYPE cap (%d) reached for "
                    "scan_type=%s; dropping remaining paths in this batch",
                    max_targets,
                    scan_type,
                )
                break
            bucket.add(path)
        if dropped_long:
            logger.warning(
                "Dropped %d path(s) exceeding MAX_PATH_LEN=%d for "
                "scan_type=%s",
                dropped_long,
                max_path_len,
                scan_type,
            )

    @expect(MessageType.MalwareScanTask)
    async def process_scan_task(self, message):
        self._process_scan_task(message)

    @expect(MessageType.MalwareRescanFiles)
    async def rescan_files(self, message):
        filelist = message["files"]
        msg = MessageType.MalwareScanTask(
            filelist=filelist, scan_type=message.get("type", "rescan")
        )
        self._process_scan_task(msg)

    @staticmethod
    async def _filter_out(targets):
        result = list()
        for filename in targets:
            if os.path.exists(
                filename
            ) and not await MalwareIgnorePath.is_path_ignored(filename):
                result.append(filename)
        return result

    async def _scan_targets(self, targets, scan_type):
        if targets:
            logger.info(
                "Checking files to scan with type={}".format(scan_type)
            )

        file_list = await self._filter_out(targets)

        if not file_list:
            return

        logger.debug("Scanning files: %s", file_list)
        scanner = MalwareScanner(sink=self._sink, hooks=True)
        scanner.start(file_list, scan_type=scan_type)
        result = await scanner.async_wait()
        if scanner is not None:
            message = await DetachedScanPlugin.aggregate_result(result)
            await self._sink.process_message(
                MessageType.MalwareScan(**message)
            )

    async def _scan(self):
        # copy set to list to prevent race conditions
        targets, self._targets = self._targets, defaultdict(set)
        for scan_type, files in targets.items():
            await self._scan_targets(files, scan_type)

    @recurring_check(Config.INOTIFY_SCAN_PERIOD)
    async def _recurring_scan(self):
        await self._scan()
