379 lines
11 KiB
Python
379 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""Reference FEN-to-[9]u32 board-state encoder for tests and design experiments.
|
|
|
|
The full board state is represented as nine 32-bit words:
|
|
|
|
state[0] = rank 1
|
|
state[1] = rank 2
|
|
...
|
|
state[7] = rank 8
|
|
state[8] = metadata
|
|
|
|
Each rank word stores eight 4-bit square values:
|
|
|
|
bits 0..3 = file a
|
|
bits 4..7 = file b
|
|
...
|
|
bits 28..31 = file h
|
|
|
|
Piece encoding uses bit 3 for color and bits 0..2 for piece type:
|
|
|
|
0 = empty
|
|
black pawn/knight/bishop/rook/queen/king = 1..6
|
|
white pawn/knight/bishop/rook/queen/king = 9..14
|
|
|
|
Metadata in state[8]:
|
|
|
|
bit 0 active color: 1 = white, 0 = black
|
|
bits 1..4 castling rights
|
|
bits 5..11 en passant: bit 6 valid, bits 0..5 square index
|
|
bits 12..18 halfmove clock
|
|
bits 19..26 fullmove counter
|
|
bits 27..31 reserved
|
|
"""
|
|
from dataclasses import dataclass
|
|
|
|
PIECE_TYPE = {
|
|
"p": 1,
|
|
"n": 2,
|
|
"b": 3,
|
|
"r": 4,
|
|
"q": 5,
|
|
"k": 6,
|
|
}
|
|
|
|
CASTLING_BITS = {
|
|
"K": 0b1000,
|
|
"Q": 0b0100,
|
|
"k": 0b0010,
|
|
"q": 0b0001,
|
|
}
|
|
|
|
ACTIVE_COLOR_SHIFT = 0
|
|
CASTLING_RIGHTS_SHIFT = 1
|
|
EN_PASSANT_SHIFT = 5
|
|
HALFMOVE_CLOCK_SHIFT = 12
|
|
FULLMOVE_COUNTER_SHIFT = 19
|
|
|
|
ACTIVE_COLOR_MASK = 0b1
|
|
CASTLING_RIGHTS_MASK = 0b1111
|
|
EN_PASSANT_MASK = 0b1111111
|
|
HALFMOVE_CLOCK_MASK = 0b1111111
|
|
FULLMOVE_COUNTER_MASK = 0b11111111
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BoardState:
|
|
state: tuple[int, ...] # 9 u32 values: ranks 1..8, then metadata
|
|
|
|
def __post_init__(self) -> None:
|
|
if len(self.state) != 9:
|
|
raise ValueError("BoardState must contain exactly 9 words")
|
|
for word in self.state:
|
|
if not 0 <= word <= 0xFFFFFFFF:
|
|
raise ValueError("BoardState words must fit in u32")
|
|
|
|
@property
|
|
def active_color(self) -> int:
|
|
return (self.state[8] >> ACTIVE_COLOR_SHIFT) & ACTIVE_COLOR_MASK
|
|
|
|
@property
|
|
def castling_rights(self) -> int:
|
|
return (self.state[8] >> CASTLING_RIGHTS_SHIFT) & CASTLING_RIGHTS_MASK
|
|
|
|
@property
|
|
def en_passant(self) -> int:
|
|
return (self.state[8] >> EN_PASSANT_SHIFT) & EN_PASSANT_MASK
|
|
|
|
@property
|
|
def halfmove_clock(self) -> int:
|
|
return (self.state[8] >> HALFMOVE_CLOCK_SHIFT) & HALFMOVE_CLOCK_MASK
|
|
|
|
@property
|
|
def fullmove_counter(self) -> int:
|
|
return (self.state[8] >> FULLMOVE_COUNTER_SHIFT) & FULLMOVE_COUNTER_MASK
|
|
|
|
|
|
def square_index(file_index: int, rank_num: int) -> int:
|
|
"""
|
|
a1 = 0
|
|
h1 = 7
|
|
a8 = 56
|
|
h8 = 63
|
|
"""
|
|
if not 0 <= file_index <= 7:
|
|
raise ValueError("file_index must be 0-7")
|
|
if not 1 <= rank_num <= 8:
|
|
raise ValueError("rank_num must be 1-8")
|
|
|
|
return (rank_num - 1) * 8 + file_index
|
|
|
|
|
|
def algebraic_to_square(square: str) -> int:
|
|
if len(square) != 2:
|
|
raise ValueError(f"Invalid square: {square}")
|
|
|
|
file_ch = square[0]
|
|
rank_ch = square[1]
|
|
|
|
if file_ch < "a" or file_ch > "h":
|
|
raise ValueError(f"Invalid file: {square}")
|
|
if rank_ch < "1" or rank_ch > "8":
|
|
raise ValueError(f"Invalid rank: {square}")
|
|
|
|
return square_index(ord(file_ch) - ord("a"), int(rank_ch))
|
|
|
|
|
|
def square_to_algebraic(square: int) -> str:
|
|
if not 0 <= square <= 63:
|
|
raise ValueError("square must be 0-63")
|
|
|
|
file_index = square % 8
|
|
rank_num = (square // 8) + 1
|
|
|
|
return f"{chr(ord('a') + file_index)}{rank_num}"
|
|
|
|
|
|
def encode_piece(ch: str) -> int:
|
|
color = 1 if ch.isupper() else 0
|
|
piece = PIECE_TYPE[ch.lower()]
|
|
return (color << 3) | piece
|
|
|
|
|
|
def get_square(state: tuple[int, ...], algebraic: str) -> int:
|
|
idx = algebraic_to_square(algebraic)
|
|
rank_index = idx // 8
|
|
file_index = idx % 8
|
|
return (state[rank_index] >> (file_index * 4)) & 0xF
|
|
|
|
|
|
def has_piece(state: tuple[int, ...], algebraic: str, piece: str) -> bool:
|
|
return get_square(state, algebraic) == encode_piece(piece)
|
|
|
|
|
|
def parse_board_placement(placement: str) -> list[int]:
|
|
ranks_u32 = [0] * 8
|
|
ranks = placement.split("/")
|
|
|
|
if len(ranks) != 8:
|
|
raise ValueError("FEN board placement must contain 8 ranks")
|
|
|
|
# FEN is rank 8 to rank 1.
|
|
for fen_rank_index, rank_text in enumerate(ranks):
|
|
rank_num = 8 - fen_rank_index
|
|
rank_index = rank_num - 1
|
|
file_index = 0
|
|
|
|
for ch in rank_text:
|
|
if ch.isdigit():
|
|
file_index += int(ch)
|
|
continue
|
|
|
|
if ch.lower() not in PIECE_TYPE:
|
|
raise ValueError(f"Invalid piece character: {ch}")
|
|
|
|
if file_index >= 8:
|
|
raise ValueError(f"Too many squares in rank: {rank_text}")
|
|
|
|
ranks_u32[rank_index] |= encode_piece(ch) << (file_index * 4)
|
|
file_index += 1
|
|
|
|
if file_index != 8:
|
|
raise ValueError(f"Rank does not contain exactly 8 squares: {rank_text}")
|
|
|
|
return ranks_u32
|
|
|
|
|
|
def en_passant_is_capturable(state: tuple[int, ...], ep_square: int, active_color: int) -> bool:
|
|
"""
|
|
active_color is side to move.
|
|
|
|
If white is to move, black just advanced a pawn two squares,
|
|
so the en passant target should be on rank 6 and a white pawn
|
|
must be on an adjacent file on rank 5.
|
|
|
|
If black is to move, white just advanced a pawn two squares,
|
|
so the en passant target should be on rank 3 and a black pawn
|
|
must be on an adjacent file on rank 4.
|
|
"""
|
|
file_index = ep_square % 8
|
|
rank_num = (ep_square // 8) + 1
|
|
|
|
if active_color == 1:
|
|
if rank_num != 6:
|
|
return False
|
|
|
|
pawn_rank = 5
|
|
pawn_char = "P"
|
|
else:
|
|
if rank_num != 3:
|
|
return False
|
|
|
|
pawn_rank = 4
|
|
pawn_char = "p"
|
|
|
|
for adjacent_file in (file_index - 1, file_index + 1):
|
|
if 0 <= adjacent_file <= 7:
|
|
adjacent_square = square_to_algebraic(
|
|
square_index(adjacent_file, pawn_rank)
|
|
)
|
|
if has_piece(state, adjacent_square, pawn_char):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def encode_en_passant(ep: str, state: tuple[int, ...], active_color: int) -> int:
|
|
"""
|
|
7-bit encoding:
|
|
bit 6: valid
|
|
bits 0-5: square index, a1=0 through h8=63
|
|
|
|
En passant is only stored if an opposing pawn can actually capture.
|
|
"""
|
|
if ep == "-":
|
|
return 0
|
|
|
|
ep_square = algebraic_to_square(ep)
|
|
|
|
if not en_passant_is_capturable(state, ep_square, active_color):
|
|
return 0
|
|
|
|
return (1 << 6) | ep_square
|
|
|
|
|
|
def decode_en_passant(ep_value: int) -> str | None:
|
|
if ((ep_value >> 6) & 1) == 0:
|
|
return None
|
|
|
|
square = ep_value & 0x3F
|
|
return square_to_algebraic(square)
|
|
|
|
|
|
def encode_metadata(
|
|
active_color: int,
|
|
castling_rights: int,
|
|
en_passant: int,
|
|
halfmove_clock: int,
|
|
fullmove_counter: int,
|
|
) -> int:
|
|
if not 0 <= active_color <= ACTIVE_COLOR_MASK:
|
|
raise ValueError("Active color must fit in 1 bit")
|
|
if not 0 <= castling_rights <= CASTLING_RIGHTS_MASK:
|
|
raise ValueError("Castling rights must fit in 4 bits")
|
|
if not 0 <= en_passant <= EN_PASSANT_MASK:
|
|
raise ValueError("En passant must fit in 7 bits")
|
|
if not 0 <= halfmove_clock <= HALFMOVE_CLOCK_MASK:
|
|
raise ValueError("Half-move clock must fit in 7 bits")
|
|
if not 0 <= fullmove_counter <= FULLMOVE_COUNTER_MASK:
|
|
raise ValueError("Full-move counter must fit in 8 bits")
|
|
|
|
return (
|
|
(active_color << ACTIVE_COLOR_SHIFT)
|
|
| (castling_rights << CASTLING_RIGHTS_SHIFT)
|
|
| (en_passant << EN_PASSANT_SHIFT)
|
|
| (halfmove_clock << HALFMOVE_CLOCK_SHIFT)
|
|
| (fullmove_counter << FULLMOVE_COUNTER_SHIFT)
|
|
)
|
|
|
|
|
|
def parse_fen(fen: str) -> BoardState:
|
|
parts = fen.strip().split()
|
|
|
|
if len(parts) != 6:
|
|
raise ValueError("FEN must contain exactly 6 fields")
|
|
|
|
placement, active, castling, ep, halfmove, fullmove = parts
|
|
|
|
state = parse_board_placement(placement)
|
|
|
|
if active == "w":
|
|
active_color = 1
|
|
elif active == "b":
|
|
active_color = 0
|
|
else:
|
|
raise ValueError(f"Invalid active color: {active}")
|
|
|
|
castling_rights = 0
|
|
if castling != "-":
|
|
for ch in castling:
|
|
if ch not in CASTLING_BITS:
|
|
raise ValueError(f"Invalid castling right: {ch}")
|
|
castling_rights |= CASTLING_BITS[ch]
|
|
|
|
halfmove_clock = int(halfmove)
|
|
fullmove_counter = int(fullmove)
|
|
en_passant = encode_en_passant(ep, tuple(state), active_color)
|
|
|
|
metadata = encode_metadata(
|
|
active_color=active_color,
|
|
castling_rights=castling_rights,
|
|
en_passant=en_passant,
|
|
halfmove_clock=halfmove_clock,
|
|
fullmove_counter=fullmove_counter,
|
|
)
|
|
state.append(metadata)
|
|
|
|
return BoardState(tuple(state))
|
|
|
|
|
|
def run_tests() -> None:
|
|
start_fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
start = parse_fen(start_fen)
|
|
|
|
assert start.active_color == 1
|
|
assert start.castling_rights == 0b1111
|
|
assert start.en_passant == 0
|
|
assert start.halfmove_clock == 0
|
|
assert start.fullmove_counter == 1
|
|
|
|
assert get_square(start.state, "a1") == encode_piece("R")
|
|
assert get_square(start.state, "e1") == encode_piece("K")
|
|
assert get_square(start.state, "a8") == encode_piece("r")
|
|
assert get_square(start.state, "e8") == encode_piece("k")
|
|
assert get_square(start.state, "e4") == 0
|
|
|
|
assert start.state[0] == 0xCABEDBAC # rank 1
|
|
assert start.state[1] == 0x99999999 # rank 2
|
|
assert start.state[6] == 0x11111111 # rank 7
|
|
assert start.state[7] == 0x42365324 # rank 8
|
|
assert start.state[8] == 0x0008001F
|
|
|
|
mid_fen = "r1bqkbnr/pppp1ppp/2n5/4p3/3P4/5N2/PPP2PPP/RNBQKB1R w KQkq e3 4 5"
|
|
mid = parse_fen(mid_fen)
|
|
|
|
assert mid.active_color == 1
|
|
assert mid.castling_rights == 0b1111
|
|
assert mid.en_passant == 0 # e3 exists in FEN, but no white pawn can capture there
|
|
assert mid.halfmove_clock == 4
|
|
assert mid.fullmove_counter == 5
|
|
|
|
assert get_square(mid.state, "a8") == encode_piece("r")
|
|
assert get_square(mid.state, "c8") == encode_piece("b")
|
|
assert get_square(mid.state, "c6") == encode_piece("n")
|
|
assert get_square(mid.state, "e5") == encode_piece("p")
|
|
assert get_square(mid.state, "d4") == encode_piece("P")
|
|
assert get_square(mid.state, "f3") == encode_piece("N")
|
|
assert get_square(mid.state, "g1") == 0
|
|
|
|
capturable_ep_fen = "8/8/8/3Pp3/8/8/8/8 w - e6 0 1"
|
|
capturable = parse_fen(capturable_ep_fen)
|
|
|
|
assert decode_en_passant(capturable.en_passant) == "e6"
|
|
|
|
print("All tests passed.")
|
|
|
|
|
|
def format_state_hex(state: tuple[int, ...]) -> str:
|
|
return "[" + ", ".join(f"0x{word:08x}" for word in state) + "]"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|
|
|
|
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
board_state = parse_fen(fen)
|
|
|
|
print(board_state)
|
|
print(f"state hex: {format_state_hex(board_state.state)}")
|