#include <iostream>
#include <fstream>
#include <filesystem>
#include <map>
#include <cstring>
#include <omp.h>

// the IDEA core functions exported by the transformed `libtridea.so` (linked into this binary with `-L. -ltridea`)
extern "C" {
uint16_t *key_sched(const uint16_t key[8], uint16_t subkeys[52]);
uint16_t *key_sched_rev(const uint16_t key[8], uint16_t subkey_inv[52]);
uint64_t encrypt_block(uint64_t M, const uint16_t K[52]);
uint64_t decrypt_block(uint64_t M, const uint16_t K[52]);
}

// the PBOXes extracted from tridea binary (e.g. using auxiliary script `export_pboxes.py`)
// region PBOXes

const uint8_t PBOX1[] = {54, 19, 144, 250, 76, 184, 174, 6, 12, 42, 29, 77, 157, 13, 231, 158, 208, 182, 235, 209, 160,
                         150, 104, 229, 171, 217, 226, 82, 230, 60, 204, 137, 105, 130, 25, 248, 75, 98, 245, 135, 220,
                         67, 41, 210, 152, 140, 254, 39, 173, 50, 3, 211, 233, 81, 16, 189, 255, 97, 179, 74, 214, 43,
                         154, 103, 121, 44, 55, 52, 232, 118, 142, 110, 92, 195, 201, 63, 149, 164, 218, 7, 57, 24, 215,
                         172, 127, 237, 45, 58, 167, 138, 79, 169, 146, 48, 31, 168, 102, 193, 247, 219, 145, 205, 111,
                         236, 0, 83, 166, 88, 51, 106, 162, 216, 5, 49, 64, 170, 80, 17, 115, 99, 227, 191, 23, 242,
                         113, 190, 109, 126, 11, 132, 161, 119, 35, 197, 117, 253, 18, 86, 207, 222, 148, 21, 200, 124,
                         46, 177, 33, 28, 56, 240, 96, 61, 198, 71, 156, 122, 185, 155, 20, 241, 22, 223, 2, 206, 9, 40,
                         89, 180, 53, 128, 212, 93, 165, 252, 36, 101, 147, 131, 66, 178, 34, 136, 202, 85, 87, 112,
                         203, 196, 27, 199, 125, 187, 224, 4, 59, 14, 183, 238, 194, 62, 153, 10, 70, 186, 108, 73, 95,
                         176, 188, 159, 91, 123, 192, 15, 163, 68, 175, 221, 69, 30, 239, 225, 141, 243, 143, 114, 134,
                         84, 90, 251, 213, 234, 133, 246, 107, 65, 72, 32, 228, 151, 129, 249, 94, 26, 244, 78, 139, 37,
                         100, 47, 38, 120, 116, 8, 181, 1};

const uint8_t PBOX2[] = {102, 117, 40, 74, 34, 255, 101, 181, 4, 16, 253, 91, 108, 110, 39, 233, 190, 2, 243, 73, 77,
                         202, 54, 92, 185, 210, 55, 6, 247, 132, 12, 172, 20, 24, 57, 64, 104, 22, 51, 93, 231, 227, 46,
                         154, 183, 223, 187, 209, 221, 177, 111, 89, 75, 135, 174, 116, 63, 99, 230, 148, 150, 165, 65,
                         173, 9, 125, 182, 218, 158, 61, 113, 199, 167, 139, 28, 207, 213, 95, 144, 43, 42, 13, 137,
                         127, 29, 220, 249, 80, 237, 204, 145, 206, 112, 168, 103, 62, 114, 246, 143, 214, 146, 67, 17,
                         229, 97, 161, 130, 128, 140, 134, 49, 178, 25, 189, 105, 19, 200, 160, 60, 236, 41, 159, 119,
                         78, 251, 245, 147, 129, 107, 14, 131, 240, 83, 48, 70, 33, 163, 122, 171, 242, 197, 225, 81,
                         195, 53, 109, 84, 191, 216, 176, 126, 66, 86, 47, 44, 244, 138, 198, 162, 10, 248, 241, 224,
                         252, 235, 175, 186, 120, 232, 76, 228, 164, 215, 166, 184, 205, 58, 226, 79, 136, 152, 157,
                         121, 1, 133, 85, 180, 3, 211, 37, 196, 155, 169, 21, 90, 219, 5, 201, 208, 250, 0, 254, 35, 98,
                         94, 212, 52, 18, 156, 59, 222, 26, 87, 72, 96, 32, 38, 118, 100, 141, 115, 142, 179, 149, 27,
                         36, 238, 203, 31, 192, 194, 50, 123, 234, 56, 69, 151, 82, 124, 217, 88, 193, 11, 239, 170, 7,
                         45, 106, 15, 68, 188, 153, 8, 71, 30, 23};

