CodeArchive BigIntSqrt

From Ubcacm
Jump to: navigation, search

Square root of Big Integer

Make use of an unknown algorithm (I forgot its name), and tested against 10023. The algorithm runs in O(n^2) where n = the number of digits in the number. Basically, we perform a search for each digit, from the most to the least significant one. In both case, the original number is destroyed.

The simple version of it ran 9.416s on 10023. The search used is linear search. Run time is thus O(n^2 * d).

#include <iostream>
#include <string>
#include <algorithm>
#include <cstdio>

using namespace std;

int sqrt(int n, char *x, char* y) {
  // Find the first digit
  int first_digit = x[0];
  if (!(n&1)) first_digit = x[0] * 10 + x[1];
  for (y[0] = 0; y[0] * y[0] <= first_digit; ++y[0]);
  --y[0];
  if (n&1) x[0] -= y[0]*y[0];
  else {x[0] -= (y[0]*y[0])/10; x[1] -= (y[0]*y[0])%10;}

  for (int len = 4-(n&1), i = 1; 2*i < n; ++i, len += 2) { // Try y[i]
    int carry = 0, extra = -1;
    y[i] = 0;
    while(carry >= 0) {
      ++y[i];
      extra += 2;
      carry = x[len-1] - extra;
      x[len-1] = carry%10;
      carry /= 10;
      if (x[len-1] < 0) {--carry; x[len-1] += 10;}
      for (int j = len-1, k = i-1;j--; --k) {
        carry += x[j];
        if (k >= 0) carry -= y[k]+y[k];
        x[j] = carry%10;
        carry /= 10;
        if (x[j] < 0) { --carry; x[j] += 10; }
      }
    }
    --y[i];
    carry = x[len-1] + extra;
    x[len-1] = carry%10;
    carry /= 10;
    for (int j = len-1, k = i-1; j--; --k) {
      carry += x[j];
      if (k >= 0) carry += y[k]+y[k];
      x[j] = carry%10;
      carry /= 10;
    }
  }
  return (n+1)/2;
}

int main() {
  int num_cases;
  string strg;
  char X[1024], Y[1024];
  scanf("%d", &num_cases);
  while(num_cases--) {
    scanf("%s", X);
    int n = 0;
    while(X[n]) X[n++] -= '0';

    int s = sqrt(n, X, Y);
    for (int i = 0; i < s; ++i) Y[i] += '0';
    Y[s] = 0;
    printf("%s\n", Y);
    if (num_cases) printf("\n");
  }
}

The second version of it ran 4.271s on 10023. Used binary search instead. Run time is thus O(n^2 * lgd). However, d should be at least 1000 for this to be effective.

#include <iostream>
#include <string>
#include <algorithm>
#include <cstdio>

using namespace std;

int SQR[1024];
const int DIGIT_BASE = 1000;

inline int sub_w_extra(int nx, int* x, int ny, int* y, int m, int ex) { // x[0..nx-1] = x[0..nx-1] - (2*DIGIT_BASE*m*y + ex)
  int carry = x[nx-1] - ex;
  x[nx-1] = carry%DIGIT_BASE;
  carry /= DIGIT_BASE;
  if (x[nx-1] < 0) {--carry; x[nx-1] += DIGIT_BASE; }
  for (int j = nx-1, k = ny-1; j--; --k) {
    carry += x[j];
    if (k >= 0) carry -= 2*m*y[k];
    x[j] = carry%DIGIT_BASE;
    carry /= DIGIT_BASE;
    if (x[j] < 0) { --carry; x[j] += DIGIT_BASE; }
  }
  return carry;
}

inline int add_w_extra(int nx, int* x, int ny, int* y, int m, int ex) { // x[0..nx-1] = x[0..nx-1] + (2*DIGIT_BASE*m*y + ex)
  int carry = x[nx-1] + ex;
  x[nx-1] = carry%DIGIT_BASE;
  carry /= DIGIT_BASE;
  for (int j = nx-1, k = ny-1; j--; --k) {
    carry += x[j];
    if (k >= 0) carry += 2*m*y[k];
    x[j] = carry%DIGIT_BASE;
    carry /= DIGIT_BASE;
  }
  return carry;
}

int sqrt(int n, int* x, int* y) {
  // Find the first digit
  int first_digit = x[0];
  if (!(n&1)) first_digit = x[0] * DIGIT_BASE + x[1];
  for (y[0] = 0; y[0] * y[0] <= first_digit; ++y[0]);
  --y[0];
  if (n&1) x[0] -= y[0]*y[0];
  else {x[0] -= (y[0]*y[0])/DIGIT_BASE; x[1] -= (y[0]*y[0])%DIGIT_BASE;}

  for (int len = 4-(n&1), i = 1; 2*i < n; ++i, len += 2) { // Try y[i]
    y[i] = 0;
    int l = 0, r = DIGIT_BASE, mid;
    while( r > l + 1) { // Binary search...
      mid = (r+l) >> 1;
      if (sub_w_extra(len, x, i, y, mid - y[i], SQR[mid] - SQR[y[i]]) >= 0)
        y[i] = l = mid;
      else {
        add_w_extra(len, x, i, y, mid - y[i], SQR[mid] - SQR[y[i]]);
        r = mid;
      }
    }
    y[i] = l;
  }
  return (n+1)/2;
}

int main() {
  int num_cases;
  string strg;
  int X[1024], Y[1024];
  char input[1024];
  for (int i = 0; i <= 1000; ++i) SQR[i] = i*i;
  scanf("%d", &num_cases);
  while(num_cases--) {
    scanf("%s", input);
    int n = 0;
    while(input[n]) input[n++] -= '0';
    int idx = 0;
    if (n%3 == 1) X[idx++] = input[0];
    else if(n%3 == 2) X[idx++] = input[0] * 10 + input[1];
    for (int i = n%3; i < n; i+=3)
      X[idx++] = input[i] * 100 + input[i+1]*10 + input[i+2];

    int s = sqrt(idx, X, Y);
    printf("%d", Y[0]);
    for (int i = 1; i < s; ++i) printf("%03d", Y[i]);
    printf("\n");
    if (num_cases) printf("\n");
  }
}

-- Main.MatthewChan - 02 Feb 2006