#include <na.h>

#define ZILDE(X)        mkobj((X), TYPN, 0, 0)

#define D_0             0.0
#define D_1             1.0
#define D_EPS           1.0e-12

DOUBLE  zdet = D_1;

fx_mult(dst, lo, hi)    /* matrix multiply */
MATRIX  dst;            /* dst = lo * hi as a matrix product */
MATRIX  lo;             /* a matrix or vector */
MATRIX  hi;
{
BYTE   *s  = lo->ADDR;
int     dx = MTYPE(s);
int     nx = MSIZE(s);
int     cx = MCOLS(s);
BYTE   *t  = hi->ADDR;
int     dy = MTYPE(t);
int     ny = MSIZE(t);
int     cy = MCOLS(t);
int     qvec = (cy == ny);      /* matrix * vector */
int     j, k, n, nr, nc, sum;
MATRIX  p;
int    *di, *a, *b, *xs, *ys;
DOUBLE *zd, *ax, *by, *zx, *zy;
DOUBLE  zsum;

       /* null absorbs    X + NULL -> NULL */
if (nx == 0 || ny == 0) return ZILDE(dst);
if (cx == 0 || cy == 0) return ERROR;
if (qvec && cx == cy) cy = 1;
else if (cx * cy != ny) return ERROR;
nr = nx/cx;

switch (dx | dy) {
    case TYPT:
    default: return ERROR;
    case TYPN:          /* integer */
    case TYPD:  break;  /* double  */
    case TYPN|TYPD:     /* mixed */
        p = (dx == TYPN) ? lo : hi;
        if (fx_cvd(dst, p) || sa_swap(p, dst)) return ERROR;
        dx = dy = TYPD;
        }

if ((qvec) ? mkobj(dst, dx, nr, nr)
           : mkobj(dst, dx, nr * cy, cy)) return ERROR;

if (dx == TYPN)
    for (di = INT_PTR(dst), a = INT_PTR(lo), b = INT_PTR(hi);
        0 < nr--;
        a += cx, b = INT_PTR(hi))
        for (j = cy; 0 < j--; b++) {
            for (xs = a, ys = b, sum = 0, k = cx; 0 < k--; ys += cy)
                sum += (*xs++) * *ys;
            *di++ = sum;
            }
else if (dx == TYPD)
    for (zd = DBL_PTR(dst), ax = DBL_PTR(lo), by = DBL_PTR(hi);
        0 < nr--;
        ax += cx, by = DBL_PTR(hi))
        for (j = cy; 0 < j--; by++) {
            for (zx = ax, zy = by, zsum = D_0, k = cx; 0 < k--; zy += cy)
                zsum += (*zx++) * *zy;
            *zd++ = zsum;
            }
return OK;
}

fx_domino(dst, src)     /* inverse matrix */
MATRIX  dst;            /* destroys dst, src */
MATRIX  src;            /* this is destroyed */
{
DOUBLE  fabs();
BYTE   *s = src->ADDR;
int     dx = MTYPE(s);
int     n  = MSIZE(s);
int     nc = MCOLS(s);
int     ncc = nc * sizeof(DOUBLE);
int     nc2 = 2 * nc;
BYTE   *t;
DOUBLE  smax, xtmp, xpivot, xlambda;
DOUBLE *si, *di, *xij;
int     i, j, k, p;
int    *imap, *ix, *qmap;

if (n != nc * nc) return ERROR;
if (mkobj(dst, 'D', 2 * nc * nc, 2 * nc)) return ERROR;

if (dx == TYPN) {
    k = 1;      /* convert integer */
    ix = INT_PTR(src);
    }
else if (dx == TYPD) {
    k = 0;      /* copy elements */
    si = DBL_PTR(src);
    }
else  return ERROR;

    /* copy src to left of dst */

for (di = DBL_PTR(dst), i = nc; 0 < i--; di += 2*nc)
    for (xij = di, j = nc; 0 < j--;)
        *xij++ = (k) ? (DOUBLE) (*ix++) : *si++;

for (di = DBL_PTR(dst) + nc, i = nc; 0 < i--; di += (2*nc+1)) *di = D_1;

if (mkobj(src, 'N', 2*nc, nc)) return ERROR;
imap = INT_PTR(src);
qmap = imap + nc;

for (i = 0, zdet = D_1; i < nc; i++) {
    xij = DBL_PTR(dst) + i;
    p   = nc;

    /* find pivot in each column */

    for (smax = D_0, j = 0; j < nc; j++, xij += nc2)
        if (0 == imap[j] && (smax < (xtmp = fabs(*xij)))) {
            smax = xtmp;
            p = j;
            }

    /* perform row operations on other rows */
    if (p == nc) return ERROR;
    imap[p] = 1+i;
    xij = DBL_PTR(dst) + i;
    xpivot = xij[p * nc2];
    if (fabs(xpivot) < D_EPS) return ERROR;
    zdet *= xpivot;
    for (j = 0; j < nc; j++, xij += nc2) {
        if (j == p) continue;
        xlambda = (*xij)/xpivot;
        si = DBL_PTR(dst) + p * nc2;
        di = DBL_PTR(dst) + j * nc2;
        for (k = 0; k < nc2; k++, si++, di++)
            *di -= xlambda * *si;
        }
    }

    /*      qmap = inverse(imap)                    */

for (j = 0; j < nc; j++) qmap[imap[j]-1] = j;

    /*      divide by pivot values                  */

for (i = 0, xij = DBL_PTR(dst); i < nc; i++, xij++) {
    p = qmap[i];
    xpivot = xij[p * nc2];
    for (si = DBL_PTR(dst) + p * nc2, j = nc2; 0 < j--;)
        *(si++) /= xpivot;
    }

        /*      permute rows to left */
for (i = 0; i < nc; i++) {
    t = TEXT_PTR(dst) + 2 * ncc * i;
    s = TEXT_PTR(dst) + ncc + 2 * ncc * qmap[i];
    memcpy(t, s, ncc);
    }

/*      create and copy result */

return sa_swap(dst, src) || fx_take(dst, src, nc, nc);
}

