engine: implement basic minimax
This commit is contained in:
parent
7f1890608d
commit
c3f4c80631
|
@ -59,6 +59,15 @@ pub const fn is_piece(square: u8, piece: u8) -> bool { has_flag(square, piece) }
|
|||
#[inline]
|
||||
pub const fn opposite(color: u8) -> u8 { color ^ SQ_COLOR_MASK }
|
||||
|
||||
/// Pretty-print a color.
|
||||
pub fn color_to_string(color: u8) -> String {
|
||||
match color {
|
||||
SQ_WH => "white".to_string(),
|
||||
SQ_BL => "black".to_string(),
|
||||
_ => panic!("Unknown color {}", color),
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimum allowed value for stored Pos components.
|
||||
pub const POS_MIN: i8 = 0;
|
||||
/// Maximum allowed value for stored Pos components.
|
||||
|
|
112
src/engine.rs
112
src/engine.rs
|
@ -9,6 +9,9 @@ use crate::notation;
|
|||
use crate::rules;
|
||||
use crate::uci;
|
||||
|
||||
const MIN_F32: f32 = std::f32::NEG_INFINITY;
|
||||
const MAX_F32: f32 = std::f32::INFINITY;
|
||||
|
||||
/// Analysis engine.
|
||||
pub struct Engine {
|
||||
/// Debug mode, log some data.
|
||||
|
@ -229,13 +232,13 @@ impl Engine {
|
|||
}
|
||||
|
||||
self.working.store(true, atomic::Ordering::Relaxed);
|
||||
let node = self.node.clone();
|
||||
let mut node = self.node.clone();
|
||||
let args = args.clone();
|
||||
let working = self.working.clone();
|
||||
let tx = match &self.mode { Mode::Uci(_, _, tx) => tx.clone(), _ => return };
|
||||
let debug = self.debug;
|
||||
thread::spawn(move || {
|
||||
analyze(&node, &args, working, tx, debug);
|
||||
analyze(&mut node, &args, working, tx, debug);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -299,7 +302,7 @@ impl Engine {
|
|||
}
|
||||
|
||||
fn analyze(
|
||||
node: &Node,
|
||||
node: &mut Node,
|
||||
_args: &WorkArgs,
|
||||
working: Arc<atomic::AtomicBool>,
|
||||
tx: mpsc::Sender<Cmd>,
|
||||
|
@ -309,38 +312,28 @@ fn analyze(
|
|||
return;
|
||||
}
|
||||
if debug {
|
||||
let state_str = format!("Analysing node: {:?}", node);
|
||||
tx.send(Cmd::Log(state_str)).unwrap();
|
||||
let mut s = vec!();
|
||||
board::draw(&node.board, &mut s);
|
||||
let draw_str = String::from_utf8_lossy(&s).to_string();
|
||||
tx.send(Cmd::Log(draw_str)).unwrap();
|
||||
}
|
||||
|
||||
let moves = rules::get_player_moves(&node.board, &node.game_state, true);
|
||||
if debug {
|
||||
let moves_str = format!("Legal moves: {}", notation::move_list_to_string(&moves));
|
||||
tx.send(Cmd::Log(format!("\tAnalyzing node:\n{}", node))).unwrap();
|
||||
let moves = rules::get_player_moves(&node.board, &node.game_state, true);
|
||||
let moves_str = format!("\tLegal moves: {}", notation::move_list_to_string(&moves));
|
||||
tx.send(Cmd::Log(moves_str)).unwrap();
|
||||
}
|
||||
|
||||
let mut best_e = if board::is_white(node.game_state.color) { -999.0 } else { 999.0 };
|
||||
let mut best_move = None;
|
||||
for m in moves {
|
||||
let mut board = node.board.clone();
|
||||
let mut game_state = node.game_state.clone();
|
||||
rules::apply_move_to(&mut board, &mut game_state, &m);
|
||||
let stats = board::compute_stats(&board);
|
||||
let e = evaluate(&stats);
|
||||
if
|
||||
(board::is_white(node.game_state.color) && e > best_e) ||
|
||||
(board::is_black(node.game_state.color) && e < best_e)
|
||||
{
|
||||
best_e = e;
|
||||
best_move = Some(m.clone());
|
||||
}
|
||||
let (max_score, best_move) = minimax(node, 0, 3, board::is_white(node.game_state.color));
|
||||
|
||||
if best_move.is_some() {
|
||||
let log_str = format!(
|
||||
"\tBest move {} evaluated {}",
|
||||
notation::move_to_string(&best_move.unwrap()), max_score
|
||||
);
|
||||
tx.send(Cmd::Log(log_str)).unwrap();
|
||||
tx.send(Cmd::TmpBestMove(best_move)).unwrap();
|
||||
} else {
|
||||
// If no best move could be found, checkmate is unavoidable; send the first legal move.
|
||||
tx.send(Cmd::Log("Checkmate is unavoidable.".to_string())).unwrap();
|
||||
let moves = rules::get_player_moves(&node.board, &node.game_state, true);
|
||||
let m = if moves.len() > 0 { Some(moves[0]) } else { None };
|
||||
tx.send(Cmd::TmpBestMove(m)).unwrap();
|
||||
}
|
||||
thread::sleep(time::Duration::from_millis(500u64));
|
||||
tx.send(Cmd::TmpBestMove(best_move)).unwrap();
|
||||
|
||||
// thread::sleep(time::Duration::from_secs(1));
|
||||
// for _ in 0..4 {
|
||||
|
@ -353,6 +346,39 @@ fn analyze(
|
|||
|
||||
}
|
||||
|
||||
fn minimax(
|
||||
node: &mut Node,
|
||||
depth: u32,
|
||||
max_depth: u32,
|
||||
maximizing: bool
|
||||
) -> (f32, Option<rules::Move>) {
|
||||
if depth == max_depth {
|
||||
board::compute_stats_into(&node.board, &mut node.stats);
|
||||
return (evaluate(&node.stats), None);
|
||||
}
|
||||
let mut minmax = if maximizing { MIN_F32 } else { MAX_F32 };
|
||||
let mut minmax_move = None;
|
||||
let moves = rules::get_player_moves(&node.board, &node.game_state, true);
|
||||
for m in moves {
|
||||
let mut sub_node = node.clone();
|
||||
rules::apply_move_to(&mut sub_node.board, &mut sub_node.game_state, &m);
|
||||
if maximizing {
|
||||
let (score, _) = minimax(&mut sub_node, depth + 1, max_depth, false);
|
||||
if score > minmax {
|
||||
minmax = score;
|
||||
minmax_move = Some(m);
|
||||
}
|
||||
} else {
|
||||
let (score, _) = minimax(&mut sub_node, depth + 1, max_depth, true);
|
||||
if score < minmax {
|
||||
minmax = score;
|
||||
minmax_move = Some(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
(minmax, minmax_move)
|
||||
}
|
||||
|
||||
fn evaluate(stats: &(board::BoardStats, board::BoardStats)) -> f32 {
|
||||
let (ws, bs) = stats;
|
||||
|
||||
|
@ -373,6 +399,30 @@ fn evaluate(stats: &(board::BoardStats, board::BoardStats)) -> f32 {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use board::pos;
|
||||
|
||||
#[test]
|
||||
fn test_minimax() {
|
||||
let mut node = Node::new();
|
||||
node.game_state.castling = 0;
|
||||
|
||||
// White mates in 1 move, queen to d7.
|
||||
board::set_square(&mut node.board, &pos("a1"), board::SQ_WH_K);
|
||||
board::set_square(&mut node.board, &pos("c6"), board::SQ_WH_P);
|
||||
board::set_square(&mut node.board, &pos("h7"), board::SQ_WH_Q);
|
||||
board::set_square(&mut node.board, &pos("d8"), board::SQ_BL_K);
|
||||
let (_, m) = minimax(&mut node, 0, 2, true);
|
||||
assert_eq!(m.unwrap(), notation::parse_move("h7d7"));
|
||||
|
||||
// Check that it works for black as well.
|
||||
board::set_square(&mut node.board, &pos("a1"), board::SQ_BL_K);
|
||||
board::set_square(&mut node.board, &pos("c6"), board::SQ_BL_P);
|
||||
board::set_square(&mut node.board, &pos("h7"), board::SQ_BL_Q);
|
||||
board::set_square(&mut node.board, &pos("d8"), board::SQ_WH_K);
|
||||
node.game_state.color = board::SQ_BL;
|
||||
let (_, m) = minimax(&mut node, 0, 2, true);
|
||||
assert_eq!(m.unwrap(), notation::parse_move("h7d7"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate() {
|
||||
|
|
|
@ -40,7 +40,8 @@ impl std::fmt::Display for GameState {
|
|||
- en_passant: {}\n\
|
||||
- halfmove: {}\n\
|
||||
- fullmove: {}",
|
||||
self.color, self.castling, notation::en_passant_to_string(self.en_passant),
|
||||
color_to_string(self.color), self.castling,
|
||||
notation::en_passant_to_string(self.en_passant),
|
||||
self.halfmove, self.fullmove
|
||||
)
|
||||
}
|
||||
|
|
Reference in a new issue