From c3f4c80631f3a35b10ec9d82e9840b221adcb1e0 Mon Sep 17 00:00:00 2001 From: dece Date: Thu, 11 Jun 2020 20:40:08 +0200 Subject: [PATCH] engine: implement basic minimax --- src/board.rs | 9 ++++ src/engine.rs | 112 ++++++++++++++++++++++++++++++++++++-------------- src/rules.rs | 3 +- 3 files changed, 92 insertions(+), 32 deletions(-) diff --git a/src/board.rs b/src/board.rs index afb73d1..cc4b9f5 100644 --- a/src/board.rs +++ b/src/board.rs @@ -59,6 +59,15 @@ pub const fn is_piece(square: u8, piece: u8) -> bool { has_flag(square, piece) } #[inline] pub const fn opposite(color: u8) -> u8 { color ^ SQ_COLOR_MASK } +/// Pretty-print a color. +pub fn color_to_string(color: u8) -> String { + match color { + SQ_WH => "white".to_string(), + SQ_BL => "black".to_string(), + _ => panic!("Unknown color {}", color), + } +} + /// Minimum allowed value for stored Pos components. pub const POS_MIN: i8 = 0; /// Maximum allowed value for stored Pos components. diff --git a/src/engine.rs b/src/engine.rs index d4052dc..1d6968f 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -9,6 +9,9 @@ use crate::notation; use crate::rules; use crate::uci; +const MIN_F32: f32 = std::f32::NEG_INFINITY; +const MAX_F32: f32 = std::f32::INFINITY; + /// Analysis engine. pub struct Engine { /// Debug mode, log some data. @@ -229,13 +232,13 @@ impl Engine { } self.working.store(true, atomic::Ordering::Relaxed); - let node = self.node.clone(); + let mut node = self.node.clone(); let args = args.clone(); let working = self.working.clone(); let tx = match &self.mode { Mode::Uci(_, _, tx) => tx.clone(), _ => return }; let debug = self.debug; thread::spawn(move || { - analyze(&node, &args, working, tx, debug); + analyze(&mut node, &args, working, tx, debug); }); } @@ -299,7 +302,7 @@ impl Engine { } fn analyze( - node: &Node, + node: &mut Node, _args: &WorkArgs, working: Arc, tx: mpsc::Sender, @@ -309,38 +312,28 @@ fn analyze( return; } if debug { - let state_str = format!("Analysing node: {:?}", node); - tx.send(Cmd::Log(state_str)).unwrap(); - let mut s = vec!(); - board::draw(&node.board, &mut s); - let draw_str = String::from_utf8_lossy(&s).to_string(); - tx.send(Cmd::Log(draw_str)).unwrap(); - } - - let moves = rules::get_player_moves(&node.board, &node.game_state, true); - if debug { - let moves_str = format!("Legal moves: {}", notation::move_list_to_string(&moves)); + tx.send(Cmd::Log(format!("\tAnalyzing node:\n{}", node))).unwrap(); + let moves = rules::get_player_moves(&node.board, &node.game_state, true); + let moves_str = format!("\tLegal moves: {}", notation::move_list_to_string(&moves)); tx.send(Cmd::Log(moves_str)).unwrap(); } - let mut best_e = if board::is_white(node.game_state.color) { -999.0 } else { 999.0 }; - let mut best_move = None; - for m in moves { - let mut board = node.board.clone(); - let mut game_state = node.game_state.clone(); - rules::apply_move_to(&mut board, &mut game_state, &m); - let stats = board::compute_stats(&board); - let e = evaluate(&stats); - if - (board::is_white(node.game_state.color) && e > best_e) || - (board::is_black(node.game_state.color) && e < best_e) - { - best_e = e; - best_move = Some(m.clone()); - } + let (max_score, best_move) = minimax(node, 0, 3, board::is_white(node.game_state.color)); + + if best_move.is_some() { + let log_str = format!( + "\tBest move {} evaluated {}", + notation::move_to_string(&best_move.unwrap()), max_score + ); + tx.send(Cmd::Log(log_str)).unwrap(); + tx.send(Cmd::TmpBestMove(best_move)).unwrap(); + } else { + // If no best move could be found, checkmate is unavoidable; send the first legal move. + tx.send(Cmd::Log("Checkmate is unavoidable.".to_string())).unwrap(); + let moves = rules::get_player_moves(&node.board, &node.game_state, true); + let m = if moves.len() > 0 { Some(moves[0]) } else { None }; + tx.send(Cmd::TmpBestMove(m)).unwrap(); } - thread::sleep(time::Duration::from_millis(500u64)); - tx.send(Cmd::TmpBestMove(best_move)).unwrap(); // thread::sleep(time::Duration::from_secs(1)); // for _ in 0..4 { @@ -353,6 +346,39 @@ fn analyze( } +fn minimax( + node: &mut Node, + depth: u32, + max_depth: u32, + maximizing: bool +) -> (f32, Option) { + if depth == max_depth { + board::compute_stats_into(&node.board, &mut node.stats); + return (evaluate(&node.stats), None); + } + let mut minmax = if maximizing { MIN_F32 } else { MAX_F32 }; + let mut minmax_move = None; + let moves = rules::get_player_moves(&node.board, &node.game_state, true); + for m in moves { + let mut sub_node = node.clone(); + rules::apply_move_to(&mut sub_node.board, &mut sub_node.game_state, &m); + if maximizing { + let (score, _) = minimax(&mut sub_node, depth + 1, max_depth, false); + if score > minmax { + minmax = score; + minmax_move = Some(m); + } + } else { + let (score, _) = minimax(&mut sub_node, depth + 1, max_depth, true); + if score < minmax { + minmax = score; + minmax_move = Some(m); + } + } + } + (minmax, minmax_move) +} + fn evaluate(stats: &(board::BoardStats, board::BoardStats)) -> f32 { let (ws, bs) = stats; @@ -373,6 +399,30 @@ fn evaluate(stats: &(board::BoardStats, board::BoardStats)) -> f32 { #[cfg(test)] mod tests { use super::*; + use board::pos; + + #[test] + fn test_minimax() { + let mut node = Node::new(); + node.game_state.castling = 0; + + // White mates in 1 move, queen to d7. + board::set_square(&mut node.board, &pos("a1"), board::SQ_WH_K); + board::set_square(&mut node.board, &pos("c6"), board::SQ_WH_P); + board::set_square(&mut node.board, &pos("h7"), board::SQ_WH_Q); + board::set_square(&mut node.board, &pos("d8"), board::SQ_BL_K); + let (_, m) = minimax(&mut node, 0, 2, true); + assert_eq!(m.unwrap(), notation::parse_move("h7d7")); + + // Check that it works for black as well. + board::set_square(&mut node.board, &pos("a1"), board::SQ_BL_K); + board::set_square(&mut node.board, &pos("c6"), board::SQ_BL_P); + board::set_square(&mut node.board, &pos("h7"), board::SQ_BL_Q); + board::set_square(&mut node.board, &pos("d8"), board::SQ_WH_K); + node.game_state.color = board::SQ_BL; + let (_, m) = minimax(&mut node, 0, 2, true); + assert_eq!(m.unwrap(), notation::parse_move("h7d7")); + } #[test] fn test_evaluate() { diff --git a/src/rules.rs b/src/rules.rs index a4591e0..7adb482 100644 --- a/src/rules.rs +++ b/src/rules.rs @@ -40,7 +40,8 @@ impl std::fmt::Display for GameState { - en_passant: {}\n\ - halfmove: {}\n\ - fullmove: {}", - self.color, self.castling, notation::en_passant_to_string(self.en_passant), + color_to_string(self.color), self.castling, + notation::en_passant_to_string(self.en_passant), self.halfmove, self.fullmove ) }