fx_det(dst)     /* return determinant */
MATRIX  dst;
{
if (mkobj(dst, 'D', 1, 1)) return ERROR;
*(DBL_PTR(dst)) = zdet;
return OK;
}

fx_outer(dst, lo, hi)   /* matrix construction */
MATRIX  dst;            /* dst[ij] = lo[i] op hi[j] */
MATRIX  lo;             /* a matrix or vector */
MATRIX  hi;
{
int     fc = MCOLS(dst->ADDR);  /* function code */
BYTE   *s  = lo->ADDR;
int     dx = MTYPE(s);
int     nx = MSIZE(s);
BYTE   *t  = hi->ADDR;
int     dy = MTYPE(t);
int     ny = MSIZE(t);
int     dz, j, rc;
MATRIX  p;
int    *zi, *xi, *yi;
DOUBLE *zd, *xs, *ys;
DOUBLE  xr;

       /* null absorbs    X + NULL -> NULL */
if (nx == 0 || ny == 0) return ZILDE(dst);

switch (dx | dy) {
    default: return ERROR;
    case TYPT:  /* character compare */
        if (fc != '='  ||
            mkobj(dst, 'N', nx * ny, ny)) return ERROR;
        for (zi = INT_PTR(dst), s = TEXT_PTR(lo); 0 < nx--; s++)
            for (t = TEXT_PTR(hi), j = ny; 0 < j--; t++, zi++)
                *zi = (*s == *t);
        return OK;
    case TYPN:          /* integer */
    case TYPD:  break;  /* double  */
    case TYPN|TYPD:     /* mixed */
        p = (dx == TYPN) ? lo : hi;
        if (fx_cvd(dst, p) || sa_swap(p, dst)) return ERROR;
        dx = dy = TYPD;
        }
 
/*      type of result. Logical operations give integers */
dz = (fc == '=' || fc == '>' || fc == '<') ? TYPN : dx;
 
if (rc = mkobj(dst, dz, nx*ny, ny)) return ERROR;
 
if (dx == TYPN)
for (zi = INT_PTR(dst), xi = INT_PTR(lo); 0 < nx--; xi++)
    for (yi = INT_PTR(hi), j = ny; 0 < j--; yi++, zi++)
        switch(fc) {
        default:
        case '+':       /* add */
                *zi = *xi + *yi;
                break;
        case '-':       /* minus */
                *zi = *xi - *yi;
                break;
        case '*':       /* multiply */
                *zi = *xi * *yi;
                break;
        case '<':       /* less     */
                *zi = *xi < *yi;
                break;
        case '>':       /* greater  */
                *zi = *xi > *yi;
                break;
        case '=':       /* equal    */
                *zi = *xi == *yi;
                break;
        case 'f':       /* floor or minimum */
                *zi = *xi;
                if (*zi > *yi) *zi = *yi;
                break;
        case 'c':       /* cieling or maximum */
                *zi = *xi;
                if (*zi < *yi) *zi = *yi;
                break;
        case '^':       /* and */
                *zi = *xi & *yi;
                break;
        case '%':       /* divide */
                if (*yi) *zi = *xi / *yi;
                break;
        case '#':       /* modulus */
                if (*yi) *zi = *xi % *yi;
                if (*xi < 0) *xi += *yi;
                break;
        }
else if (dx == TYPD) switch (dz) {
    case TYPN:  /* logical operation */
    for (zi = INT_PTR(dst), xs = DBL_PTR(lo); 0 < nx--; xs++)
        for (ys = DBL_PTR(hi), j = ny; 0 < j--; ys++, zi++)
        switch(fc) {
        default:
        case '<':       /* less     */
                *zi = (*xs < *ys);
                break;
        case '>':       /* greater  */
                *zi = (*xs > *ys);
                break;
        case '=':       /* equal    */
                *zi = (*xs == *ys);
                break;
                }
        break;
    case TYPD:  /* scalar function  */
    for (zd = DBL_PTR(dst), xs = DBL_PTR(lo); 0 < nx--; xs++)
        for (ys = DBL_PTR(hi), j = ny; 0 < j--; ys++, zd++)
        switch(fc) {
        default:
        case '+':       /* add */
                *zd = *xs + *ys;
                break;
        case '-':       /* subtract */
                *zd = *xs - *ys;
                break;
        case '*':       /* multiply */
                *zd = *xs * *ys;
                break;
        case 'f':       /* floor or minimum */
                *zd = *xs;
                if (*zd > *ys) *zd = *ys;
                break;
        case 'c':       /* cieling or maximum */
                *zd = *xs;
                if (*zd < *ys) *zd = *ys;
                break;
        }
    break;
    default: rc = ERROR;
    break;
    }
return rc;
}
