(a * b) % p = a % p * b % p
(a * b * c) % p = a % p * b % p * c % p
...
を利用することで、5 つの数の積が整数オーバーフローすることを回避します。ちなみに %
の優先度は *
と等しいです。
例:
(11 * 12 * 13) % 10
= 11 % 10 * 12 % 10 * 13 % 10
= 1 * 12 % 10 * 13 % 10
= 12 % 10 * 13 % 10
= 2 * 13 % 10
= 26 % 10
= 6
#include <iostream>
#include <vector>
int main()
{
// N 個の整数, P で割った余りが Q
int N, P, Q;
std::cin >> N >> P >> Q;
// 整数の入力
std::vector<int> A(N);
for (auto& a : A)
{
std::cin >> a;
}
// 成立した組み合わせの個数
int count = 0;
for (int i = 0; i < (N - 4); ++i)
{
// 整数オーバーフローしないよう long long 型に
const long long a = A[i];
for (int j = (i + 1); j < (N - 3); ++j)
{
const long long b = a * A[j] % P;
for (int k = (j + 1); k < (N - 2); ++k)
{
const long long c = b * A[k] % P;
for (int l = (k + 1); l < (N - 1); ++l)
{
const long long d = c * A[l] % P;
for (int m = (l + 1); m < N; ++m)
{
const long long e = d * A[m] % P;
if (e == Q)
{
++count;
}
}
}
}
}
}
std::cout << count << '\n';
}
この問題における mod 演算は、Barrett reduction というテクニックを使うことで、計算に時間がかかる %
演算子の使用を回避して定数倍高速化できます。AC Library の atcoder::modint
ライブラリには Barrett reduction が実装されているため、それを利用するとコードも短くなり一石二鳥です。
なお、atcoder::modint
は、「C++ 標準の整数型 → atcoder::modint
」の変換に通常の %
演算を使うため、プログラム全体を通してこの変換回数を減らすよう設計しないと速度向上効果が出ません(かえって遅くなることもあります)。次のように、入力された値を最初にすべて atcoder::modint
に変換して格納し、以降で変換が発生しないようにします。
#include <iostream>
#include <vector>
#include <atcoder/modint.hpp> // atcoder::modint
int main()
{
// N 個の整数, P で割った余りが Q
int N, P, Q;
std::cin >> N >> P >> Q;
atcoder::modint::set_mod(P); // mod を設定
// 整数の入力
std::vector<atcoder::modint> A(N);
for (auto& a : A)
{
int t;
std::cin >> t;
a = atcoder::modint{ t };
}
// 成立した組み合わせの個数
int count = 0;
for (int i = 0; i < (N - 4); ++i)
{
const atcoder::modint a = A[i];
for (int j = (i + 1); j < (N - 3); ++j)
{
const atcoder::modint b = a * A[j];
for (int k = (j + 1); k < (N - 2); ++k)
{
const atcoder::modint c = b * A[k];
for (int l = (k + 1); l < (N - 1); ++l)
{
const atcoder::modint d = c * A[l];
for (int m = (l + 1); m < N; ++m)
{
const atcoder::modint e = d * A[m];
if (e.val() == Q)
{
++count;
}
}
}
}
}
}
std::cout << count << '\n';
}