A walk-through of real-world AES fault injection for dummies on a shoestring budget.

Introduction

In the post we will setup a microcontroller which encrypts using AES with an unknown key, then explore how to recover the full AES key from simply (randomly) glitching the power-supply which will introduce faults in the arithmetic during the computation of AES encryption.

This post is designed to serve as a tutorial and enable the reader to follow along. If you get stuck anywhere (or lack the hardware), just skip the current step and pick up the provided files which will enable you to continue from the next step onwards. To this end there is an associated Github repository which contains the victim code, the attack and visualization scripts as well as the raw samples I collected – which can be used in lieu of setting up the hardware yourself.

The hardware setup in this post is deliberately kept minimal.

We will apply some of the techniques found in the paper “Differential Fault Analysis of the Advanced Encryption Standard using a Single Fault” by Michael Tunstall and Debdeep Mukhopadhyay, however we will not assume that we can precisely control where the fault is injected and instead create a large number of faulty ciphertexts by injecting faults randomly: this is much easier as we do not rely on precise timing.

Prerequisites

If you want to follow along you should make sure you satisfy the prerequisites.

Wetware

I only assume that the reader has:

  • The ability to read/write C.
  • The ability to read/write Verilog.
  • The ability to read/write python3.
  • Basic familiarity with finite field arithmetic.

I assume only very minimal familiarity with electronics, since this post is largely aimed at the computer science / “cybersecurity” crowd. The hardware sections in this post can also be skipped entirely.

Hardware

Our target will be the common Atmega328 microcontroller:

Showing the pinout of the Atmega328 in DIP package.
Atmega328: Pinout of Atmega328 in DIP package.

To follow along with the hardware section of this post, you will need:

  • AVR microcontroller (in a DIP), e.g. Atmega328, acting as the victim.
  • FTDI USB-to-Serial device, e.g. FTDI232, for programming and retrieving samples.
  • FPGA with YoSys & Nextpnr support. e.g. a board with some variant of an iCE40.
  • 10uF capacitor.
  • 10kΩ resistor.
  • A protoboard.
  • Male-Male jumpers to connect it all.
  • (potentially) 2N2222 transistor (or similar).
  • (optional) 1kΩ resistor.
  • (optional) LED, used as indicator.

The transistor will only be needed if the FPGA that you use is incapable of driving the microcontroller directly from an output pin.

I will be using the Atmega328 microcontroller and an old iCE40hx1k FPGA

(circa. 2011) that I had lying around on the Go Board, which means it is driven by a 25MHz clock and should be sufficiently fast to glitch the Atmega328 running with its internal clock at 8MHz. If your setup is different you will have to adjust Makefiles accordingly, but it should be straight forward. However you must ensure two things:

1. It is important that your microcontroller is not on a board with a large capacitor across the power-rail or very sensitive brown-out detection (for obvious reasons). In particular an Arduino Uno board will not work, however removing the Atmega328PU from an Uno and driving it off an external clock on the protoboard should: you might use the FPGA to generate the clock if you do not have a crystal laying around. The internal brown-out detection on the Atmega328 wont be a problem (the glitch is too short to trigger the time-out).

2. The clock frequency of your FPGA should be significantly higher than the microcontroller, since we will switch power off for a fraction of the execution time of an instruction.

Software

First you should clone the Github repository associated with this tutorial.

We will also need a bunch of open-source tools:

  • yosys, for Verilog synthesis.
  • nextpnr, for place-and-route.
  • avrdude, for flashing the microcontroller.
  • pyserial, for collecting samples over the serial port.
  • avr-gcc, for compiling C to the microcontroller.
  • python3, for collecting samples and analysis.

If you have the Nix package manager you can obtain a shell with yosys and nextpnr by running:

$ nix-shell -p nextpnr yosys

On Debian you can obtain the remainder from the package manager:

$ apt install python3-serial gcc-avr avrdude

In case you are not doing the hardware parts of this tutorial, you only need python.

Programming the victim

I assume that the Atmega328 already has the Arduino bootloader flashed, if not you can find a tutorial on how to flash the bootloader onto an Atmega328 using an Arduino Uno board, go to the “Minimal Circuit (Eliminating the External Clock)” section and follow the steps.

Prepare the setup for programming the microcontroller (howing the FTDI device on the left):

Circuit diagram of programming setup.
Programming setup: Circuit diagram of programming setup.

Depending on your LED, you might need an additional 1kΩ resistor between the LED and ground to avoid killing it. Once you map this into 3D space it might look something like this:

Picture of programming setup.
Programming setup: Picture of programming setup.

Now plug the thing into your computer and make sure it is recognized by Linux (e.g. use lsusb), then go to the victim/ directory and run:

$ make flash

You may have to adjust the tty (/dev/ttyUSB0 -> /dev/ttyUSB*) for the above to work.

This compiles and flashes the victim program to the AVR microcontroller.

Solution

See victim/victim.c:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <avr/io.h>
#include <stdint.h>
#include <string.h>

#include "aes.h"

#define UART_BAUD 9600

char HEX[] = {
    '0', '1', '2', '3',
    '4', '5', '6', '7',
    '8', '9', 'a', 'b',
    'c', 'd', 'e', 'f'
};

// secret AES key (pretend you did not see this)
uint8_t key[] = {
    0x00, 0x01, 0x02, 0x03,
    0x04, 0x05, 0x06, 0x07,
    0x08, 0x09, 0x0a, 0x0b,
    0x0c, 0x0d, 0x0e, 0x0f
};

void uart_init(void) {
    UBRR0 = (F_CPU / (16UL * UART_BAUD)) - 1;
    UCSR0B = _BV(TXEN0) | _BV(RXEN0);
}

void uart_putchar(char c) {
    loop_until_bit_is_set(UCSR0A, UDRE0);
    UDR0 = c;
}

void uart_putstr(char *s) {
    while (*s) {
        uart_putchar(*s);
        s++;
    }
}

void uart_hex(uint8_t c) {
    uart_putchar(HEX[c >> 4]);
    uart_putchar(HEX[c & 0xf]);
}

__attribute__((optimize("unroll-loops")))
int main (void) {
    uint8_t pt[AES_BLOCKSIZE];
    aes_expanded_key_t ekey;

    // setup serial
    uart_init();

    // set pin 5 of PORTB for output
    DDRB |= _BV(DDB5);

    // set pin 5 high to turn led on */
    PORTB |= _BV(PORTB5);

    // expand AES key
    aes_expand(&ekey, key);

    // repeatedly encrypt same plaintext
    uart_putstr("\n\r");
    while(1) {
        // pt <- 0^16
        memset(pt, 0, sizeof pt);

        // pt <- AES128(key, 0^B)
        aes_encrypt(pt, &ekey);

        // print(pt)
        for (uint8_t i = 0; i < sizeof pt; i++)
            uart_hex(pt[i]);
        uart_putchar('\n');
        uart_putchar('\r');
    }
}

