CTF writeups/Crypto

[Cyber Apocalypse CTF 2022] The Three-Eyed Oracle

Now1z 2022. 5. 20. 13:48

1. Investigation

썸네일용

 

>> chall.py

from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
import random
import signal
import subprocess
import socketserver

FLAG = b'HTB{--REDACTED--}'
prefix = random.randbytes(12)
key = random.randbytes(16)


def encrypt(key, msg):
    msg = bytes.fromhex(msg)
    crypto = AES.new(key, AES.MODE_ECB)
    padded = pad(prefix + msg + FLAG, 16)
    return crypto.encrypt(padded).hex()


def challenge(req):
    req.sendall(b'Welcome to Klaus\'s crypto lab.\n' +
                b'It seems like there is a prefix appended to the real firmware\n' +
                b'Can you somehow extract the firmware and fix the chip?\n')
    while True:
        req.sendall(b'> ')
        try:
            msg = req.recv(4096).decode()

            ct = encrypt(key, msg)
        except:
            req.sendall(b'An error occurred! Please try again!')

        req.sendall(ct.encode() + b'\n')


class incoming(socketserver.BaseRequestHandler):
    def handle(self):
        signal.alarm(1500)
        req = self.request
        challenge(req)


class ReusableTCPServer(socketserver.ForkingMixIn, socketserver.TCPServer):
    pass


def main():
    socketserver.TCPServer.allow_reuse_address = True
    server = ReusableTCPServer(("0.0.0.0", 1337), incoming)
    server.serve_forever()


if __name__ == "__main__":
    main()

클라이언트로부터 입력받은 값을 msg 에 저장하고 랜덤한 16 바이트 크기의 key 로 AES ECB 모드를 사용하여 암호화하는 소스코드인 것을 확인할 수 있습니다.

 

특이한 점은 msg 를 패딩할 때 (prefix + msg + FLAG) 형태로 만든 뒤 패딩을 수행한다는 점인데, AES ECB 모드에서 해당 형태(즉, FLAG 앞에 공격자가 조작할 수 있는 msg 를 넣을 수 있는 형태)로 평문을 암호화하게 되면 한 바이트씩 Brute-Force 공격을 수행해 msg 뒤에 있는 평문을 알아낼 수 있다는 약점이 있습니다.

 

REF#1 : https://crypto.stackexchange.com/questions/42891/chosen-plaintext-attack-on-aes-in-ecb-mode


2. Solution

REF#1 과 함께 참고해서 보면 좋습니다. (대신 영어로 적혀있음)

prefix(12)  msg(4) msg(15)        FLAG(?)
xxxxxxxxxxxx0000 | 000000000000000? | ???????????????? | ???????????????? | ...
-------------------------------------------------------------------------------
CBLOCK1          | CBLOCK2          | CBLOCK3          | CBLOCK4          | ...
                                  ^
                                  Brute-Force byte-by-byte until we get same CBLOCK2.

위와 같이 msg 에 4 + 15 바이트의 임의의 값을 넣으면 FLAG 가 한 바이트 포함된 형태로 암호화가 블록 단위로 진행되게 됩니다. 이렇게 해서 암호문을 얻고 msg 의 마지막 1 바이트에 Brute-Force 를 수행하게 되면 결국 CBLOCK2 와 일치하는 암호문을 얻었을 때에 값이 평문이 됩니다.

 

prefix(12)  msg(4) msg(15)       FLAG(?)
xxxxxxxxxxxx0000 | 00000000000000H? | ???????????????? | ???????????????? | ...
-------------------------------------------------------------------------------
CBLOCK1          | CBLOCK2          | CBLOCK3          | CBLOCK4          | ...
                                 ^
                                 Brute-Force byte-by-byte until we get same CBLOCK2.

그 다음은 한칸 씩 왼쪽으로 밀어서 똑같이 마지막 바이트에 Brute-Force 를 수행하고 일치하는 암호문을 찾았을 때의 평문을 저장하면 됩니다. 이런식으로 모든 FLAG 를 알아낼 수 있습니다.

 

>> sol.py

from pwn import *

ip = '188.166.172.138'
port = 30328

nc = remote(ip, port)
print('='*64)

FLAG = ''
new_msg = bytearray(4+16+16)

nth_block = 2

charset = '{}_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'

for i in range(32):    
    msg = b'00'*(4+16+15-i)
    nc.sendafter(b'> ', msg)
    print("[+] msg sent: ", msg.decode())

    ct = nc.recvline().strip()
    print("[+] ct : ", ct)
    print("-"*64)
    print("[+] cb1 : ", ct[:32])
    print("[+] cb2 : ", ct[32:64])
    print("[+] cb3 : ", ct[64:96])
    print("[+] cb4 : ", ct[96:128])
    print("[+] cb5 : ", ct[128:])
    print('='*64)
    
    for j in charset:
        #new_msg = msg + format(j, '02x').encode()
        new_msg[-1] = ord(j)
        print('.', end=' ')
        nc.sendafter(b'> ', new_msg.hex().encode())
        new_ct = nc.recvline().strip()
        #print(i, new_ct)
        if(new_ct[nth_block*32:nth_block*32+32] == ct[nth_block*32:nth_block*32+32]):
            FLAG += j
            print("\n[+] Found char: ", j)
            print("[+] new_ct: ", new_ct[nth_block*32:nth_block*32+32])
            print("[+] new_msg: ", new_msg.hex().encode())
            for k in range(i, -1, -1):
                new_msg[-(k+2)] = new_msg[-(k+1)]
            print("[+] FLAG: ", FLAG)
            print('-'*64)

            if(j == '}'):
                exit(0)
            
            break
    else:
        print("[+] No char found, end loop")
        break


# FLAG : HTB{345y************v3ry}

🚩 FLAG : HTB{345y************v3ry}