const uint8_t PBOX3[] = {201, 5, 200, 254, 121, 97, 14, 192, 3, 199, 37, 83, 211, 70, 216, 233, 85, 135, 40, 228, 234,
                         235, 47, 202, 154, 153, 41, 249, 0, 185, 244, 151, 190, 134, 31, 116, 114, 73, 246, 27, 34, 62,
                         87, 90, 69, 6, 8, 101, 168, 204, 81, 60, 29, 12, 241, 146, 33, 148, 251, 196, 237, 232, 138,
                         242, 2, 173, 25, 176, 39, 96, 132, 184, 155, 124, 89, 183, 236, 99, 130, 229, 223, 103, 19,
                         140, 77, 54, 170, 98, 36, 35, 76, 239, 169, 79, 152, 46, 145, 4, 49, 205, 38, 255, 52, 195,
                         215, 107, 44, 109, 206, 67, 82, 253, 187, 128, 110, 22, 72, 182, 24, 164, 118, 157, 10, 224,
                         53, 227, 231, 117, 189, 123, 208, 131, 218, 171, 136, 122, 144, 219, 127, 210, 191, 165, 142,
                         247, 209, 159, 64, 71, 238, 102, 193, 149, 150, 250, 180, 129, 243, 139, 108, 16, 188, 93, 92,
                         95, 207, 59, 94, 214, 30, 217, 248, 179, 63, 115, 66, 137, 197, 113, 181, 18, 175, 126, 65, 28,
                         75, 158, 220, 43, 91, 230, 7, 119, 125, 17, 88, 23, 111, 74, 172, 84, 133, 9, 186, 105, 68,
                         161, 174, 120, 147, 55, 252, 194, 26, 20, 45, 178, 221, 198, 56, 245, 225, 51, 163, 166, 42,
                         57, 58, 100, 160, 162, 226, 11, 21, 15, 240, 213, 143, 222, 32, 156, 78, 104, 167, 106, 50,
                         177, 86, 13, 61, 1, 112, 80, 48, 212, 203, 141};

const uint8_t PBOX4[] = {133, 198, 124, 116, 77, 8, 195, 203, 155, 112, 139, 65, 39, 130, 138, 161, 182, 243, 21, 44,
                         32, 23, 179, 241, 87, 127, 59, 49, 219, 30, 247, 88, 148, 202, 22, 220, 194, 92, 117, 213, 142,
                         94, 118, 162, 43, 13, 91, 114, 90, 180, 10, 73, 189, 242, 74, 250, 120, 160, 51, 70, 60, 6,
                         215, 101, 211, 208, 52, 35, 57, 140, 177, 190, 129, 53, 173, 181, 121, 245, 7, 36, 48, 45, 150,
                         96, 149, 93, 62, 122, 169, 209, 193, 16, 29, 218, 246, 164, 61, 188, 5, 224, 178, 47, 204, 12,
                         196, 230, 110, 34, 98, 199, 75, 1, 163, 103, 233, 76, 28, 46, 151, 97, 113, 187, 186, 136, 249,
                         158, 95, 212, 54, 104, 200, 145, 63, 85, 240, 69, 123, 248, 152, 3, 38, 115, 156, 210, 0, 197,
                         31, 154, 128, 11, 56, 223, 228, 84, 214, 255, 125, 64, 235, 67, 79, 126, 55, 185, 165, 20, 72,
                         15, 99, 157, 58, 83, 66, 108, 80, 42, 17, 9, 168, 111, 159, 50, 253, 234, 102, 167, 137, 184,
                         192, 153, 166, 132, 238, 141, 170, 244, 216, 27, 252, 183, 86, 26, 78, 106, 239, 146, 236, 172,
                         226, 37, 14, 175, 232, 134, 227, 131, 217, 24, 105, 251, 68, 191, 81, 222, 82, 171, 107, 119,
                         206, 174, 89, 237, 221, 33, 201, 205, 147, 2, 207, 71, 225, 19, 176, 135, 100, 144, 254, 4, 18,
                         109, 40, 143, 229, 25, 41, 231};