The program causes the microcontroller to repeatedly output the hexadecimal encoding of \( \texttt{AES-128}(k, 0^{16}) \), the encryption of the all zero plaintext, on the serial port. e.g. using miniterm.py which you got along PySerial, you should observe:

--- Available ports:
---  1: /dev/ttyUSB0         'FT232R USB UART'
--- Enter port index or full name: 1
--- Miniterm on /dev/ttyUSB0  9600,8,N,1 ---
--- Quit: Ctrl+] | Menu: Ctrl+T | Help: Ctrl+T followed by Ctrl+H ---
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
...

Going on into infinity…

Inducing Glitches

But having a microcontroller which consistently does the right thing is not exactly the goal of this post. In this section we will try to make the microcontroller occasionally do the wrong thing: to make a mistake somewhere in the call to aes_encrypt.

Our first task is to control the power of microcontroller from the FPGA, so start by removing the jumper between Vout on the FTDI and the power rail on the protoboard. Depending on whether the output pins of your FPGA delivers sufficient power you may/may not have to use a transistor. We consider both these possibilities. In my case I will power the Atmega328 directly from the PMOD connector on the Go Board, by wiring one of the ground pins from the PMOD connector to ground on the protoboard and the 4th pin (M4) to power.

PMOD connector on the Go Board
PMOD connector: PMOD connector on the Go Board
Setup powering the microcontroller directly from the FPGA.
Circuit diagram of setup without NPN BJT
Glitching Setup: Circuit diagram of setup without NPN BJT

Which looks approximately like this when mapped to the real world:

Picture of setup without NPN BJT
Glitching Setup: Picture of setup without NPN BJT
Setup powering the microcontroller from an external power source.
Circuit diagram of setup with NPN BJT
Glitching Setup: Circuit diagram of setup with NPN BJT

Which looks approximately like this when mapped to the real world:

Picture of setup with NPN BJT
Glitching Setup: Picture of setup with NPN BJT

These setups lets you easily switch back/forth between programming the microcontroller and glitching if you want to experiment further on your own e.g. try to simply glitch a comparison to bypass authentication.

Exercise.

Write a Verilog program which executes the following procedure on the FPGA:

  1. Switch power on for 1 second (allowing the microcontroller to start up).
  2. Then 5000 times:
    • Switch power off for < 1 / 8 000 000 of a second
    • Switch power on for 1 / 2 000 of a second
  3. Switch power off for 1 second (shutting off the microcontroller).
  4. Go to the first step.

There is a Makefile in attacker/ which should aide in the process of flashing the program to the FPGA if you are using the Go Board. If you are using another iCE40 FPGA the modifications should be straight-forward. You might want to play around with the length of the glitches and the duration between them. I switch power off for 1 / 25 000 000 of a second.

You will know that you have succeeded by connecting to the serial port (like before) and observing that the ciphertexts occasionally deviates from the correct encryption observed earlier. Here are some successful glitches (from the reference solution). The deviating ciphertexts has been marked:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
5b39c343e831f4b4787dac01a0a78b69
ff0fe5db6edd48de705bb37837e2fd44
c6a13b37878f5b826f4f8162a1c8d879
c6a13bbe878f5b826f4f816233c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
fb568222c589024cec16ee90182f6d53
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a1783787405b82ee4f8162a1c8d831
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
c6a13b37878f5b826f4f8162a1c8d879
Solution (Click to Expand)

See attacker/glitch.v:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// chip: ice40hx1k
// board: go-board

parameter CLK_HZ = 25000000; // 25 MHz
parameter CONTROL_HZ = 8000000; // 8 MHz

parameter RESET = 2'b00;  // power off
parameter POWER = 2'b01;  // power on
parameter GLITCH = 2'b10; // do the glitch
parameter WAIT = 2'b11;   // wait for boot

parameter GLITCHES = 5000; // glitch 5000 times before reset

parameter WAIT_LEN = CLK_HZ;
parameter RESET_LEN = CLK_HZ;
parameter POWER_LEN = CLK_HZ / 2000;
parameter GLITCH_LEN = 1;

module Top(
    input CLK,
    output reg LED1,
    output reg LED2,
    output reg LED3,
    output reg LED4,
    output reg PMOD4,
);
    reg [1:0] state = RESET;
    reg [$clog2(CLK_HZ) + 5:0] cnt;
    reg [$clog2(GLITCHES):0] glitches;

    always @(posedge CLK) begin
        LED1 <= 0;
        LED2 <= 0;
        LED3 <= 0;
        LED4 <= 0;
        cnt <= cnt + 1;
        case (state)
            RESET : begin
                LED1 <= 1;
                PMOD4 <= 0;
                if (cnt > RESET_LEN) begin
                    cnt <= 0;
                    state <= WAIT;
                end
            end
            WAIT : begin
                LED2 <= 1;
                PMOD4 <= 1;
                if (cnt > WAIT_LEN) begin
                    cnt <= 0;
                    state <= GLITCH;
                    glitches <= 0;
                end
            end
            POWER : begin
                LED3 <= 1;
                PMOD4 <= 1;
                if (cnt > POWER_LEN) begin
                    cnt <= 0;
                    state <= GLITCH;
                end
            end
            GLITCH : begin
                PMOD4 <= 0;
                if (cnt > GLITCH_LEN) begin
                    cnt <= 0;
                    state <= POWER;
                    glitches <= glitches + 1;
                end
                if (glitches > GLITCHES) begin
                    cnt <= 0;
                    state <= RESET;
                end
            end
        endcase
    end
endmodule

Sitting around and manually collecting the divating ciphertext samples is silly, especially since we just automated the glitching part…

So automate this part as well:

Exercise.
Write a python script to automatically collect the ciphertexts from the serial port using PySerial, remove duplicates and save them to a file.
Solution (Click to Expand)

See attacker/collect.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#!/usr/bin/env python3

import sys
import serial
import binascii

ser = serial.Serial(sys.argv[1])

try:
    with open(sys.argv[2], 'r') as f:
        cts = map(str.strip, f.readlines())
        cts = [binascii.unhexlify(x) for x in cts]
        cts = set(cts)
