Refactor Vulkan setup stage 1

This commit is contained in:
WayfinderAK 2026-05-15 00:17:56 -08:00
parent 32f6e6c14d
commit 3883827f94
No known key found for this signature in database
2 changed files with 1008 additions and 482 deletions

View File

@ -2,117 +2,73 @@ const std = @import("std");
const glfw = @import("zglfw"); const glfw = @import("zglfw");
const vk = @import("vulkan"); const vk = @import("vulkan");
// The build script compiles these GLSL files to SPIR-V and exposes them as
// anonymous imports. They are currently only loaded and printed; the program
// does not create shader modules or a graphics pipeline yet.
const square_vert_spv = @embedFile("square_vertex_shader"); const square_vert_spv = @embedFile("square_vertex_shader");
const square_frag_spv = @embedFile("square_fragment_shader"); const square_frag_spv = @embedFile("square_fragment_shader");
const VulkanContext = struct {
base: vk.BaseWrapper,
instance: vk.Instance,
vki: vk.InstanceWrapper,
fn destroy(self: *const VulkanContext) void {
self.vki.destroyInstance(self.instance, null);
}
};
pub fn main() !void { pub fn main() !void {
std.debug.print("zig-chess bootstrap\n", .{}); std.debug.print("zig-chess bootstrap\n", .{});
std.debug.print("vertex shader bytes: {}\n", .{square_vert_spv.len}); std.debug.print("vertex shader bytes: {}\n", .{square_vert_spv.len});
std.debug.print("fragment shader bytes: {}\n", .{square_frag_spv.len}); std.debug.print("fragment shader bytes: {}\n", .{square_frag_spv.len});
try glfw.init(); // ---------------------------------------------------------------------
// Window bootstrap
// ---------------------------------------------------------------------
const window = try initWindow(800, 600, "zig-chess");
defer glfw.terminate(); defer glfw.terminate();
glfw.windowHint(.client_api, .no_api);
const window = try glfw.Window.create(
800,
600,
"zig-chess",
null,
null,
);
defer window.destroy(); defer window.destroy();
std.debug.print("GLFW platform: {any}\n", .{glfw.getPlatform()});
std.debug.print("Vulkan supported by GLFW: {}\n", .{glfw.isVulkanSupported()});
const size = window.getSize();
const fb_size = window.getFramebufferSize();
std.debug.print("window size: {}x{}\n", .{ size[0], size[1] });
std.debug.print("framebuffer size: {}x{}\n", .{ fb_size[0], fb_size[1] });
std.debug.print("window visible attr: {}\n", .{window.getAttribute(.visible)});
window.show(); window.show();
window.requestAttention(); window.requestAttention();
const base = vk.BaseWrapper.load(glfw.getInstanceProcAddress); // ---------------------------------------------------------------------
const required_extensions = try glfw.getRequiredInstanceExtensions(); // Vulkan instance setup
// ---------------------------------------------------------------------
const app_info = vk.ApplicationInfo{ const vc = try initInstance("zig-chess");
.p_application_name = "zig-chess", defer vc.destroy();
.application_version = 1,
.p_engine_name = "zig-chess",
.engine_version = 1,
.api_version = @bitCast(vk.makeApiVersion(0, 1, 2, 0)),
};
const instance_create_info = vk.InstanceCreateInfo{
.p_application_info = &app_info,
.enabled_extension_count = @intCast(required_extensions.len),
.pp_enabled_extension_names = required_extensions.ptr,
};
const instance = try base.createInstance(&instance_create_info, null);
std.debug.print("required instance extensions:\n", .{});
for (required_extensions) |extension| {
std.debug.print(" {s}\n", .{extension});
}
std.debug.print("Created Vulkan Instance\n", .{});
const vki = vk.InstanceWrapper.load(instance, base.dispatch.vkGetInstanceProcAddr.?);
defer vki.destroyInstance(instance, null);
// ---------------------------------------------------------------------
// Window surface
//
// This connects the platform window to Vulkan presentation. Swapchain
// support and present-capable queue families are queried against this
// surface.
// ---------------------------------------------------------------------
var surface: vk.SurfaceKHR = undefined; var surface: vk.SurfaceKHR = undefined;
try glfw.createWindowSurface(instance, window, null, &surface); try glfw.createWindowSurface(vc.instance, window, null, &surface);
defer vki.destroySurfaceKHR(instance, surface, null); defer vc.vki.destroySurfaceKHR(vc.instance, surface, null);
std.debug.print("Created Vulkan surface\n", .{}); std.debug.print("Created Vulkan surface\n", .{});
const physical_devices = try vki.enumeratePhysicalDevicesAlloc(instance, std.heap.page_allocator); // ---------------------------------------------------------------------
// Physical device and queue-family discovery
// ---------------------------------------------------------------------
const physical_devices = try vc.vki.enumeratePhysicalDevicesAlloc(vc.instance, std.heap.page_allocator);
defer std.heap.page_allocator.free(physical_devices); defer std.heap.page_allocator.free(physical_devices);
//try debugPhysicalGPUs(vc, physical_devices, surface);
std.debug.print("physical devices: {}\n", .{physical_devices.len}); // TODO(refactor): this is intentionally temporary and machine-specific.
// Replace it with selection logic that searches for a device/queue pair
for (physical_devices, 0..) |physical_device, i| { // where queue_flags.graphics_bit is true and surface present support is
const props = vki.getPhysicalDeviceProperties(physical_device); // true. Hardcoding physical_devices[1] will fail on many systems.
std.debug.print("device {}: {s}\n", .{ i, std.mem.sliceTo(&props.device_name, 0) });
const queue_families = try vki.getPhysicalDeviceQueueFamilyPropertiesAlloc(
physical_device,
std.heap.page_allocator,
);
defer std.heap.page_allocator.free(queue_families);
for (queue_families, 0..) |queue_family, queue_index| {
const supports_graphics = queue_family.queue_flags.graphics_bit;
const supports_compute = queue_family.queue_flags.compute_bit;
const supports_transfer = queue_family.queue_flags.transfer_bit;
const supports_present = try vki.getPhysicalDeviceSurfaceSupportKHR(
physical_device,
@intCast(queue_index),
surface,
);
std.debug.print(
" queue {}: count={}, graphics={}, compute={}, transfer={}, present={}\n",
.{
queue_index,
queue_family.queue_count,
supports_graphics,
supports_compute,
supports_transfer,
supports_present,
},
);
}
}
const selected_physical_device = physical_devices[1]; const selected_physical_device = physical_devices[1];
const graphics_queue_family_index: u32 = 0; const graphics_queue_family_index: u32 = 0;
const selected_props = vki.getPhysicalDeviceProperties(selected_physical_device); const selected_props = vc.vki.getPhysicalDeviceProperties(selected_physical_device);
std.debug.print( std.debug.print(
"selected device: {s}, queue family {}\n", "selected device: {s}, queue family {}\n",
.{ .{
@ -121,6 +77,15 @@ pub fn main() !void {
}, },
); );
// ---------------------------------------------------------------------
// Logical device and queue
//
// The logical device enables VK_KHR_swapchain so we can present rendered
// images. We request one queue from the selected queue family.
//
// Refactor direction: createLogicalDevice() should return the device,
// DeviceWrapper, and graphics/present queue handles or indices.
// ---------------------------------------------------------------------
const queue_priority: f32 = 1.0; const queue_priority: f32 = 1.0;
const queue_create_info = vk.DeviceQueueCreateInfo{ const queue_create_info = vk.DeviceQueueCreateInfo{
@ -140,17 +105,26 @@ pub fn main() !void {
.pp_enabled_extension_names = &device_extensions, .pp_enabled_extension_names = &device_extensions,
}; };
const device = try vki.createDevice(selected_physical_device, &device_create_info, null); const device = try vc.vki.createDevice(selected_physical_device, &device_create_info, null);
std.debug.print("created logical device\n", .{}); std.debug.print("created logical device\n", .{});
const vkd = vk.DeviceWrapper.load(device, vki.dispatch.vkGetDeviceProcAddr.?); const vkd = vk.DeviceWrapper.load(device, vc.vki.dispatch.vkGetDeviceProcAddr.?);
defer vkd.destroyDevice(device, null); defer vkd.destroyDevice(device, null);
const graphics_queue = vkd.getDeviceQueue(device, graphics_queue_family_index, 0); const graphics_queue = vkd.getDeviceQueue(device, graphics_queue_family_index, 0);
std.debug.print("retrieved graphics queue\n", .{}); std.debug.print("retrieved graphics queue\n", .{});
const surface_caps = try vki.getPhysicalDeviceSurfaceCapabilitiesKHR( // ---------------------------------------------------------------------
// Swapchain support query and choice of format/present mode/extent
//
// These values describe how the surface can be presented to. FIFO is the
// safe baseline because Vulkan requires it to be supported.
//
// Refactor direction: create a chooseSwapchainSettings() helper returning
// the chosen format, present mode, extent, and image count.
// ---------------------------------------------------------------------
const surface_caps = try vc.vki.getPhysicalDeviceSurfaceCapabilitiesKHR(
selected_physical_device, selected_physical_device,
surface, surface,
); );
@ -171,7 +145,7 @@ pub fn main() !void {
}, },
); );
const surface_formats = try vki.getPhysicalDeviceSurfaceFormatsAllocKHR( const surface_formats = try vc.vki.getPhysicalDeviceSurfaceFormatsAllocKHR(
selected_physical_device, selected_physical_device,
surface, surface,
std.heap.page_allocator, std.heap.page_allocator,
@ -186,7 +160,7 @@ pub fn main() !void {
); );
} }
const present_modes = try vki.getPhysicalDeviceSurfacePresentModesAllocKHR( const present_modes = try vc.vki.getPhysicalDeviceSurfacePresentModesAllocKHR(
selected_physical_device, selected_physical_device,
surface, surface,
std.heap.page_allocator, std.heap.page_allocator,
@ -261,6 +235,16 @@ pub fn main() !void {
.clipped = .true, .clipped = .true,
}; };
// ---------------------------------------------------------------------
// Swapchain creation
//
// The swapchain owns the presentable images. Anything tied to its image
// format or extent must be recreated when the window is resized or Vulkan
// reports the swapchain is out of date/suboptimal.
//
// Refactor direction: group swapchain, images, image views, framebuffers,
// format, and extent into a SwapchainResources struct.
// ---------------------------------------------------------------------
const swapchain = try vkd.createSwapchainKHR(device, &swapchain_create_info, null); const swapchain = try vkd.createSwapchainKHR(device, &swapchain_create_info, null);
defer vkd.destroySwapchainKHR(device, swapchain, null); defer vkd.destroySwapchainKHR(device, swapchain, null);
@ -275,6 +259,8 @@ pub fn main() !void {
std.debug.print("swapchain images: {}\n", .{swapchain_images.len}); std.debug.print("swapchain images: {}\n", .{swapchain_images.len});
// Each swapchain image needs an image view so it can be used as a render
// pass attachment.
const swapchain_image_views = try std.heap.page_allocator.alloc( const swapchain_image_views = try std.heap.page_allocator.alloc(
vk.ImageView, vk.ImageView,
swapchain_images.len, swapchain_images.len,
@ -318,6 +304,17 @@ pub fn main() !void {
std.debug.print("created swapchain image views: {}\n", .{swapchain_image_views.len}); std.debug.print("created swapchain image views: {}\n", .{swapchain_image_views.len});
// ---------------------------------------------------------------------
// Render pass
//
// This render pass has one color attachment: the current swapchain image.
// load_op=.clear means each frame starts by clearing the image; final
// layout present_src_khr means the image is ready for presentation.
//
// Refactor direction: this can become createRenderPass(device, format).
// Later, when drawing pieces or UI, this may gain depth/stencil or change
// if we move to dynamic rendering.
// ---------------------------------------------------------------------
const color_attachment = vk.AttachmentDescription{ const color_attachment = vk.AttachmentDescription{
.format = chosen_surface_format.format, .format = chosen_surface_format.format,
.samples = .{ .@"1_bit" = true }, .samples = .{ .@"1_bit" = true },
@ -369,6 +366,12 @@ pub fn main() !void {
std.debug.print("created render pass\n", .{}); std.debug.print("created render pass\n", .{});
// ---------------------------------------------------------------------
// Framebuffers
//
// A framebuffer binds the render pass attachment description to a concrete
// image view. We need one framebuffer for each swapchain image view.
// ---------------------------------------------------------------------
const framebuffers = try std.heap.page_allocator.alloc( const framebuffers = try std.heap.page_allocator.alloc(
vk.Framebuffer, vk.Framebuffer,
swapchain_image_views.len, swapchain_image_views.len,
@ -402,6 +405,16 @@ pub fn main() !void {
std.debug.print("created framebuffers: {}\n", .{framebuffers.len}); std.debug.print("created framebuffers: {}\n", .{framebuffers.len});
// ---------------------------------------------------------------------
// Command pool and command buffers
//
// Command buffers record GPU work. Right now each swapchain image gets one
// pre-recorded command buffer that only clears the image.
//
// Refactor direction: for frame generation, introduce recordCommandBuffer()
// and call it per frame after acquiring the image. That will make dynamic
// board drawing, highlights, and resize handling easier to reason about.
// ---------------------------------------------------------------------
const command_pool_create_info = vk.CommandPoolCreateInfo{ const command_pool_create_info = vk.CommandPoolCreateInfo{
.flags = .{ .flags = .{
.reset_command_buffer_bit = true, .reset_command_buffer_bit = true,
@ -443,6 +456,8 @@ pub fn main() !void {
try vkd.beginCommandBuffer(command_buffer, &begin_info); try vkd.beginCommandBuffer(command_buffer, &begin_info);
// This is the only "drawing" currently happening: begin a render pass
// and clear the swapchain image. The embedded shaders are not used yet.
const clear_color = vk.ClearValue{ const clear_color = vk.ClearValue{
.color = .{ .color = .{
.float_32 = .{ 0.02, 0.02, 0.08, 1.0 }, .float_32 = .{ 0.02, 0.02, 0.08, 1.0 },
@ -473,6 +488,17 @@ pub fn main() !void {
std.debug.print("recorded command buffers\n", .{}); std.debug.print("recorded command buffers\n", .{});
// ---------------------------------------------------------------------
// Synchronization objects
//
// The image-available semaphore is signaled when acquireNextImageKHR has a
// swapchain image ready. The render-finished semaphore is signaled when GPU
// rendering completes and presentation may wait on it. The fence lets the
// CPU wait until submitted GPU work for this frame is done.
//
// Refactor direction: use arrays for 2 frames in flight, e.g.
// image_available[2], render_finished[2], in_flight_fences[2].
// ---------------------------------------------------------------------
const semaphore_create_info = vk.SemaphoreCreateInfo{}; const semaphore_create_info = vk.SemaphoreCreateInfo{};
const image_available_semaphore = try vkd.createSemaphore( const image_available_semaphore = try vkd.createSemaphore(
@ -504,6 +530,22 @@ pub fn main() !void {
std.debug.print("created synchronization objects\n", .{}); std.debug.print("created synchronization objects\n", .{});
// ---------------------------------------------------------------------
// Single-frame acquire/submit/present
//
// This renders exactly one frame before entering the event loop. To turn
// this into frame generation, move this whole block into drawFrame() and
// call it from the window loop below.
//
// Per-frame shape:
// 1. wait/reset the in-flight fence
// 2. acquire the next swapchain image
// 3. submit the command buffer for that image
// 4. present that image
//
// Later this block must handle out-of-date/suboptimal swapchains and call
// recreateSwapchainResources().
// ---------------------------------------------------------------------
const wait_fences = [_]vk.Fence{in_flight_fence}; const wait_fences = [_]vk.Fence{in_flight_fence};
_ = try vkd.waitForFences(device, &wait_fences, .true, std.math.maxInt(u64)); _ = try vkd.waitForFences(device, &wait_fences, .true, std.math.maxInt(u64));
try vkd.resetFences(device, &wait_fences); try vkd.resetFences(device, &wait_fences);
@ -557,7 +599,113 @@ pub fn main() !void {
std.debug.print("presented one frame\n", .{}); std.debug.print("presented one frame\n", .{});
// ---------------------------------------------------------------------
// Event loop
//
// Currently this only keeps the window alive after the one presented frame.
// Next rendering milestone: call drawFrame() each iteration after polling
// events, then wait for the device to be idle before cleanup on exit.
// ---------------------------------------------------------------------
while (!window.shouldClose()) { while (!window.shouldClose()) {
glfw.pollEvents(); glfw.pollEvents();
} }
} }
fn initWindow(x: c_int, y: c_int, name: [:0]const u8) !*glfw.Window {
try glfw.init();
errdefer glfw.terminate();
glfw.windowHint(.client_api, .no_api);
const window = try glfw.Window.create(
x,
y,
name,
null,
null,
);
std.debug.print("GLFW platform: {any}\n", .{glfw.getPlatform()});
std.debug.print("Vulkan supported by GLFW: {}\n", .{glfw.isVulkanSupported()});
const size = window.getSize();
const fb_size = window.getFramebufferSize();
std.debug.print("Window size: {}x{}\n", .{ size[0], size[1] });
std.debug.print("Framebuffer size: {}x{}\n", .{ fb_size[0], fb_size[1] });
std.debug.print("Window visible: {}\n", .{window.getAttribute(.visible)});
return window;
}
fn initInstance(name: [:0]const u8) !VulkanContext {
const base = vk.BaseWrapper.load(glfw.getInstanceProcAddress);
const required_extensions = try glfw.getRequiredInstanceExtensions();
std.debug.print("Required instance extensions:\n", .{});
for (required_extensions) |extension| {
std.debug.print(" {s}\n", .{extension});
}
const app_info = vk.ApplicationInfo{
.p_application_name = name,
.application_version = 1,
.p_engine_name = name,
.engine_version = 1,
.api_version = @bitCast(vk.makeApiVersion(0, 1, 2, 0)),
};
const instance_create_info = vk.InstanceCreateInfo{
.p_application_info = &app_info,
.enabled_extension_count = @intCast(required_extensions.len),
.pp_enabled_extension_names = required_extensions.ptr,
};
const instance = try base.createInstance(&instance_create_info, null);
std.debug.print("Created Vulkan Instance", .{});
const vki = vk.InstanceWrapper.load(instance, base.dispatch.vkGetInstanceProcAddr.?);
return .{
.instance = instance,
.base = base,
.vki = vki,
};
}
fn debugPhysicalGPUs(vc: VulkanContext, physical_devices: []vk.PhysicalDevice, surface: vk.SurfaceKHR) !void {
std.debug.print("physical devices: {}\n", .{physical_devices.len});
for (physical_devices, 0..) |physical_device, i| {
const props = vc.vki.getPhysicalDeviceProperties(physical_device);
std.debug.print("device {}: {s}\n", .{ i, std.mem.sliceTo(&props.device_name, 0) });
const queue_families = try vc.vki.getPhysicalDeviceQueueFamilyPropertiesAlloc(
physical_device,
std.heap.page_allocator,
);
defer std.heap.page_allocator.free(queue_families);
for (queue_families, 0..) |queue_family, queue_index| {
const supports_graphics = queue_family.queue_flags.graphics_bit;
const supports_compute = queue_family.queue_flags.compute_bit;
const supports_transfer = queue_family.queue_flags.transfer_bit;
const supports_present = try vc.vki.getPhysicalDeviceSurfaceSupportKHR(
physical_device,
@intCast(queue_index),
surface,
);
std.debug.print(
" queue {}: count={}, graphics={}, compute={}, transfer={}, present={}\n",
.{
queue_index,
queue_family.queue_count,
supports_graphics,
supports_compute,
supports_transfer,
supports_present,
},
);
}
}
}

378
tools/fen_to_board_state.py Normal file
View File

@ -0,0 +1,378 @@
#!/usr/bin/env python3
"""Reference FEN-to-[9]u32 board-state encoder for tests and design experiments.
The full board state is represented as nine 32-bit words:
state[0] = rank 1
state[1] = rank 2
...
state[7] = rank 8
state[8] = metadata
Each rank word stores eight 4-bit square values:
bits 0..3 = file a
bits 4..7 = file b
...
bits 28..31 = file h
Piece encoding uses bit 3 for color and bits 0..2 for piece type:
0 = empty
black pawn/knight/bishop/rook/queen/king = 1..6
white pawn/knight/bishop/rook/queen/king = 9..14
Metadata in state[8]:
bit 0 active color: 1 = white, 0 = black
bits 1..4 castling rights
bits 5..11 en passant: bit 6 valid, bits 0..5 square index
bits 12..18 halfmove clock
bits 19..26 fullmove counter
bits 27..31 reserved
"""
from dataclasses import dataclass
PIECE_TYPE = {
"p": 1,
"n": 2,
"b": 3,
"r": 4,
"q": 5,
"k": 6,
}
CASTLING_BITS = {
"K": 0b1000,
"Q": 0b0100,
"k": 0b0010,
"q": 0b0001,
}
ACTIVE_COLOR_SHIFT = 0
CASTLING_RIGHTS_SHIFT = 1
EN_PASSANT_SHIFT = 5
HALFMOVE_CLOCK_SHIFT = 12
FULLMOVE_COUNTER_SHIFT = 19
ACTIVE_COLOR_MASK = 0b1
CASTLING_RIGHTS_MASK = 0b1111
EN_PASSANT_MASK = 0b1111111
HALFMOVE_CLOCK_MASK = 0b1111111
FULLMOVE_COUNTER_MASK = 0b11111111
@dataclass(frozen=True)
class BoardState:
state: tuple[int, ...] # 9 u32 values: ranks 1..8, then metadata
def __post_init__(self) -> None:
if len(self.state) != 9:
raise ValueError("BoardState must contain exactly 9 words")
for word in self.state:
if not 0 <= word <= 0xFFFFFFFF:
raise ValueError("BoardState words must fit in u32")
@property
def active_color(self) -> int:
return (self.state[8] >> ACTIVE_COLOR_SHIFT) & ACTIVE_COLOR_MASK
@property
def castling_rights(self) -> int:
return (self.state[8] >> CASTLING_RIGHTS_SHIFT) & CASTLING_RIGHTS_MASK
@property
def en_passant(self) -> int:
return (self.state[8] >> EN_PASSANT_SHIFT) & EN_PASSANT_MASK
@property
def halfmove_clock(self) -> int:
return (self.state[8] >> HALFMOVE_CLOCK_SHIFT) & HALFMOVE_CLOCK_MASK
@property
def fullmove_counter(self) -> int:
return (self.state[8] >> FULLMOVE_COUNTER_SHIFT) & FULLMOVE_COUNTER_MASK
def square_index(file_index: int, rank_num: int) -> int:
"""
a1 = 0
h1 = 7
a8 = 56
h8 = 63
"""
if not 0 <= file_index <= 7:
raise ValueError("file_index must be 0-7")
if not 1 <= rank_num <= 8:
raise ValueError("rank_num must be 1-8")
return (rank_num - 1) * 8 + file_index
def algebraic_to_square(square: str) -> int:
if len(square) != 2:
raise ValueError(f"Invalid square: {square}")
file_ch = square[0]
rank_ch = square[1]
if file_ch < "a" or file_ch > "h":
raise ValueError(f"Invalid file: {square}")
if rank_ch < "1" or rank_ch > "8":
raise ValueError(f"Invalid rank: {square}")
return square_index(ord(file_ch) - ord("a"), int(rank_ch))
def square_to_algebraic(square: int) -> str:
if not 0 <= square <= 63:
raise ValueError("square must be 0-63")
file_index = square % 8
rank_num = (square // 8) + 1
return f"{chr(ord('a') + file_index)}{rank_num}"
def encode_piece(ch: str) -> int:
color = 1 if ch.isupper() else 0
piece = PIECE_TYPE[ch.lower()]
return (color << 3) | piece
def get_square(state: tuple[int, ...], algebraic: str) -> int:
idx = algebraic_to_square(algebraic)
rank_index = idx // 8
file_index = idx % 8
return (state[rank_index] >> (file_index * 4)) & 0xF
def has_piece(state: tuple[int, ...], algebraic: str, piece: str) -> bool:
return get_square(state, algebraic) == encode_piece(piece)
def parse_board_placement(placement: str) -> list[int]:
ranks_u32 = [0] * 8
ranks = placement.split("/")
if len(ranks) != 8:
raise ValueError("FEN board placement must contain 8 ranks")
# FEN is rank 8 to rank 1.
for fen_rank_index, rank_text in enumerate(ranks):
rank_num = 8 - fen_rank_index
rank_index = rank_num - 1
file_index = 0
for ch in rank_text:
if ch.isdigit():
file_index += int(ch)
continue
if ch.lower() not in PIECE_TYPE:
raise ValueError(f"Invalid piece character: {ch}")
if file_index >= 8:
raise ValueError(f"Too many squares in rank: {rank_text}")
ranks_u32[rank_index] |= encode_piece(ch) << (file_index * 4)
file_index += 1
if file_index != 8:
raise ValueError(f"Rank does not contain exactly 8 squares: {rank_text}")
return ranks_u32
def en_passant_is_capturable(state: tuple[int, ...], ep_square: int, active_color: int) -> bool:
"""
active_color is side to move.
If white is to move, black just advanced a pawn two squares,
so the en passant target should be on rank 6 and a white pawn
must be on an adjacent file on rank 5.
If black is to move, white just advanced a pawn two squares,
so the en passant target should be on rank 3 and a black pawn
must be on an adjacent file on rank 4.
"""
file_index = ep_square % 8
rank_num = (ep_square // 8) + 1
if active_color == 1:
if rank_num != 6:
return False
pawn_rank = 5
pawn_char = "P"
else:
if rank_num != 3:
return False
pawn_rank = 4
pawn_char = "p"
for adjacent_file in (file_index - 1, file_index + 1):
if 0 <= adjacent_file <= 7:
adjacent_square = square_to_algebraic(
square_index(adjacent_file, pawn_rank)
)
if has_piece(state, adjacent_square, pawn_char):
return True
return False
def encode_en_passant(ep: str, state: tuple[int, ...], active_color: int) -> int:
"""
7-bit encoding:
bit 6: valid
bits 0-5: square index, a1=0 through h8=63
En passant is only stored if an opposing pawn can actually capture.
"""
if ep == "-":
return 0
ep_square = algebraic_to_square(ep)
if not en_passant_is_capturable(state, ep_square, active_color):
return 0
return (1 << 6) | ep_square
def decode_en_passant(ep_value: int) -> str | None:
if ((ep_value >> 6) & 1) == 0:
return None
square = ep_value & 0x3F
return square_to_algebraic(square)
def encode_metadata(
active_color: int,
castling_rights: int,
en_passant: int,
halfmove_clock: int,
fullmove_counter: int,
) -> int:
if not 0 <= active_color <= ACTIVE_COLOR_MASK:
raise ValueError("Active color must fit in 1 bit")
if not 0 <= castling_rights <= CASTLING_RIGHTS_MASK:
raise ValueError("Castling rights must fit in 4 bits")
if not 0 <= en_passant <= EN_PASSANT_MASK:
raise ValueError("En passant must fit in 7 bits")
if not 0 <= halfmove_clock <= HALFMOVE_CLOCK_MASK:
raise ValueError("Half-move clock must fit in 7 bits")
if not 0 <= fullmove_counter <= FULLMOVE_COUNTER_MASK:
raise ValueError("Full-move counter must fit in 8 bits")
return (
(active_color << ACTIVE_COLOR_SHIFT)
| (castling_rights << CASTLING_RIGHTS_SHIFT)
| (en_passant << EN_PASSANT_SHIFT)
| (halfmove_clock << HALFMOVE_CLOCK_SHIFT)
| (fullmove_counter << FULLMOVE_COUNTER_SHIFT)
)
def parse_fen(fen: str) -> BoardState:
parts = fen.strip().split()
if len(parts) != 6:
raise ValueError("FEN must contain exactly 6 fields")
placement, active, castling, ep, halfmove, fullmove = parts
state = parse_board_placement(placement)
if active == "w":
active_color = 1
elif active == "b":
active_color = 0
else:
raise ValueError(f"Invalid active color: {active}")
castling_rights = 0
if castling != "-":
for ch in castling:
if ch not in CASTLING_BITS:
raise ValueError(f"Invalid castling right: {ch}")
castling_rights |= CASTLING_BITS[ch]
halfmove_clock = int(halfmove)
fullmove_counter = int(fullmove)
en_passant = encode_en_passant(ep, tuple(state), active_color)
metadata = encode_metadata(
active_color=active_color,
castling_rights=castling_rights,
en_passant=en_passant,
halfmove_clock=halfmove_clock,
fullmove_counter=fullmove_counter,
)
state.append(metadata)
return BoardState(tuple(state))
def run_tests() -> None:
start_fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
start = parse_fen(start_fen)
assert start.active_color == 1
assert start.castling_rights == 0b1111
assert start.en_passant == 0
assert start.halfmove_clock == 0
assert start.fullmove_counter == 1
assert get_square(start.state, "a1") == encode_piece("R")
assert get_square(start.state, "e1") == encode_piece("K")
assert get_square(start.state, "a8") == encode_piece("r")
assert get_square(start.state, "e8") == encode_piece("k")
assert get_square(start.state, "e4") == 0
assert start.state[0] == 0xCABEDBAC # rank 1
assert start.state[1] == 0x99999999 # rank 2
assert start.state[6] == 0x11111111 # rank 7
assert start.state[7] == 0x42365324 # rank 8
assert start.state[8] == 0x0008001F
mid_fen = "r1bqkbnr/pppp1ppp/2n5/4p3/3P4/5N2/PPP2PPP/RNBQKB1R w KQkq e3 4 5"
mid = parse_fen(mid_fen)
assert mid.active_color == 1
assert mid.castling_rights == 0b1111
assert mid.en_passant == 0 # e3 exists in FEN, but no white pawn can capture there
assert mid.halfmove_clock == 4
assert mid.fullmove_counter == 5
assert get_square(mid.state, "a8") == encode_piece("r")
assert get_square(mid.state, "c8") == encode_piece("b")
assert get_square(mid.state, "c6") == encode_piece("n")
assert get_square(mid.state, "e5") == encode_piece("p")
assert get_square(mid.state, "d4") == encode_piece("P")
assert get_square(mid.state, "f3") == encode_piece("N")
assert get_square(mid.state, "g1") == 0
capturable_ep_fen = "8/8/8/3Pp3/8/8/8/8 w - e6 0 1"
capturable = parse_fen(capturable_ep_fen)
assert decode_en_passant(capturable.en_passant) == "e6"
print("All tests passed.")
def format_state_hex(state: tuple[int, ...]) -> str:
return "[" + ", ".join(f"0x{word:08x}" for word in state) + "]"
if __name__ == "__main__":
run_tests()
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
board_state = parse_fen(fen)
print(board_state)
print(f"state hex: {format_state_hex(board_state.state)}")