diff --git a/CMakeLists.txt b/CMakeLists.txt index a97a9c96..da1945ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -423,6 +423,9 @@ target_link_libraries(test-vm PRIVATE ton_crypto fift-lib) add_executable(test-smartcont test/test-td-main.cpp ${SMARTCONT_TEST_SOURCE}) target_link_libraries(test-smartcont PRIVATE smc-envelope fift-lib ton_db) +add_executable(test-bigint ${BIGINT_TEST_SOURCE}) +target_link_libraries(test-bigint PRIVATE ton_crypto) + add_executable(test-cells test/test-td-main.cpp ${CELLS_TEST_SOURCE}) target_link_libraries(test-cells PRIVATE ton_crypto) @@ -523,6 +526,7 @@ if (HAS_PARENT) ${FEC_TEST_SOURCE} ${ED25519_TEST_SOURCE} ${TONDB_TEST_SOURCE} + ${BIGNUM_TEST_SOURCE} ${CELLS_TEST_SOURCE} # ${TONVM_TEST_SOURCE} ${FIFT_TEST_SOURCE} ${TONLIB_ONLINE_TEST_SOURCE} PARENT_SCOPE) endif() @@ -536,6 +540,7 @@ set(TEST_OPTIONS "--regression ${CMAKE_CURRENT_SOURCE_DIR}/test/regression-tests separate_arguments(TEST_OPTIONS) add_test(test-ed25519-crypto crypto/test-ed25519-crypto) add_test(test-ed25519 test-ed25519) +add_test(test-bigint test-bigint) add_test(test-vm test-vm ${TEST_OPTIONS}) add_test(test-fift test-fift ${TEST_OPTIONS}) add_test(test-cells test-cells ${TEST_OPTIONS}) diff --git a/crypto/CMakeLists.txt b/crypto/CMakeLists.txt index 79406ffc..88a0272b 100644 --- a/crypto/CMakeLists.txt +++ b/crypto/CMakeLists.txt @@ -259,6 +259,11 @@ set(FIFT_TEST_SOURCE PARENT_SCOPE ) +set(BIGINT_TEST_SOURCE + ${CMAKE_CURRENT_SOURCE_DIR}/test/test-bigint.cpp + PARENT_SCOPE +) + add_library(ton_crypto STATIC ${TON_CRYPTO_SOURCE}) target_include_directories(ton_crypto PUBLIC $ diff --git a/crypto/common/bigint.hpp b/crypto/common/bigint.hpp index e78f4892..94da3a13 100644 --- a/crypto/common/bigint.hpp +++ b/crypto/common/bigint.hpp @@ -264,7 +264,7 @@ class AnyIntView { return digits[size() - 1]; } double top_double() const { - return size() > 1 ? (double)digits[size() - 1] + (double)digits[size() - 2] * (1.0 / Tr::Base) + return size() > 1 ? (double)digits[size() - 1] + (double)digits[size() - 2] * Tr::InvBase : (double)digits[size() - 1]; } bool is_odd_any() const { @@ -314,8 +314,15 @@ class BigIntG { digits[0] = x; } BigIntG(Normalize, word_t x) : n(1) { - digits[0] = x; - normalize_bool(); + if (x >= -Tr::Half && x < Tr::Half) { + digits[0] = x; + } else if (len <= 1) { + digits[0] = x; + normalize_bool(); + } else { + digits[0] = ((x + Tr::Half) & (Tr::Base - 1)) - Tr::Half; + digits[n++] = (x >> Tr::word_shift) + (digits[0] < 0); + } } BigIntG(const BigIntG& x) : n(x.n) { std::memcpy(digits, x.digits, n * sizeof(word_t)); @@ -757,7 +764,7 @@ bool AnyIntView::add_pow2_any(int exponent, int factor) { while (size() <= k) { digits[inc_size()] = 0; } - digits[k] += (factor << dm.rem); + digits[k] += ((word_t)factor << dm.rem); return true; } @@ -1087,12 +1094,16 @@ int AnyIntView::cmp_any(const AnyIntView& yp) const { template int AnyIntView::cmp_any(word_t y) const { - if (size() > 1) { - return top_word() < 0 ? -1 : 1; - } else if (size() == 1) { + if (size() == 1) { return digits[0] < y ? -1 : (digits[0] > y ? 1 : 0); - } else { + } else if (!size()) { return 0x80000000; + } else if (size() == 2 && (y >= Tr::Half || y < -Tr::Half)) { + word_t x0 = digits[0] & (Tr::Base - 1), y0 = y & (Tr::Base - 1); + word_t x1 = digits[1] + (digits[0] >> Tr::word_shift), y1 = (y >> Tr::word_shift); + return x1 < y1 ? -1 : (x1 > y1 ? 1 : (x0 < y0 ? -1 : (x0 > y0 ? 1 : 0))); + } else { + return top_word() < 0 ? -1 : 1; } } @@ -1312,17 +1323,14 @@ bool AnyIntView::mod_div_any(const AnyIntView& yp, AnyIntView& quot, if (k > quot.max_size()) { return invalidate_bool(); } - quot.set_size(max(k,1)); - for(int qi=0; qi< max(k,1); qi++) { - quot.digits[qi]=0; - } + quot.set_size(std::max(k, 1)); + quot.digits[0] = 0; } else { if (k >= quot.max_size()) { return invalidate_bool(); } quot.set_size(k + 1); - double x_top = top_double(); - word_t q = std::llrint(x_top * y_inv * Tr::InvBase); + word_t q = std::llrint(top_double() * y_inv * Tr::InvBase); quot.digits[k] = q; int i = yp.size() - 1; word_t hi = 0; @@ -1337,8 +1345,7 @@ bool AnyIntView::mod_div_any(const AnyIntView& yp, AnyIntView& quot, quot.digits[0] = 0; } while (--k >= 0) { - double x_top = top_double(); - word_t q = std::llrint(x_top * y_inv); + word_t q = std::llrint(top_double() * y_inv); quot.digits[k] = q; for (int i = yp.size() - 1; i >= 0; --i) { Tr::sub_mul(&digits[k + i + 1], &digits[k + i], q, yp.digits[i]); @@ -1346,15 +1353,18 @@ bool AnyIntView::mod_div_any(const AnyIntView& yp, AnyIntView& quot, dec_size(); digits[size() - 1] += (digits[size()] << word_shift); } - if (size() >= yp.size()) { - assert(size() == yp.size()); - double x_top = top_double(); - double t = x_top * y_inv * Tr::InvBase; + if (size() >= yp.size() - 1) { + assert(size() <= yp.size()); + bool grow = (size() < yp.size()); + double t = top_double() * y_inv * (grow ? Tr::InvBase * Tr::InvBase : Tr::InvBase); if (round_mode >= 0) { t += (round_mode ? 1 : 0.5); } word_t q = std::llrint(std::floor(t)); if (q) { + if (grow) { + digits[inc_size()] = 0; + } for (int i = 0; i < size(); i++) { digits[i] -= q * yp.digits[i]; } @@ -1411,6 +1421,7 @@ bool AnyIntView::mod_div_any(const AnyIntView& yp, AnyIntView& quot, return normalize_bool_any(); } +// works for almost-normalized numbers (digits -Base+1 .. Base-1, top non-zero), result also almost-normalized template bool AnyIntView::mod_pow2_any(int exponent) { if (!is_valid()) { @@ -1462,25 +1473,21 @@ bool AnyIntView::mod_pow2_any(int exponent) { if (exponent >= max_size() * word_shift) { return invalidate_bool(); } - if (q - word_shift >= 0) { + if (q - word_shift >= 0) { // original top digit was a non-zero multiple of Base, impossible(?) digits[size() - 1] = 0; digits[inc_size()] = ((word_t)1 << (q - word_shift)); - } - if (q - word_shift == -1 && size() < max_size() - 1) { + } else if (q - word_shift == -1 && size() < max_size()) { digits[size() - 1] = -Tr::Half; digits[inc_size()] = 1; } else { digits[size() - 1] = pow; } return true; - } else if (v >= Tr::Half) { - if (size() == max_size() - 1) { - return invalidate_bool(); - } else { - digits[size() - 1] = v | -Tr::Half; - digits[inc_size()] = ((word_t)1 << (q - word_shift)); - return true; - } + } else if (v >= Tr::Half && size() < max_size()) { + word_t w = (((v >> (word_shift - 1)) + 1) >> 1); + digits[size() - 1] = v - (w << word_shift); + digits[inc_size()] = w; + return true; } else { digits[size() - 1] = v; return true; diff --git a/crypto/common/refint.cpp b/crypto/common/refint.cpp index b79750ce..8e06da5f 100644 --- a/crypto/common/refint.cpp +++ b/crypto/common/refint.cpp @@ -128,12 +128,10 @@ RefInt256 muldiv(RefInt256 x, RefInt256 y, RefInt256 z, int round_mode) { } std::pair muldivmod(RefInt256 x, RefInt256 y, RefInt256 z, int round_mode) { - typename td::BigInt256::DoubleInt tmp{0}; + typename td::BigInt256::DoubleInt tmp{0}, quot; tmp.add_mul(*x, *y); - RefInt256 quot{true}; - tmp.mod_div(*z, quot.unique_write(), round_mode); - quot.write().normalize(); - return std::make_pair(std::move(quot), td::make_refint(tmp)); + tmp.mod_div(*z, quot, round_mode); + return std::make_pair(td::make_refint(quot.normalize()), td::make_refint(tmp)); } RefInt256 operator&(RefInt256 x, RefInt256 y) { diff --git a/crypto/fift/lib/Asm.fif b/crypto/fift/lib/Asm.fif index d61923c1..76381b85 100644 --- a/crypto/fift/lib/Asm.fif +++ b/crypto/fift/lib/Asm.fif @@ -334,9 +334,14 @@ x{A926} @Defop RSHIFTC x{A935} @Defop(8u+1) RSHIFTR# x{A936} @Defop(8u+1) RSHIFTC# x{A938} @Defop(8u+1) MODPOW2# +x{A939} @Defop(8u+1) MODPOW2R# +x{A93A} @Defop(8u+1) MODPOW2C# x{A984} @Defop MULDIV x{A985} @Defop MULDIVR +x{A988} @Defop MULMOD x{A98C} @Defop MULDIVMOD +x{A98D} @Defop MULDIVMODR +x{A98E} @Defop MULDIVMODC x{A9A4} @Defop MULRSHIFT x{A9A5} @Defop MULRSHIFTR x{A9A6} @Defop MULRSHIFTC diff --git a/crypto/test/modbigint.cpp b/crypto/test/modbigint.cpp new file mode 100644 index 00000000..851f0f9b --- /dev/null +++ b/crypto/test/modbigint.cpp @@ -0,0 +1,1070 @@ +/* + This file is part of TON Blockchain Library. + + TON Blockchain Library is free software: you can redistribute it and/or modify + it under the terms of the GNU Lesser General Public License as published by + the Free Software Foundation, either version 2 of the License, or + (at your option) any later version. + + TON Blockchain Library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public License + along with TON Blockchain Library. If not, see . +*/ +#include +#include +#include +#include +#include +#include + +namespace modint { + +enum { mod_cnt = 32 }; + +// mod_cnt = 9 => integers -2^268 .. 2^268 +// mod_cnt = 18 => integers -2^537 .. 2^537 +// mod_cnt = 32 => integers -2^955 .. 2^955 +constexpr int mod[mod_cnt] = {999999937, 999999929, 999999893, 999999883, 999999797, 999999761, 999999757, 999999751, + 999999739, 999999733, 999999677, 999999667, 999999613, 999999607, 999999599, 999999587, + 999999541, 999999527, 999999503, 999999491, 999999487, 999999433, 999999391, 999999353, + 99999337, 999999323, 999999229, 999999223, 999999197, 999999193, 999999191, 999999181}; + +// invm[i][j] = mod[i]^(-1) modulo mod[j] +int invm[mod_cnt][mod_cnt]; + +int gcdx(int a, int b, int& u, int& v); + +template +struct ModArray; + +template +struct MixedRadix; + +template +struct ArrayRawDumpRef; + +template +std::ostream& raw_dump_array(std::ostream& os, const std::array& arr) { + os << '['; + for (auto x : arr) { + os << ' ' << x; + } + return os << " ]"; +} + +template +struct MixedRadix { + enum { n = N }; + int a[N]; + MixedRadix(int v) { + set_int(v); + } + MixedRadix() = default; + MixedRadix(const MixedRadix&) = default; + MixedRadix(std::initializer_list l) { + auto sz = std::min(l.size(), (std::size_t)N); + std::copy(l.begin(), l.begin() + sz, a); + std::fill(a + sz, a + N, 0); + } + MixedRadix(const std::array& arr) { + std::copy(arr.begin(), arr.end(), a); + } + template + MixedRadix(const MixedRadix& other) { + static_assert(M >= N); + std::copy(other.a, other.a + N, a); + } + MixedRadix(const ModArray& other); + MixedRadix(const ModArray& other, bool sgnd); + + MixedRadix& set_zero() { + std::fill(a, a + N, 0); + return *this; + } + MixedRadix& set_one() { + a[0] = 1; + std::fill(a + 1, a + N, 0); + return *this; + } + MixedRadix& set_int(int v) { + a[0] = v; + std::fill(a + 1, a + N, 0); + return *this; + } + + MixedRadix copy() const { + return MixedRadix{*this}; + } + + static const int* mod_array() { + return mod; + } + + static int modulus(int i) { + return mod[i]; + } + + int sgn() const { + int i = N - 1; + while (i >= 0 && !a[i]) { + --i; + } + return i < 0 ? 0 : (a[i] > 0 ? 1 : -1); + } + + int cmp(const MixedRadix& other) const { + int i = N - 1; + while (i >= 0 && a[i] == other.a[i]) { + --i; + } + return i < 0 ? 0 : (a[i] > other.a[i] ? 1 : -1); + } + + bool is_small() const { + return !a[N - 1] || a[N - 1] == -1; + } + + bool operator==(const MixedRadix& other) const { + return std::equal(a, a + N, other.a); + } + + bool operator!=(const MixedRadix& other) const { + return !std::equal(a, a + N, other.a); + } + + bool operator<(const MixedRadix& other) const { + return cmp(other) < 0; + } + + bool operator<=(const MixedRadix& other) const { + return cmp(other) <= 0; + } + + bool operator>(const MixedRadix& other) const { + return cmp(other) > 0; + } + + bool operator>=(const MixedRadix& other) const { + return cmp(other) >= 0; + } + + explicit operator bool() const { + return sgn(); + } + + bool operator!() const { + return !sgn(); + } + + MixedRadix& negate() { + int i = 0; + while (i < N - 1 && !a[i]) { + i++; + } + a[i]--; + for (; i < N; i++) { + a[i] = mod[i] - a[i] - 1; + } + a[N - 1] -= mod[N - 1]; + return *this; + } + + static const MixedRadix& pow2(int power); + static MixedRadix negpow2(int power) { + return -pow2(power); + } + + template + const MixedRadix& as_shorter() const { + static_assert(M <= N); + return *reinterpret_cast*>(this); + } + + MixedRadix& import_mod_array(const int* data, bool sgnd = true) { + for (int i = 0; i < N; i++) { + a[i] = data[i] % mod[i]; + } + for (int i = 0; i < N; i++) { + if (a[i] < 0) { + a[i] += mod[i]; + } + for (int j = i + 1; j < N; j++) { + a[j] = (int)((long long)(a[j] - a[i]) * invm[i][j] % mod[j]); + } + } + if (sgnd && a[N - 1] > (mod[N - 1] >> 1)) { + a[N - 1] -= mod[N - 1]; + } + return *this; + } + + MixedRadix& operator=(const MixedRadix&) = default; + + template + MixedRadix& operator=(const MixedRadix& other) { + static_assert(M >= N); + std::copy(other.a, other.a + N, a); + } + + MixedRadix& import_mod_array(const ModArray& other, bool sgnd = true); + + MixedRadix& operator=(const ModArray& other) { + return import_mod_array(other); + } + + MixedRadix& set_sum(const MixedRadix& x, const MixedRadix& y, int factor = 1) { + long long carry = 0; + for (int i = 0; i < N; i++) { + long long acc = x.a[i] + carry + (long long)factor * y.a[i]; + carry = acc / mod[i]; + a[i] = (int)(acc - carry * mod[i]); + if (a[i] < 0) { + a[i] += mod[i]; + --carry; + } + } + if (a[N - 1] >= 0 && carry == -1) { + a[N - 1] -= mod[N - 1]; + } + return *this; + } + + MixedRadix& operator+=(const MixedRadix& other) { + return set_sum(*this, other); + } + + MixedRadix& operator-=(const MixedRadix& other) { + return set_sum(*this, other, -1); + } + + static const MixedRadix& zero(); + static const MixedRadix& one(); + + MixedRadix& operator*=(int factor) { + return set_sum(zero(), *this, factor); + } + + MixedRadix operator-() const { + MixedRadix copy{*this}; + copy.negate(); + return copy; + } + + MixedRadix operator+(const MixedRadix& other) const { + MixedRadix res; + res.set_sum(*this, other); + return res; + } + + MixedRadix operator-(const MixedRadix& other) const { + MixedRadix res; + res.set_sum(*this, other, -1); + return res; + } + + MixedRadix operator*(int factor) const { + MixedRadix res; + res.set_sum(zero(), *this, factor); + return res; + } + + int operator%(int b) const { + int x = a[N - 1] % b; + for (int i = N - 2; i >= 0; --i) { + x = ((long long)x * mod[i] + a[i]) % b; + } + return ((x ^ b) < 0 && x) ? x + b : x; + } + + explicit operator double() const { + double acc = 0.; + for (int i = N - 1; i >= 0; --i) { + acc = acc * mod[i] + a[i]; + } + return acc; + } + + explicit operator long long() const { + long long acc = 0.; + for (int i = N - 1; i >= 0; --i) { + acc = acc * mod[i] + a[i]; + } + return acc; + } + + MixedRadix& to_base(int base) { + int k = N - 1; + while (k > 0 && !a[k]) { + --k; + } + if (k <= 0) { + return *this; + } + for (int i = k - 1; i >= 0; --i) { + // a[i..k] := a[i+1..k] * mod[i] + a[i] + long long carry = a[i]; + for (int j = i; j < k; j++) { + long long t = (long long)a[j + 1] * mod[i] + carry; + carry = t / base; + a[j] = (int)(t - carry * base); + } + a[k] = (int)carry; + } + return *this; + } + + std::ostream& print_dec_destroy(std::ostream& os) { + int s = sgn(); + if (s < 0) { + os << '-'; + negate(); + } else if (!s) { + os << '0'; + return os; + } + to_base(1000000000); + int i = N - 1; + while (!a[i] && i > 0) { + --i; + } + os << a[i]; + while (--i >= 0) { + char buff[12]; + sprintf(buff, "%09d", a[i]); + os << buff; + } + return os; + } + + std::ostream& print_dec(std::ostream& os) const& { + MixedRadix copy{*this}; + return copy.print_dec_destroy(os); + } + + std::ostream& print_dec(std::ostream& os) && { + return print_dec_destroy(os); + } + + std::string to_dec_string_destroy() { + std::ostringstream os; + print_dec_destroy(os); + return std::move(os).str(); + } + + std::string to_dec_string() const& { + MixedRadix copy{*this}; + return copy.to_dec_string_destroy(); + } + + std::string to_dec_string() && { + return to_dec_string_destroy(); + } + + bool to_binary_destroy(unsigned char* arr, int size, bool sgnd = true) { + if (size <= 0) { + return false; + } + int s = (sgnd ? sgn() : 1); + memset(arr, 0, size); + if (s < 0) { + negate(); + } else if (!s) { + return true; + } + to_base(1 << 30); + long long acc = 0; + int bits = 0, j = size; + for (int i = 0; i < N; i++) { + if (!j && a[i]) { + return false; + } + acc += ((long long)a[i] << bits); + bits += 30; + while (bits >= 8 && j > 0) { + arr[--j] = (unsigned char)(acc & 0xff); + bits -= 8; + acc >>= 8; + } + } + while (j > 0) { + arr[--j] = (unsigned char)(acc & 0xff); + acc >>= 8; + } + if (acc) { + return false; + } + if (!sgnd) { + return true; + } + if (s >= 0) { + return arr[0] <= 0x7f; + } + j = size - 1; + while (j >= 0 && !arr[j]) { + --j; + } + assert(j >= 0); + arr[j] = (unsigned char)(-arr[j]); + while (--j >= 0) { + arr[j] = (unsigned char)~arr[j]; + } + return arr[0] >= 0x80; + } + + bool to_binary(unsigned char* arr, int size, bool sgnd = true) const& { + MixedRadix copy{*this}; + return copy.to_binary_destroy(arr, size, sgnd); + } + + bool to_binary(unsigned char* arr, int size, bool sgnd = true) && { + return to_binary_destroy(arr, size, sgnd); + } + + std::ostream& raw_dump(std::ostream& os) const { + return raw_dump_array(os, a); + } + + ArrayRawDumpRef dump() const { + return {a}; + } +}; + +template +struct ModArray { + enum { n = N }; + int a[N]; + ModArray(int v) { + set_int(v); + } + ModArray(long long v) { + set_long(v); + } + ModArray(long v) { + set_long(v); + } + ModArray() = default; + ModArray(const ModArray&) = default; + ModArray(std::initializer_list l) { + auto sz = std::min(l.size(), (std::size_t)N); + std::copy(l.begin(), l.begin() + sz, a); + std::fill(a + sz, a + N, 0); + } + ModArray(const std::array& arr) { + std::copy(arr.begin(), arr.end(), a); + } + template + ModArray(const ModArray& other) { + static_assert(M >= N); + std::copy(other.a, other.a + N, a); + } + ModArray(const int* p) : a(p) { + } + ModArray(std::string str) { + assert(from_dec_string(str) && "not a decimal number"); + } + + ModArray& set_zero() { + std::fill(a, a + N, 0); + return *this; + } + ModArray& set_one() { + std::fill(a, a + N, 1); + return *this; + } + + ModArray& set_int(int v) { + if (v >= 0) { + std::fill(a, a + N, v); + } else { + for (int i = 0; i < N; i++) { + a[i] = mod[i] + v; + } + } + return *this; + } + + ModArray& set_long(long long v) { + for (int i = 0; i < N; i++) { + a[i] = v % mod[i]; + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return *this; + } + + ModArray copy() const { + return ModArray{*this}; + } + + static const int* mod_array() { + return mod; + } + + static int modulus(int i) { + return mod[i]; + } + + static const ModArray& zero(); + static const ModArray& one(); + + ModArray& operator=(const ModArray&) = default; + + template + ModArray& operator=(const ModArray& other) { + static_assert(M >= N); + std::copy(other.a, other.a + N, a); + return *this; + } + + ModArray& negate() { + for (int i = 0; i < N; i++) { + a[i] = (a[i] ? mod[i] - a[i] : 0); + } + return *this; + } + + ModArray& norm_neg() { + for (int i = 0; i < N; i++) { + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return *this; + } + + ModArray& normalize() { + for (int i = 0; i < N; i++) { + a[i] %= mod[i]; + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return *this; + } + + bool is_zero() const { + for (int i = 0; i < N; i++) { + if (a[i]) { + return false; + } + } + return true; + } + + explicit operator bool() const { + return !is_zero(); + } + + bool operator!() const { + return is_zero(); + } + + bool operator==(const ModArray& other) const { + return std::equal(a, a + N, other.a); + } + + bool operator!=(const ModArray& other) const { + return !std::equal(a, a + N, other.a); + } + + bool operator==(long long val) const { + for (int i = 0; i < N; i++) { + int r = (int)(val % mod[i]); + if (a[i] != (r < 0 ? r + mod[i] : r)) { + return false; + } + } + return true; + } + + bool operator!=(long long val) const { + return !operator==(val); + } + + long long try_get_long() const { + return (long long)(MixedRadix<3>(*this)); + } + + bool fits_long() const { + return operator==(try_get_long()); + } + + explicit operator long long() const { + auto v = try_get_long(); + return operator==(v) ? v : -0x8000000000000000; + } + + ModArray& set_sum(const ModArray& x, const ModArray& y) { + for (int i = 0; i < N; i++) { + a[i] = x.a[i] + y.a[i]; + if (a[i] >= mod[i]) { + a[i] -= mod[i]; + } + } + return *this; + } + + ModArray& operator+=(const ModArray& other) { + for (int i = 0; i < N; i++) { + a[i] += other.a[i]; + if (a[i] >= mod[i]) { + a[i] -= mod[i]; + } + } + return *this; + } + + ModArray& operator+=(long long v) { + for (int i = 0; i < N; i++) { + a[i] = (int)((a[i] + v) % mod[i]); + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return *this; + } + + ModArray& operator-=(const ModArray& other) { + for (int i = 0; i < N; i++) { + a[i] -= other.a[i]; + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return *this; + } + + ModArray& operator-=(long long v) { + return (operator+=)(-v); + } + + ModArray& mul_arr(const int other[]) { + for (int i = 0; i < N; i++) { + a[i] = (int)(((long long)a[i] * other[i]) % mod[i]); + } + return *this; + } + + ModArray& operator*=(const ModArray& other) { + return mul_arr(other.a); + } + + template + ModArray& operator*=(const ModArray& other) { + static_assert(M >= N); + return mul_arr(other.a); + } + + ModArray& operator*=(int v) { + for (int i = 0; i < N; i++) { + a[i] = (int)(((long long)a[i] * v) % mod[i]); + } + return (v >= 0 ? *this : norm_neg()); + } + + ModArray& operator*=(long long v) { + for (int i = 0; i < N; i++) { + a[i] = (int)(((long long)a[i] * (v % mod[i])) % mod[i]); + } + return (v >= 0 ? *this : norm_neg()); + } + + ModArray& mul_add(int v, long long w) { + for (int i = 0; i < N; i++) { + a[i] = (int)(((long long)a[i] * v + w) % mod[i]); + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return *this; + } + + // *this = (*this * other) + w + ModArray& mul_add(const ModArray& other, long long w) { + for (int i = 0; i < N; i++) { + a[i] = (int)(((long long)a[i] * other.a[i] + w) % mod[i]); + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return *this; + } + + // *this = (*this << shift) + w + ModArray& lshift_add(int shift, long long w) { + return mul_add(pow2(shift), w); + } + + // *this = *this + other * w + ModArray& add_mul(const ModArray& other, long long w) { + for (int i = 0; i < N; i++) { + a[i] = (int)((a[i] + other.a[i] * w) % mod[i]); + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return *this; + } + + // *this += w << shift + ModArray& add_lshift(int shift, long long w) { + return add_mul(pow2(shift), w); + } + + ModArray operator+(const ModArray& other) const { + ModArray copy{*this}; + copy += other; + return copy; + } + + ModArray operator-(const ModArray& other) const { + ModArray copy{*this}; + copy -= other; + return copy; + } + + ModArray operator+(long long other) const { + ModArray copy{*this}; + copy += other; + return copy; + } + + ModArray operator-(long long other) const { + ModArray copy{*this}; + copy += -other; + return copy; + } + + ModArray operator-() const { + ModArray copy{*this}; + copy.negate(); + return copy; + } + + ModArray operator*(const ModArray& other) const { + ModArray copy{*this}; + copy *= other; + return copy; + } + + ModArray operator*(long long other) const { + ModArray copy{*this}; + copy *= other; + return copy; + } + + bool invert() { + for (int i = 0; i < N; i++) { + int t; + if (gcdx(a[i], mod[i], a[i], t) != 1) { + return false; + } + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return true; + } + + bool try_divide(const ModArray& other) { + for (int i = 0; i < N; i++) { + int q, t; + if (gcdx(other.a[i], mod[i], q, t) != 1) { + return false; + } + a[i] = (int)((long long)a[i] * q % mod[i]); + if (a[i] < 0) { + a[i] += mod[i]; + } + } + return true; + } + + ModArray& operator/=(const ModArray& other) { + assert(try_divide(other) && "division by zero?"); + return *this; + } + + ModArray operator/(const ModArray& other) { + ModArray copy{*this}; + copy /= other; + return copy; + } + + static const ModArray& pow2(int power); + static const ModArray& negpow2(int power); + + ModArray& operator<<=(int lshift) { + return operator*=(pow2(lshift)); + } + + ModArray operator<<(int lshift) const { + return operator*(pow2(lshift)); + } + + ModArray& operator>>=(int rshift) { + return operator/=(pow2(rshift)); + } + + ModArray operator>>(int rshift) const { + return operator/(pow2(rshift)); + } + + template + const ModArray& as_shorter() const { + static_assert(M <= N); + return *reinterpret_cast*>(this); + } + + MixedRadix& to_mixed_radix(MixedRadix& dest, bool sgnd = true) const { + return dest.import_mod_array(a, sgnd); + } + + MixedRadix to_mixed_radix(bool sgnd = true) const { + return MixedRadix(*this, sgnd); + } + + int operator%(int div) const { + return to_mixed_radix() % div; + } + + explicit operator double() const { + return (double)to_mixed_radix(); + } + + std::string to_dec_string() const { + return MixedRadix(*this).to_dec_string(); + } + + std::ostream& print_dec(std::ostream& os, bool sgnd = true) const { + return MixedRadix(*this, sgnd).print_dec(os); + } + + bool to_binary(unsigned char* arr, int size, bool sgnd = true) const { + return MixedRadix(*this, sgnd).to_binary(arr, size, sgnd); + } + + template + bool to_binary(std::array& arr, bool sgnd = true) const { + return to_binary(arr.data(), M, sgnd); + } + + bool from_dec_string(const char* start, const char* end) { + set_zero(); + if (start >= end) { + return false; + } + bool sgn = (*start == '-'); + if (sgn && ++start == end) { + return false; + } + int acc = 0, pow = 1; + while (start < end) { + if (*start < '0' || *start > '9') { + return false; + } + acc = acc * 10 + (*start++ - '0'); + pow *= 10; + if (pow >= 1000000000) { + mul_add(pow, acc); + pow = 1; + acc = 0; + } + } + if (pow > 1) { + mul_add(pow, acc); + } + if (sgn) { + negate(); + } + return true; + } + + bool from_dec_string(std::string str) { + return from_dec_string(str.data(), str.data() + str.size()); + } + + ModArray& from_binary(const unsigned char* arr, int size, bool sgnd = true) { + set_zero(); + if (size <= 0) { + return *this; + } + int i = 0, pow = 0; + long long acc = (sgnd && arr[0] >= 0x80 ? -1 : 0); + while (i < size && arr[i] == (unsigned char)acc) { + i++; + } + for (; i < size; i++) { + pow += 8; + acc = (acc << 8) + arr[i]; + if (pow >= 56) { + lshift_add(pow, acc); + acc = pow = 0; + } + } + if (pow || acc) { + lshift_add(pow, acc); + } + return *this; + } + + template + ModArray& from_binary(const std::array& arr, bool sgnd = true) { + return from_binary(arr.data(), M, sgnd); + } + + std::ostream& raw_dump(std::ostream& os) const { + return raw_dump_array(os, a); + } + + ArrayRawDumpRef dump() const { + return {a}; + } +}; + +template +MixedRadix::MixedRadix(const ModArray& other) { + import_mod_array(other.a); +} + +template +MixedRadix::MixedRadix(const ModArray& other, bool sgnd) { + import_mod_array(other.a, sgnd); +} + +template +MixedRadix& MixedRadix::import_mod_array(const ModArray& other, bool sgnd) { + return import_mod_array(other.a, sgnd); +} + +template +std::ostream& operator<<(std::ostream& os, const ModArray& x) { + return x.print_dec(os); +} + +template +std::ostream& operator<<(std::ostream& os, const MixedRadix& x) { + return x.print_dec(os); +} + +template +std::ostream& operator<<(std::ostream& os, MixedRadix&& x) { + return x.print_dec_destroy(os); +} + +template +struct ArrayRawDumpRef { + const std::array& ref; + ArrayRawDumpRef(const std::array& _ref) : ref(_ref){}; +}; + +template +std::ostream& operator<<(std::ostream& os, ArrayRawDumpRef rd_ref) { + return raw_dump_array(os, rd_ref.ref); +}; + +constexpr int pow2_cnt = 1001; + +ModArray Zero(0), One(1), Pow2[pow2_cnt], NegPow2[pow2_cnt]; +MixedRadix Zero_mr(0), One_mr(1), Pow2_mr[pow2_cnt], NegPow2_mr[pow2_cnt]; + +template +const MixedRadix& MixedRadix::pow2(int power) { + return Pow2_mr[power].as_shorter(); +} + +/* +template +const MixedRadix& MixedRadix::negpow2(int power) { + return NegPow2_mr[power].as_shorter(); +} +*/ + +template +const ModArray& ModArray::pow2(int power) { + return Pow2[power].as_shorter(); +} + +template +const ModArray& ModArray::negpow2(int power) { + return NegPow2[power].as_shorter(); +} + +template +const ModArray& ModArray::zero() { + return Zero.as_shorter(); +} + +template +const ModArray& ModArray::one() { + return One.as_shorter(); +} + +template +const MixedRadix& MixedRadix::zero() { + return Zero_mr.as_shorter(); +} + +template +const MixedRadix& MixedRadix::one() { + return One_mr.as_shorter(); +} + +void init_pow2() { + Pow2[0].set_one(); + Pow2_mr[0].set_one(); + for (int i = 1; i < pow2_cnt; i++) { + Pow2[i].set_sum(Pow2[i - 1], Pow2[i - 1]); + Pow2_mr[i].set_sum(Pow2_mr[i - 1], Pow2_mr[i - 1]); + } + for (int i = 0; i < pow2_cnt; i++) { + NegPow2[i] = -Pow2[i]; + NegPow2_mr[i] = -Pow2_mr[i]; + } +} + +int gcdx(int a, int b, int& u, int& v) { + int a1 = 1, a2 = 0, b1 = 0, b2 = 1; + while (b) { + int q = a / b; + int t = a - q * b; + a = b; + b = t; + t = a1 - q * b1; + a1 = b1; + b1 = t; + t = a2 - q * b2; + a2 = b2; + b2 = t; + } + u = a1; + v = a2; + return a; +} + +void init_invm() { + for (int i = 0; i < mod_cnt; i++) { + assert(mod[i] > 0 && mod[i] <= (1 << 30)); + for (int j = 0; j < i; j++) { + assert(gcdx(mod[i], mod[j], invm[i][j], invm[j][i]) == 1); + if (invm[i][j] < 0) { + invm[i][j] += mod[j]; + } + if (invm[j][i] < 0) { + invm[j][i] += mod[i]; + } + } + } +} + +void init() { + init_invm(); + init_pow2(); +} + +} // namespace modint diff --git a/crypto/test/test-bigint.cpp b/crypto/test/test-bigint.cpp new file mode 100644 index 00000000..bf85e2ad --- /dev/null +++ b/crypto/test/test-bigint.cpp @@ -0,0 +1,876 @@ +/* + This file is part of TON Blockchain Library. + + TON Blockchain Library is free software: you can redistribute it and/or modify + it under the terms of the GNU Lesser General Public License as published by + the Free Software Foundation, either version 2 of the License, or + (at your option) any later version. + + TON Blockchain Library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public License + along with TON Blockchain Library. If not, see . +*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/refcnt.hpp" +#include "common/bigint.hpp" +#include "common/refint.h" +#include "modbigint.cpp" + +#include "td/utils/tests.h" + +int mkint_chk_mode = -1, res_chk_mode = 0; +long long iterations = 100000, cur_iteration = -1, debug_iteration = -2; +#define IFDEBUG if (cur_iteration == debug_iteration || debug_iteration == -3) + +using BInt = modint::ModArray<18>; // integers up to 2^537 +using MRInt = modint::MixedRadix<18>; // auxiliary integer representation for printing, comparing etc + +MRInt p2_256, np2_256, p2_63, np2_63; +constexpr long long ll_min = -2 * (1LL << 62), ll_max = ~ll_min; +constexpr double dbl_pow256 = 1.1579208923731619542e77 /* 0x1p256 */; // 2^256 + +std::mt19937_64 Random(666); + +template +bool equal(td::RefInt256 x, T y) { + return !td::cmp(x, y); +} + +bool equal_or_nan(td::RefInt256 x, td::RefInt256 y) { + return equal(x, y) || (!x->is_valid() && !y->fits_bits(257)) || (!y->is_valid() && !x->fits_bits(257)); +} + +#define CHECK_EQ(__x, __y) CHECK(equal(__x, __y)) +#define CHECK_EQ_NAN(__x, __y) CHECK(equal_or_nan(__x, __y)) + +bool mr_in_range(const MRInt& x) { + return x < p2_256 && x >= np2_256; +} + +bool mr_is_small(const MRInt& x) { + return x < p2_63 && x >= np2_63; +} + +bool mr_fits_bits(const MRInt& x, int bits) { + if (bits > 0) { + return x < MRInt::pow2(bits - 1) && x >= MRInt::negpow2(bits - 1); + } else { + return !bits && !x.sgn(); + } +} + +bool mr_ufits_bits(const MRInt& x, int bits) { + return bits >= 0 && x.sgn() >= 0 && x < MRInt::pow2(bits); +} + +struct ShowBin { + unsigned char* data; + ShowBin(unsigned char _data[64]) : data(_data) { + } +}; + +std::ostream& operator<<(std::ostream& os, ShowBin bin) { + int i = 0, s = bin.data[0]; + if (s == 0 || s == 0xff) { + while (i < 64 && bin.data[i] == s) { + i++; + } + } + if (i >= 3) { + os << (s ? "ff..ff" : "00..00"); + } else { + i = 0; + } + constexpr static char hex_digits[] = "0123456789abcdef"; + while (i < 64) { + int t = bin.data[i++]; + os << hex_digits[t >> 4] << hex_digits[t & 15]; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, const td::AnyIntView& x) { + os << '['; + for (int i = 0; i < x.size(); i++) { + os << ' ' << x.digits[i]; + } + os << " ]"; + return os; +} + +template +bool extract_value_any_bool(BInt& val, const td::AnyIntView& x, bool chk_norm = true) { + int n = x.size(); + if (n <= 0 || n > x.max_size() || (!x.digits[n - 1] && n > 1)) { + return false; + } + assert(n == 1 || x.digits[n - 1] != 0); + val.set_zero(); + for (int i = n - 1; i >= 0; --i) { + val.lshift_add(T::word_shift, x.digits[i]); + if (chk_norm && (x.digits[i] < -T::Half || x.digits[i] >= T::Half)) { + return false; // unnormalized + } + } + return true; +} + +template +bool extract_value_bool(BInt& val, const T& x, bool chk_norm = true) { + return extract_value_any_bool(val, x.as_any_int(), chk_norm); +} + +BInt extract_value_any(const td::AnyIntView& x, bool chk_norm = true) { + BInt res; + CHECK(extract_value_any_bool(res, x, chk_norm)); + return res; +} + +template +BInt extract_value(const T& x, bool chk_norm = true) { + return extract_value_any(x.as_any_int(), chk_norm); +} + +template +BInt extract_value_alt(const T& x) { + BInt res; + const int* md = res.mod_array(); + for (int i = 0; i < res.n / 2; i++) { + T copy{x}; + int m1 = md[2 * i], m2 = md[2 * i + 1]; + long long rem = copy.divmod_short((long long)m1 * m2); + res.a[2 * i] = (int)(rem % m1); + res.a[2 * i + 1] = (int)(rem % m2); + } + if (res.n & 1) { + T copy{x}; + res.a[res.n - 1] = (int)copy.divmod_short(md[res.n - 1]); + } + return res; +} + +constexpr int min_spec_int = -0xfd08, max_spec_int = 0xfd07; +// x = sgn*(ord*256+a*16+b) => sgn*((32+a)*2^(ord-2) + b - 8) +// x = -0xfd08 => -2^256 ... x = 0xfd07 => 2^256 - 1 +td::RefInt256 make_special_int(int x, BInt* ptr = nullptr, unsigned char bin[64] = nullptr) { + bool sgn = (x < 0); + if (sgn) { + x = -x; + } + int ord = (x >> 8) - 2, a = 32 + ((x >> 4) & 15), b = (x & 15) - 8; + if (ord < 0) { + a >>= -ord; + ord = 0; + } + if (sgn) { + a = -a; + b = -b; + } + if (ptr) { + ptr->set_int(a); + *ptr <<= ord; + *ptr += b; + } + if (bin) { + int acc = b, r = ord; + for (int i = 63; i >= 0; --i) { + if (r < 8) { + acc += (a << r); + r = 1024; + } + r -= 8; + bin[i] = (unsigned char)(acc & 0xff); + acc >>= 8; + } + } + return (td::make_refint(a) << ord) + b; +} + +int rand_int(int min, int max) { + return min + (int)(Random() % (max - min + 1)); +} + +unsigned randu() { + return (unsigned)(Random() << 16); +} + +bool coin() { + return Random() & (1 << 28); +} + +// returns 0 with probability 1/2, 1 with prob. 1/4, ..., k with prob. 1/2^(k+1) +int randexp(int max = 63, int min = 0) { + return min + __builtin_clzll(Random() | (1ULL << (63 - max + min))); +} + +void bin_add_small(unsigned char bin[64], long long val, int shift = 0) { + val <<= shift & 7; + for (int i = 63 - (shift >> 3); i >= 0 && val; --i) { + val += bin[i]; + bin[i] = (unsigned char)val; + val >>= 8; + } +} + +// adds sgn * (random number less than 2^(ord - ord2)) * 2^ord2 +td::RefInt256 add_random_bits(td::RefInt256 x, BInt& val, unsigned char bin[64], int ord2, int ord, int sgn = 1) { + int t; + do { + t = std::max((ord - 1) & -16, ord2); + int a = sgn * rand_int(0, (1 << (ord - t)) - 1); + // add a << t + val.add_lshift(t, a); + x += td::make_refint(a) << t; + bin_add_small(bin, a, t); + ord = t; + } while (t > ord2); + return x; +} + +// generates a random integer in range -2^256 .. 2^256-1 (and sometimes outside) +// distribution is skewed towards +/- 2^n +/- 2^n +/- smallint, but completely random integers are also generated +td::RefInt256 make_random_int0(BInt& val, unsigned char bin[64]) { + memset(bin, 0, 64); + int ord = rand_int(-257, 257); + if (ord <= 2 && ord >= -2) { + // -2..2 represent themselves + val.set_int(ord); + bin_add_small(bin, ord); + return td::make_refint(ord); + } + int sgn = (ord < 0 ? -1 : 1); + ord = sgn * ord - 1; + int f = std::min(ord, randexp(15)), a = sgn * rand_int(1 << f, (2 << f) - 1); + ord -= f; + // first summand is a << ord + auto res = td::make_refint(a) << ord; + val.set_int(a); + val <<= ord; + bin_add_small(bin, a, ord); + if (!ord) { + // all bits ready + return res; + } + for (int s = 0; s < 2 && ord; s++) { + // decide whether we want an intermediate order (50%), and whether we want randomness above/below that order + int ord2 = (s ? 0 : std::max(0, rand_int(~ord, ord - 1))); + if (!rand_int(0, 4)) { // 20% + // random bits between ord2 and ord + res = add_random_bits(std::move(res), val, bin, ord2, ord, sgn); + } + if (rand_int(0, 4)) { // 80% + // non-zero adjustment + f = randexp(15); + a = rand_int(-(2 << f) + 1, (2 << f) - 1); + ord = std::max(ord2 - f, 0); + // add a << ord + val.add_lshift(ord, a); + res += (td::make_refint(a) << ord); + bin_add_small(bin, a, ord); + } + } + return res; +} + +td::RefInt256 make_random_int(BInt& val, unsigned char bin[64]) { + while (true) { + auto res = make_random_int0(val, bin); + if (res->fits_bits(257)) { + return res; + } + } +} + +void check_one_int_repr(td::RefInt256 x, int mode, int in_range, const BInt* valptr = nullptr, + const unsigned char bin[64] = nullptr) { + CHECK(x.not_null() && (in_range <= -2 || x->is_valid())); + if (!x->is_valid()) { + // not much to check when x is a NaN + unsigned char bytes[64]; + if (valptr) { + // check that the true answer at `valptr` is out of range + CHECK(!mr_in_range(valptr->to_mixed_radix())); + if (mode & 0x200) { + // check BInt binary export + valptr->to_binary(bytes, 64); + if (bin) { + // check that the two true answers match + CHECK(!memcmp(bin, bytes, 64)); + } else { + bin = bytes; + } + } + } + if (bin) { + // check that the true answer in `bin` is out of range + int i = 0, sgn = (bin[0] >= 0x80 ? -1 : 0); + while (i < 32 && bin[i] == (unsigned char)sgn) + ; + CHECK(i < 32); + if (valptr && (mode & 0x100)) { + // check BInt binary export + BInt val2; + val2.from_binary(bin, 64); + CHECK(*valptr == val2); + } + } + return; + } + unsigned char bytes[64]; + CHECK(x->export_bytes(bytes, 64)); + if (bin) { + CHECK(!memcmp(bytes, bin, 64)); + } + BInt val = extract_value(*x); + if (valptr) { + if (val != *valptr) { + std::cerr << "extracted " << val << " from " << x << ' ' << x->as_any_int() << ", expected " << *valptr + << std::endl; + } + CHECK(val == *valptr); + } + if (mode & 1) { + BInt val2 = extract_value_alt(*x); + CHECK(val == val2); + } + if (mode & 2) { + // check binary import + td::BigInt256 y; + y.import_bytes(bytes, 64); + CHECK(y == *x); + } + if (mode & 0x100) { + // check binary import for BInt + BInt val2; + val2.from_binary(bytes, 64); + CHECK(val == val2); + } + // check if small (fits into 64 bits) + long long xval = (long long)val; + bool is_small = (xval != ll_min || val == xval); + CHECK(is_small == x->fits_bits(64)); + if (is_small) { + // special check for small (64-bit) values + CHECK(x->to_long() == xval); + CHECK((long long)__builtin_bswap64(*(long long*)(bytes + 64 - 8)) == xval); + CHECK(in_range); + // check sign + CHECK(x->sgn() == (xval > 0 ? 1 : (xval < 0 ? -1 : 0))); + // check comparison with long long + CHECK(x == xval); + CHECK(!cmp(x, xval)); + if (mode & 4) { + // check constructor from long long + CHECK(!cmp(x, td::make_refint(xval))); + if (xval != ll_min) { + CHECK(x > xval - 1); + CHECK(x > td::make_refint(xval - 1)); + } + if (xval != ll_max) { + CHECK(x < xval + 1); + CHECK(x < td::make_refint(xval + 1)); + } + } + if (!(mode & ~0x107)) { + return; // fast check for small ints in this case + } + } + + MRInt mval(val); // somewhat slow + bool val_in_range = mr_in_range(mval); + CHECK(x->fits_bits(257) == val_in_range); + if (in_range >= 0) { + CHECK((int)val_in_range == in_range); + } + if (mode & 0x200) { + // check binary export for BInt + unsigned char bytes2[64]; + mval.to_binary(bytes2, 64); + CHECK(!memcmp(bytes, bytes2, 64)); + } + // check sign + int sgn = mval.sgn(); + CHECK(x->sgn() == sgn); + CHECK(is_small == mr_is_small(mval)); + if (is_small) { + CHECK((long long)mval == xval); + } + if (mode & 0x10) { + // check decimal export + std::string dec = mval.to_dec_string(); + CHECK(x->to_dec_string() == dec); + // check decimal import + td::BigInt256 y; + int l = y.parse_dec(dec); + CHECK((std::size_t)l == dec.size() && y == *x); + if (mode & 0x1000) { + // check decimal import for BInt + BInt val2; + CHECK(val2.from_dec_string(dec) && val2 == val); + } + } + if (mode & 0x20) { + // check binary bit size + int sz = x->bit_size(); + CHECK(sz >= 0 && sz <= 300); + CHECK(x->fits_bits(sz) && (!sz || !x->fits_bits(sz - 1))); + CHECK(mr_fits_bits(mval, sz) && !mr_fits_bits(mval, sz - 1)); + int usz = x->bit_size(false); + CHECK(sgn >= 0 || usz == 0x7fffffff); + if (sgn >= 0) { + CHECK(x->unsigned_fits_bits(usz) && (!usz || !x->unsigned_fits_bits(usz - 1))); + CHECK(mr_ufits_bits(mval, usz) && !mr_ufits_bits(mval, usz - 1)); + } else { + CHECK(!x->unsigned_fits_bits(256) && !x->unsigned_fits_bits(300)); + } + } +} + +void init_aux() { + np2_256 = p2_256 = MRInt::pow2(256); + np2_256.negate(); + CHECK(np2_256 == MRInt::negpow2(256)); + p2_63 = np2_63 = MRInt::pow2(63); + np2_63.negate(); + CHECK(np2_63 == MRInt::negpow2(63)); +} + +std::vector SpecInt; +BInt SpecIntB[max_spec_int - min_spec_int + 1]; + +void init_check_special_ints() { + std::cerr << "check special ints" << std::endl; + BInt b; + unsigned char binary[64]; + for (int idx = min_spec_int - 512; idx <= max_spec_int + 512; idx++) { + td::RefInt256 x = make_special_int(idx, &b, binary); + check_one_int_repr(x, mkint_chk_mode, idx >= min_spec_int && idx <= max_spec_int, &b, binary); + if (idx >= min_spec_int && idx <= max_spec_int) { + SpecIntB[idx - min_spec_int] = b; + SpecInt.push_back(std::move(x)); + } + } +} + +void check_res(td::RefInt256 y, const BInt& yv) { + check_one_int_repr(std::move(y), res_chk_mode, -2, &yv); +} + +void check_unary_ops_on(td::RefInt256 x, const BInt& xv) { + // NEGATE + BInt yv = -xv; + check_res(-x, yv); + // NOT + check_res(~x, yv -= 1); +} + +void check_unary_ops() { + std::cerr << "check unary ops" << std::endl; + for (int idx = min_spec_int; idx <= max_spec_int; idx++) { + check_unary_ops_on(SpecInt[idx - min_spec_int], SpecIntB[idx - min_spec_int]); + } +} + +void check_pow2_ops(int shift) { + // POW2 + td::RefInt256 r{true}; + r.unique_write().set_pow2(shift); + check_res(r, BInt::pow2(shift)); + // POW2DEC + r.unique_write().set_pow2(shift).add_tiny(-1).normalize(); + check_res(r, BInt::pow2(shift) - 1); + // NEGPOW2 + r.unique_write().set_pow2(shift).negate().normalize(); + check_res(r, -BInt::pow2(shift)); +} + +void check_pow2_ops() { + std::cerr << "check power-2 ops" << std::endl; + for (int i = 0; i <= 256; i++) { + check_pow2_ops(i); + } +} + +void check_shift_ops_on(int shift, td::RefInt256 x, const BInt& xv, const MRInt& mval) { + // LSHIFT + check_res(x << shift, xv << shift); + // FITS + CHECK(x->fits_bits(shift) == mr_fits_bits(mval, shift)); + // UFITS + CHECK(x->unsigned_fits_bits(shift) == mr_ufits_bits(mval, shift)); + // ADDPOW2 / SUBPOW2 + auto y = x; + y.write().add_pow2(shift).normalize(); + check_res(std::move(y), xv + BInt::pow2(shift)); + y = x; + y.write().sub_pow2(shift).normalize(); + check_res(std::move(y), xv - BInt::pow2(shift)); + // RSHIFT, MODPOW2 + for (int round_mode = -1; round_mode <= 1; round_mode++) { + auto r = x, q = td::rshift(x, shift, round_mode); // RSHIFT + CHECK(q.not_null() && q->is_valid()); + r.write().mod_pow2(shift, round_mode).normalize(); // MODPOW2 + CHECK(r.not_null() && r->is_valid()); + if (round_mode < 0) { + CHECK(!cmp(x >> shift, q)); // operator>> should be equivalent to td::rshift + } + BInt qv = extract_value(*q), rv = extract_value(*r); + // check main division equality (q << shift) + r == x + CHECK((qv << shift) + rv == xv); + MRInt rval(rv); + // check remainder range + switch (round_mode) { + case 1: + rval.negate(); // fallthrough + case -1: + CHECK(mr_ufits_bits(rval, shift)); + break; + case 0: + CHECK(mr_fits_bits(rval, shift)); + } + } +} + +void check_shift_ops() { + std::cerr << "check left/right shift ops" << std::endl; + for (int idx = min_spec_int; idx <= max_spec_int; idx++) { + //for (int idx : {-52240, -52239, -52238, -3, -2, -1, 0, 1, 2, 3, 52238, 52239, 52240}) { + const auto& xv = SpecIntB[idx - min_spec_int]; + MRInt mval(xv); + if (!(idx % 1000)) { + std::cerr << "# " << idx << " : " << mval << std::endl; + } + for (int i = 0; i <= 256; i++) { + check_shift_ops_on(i, SpecInt[idx - min_spec_int], xv, mval); + } + } +} + +void check_remainder_range(BInt& rv, const BInt& dv, int rmode = -1) { + if (rmode > 0) { + rv.negate(); + } else if (!rmode) { + rv *= 2; + } + MRInt d(dv), r(rv); + int ds = d.sgn(), rs = r.sgn(); + //std::cerr << "rmode=" << rmode << " ds=" << ds << " rs=" << rs << " d=" << d << " r=" << r << std::endl; + if (!rs) { + return; + } + if (rmode) { + // must have 0 < r < d or 0 > r > d + //if (rs != ds) std::cerr << "iter=" << cur_iteration << " : rmode=" << rmode << " ds=" << ds << " rs=" << rs << " d=" << d << " r=" << r << std::endl; + CHECK(rs == ds); + CHECK(ds * r.cmp(d) < 0); + } else { + // must have -d <= r < d or -d >= r > d + if (rs == -ds) { + r.negate(); + CHECK(ds * r.cmp(d) <= 0); + } else { + CHECK(ds * r.cmp(d) < 0); + } + } +} + +void check_divmod(td::RefInt256 x, const BInt& xv, long long xl, td::RefInt256 y, const BInt& yv, long long yl, + int rmode = -2) { + if (rmode < -1) { + //IFDEBUG std::cerr << " divide " << x << " / " << y << std::endl; + for (rmode = -1; rmode <= 1; rmode++) { + check_divmod(x, xv, xl, y, yv, yl, rmode); + } + return; + } + auto dm = td::divmod(x, y, rmode); + auto q = std::move(dm.first), r = std::move(dm.second); + if (!yl) { + // division by zero + CHECK(q.not_null() && !q->is_valid() && r.not_null() && !r->is_valid()); + return; + } + CHECK(q.not_null() && q->is_valid() && r.not_null() && r->is_valid()); + CHECK_EQ(x, y * q + r); + BInt qv = extract_value(*q), rv = extract_value(*r); + CHECK(xv == yv * qv + rv); + //IFDEBUG std::cerr << " quot=" << q << " rem=" << r << std::endl; + check_remainder_range(rv, yv, rmode); + if (yl != ll_min && rmode == -1) { + // check divmod_short() + auto qq = x; + auto rem = qq.write().divmod_short(yl); + qq.write().normalize(); + CHECK(qq->is_valid()); + CHECK_EQ(qq, q); + CHECK(r == rem); + if (xl != ll_min) { + auto dm = std::lldiv(xl, yl); + if (dm.rem && (dm.rem ^ yl) < 0) { + dm.rem += yl; + dm.quot--; + } + CHECK(q == dm.quot); + CHECK(r == dm.rem); + } + } +} + +void check_binary_ops_on(td::RefInt256 x, const BInt& xv, td::RefInt256 y, const BInt& yv) { + bool x_small = x->fits_bits(62), y_small = y->fits_bits(62); // not 63 + long long xl = x_small ? x->to_long() : ll_min, yl = y_small ? y->to_long() : ll_min; + if (x_small) { + CHECK(x == xl); + } + if (y_small) { + CHECK(y == yl); + } + // ADD, ADDR + auto z = x + y, w = y + x; + CHECK_EQ(z, w); + check_res(z, xv + yv); + // ADDCONST + if (y_small) { + CHECK_EQ(z, x + yl); + } + if (x_small) { + CHECK_EQ(z, y + xl); + } + if (x_small && y_small) { + CHECK_EQ(z, xl + yl); + } + // SUB + z = x - y; + check_res(z, xv - yv); + // SUBCONST + if (y_small) { + CHECK_EQ(z, x - yl); + if (x_small) { + CHECK_EQ(z, xl - yl); + } + } + // SUBR + z = y - x; + check_res(z, yv - xv); + if (x_small) { + CHECK_EQ(z, y - xl); + if (y_small) { + CHECK_EQ(z, yl - xl); + } + } + // CMP + MRInt xmr(xv), ymr(yv); + int cmpv = xmr.cmp(ymr); + CHECK(td::cmp(x, y) == cmpv); + CHECK(td::cmp(y, x) == -cmpv); + if (y_small) { + CHECK(td::cmp(x, yl) == cmpv); + } + if (x_small) { + CHECK(td::cmp(y, xl) == -cmpv); + } + if (x_small && y_small) { + CHECK(cmpv == (xl < yl ? -1 : (xl > yl ? 1 : 0))); + } + // MUL + z = x * y; + BInt zv = xv * yv; + check_res(z, zv); + CHECK_EQ(z, y * x); + // MULCONST + if (y_small) { + CHECK_EQ_NAN(z, x * yl); + } + if (x_small) { + CHECK_EQ_NAN(z, y * xl); + } + if (x_small && y_small && (!yl || std::abs(xl) <= ll_max / std::abs(yl))) { + CHECK_EQ(z, xl * yl); + } + // DIVMOD + if (z->fits_bits(257)) { + int adj = 2 * rand_int(-2, 2) - (int)z->is_odd(); + z += adj; + z >>= 1; + zv += adj; + zv >>= 1; + // z is approximately x * y / 2; divide by y + check_divmod(z, zv, z->fits_bits(62) ? z->to_long() : ll_min, y, yv, yl); + } + check_divmod(x, xv, xl, y, yv, yl); +} + +void finish_check_muldivmod(td::RefInt256 x, const BInt& xv, td::RefInt256 y, const BInt& yv, td::RefInt256 z, + const BInt& zv, td::RefInt256 q, td::RefInt256 r, int rmode) { + static constexpr double eps = 1e-14; + CHECK(q.not_null() && r.not_null()); + //std::cerr << " muldivmod: " << xv << " * " << yv << " / " << zv << " (round " << rmode << ") = " << q << " " << r << std::endl; + if (!zv) { + // division by zero + CHECK(!q->is_valid() && !r->is_valid()); + return; + } + CHECK(r->is_valid()); // remainder always exists if y != 0 + BInt xyv = xv * yv, rv = extract_value(*r); + MRInt xy_mr(xyv), z_mr(zv); + double q0 = (double)xy_mr / (double)z_mr; + if (std::abs(q0) < 1.01 * dbl_pow256) { + // result more or less in range + CHECK(q->is_valid()); + } else if (!q->is_valid()) { + // result out of range, NaN is an acceptable answer + // check that x * y - r is divisible by z + xyv -= rv; + xyv /= zv; + xy_mr = xyv; + double q1 = (double)xy_mr; + CHECK(std::abs(q1 - q0) < eps * std::abs(q0)); + } else { + BInt qv = extract_value(*q); + // must have x * y = z * q + r + CHECK(xv * yv == zv * qv + rv); + } + // check that r is in correct range [0, z) or [0, -z) or [-z/2, z/2) + check_remainder_range(rv, zv, rmode); +} + +void check_muldivmod_on(td::RefInt256 x, const BInt& xv, td::RefInt256 y, const BInt& yv, td::RefInt256 z, + const BInt& zv, int rmode = 2) { + if (rmode < -1) { + for (rmode = -1; rmode <= 1; rmode++) { + check_muldivmod_on(x, xv, y, yv, z, zv, rmode); + } + return; + } else if (rmode > 1) { + rmode = rand_int(-1, 1); + } + // MULDIVMOD + auto qr = td::muldivmod(x, y, z, rmode); + finish_check_muldivmod(std::move(x), xv, std::move(y), yv, std::move(z), zv, std::move(qr.first), + std::move(qr.second), rmode); +} + +void check_mul_rshift_on(td::RefInt256 x, const BInt& xv, td::RefInt256 y, const BInt& yv, int shift, int rmode = 2) { + if (rmode < -1) { + for (rmode = -1; rmode <= 1; rmode++) { + check_mul_rshift_on(x, xv, y, yv, shift, rmode); + } + return; + } else if (rmode > 1) { + rmode = rand_int(-1, 1); + } + // MULRSHIFTMOD + typename td::BigInt256::DoubleInt tmp{0}; + tmp.add_mul(*x, *y); + typename td::BigInt256::DoubleInt tmp2{tmp}; + tmp2.rshift(shift, rmode).normalize(); + tmp.normalize().mod_pow2(shift, rmode).normalize(); + finish_check_muldivmod(std::move(x), xv, std::move(y), yv, {}, BInt::pow2(shift), td::make_refint(tmp2), + td::make_refint(tmp), rmode); +} + +void check_lshift_div_on(td::RefInt256 x, const BInt& xv, td::RefInt256 y, const BInt& yv, int shift, int rmode = 2) { + if (rmode < -1) { + for (rmode = -1; rmode <= 1; rmode++) { + check_lshift_div_on(x, xv, y, yv, shift, rmode); + } + return; + } else if (rmode > 1) { + rmode = rand_int(-1, 1); + } + // LSHIFTDIV + typename td::BigInt256::DoubleInt tmp{*x}, quot; + tmp <<= shift; + tmp.mod_div(*y, quot, rmode); + quot.normalize(); + finish_check_muldivmod(std::move(x), xv, {}, BInt::pow2(shift), std::move(y), yv, td::make_refint(quot), + td::make_refint(tmp), rmode); +} + +void check_random_ops() { + constexpr long long chk_it = 100000; + std::cerr << "check random ops (" << iterations << " iterations)" << std::endl; + BInt xv, yv, zv; + unsigned char xbin[64], ybin[64], zbin[64]; + for (cur_iteration = 0; cur_iteration < iterations; cur_iteration++) { + auto x = make_random_int0(xv, xbin); + if (!(cur_iteration % 10000)) { + std::cerr << "#" << cur_iteration << ": check on " << xv << " = " << ShowBin(xbin) << " = " << x->as_any_int() + << std::endl; + } + check_one_int_repr(x, cur_iteration < chk_it ? -1 : 0, -1, &xv, xbin); + MRInt xmr(xv); + if (!x->fits_bits(257)) { + continue; + } + check_unary_ops_on(x, xv); + for (int j = 0; j < 10; j++) { + int shift = rand_int(0, 256); + //std::cerr << "check shift by " << shift << std::endl; + check_shift_ops_on(shift, x, xv, xmr); + auto y = make_random_int(yv, ybin); + //std::cerr << " y = " << y << " = " << yv << " = " << ShowBin(ybin) << " = " << y->as_any_int() << std::endl; + check_one_int_repr(y, 0, 1, &yv, ybin); + check_binary_ops_on(x, xv, y, yv); + //std::cerr << " *>> " << shift << std::endl; + check_mul_rshift_on(x, xv, y, yv, shift); + //std::cerr << " <as_any_int() << std::endl; + check_muldivmod_on(x, xv, y, yv, z, zv); + } + } +} + +void check_special() { + std::cerr << "run special tests" << std::endl; + check_divmod((td::make_refint(-1) << 207) - 1, BInt::negpow2(207) - 1, ll_min, (td::make_refint(1) << 207) - 1, + BInt::pow2(207) - 1, ll_min); +} + +int main(int argc, char* const argv[]) { + bool do_check_shift_ops = false; + int i; + while ((i = getopt(argc, argv, "hSs:i:")) != -1) { + switch (i) { + case 'S': + do_check_shift_ops = true; + break; + case 's': + Random.seed(atoll(optarg)); + break; + case 'i': + iterations = atoll(optarg); + break; + default: + std::cerr << "unknown option: " << (char)i << std::endl; + // fall through + case 'h': + std::cerr << "usage:\t" << argv[0] << " [-S] [-i] [-s]" << std::endl; + return 2; + } + } + modint::init(); + init_aux(); + init_check_special_ints(); + check_pow2_ops(); + check_unary_ops(); + if (do_check_shift_ops) { + check_shift_ops(); + } + check_special(); + check_random_ops(); + return 0; +} diff --git a/crypto/vm/arithops.cpp b/crypto/vm/arithops.cpp index 823b4408..2831944b 100644 --- a/crypto/vm/arithops.cpp +++ b/crypto/vm/arithops.cpp @@ -387,18 +387,16 @@ int exec_muldivmod(VmState* st, unsigned args, int quiet) { auto z = stack.pop_int(); auto y = stack.pop_int(); auto x = stack.pop_int(); - typename td::BigInt256::DoubleInt tmp{0}; + typename td::BigInt256::DoubleInt tmp{0}, quot; tmp.add_mul(*x, *y); auto q = td::make_refint(); - tmp.mod_div(*z, q.unique_write(), round_mode); + tmp.mod_div(*z, quot, round_mode); switch ((args >> 2) & 3) { case 1: - q.unique_write().normalize(); - stack.push_int_quiet(std::move(q), quiet); + stack.push_int_quiet(td::make_refint(quot.normalize()), quiet); break; case 3: - q.unique_write().normalize(); - stack.push_int_quiet(std::move(q), quiet); + stack.push_int_quiet(td::make_refint(quot.normalize()), quiet); // fallthrough case 2: stack.push_int_quiet(td::make_refint(tmp), quiet); @@ -459,7 +457,7 @@ int exec_mulshrmod(VmState* st, unsigned args, int mode) { } // fallthrough case 2: - tmp.mod_pow2(z, round_mode).normalize(); + tmp.normalize().mod_pow2(z, round_mode).normalize(); stack.push_int_quiet(td::make_refint(tmp), mode & 1); break; } @@ -520,21 +518,17 @@ int exec_shldivmod(VmState* st, unsigned args, int mode) { } auto z = stack.pop_int(); auto x = stack.pop_int(); - typename td::BigInt256::DoubleInt tmp{*x}; + typename td::BigInt256::DoubleInt tmp{*x}, quot; tmp <<= y; switch ((args >> 2) & 3) { case 1: { - auto q = td::make_refint(); - tmp.mod_div(*z, q.unique_write(), round_mode); - q.unique_write().normalize(); - stack.push_int_quiet(std::move(q), mode & 1); + tmp.mod_div(*z, quot, round_mode); + stack.push_int_quiet(td::make_refint(quot.normalize()), mode & 1); break; } case 3: { - auto q = td::make_refint(); - tmp.mod_div(*z, q.unique_write(), round_mode); - q.unique_write().normalize(); - stack.push_int_quiet(std::move(q), mode & 1); + tmp.mod_div(*z, quot, round_mode); + stack.push_int_quiet(td::make_refint(quot.normalize()), mode & 1); stack.push_int_quiet(td::make_refint(tmp), mode & 1); break; } diff --git a/doc/tvm.tex b/doc/tvm.tex index 4f887ac2..71aa7abc 100644 --- a/doc/tvm.tex +++ b/doc/tvm.tex @@ -1531,6 +1531,7 @@ Examples: \item {\tt A934$tt$} --- same as {\tt RSHIFT $tt+1$}: ($x$ -- $\lfloor x\cdot 2^{-tt-1}\rfloor$). \item {\tt A938$tt$} --- {\tt MODPOW2 $tt+1$}: ($x$ -- $x\bmod 2^{tt+1}$). \item {\tt A985} --- {\tt MULDIVR} ($x$ $y$ $z$ -- $q'$), where $q'=\lfloor xy/z+1/2\rfloor$. +\item {\tt A988} --- {\tt MULMOD} ($x$ $y$ $z$ -- $r$), where $r=xy\bmod z=xy-qz$, $q=\lfloor xy/z\rfloor$. This operation always succeeds for $z\neq0$ and returns the correct value of~$r$, even if the intermediate result $xy$ or the quotient $q$ do not fit into 257 bits. \item {\tt A98C} --- {\tt MULDIVMOD} ($x$ $y$ $z$ -- $q$ $r$), where $q:=\lfloor x\cdot y/z\rfloor$, $r:=x\cdot y\bmod z$ (same as {\tt */MOD} in Forth). \item {\tt A9A4} --- {\tt MULRSHIFT} ($x$ $y$ $z$ -- $\lfloor xy\cdot2^{-z}\rfloor$) for $0\leq z\leq 256$. \item {\tt A9A5} --- {\tt MULRSHIFTR} ($x$ $y$ $z$ -- $\lfloor xy\cdot2^{-z}+1/2\rfloor$) for $0\leq z\leq 256$. @@ -1576,10 +1577,12 @@ We opted to make all arithmetic operations ``non-quiet'' (signaling) by default, \begin{itemize} \item {\tt B7xx} --- {\tt QUIET} prefix, transforming any arithmetic operation into its ``quiet'' variant, indicated by prefixing a {\tt Q} to its mnemonic. Such operations return {\tt NaN}s instead of throwing integer overflow exceptions if the results do not fit in {\it Integer\/}s, or if one of their arguments is a {\tt NaN}. Notice that this does not extend to shift amounts and other parameters that must be within a small range (e.g., 0--1023). Also notice that this does not disable type-checking exceptions if a value of a type other than {\it Integer\/} is supplied. \item {\tt B7A0} --- {\tt QADD} ($x$ $y$ -- $x+y$), always works if $x$ and $y$ are {\it Integer\/}s, but returns a {\tt NaN} if the addition cannot be performed. +\item {\tt B7A8} --- {\tt QMUL} ($x$ $y$ -- $xy$), returns the product of $x$ and $y$ if $-2^{256}\leq xy<2^{256}$. Otherwise returns a {\tt NaN}, even if $x=0$ and $y$ is a {\tt NaN}. \item {\tt B7A904} --- {\tt QDIV} ($x$ $y$ -- $\lfloor x/y\rfloor$), returns a {\tt NaN} if $y=0$, or if $y=-1$ and $x=-2^{256}$, or if either of $x$ or $y$ is a {\tt NaN}. +\item {\tt B7A98C} --- {\tt QMULDIVMOD} ($x$ $y$ $z$ -- $q$ $r$), where $q:=\lfloor x\cdot y/z\rfloor$, $r:=x\cdot y\bmod z$. If $z=0$, or if at least one of $x$, $y$, or $z$ is a {\tt NaN}, both $q$ and $r$ are set to {\tt NaN}. Otherwise the correct value of $r$ is always returned, but $q$ is replaced with {\tt NaN} if $q<-2^{256}$ or $q\geq2^{256}$. \item {\tt B7B0} --- {\tt QAND} ($x$ $y$ -- $x\&y$), bitwise ``and'' (similar to {\tt AND}), but returns a {\tt NaN} if either $x$ or $y$ is a {\tt NaN} instead of throwing an integer overflow exception. However, if one of the arguments is zero, and the other is a {\tt NaN}, the result is zero. \item {\tt B7B1} --- {\tt QOR} ($x$ $y$ -- $x\vee y$), bitwise ``or''. If $x=-1$ or $y=-1$, the result is always $-1$, even if the other argument is a {\tt NaN}. -\item {\tt B7B507} --- {\tt QUFITS 8} ($x$ -- $x'$), checks whether $x$ is an unsigned byte (i.e., whether $0\leq x<2^8$), and replaces $x$ with a {\tt NaN} if this is not the case; leaves $x$ intact otherwise (i.e., if $x$ is an unsigned byte). +\item {\tt B7B507} --- {\tt QUFITS 8} ($x$ -- $x'$), checks whether $x$ is an unsigned byte (i.e., whether $0\leq x<2^8$), and replaces $x$ with a {\tt NaN} if this is not the case; leaves $x$ intact otherwise (i.e., if $x$ is an unsigned byte or a {\tt NaN}). \end{itemize} \mysubsection{Comparison primitives} diff --git a/rldp2/CMakeLists.txt b/rldp2/CMakeLists.txt index f93f6e1e..33ae5b67 100644 --- a/rldp2/CMakeLists.txt +++ b/rldp2/CMakeLists.txt @@ -51,7 +51,7 @@ target_include_directories(rldp PUBLIC ${OPENSSL_INCLUDE_DIR} ) if (GSL_FOUND) - target_link_libraries(rldp2 PRIVATE GSL::gsl) + target_link_libraries(rldp2 PRIVATE gsl) target_compile_definitions(rldp2 PRIVATE -DTON_HAVE_GSL=1) endif() target_link_libraries(rldp2 PUBLIC tdutils tdactor fec adnl tl_api)