Skip to content

A*算法

简介

A*搜索算法(英文:A*search algorithm,A*读作 A-star),简称 A*算法,是一种在图形平面上,对于有多个节点的路径求出最低通过成本的算法。它属于图遍历(英文:Graph traversal)和最佳优先搜索算法(英文:Best-first search),亦是 BFS 的改进。

定义起点 s,终点 t,从起点(初始状态)开始的距离函数 g(x),到终点(最终状态)的距离函数 h(x)h(x),以及每个点的估价函数 f(x)=g(x)+h(x)

A*算法每次从优先队列中取出一个 f 最小的元素,然后更新相邻的状态。

如果 hh,则 A*算法能找到最优解。

上述条件下,如果 h 满足三角形不等式,则 A*算法不会将重复结点加入队列。

h=0 时,A*算法变为 Dijkstra;当 h=0 并且边权为 1 时变为 BFS。

A*算法适用于搜索空间较大的情况。

例题

八数码

题目描述

在一个 3×3 的网格中,1sim88 个数字和一个 X 恰好不重不漏地分布在这 3×3 的网格中。

例如:

1 2 3
X 4 6
7 5 8

在游戏过程中,可以把 X 与其上、下、左、右四个方向之一的数字交换(如果存在)。

我们的目的是通过交换,使得网格变为如下排列(称为正确排列):

1 2 3
4 5 6
7 8 X

例如,示例中图形就可以通过让 X 先后与右、下、右三个方向的数字交换成功得到正确排列。

交换过程如下:

1 2 3   1 2 3   1 2 3   1 2 3
X 4 6   4 X 6   4 5 6   4 5 6
7 5 8   7 5 8   7 X 8   7 8 X

X 与上下左右方向数字交换的行动记录为 udlr

现在,给你一个初始网格,请你通过最少的移动次数,得到正确排列。

输入格式

输入占一行,将 3×3 的初始网格描绘出来。

例如,如果初始网格如下所示:

1 2 3 
x 4 6 
7 5 8

则输入为:1 2 3 x 4 6 7 5 8

输出格式

输出占一行,包含一个字符串,表示得到正确排列的完整行动记录。

如果答案不唯一,输出任意一种合法方案即可。

如果不存在解决方案,则输出 unsolvable

输入样例:

2  3  4  1  5  x  7  6  8

输出样例

ullddrurdllurdruldr

题解

由于搜索空间较大,有 4 个移动操作,假设我们从起点转移到终点至少需要 n 步,那么最起码我们要搜索 4n 条路径,而且这已经是很理想的状况了,我们在搜索的时候难免会“走几条弯路”(也就是搜那些不可能到的路径,而且搜了很长一段距离停不下来),那么时间复杂度就会达到一个指数级别的增长。

所以采用 **A***算法 或 双向BFS 减少搜索空间。

设计估价函数:

  • 当前状态中每个数与它的目标位置的曼哈顿距离之和。

代码实现

cpp
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>
#include <unordered_map>

using namespace std;

int f(string state)
{
    int res = 0;
    for (int i = 0; i < state.size(); i ++ )
        if (state[i] != 'x')
        {
            int t = state[i] - '1';
            res += abs(i / 3 - t / 3) + abs(i % 3 - t % 3);
        }
    return res;
}

string bfs(string start)
{
    int dx[4] = {-1, 0, 1, 0}, dy[4] = {0, 1, 0, -1};
    char op[4] = {'u', 'r', 'd', 'l'};

    string end = "12345678x";
    unordered_map<string, int> dist;
    unordered_map<string, pair<string, char>> prev;
    priority_queue<pair<int, string>, vector<pair<int, string>>, greater<pair<int, string>>> heap;

    heap.push({f(start), start});
    dist[start] = 0;

    while (heap.size())
    {
        auto t = heap.top();
        heap.pop();

        string state = t.second;

        if (state == end) break;

        int step = dist[state];
        int x, y;
        for (int i = 0; i < state.size(); i ++ )
            if (state[i] == 'x')
            {
                x = i / 3, y = i % 3;
                break;
            }
        string source = state;
        for (int i = 0; i < 4; i ++ )
        {
            int a = x + dx[i], b = y + dy[i];
            if (a >= 0 && a < 3 && b >= 0 && b < 3)
            {
                swap(state[x * 3 + y], state[a * 3 + b]);
                if (!dist.count(state) || dist[state] > step + 1)
                {
                    dist[state] = step + 1;
                    prev[state] = {source, op[i]};
                    heap.push({dist[state] + f(state), state});
                }
                swap(state[x * 3 + y], state[a * 3 + b]);
            }
        }
    }

    string res;
    while (end != start)
    {
        res += prev[end].second;
        end = prev[end].first;
    }
    reverse(res.begin(), res.end());
    return res;
}

int main()
{
    string g, c, seq;
    while (cin >> c)
    {
        g += c;
        if (c != "x") seq += c;
    }

    int t = 0;
    for (int i = 0; i < seq.size(); i ++ )
        for (int j = i + 1; j < seq.size(); j ++ )
            if (seq[i] > seq[j])
                t ++ ;

    if (t % 2) puts("unsolvable");
    else cout << bfs(g) << endl;

    return 0;
}