小数型を有理数に変換する in C++

IEEE 754 方式の浮動小数点数ではビット列がある有理数をちょうど表す。

例えば倍精度の場合、

\begin{align*} \text{s} \ \text{e}_{11} \text{e}_{10} \ldots \text{e}_{1} \ \text{s}_{52} \text{s}_{51} \ldots \text{s}_{1} \end{align*}

というビット列が

\begin{align*} (-1)^{\text{s}} (1.\text{s}_{52} \text{s}_{51} \ldots \text{s}_{1})_{2} \times 2^{(\text{e}_{11} \text{e}_{10} \ldots \text{e}_{1})_{2} - 1023} \end{align*}

という有理数を表す。

前の記事で浮動小数点数を扱うクラスを作り、 別の記事で任意精度の有理数クラスを作ったので、 この2つを組み合わせれば浮動小数点数が表す有理数を計算することができる。

浮動小数点数を扱うクラスを拡張してこの処理を実装した。

Code Snippet 1: floating_point.hpp

#include <bitset>
#include <cassert>
#include <type_traits>
#include "big_int.hpp"
#include "fraction.hpp"

template <typename T>
class FloatingPoint;

using Single = FloatingPoint<float>;
using Double = FloatingPoint<double>;

template <typename...>
constexpr bool false_v = false;

// FloatingPoint assumes that sign bit, exponent bits, and significand bits are
// in this order.
template <typename T>
class FloatingPoint {
 public:
  constexpr static unsigned width_exp = [] {
    if (std::is_same_v<T, float>) {
      return 8;
    } else if (std::is_same_v<T, double>) {
      return 11;
    } else {
      assert(false_v<>);
    }
  }();

  constexpr static unsigned width_sig = [] {
    if (std::is_same_v<T, float>) {
      return 23;
    } else if (std::is_same_v<T, double>) {
      return 52;
    } else {
      assert(false_v<>);
    }
  }();

  constexpr static unsigned bias = [] {
    if (std::is_same_v<T, float>) {
      return 127;
    } else if (std::is_same_v<T, double>) {
      return 1023;
    } else {
      assert(false_v<>);
    }
  }();

  constexpr static unsigned width = 1 + width_exp + width_sig;

 private:
  std::bitset<width> m_bits_all;
  bool m_bit_sign;
  std::bitset<width_exp> m_bits_exp;
  std::bitset<width_sig> m_bits_sig;
  T m_val;

  static std::bitset<width> bits_of_val(T val);
  static T val_of_bits(std::bitset<width> bits);

  void set_bits() {
    m_bit_sign = m_bits_all[width - 1];

    for (unsigned i = 0; i < width_exp; ++i) {
      m_bits_exp[i] = m_bits_all[width_sig + i];
    }

    for (unsigned i = 0; i < width_sig; ++i) {
      m_bits_sig[i] = m_bits_all[i];
    }
  }

 public:
  FloatingPoint(T val) : m_bits_all(bits_of_val(val)), m_val(val) {
    set_bits();
  }

  FloatingPoint(bool sign, std::bitset<width_exp> exp,
                std::bitset<width_sig> sig)
      : m_bits_all(std::to_string(sign) + exp.to_string() + sig.to_string()),
        m_bit_sign(sign),
        m_bits_exp(exp),
        m_bits_sig(sig),
        m_val(val_of_bits(m_bits_all)) {}

  explicit FloatingPoint(const std::string &val)
      : m_bits_all(val), m_val(val_of_bits(std::bitset<width>(val))) {
    set_bits();
  }

  const T &value() { return m_val; }
  std::bitset<width> bits_all() const { return m_bits_all; }
  std::bitset<width_exp> bits_exp() const { return m_bits_exp; }
  std::bitset<width_sig> bits_sig() const { return m_bits_sig; }
  bool bit_sign() const { return m_bit_sign; }