except IOError:
    cts = set([])

out = open(sys.argv[2], 'a')

print('Samples:', len(cts))
print('Collecting samples from:', ser.name)

line = b''
while 1:
    line += ser.read()
    if b'\n' in line:
        line = line.strip()
        try:
            line = binascii.unhexlify(line)
            if len(line) == 16:
                print('Ciphertext:', line.hex())
                if line not in cts:
                    print('New:', line.hex())
                    out.write(line.hex() + '\n')
                    out.flush()
                    cts.add(line)
        except binascii.Error:
            pass
        line = b''
ser.close()

Let it run for a while until you have a file with a few hundred distinct ciphertexts. I was able to successfully mount the attack with as little as ~ 20 samples, but collect a few more for good measure.

If you were unable to complete this section: use the provided attacker/samples.txt which I collected from the setup shown in this section and which will enable you to follow along in the cryptanalysis part of this post.

AES(-128) encryption

In order to make sense of the faulty ciphertexts we will need to understand the internals of AES.

The AES wikipedia article covers most of the content in this section (from which I also pilfered the illustrations), however I include it here for completeness. If you prefer stick figures then there is even A Stick Figure Guide to the AES. For those who prefer C, it might be easier to read my didactic implementation of AES-128 which we are attacking in this post victim/aes.c victim/aes.h (\(\approx\) 150 lines) included in the git repository. Hopefully by the end of this section you will realize that AES has a very intuitive structure.

AES explained with stick figures
AES explained with stick figures: From A Stick Figure Guide to the AES.

Bytes are interpreted as elements of \( \mathbb{F}_{2^8} = \mathbb{F}_2[X] / (X^8+X^4+X^3+X+1) \) in the straight-forward way, by mapping the bits to the coefficients of a canonical representative, where the least significant bit is the constant term. For those new to this, see the crash-course below.

Crash-course on the AES extension field (click to expand)

Operations in AES is done over the field \( \mathbb{F}_{2^8} = \mathbb{F}_2[X] / (X^8+X^4+X^3+X+1) \), i.e. the quotient group formed by taking univariate polynomials with coefficients in \( \mathbb{F}_2 \) modulo the irreducible polynomial \(X^8 + X^4 + X^3 + X + 1\). In other words adding/subtracting multiples of \( X^8 + X^4 + X^3 + X + 1 \), like \( X^9 + X^5 + X^3 + X^2 + 1 = \) \( (X + 1) (X^8 + X^4 + X^3 + X + 1) \equiv \) \( 0 \mod X^8 + X^4 + X^3 + X + 1 \) each represent the same element, analogous to arithmetic modulo a natural number (i.e. \( \mathbb{Z}_n \)). The elements of \( \mathbb{F}_{2^8} \) can be represented as bytes in a natural way:

Representation: Given a byte \( B = \sum_{i = 0}^{7} b_i \cdot 2^i \), it is interpreted as a polynomial \( p_{B}(X) = \sum_{i = 0}^{7} b_i \cdot X^i \); a canonical representative for the coset.

Addition: Addition of elements in an extension field is simply addition of polynomials: addition is done pairwise on the coefficients for every power. Since the group has characteristic 2, addition and subtraction in the group is the same and the common nomenclature for this sophisticated operation is “xor of bytes”, e.g. \[ \texttt{0x54} = \texttt{01010100}_{2} =_{repr} X^6 + X^4 + X^2 \] \[ \texttt{0x64} = \texttt{01100100}_{2} =_{repr} X^6 + X^5 + X^2 \] \[ \texttt{0x30} = \texttt{0x54} \oplus \texttt{0x64} = \texttt{00110000}_{2} \] \[ \texttt{0x30} =_{repr} (X^6 + X^4 + X^2) + (X^6 + X^5 + X^2) = (2 \cdot X^6) + X^5 + X^4 + (2 \cdot X^2) \] \[ \texttt{0x30} =_{repr} (0 \cdot X^6) + X^5 + X^4 + (0 \cdot X^2) = X^5 + X^4 \]

Multiplication: Multiplication of elements in an extension field is likewise multiplication of polynomials, however unlike addition a reduction might be required: \[ \texttt{0x54} = \texttt{01010100}_{2} =_{repr} X^6 + X^4 + X^2 \] \[ \texttt{0x64} = \texttt{01100100}_{2} =_{repr} X^6 + X^5 + X^2 \] \[ R(X) = (X^6 + X^4 + X^2) \cdot (X^6 + X^5 + X^2) = X^{12} + X^{11} + X^{10} + X^9 + 2 \cdot X^8 + X^7 + X^6 + X^4 \] Then we separate out all the multiples of \( X^8 + X^4 + X^3 + 1 \): \[ R(X) = (-X^7 - X^6 - 2 X^5 - 2 \cdot X^4 - 2 \cdot X^3 - X^2 - X - 1) + (X^8 + X^4 + X^3 + 1) (X^4 + X^3 + X^2 + X + 1) \] Since \( 2 \equiv 0\mod 2 \), \( - 1 \equiv 1\mod 2 \), the final reduced canonical representative is: \[ R(X) \equiv X^7 + X^6 + X^2 + X + 1 \mod X^8 + X^4 + X^3 + 1 \]

This can be efficiently implemented using a variant of modular long multiplication on the byte/integer representations of the elements. The pseudocode for this procedure look as follows:

func mult(a, b) {
    r = 0
    p = 0b100011011 // (X^8+X^4+X^3+X+1)

    while a > 0:
        // add if the constant term is set
        if a & 1 == 1:
            r ^= b
        a >>= 1

        // multiply by "X"
        b <<= 1
        if b >> 8 == 1:
            b ^= p

    return r
}

The state of AES consists of 16 field elements of \(\mathbb{F}_{2^8}\), which conveniently can be represented as bytes: A “flat” plaintext/ciphertext array of 16 bytes \( a_{0,0}, a_{1,0}, a_{2,0}, a_{3,0}, \) \( a_{0,1}, a_{1,1}, a_{2,1}, a_{3,1}, \) \( a_{0,2}, a_{1,2}, a_{2,2}, a_{3,2}, \) \( a_{0,3}, a_{1,3}, a_{2,3}, a_{3,3} \) is interpreted as a 4x4 matrix of \(\mathbb{F}_{2^8})\) elements column-by-column:

\[ \begin{bmatrix}a_{0, 0} & a_{0,1} & a_{0,2} & a_{0, 3} \\ a_{1,0} & a_{1,1} & a_{1,2} & a_{1,3} \\ a_{2,0} & a_{2,1} & a_{2,2} & a_{2,3} \\ a_{3,0} & a_{3,1} & a_{3,2} & a_{3,3} \end{bmatrix} \in M_{4}(\mathbb{F}_{2^8}) \]

There are 4 different transformations applied to the state matrix in AES:

SubBytes

SubBytes replaces every element in the state by applying a permutation (the S-box). We will revisit the exact value of the S-box later, when we extract key-material from the samples. SubBytes is the only non-linear (over \(\mathbb{F}_{2^8} \)) operation applied in AES.

AES SubBytes illustration
AES SubBytes: Applies a permutation to every element in the state.

Note: The AES S-Box is carefully chosen to enable cheap hardware implementation (without lookup tables) while providing excellent protection against linear and differential cryptanalysis.

ShiftRows

ShiftRows rotates each row by a different amount, this moves an entry from every column to every other and provides “mixing” across columns:

AES ShiftRows illustration
AES ShiftRows: Shifts the rows of the state matrix.

Note: ShiftRows can be implemented at no cost, since it is merely a relabeling of the elements in the state.

MixColumns

MixColumns applies an MDS (Maximum Distance Separable) matrix to every column independently. MixColumns is the only operation which does not operate on a “byte-by-byte” basis and provides “diffusion” within each column:

AES MixColumns illustration
AES MixColumns: Applies an MDS matrix to every column in the state matrix.

Note: Since the matrix is MDS any single element change to the input column affects every element of the new column. The MDS property is integral to the “wide-trail strategy” which protects AES against linear and differential cryptanalysis.

AddRoundKey

AddRoundKey adds a secret round-key matrix to the state matrix (addition of matrices):

AES AddRoundKey illustration
AES AddRoundKey: Adds a secret round-key matrix to the state matrix.

Note: We will revisit how the round-keys are generated later.

Rounds

AES-128 then operates by applying the transformation above in order for 10 rounds:

  • Initial round:
    • AddRoundKey
  • For 9 rounds:
    • SubBytes
    • ShiftRows
    • MixColumns
    • AddRoundKey
  • Final round:
    • SubBytes
    • ShiftRows
    • AddRoundKey

Finally the state matrix is again interpreted as a flat array (column-by-column) of bytes, which forms the ciphertext. Since both MixColumns and AddRoundKey are linear in \(\mathbb{F}_{2^8}\), the omission of MixColumns in the final round has no effect on security: let \( a_i \) be a row of the state then MixColumns followed by AddRoundKey can be expressed as \(b_i = C(a_i) + k_i^{r} \), given such a ciphertext the adversary can simply compute \( C^{-1}(a_i) \) \( = C^{-1}(C(a_i)) + C^{-1}(k_i^{r}) \) \( = a_i + C^{-1}(k_i^{r})\) which gets rid of the diffusion induced by the final MixColumns.


Now you are ready to quiz yourself on AES:

Exercise.
If a single element in the state is changed at the start of a round, how many elements are changed by the end of the round?
Solution (Click to Expand)

4 elements are changed at the end of the round:

  • SubBytes preserves the difference,
  • ShiftRows preserves the difference,
  • MixColumns causes all elements in the column to change (since it is multiplication by an MDS matrix),
  • AddRoundKey preserves the 4 resulting differences.
Exercise.
How many rounds does “full diffusion” take in AES? i.e. if a single byte is changed in the state, how many rounds are required before this causes every byte in the state to change?
Solution (Click to Expand) Full diffusion occurs within 2 rounds (as per the previous exercise).
Exercise.
Explain why AES encryption starts with AddRoundKey.
Solution (Click to Expand) Otherwise the first SubBytes, ShiftRows and MixColumns can be trivially inverted.
Exercise.
AES becomes trivially breakable if we remove any of the transformations, regardless of how they are composed or how many rounds we apply. Explain why.
Solution (Click to Expand)
  • Omitting SubBytes: AES becomes completely affine: every ciphertext c_i can be computed as c_i = M x + k for some M and k. Which enables the adversary to recover the key by Gaussian elimination.
  • Omitting ShiftRows: AES decomposes into 4 parallel 32-bit blockciphers. Which enables you to construct the full codebook by encryption only \(2^{32}\) plaintexts.
  • Omitting MixColumns: AES decomposes into a 16 parallel instances of an 8-bit blockcipher. Which enables you to construct the full codebook by encrypting only \(2^8\) plaintexts.
  • Omitting AddRoundKey: Clearly broken.

Recovering the last round key

With our new found knowledge of AES, we start by getting a better grasp of the data we just collected, which also serves as a sanity check for our collection process. If you skipped the hardware section earlier use the samples from attacker/samples.txt for the remainder of this post.

Understanding the data

Note: The Hamming distance is simply the number of differing symbols between two strings, e.g. \(D_{H}(01011, 10111) = 3\), below each symbol is an element of \( \mathbb{F}_{2^8} \) or you know: a byte. Hence each string (ciphertext) contains 16 symbols and we count the number of differing bytes.

If you plot a histogram (e.g. using attacker/hist.py) of the Hamming distance between each faulty ciphertext and the correct ciphertext, you should get something like the following plot:

Number of faulty ciphertexts with a given Hamming distance from the correct ciphertext.
Histogram over fault AES ciphertexts: Number of faulty ciphertexts with a given Hamming distance from the correct ciphertext.

The shape of this distribution reveals something about the structure of AES that we learned earlier (in an exercise) and tells us where the fault for a particular faulty ciphertext was induced. This will allow us to filter out the faults that we will use for our attack. We will focus on the spikes at 1, 2, 4, 8, 15 and 16:

Distance
Justification
1A difference of 1 byte is caused when the fault is introduced in the final round, during either SubBytes, ShiftRows or AddRoundKey. The last round omits MixColumns and therefore operates entirely on a byte-by-byte basis.

Another plausible cause of 1 symbol errors is the hexadecimal conversion code in the victim program.
2A difference of 2 bytes is caused by 2 faults in the last round.
4A difference of 4 bytes indicates that a single fault is induced before/during MixColumns in round 9, but after MixColumns in round 8. As noted in an earlier exercise this causes 4 elements of the state to change in round 10, which omits the MixColumns step and hence operates entirely element wise.
8A difference of 8 bytes indicates 2 faults in the before/during MixColumns in round 9, but after MixColumns in round 8.
15 & 16A completely random looking ciphertext indicates that the fault is induced in an earlier round, which causes the fault to diffuse (in 2 rounds, as noted in an earlier exercise). The number of entries with a Hamming distance of 15 bytes is very close the statical expectation:
\( 144 \approx (2340 + 144) \cdot \frac{1}{256} \cdot \Big(\frac{255}{256}\Big)^{15} \cdot { 16 \choose 1 } = 145.8260\)
i.e. what we would expect if we sampled these ciphertexts uniformly at random.