const uint8_t PBOX5[] = {205, 26, 60, 75, 12, 56, 64, 71, 103, 184, 147, 62, 149, 218, 8, 145, 241, 125, 181, 9, 24,
                         250, 11, 85, 83, 98, 182, 164, 171, 208, 67, 121, 193, 134, 151, 3, 28, 156, 233, 234, 4, 126,
                         143, 55, 231, 123, 37, 248, 253, 170, 148, 109, 224, 112, 111, 228, 101, 226, 161, 57, 94, 188,
                         167, 95, 174, 254, 187, 113, 87, 135, 183, 46, 244, 166, 42, 200, 0, 47, 17, 189, 38, 124, 25,
                         48, 196, 239, 159, 185, 249, 150, 88, 195, 80, 117, 66, 110, 255, 43, 204, 115, 232, 19, 175,
                         13, 222, 30, 201, 168, 41, 96, 192, 35, 21, 158, 162, 52, 51, 176, 93, 136, 100, 53, 78, 202,
                         74, 1, 63, 169, 245, 177, 6, 227, 133, 216, 81, 84, 138, 207, 120, 173, 186, 206, 141, 116,
                         217, 23, 70, 22, 246, 178, 10, 122, 2, 214, 197, 114, 219, 221, 240, 247, 40, 39, 29, 34, 191,
                         142, 213, 172, 69, 130, 199, 152, 146, 44, 32, 209, 89, 119, 180, 127, 198, 237, 223, 220, 179,
                         58, 194, 139, 211, 210, 129, 118, 131, 59, 16, 77, 243, 153, 36, 144, 14, 238, 97, 229, 165,
                         190, 49, 90, 212, 65, 72, 86, 215, 163, 45, 79, 50, 92, 107, 132, 225, 31, 104, 203, 61, 54,
                         15, 251, 154, 106, 128, 102, 82, 137, 230, 236, 157, 68, 252, 99, 33, 105, 73, 140, 7, 5, 160,
                         108, 155, 235, 91, 20, 18, 76, 27, 242};

const uint8_t PBOX6[] = {84, 138, 203, 188, 199, 148, 250, 141, 60, 158, 116, 50, 162, 172, 220, 143, 14, 185, 104, 164,
                         214, 107, 192, 169, 58, 228, 252, 213, 78, 246, 144, 93, 161, 117, 5, 72, 111, 66, 119, 24,
                         244, 227, 105, 34, 120, 255, 6, 151, 179, 59, 99, 3, 239, 163, 142, 230, 167, 121, 232, 131,
                         76, 236, 154, 171, 137, 77, 62, 238, 146, 165, 71, 149, 49, 115, 187, 80, 16, 18, 128, 28, 152,
                         19, 157, 91, 44, 133, 9, 39, 106, 166, 87, 229, 75, 205, 122, 225, 29, 56, 70, 88, 155, 140,
                         182, 124, 150, 147, 136, 90, 249, 211, 25, 177, 20, 64, 242, 159, 190, 176, 109, 42, 23, 17,
                         235, 240, 183, 253, 129, 215, 132, 194, 248, 7, 193, 197, 86, 31, 8, 36, 204, 73, 100, 30, 98,
                         126, 97, 237, 10, 130, 231, 221, 38, 52, 180, 95, 212, 135, 174, 251, 47, 247, 69, 202, 153,
                         125, 13, 43, 46, 168, 108, 83, 102, 37, 139, 96, 127, 210, 26, 12, 4, 145, 11, 89, 33, 63, 101,
                         61, 181, 218, 170, 53, 216, 134, 206, 94, 209, 245, 191, 79, 219, 45, 103, 208, 27, 67, 40,
                         241, 81, 85, 110, 196, 217, 65, 2, 55, 243, 57, 223, 82, 184, 189, 173, 254, 200, 178, 118, 1,
                         35, 114, 15, 234, 74, 226, 156, 22, 198, 48, 32, 0, 123, 222, 160, 186, 224, 207, 51, 41, 175,
                         195, 54, 68, 201, 233, 112, 92, 21, 113};