  // assumes that width_sig < 64 && width_exp < 64.
  Fraction<BigInt> to_fraction() {
    if (m_bits_all == std::bitset<width>(0)) {
      return Fraction<BigInt>(0);
    }
    uint64_t exp = m_bits_exp.to_ullong();
    uint64_t sig = m_bits_sig.to_ullong();
    uint64_t total_bias = bias + width_sig;
    if (exp > total_bias) {
      BigInt a(sig + (1ul << width_sig));
      BigInt b = BigInt(1).shift_left(exp - total_bias);
      return m_bit_sign ? -Fraction<BigInt>(a * b) : Fraction<BigInt>(a * b);
    } else {
      BigInt numer(sig + (1ul << width_sig));
      BigInt denom = BigInt(1).shift_left(total_bias - exp);
      return m_bit_sign ? -Fraction<BigInt>(numer, denom)
                        : Fraction<BigInt>(numer, denom);
    }
  }

  friend bool operator==(const FloatingPoint &lhs, const FloatingPoint &rhs) {
    return rhs.bits_all() == lhs.bits_all();
  }
};

template <>
float Single::val_of_bits(std::bitset<Single::width> val) {
  auto temp = val.to_ulong();
  char *c = reinterpret_cast<char *>(&temp);
  return *reinterpret_cast<float *>(c);
}

template <>
std::bitset<Single::width> Single::bits_of_val(float val) {
  char *c = reinterpret_cast<char *>(&val);
  return *reinterpret_cast<uint32_t *>(c);
}

template <>
double Double::val_of_bits(std::bitset<Double::width> val) {
  auto temp = val.to_ullong();
  char *c = reinterpret_cast<char *>(&temp);
  return *reinterpret_cast<double *>(c);
}

template <>
std::bitset<Double::width> Double::bits_of_val(double val) {
  char *c = reinterpret_cast<char *>(&val);
  return *reinterpret_cast<uint64_t *>(c);
}

以下のように使える。

#include "floating_point.hpp"
#include <iostream>

int main() {
  std::cout  << "1234.567f = " << Single(1234.567f).to_fraction() << "\n";
  std::cout << "12345678.90123456 = " << Double(12345678.90123456).to_fraction() << "\n";

  std::cout << "1.0e20f = " << Double(1.0e20f).to_fraction() << "\n";
  std::cout << "1.0e50 = " << Double(1.0e50).to_fraction() << "\n";

  return 0;
}
1234.567f = 10113573 / 8192
12345678.90123456 = 1657008972741239 / 134217728
1.0e20f = 100000002004087734272
1.0e50 = 100000000000000007629769841091887003294964970946560

この答えがあっていることは出力された文字列をググれば分かる。 Python の Decimal や Haskell の toRational の使用例が出てくるので多分あっているだろう。

応用として、以下のようにすると減算で生じる桁落ち誤差を計算することができる。

#include "floating_point.hpp"
#include <iostream>

// return the error that a - b will have.
template <typename T>
Fraction<BigInt> error_by_sub(T a, T b) {
    auto exact = FloatingPoint(a).to_fraction() - FloatingPoint(b).to_fraction();
    return exact - FloatingPoint(a - b).to_fraction();
}

int main() {
  double x = 1.3;
  std::cout << "x = 1.3 = " << Double(x).to_fraction() << "\n";
  std::cout << "error in 1.3 - x      : " << error_by_sub(1.3, 1.3)  << "\n";
  std::cout << "error in 12.3 - x     : " << error_by_sub(12.3, 1.3)  << "\n";
  std::cout << "error in 123.3 - x    : " << error_by_sub(123.3, 1.3)  << "\n";
  std::cout << "error in 123456.3 - x : " << error_by_sub(123456.3, 1.3)  << "\n";
  std::cout << "error in 1.0e25 - x   : " << error_by_sub(1.0e25, 1.3)  << "\n";
  return 0;
}
x = 1.3 = 5854679515581645 / 4503599627370496
error in 1.3 - x      : 0
error in 12.3 - x     : 3 / 4503599627370496
error in 123.3 - x    : -13 / 4503599627370496
error in 123456.3 - x : 13107 / 4503599627370496
error in 1.0e25 - x   : -5854679515581645 / 4503599627370496