Differential faults on AES

To get started lets consider the propagation of a single fault in the 9th round of AES, just before the final MixColumns (recall that the 10th round omits MixColumns). The remaining section of the cipher causes the fault to propagate as follows:

  • MixColumns (round 9) causes the single difference to be propagate to the entire column.
  • AddRoundKey (round 9) is a bijection and hence all 4 differences are preserved.
  • SubBytes (round 10) is a bijection and hence all 4 differences are preserved.
  • ShiftRows (round 10) is applied which permutes the elements in each.
  • AddRoundKey (round 10) is a bijection and hence all 4 differences are preserved.

The propagation of the fault is illustrated below (ignoring the element-wise bijections):

The consequences of inducing a fault in the 9th round of AES
Fault Propagation: The consequences of inducing a fault in the 9th round of AES

Hence a fault introduced in any element of the first column, causes a difference to every element in the column after MixColumns, then ShiftRows is applied which permutes the indexes of the first column [1, 2, 3, 4] to [1, 14, 11, 8] respectively. Therefore we get a non-zero difference in the 1st, 8th, 11th and 14th byte of the faulty ciphertext as seen in the illustration above.

Exercise.
Since the state consists of 4 columns, if a single fault is induced prior to the final MixColumns application there are 4 different and disjoint (ordered) list of indexes in which the faulty ciphertext can differ from the correct ciphertext. The first such list ([1, 14, 11, 8]) is given above, compute the remaining 3 lists.
Solution (Click to Expand)
  • 1st column: [1, 2, 3, 4] --ShiftRows-> [ 1, 14, 11, 8 ]
  • 2nd column: [5, 6, 7, 8] --ShiftRows--> [ 5, 2, 12, 15 ]
  • 3rd column: [9, 10, 11, 12] --ShiftRows--> [9, 6, 3, 16]
  • 4th column: [13, 14, 15, 16] --ShiftRows--> [13, 10, 7, 4]
Exercise.

Is this observation consistent with the samples you collected?

Compare the indexes of the faulty ciphertexts with a Hamming distance of 4 from the correct ciphertext.

Okay, so given such a scenario how can be begin to recover key material? Key-recovery in symmetric cryptography often involves making a guess for some part of the key-material, then using this guess to partially encrypt/decrypt and test a relation part-way inside the cipher which will be satisfied if our guess was correct. Particularly for those familiar with the key recovery procedure used in differential cryptanalysis the method applied here bares a striking resemblance: Like differential cryptanalysis we will guess part of the last round key, then attempt partial decryption and use whether the differential relation is satisfied to validate/reject our candidate.

Assume for now that we know in which of the 4 columns a fault is introduced and the difference \( \delta \neq \mathbf{0} \) introduced by the fault, a vector \( \delta \in (\mathbb{F}_{2^8})^4 \), i.e. the faulty value is \( y’ = y + \delta \) where \( y \) (unknown) was the original correct value. If we know \( \delta \) we can recover candidates for 4 bytes of the last round key as follows:

  1. Compute \( \Delta = C \ \delta \) corresponding to the difference introduced by the fault after the MixColumns transformation, i.e. since MixColumns is linear: \( \ C \ y’ = C (\delta + y) = C \ \delta + C \ y \)

  2. Given the list of faulty indexes for the row \( [ i_{1}, \ldots, i_{4} ] = I_{row} \). For every \(j \in [1, 4]\), obtain candidates \( K_j \) for the \(i_j\)’th byte of the last round key as follows:

    1. Let \( x_{i_j} \) be the \( i_j \)’th element from the correct ciphertext and let \( x_{i_j}’ \) be the \( i_j \)’th element of the faulty ciphertext. Guess the corresponding byte of the last round key \(k^{10}_{i_j}\), then do partial decryption of the last round by computing \( z_{j} = S^{-1}(x_{i_j} - k^{10}_{i_j}) \), \( z_{j}’ = S^{-1}(x_{i_j}’ - k^{10}_{i_j}) \), compute the difference \( \rho_j = z’_j - z_j \). Observe that the key-addition in round 9 does not affect the difference, since \( \rho_j = (z’_{j} - e) - (z_{j} - e) \) for any \( e \).

    2. If our guess for \( k^{10}_{i_j} \) was correct then \( \rho_{j} = \Delta_{j} \). Thus if \( \rho_{i} = \Delta_j \), we add our guess for \( k^{10}_{i_j}\) to the set \( K_{j} \) of candidates for the \(i_j\)’th byte of the last roundkey.

  3. Compute the product \(K_1 \times K_2 \times K_3 \times K_4 \) to obtain a set of candidates for 32-bits of the last round key.

Example.

Assume that the fault introduces a difference \( \delta = \texttt{0x85} \) in the first entry (first row) in the first column. The differences in the column after applying MixColumns becomes: \[ \Delta = C \begin{bmatrix}
\delta
\\ 0
\\ 0
\\ 0
\end{bmatrix} = \begin{bmatrix} 2 & 3 & 1 & 1 \\
1 & 2 & 3 & 1 \\
1 & 1 & 2 & 3 \\
3 & 1 & 1 & 2 \end{bmatrix} \begin{bmatrix}
\delta
\\ 0
\\ 0
\\ 0
\end{bmatrix} = \begin{bmatrix} 2 \delta \\
\delta \\
\delta \\
3 \delta
\end{bmatrix} = \begin{bmatrix} \texttt{0x11} \\
\texttt{0x85} \\
\texttt{0x85} \\
\texttt{0x94}
\end{bmatrix} \] As noted earlier, after applying ShiftRows the column is mapped to the indexes \( I = [ 1, 14, 11, 8] \). We for each index \(i_j \in I\) we execute the procedure. i.e.

  • For \( i_1 = 1 \), try every possible value for \( k^{10}_{1} \), compute \( \rho_1 = S^{-1}(x_1 - k^{10}_1) - S^{-1}(x’_1 - k^{10}_1) \), if \( \rho_1 = \Delta_{1} = 2 \delta = \texttt{0x11} \) then \( k^{10}_1 \) is a candidate for the first byte of the last round key.
  • For \( i_2 = 14 \), try every possible value for \( k^{10}_{14} \), compute \( \rho_2 = S^{-1}(x_{14} - k^{10}_{14}) - S^{-1}(x’_{14} - k^{10}_{14}) \), if \( \rho_2 = \Delta_{2} = \delta = \texttt{0x85} \) then \( k^{10}_{14} \) is a candidate for the second byte of the last round key.
  • For \( i_3 = 11 \), try every possible value for \( k^{10}_{11} \), compute \( \rho_3 = S^{-1}(x_{11} - k^{10}_{11}) - S^{-1}(x’_{11} - k^{10}_{11}) \), if \( \rho_3 = \Delta_{3} = \delta = \texttt{0x85} \) then \( k^{10}_{11} \) is a candidate for the third byte of the last round key.
  • For \( i_4 = 8 \), try every possible value for \( k^{10}_{8} \), compute \( \rho_4 = S^{-1}(x_{11} - k^{10}_{8}) - S^{-1}(x’_{8} - k^{10}_{8}) \), if \( \rho_4 = \Delta_{4} = 3 \delta = \texttt{0x94} \) then \( k^{10}_{8} \) is a candidate for the fourth byte of the last round key.