// endregion PBOXes


// this function checks the imported IDEA implementation functions using IDEA's test vectors
typedef struct {
    uint16_t key[8];
    uint64_t plaintext;
    uint64_t ciphertext;
} TestVector;

void idea_identity_check() {
    // can be found here: https://crypto.stackexchange.com/a/92055
    TestVector test_vectors[] = {
            {{0x0001, 0x0002, 0x0003, 0x0004, 0x0005, 0x0006, 0x0007, 0x0008},
                    0x0000000100020003, 0x11FBED2B01986DE5},
            {{0x0001, 0x0002, 0x0003, 0x0004, 0x0005, 0x0006, 0x0007, 0x0008},
                    0x0102030405060708, 0x540E5FEA18C2F8B1},
            {{0x0001, 0x0002, 0x0003, 0x0004, 0x0005, 0x0006, 0x0007, 0x0008},
                    0x0019324B647D96AF, 0x9F0A0AB6E10CED78},
            {{0x0001, 0x0002, 0x0003, 0x0004, 0x0005, 0x0006, 0x0007, 0x0008},
                    0xF5202D5B9C671B08, 0xCF18FD7355E2C5C5},
            {{0x0001, 0x0002, 0x0003, 0x0004, 0x0005, 0x0006, 0x0007, 0x0008},
                    0xFAE6D2BEAA96826E, 0x85DF52005608193D},
            {{0x0001, 0x0002, 0x0003, 0x0004, 0x0005, 0x0006, 0x0007, 0x0008},
                    0x0A141E28323c4650, 0x2F7DE750212FB734},
            {{0x0001, 0x0002, 0x0003, 0x0004, 0x0005, 0x0006, 0x0007, 0x0008},
                    0x050A0F14191E2328, 0x7B7314925DE59C09},
            {{0x0005, 0x000A, 0x000F, 0x0014, 0x0019, 0x001E, 0x0023, 0x0028},
                    0x0102030405060708, 0x3EC04780BEFF6E20},
            {{0x3A98, 0x4E20, 0x0019, 0x5DB3, 0x2EE5, 0x01C8, 0xC47C, 0xEA60},
                    0x0102030405060708, 0x97BCD8200780DA86},
            {{0x0064, 0x00C8, 0x012C, 0x0190, 0x01F4, 0x0258, 0x02BC, 0x0320},
                    0x05320A6414C819FA, 0x65BE87E7A2538AED},
            {{0x9D40, 0x75C1, 0x03BC, 0x322A, 0xFB03, 0xE7BE, 0x6AB3, 0x0006},
                    0x0808080808080808, 0xF5DB1AC45E5EF9F9}
    };
    for (int i = 0; i < sizeof(test_vectors) / sizeof(TestVector); i++) {
        uint64_t M = test_vectors[i].plaintext;
        uint64_t C = test_vectors[i].ciphertext;

        uint16_t *key = test_vectors[i].key;
        uint16_t K[52], K_rev[52];
        key_sched(key, K);
        key_sched_rev(key, K_rev);

        uint64_t C1 = encrypt_block(M, K);
        if (C != C1) {
            printf("Encryption failed for test vector %d\n", i);
        }
        uint64_t M1 = decrypt_block(C, K_rev);
        if (M != M1) {
            printf("Decryption failed for test vector %d\n", i);
        }
    }
}


