#!/usr/bin/env python3 """Reference FEN-to-[9]u32 board-state encoder for tests and design experiments. The full board state is represented as nine 32-bit words: state[0] = rank 1 state[1] = rank 2 ... state[7] = rank 8 state[8] = metadata Each rank word stores eight 4-bit square values: bits 0..3 = file a bits 4..7 = file b ... bits 28..31 = file h Piece encoding uses bit 3 for color and bits 0..2 for piece type: 0 = empty black pawn/knight/bishop/rook/queen/king = 1..6 white pawn/knight/bishop/rook/queen/king = 9..14 Metadata in state[8]: bit 0 active color: 1 = white, 0 = black bits 1..4 castling rights bits 5..11 en passant: bit 6 valid, bits 0..5 square index bits 12..18 halfmove clock bits 19..26 fullmove counter bits 27..31 reserved """ from dataclasses import dataclass PIECE_TYPE = { "p": 1, "n": 2, "b": 3, "r": 4, "q": 5, "k": 6, } CASTLING_BITS = { "K": 0b1000, "Q": 0b0100, "k": 0b0010, "q": 0b0001, } ACTIVE_COLOR_SHIFT = 0 CASTLING_RIGHTS_SHIFT = 1 EN_PASSANT_SHIFT = 5 HALFMOVE_CLOCK_SHIFT = 12 FULLMOVE_COUNTER_SHIFT = 19 ACTIVE_COLOR_MASK = 0b1 CASTLING_RIGHTS_MASK = 0b1111 EN_PASSANT_MASK = 0b1111111 HALFMOVE_CLOCK_MASK = 0b1111111 FULLMOVE_COUNTER_MASK = 0b11111111 @dataclass(frozen=True) class BoardState: state: tuple[int, ...] # 9 u32 values: ranks 1..8, then metadata def __post_init__(self) -> None: if len(self.state) != 9: raise ValueError("BoardState must contain exactly 9 words") for word in self.state: if not 0 <= word <= 0xFFFFFFFF: raise ValueError("BoardState words must fit in u32") @property def active_color(self) -> int: return (self.state[8] >> ACTIVE_COLOR_SHIFT) & ACTIVE_COLOR_MASK @property def castling_rights(self) -> int: return (self.state[8] >> CASTLING_RIGHTS_SHIFT) & CASTLING_RIGHTS_MASK @property def en_passant(self) -> int: return (self.state[8] >> EN_PASSANT_SHIFT) & EN_PASSANT_MASK @property def halfmove_clock(self) -> int: return (self.state[8] >> HALFMOVE_CLOCK_SHIFT) & HALFMOVE_CLOCK_MASK @property def fullmove_counter(self) -> int: return (self.state[8] >> FULLMOVE_COUNTER_SHIFT) & FULLMOVE_COUNTER_MASK def square_index(file_index: int, rank_num: int) -> int: """ a1 = 0 h1 = 7 a8 = 56 h8 = 63 """ if not 0 <= file_index <= 7: raise ValueError("file_index must be 0-7") if not 1 <= rank_num <= 8: raise ValueError("rank_num must be 1-8") return (rank_num - 1) * 8 + file_index def algebraic_to_square(square: str) -> int: if len(square) != 2: raise ValueError(f"Invalid square: {square}") file_ch = square[0] rank_ch = square[1] if file_ch < "a" or file_ch > "h": raise ValueError(f"Invalid file: {square}") if rank_ch < "1" or rank_ch > "8": raise ValueError(f"Invalid rank: {square}") return square_index(ord(file_ch) - ord("a"), int(rank_ch)) def square_to_algebraic(square: int) -> str: if not 0 <= square <= 63: raise ValueError("square must be 0-63") file_index = square % 8 rank_num = (square // 8) + 1 return f"{chr(ord('a') + file_index)}{rank_num}" def encode_piece(ch: str) -> int: color = 1 if ch.isupper() else 0 piece = PIECE_TYPE[ch.lower()] return (color << 3) | piece def get_square(state: tuple[int, ...], algebraic: str) -> int: idx = algebraic_to_square(algebraic) rank_index = idx // 8 file_index = idx % 8 return (state[rank_index] >> (file_index * 4)) & 0xF def has_piece(state: tuple[int, ...], algebraic: str, piece: str) -> bool: return get_square(state, algebraic) == encode_piece(piece) def parse_board_placement(placement: str) -> list[int]: ranks_u32 = [0] * 8 ranks = placement.split("/") if len(ranks) != 8: raise ValueError("FEN board placement must contain 8 ranks") # FEN is rank 8 to rank 1. for fen_rank_index, rank_text in enumerate(ranks): rank_num = 8 - fen_rank_index rank_index = rank_num - 1 file_index = 0 for ch in rank_text: if ch.isdigit(): file_index += int(ch) continue if ch.lower() not in PIECE_TYPE: raise ValueError(f"Invalid piece character: {ch}") if file_index >= 8: raise ValueError(f"Too many squares in rank: {rank_text}") ranks_u32[rank_index] |= encode_piece(ch) << (file_index * 4) file_index += 1 if file_index != 8: raise ValueError(f"Rank does not contain exactly 8 squares: {rank_text}") return ranks_u32 def en_passant_is_capturable(state: tuple[int, ...], ep_square: int, active_color: int) -> bool: """ active_color is side to move. If white is to move, black just advanced a pawn two squares, so the en passant target should be on rank 6 and a white pawn must be on an adjacent file on rank 5. If black is to move, white just advanced a pawn two squares, so the en passant target should be on rank 3 and a black pawn must be on an adjacent file on rank 4. """ file_index = ep_square % 8 rank_num = (ep_square // 8) + 1 if active_color == 1: if rank_num != 6: return False pawn_rank = 5 pawn_char = "P" else: if rank_num != 3: return False pawn_rank = 4 pawn_char = "p" for adjacent_file in (file_index - 1, file_index + 1): if 0 <= adjacent_file <= 7: adjacent_square = square_to_algebraic( square_index(adjacent_file, pawn_rank) ) if has_piece(state, adjacent_square, pawn_char): return True return False def encode_en_passant(ep: str, state: tuple[int, ...], active_color: int) -> int: """ 7-bit encoding: bit 6: valid bits 0-5: square index, a1=0 through h8=63 En passant is only stored if an opposing pawn can actually capture. """ if ep == "-": return 0 ep_square = algebraic_to_square(ep) if not en_passant_is_capturable(state, ep_square, active_color): return 0 return (1 << 6) | ep_square def decode_en_passant(ep_value: int) -> str | None: if ((ep_value >> 6) & 1) == 0: return None square = ep_value & 0x3F return square_to_algebraic(square) def encode_metadata( active_color: int, castling_rights: int, en_passant: int, halfmove_clock: int, fullmove_counter: int, ) -> int: if not 0 <= active_color <= ACTIVE_COLOR_MASK: raise ValueError("Active color must fit in 1 bit") if not 0 <= castling_rights <= CASTLING_RIGHTS_MASK: raise ValueError("Castling rights must fit in 4 bits") if not 0 <= en_passant <= EN_PASSANT_MASK: raise ValueError("En passant must fit in 7 bits") if not 0 <= halfmove_clock <= HALFMOVE_CLOCK_MASK: raise ValueError("Half-move clock must fit in 7 bits") if not 0 <= fullmove_counter <= FULLMOVE_COUNTER_MASK: raise ValueError("Full-move counter must fit in 8 bits") return ( (active_color << ACTIVE_COLOR_SHIFT) | (castling_rights << CASTLING_RIGHTS_SHIFT) | (en_passant << EN_PASSANT_SHIFT) | (halfmove_clock << HALFMOVE_CLOCK_SHIFT) | (fullmove_counter << FULLMOVE_COUNTER_SHIFT) ) def parse_fen(fen: str) -> BoardState: parts = fen.strip().split() if len(parts) != 6: raise ValueError("FEN must contain exactly 6 fields") placement, active, castling, ep, halfmove, fullmove = parts state = parse_board_placement(placement) if active == "w": active_color = 1 elif active == "b": active_color = 0 else: raise ValueError(f"Invalid active color: {active}") castling_rights = 0 if castling != "-": for ch in castling: if ch not in CASTLING_BITS: raise ValueError(f"Invalid castling right: {ch}") castling_rights |= CASTLING_BITS[ch] halfmove_clock = int(halfmove) fullmove_counter = int(fullmove) en_passant = encode_en_passant(ep, tuple(state), active_color) metadata = encode_metadata( active_color=active_color, castling_rights=castling_rights, en_passant=en_passant, halfmove_clock=halfmove_clock, fullmove_counter=fullmove_counter, ) state.append(metadata) return BoardState(tuple(state)) def run_tests() -> None: start_fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" start = parse_fen(start_fen) assert start.active_color == 1 assert start.castling_rights == 0b1111 assert start.en_passant == 0 assert start.halfmove_clock == 0 assert start.fullmove_counter == 1 assert get_square(start.state, "a1") == encode_piece("R") assert get_square(start.state, "e1") == encode_piece("K") assert get_square(start.state, "a8") == encode_piece("r") assert get_square(start.state, "e8") == encode_piece("k") assert get_square(start.state, "e4") == 0 assert start.state[0] == 0xCABEDBAC # rank 1 assert start.state[1] == 0x99999999 # rank 2 assert start.state[6] == 0x11111111 # rank 7 assert start.state[7] == 0x42365324 # rank 8 assert start.state[8] == 0x0008001F mid_fen = "r1bqkbnr/pppp1ppp/2n5/4p3/3P4/5N2/PPP2PPP/RNBQKB1R w KQkq e3 4 5" mid = parse_fen(mid_fen) assert mid.active_color == 1 assert mid.castling_rights == 0b1111 assert mid.en_passant == 0 # e3 exists in FEN, but no white pawn can capture there assert mid.halfmove_clock == 4 assert mid.fullmove_counter == 5 assert get_square(mid.state, "a8") == encode_piece("r") assert get_square(mid.state, "c8") == encode_piece("b") assert get_square(mid.state, "c6") == encode_piece("n") assert get_square(mid.state, "e5") == encode_piece("p") assert get_square(mid.state, "d4") == encode_piece("P") assert get_square(mid.state, "f3") == encode_piece("N") assert get_square(mid.state, "g1") == 0 capturable_ep_fen = "8/8/8/3Pp3/8/8/8/8 w - e6 0 1" capturable = parse_fen(capturable_ep_fen) assert decode_en_passant(capturable.en_passant) == "e6" print("All tests passed.") def format_state_hex(state: tuple[int, ...]) -> str: return "[" + ", ".join(f"0x{word:08x}" for word in state) + "]" if __name__ == "__main__": run_tests() fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" board_state = parse_fen(fen) print(board_state) print(f"state hex: {format_state_hex(board_state.state)}")