/**
IBM SOFTWARE DISCLAIMER 

Java array package (draft 0.2). Copyright (1998), International Business 
Machines Corporation.

Permission to use, copy, modify and distribute this software for any
noncommercial purpose and without fee is hereby granted, provided that
this copyright and permission notice appear on all copies of the
software. The name of the IBM Corporation may not be used in any
advertising or publicity pertaining to the use of the software. IBM
makes no warranty or representations about the suitability of the
software for any purpose.  It is provided "AS IS" without any express
or implied warranty, including the implied warranties of
merchantability, fitness for a particular purpose and non-infringement.
IBM shall not be liable for any direct, indirect, special or
consequential damages resulting from the loss of use, data or projects,
whether in an action of contract or tort, arising out of or in
connection with the use or performance of this software.
*/

import array.*;

public class MatMul {

    public static double inner(doubleArray1D a, doubleArray1D b)
	throws NonconformingArrayException {
	
	/*
	 * Compute the inner product of two vectors, a and b, represented
	 * as one-dimensional arrays. The two vectors must have the 
	 * same number of elements.
	 */
	int n = a.size();
	if (n != b.size()) throw new NonconformingArrayException();

	double sum = 0.0;
	for (int i=0; i<n; i++) {
	    sum += a.get(i) * b.get(i);
	}

	return sum;
    }

    public static doubleArray1D matvec(doubleArray2D A, doubleArray1D b)
	throws NonconformingArrayException, InvalidArrayShapeException,
	       InvalidArrayAxisException, InvalidRangeException {

	/*
	 * Compute the product of a matrix A, represented as a 
	 * two-dimensional array, and a vector b, represented as a
	 * one-dimensional array b. The number of columns of A
	 * must be equal to the number of elements of b.
	 */
	int m = A.size(0);
	int n = A.size(1);
	if (n != b.size()) throw new NonconformingArrayException();
	
	doubleArray1D c = new doubleArray1D(m);

	/*
	 * Element i of c is the inner product of row i of A by b.
	 */
	for (int i=0; i<m; i++) {
	    double dot = inner(A.section(i,new Range(0,n-1)),b);
	    c.set(i,dot);
	}

	return c;
    }

    public static doubleArray2D matmul(doubleArray2D A, doubleArray2D B) 
	throws NonconformingArrayException, InvalidArrayShapeException,
	       InvalidArrayAxisException, InvalidRangeException {

	/*
	 * Compute the product of two matrices, A and B, represented
	 * as two-dimensional array. If A is mxn and B is nxp then
	 * the resulting matrix C is mxp.
	 */
	int m = A.size(0);
	int n = A.size(1);
	if (n != B.size(0)) throw new NonconformingArrayException();
	int p = B.size(1);

	doubleArray2D C = new doubleArray2D(m,p);

	/*
	 * Column j of C is the product of matrix A by column j of B:
	 *	C(0:m-1,j) = A * B(0:n-1,j)
	 */
	for (int j=0; j<p; j++) {
	    doubleArray1D column = matvec(A,B.section(new Range(0,n-1),j));
	    C.section(new Range(0,m-1),j).assign(column);
	}

	return C;
    }
	
    public static void main(String[] args) 
	throws InvalidArrayShapeException, InvalidRangeException,
	       NonconformingArrayException, InvalidArrayAxisException {

	int m = 4;
	int n = 4;
	int p = 4;
	boolean print = false;
	double eps = 1e-09;

	/*
	 * Parse command line arguments.
	 */
	try{
	    
	    for(int arg=0; arg<args.length; arg++) {
		if (args[arg].equals("-m")) 
		    m = Integer.parseInt(args[++arg]);
		else if (args[arg].equals("-n"))
		    n = Integer.parseInt(args[++arg]);
		else if (args[arg].equals("-p"))
		    p = Integer.parseInt(args[++arg]);
		else if (args[arg].equals("-print"))
		    print = true;
		else if (args[arg].equals("-h"))
		    throw new Exception();
		else if (args[arg].equals("-help"))
		    throw new Exception();
		else throw new Exception();
	    }
	} catch (Exception e) {
	    System.out.println("usage: MatMul [-m <m>] [-n <n>] [-p <p>] [-print]");
	    return;
	}

	System.out.println("matrix multiply: "+m+" x "+n+" x "+p);

	/*
	 * Initialize the matrices A(mxn) and B(nxp) so that
	 * row i of each has all elements equal to i.
	 */
	doubleArray2D A = new doubleArray2D(m,n);
	doubleArray2D B = new doubleArray2D(n,p);
	for (int i=0; i<m; i++) {
	    A.section(i,new Range(0,n-1)).assign(i);
	}
	for (int i=0; i<n; i++) {
	    B.section(i,new Range(0,p-1)).assign(i);
	}

	/*
	 * Perform the matrix multiplication of A and B, yielding C.
	 */
	double etime = System.currentTimeMillis();
	doubleArray2D C = matmul(A,B);
	etime = System.currentTimeMillis() - etime;

	/*
	 * If print flag is set, print the result.
	 */
	if (print) {
	    for (int i=0; i<m; i++) {
		if (i > 0) System.out.println();
		if (p > 0) System.out.print(C.get(i,0));
		for (int j=1; j<p; j++) {
		    System.out.print(" "+C.get(i,j));
		}
	    }
	    System.out.println();
	}

	/*
	 * Check validity of results. Element C(i,j) should be 
	 * equal to i*n*(n-1)/2 (within an epsilon).
	 */
	for (int i=0; i<m; i++) {
	    for (int j=0; j<p; j++) {
		double correct = ((double)i*(double)n*(double)(n-1)/2.0);
		double error = Math.abs(C.get(i,j)-correct);
		if (error > eps) {
		    System.out.println("error in result");
		    return;
		}
	    }
	}
	double mflops = 1.0e-3*(2.0*m*n*p)/etime;
	System.out.println("mflops: "+mflops);
    }
}