// these three functions turn 16-bit values into 128-bit encryption keys (for each of the enc-dec-enc stages)
// reversed-engineered code block 0x24f4 - 0x261d of function 0x2409 (`pass_phrase_to_keys`)
static inline void gen_keys_1(uint16_t pkey1, uint16_t key1[8]);

static inline void gen_keys_2(uint16_t pkey2, uint16_t key2[8]);

static inline void gen_keys_3(uint16_t pkey3, uint16_t key3[8]);

void gen_keys_1(uint16_t pkey1, uint16_t key1[8]) {
    for (int i = 0; i < 8; i++) {
        key1[i] = (PBOX1[(pkey1 + i) & 0xFF] << 8) | PBOX2[((pkey1 >> 8) + i) & 0xFF];
    }
}

void gen_keys_2(uint16_t pkey2, uint16_t key2[8]) {
    for (int i = 0; i < 8; i++) {
        key2[i] = (PBOX3[(pkey2 + i) & 0xFF] << 8) | PBOX4[((pkey2 >> 8) + i) & 0xFF];
    }
}

void gen_keys_3(uint16_t pkey3, uint16_t key3[8]) {
    for (int i = 0; i < 8; i++) {
        key3[i] = (PBOX5[(pkey3 + i) & 0xFF] << 8) | PBOX6[((pkey3 >> 8) + i) & 0xFF];
    }
}

// using the original `decrypt` function requires building a pass-phrase that gives three 16-bit keys found by the MITM attack
// instead we can reverse-engineer it and change the key derivation part
//     namely, get rid of code block 0x27e7 - 0x2854 and use the three 16-bit keys directly
//     as we exported all the functions that are called in the original `decrypt`, this doesn't require a lot of work
// N.B. there are other ways to inject the three 16-bit keys, e.g. using `gdb` or DBI (see auxiliary script `decrypt.py`)
void decrypt(uint8_t *data, long *length, uint16_t pkey1, uint16_t pkey2, uint16_t pkey3) {
    // 1. generate the round subkeys
    uint16_t key1[8], key2[8], key3[8];
    gen_keys_1(pkey1, key1);
    gen_keys_2(pkey2, key2);
    gen_keys_3(pkey3, key3);

    uint16_t K1[52], K2[52], K3[52];
    key_sched_rev(key1, K1);
    key_sched(key2, K2);
    key_sched_rev(key3, K3);

    // 2. decrypt in the ECB mode, each block is processed three times
    for (int i = 0; i < *length; i += 8) {
        uint64_t block = *(uint64_t *) (&data[i]);
        block = decrypt_block(block, K3);
        block = encrypt_block(block, K2);
        block = decrypt_block(block, K1);
        *(uint64_t *) (&data[i]) = block;
    }

    // 3. unpad the decrypted data
    long index = *length - 1;
    while (data[index] == 0) {
        index--;
    }
    *length = index;
}


