#!/usr/bin/env python2
# ===========
# pysap - Python library for crafting SAP's network protocols packets
#
# SECUREAUTH LABS. Copyright (C) 2021 SecureAuth Corporation. All rights reserved.
#
# The library was designed and developed by Martin Gallo from
# the SecureAuth's Innovation Labs team.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# ==============

# Standard imports
import logging
from os import environ
from os.path import join
from argparse import ArgumentParser
# Custom imports
import pysap
from pysap.SAPPSE import SAPPSEFile
from pysap.SAPCredv2 import SAPCredv2, cipher_algorithms


pysapgenpse_usage = """

seclogin command:
----------------

- show SSO-credentials in a cred file

  pysapgenpse -c seclogin -l[v][f <cred_v2>]

- decrypt SSO-credentials in a cred file

  pysapgenpse -c seclogin -d[v][f <cred_v2>] [-u <username>] [-n 0] [-o <output>]


get_pse_certs command:
-------------------

- decrypt and show/export certificates in a PSE file

  pysapgenpse -c get_pse_certs -[v][f <pse_path>] [-x pin] [-n 0] [-o <output>]

"""


class PySAPGenPSE(object):

    # Private attributes
    _logger = None

    # Instance attributes
    mode = None
    log_level = None

    commands = ["seclogin", "get_pse_certs"]

    @staticmethod
    def parse_options():
        """Parses command-line options.
        """
        description = "Basic and experimental implementation of sapgenpse tool."

        parser = ArgumentParser(usage=pysapgenpse_usage, description=description, epilog=pysap.epilog)
        parser.add_argument("--version", action="version", version="%(prog)s {}".format(pysap.__version__))

        # Commands
        parser.add_argument("-c", dest="command", help="Command", choices=PySAPGenPSE.commands)
        parser.add_argument("-f", dest="filename", help="Input filename", metavar="FILE")
        parser.add_argument("-o", dest="output", help="Output filename", metavar="FILE")

        seclogin = parser.add_argument_group("seclogin command options")
        seclogin.add_argument("-l", "--list", dest="list", action="store_true", help="List the contents of a cred file")
        seclogin.add_argument("-d", "--decrypt", dest="decrypt", action="store_true",
                              help="Decrypt the contents of a cred file")
        seclogin.add_argument("--no-decrypt-provider", dest="decrypt_provider", action="store_false",
                              help="Do not attempt to further decrypt using the provider")
        seclogin.add_argument("-u", dest="username", help="Username to decrypt the cred file [$USER or $USERNAME env]")

        get_pse_certs = parser.add_argument_group("get_pse_certs command options")
        get_pse_certs.add_argument("-x", dest="pin", help="Pin to use to decrypt the PSE file")

        misc = parser.add_argument_group("Misc options")
        misc.add_argument("-v", "--verbose", dest="verbose", action="store_true", help="Verbose output")
        misc.add_argument("-n", dest="number", type=int, help="Number of credential/certificate to decrypt")

        (options, args) = parser.parse_known_args()

        return options, args

    @property
    def logger(self):
        """Sets the logger of the cli tool.
        """
        if self._logger is None:
            self._logger = logging.getLogger("pysap.pysapgenpse")
            self._logger.setLevel(logging.DEBUG if self.verbose else logging.INFO)
            self._logger.addHandler(logging.StreamHandler())
        return self._logger

    def main(self):
        """Main routine for parsing options and dispatch desired command.
        """
        options, args = self.parse_options()

        # Set the verbosity
        self.verbose = options.verbose

        self.logger.info("pysapgenpse version: %s", pysap.__version__)

        # Dispatch the command
        if options.command in self.commands:
            getattr(self, options.command)(options, args)
        else:
            self.logger.error("pysapgenpse: Invalid command '%s'", options.command)
            self.logger.error("pysapgenpse: Valid commands are: %s", ",".join(self.commands))

    def open_file(self, filename, cls, type):
        """Opens a file to work with it and tries to parse it.

        :param filename: name of the file to open
        :type filename: string

        :param cls: class to use for parsing the file
        :type cls: type

        :param type: name of the file type
        :type type: string

        :return: parsed packet
        :rtype: ASN1_Packet
        """

        try:
            with open(filename) as f:
                obj = cls(f.read())
            self.logger.info("pysapgenpse: Reading {} file '{}'".format(type, filename))
        except IOError as e:
            self.logger.error("pysapgenpse: Error reading {} file '{}' ({})".format(type, filename, e.message))
            return None

        return obj

    def get_pse_certs(self, options, args):
        """Implements operations on PSE related files.
        """

        # If no filename was specified
        if not options.filename:
            self.logger.error("pysapgenpse: No PSE file specified")
            return

        # If no pin was specified
        if not options.pin:
            self.logger.error("pysapgenpse: No PIN provided")
            return

        # Parse the PSE file
        try:
            pse = self.open_file(options.filename, SAPPSEFile, "PSE")
        except Exception:
            self.logger.error("pysapgenpse: Unable to read certificates in PSE file {}\n".format(options.filename))
            return

        if pse is None or not pse.enc_cont:
            self.logger.error("pysapgenpse: No encrypted content found in file {}".format(options.filename))

        plain = pse.decrypt(options.pin)
        self.logger.info("Decrypted PSE, {} bytes".format(len(plain)))

        self.write_output(options.output, plain)

    def seclogin(self, options, args):
        """Implements operations on SSO-credentials related files.
        """

        # If no filename was specified and SECUDIR is set, use as default
        if not options.filename:
            if "SECUDIR" in environ:
                options.filename = join(environ["SECUDIR"], "cred_v2")
            else:
                self.logger.error("pysapgenpse: No credential file specified")
                return

        # If no username was specified and USER or USERNAME env variable is set, use as default
        if not options.username:
            if "USER" in environ:
                options.username = environ["USER"]
            elif "USERNAME" in environ:
                options.username = environ["USERNAME"]

        # Parse the credential file
        try:
            cred_v2 = self.open_file(options.filename, SAPCredv2, "credentials")
        except Exception:
            self.logger.error("pysapgenpse: Unable to read credentials in file {}\n".format(options.filename))
            return

        # Validate that there are credentials there
        if not (cred_v2 and cred_v2.creds):
            self.logger.error("pysapgenpse: No credentials found in file {}\n".format(options.filename))
            return

        # List each credential and decrypt it if required
        i = 0
        for cred in cred_v2.creds:
            cred = cred.cred
            cn = cred.common_name
            pse_path = cred.pse_file_path
            lps = cred.lps_type_str

            self.logger.info("\n {} (LPS:{}): {}".format(i, lps, cn))
            # We're not yet parsing the LPS status of the PSE so it's reported as N/A
            self.logger.info("\t (LPS:N/A): {}".format(pse_path))

            self.logger.debug("\t\t Credential cipher format version {}, algorithm {}".format(
                cred.cipher_format_version, cipher_algorithms[cred.cipher_algorithm]))

            # Try to decrypt if specified
            if options.decrypt and options.username and (options.number is None or options.number == i):
                try:
                    self.decrypt_cred(cred, options.username, options.output, options.decrypt_provider)
                except Exception as e:
                    self.logger.error("pysapgenpse: Error trying to decrypt with username '{}'".format(options.username))
                    self.logger.error(e)

            i += 1

        self.logger.info("\n\n {} readable SSO-Credentials available\n".format(i))

    def decrypt_cred(self, cred, username, output_filename=None, decrypt_provider=True):
        """Decrypts a given credential and writes the output to a file.

        :param cred: credential to decrypt
        :type cred: SAPCredv2_Cred or SAPCredv2_Cred_LPS

        :param username: username to use when decrypting
        :type username: string

        :param output_filename: name of the file to write the output
        :type output_filename: string

        :param decrypt_provider: if further decryption should be done
        :type decrypt_provider: bool
        """
        plain = cred.decrypt(username)

        # If option1 specifies a credential provider, try to further decrypt the PIN using it
        if plain.option1 and plain.option1.val in plain.providers and decrypt_provider:
            try:
                pin = plain.decrypt_provider(cred)
            except Exception as e:
                self.logger.error("pysapgenpse: Unable to decrypt using the provider {} ({}), writing plain blob".format(
                    plain.option1.val, e.message))
                pin = plain.pin.val
        else:
            pin = plain.pin.val

        self.logger.info("\t\t PIN:\t\t{}".format(pin))
        if plain.option1:
            self.logger.debug("\t\t Options:\t{}".format(plain.option1.val))
        if plain.option2:
            self.logger.debug("\t\t\t\t{}".format(plain.option2.val))
        if plain.option3:
            self.logger.debug("\t\t\t\t{}".format(plain.option3.val))
        self.write_output(output_filename, pin)

    def write_output(self, output_filename, output):
        """Writes some output to a file.

        :param output_filename: name of the file to write the output
        :type output_filename: string

        :param output: content to write to the file
        :type output: string
        """
        if output_filename:
            with open(output_filename, "wb") as output_file:
                output_file.write(output)
            self.logger.info("pysapgenpse: Output written to file '{}'".format(output_filename))


if __name__ == "__main__":
    pysapgenpse = PySAPGenPSE()
    pysapgenpse.main()