After recovering the candidates for every \( i_j \) under the same \( \delta = \texttt{0x85} \) you then compute the direct product of the sets to obtain every 32-bit partial key-candidate (for the indexes of the full round key in the set I). In particular if no \( k^{10}_{i_j} \) satisfies \( \rho_{j} = \Delta_{j} \) for some \(j\), then \( \delta_i = \texttt{0x85} \) is an “impossible differential” and there exists no key consistent with the assumption \( \delta = \texttt{0x85} \).

In other words, given a vector of differentials after MixColumns “D” (\(\Delta\) in the example above) and list of indexes after ShiftRows “I” (I in the example above), pseudocode for key-recovery looks as follows:

func possible_keys(D, I, ct, ct') {
    // recover candidates for each byte
    ks = [{}, {}, {}, {}]
    for i in [1, ..., 4]:
        for k in [0, ..., 0x100]:
            x  = ct[I[i]]
            x' = ct'[I[i]]
            d  = Sinv[x xor k] xor Sinv[x' xor k]

            if d == D[i]:
                ks[i].add(k)

    // compute 32-bit partial keys
    return product(ks)
}
Exercise.

Implement the program above in python.

You will need the concrete AES sbox (and to compute its inverse yourself):

sbox = [
    0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 
    0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
    0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 
    0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
    0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 
    0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
    0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 
    0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
    0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 
    0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
    0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 
    0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
    0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 
    0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
    0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 
    0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
    0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 
    0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
    0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 
    0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
    0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 
    0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
    0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 
    0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
    0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 
    0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
    0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 
    0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
    0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 
    0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
    0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 
    0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16
]
Solution (Click to Expand)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import itertools

def solve(xA, xB, D):
    ks = []
    for k in range(0x100):
        if D == sboxI[xA ^ k] ^ sboxI[xB ^ k]:
            ks.append(k)
    assert len(ks) in (0, 2, 4)
    return ks

def possible_keys(D, I, ct_correct, ct_fault)
    ks = [[], [], [], []]
    for j, i in enumerate(I):
        ks[j] = solve(ct_correct[i], ct_fault[i], D[j])
        if len(ks[j]) == 0:
            break

    return itertools.product(*ks)

We are also going to need some code to actually calculate “D” for every possible single element fault in any position of a column. From the example earlier it should be clear enough how to do this.

Exercise.
Write code to compute “D” based on the row and difference of the fault. Remember that multiplication (mult) is not in the integers but in \( \mathbb{F}_2[X] / (X^8+X^4+X^3+X+1) \). You can use the example earlier to test your code.
Solution (Click to Expand)

Excerpt from attacker/extract.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def mult(a, b):
    assert 0x100 > a >= 0
    assert 0x100 > b >= 0

    r = 0
    p = 0b100011011

    while b:
        if b & 1:
            r ^= a

        a <<= 1
        b >>= 1

        if a >> 8:
            a ^= p

    return r

def mix_column(v):
    assert len(v) == 4

    r1 = mult(2, v[0]) ^ mult(3, v[1]) ^ v[2] ^ v[3]
    r2 = v[0] ^ mult(2, v[1]) ^ mult(3, v[2]) ^ v[3]
    r3 = v[0] ^ v[1] ^ mult(2, v[2]) ^ mult(3, v[3])
    r4 = mult(3, v[0]) ^ v[1] ^ v[2] ^ mult(2, v[3])

    return (r1, r2, r3, r4)

def mix_fault(d, row):
    col = [0, 0, 0, 0]
    col[row] = d
    return mix_column(col)

So how do we figure out which fault is induced in which row and column?

The simple answer is that we do not: for each of the 4 sets of indexes “I”, we iterate over every faulty ciphertext which differs from the correct ciphertext at every index in “I”. For each such ciphertext we try every possible “D”, then keep a count of how many times we see any particular key candidate. If we guess correctly then the counter for the correct key will be increased, if not we assume that the erroneous key candidate will be sampled uniformly at random.

Hence across multiple samples the correct key will have a significantly higher frequency than any other candidate: it will emerge consistently whenever we guess correctly, while the erroneous candidates will be distributed uniformly. In the cases where faults are injected very early in the cipher, every key candidate will be part of “the noise floor” of keys from erroneous guesses. This process is very similar to simple key-recovery in differential/linear cryptanalysis, where you can keep a counter for every candidate, which is then increased for every ciphertext & plaintext pair for which the linear/differential characteristic is satisfied.

Exercise.

Implement the full round key recovery program. The program should:

  1. Load the samples (from samples.txt) and the correct ciphertext.
  2. Compute a list of every possible \(\Delta\) (\( 255 \times 4 \) elements).
  3. Compute a list of every possible \(I\) (4 elements).
  4. Initialize \(\text{KEY} = 0^{16} \).
  5. For every \(I\):
    1. Initialize the map \(\text{cnt} = \emptyset\).
    2. For every sample:
      1. Compute the set \( \mathcal{I} \) of indexes in which the correct ciphertext and the sample differs. If \( I \not\subseteq \mathcal{I} \), then ignore the sample.
      2. Otherwise, use the key recovery procedure earlier to recover the key candidates \(K\). For every \( k \in K\), increment \( \text{cnt}[k] \) by one.
    3. Pick the value \(k\) for which \(\text{cnt}[k]\) which is maximal and update the full guess for the last round key: For every \( j \in [1,4] \) set \( \text{KEY}_{I_j} = k_j \).
  6. Output \(\text{KEY}\).
Solution (Click to Expand)

See attacker/extract.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3

import sys
import binascii
import itertools
import collections

sbox = [
    0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 
    0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
    0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 
    0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
    0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 
    0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
    0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 
    0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
    0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 
    0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
    0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 
    0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
    0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 
    0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
    0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 
    0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
    0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 
    0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
    0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 
    0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
    0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 
    0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
    0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 
    0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
    0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 
    0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
    0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 
    0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
    0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 
    0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
    0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 
    0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16
]

# shift row permutation on indexes
shift_rows = [
    0, 13, 10, 7,
    4, 1, 14, 11,
    8, 5, 2, 15,
    12, 9, 6, 3,
]

# indexes for every column
cols = [
    (0, 1, 2, 3),
    (4, 5, 6, 7),
    (8, 9, 10, 11),
    (12, 13, 14, 15)
]

def inv_perm(perm):
    inv = [0] * len(perm)
    for i in range(len(perm)):
        inv[perm[i]] = i
    return inv

sboxI = inv_perm(sbox)

def mult(a, b):
    assert 0x100 > a >= 0
    assert 0x100 > b >= 0

    r = 0
    p = 0b100011011

    while b:
        if b & 1:
            r ^= a

        a <<= 1
        b >>= 1

        if a >> 8:
            a ^= p

    return r

def mix_column(v):
    assert len(v) == 4

    r1 = mult(2, v[0]) ^ mult(3, v[1]) ^ v[2] ^ v[3]
    r2 = v[0] ^ mult(2, v[1]) ^ mult(3, v[2]) ^ v[3]
    r3 = v[0] ^ v[1] ^ mult(2, v[2]) ^ mult(3, v[3])
    r4 = mult(3, v[0]) ^ v[1] ^ v[2] ^ mult(2, v[3])

    return (r1, r2, r3, r4)

def mix_fault(d, row):
    col = [0, 0, 0, 0]
    col[row] = d
    return mix_column(col)

def recover(ct_correct, ct_fault, I, K):
    assert len(I) == 4
    assert len(ct_correct) == len(ct_fault) == 16

    def solve(xA, xB, D):
        ks = []
        for k in range(0x100):
            if D == sboxI[xA ^ k] ^ sboxI[xB ^ k]:
                ks.append(k)
        assert len(ks) in (0, 2, 4)
        return ks

    for row in range(4): # guess the row of the fault
        for d in range(1, 0x100): # guess the difference
            D = mix_fault(d, row)

            # calculate candiates for each byte
            ks = [[], [], [], []]
            for j, i in enumerate(I):
                ks[j] = solve(ct_correct[i], ct_fault[i], D[j])
                if len(ks[j]) == 0:
                    break

            # expand the cross product to get full 32-bit keys
            for k in itertools.product(*ks):
                K[k] += 1

def diff(s1, s2):
    idx = set([])
    for i in range(len(s1)):
        if s1[i] != s2[i]:
            idx.add(i)
    return frozenset(idx)

if __name__ == '__main__':
    # sample file
    with open(sys.argv[1], 'r') as f:
        samples = map(str.strip, f.readlines())
        samples = [binascii.unhexlify(s) for s in samples]

    # correct ciphertext
    correct = binascii.unhexlify(sys.argv[2])

    # optional threshold
    try:
        threshold = int(sys.argv[3])
    except IndexError:
        threshold = None

    # group ciphertext by faulty indexes
    Is = [tuple([shift_rows[v] for v in col]) for col in cols]
    groups = { I: [] for I in Is}
    for sample in samples:
        df = diff(sample, correct)
        for I in Is:
            if all([i in df for i in I]):
                groups[I].append(sample)

    # extract from every group of faulty indexes
    KEY = [None] * 16
    for I in Is:
        print('I : %14s, ciphertexts: %d' % (I, len(groups[I])))

        # find the most seen 32-bit key candidates
        K = collections.Counter()
        for i, sample in enumerate(groups[I]):
            recover(correct, sample, I, K)

            # print status
            show = ''.join(['%02x' % v if v else '??' for v in KEY])
            print('Key: %s, Sample %d / %d, Cand: %d, Top: %s' % (show, i, len(groups[I]), len(K), K.most_common(2)))

            # check optional threshold
            try:
                _, cnt = K.most_common(1)[0]
                if cnt >= threshold:
                    break
            except:
                pass

        # fill in the full round key with the recovered 32-bits
        key, _ = K.most_common(1)[0]
        for j, i in enumerate(I):
            assert KEY[i] is None
            KEY[i] = key[j]

    print('KEY:', bytes(KEY).hex())

Four faults for the price of one.

One might think that the above attack only allows us to use the ciphertexts where a fault is introduced in the ninth round: those faulty ciphertexts with a Hamming distance of exactly four from the correct ciphertext.

However it turns out that the most useful fault is one introduced before MixColumns in the eight round: A fault introduced in the eight round before MixColumns causes exactly one fault in every column in the eight round: we get four faults for the effort of one. This is illustrated below:

The consequences of inducing a fault in the 8th round of AES. From the paper “Differential Fault Analysis of the Advanced Encryption Standard using a Single Fault” by Michael Tunstall and Debdeep  Mukhopadhyay.
Fault Propagation: The consequences of inducing a fault in the 8th round of AES. From the paper “Differential Fault Analysis of the Advanced Encryption Standard using a Single Fault” by Michael Tunstall and Debdeep Mukhopadhyay.

Hence such a faulty ciphertext will have a Hamming distance of exactly 16 from the correct ciphertext. Our program is also capable of taking advantage of these faults without any additional modifications. Now we are finally ready to recover the key:

Exercise.
Use the program to recover the last round key.
Solution (Click to Expand)

The last round key is:

13111d7fe3944a17f307a78b4d2b30c5

To recover the cipherkey itself we need to gain some insight into how the key schedule of AES-128 functions.

AES-128 key schedule

To recover the cipherkey (“the AES-128 key”), from the last round key we will need to understand the AES key-schedule. For those comfortable with C it might be easier to read the key-schedule code from the didactic AES implementation included in the repository (victim/aes.c).

The AES-128 key-schedule operates by splitting the 16 element/byte \( k^0_{0,0} k^0_{1,0} k^0_{2,0} k^0_{3,0} \) \( k^0_{0,1} k^0_{1,1} k^0_{2,1} k^0_{3,1} \) \( k^0_{0,2} k^0_{1,2} k^0_{2,2} k^0_{3,2} \) \( k^0_{0,3} k^0_{1,3} k^0_{2,3} k^0_{3,3} \) cipherkey into 4 “words” holding 4 field elements each. Then applies a function to the last word, adding the result (in \( \mathbb{F}_{2^8}^4 \)) to the first word, then cascading the change by adding every word to its successor (see the illustration below):

The Feistel-like key schedule of AES-128. The cipherkey is split into 4 words, then the last word is transformed and added to the first word. The result is then cascaded by adding every word to its successor.
AES-128 key schedule: The Feistel-like key schedule of AES-128. The cipherkey is split into 4 words, then the last word is transformed and added to the first word. The result is then cascaded by adding every word to its successor.

This process is repeated iteratively 10 times to produce the 11 required round keys: the first round key is the cipherkey and every intermediate result is a new round key.

Exercise.
The AES key-schedule (image above) is invertible, regardless of what RotWord, SubWord and Rcon are (they could be arbitrary functions): stare at the image above and convince yourself that the process can be inverted.
Solution (Click to Expand)

The process is somewhat familiar to that of a generalized Feistel network. Recover:

  • \( k^{0}_{0,3} \leftarrow k^{1}_{0,2} \oplus k^{1}_{0,3} \)
  • \( k^{0}_{1,3} \leftarrow k^{1}_{1,2} \oplus k^{1}_{1,3} \)
  • \( k^{0}_{2,3} \leftarrow k^{1}_{2,2} \oplus k^{1}_{2,3} \)
  • \( k^{0}_{3,3} \leftarrow k^{1}_{3,2} \oplus k^{1}_{3,3} \)

  • \( k^{0}_{0,2} \leftarrow k^{1}_{0,1} \oplus k^{1}_{0,2} \)
  • \( k^{0}_{1,2} \leftarrow k^{1}_{1,1} \oplus k^{1}_{1,2} \)
  • \( k^{0}_{2,2} \leftarrow k^{1}_{2,1} \oplus k^{1}_{2,2} \)
  • \( k^{0}_{3,2} \leftarrow k^{1}_{3,1} \oplus k^{1}_{3,2} \)

  • \( k^{0}_{0,1} \leftarrow k^{1}_{0,0} \oplus k^{1}_{0,1} \)
  • \( k^{0}_{1,1} \leftarrow k^{1}_{1,0} \oplus k^{1}_{1,1} \)
  • \( k^{0}_{2,1} \leftarrow k^{1}_{2,0} \oplus k^{1}_{2,1} \)
  • \( k^{0}_{3,1} \leftarrow k^{1}_{3,0} \oplus k^{1}_{3,1} \)

Compute \( t_0, t_1, t_2, t_3 \leftarrow (\texttt{Rcon} \circ \texttt{SubWord} \circ \texttt{RotWord})(k^{0}_{0,3}, k^{0}_{1,3}, k^{0}_{2,3}, k^{0}_{3,3}) \)


  • \( k^{0}_{0,0} \leftarrow t_0 \oplus k^{1}_{0,0} \)
  • \( k^{0}_{1,0} \leftarrow t_1 \oplus k^{1}_{1,0} \)
  • \( k^{0}_{2,0} \leftarrow t_2 \oplus k^{1}_{2,0} \)
  • \( k^{0}_{3,0} \leftarrow t_3 \oplus k^{1}_{3,0} \)

Here is how RotWord, SubWord and Rcon actually works:

  • RotWord: Left-rotate the word. Given \( (w_1, w_2, w_3, w_4) \) return \( (w_2, w_3, w_4, w_1) \).

  • SubWord: Apply the S-box to every byte. Given \( (w_1, w_2, w_3, w_4) \) return \( (S[w_1], S[w_2], S[w_3], S[w_4]) \):

  • Rcon: Add (xor) a round constant into the first byte. Given \( (w_1, w_2, w_3, w_4) \) return \( (w_1 + \text{rcon}[r], w_2, w_3, w_4) \). Where \(r\) is the round number and \(\text{rcon}\) is the array [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36].

Note: The Rcon constants are derived by repeatedly multiplying by the polynomial \( X \in \mathbb{F}_{2^8}\) (encoded as 0x02).

Exercise.
Implement a script which takes the last round key of AES and recovers the cipherkey (equal to the first round key).
Solution (Click to Expand)

See attacker/invert.py.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python3

import sys
import binascii

rcon = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36]