int main(int argc, char **argv) {
    // idea_identity_check();
    // check cmd args
    if (argc != 3) {
        std::cout << "Usage: tridea-attack <encrypted-file-path>" << std::endl;
        return 1;
    }

    // 1. read the encrypted file
    std::filesystem::path input_file_path{argv[1]};
    std::ifstream fin(argv[1], std::ios_base::binary);
    if (!fin) {
        std::cerr << "Failed to open input file \"" << input_file_path << "\": " << strerror(errno) << std::endl;
        return 2;
    }

    auto input_file_length = (long) std::filesystem::file_size(input_file_path);
    if (input_file_length == 0) {
        std::cerr << "Empty input file \"" << input_file_path << "\"" << std::endl;
        return 2;
    }
    int rem = ((int) input_file_length) % 8;
    auto *data = (uint8_t *) malloc(input_file_length + (rem == 0 ? 8 : 8 - rem));

    fin.read(reinterpret_cast<char *>(data), (long) input_file_length);
    if (!fin) {
        std::cerr << "Failed to read input file \"" << input_file_path << "\"" << std::endl;
        return 2;
    }
    fin.close();

    // 2. define our known plaintext-ciphertext pair
    uint64_t P = 0x0A1A0A0D474E5089;  // .PNG....
    uint64_t C = *(uint64_t *) data;

    // 3. build a B<->k3 mapping, where B = D_k3(C) and should match D_k2(E_k1(P)) for k1, k2, k3 we're searching for
    std::cout << "Stage 1: Building a B<->k3 mapping" << std::endl;
    std::map<uint64_t, uint16_t> map_b_k3;
    for (uint32_t pkey3 = 0; pkey3 <= 0xFFFF; ++pkey3) {
        // i. generate a putative key 3 and the corresponding round keys
        uint16_t key3[8], K3[52];
        gen_keys_3(pkey3, key3);
        key_sched_rev(key3, K3);

        // ii. decrypt the ciphertext block with the current putative key and save to the map
        uint64_t B = decrypt_block(C, K3);
        map_b_k3[B] = pkey3;
    }

    // 4. searching for a collision: trying to decrypt-encrypt and check if the resulting block R2 is in the map
    std::cout << "Stage 2: Searching for a collision to meet in the middle" << std::endl;

    // 4.1 encrypt the plaintext block P for every possible k1 and store in an array
    // N.B. this way we can encrypt P for every k1 only once
    uint64_t R1s[0x10000];
    for (uint32_t pkey1 = 0; pkey1 <= 0xFFFF; ++pkey1) {
        // i. generate a putative key 1 and the corresponding round keys
        uint16_t key1[8], K1[52];
        gen_keys_1(pkey1, key1);
        key_sched(key1, K1);

        // ii. encrypt the plaintext block with the current putative key and save to the cache array
        R1s[pkey1] = encrypt_block(P, K1);
    }

    // 4.2 perform the search
    // N.B. at this point it's worth parallelizing this loop
    std::cout << "omp_get_max_threads=" << omp_get_max_threads() << std::endl;
    #pragma omp parallel for
    for (uint32_t pkey1 = 0x0; pkey1 <= 0xFFFF; ++pkey1) {
        std::cout << "....pkey1 = 0x" << std::hex << std::setw(4) << std::setfill('0') << pkey1 << std::endl;

        uint64_t R1 = R1s[pkey1];
        for (uint32_t pkey2 = 0; pkey2 <= 0xFFFF; ++pkey2) {
            // i. generate decryption key and the corresponding round keys
            uint16_t key2[8], K2[52];
            gen_keys_2(pkey2, key2);
            key_sched_rev(key2, K2);

            // ii. decrypt and search for R2 in the map
            uint64_t R2 = decrypt_block(R1, K2);
            auto s = map_b_k3.find(R2);
            if (s != map_b_k3.end()) {
                uint16_t pkey3 = s->second;
                std::cout << "Found the matching keys!!!" << std::endl;
                std::cout << "    pkey1 = 0x" << std::hex << std::setw(4) << std::setfill('0') << pkey1 << std::endl;
                std::cout << "    pkey2 = 0x" << std::hex << std::setw(4) << std::setfill('0') << pkey2 << std::endl;
                std::cout << "    pkey3 = 0x" << std::hex << std::setw(4) << std::setfill('0') << pkey3 << std::endl;

                // decrypt the flag image
                decrypt(data, &input_file_length, pkey1, pkey2, pkey3);

                // save the result
                std::filesystem::path output_file_path{argv[2]};
                std::ofstream fout(output_file_path, std::ios_base::binary | std::ios_base::out);
                if (!fout) {
                    std::cerr << "Failed to open output file \"" << output_file_path << "\": " << strerror(errno)
                              << std::endl;
                    exit(3);
                }

                fout.write(reinterpret_cast<char *>(data), (long) input_file_length);
                if (!fout) {
                    std::cerr << "Failed to write to output file \"" << output_file_path << "\"" << std::endl;
                    exit(3);
                }
                fout.close();

                std::cout << "Success" << std::endl;

                exit(0);
            }
        }
    }

    return 0;
}
