# -*- coding: utf-8 -*-
import argparse
import os

from qiling import Qiling
from qiling.const import QL_VERBOSE

PASS_PHRASE_TO_KEYS_HOOK_OFFSET = 0x2621
ROOTFS = '/'


def init_ql(*args) -> Qiling:
    # init the handler
    ql = Qiling(list(args), ROOTFS, verbose=QL_VERBOSE.OFF)

    # on Ubuntu 22.04 ISA patching is required
    # for details refer to https://github.com/qilingframework/qiling/issues/1201
    def try_patch_isa(_ql: Qiling):
        # 00023881  8b8a28030000       mov     ecx, dword [rdx+0x328]
        # 00023887  89cf               mov     edi, ecx
        # 00023889  4421c7             and     edi, r8d  {0x0}
        # 0002388c  39f9               cmp     ecx, edi
        # 0002388e  0f85000b0000       jne     0x24394  // from here

        # 00023894  488d50f8           lea     rdx, [rax-0x8]
        # 00023898  4839f0             cmp     rax, rsi
        # 0002389b  75c3               jne     0x23860

        # 0002389d  4c89ce             mov     rsi, r9  {0x3010102464c457f}  // to here
        # 000238a0  4c89ff             mov     rdi, r15
        # 000238a3  e8d872ffff         call    sub_1ab80

        pre = bytes.fromhex('8b8a2803000089cf4421c739f9')
        ins = bytes.fromhex('0f85000b0000')
        skip = bytes.fromhex('488d50f84839f075c3')

        def bypass_isa_check(__ql: Qiling) -> None:
            print('Bypassing ISA Check...')
            __ql.arch.regs.rip += len(ins) + len(skip)

        for start, end, perm, label, img in _ql.mem.get_mapinfo():
            if label != 'ld-linux-x86-64.so.2':
                continue
            if 'x' not in perm:
                continue

            adrs = _ql.mem.search(pre + ins + skip, begin=start, end=end)
            for adr in adrs:
                _ql.hook_address(bypass_isa_check, adr + len(pre))

    try_patch_isa(ql)

    return ql


def decrypt(binary_path, encrypted_file_path, decrypted_file_path, pkey1: int, pkey2: int, pkey3: int):
    """
        This function uses the binary to decrypt an input file.
        N.B. There's no need to calculate suitable pass-phrase that would give the correct pkey1 - pkey3, we can
             hook at the right place and overwrite these values on the stack of the running process.
             Check out code block 0x24a6 - 0x24ef.

        offsets:
            pkey1= word ptr -22h
            pkey2= word ptr -20h
            pkey3= word ptr -1Eh
    """
    ql = init_ql(binary_path, 'dec', encrypted_file_path, decrypted_file_path, 'ololo')

    # set up the handler on execution reaching 0x2621
    def overwrite_pkeys(_ql: Qiling) -> None:
        # overwrite the keys pkey1 - pkey3 on the stack
        _ql.mem.write(ql.arch.regs.rbp - 0x22, pkey1.to_bytes(2, byteorder='little', signed=False))
        _ql.mem.write(ql.arch.regs.rbp - 0x20, pkey2.to_bytes(2, byteorder='little', signed=False))
        _ql.mem.write(ql.arch.regs.rbp - 0x1e, pkey3.to_bytes(2, byteorder='little', signed=False))

    base_addr = ql.mem.get_lib_base(os.path.basename(ql.path))
    ql.hook_address(overwrite_pkeys, base_addr + PASS_PHRASE_TO_KEYS_HOOK_OFFSET)

    # execute the binary to decrypt the encrypted file
    ql.run()


if __name__ == '__main__':
    # parse cmd args
    parser = argparse.ArgumentParser()
    parser.add_argument('binary_path', type=str, help='Path to tridea binary')
    parser.add_argument('encrypted_file_path', type=str, help='Path to encrypted file')
    parser.add_argument('pkey1', type=str, help='The first 16-bit key pkey1')
    parser.add_argument('pkey2', type=str, help='The second 16-bit key pkey2')
    parser.add_argument('pkey3', type=str, help='The third 16-bit key pkey3')
    parser.add_argument('-o', '--output', type=str, default='flag.decrypted.qiling.png', help='Path to decrypted file')

    args = parser.parse_args()

    # parse the keys
    def to_int(s):
        if s.endswith('h'):
            s = '0x' + s.replace('h')
        return int(s, base=16) if s.startswith('0x') else int(s)

    pkey1 = to_int(args.pkey1)
    pkey2 = to_int(args.pkey2)
    pkey3 = to_int(args.pkey3)

    # decrypt
    decrypt(
        os.path.abspath(args.binary_path),
        os.path.abspath(args.encrypted_file_path),
        os.path.abspath(args.output),
        pkey1, pkey2, pkey3
    )

    pass
