问题是哪种好,然后有哪个地方需要修改的
ID3:
/*
* ID3.java
* Copyright (C) 2005 Liangxiao Jiang
*
*/package my_weka;import weka.classifiers.*;
import weka.core.*;
import java.util.*;/**
* Implement ID3 classifier.
*/
public class ID3 extends Classifier { /** The node's successors. */
private ID3[] m_Successors; /** Attribute used for splitting. */
private Attribute m_Attribute; /** The instances of the leaf node. */
private Instances m_Instances; /**
* Builds ID3 decision tree classifier.
*
* @param data the training data
* @exception Exception if classifier can't be built successfully
*/
public void buildClassifier(Instances data) throws Exception { //Build ID3 tree
makeTree(data);
} /**
* Method building ID3 tree using information gain measure
*
* @param data the training data
* @exception Exception if decision tree can't be built successfully
*/
private void makeTree(Instances data) throws Exception { // Check if no instances have reached this node
if (data.numInstances() == 0) {
m_Attribute = null;
m_Instances=new Instances(data);
return;
}
// Compute attribute with maximum split value.
double impurityReduce=0;
double maxValue=0;
int maxIndex=-1;
for (int i = 0; i < data.numAttributes(); i++){
if(i == data.classIndex()) continue;
impurityReduce=computeEntropyReduce(data,data.attribute(i));
if (impurityReduce>maxValue){
maxValue=impurityReduce;
maxIndex=i;
}
}
// Make leaf if information gain is zero, otherwise create successors.
if(Utils.eq(maxValue, 0)){
m_Attribute = null;
m_Instances=new Instances(data);
return;
}
else {
m_Attribute = data.attribute(maxIndex);
Instances[] splitData = splitData(data, m_Attribute);
m_Successors = new ID3[m_Attribute.numValues()];
for (int j = 0; j < m_Attribute.numValues(); j++) {
m_Successors[j] = new ID3();
m_Successors[j].makeTree(splitData[j]);
}
}
} /**
* Splits a dataset according to the values of a nominal attribute.
*
* @param data the data which is to be split
* @param att the attribute to be used for splitting
* @return the sets of instances produced by the split
*/
private Instances[] splitData(Instances data, Attribute att) { int numAttValues=att.numValues();
Instances[] splitData = new Instances[numAttValues];
for (int j = 0; j < numAttValues; j++) {
splitData[j] = new Instances(data,0);
}
int numInstances=data.numInstances();
for (int i=0;i<numInstances;i++){
int attVal=(int)data.instance(i).value(att);
splitData[attVal].add(data.instance(i));
}
return splitData;
} /**
* Computes information gain for an attribute.
*
* @param data the data for which info gain is to be computed
* @param att the attribute
* @return the information gain for the given attribute and data
*/
private double computeEntropyReduce(Instances data, Attribute att) throws Exception { double entropyReduce = computeEntropy(data);
Instances[] splitData = splitData(data, att);
for (int j = 0; j < att.numValues(); j++) {
if (splitData[j].numInstances() > 0) {
entropyReduce-=((double)splitData[j].numInstances()/(double) data.numInstances())*computeEntropy(splitData[j]);
}
}
return entropyReduce;
} /**
* Computes the entropy of a dataset.
*
* @param data the data for which entropy is to be computed
* @return the entropy of the data's class distribution
*/
private double computeEntropy(Instances data) throws Exception { int numClasses=data.numClasses();
int numInstances=data.numInstances();
double[] classCounts=new double[numClasses];
for (int i=0;i<numInstances;i++){
int classVal=(int)data.instance(i).classValue();
classCounts[classVal]++;
}
for (int i=0;i<numClasses;i++){
classCounts[i]/=numInstances;
}
double Entropy=0;
for (int i=0;i<numClasses;i++){
Entropy-=classCounts[i]*log2(classCounts[i],1);
}
return Entropy;
} /**
* compute the logarithm whose base is 2.
*
* @param args x,y are numerator and denominator of the fraction.
* @return the natual logarithm of this fraction.
*/
private double log2(double x,double y){ if(x<1e-6||y<1e-6)
return 0.0;
else
return Math.log(x/y)/Math.log(2);
} /**
* Computes class distribution for instance using decision tree.
*
* @param instance the instance for which distribution is to be computed
* @return the class distribution for the given instance
*/
public double[] distributionForInstance(Instance instance) throws Exception{ if (m_Attribute == null) {
return computeDistribution(m_Instances);
}
else {
return m_Successors[(int) instance.value(m_Attribute)].distributionForInstance(instance);
}
} /**
* Compute the distribution.
*
* @param data the training data
* @exception Exception if classifier can't be built successfully
*/
private double[] computeDistribution(Instances data) throws Exception { int numClasses=data.numClasses();
double[] probs=new double[numClasses];
double[] classCounts=new double[numClasses];
int numInstances=data.numInstances();
for (int i=0;i<numInstances;i++){
int classVal=(int)data.instance(i).classValue();
classCounts[classVal] ++;
}
for (int i=0;i<numClasses;i++){
probs[i]=(classCounts[i]+1.0)/(numInstances+numClasses);
}
Utils.normalize(probs);
return probs;
} /**
* Main method.
*
* @param args the options for the classifier
*/
public static void main(String[] args) { try {
System.out.println(Evaluation.evaluateModel(new ID3(), args));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}}
ID3:
/*
* ID3.java
* Copyright (C) 2005 Liangxiao Jiang
*
*/package my_weka;import weka.classifiers.*;
import weka.core.*;
import java.util.*;/**
* Implement ID3 classifier.
*/
public class ID3 extends Classifier { /** The node's successors. */
private ID3[] m_Successors; /** Attribute used for splitting. */
private Attribute m_Attribute; /** The instances of the leaf node. */
private Instances m_Instances; /**
* Builds ID3 decision tree classifier.
*
* @param data the training data
* @exception Exception if classifier can't be built successfully
*/
public void buildClassifier(Instances data) throws Exception { //Build ID3 tree
makeTree(data);
} /**
* Method building ID3 tree using information gain measure
*
* @param data the training data
* @exception Exception if decision tree can't be built successfully
*/
private void makeTree(Instances data) throws Exception { // Check if no instances have reached this node
if (data.numInstances() == 0) {
m_Attribute = null;
m_Instances=new Instances(data);
return;
}
// Compute attribute with maximum split value.
double impurityReduce=0;
double maxValue=0;
int maxIndex=-1;
for (int i = 0; i < data.numAttributes(); i++){
if(i == data.classIndex()) continue;
impurityReduce=computeEntropyReduce(data,data.attribute(i));
if (impurityReduce>maxValue){
maxValue=impurityReduce;
maxIndex=i;
}
}
// Make leaf if information gain is zero, otherwise create successors.
if(Utils.eq(maxValue, 0)){
m_Attribute = null;
m_Instances=new Instances(data);
return;
}
else {
m_Attribute = data.attribute(maxIndex);
Instances[] splitData = splitData(data, m_Attribute);
m_Successors = new ID3[m_Attribute.numValues()];
for (int j = 0; j < m_Attribute.numValues(); j++) {
m_Successors[j] = new ID3();
m_Successors[j].makeTree(splitData[j]);
}
}
} /**
* Splits a dataset according to the values of a nominal attribute.
*
* @param data the data which is to be split
* @param att the attribute to be used for splitting
* @return the sets of instances produced by the split
*/
private Instances[] splitData(Instances data, Attribute att) { int numAttValues=att.numValues();
Instances[] splitData = new Instances[numAttValues];
for (int j = 0; j < numAttValues; j++) {
splitData[j] = new Instances(data,0);
}
int numInstances=data.numInstances();
for (int i=0;i<numInstances;i++){
int attVal=(int)data.instance(i).value(att);
splitData[attVal].add(data.instance(i));
}
return splitData;
} /**
* Computes information gain for an attribute.
*
* @param data the data for which info gain is to be computed
* @param att the attribute
* @return the information gain for the given attribute and data
*/
private double computeEntropyReduce(Instances data, Attribute att) throws Exception { double entropyReduce = computeEntropy(data);
Instances[] splitData = splitData(data, att);
for (int j = 0; j < att.numValues(); j++) {
if (splitData[j].numInstances() > 0) {
entropyReduce-=((double)splitData[j].numInstances()/(double) data.numInstances())*computeEntropy(splitData[j]);
}
}
return entropyReduce;
} /**
* Computes the entropy of a dataset.
*
* @param data the data for which entropy is to be computed
* @return the entropy of the data's class distribution
*/
private double computeEntropy(Instances data) throws Exception { int numClasses=data.numClasses();
int numInstances=data.numInstances();
double[] classCounts=new double[numClasses];
for (int i=0;i<numInstances;i++){
int classVal=(int)data.instance(i).classValue();
classCounts[classVal]++;
}
for (int i=0;i<numClasses;i++){
classCounts[i]/=numInstances;
}
double Entropy=0;
for (int i=0;i<numClasses;i++){
Entropy-=classCounts[i]*log2(classCounts[i],1);
}
return Entropy;
} /**
* compute the logarithm whose base is 2.
*
* @param args x,y are numerator and denominator of the fraction.
* @return the natual logarithm of this fraction.
*/
private double log2(double x,double y){ if(x<1e-6||y<1e-6)
return 0.0;
else
return Math.log(x/y)/Math.log(2);
} /**
* Computes class distribution for instance using decision tree.
*
* @param instance the instance for which distribution is to be computed
* @return the class distribution for the given instance
*/
public double[] distributionForInstance(Instance instance) throws Exception{ if (m_Attribute == null) {
return computeDistribution(m_Instances);
}
else {
return m_Successors[(int) instance.value(m_Attribute)].distributionForInstance(instance);
}
} /**
* Compute the distribution.
*
* @param data the training data
* @exception Exception if classifier can't be built successfully
*/
private double[] computeDistribution(Instances data) throws Exception { int numClasses=data.numClasses();
double[] probs=new double[numClasses];
double[] classCounts=new double[numClasses];
int numInstances=data.numInstances();
for (int i=0;i<numInstances;i++){
int classVal=(int)data.instance(i).classValue();
classCounts[classVal] ++;
}
for (int i=0;i<numClasses;i++){
probs[i]=(classCounts[i]+1.0)/(numInstances+numClasses);
}
Utils.normalize(probs);
return probs;
} /**
* Main method.
*
* @param args the options for the classifier
*/
public static void main(String[] args) { try {
System.out.println(Evaluation.evaluateModel(new ID3(), args));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}}
* KNN.java
* Copyright 2004 Liangxiao Jiang
**/package my_weka;import weka.classifiers.*;
import weka.core.*;
import java.util.*;/**
* Implement an KNN classifier.
*/
public class KNN extends Classifier { /** The training instances used for classification. */
private Instances m_Train; /** The number of neighbours to use for classification. */
private int m_kNN; /**
* Builds KNN classifier.
*
* @param data the training data
* @exception Exception if classifier can't be built successfully
*/
public void buildClassifier(Instances data) throws Exception { //initial data
m_Train=new Instances(data);
m_kNN=10;
} /**
* Computes class distribution for a test instance.
*
* @param instance the instance for which distribution is to be computed
* @return the class distribution for the given instance
*/
public double[] distributionForInstance(Instance instance) throws Exception { NeighborList neighborlist = findNeighbors(instance,m_kNN);
return computeDistribution(neighborInstances(neighborlist),instance);
} /**
* Build the list of nearest k neighbors to the given test instance.
*
* @param instance the instance to search for neighbours
* @return a list of neighbors
*/
private NeighborList findNeighbors(Instance instance,int kNN) { double distance;
NeighborList neighborlist = new NeighborList(kNN);
for(int i=0; i<m_Train.numInstances();i++){
Instance trainInstance=m_Train.instance(i);
distance=distance(instance,trainInstance);
if (neighborlist.isEmpty()||i<kNN||distance<=neighborlist.m_Last.m_Distance) {
neighborlist.insertSorted(distance,trainInstance);
}
}
return neighborlist; } /**
* Turn the list of nearest neighbors into a probability distribution
*
* @param neighborlist the list of nearest neighboring instances
* @return the probability distribution
*/
private Instances neighborInstances (NeighborList neighborlist) throws Exception { Instances neighborInsts = new Instances(m_Train, neighborlist.currentLength());
if (!neighborlist.isEmpty()) {
NeighborNode current = neighborlist.m_First;
while (current != null) {
neighborInsts.add(current.m_Instance);
current = current.m_Next;
}
}
return neighborInsts; } /**
* Calculates the distance between two instances
*
* @param first the first instance
* @param second the second instance
* @return the distance between the two given instances
*/
private double distance(Instance first, Instance second) { double distance = 0;
for(int i=0;i<m_Train.numAttributes();i++){
if(i == m_Train.classIndex()) continue;
if((int)first.value(i)!=(int)second.value(i)){
distance+=1;
}
}
return distance;
} /**
* Compute the distribution.
*
* @param data the training data
* @exception Exception if classifier can't be built successfully
*/
private double[] computeDistribution(Instances data,Instance instance) throws Exception { int numClasses=data.numClasses();
double[] probs=new double[numClasses];
double[] classCounts=new double[numClasses];
int numInstances=data.numInstances();
for (int i=0;i<numInstances;i++){
int classVal=(int)data.instance(i).classValue();
classCounts[classVal] ++;
}
for (int i=0;i<numClasses;i++){
probs[i]=(classCounts[i]+1.0)/(numInstances+numClasses);
}
Utils.normalize(probs);
return probs;
} /**
* Main method.
*
* @param args the options for the classifier
*/
public static void main(String[] args) { try {
System.out.println(Evaluation.evaluateModel(new KNN(), args));
} catch (Exception e) {
System.err.println(e.getMessage());
}
} /*
* A class for storing data about a neighboring instance
*/
private class NeighborNode { /** The neighbor instance */
private Instance m_Instance; /** The distance from the current instance to this neighbor */
private double m_Distance; /** A link to the next neighbor instance */
private NeighborNode m_Next; /**
* Create a new neighbor node.
*
* @param distance the distance to the neighbor
* @param instance the neighbor instance
* @param next the next neighbor node
*/
public NeighborNode(double distance, Instance instance, NeighborNode next){
m_Distance = distance;
m_Instance = instance;
m_Next = next;
} /**
* Create a new neighbor node that doesn't link to any other nodes.
*
* @param distance the distance to the neighbor
* @param instance the neighbor instance
*/
public NeighborNode(double distance, Instance instance) { this(distance, instance, null);
}
} /*
* A class for a linked list to store the nearest k neighbours to an instance.
*/
private class NeighborList { /** The first node in the list */
private NeighborNode m_First; /** The last node in the list */
private NeighborNode m_Last; /** The number of nodes to attempt to maintain in the list */
private int m_Length = 1; /**
* Creates the neighborlist with a desired length
*
* @param length the length of list to attempt to maintain
*/
public NeighborList(int length) { m_Length = length;
} /**
* Gets whether the list is empty.
*
* @return true if so
*/
public boolean isEmpty() { return (m_First == null);
} /**
* Gets the current length of the list.
*
* @return the current length of the list
*/
public int currentLength() { int i = 0;
NeighborNode current = m_First;
while (current != null) {
i++;
current = current.m_Next;
}
return i;
} /**
* Inserts an instance neighbor into the list, maintaining the list sorted by distance.
*
* @param distance the distance to the instance
* @param instance the neighboring instance
*/
public void insertSorted(double distance, Instance instance) { if (isEmpty()) {
m_First = m_Last = new NeighborNode(distance, instance);
} else {
NeighborNode current = m_First;
if (distance < m_First.m_Distance) {// Insert at head
m_First = new NeighborNode(distance, instance, m_First);
}
else { // Insert further down the list
for( ;(current.m_Next != null) &&
(current.m_Next.m_Distance < distance);
current = current.m_Next);
current.m_Next = new NeighborNode(distance, instance,
current.m_Next);
if (current.equals(m_Last)) {
m_Last = current.m_Next;
}
} // Trip down the list until we've got k list elements (or more if the distance to the last elements is the same).
int valcount = 0;
for(current = m_First; current.m_Next != null;
current = current.m_Next) {
valcount++;
if ((valcount >= m_Length) && (current.m_Distance != current.m_Next.m_Distance)) {
m_Last = current;
current.m_Next = null;
break;
}
}
}
} /**
* Prunes the list to contain the k nearest neighbors. If there are multiple neighbors at the k'th distance, all will be kept.
*
* @param k the number of neighbors to keep in the list.
*/
public void pruneToK(int k) { if (isEmpty()) {
return;
}
if (k < 1) {
k = 1;
}
int currentK = 0;
double currentDist = m_First.m_Distance;
NeighborNode current = m_First;
for(; current.m_Next != null; current = current.m_Next) {
currentK++;
currentDist = current.m_Distance;
if ((currentK >= k) && (currentDist != current.m_Next.m_Distance)) {
m_Last = current;
current.m_Next = null;
break;
}
}
} }}
* NB.java
* Copyright 2005 Liangxiao Jiang
*/package my_weka;import weka.core.*;
import weka.classifiers.*;/**
* Implement the NB classifier.
*/
public class NB extends Classifier { /** The number of class and each attribute value occurs in the dataset */
private double [][] m_ClassAttCounts; /** The number of each class value occurs in the dataset */
private double [] m_ClassCounts; /** The number of values for each attribute in the dataset */
private int [] m_NumAttValues; /** The starting index of each attribute in the dataset */
private int [] m_StartAttIndex; /** The number of values for all attributes in the dataset */
private int m_TotalAttValues; /** The number of classes in the dataset */
private int m_NumClasses; /** The number of attributes including class in the dataset */
private int m_NumAttributes; /** The number of instances in the dataset */
private int m_NumInstances; /** The index of the class attribute in the dataset */
private int m_ClassIndex; /**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception { // reset variable
m_NumClasses = instances.numClasses();
m_ClassIndex = instances.classIndex();
m_NumAttributes = instances.numAttributes();
m_NumInstances = instances.numInstances();
m_TotalAttValues = 0;
// allocate space for attribute reference arrays
m_StartAttIndex = new int[m_NumAttributes];
m_NumAttValues = new int[m_NumAttributes];
// set the starting index of each attribute and the number of values for
// each attribute and the total number of values for all attributes(not including class).
for(int i = 0; i < m_NumAttributes; i++) {
if(i != m_ClassIndex) {
m_StartAttIndex[i] = m_TotalAttValues;
m_NumAttValues[i] = instances.attribute(i).numValues();
m_TotalAttValues += m_NumAttValues[i];
}
else {
m_StartAttIndex[i] = -1;
m_NumAttValues[i] = m_NumClasses;
}
}
// allocate space for counts and frequencies
m_ClassCounts = new double[m_NumClasses];
m_ClassAttCounts = new double[m_NumClasses][m_TotalAttValues];
// Calculate the counts
for(int k = 0; k < m_NumInstances; k++) {
int classVal=(int)instances.instance(k).classValue();
m_ClassCounts[classVal] ++;
int[] attIndex = new int[m_NumAttributes];
for(int i = 0; i < m_NumAttributes; i++) {
if(i == m_ClassIndex){
attIndex[i] = -1;
}
else{
attIndex[i] = m_StartAttIndex[i] + (int)instances.instance(k).value(i);
m_ClassAttCounts[classVal][attIndex[i]]++;
}
}
}
} /**
* Calculates the class membership probabilities for the given test instance
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if there is a problem generating the prediction
*/
public double [] distributionForInstance(Instance instance) throws Exception { //Definition of local variables
double [] probs = new double[m_NumClasses];
// store instance's att values in an int array
int[] attIndex = new int[m_NumAttributes];
for(int att = 0; att < m_NumAttributes; att++) {
if(att == m_ClassIndex)
attIndex[att] = -1;
else
attIndex[att] = m_StartAttIndex[att] + (int)instance.value(att);
}
// calculate probabilities for each possible class value
for(int classVal = 0; classVal < m_NumClasses; classVal++) {
probs[classVal]=(m_ClassCounts[classVal]+1.0)/(m_NumInstances+m_NumClasses);
for(int att = 0; att < m_NumAttributes; att++) {
if(attIndex[att]==-1) continue;
probs[classVal]*=(m_ClassAttCounts[classVal][attIndex[att]]+1.0)/(m_ClassCounts[classVal]+m_NumAttValues[att]);
}
}
Utils.normalize(probs);
return probs;
} /**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
new NB();
try {
System.out.println(Evaluation.evaluateModel(new NB(), argv));
}
catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}}