top

Happy new year!

Here is a writeup for the challenge Spikey Elf from the 38C3 CTF by HXP.

The Challenge

We are given the following code:

#!/usr/bin/env sage
proof.all(False)

bits = 1024
errs = 16

p = random_prime(2^(bits//2))
q = random_prime(2^(bits//2))
n = p * q
e = 0x10001
print(f'{n = :#x}')
print(f'{e = :#x}')

flag = pow(int.from_bytes(open('flag.txt','rb').read().strip()), e, n)
print(f'{flag = :#x}')

d = inverse_mod(e, lcm(p-1, q-1))
locs = sorted(Subsets(range(bits), errs).random_element())
for loc in locs:
    d ^^= 1 << loc
print(f'{d = :#x}')

And the output:

n = 0x639d87bf6a02786607d67741ebde10aa39746dc8ed22b191ff2
    fefe9c210b3ee2ce68b185dc7f8069e78441bdec1d33e2b342c22
    6b5cde8a49f567ac11a3bcb7ff88eeededdd0d50eb981635920d2
    380a6b878d327b261821355d65b2ef9f807035a70c77252d09787
    c2b3dfafdfa4f5c6b39a1c66c5b39fe9d1ee4b36d86d5
e = 0x10001
flag = 0x40208a7900b1575431a49690030e4eb8be6269edcd3c7b2d
    97ae94a6eb744e9c622d81b95ea45b23ee6e0d773e3dd48adc6bb
    2c7c6423d8fd52eddcc6c0710f607590d5fc57a45883a36ad0d85
    1f84d4bee86ffaf65bc1773f97430080926550dce3666051befa8
    7bacc01d44dd09baa6ae93a85cedde5933f7cbbe2cb56cdd
d = 0x1a54893799cd9805600cfaee1c8a408813525db268fbc29e7f2
    a81eb47b64d2dd20dc8be52b6332e375f92a120957042a92a4bd4
    f5e13ef14e9b398bec330602dc9dbbb63cf3dfe6d33bf95d08306
    a894b052e005a57cc41673fe866f4f8b2ffb0aa26fc4c51a8f513
    5e40df2107e0259ddf4c1d9c1eb41b1f702b135c941

In other words:

  • The RSA public key (n, e)
  • The encrypted flag (flag)
  • The private key d with 16 random bits flipped

Solution

The first observation is that since e is small (we can enumerate 1..e) we can easily recover the high bits of d.

To see this, observe:

\[ \begin{aligned} e \cdot d &\equiv 1 \mod (p - 1) \cdot (q - 1) \\ e \cdot d &= 1 + k \cdot (p - 1) \cdot (q - 1) \\ e \cdot d &= 1 + k \cdot (p q - p - q + 1) \\ e \cdot d &= 1 + k \cdot (n - p - q + 1) \\ d &= \frac{k \cdot n - k (p+q - 1) - 1}{e} \end{aligned} \]

So if we simply guess \( k \) we can recover the high bits of \( d \) since \( n \) is roughly twice the length of the unknown \( p + q \). This immediately lets us eliminate a bunch of the bit flips in \( d \).

After applying this trick we are left with the following partly corrected \( d \):

d = 0x1a548937b9cd9805600cface1c8a408813525db268fbc29e7f2
    a81ea47b6cd2dd205c8be52b6332e375f92a120957040a92acbd4
    f7e13ef14e9b39cbec330602dc9dbbb63cf3dfe6d33bf95d08306
    a894b052e005a57cc41673fe866f4f8b2ffb0aa26fc4c51a8f513
    5e40df2107e0259ddf4c1d9c1eb41b1f702b135c941

This leaves us with 7 errors in the low bits of \( d \) to correct. Since the low bits of \( d \) are 512 bits long, and we have 7 errors, a naive bruteforce enumerating all options would take:

\[ \binom{512}{7} \approx 2^{51} \]

Which is doable, but a bit much…

The better solution is a meet-in-the-middle attack:

  • Pick a known plaintext \( m \) (e.g. \( 2 \) )

  • Compute the ciphertext \( c = m^e \bmod n \).

  • Partition the low bits of \( d \) into two random sets \( \text{lhs_bits} \subseteq [512] \) and \( \text{rhs_bits} \subseteq [512] \):

    • \( \text{lhs_bits} \) containing 3 error positions.
    • \( \text{rhs_bits} \) containing 4 error positions.

    Of course, we won’t know that these partitions contain the correct number of errors, but we guess that they do.
    It happens with probability roughly \( 27\% \).

Next observe that for the correct d we have: \[ \begin{aligned} (m^e)^d &\equiv m \bmod n \\ (m^e)^d \cdot m^{-1} &\equiv 1 \bmod n \\ \end{aligned} \] If we split d into: \[ d = \mathsf{HIGH} + \sum_{i \in \text{lhs_bits}} 2^i \cdot d_i + \sum_{i \in \text{rhs_bits}} 2^i \cdot d_i \] For a constant \( \mathsf{HIGH} \) representing the upper bits of \( d \) – which we have successfully recovered.

This formulation allows us to write: \[ 1 \equiv m^{-1} \cdot c^{\mathsf{HIGH} + \sum_{i \in \text{lhs_bits}} 2^i \cdot d_i} \cdot c^{\sum_{i \in \text{rhs_bits}} 2^i \cdot d_i} \bmod n \] If we divide both sides by \( m^{-1} \cdot c^{\mathsf{HIGH} + \sum_{i \in \text{lhs_bits}} 2^i \cdot d_i} \) we get: \[ m \cdot c^{- \mathsf{HIGH} - \sum_{i \in \text{lhs_bits}} 2^i \cdot d_i} \equiv c^{\sum_{i \in \text{rhs_bits}} 2^i \cdot d_i} \bmod n \] This enables us to search through the possibilities on the left and right hand side of the equation, looking for a match. Meaning that we can individually guess which bits in \( \text{lhs_bits} \) (3) and \( \text{rhs_bits} \) (4) are flipped. Once we have found a match, we can reconstruct \( d \) by flipping the bits in the correct positions.

In Rust, the attack looks like this:

use std::{collections::HashMap, sync::Mutex};

use rand::random;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rug::{integer::Order, Complete, Integer};

const LOW_BITS: usize = 512;

const TOTAL_BITS: usize = 1024;

const IS_TEST: bool = false;

fn unhex(s: &str) -> Integer {
    Integer::parse_radix(s, 16).unwrap().complete()
}

fn main() {
    let e = unhex("10001");
    let n;
    let d; // partly faulty d
    if IS_TEST {
        n = unhex("100a4ee7a760d592fd30bcff5e45d790782c176a957193c8943d6bb4f9a9d3d202af512aeea585da574b3789c22
                   e3dbca5aff1423d5959efda1120f8b37f81b6a7c86bb15a4a448aa4aabb4ec3f5897333f75185c29b9a36924f0a
                   20e807308404269c1fc74914f0dc09aca0e9db293b1462cf84bc9d3472d90874c0c8de46af");
        d = unhex("263bec095378069ef0d6942b4253c8999d59bc175b51f7aa4d707aa44fd8f807c84e8d7380971c906f33e396833
                   e25fade169d45ff665472dda1a6ae2b6eeb897d3cab0f2e742fea35a7abfcd7bd0982c2c8665ab8fd0db99d7105
                   e544ea6406345df27a7dbd86639a6f9ce59bbeb7e773301400b8b2a02241b96bf51d50189");
    } else {
        n = unhex("639d87bf6a02786607d67741ebde10aa39746dc8ed22b191ff2fefe9c210b3ee2ce68b185dc7f8069e78441bdec
                   1d33e2b342c226b5cde8a49f567ac11a3bcb7ff88eeededdd0d50eb981635920d2380a6b878d327b261821355d6
                   5b2ef9f807035a70c77252d09787c2b3dfafdfa4f5c6b39a1c66c5b39fe9d1ee4b36d86d5");
        d = unhex("1a548937b9cd9805600cface1c8a408813525db268fbc29e7f2a81ea47b6cd2dd205c8be52b6332e375f92a1209
                   57040a92acbd4f7e13ef14e9b39cbec330602dc9dbbb63cf3dfe6d33bf95d08306a894b052e005a57cc41673fe8
                   66f4f8b2ffb0aa26fc4c51a8f5135e40df2107e0259ddf4c1d9c1eb41b1f702b135c941");
    }

    let pt = Integer::parse("2").unwrap().complete();
    let ct = pt.clone().pow_mod(&e, &n).unwrap();

    // compute ct^{2^i} for i in 0..MAX_BITS
    let mut powers: Vec<Integer> = vec![ct];
    for _i in 0..TOTAL_BITS {
        let last = powers.last().unwrap().clone();
        powers.push((last.clone() * last).modulo(&n));
    }

    // invert every power mod n
    let inv_powers: Vec<Integer> = powers
        .iter()
        .map(|x| x.clone().invert(&n).unwrap())
        .collect();

    // decompose d into bits
    let d_bits: Vec<bool> = d.to_digits(Order::Lsf);

    // bits to flip
    let sol: Mutex<Option<Vec<usize>>> = Mutex::new(None);
    while sol.lock().unwrap().is_none() {
        // randomly split the bits
        let mut lhs_bits: Vec<bool> = vec![false; LOW_BITS];
        let mut rhs_bits: Vec<bool> = vec![false; LOW_BITS];
        for i in 0..LOW_BITS {
            if random() {
                lhs_bits[i] = true;
            } else {
                rhs_bits[i] = true;
            }
        }

        // split the low bits
        // m = (m^e)^d = (m^e)^d_high * (m^e)^d_lo
        // (m^e)^-d_high * m = (m^e)^d_lo
        let mut pt_lhs = pt.clone();
        let mut pt_rhs = Integer::from(1);
        for (i, bit) in d_bits.iter().enumerate() {
            if !bit {
                continue;
            } else if i >= LOW_BITS {
                pt_lhs *= &inv_powers[i];
                pt_lhs %= &n;
            } else if lhs_bits[i] {
                pt_lhs *= &inv_powers[i];
                pt_lhs %= &n;
            } else if rhs_bits[i] {
                pt_rhs *= &powers[i];
                pt_rhs %= &n;
            }
        }

        fn flip(
            bits: &[bool],
            d_0_to_1: &[Integer], // group deltas taking "1"s to "0"s
            d_1_to_0: &[Integer], // group deltas taking "0"s to "1"s
            val: &Integer,
            n: &Integer,
            i: usize,
        ) -> Integer {
            let delta = if bits[i] {
                d_1_to_0[i].clone()
            } else {
                d_0_to_1[i].clone()
            };
            (val * delta) % n
        }

        //
        let mut table: HashMap<Integer, (usize, usize, usize)> = HashMap::new();

        // forward (low errors 3)
        println!("Building table");
        let lhs = pt_lhs.clone();
        for i1 in 0..LOW_BITS {
            if !lhs_bits[i1] {
                continue;
            }
            let lhs = flip(&d_bits, &inv_powers, &powers, &lhs, &n, i1);
            for i2 in i1..LOW_BITS {
                if !lhs_bits[i2] {
                    continue;
                }
                let lhs = flip(&d_bits, &inv_powers, &powers, &lhs, &n, i2);
                for i3 in i2..LOW_BITS {
                    if !lhs_bits[i3] {
                        continue;
                    }
                    let lhs = flip(&d_bits, &inv_powers, &powers, &lhs, &n, i3);
                    table.insert(lhs, (i1, i2, i3));
                }
            }
        }

        // matching time
        println!("Matching time");
        (0..LOW_BITS).into_par_iter().for_each(|i1| {
            let rhs = pt_rhs.clone();
            if !rhs_bits[i1] {
                return;
            }
            let rhs = flip(&d_bits, &powers, &inv_powers, &rhs, &n, i1);
            for i2 in i1..LOW_BITS {
                if !rhs_bits[i2] {
                    continue;
                }
                let rhs = flip(&d_bits, &powers, &inv_powers, &rhs, &n, i2);
                for i3 in i2..LOW_BITS {
                    if !rhs_bits[i3] {
                        continue;
                    }
                    let rhs = flip(&d_bits, &powers, &inv_powers, &rhs, &n, i3);
                    for i4 in i3..LOW_BITS {
                        if !rhs_bits[i4] {
                            continue;
                        }
                        let rhs = flip(&d_bits, &powers, &inv_powers, &rhs, &n, i4);
                        match table.get(&rhs) {
                            Some((j1, j2, j3)) => {
                                *sol.lock().unwrap() = Some(vec![i1, i2, i3, i4, *j1, *j2, *j3]);
                            }
                            None => {}
                        }
                    }
                }
            }
        });
    }

    // flip the bits
    let sol = sol.lock().unwrap().clone().unwrap();
    let mut d_bits: Vec<bool> = d.to_digits(Order::Lsf);
    for i in sol.iter() {
        d_bits[*i] = !d_bits[*i];
    }

    // reconstruct d
    let d = Integer::from_digits(&d_bits, Order::Lsf);

    // decrypt the flag
    let flag = unhex(
        "40208a7900b1575431a49690030e4eb8be6269edcd3c7b2d
         97ae94a6eb744e9c622d81b95ea45b23ee6e0d773e3dd48adc6bb
         2c7c6423d8fd52eddcc6c0710f607590d5fc57a45883a36ad0d85
         1f84d4bee86ffaf65bc1773f97430080926550dce3666051befa8
         7bacc01d44dd09baa6ae93a85cedde5933f7cbbe2cb56cdd",
    );
    let flag = flag.pow_mod(&d, &n).unwrap();
    let flag = flag.to_string_radix(16);

    // hex decode
    let flag = hex::decode(flag).unwrap();
    println!("{}", String::from_utf8_lossy(&flag));
}

Running the code we recover the flag:

rot256@digit ~/s/rsarac (master)> cargo run --release
   Compiling rsarac v0.1.0 (/Users/rot256/src/rsarac)
    Finished `release` profile [optimized] target(s) in 0.51s
     Running `target/release/rsarac`
Building table
Matching time
hxp{fr13nd5_d0nt_l3t_fr1ends_r34d_p0wer_trac35_by_h4nd}