zig-chess/tools/fen_to_board_state.py

379 lines
11 KiB
Python

#!/usr/bin/env python3
"""Reference FEN-to-[9]u32 board-state encoder for tests and design experiments.
The full board state is represented as nine 32-bit words:
state[0] = rank 1
state[1] = rank 2
...
state[7] = rank 8
state[8] = metadata
Each rank word stores eight 4-bit square values:
bits 0..3 = file a
bits 4..7 = file b
...
bits 28..31 = file h
Piece encoding uses bit 3 for color and bits 0..2 for piece type:
0 = empty
black pawn/knight/bishop/rook/queen/king = 1..6
white pawn/knight/bishop/rook/queen/king = 9..14
Metadata in state[8]:
bit 0 active color: 1 = white, 0 = black
bits 1..4 castling rights
bits 5..11 en passant: bit 6 valid, bits 0..5 square index
bits 12..18 halfmove clock
bits 19..26 fullmove counter
bits 27..31 reserved
"""
from dataclasses import dataclass
PIECE_TYPE = {
"p": 1,
"n": 2,
"b": 3,
"r": 4,
"q": 5,
"k": 6,
}
CASTLING_BITS = {
"K": 0b1000,
"Q": 0b0100,
"k": 0b0010,
"q": 0b0001,
}
ACTIVE_COLOR_SHIFT = 0
CASTLING_RIGHTS_SHIFT = 1
EN_PASSANT_SHIFT = 5
HALFMOVE_CLOCK_SHIFT = 12
FULLMOVE_COUNTER_SHIFT = 19
ACTIVE_COLOR_MASK = 0b1
CASTLING_RIGHTS_MASK = 0b1111
EN_PASSANT_MASK = 0b1111111
HALFMOVE_CLOCK_MASK = 0b1111111
FULLMOVE_COUNTER_MASK = 0b11111111
@dataclass(frozen=True)
class BoardState:
state: tuple[int, ...] # 9 u32 values: ranks 1..8, then metadata
def __post_init__(self) -> None:
if len(self.state) != 9:
raise ValueError("BoardState must contain exactly 9 words")
for word in self.state:
if not 0 <= word <= 0xFFFFFFFF:
raise ValueError("BoardState words must fit in u32")
@property
def active_color(self) -> int:
return (self.state[8] >> ACTIVE_COLOR_SHIFT) & ACTIVE_COLOR_MASK
@property
def castling_rights(self) -> int:
return (self.state[8] >> CASTLING_RIGHTS_SHIFT) & CASTLING_RIGHTS_MASK
@property
def en_passant(self) -> int:
return (self.state[8] >> EN_PASSANT_SHIFT) & EN_PASSANT_MASK
@property
def halfmove_clock(self) -> int:
return (self.state[8] >> HALFMOVE_CLOCK_SHIFT) & HALFMOVE_CLOCK_MASK
@property
def fullmove_counter(self) -> int:
return (self.state[8] >> FULLMOVE_COUNTER_SHIFT) & FULLMOVE_COUNTER_MASK
def square_index(file_index: int, rank_num: int) -> int:
"""
a1 = 0
h1 = 7
a8 = 56
h8 = 63
"""
if not 0 <= file_index <= 7:
raise ValueError("file_index must be 0-7")
if not 1 <= rank_num <= 8:
raise ValueError("rank_num must be 1-8")
return (rank_num - 1) * 8 + file_index
def algebraic_to_square(square: str) -> int:
if len(square) != 2:
raise ValueError(f"Invalid square: {square}")
file_ch = square[0]
rank_ch = square[1]
if file_ch < "a" or file_ch > "h":
raise ValueError(f"Invalid file: {square}")
if rank_ch < "1" or rank_ch > "8":
raise ValueError(f"Invalid rank: {square}")
return square_index(ord(file_ch) - ord("a"), int(rank_ch))
def square_to_algebraic(square: int) -> str:
if not 0 <= square <= 63:
raise ValueError("square must be 0-63")
file_index = square % 8
rank_num = (square // 8) + 1
return f"{chr(ord('a') + file_index)}{rank_num}"
def encode_piece(ch: str) -> int:
color = 1 if ch.isupper() else 0
piece = PIECE_TYPE[ch.lower()]
return (color << 3) | piece
def get_square(state: tuple[int, ...], algebraic: str) -> int:
idx = algebraic_to_square(algebraic)
rank_index = idx // 8
file_index = idx % 8
return (state[rank_index] >> (file_index * 4)) & 0xF
def has_piece(state: tuple[int, ...], algebraic: str, piece: str) -> bool:
return get_square(state, algebraic) == encode_piece(piece)
def parse_board_placement(placement: str) -> list[int]:
ranks_u32 = [0] * 8
ranks = placement.split("/")
if len(ranks) != 8:
raise ValueError("FEN board placement must contain 8 ranks")
# FEN is rank 8 to rank 1.
for fen_rank_index, rank_text in enumerate(ranks):
rank_num = 8 - fen_rank_index
rank_index = rank_num - 1
file_index = 0
for ch in rank_text:
if ch.isdigit():
file_index += int(ch)
continue
if ch.lower() not in PIECE_TYPE:
raise ValueError(f"Invalid piece character: {ch}")
if file_index >= 8:
raise ValueError(f"Too many squares in rank: {rank_text}")
ranks_u32[rank_index] |= encode_piece(ch) << (file_index * 4)
file_index += 1
if file_index != 8:
raise ValueError(f"Rank does not contain exactly 8 squares: {rank_text}")
return ranks_u32
def en_passant_is_capturable(state: tuple[int, ...], ep_square: int, active_color: int) -> bool:
"""
active_color is side to move.
If white is to move, black just advanced a pawn two squares,
so the en passant target should be on rank 6 and a white pawn
must be on an adjacent file on rank 5.
If black is to move, white just advanced a pawn two squares,
so the en passant target should be on rank 3 and a black pawn
must be on an adjacent file on rank 4.
"""
file_index = ep_square % 8
rank_num = (ep_square // 8) + 1
if active_color == 1:
if rank_num != 6:
return False
pawn_rank = 5
pawn_char = "P"
else:
if rank_num != 3:
return False
pawn_rank = 4
pawn_char = "p"
for adjacent_file in (file_index - 1, file_index + 1):
if 0 <= adjacent_file <= 7:
adjacent_square = square_to_algebraic(
square_index(adjacent_file, pawn_rank)
)
if has_piece(state, adjacent_square, pawn_char):
return True
return False
def encode_en_passant(ep: str, state: tuple[int, ...], active_color: int) -> int:
"""
7-bit encoding:
bit 6: valid
bits 0-5: square index, a1=0 through h8=63
En passant is only stored if an opposing pawn can actually capture.
"""
if ep == "-":
return 0
ep_square = algebraic_to_square(ep)
if not en_passant_is_capturable(state, ep_square, active_color):
return 0
return (1 << 6) | ep_square
def decode_en_passant(ep_value: int) -> str | None:
if ((ep_value >> 6) & 1) == 0:
return None
square = ep_value & 0x3F
return square_to_algebraic(square)
def encode_metadata(
active_color: int,
castling_rights: int,
en_passant: int,
halfmove_clock: int,
fullmove_counter: int,
) -> int:
if not 0 <= active_color <= ACTIVE_COLOR_MASK:
raise ValueError("Active color must fit in 1 bit")
if not 0 <= castling_rights <= CASTLING_RIGHTS_MASK:
raise ValueError("Castling rights must fit in 4 bits")
if not 0 <= en_passant <= EN_PASSANT_MASK:
raise ValueError("En passant must fit in 7 bits")
if not 0 <= halfmove_clock <= HALFMOVE_CLOCK_MASK:
raise ValueError("Half-move clock must fit in 7 bits")
if not 0 <= fullmove_counter <= FULLMOVE_COUNTER_MASK:
raise ValueError("Full-move counter must fit in 8 bits")
return (
(active_color << ACTIVE_COLOR_SHIFT)
| (castling_rights << CASTLING_RIGHTS_SHIFT)
| (en_passant << EN_PASSANT_SHIFT)
| (halfmove_clock << HALFMOVE_CLOCK_SHIFT)
| (fullmove_counter << FULLMOVE_COUNTER_SHIFT)
)
def parse_fen(fen: str) -> BoardState:
parts = fen.strip().split()
if len(parts) != 6:
raise ValueError("FEN must contain exactly 6 fields")
placement, active, castling, ep, halfmove, fullmove = parts
state = parse_board_placement(placement)
if active == "w":
active_color = 1
elif active == "b":
active_color = 0
else:
raise ValueError(f"Invalid active color: {active}")
castling_rights = 0
if castling != "-":
for ch in castling:
if ch not in CASTLING_BITS:
raise ValueError(f"Invalid castling right: {ch}")
castling_rights |= CASTLING_BITS[ch]
halfmove_clock = int(halfmove)
fullmove_counter = int(fullmove)
en_passant = encode_en_passant(ep, tuple(state), active_color)
metadata = encode_metadata(
active_color=active_color,
castling_rights=castling_rights,
en_passant=en_passant,
halfmove_clock=halfmove_clock,
fullmove_counter=fullmove_counter,
)
state.append(metadata)
return BoardState(tuple(state))
def run_tests() -> None:
start_fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
start = parse_fen(start_fen)
assert start.active_color == 1
assert start.castling_rights == 0b1111
assert start.en_passant == 0
assert start.halfmove_clock == 0
assert start.fullmove_counter == 1
assert get_square(start.state, "a1") == encode_piece("R")
assert get_square(start.state, "e1") == encode_piece("K")
assert get_square(start.state, "a8") == encode_piece("r")
assert get_square(start.state, "e8") == encode_piece("k")
assert get_square(start.state, "e4") == 0
assert start.state[0] == 0xCABEDBAC # rank 1
assert start.state[1] == 0x99999999 # rank 2
assert start.state[6] == 0x11111111 # rank 7
assert start.state[7] == 0x42365324 # rank 8
assert start.state[8] == 0x0008001F
mid_fen = "r1bqkbnr/pppp1ppp/2n5/4p3/3P4/5N2/PPP2PPP/RNBQKB1R w KQkq e3 4 5"
mid = parse_fen(mid_fen)
assert mid.active_color == 1
assert mid.castling_rights == 0b1111
assert mid.en_passant == 0 # e3 exists in FEN, but no white pawn can capture there
assert mid.halfmove_clock == 4
assert mid.fullmove_counter == 5
assert get_square(mid.state, "a8") == encode_piece("r")
assert get_square(mid.state, "c8") == encode_piece("b")
assert get_square(mid.state, "c6") == encode_piece("n")
assert get_square(mid.state, "e5") == encode_piece("p")
assert get_square(mid.state, "d4") == encode_piece("P")
assert get_square(mid.state, "f3") == encode_piece("N")
assert get_square(mid.state, "g1") == 0
capturable_ep_fen = "8/8/8/3Pp3/8/8/8/8 w - e6 0 1"
capturable = parse_fen(capturable_ep_fen)
assert decode_en_passant(capturable.en_passant) == "e6"
print("All tests passed.")
def format_state_hex(state: tuple[int, ...]) -> str:
return "[" + ", ".join(f"0x{word:08x}" for word in state) + "]"
if __name__ == "__main__":
run_tests()
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
board_state = parse_fen(fen)
print(board_state)
print(f"state hex: {format_state_hex(board_state.state)}")