sbox = [
    0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 
    0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
    0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 
    0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
    0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 
    0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
    0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 
    0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
    0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 
    0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
    0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 
    0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
    0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 
    0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
    0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 
    0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
    0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 
    0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
    0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 
    0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
    0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 
    0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
    0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 
    0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
    0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 
    0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
    0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 
    0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
    0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 
    0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
    0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 
    0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16
]

def words(key):
    assert len(key) == 16
    return (key[0:4], key[4:8], key[8:12], key[12:])

def xor(a, b):
    assert len(a) == len(b)
    return bytes([x ^ y for (x, y) in zip(a, b)])

def rot(w):
    return w[1:] + bytes([w[0]])

def sub(w):
    return bytes([sbox[x] for x in w])

def invert(key, rcon):

    w1, w2, w3, w4 = words(key)

    w4 = xor(w4, w3)
    w3 = xor(w3, w2)
    w2 = xor(w2, w1)

    D = xor(bytes([rcon, 0, 0, 0]), sub(rot(w4)))

    w1 = xor(w1, D)

    return w1 + w2 + w3 + w4

if __name__ == '__main__':
    key = binascii.unhexlify(sys.argv[1])
    for i, r in list(enumerate(rcon))[::-1]:
        key = invert(key, r)
        print('Round %d : %s' % (i, key.hex()))
Exercise.
Use the program to obtain the cipherkey from the last round key.
Solution (Click to Expand)

The cipherkey is:

000102030405060708090a0b0c0d0e0f

Independent Round Keys

In the attack above we made use of the fact that we can recover the cipherkey from the last round key. However even if AES used independent round keys we can recover all the remaining round keys by stripping rounds and keep applying differential fault attacks:

  1. Use the recovered round key to decrypt the last round of every ciphertext.
  2. Use the linear trick (see note) described in the AES encryption section to remove the last MixColumns.
  3. Mount a differential fault attack on the samples from this new (n - 1) round variant of AES just as before and recover inverse MixColumns applied to the new last round key.
  4. Apply the MixColumns to the recovered key to obtain the actual round key.

Repeat for every round.

Exercise.
Convince yourself that this works.