多倍長整数を実装できたのでまとめた。C++のいい練習になると思って始めたらかなり時間がかかった。
全体のコードはかなり長くなったので最後に載せた。
使用例
作ったクラスで何かできるかを最初に示す。
BigInt fact(BigInt)
と BigInt fib(BigInt)
を int
と同じ方法で定義し、 operator<<
で出力している。
#include "big_int.hpp"
#include <iostream>
BigInt fact(BigInt n) {
if (n == 1)
return 1;
else
return n * fact(n - 1);
}
BigInt fib(BigInt n) {
if (n < 2)
return 1;
BigInt a = 1;
BigInt b = 1;
for (; n > 1; --n) {
BigInt temp = b;
b = a + b;
a = temp;
}
return b;
}
int main() {
std::cout << "fact(100) = " << fact(100) << "\n";
std::cout << "fib(1000) = " << fib(1000) << "\n";
return 0;
}
fact(100) = 93326215443944152681699238856266700490715968264381621468592963895217599993229915608941463976156518286253697920827223758251185210916864000000000000000000000000
fib(1000) = 70330367711422815821835254877183549770181269836358732742604905087154537118196933579742249494562611733487750449241765991088186363265450223647106012053374121273867339111198139373125598767690091902245245323403501
実装したこと
使用例を動かすための最低限の関数を実装した。最終的なヘッダは
class BigInt {
private:
...(略)...
public:
static std::mt19937_64 random_engine;
BigInt();
BigInt(int32_t i);
BigInt(int64_t i);
BigInt(uint32_t i, bool sign = false);
BigInt(uint64_t i, bool sign = false);
BigInt(const std::string &str);
static BigInt of_big_uint(BigUInt abs, bool sign = false);
static BigInt of_vector(std::vector<uint32_t> abs, bool sign = false);
static BigInt random(unsigned n = 5);
uint64_t width() const;
bool sign() const;
bool is_zero() const;
bool is_positive() const;
bool is_negative() const;
BigInt cshift_left(uint64_t shamt) const;
BigInt &shift_left(uint64_t shamt);
BigInt cshift_right(uint64_t shamt) const;
BigInt &shift_right(uint64_t shamt);
BigInt pow(BigInt exp) const;
const uint32_t &operator[](uint64_t i) const;
uint32_t &operator[](uint64_t i);
BigInt &operator++();
BigInt operator++(int);
BigInt &operator--();
BigInt operator--(int);
BigInt operator-() const;
friend bool operator==(const BigInt &lhs, const BigInt &rhs);
friend bool operator!=(const BigInt &lhs, const BigInt &rhs);
friend bool operator<(const BigInt &lhs, const BigInt &rhs);
friend bool operator>(const BigInt &lhs, const BigInt &rhs);
friend bool operator<=(const BigInt &lhs, const BigInt &rhs);
friend bool operator>=(const BigInt &lhs, const BigInt &rhs);
friend BigInt &operator+=(BigInt &lhs, const BigInt &rhs);
friend BigInt &operator-=(BigInt &lhs, const BigInt &rhs);
friend BigInt &operator*=(BigInt &lhs, const BigInt &rhs);
friend BigInt &operator/=(BigInt &lhs, const BigInt &rhs);
friend BigInt &operator%=(BigInt &lhs, const BigInt &rhs);
friend BigInt operator+(const BigInt &lhs, const BigInt &rhs);
friend BigInt operator-(const BigInt &lhs, const BigInt &rhs);
friend BigInt operator*(const BigInt &lhs, const BigInt &rhs);
friend BigInt operator/(const BigInt &lhs, const BigInt &rhs);
friend BigInt operator%(const BigInt &lhs, const BigInt &rhs);
friend std::ostream &operator<<(std::ostream &lhs, const BigInt &rhs);
};
となった。
実装する上でつまずいたところを書く。
符号と絶対値を分ける
符号のない多倍長整数を表すクラス BigUInt
を作り、符号ありの多倍長整数を絶対値(BigUInt
)と符号(bool
)で表すことにした。
#include "big_uint.hpp"
class BigInt {
private:
bool m_sign;
BigUInt m_abs;
public:
...(略)...
}
符号と絶対値を分ければ多くの演算は符号による場合分けによって符号なしの演算に帰着できる。例えば加算は
BigInt &operator+=(BigInt &lhs, const BigInt &rhs) {
// 以下の <, >, +=, -=, + が全てBigUIntのものであることに注意
if (lhs.m_sign == rhs.m_sign) {
lhs.m_abs += rhs.m_abs;
} else if (lhs.m_abs > rhs.m_abs) {
lhs.m_abs -= rhs.m_abs;
} else if (lhs.m_abs < rhs.m_abs) {
lhs.m_sign = !lhs.m_sign;
lhs.m_abs = rhs.m_abs - lhs.m_abs;
} else {
lhs = BigInt(0ul);
}
return lhs;
}
となる。
固定長整数の場合には符号付き整数を2の補数で表現することで加算と減算の処理を同じにできる。 しかし、多倍長整数の場合には2項演算子の2つのオペランドのビット長が通常異なるので、減算を加算と同じように行うことができない。そのため2の補数表現は使わなかった。
BigUInt
の基数を \(2^{32}\) とする
C++では \(n\) ビット符号なし整数の演算は \(\mod 2^{n}\) 行われるので、 \(2^{n}\)を基数にすると桁同士の演算で都合がいい。
加えて多倍長整数の乗算では桁同士の乗算を桁あふれなく扱う必要があるので、 基数を \(2^{32}\) (型は uint32_t
) とした。
uint64_t
を使えば uint32_t
同士の乗算を桁あふれなく扱うことができる。
class BigUInt {
private:
std::vector<uint32_t> m_vec;
...(略)...
};
オーバーフローの検出
多倍長整数の加算と減算ではキャリーを計算するためにオーバーフローを検出する必要がある。
uint32_t
の変数 a
, b
を足したときにオーバーフローが起こるのは
\begin{align*}
a + b &\geq 2^{32} \\
\iff \quad \quad a &> (2^{32} - 1) - b
\end{align*}
のときである。最後の式の右辺の減算ではオーバーフローは発生しないので、 uint32_t
の演算だけを使ってオーバーフローの有無を調べることができる。
実際のコードでは「この2つの数を足したらオーバーフローするか?」ではなく「この2つの数を足した結果オーバーフローが起こったか?」が必要になったので、次の事実を使った。
uint32_t
型の変数 a
, b
について、 a += b
を実行したとする。この加算でオーバーフローが発生したかどうかは
\begin{align*} a < b \end{align*}
を調べれば分かる。
a += b
を実行する前の a
, b
の値を \(A, B\) とする。
オーバーフローが発生した場合には
\begin{align*}
a &= A + B \mod 2^{32} \\
&= A + B - 2^{32}
\end{align*}
であり、
\begin{align*} b - a = 2^{32} - a > 0 \end{align*}
が成り立つ。 オーバーフローが発生しなかった場合には
\begin{align*}
a &= A + B \mod 2^{32} \\
&= A + B
\end{align*}
より
\begin{align*} a - b = A \geq 0 \end{align*}
が成り立つ。
コードでは以下のようになった。
BigUInt &operator+=(BigUInt &lhs, const BigUInt &rhs) {
uint32_t carry = 0;
uint32_t i;
lhs.m_vec.resize(std::max(lhs.width(), rhs.width()));
for (i = 0; i < rhs.width(); ++i) {
// lhs[i] += carry + lhs[i] を実行し、carryを更新する
lhs[i] += carry;
carry = lhs[i] < carry;
lhs[i] += rhs[i];
carry |= lhs[i] < rhs[i];
}
for (; carry && i < lhs.width(); ++i) {
// lhs[i] += carry を実行し、carryを更新する
++lhs[i];
carry = lhs[i] < carry;
}
if (carry) {
lhs.m_vec.push_back(1);
}
return lhs;
}
減算も似たような関係を使って書いた。
除算アルゴリズム
除算には TAOCP 2巻 の4.3.1項にあるアルゴリズムDを用いた。 このアルゴリズム自体は何も間違っていないが、本に載っている前提条件が普通でないので少しハマった。
具体的には、TAOCPでは \(m + n - 1\) 桁の整数を \(n - 1\) 桁の整数で割った商を \(m\) 桁としているが、 一般には \(m + 1\) 桁になりうる。そして商の最初の桁についてはインデックスの特別扱いが必要になる。 前提条件を読み間違えたせいでメモリリークが発生して大変だった。
効率
効率をあまり意識して書いていないのと、乗算アルゴリズムが \(O(mn)\) なことがあって効率は全然良くない。 冒頭で挙げた使用例は実行に1秒程度かかるが、 pythonインタープリタで同じコードを書いて実行すると一瞬で終わる(悲しい)。
さらなる高速化の手段としては
- 乗算アルゴリズムの変更 (Schönhage–Strassen algorithm, Karatsuba algorithm)
- 配列の要素として
uint32_t
ではなくuint64_t
を使う - アセンブリを書く(コンディションレジスタを使う / mulhi命令を(あれば)使う)
operator<<
でバッファを使うようにする、あるいはoperator<<
を廃止する。
が考えられる。必要になったらやる。
BigUIntクラス
BigUIntクラスのヘッダと実装を載せる。
#ifndef BIG_UINT_H
#define BIG_UINT_H
#include <cstdint>
#include <ostream>
#include <random>
#include <string>
#include <vector>
class BigUInt {
private:
static std::mt19937_64 random_engine;
std::vector<uint32_t> m_vec;
BigUInt(std::vector<uint32_t> vec);
BigUInt &delete_trailing_zeros();
public:
BigUInt();
BigUInt(int32_t i);
BigUInt(int64_t i);
BigUInt(uint32_t i);
BigUInt(uint64_t i);
BigUInt(const std::string &str);
static BigUInt of_vector(std::vector<uint32_t> vec);
static BigUInt random(unsigned n = 5);
uint64_t width() const;
bool is_zero() const;
BigUInt cshift_left(uint64_t shamt) const;
BigUInt &shift_left(uint64_t shamt);
BigUInt cshift_right(uint64_t shamt) const;
BigUInt &shift_right(uint64_t shamt);
BigUInt pow(BigUInt exp) const;
const uint32_t &operator[](uint64_t i) const;
uint32_t &operator[](uint64_t i);
BigUInt &operator++();
BigUInt operator++(int);
BigUInt &operator--();
BigUInt operator--(int);
BigUInt operator-();
friend bool operator==(const BigUInt &lhs, const BigUInt &rhs);
friend bool operator!=(const BigUInt &lhs, const BigUInt &rhs);
friend bool operator<(const BigUInt &lhs, const BigUInt &rhs);
friend bool operator>(const BigUInt &lhs, const BigUInt &rhs);
friend bool operator<=(const BigUInt &lhs, const BigUInt &rhs);
friend bool operator>=(const BigUInt &lhs, const BigUInt &rhs);
friend BigUInt &operator+=(BigUInt &lhs, const BigUInt &rhs);
friend BigUInt &operator-=(BigUInt &lhs, const BigUInt &rhs);
friend BigUInt &operator*=(BigUInt &lhs, const BigUInt &rhs);
friend BigUInt &operator/=(BigUInt &lhs, const BigUInt &rhs);
friend BigUInt &operator%=(BigUInt &lhs, const BigUInt &rhs);
friend BigUInt operator+(const BigUInt &lhs, const BigUInt &rhs);
friend BigUInt operator-(const BigUInt &lhs, const BigUInt &rhs);
friend BigUInt operator*(const BigUInt &lhs, const BigUInt &rhs);
friend BigUInt operator/(const BigUInt &lhs, const BigUInt &rhs);
friend BigUInt operator%(const BigUInt &lhs, const BigUInt &rhs);
friend std::ostream &operator<<(std::ostream &lhs, const BigUInt &rhs);
};
#endif /* BIG_UINT_H */
#include "big_uint.hpp"
#include <stdexcept>
template <typename T>
constexpr bool bit_at(T a, unsigned n) {
static_assert(std::is_arithmetic_v<T>);
return a & (static_cast<T>(1) << n);
}
// position of the highest non-zero bit.
// returns width of the argument type if a == 0.
template <typename T>
constexpr unsigned hnz(T a) {
static_assert(std::is_unsigned_v<T>);
unsigned i = std::numeric_limits<T>::digits - 1;
for (; i != std::numeric_limits<T>::max(); --i)
if (bit_at(a, i)) {
return i;
}
return i;
}
std::mt19937_64 BigUInt::random_engine((std::random_device())());
BigUInt::BigUInt(std::vector<uint32_t> vec) : m_vec(vec) {
delete_trailing_zeros();
}
BigUInt &BigUInt::delete_trailing_zeros() {
while (width() > 1 && m_vec[width() - 1] == 0) {
m_vec.pop_back();
}
return *this;
}
BigUInt::BigUInt() {}
BigUInt::BigUInt(int32_t i) : m_vec({static_cast<uint32_t>(i)}) {}
BigUInt::BigUInt(int64_t i) : BigUInt(static_cast<uint64_t>(i)) {}
BigUInt::BigUInt(uint32_t i) : m_vec({i}) {}
BigUInt::BigUInt(uint64_t i)
: m_vec({static_cast<uint32_t>(i), static_cast<uint32_t>(i >> 32)}) {
delete_trailing_zeros();
}
BigUInt::BigUInt(const std::string &str) : m_vec({0}) {
BigUInt base(1);
BigUInt ten(10);
for (uint64_t i = str.length() - 1; i != UINT64_MAX; --i) {
*this += std::stoul(str.substr(i, 1)) * base;
base *= ten;
}
delete_trailing_zeros();
}
uint64_t BigUInt::width() const { return m_vec.size(); }
const uint32_t &BigUInt::operator[](uint64_t i) const { return m_vec[i]; }
uint32_t &BigUInt::operator[](uint64_t i) { return m_vec[i]; }
BigUInt BigUInt::of_vector(std::vector<uint32_t> vec) { return BigUInt(vec); }
BigUInt BigUInt::random(unsigned n) {
std::vector<uint32_t> vec(n);
for (auto &i : vec) {
i = static_cast<uint32_t>(random_engine());
}
return BigUInt(vec);
}
bool BigUInt::is_zero() const {
for (uint32_t i = 0; i < width(); ++i) {
if ((*this)[i] != 0) {
return false;
}
}
return true;
}
BigUInt BigUInt::cshift_left(uint64_t shamt) const {
BigUInt temp(*this);
return temp.shift_left(shamt);
}
BigUInt &BigUInt::shift_left(uint64_t shamt) {
if (shamt == 0) {
return *this;
}
const uint64_t shamt_quot = shamt / 32;
const uint64_t shamt_rem = shamt % 32;
m_vec.resize(width() + shamt_quot + 1);
uint64_t i = width() - 1;
if (shamt_rem == 0) {
// avoid undefined behavior by shifting 0 or 32.
for (; i != shamt_quot - 1; --i) {
(*this)[i] = (*this)[i - shamt_quot];
}
} else {
const uint64_t shamt_rem_inv = 32 - shamt_rem;
for (; i != shamt_quot; --i) {
auto lower = (*this)[i - shamt_quot - 1] >> shamt_rem_inv;
auto upper = (*this)[i - shamt_quot] << shamt_rem;
(*this)[i] = lower + upper;
}
(*this)[i--] = (*this)[0] << shamt_rem;
}
for (; i != UINT64_MAX; --i) {
(*this)[i] = 0;
}
return delete_trailing_zeros();
}
BigUInt BigUInt::cshift_right(uint64_t shamt) const {
BigUInt temp(*this);
return temp.shift_right(shamt);
}
BigUInt &BigUInt::shift_right(uint64_t shamt) {
if (shamt == 0) {
return *this;
}
const uint64_t shamt_quot = shamt / 32;
const uint64_t shamt_rem = shamt % 32;
if (shamt_rem == 0) {
// avoid undefined behavior by shifting 0 or 32.
for (uint64_t i = 0; i < width() - shamt_quot; ++i) {
(*this)[i] = (*this)[i + shamt_quot];
}
} else {
const uint64_t shamt_rem_inv = 32 - shamt_rem;
for (uint64_t i = 0; i < width() - shamt_quot - 1; ++i) {
auto lower = (*this)[i + shamt_quot] >> shamt_rem;
auto upper = (*this)[i + shamt_quot + 1] << shamt_rem_inv;
(*this)[i] = lower + upper;
}
(*this)[width() - shamt_quot - 1] = (*this)[width() - 1] >> shamt_rem;
}
m_vec.resize(width() - shamt_quot);
return delete_trailing_zeros();
}
BigUInt BigUInt::operator-() {
throw std::invalid_argument("unary - of BigUInt");
}
BigUInt &BigUInt::operator++() {
*this += 1;
return *this;
}
BigUInt BigUInt::operator++(int) {
auto temp(*this);
*this += 1;
return temp;
}
BigUInt &BigUInt::operator--() {
*this -= 1;
return *this;
}
BigUInt BigUInt::operator--(int) {
auto temp(*this);
*this -= 1;
return temp;
}
bool operator==(const BigUInt &lhs, const BigUInt &rhs) {
return lhs.m_vec == rhs.m_vec;
}
bool operator!=(const BigUInt &lhs, const BigUInt &rhs) {
return lhs.m_vec != rhs.m_vec;
}
bool operator<(const BigUInt &lhs, const BigUInt &rhs) {
if (lhs.width() < rhs.width()) {
return true;
}
if (lhs.width() > rhs.width()) {
return false;
}
for (uint64_t i = lhs.width() - 1; i != UINT64_MAX; --i) {
if (lhs[i] == rhs[i]) {
continue;
}
return lhs[i] < rhs[i];
}
return false;
}
bool operator<=(const BigUInt &lhs, const BigUInt &rhs) {
return lhs == rhs || lhs < rhs;
}
bool operator>(const BigUInt &lhs, const BigUInt &rhs) { return rhs < lhs; }
bool operator>=(const BigUInt &lhs, const BigUInt &rhs) { return rhs <= lhs; }
BigUInt &operator+=(BigUInt &lhs, const BigUInt &rhs) {
uint32_t carry = 0;
uint32_t i;
lhs.m_vec.resize(std::max(lhs.width(), rhs.width()));
for (i = 0; i < rhs.width(); ++i) {
lhs[i] += carry;
carry = lhs[i] < carry;
lhs[i] += rhs[i];
carry |= lhs[i] < rhs[i];
}
for (; carry && i < lhs.width(); ++i) {
++lhs[i];
carry = lhs[i] < carry;
}
if (carry) {
lhs.m_vec.push_back(1);
}
return lhs;
}
BigUInt operator+(const BigUInt &lhs, const BigUInt &rhs) {
BigUInt ans(lhs);
return ans += rhs;
}
// if lhs < rhs, result is undefined.
BigUInt &operator-=(BigUInt &lhs, const BigUInt &rhs) {
uint32_t i;
uint32_t carry = 0;
for (i = 0; i < rhs.width(); i++) {
lhs[i] -= carry;
carry = lhs[i] == UINT32_MAX && carry == 1;
carry |= lhs[i] < rhs[i];
lhs[i] -= rhs[i];
}
for (; carry && i < lhs.width(); ++i) {
--lhs[i];
carry = lhs[i] == UINT32_MAX && carry == 1;
}
return lhs.delete_trailing_zeros();
}
BigUInt operator-(const BigUInt &lhs, const BigUInt &rhs) {
BigUInt ans(lhs);
return ans -= rhs;
}
BigUInt operator*(const BigUInt &lhs, const BigUInt &rhs) {
BigUInt ans;
std::vector<uint32_t> lower(lhs.width());
std::vector<uint32_t> upper(lhs.width() + 1);
for (uint32_t i = 0; i < rhs.width(); ++i) {
for (uint32_t j = 0; j < lhs.width(); ++j) {
upper[0] = 0;
lower[j] = lhs[j] * rhs[i];
upper[j + 1] = (static_cast<uint64_t>(lhs[j]) * rhs[i]) >> 32;
}
ans += (BigUInt(lower) + BigUInt(upper)).shift_left(32 * i);
}
return ans;
}
BigUInt &operator*=(BigUInt &lhs, const BigUInt &rhs) {
lhs = lhs * rhs;
return lhs;
}
// knuth's division algorithm D.
// returns quotient, modifing dividend to be remainder.
static BigUInt divmod(BigUInt ÷nd, const BigUInt &divisor) {
if (divisor.is_zero()) {
throw std::invalid_argument("BigUInt : division by zero.");
}
if (dividend < divisor) {
return 0;
}
// normalize dividend and divisor.
const uint64_t shamt = 31 - hnz(divisor[divisor.width() - 1]);
dividend.shift_left(shamt);
const_cast<BigUInt &>(divisor).shift_left(shamt);
uint64_t div_msb = divisor[divisor.width() - 1];
uint64_t div_ssb = divisor.width() == 1 ? 0 : divisor[divisor.width() - 2];
std::vector<uint32_t> ans;
ans.resize(dividend.width() - divisor.width() + 1);
for (uint64_t i = ans.size() - 1; i != UINT64_MAX; --i) {
uint64_t upper = (i + divisor.width() < dividend.width())
? dividend[i + divisor.width()]
: 0;
uint64_t lower = dividend[i + divisor.width() - 1];
uint64_t num = (upper << 32) + lower;
uint64_t q = num / div_msb;
uint64_t r = num % div_msb;
uint64_t idx = i + divisor.width() - 2;
while (r < UINT32_MAX &&
(q >= UINT32_MAX ||
q * div_ssb >
(r << 32) + (idx < dividend.width() ? dividend[idx] : 0))) {
q--;
r += div_msb;
}
BigUInt sub = (q * divisor).cshift_left(i * 32);
if (dividend < sub) {
q--;
sub -= divisor.cshift_left(i * 32);
}
ans[i] = static_cast<uint32_t>(q);
dividend -= sub;
}
const_cast<BigUInt &>(divisor).shift_right(shamt);
dividend.shift_right(shamt);
return BigUInt::of_vector(ans);
}
BigUInt operator/(const BigUInt &lhs, const BigUInt &rhs) {
auto temp(lhs);
return divmod(temp, rhs);
}
BigUInt &operator/=(BigUInt &lhs, const BigUInt &rhs) {
return lhs = lhs / rhs;
}
BigUInt operator%(const BigUInt &lhs, const BigUInt &rhs) {
BigUInt temp(lhs);
return temp %= rhs;
}
BigUInt &operator%=(BigUInt &lhs, const BigUInt &rhs) {
divmod(lhs, rhs);
return lhs;
}
BigUInt BigUInt::pow(BigUInt exp) const {
BigUInt x = 1;
for (; exp != 0; --exp) {
x *= *this;
}
return x;
}
std::ostream &operator<<(std::ostream &lhs, const BigUInt &rhs) {
BigUInt base(10);
const BigUInt ten(10);
while (rhs >= base) {
base *= ten;
}
auto temp(rhs);
while (base != 1) {
temp %= base;
lhs << (temp / (base /= ten))[0];
}
return lhs;
}
BigIntクラス
BigIntクラスのヘッダと実装を載せる。
#ifndef BITINT_H
#define BITINT_H
#include "big_uint.hpp"
class BigInt {
private:
bool m_sign;
BigUInt m_abs;
public:
static std::mt19937_64 random_engine;
BigInt();
BigInt(int32_t i);
BigInt(int64_t i);
BigInt(uint32_t i, bool sign = false);
BigInt(uint64_t i, bool sign = false);
BigInt(const std::string &str);
static BigInt of_big_uint(BigUInt abs, bool sign = false);
static BigInt of_vector(std::vector<uint32_t> abs, bool sign = false);
static BigInt random(unsigned n = 5);
uint64_t width() const;
bool sign() const;
bool is_zero() const;
bool is_positive() const;
bool is_negative() const;
BigInt cshift_left(uint64_t shamt) const;
BigInt &shift_left(uint64_t shamt);
BigInt cshift_right(uint64_t shamt) const;
BigInt &shift_right(uint64_t shamt);
BigInt pow(BigInt exp) const;
const uint32_t &operator[](uint64_t i) const;
uint32_t &operator[](uint64_t i);
BigInt &operator++();
BigInt operator++(int);
BigInt &operator--();
BigInt operator--(int);
BigInt operator-() const;
friend bool operator==(const BigInt &lhs, const BigInt &rhs);
friend bool operator!=(const BigInt &lhs, const BigInt &rhs);
friend bool operator<(const BigInt &lhs, const BigInt &rhs);
friend bool operator>(const BigInt &lhs, const BigInt &rhs);
friend bool operator<=(const BigInt &lhs, const BigInt &rhs);
friend bool operator>=(const BigInt &lhs, const BigInt &rhs);
friend BigInt &operator+=(BigInt &lhs, const BigInt &rhs);
friend BigInt &operator-=(BigInt &lhs, const BigInt &rhs);
friend BigInt &operator*=(BigInt &lhs, const BigInt &rhs);
friend BigInt &operator/=(BigInt &lhs, const BigInt &rhs);
friend BigInt &operator%=(BigInt &lhs, const BigInt &rhs);
friend BigInt operator+(const BigInt &lhs, const BigInt &rhs);
friend BigInt operator-(const BigInt &lhs, const BigInt &rhs);
friend BigInt operator*(const BigInt &lhs, const BigInt &rhs);
friend BigInt operator/(const BigInt &lhs, const BigInt &rhs);
friend BigInt operator%(const BigInt &lhs, const BigInt &rhs);
friend std::ostream &operator<<(std::ostream &lhs, const BigInt &rhs);
};
#endif // BITINT_H
#include "big_int.hpp"
const uint32_t &BigInt::operator[](uint64_t i) const { return m_abs[i]; }
uint32_t &BigInt::operator[](uint64_t i) { return m_abs[i]; }
std::mt19937_64 BigInt::random_engine((std::random_device())());
BigInt::BigInt() {}
BigInt::BigInt(int32_t i)
: m_sign(i < 0),
m_abs(m_sign ? static_cast<uint32_t>(-i) : static_cast<uint32_t>(i)) {}
BigInt::BigInt(int64_t i)
: m_sign(i < 0),
m_abs(m_sign ? static_cast<uint64_t>(-i) : static_cast<uint64_t>(i)) {}
BigInt::BigInt(uint32_t i, bool sign) : m_sign(sign), m_abs(i) {}
BigInt::BigInt(uint64_t i, bool sign) : m_sign(sign), m_abs(i) {}
BigInt::BigInt(const std::string &str)
: m_sign(str[0] == '-'), m_abs(str.substr(m_sign, str.length() - m_sign)) {}
BigInt BigInt::of_big_uint(BigUInt abs, bool sign) {
BigInt temp;
temp.m_abs = abs;
temp.m_sign = sign;
return temp;
}
BigInt BigInt::of_vector(std::vector<uint32_t> abs, bool sign) {
BigInt temp;
temp.m_abs = BigUInt::of_vector(abs);
temp.m_sign = sign;
return temp;
}
BigInt BigInt::random(unsigned n) {
return BigInt::of_big_uint(BigUInt::random(n), random_engine() % 2);
}
uint64_t BigInt::width() const { return m_abs.width(); }
bool BigInt::sign() const { return m_sign; }
bool BigInt::is_zero() const { return m_abs.is_zero(); }
bool BigInt::is_positive() const { return !m_sign; }
bool BigInt::is_negative() const { return m_sign; }
BigInt BigInt::cshift_left(uint64_t shamt) const {
return BigInt::of_big_uint(m_abs.cshift_left(shamt), m_sign);
}
BigInt &BigInt::shift_left(uint64_t shamt) {
m_abs.shift_left(shamt);
return *this;
}
BigInt BigInt::cshift_right(uint64_t shamt) const {
return BigInt::of_big_uint(m_abs.cshift_right(shamt), m_sign);
}
BigInt &BigInt::shift_right(uint64_t shamt) {
m_abs.shift_right(shamt);
return *this;
}
bool operator==(const BigInt &lhs, const BigInt &rhs) {
return lhs.m_sign == rhs.m_sign && lhs.m_abs == rhs.m_abs;
}
bool operator!=(const BigInt &lhs, const BigInt &rhs) {
return lhs.m_sign != rhs.m_sign || lhs.m_abs != rhs.m_abs;
}
BigInt BigInt::operator-() const { return BigInt::of_big_uint(m_abs, !m_sign); }
bool operator<(const BigInt &lhs, const BigInt &rhs) {
if (lhs.is_positive() && rhs.is_negative()) {
return false;
}
if (lhs.is_negative() && rhs.is_positive()) {
return true;
}
if (lhs.is_positive() && rhs.is_positive()) {
return lhs.m_abs < rhs.m_abs;
}
return lhs.m_abs > rhs.m_abs;
}
bool operator<=(const BigInt &lhs, const BigInt &rhs) {
return lhs == rhs || lhs < rhs;
}
bool operator>(const BigInt &lhs, const BigInt &rhs) { return rhs < lhs; }
bool operator>=(const BigInt &lhs, const BigInt &rhs) { return rhs <= lhs; }
BigInt &operator+=(BigInt &lhs, const BigInt &rhs) {
if (lhs.m_sign == rhs.m_sign) {
lhs.m_abs += rhs.m_abs;
} else if (lhs.m_abs > rhs.m_abs) {
lhs.m_abs -= rhs.m_abs;
} else if (lhs.m_abs < rhs.m_abs) {
lhs.m_sign = !lhs.m_sign;
lhs.m_abs = rhs.m_abs - lhs.m_abs;
} else {
lhs = BigInt(0ul);
}
return lhs;
}
BigInt operator+(const BigInt &lhs, const BigInt &rhs) {
BigInt ans = lhs;
return ans += rhs;
}
BigInt &operator-=(BigInt &lhs, const BigInt &rhs) {
if (lhs.m_sign != rhs.m_sign) {
lhs.m_abs += rhs.m_abs;
} else if (lhs.m_abs > rhs.m_abs) {
lhs.m_abs -= rhs.m_abs;
} else if (lhs.m_abs < rhs.m_abs) {
lhs.m_sign = !lhs.m_sign;
lhs.m_abs = rhs.m_abs - lhs.m_abs;
} else {
lhs = BigInt(0ul);
}
return lhs;
}
BigInt operator-(const BigInt &lhs, const BigInt &rhs) {
auto temp(lhs);
return temp -= rhs;
}
BigInt &operator*=(BigInt &lhs, const BigInt &rhs) {
lhs.m_abs *= rhs.m_abs;
lhs.m_sign = lhs.sign() != rhs.sign();
return lhs;
}
BigInt operator*(const BigInt &lhs, const BigInt &rhs) {
auto temp(lhs);
return temp *= rhs;
}
BigInt &operator/=(BigInt &lhs, const BigInt &rhs) {
lhs.m_abs /= rhs.m_abs;
lhs.m_sign = lhs.sign() != rhs.sign();
return lhs;
}
BigInt operator/(const BigInt &lhs, const BigInt &rhs) {
auto temp(lhs);
return temp /= rhs;
}
BigInt &operator%=(BigInt &lhs, const BigInt &rhs) {
lhs.m_abs %= rhs.m_abs;
lhs.m_sign = lhs.sign() != rhs.sign();
return lhs;
}
BigInt operator%(const BigInt &lhs, const BigInt &rhs) {
auto temp(lhs);
return temp %= rhs;
}
BigInt &BigInt::operator++() {
*this += 1;
return *this;
}
BigInt BigInt::operator++(int) {
auto temp(*this);
*this += 1;
return temp;
}
BigInt &BigInt::operator--() {
*this -= 1;
return *this;
}
BigInt BigInt::operator--(int) {
auto temp(*this);
*this -= 1;
return temp;
}
BigInt BigInt::pow(BigInt exp) const {
BigInt x = 1;
for (; exp > 0; --exp) {
x *= *this;
}
return x;
}
std::ostream &operator<<(std::ostream &lhs, const BigInt &rhs) {
if (rhs.is_negative()) {
lhs << "-";
}
lhs << rhs.m_abs;
return lhs;
}