Logo ryp 的博客

博客

P1054 [NOIP2005 提高组] 等价表达式 题解

...
ryp
2025-12-01 12:50:22
She's not square

本文章由 WyOJ Shojo 从洛谷专栏拉取,原发布时间为 2024-06-18 10:33:04

看了一圈题解区只有我用了暴力展开 /jy

实际上并不难,还有点水,也不难调。

相信各位应当都会中缀表达式的计算:我们维护一个符号栈和一个数据栈,然后分类讨论根据优先级计算即可。这一点不管是别的题目还是题解区已经讲的很好了,OI Wiki 上的讲解也不错。

这篇题解的重点是把普通整数上的操作扩展到多项式。

由于只有一个字母,我们可以维护系数向量 $v$,其中 $v_k$ 代表 $a^k$ 的系数,常数就是 $a^0$ 的系数。于是,我们就可以用 $O(k)$ 时间做到加减操作。

然后考虑乘和乘方。先考虑朴素的 $O(n^2)$ 乘法。(通过手工模拟可得),代码为:

for (int i = 0; i < K; i++)
  	for (int j = 0; i + j < K; j++) 
      	r.k[i + j] += k[i] * x.k[j];

如果要求更高的速度,可以用 FFT 优化。但是本题数据范围比较小,所以不用。

乘方可以用经典的快速幂,复杂度是 $O(k^2\log n)$ 的,或者是 $O(k \log^2 n)$(FFT 优化)。

接下来就是代码时间了,其实思路是非常好想的。

但是有几点注意:

  • 数据换行符含有 \r,不能简单地用 getline

  • 最后一个点卡爆 long long,所以需要对这个多项式取模(所以不妨用 NTT

  • 括号可能不匹配 遇到这种情况,应该直接丢掉多余的括号

接下来放看起来还是很不错的代码:

#include <iostream>
#include <stack>
#include <cstring>
#define siz(x) static_cast<int> ((x).size ())
#define fi first
#define se second
#define int long long
using namespace std;
const int N = 30, K = 500, P = 998244353;

class poly {
public:
	int k[K];
	
	poly () { memset (k, 0, sizeof k); }
	poly (const poly &x) { memcpy (k, x.k, sizeof k); }
	poly &operator = (const poly &x) { memcpy (k, x.k, sizeof k); return *this; }
	poly (int x, bool flag = false) { memset (k, 0, sizeof k), k[flag] = x; }
	poly operator + (const poly &x) {
		poly r;
		for (int i = 0; i < K; i++) r.k[i] = (k[i] + x.k[i]) % P;
		return r;
	}
	
	poly operator - (const poly &x) {
		poly r;
		for (int i = 0; i < K; i++) r.k[i] = (k[i] + P - x.k[i]) % P;
		return r;
	}
	
	poly operator * (const poly &x) {
		poly r;
		for (int i = 0; i < K; i++) for (int j = 0; i + j < K; j++) (r.k[i + j] += k[i] * x.k[j] % P) %= P;
		return r;
	}
	
	poly operator ^ (int x) {
		poly r = *this, p = *this;
		--x;
		while (x) {
			if (x & 1) r = r * p;
			x >>= 1, p = p * p;
		}
		return r;
	}
	
	bool operator == (const poly &x) {
		for (int i = 0; i < K; i++) if (k[i] != x.k[i]) return false;
		return true;
	}
	
	friend ostream &operator << (ostream &out, const poly &x) {
		for (int i = K - 1; i >= 2; i--) if (x.k[i]) {
			if (x.k[i] == 1) out << "a^" << i << " + ";
			else if (x.k[i] == -1) out << "-a^" << i << " + ";
			else out << x.k[i] << "a^" << i << " + ";
		}
		if (x.k[1] == 1) out << "a + ";
		else if (x.k[1] == -1) out << "-a + ";
		else if (x.k[1]) out << x.k[1] << "a + ";
		out << x.k[0];
		return out;
	}
};


stack<char> op;
stack<poly> q;

void ins (char c)
{
	poly y, x;
	
	if (c == '(') return;
	if (q.empty ()) cout << "fuck off", exit (0);
	y = q.top (), q.pop ();
	if (q.empty ()) cout << "fuck off", exit (0);
	x = q.top (), q.pop ();
	switch (c) {
	case '+': q.push (x + y); break;
	case '-': q.push (x - y); break;
	case '*': q.push (x * y); break;
	case '^': q.push (x ^ y.k[0]); break;
	default: break;
	}
}

int prio (char c)
{
	switch (c) {
	case '(': return 0;
	case '+': case '-': return 1;
	case '*': return 2;
	case '^': return 3;
	default: return -1;
	}
}

bool space (char c) { return c == ' ' || c == '\n' || c == '\r'; }

poly expand (void)
{
	int level = 0;
	char c;

	while (!q.empty ()) q.pop ();
	while (space (c = getchar ()));
	do {
		if (isdigit (c)) {
			int x = 0;
			do x = (x * 10 % P + c - '0') % P; while (isdigit (c = getchar ()));
			while (space (c) && c != '\n') c = getchar ();
			q.push (poly (x));
			if (c == '\n') break;
			else continue;
		}
		
		switch (c) {
		case 'a': q.push (poly (1, true)); break;
		case '(': op.push ('('), level++; break;
		case ')':
			if (--level < 0) break;
			while (!op.empty () && op.top () != '(') ins (op.top ()), op.pop ();
			op.pop (); break;
		default:
			while (!op.empty () && prio (op.top ()) >= prio (c)) ins (op.top ()), op.pop ();
			op.push (c); break;
		}
		while (space (c = getchar ()) && c != '\n');
	} while (c != '\n');
	
	while (!op.empty ()) ins (op.top ()), op.pop ();
	return q.top ();
}

int read (void)
{
	int res = 0;
	char c;
	while (!isdigit (c = getchar ()));
	do res = res * 10 + c - '0'; while (isdigit (c = getchar ()));
	return res;
}

signed main (void)
{
	int n;
	poly res;
	
	res = expand ();
	n = read ();
	
	for (int i = 1; i <= n; i++) if (expand () == res) putchar ("ABCDEFGHIJKLMNOPQRSTUVWXYZ"[i - 1]);
	putchar ('\n');
	return 0;
}

评论

暂无评论

发表评论

可以用@mike来提到mike这个用户,mike会被高亮显示。如果你真的想打“@”这个字符,请用“@@”。