代碼地址
https://gitcode.net/m0_56745306/knn_classifier.git
一、KNN算法簡介
該部分內(nèi)容參考自:https://zhuanlan.zhihu.com/p/45453761
-
KNN(K-Nearest Neighbor) 算法是機器學(xué)習(xí)算法中最基礎(chǔ)、最簡單的算法之一。它既能用于分類,也能用于回歸。KNN通過測量不同特征值之間的距離來進行分類。
-
KNN算法的思想非常簡單:對于任意n維輸入向量,分別對應(yīng)于特征空間中的一個點,輸出為該特征向量所對應(yīng)的類別標(biāo)簽或預(yù)測值。
-
對于一個需要預(yù)測的輸入向量x,我們只需要在訓(xùn)練數(shù)據(jù)集中尋找k個與向量x最近的向量的集合,然后把x的類別預(yù)測為這k個樣本中類別數(shù)最多的那一類。
如圖所示,ω1、ω2、ω3分別代表訓(xùn)練集中的三個類別。其中,與xu最相近的5個點(k=5)如圖中箭頭所指,很明顯與其最相近的5個點中最多的類別為ω1,因此,KNN算法將xu的類別預(yù)測為ω1。
二、KNN算法示例:推測鳶尾花種類
鳶尾花數(shù)據(jù)集記載了三類花(Setosa,versicolor,virginica)以及它們的四種屬性(花萼長度、花萼寬度、花瓣長度、花瓣寬度)。例如:
4.9,3.0,1.4,0.2,setosa
6.4,3.2,4.5,1.5,versicolor
6.0,2.2,5.0,1.5,virginica
對于給定的測試數(shù)據(jù),我們需要根據(jù)它的四種信息判斷其屬于哪一種鳶尾花。并輸出它的序號:
例如:
#假設(shè)該數(shù)據(jù)為第一條數(shù)據(jù)(對應(yīng)序號為0)
5.7,3.0,4.2,1.2
輸出可以為:
0 setosa
三、MapReduce+Hadoop實現(xiàn)KNN鳶尾花分類:
1. 實現(xiàn)環(huán)境
- Ubuntu20.04
- Hadoop3.3.5
- Java8
- Maven3.9.1
2.pom.xml
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>KNN_Classifier</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>KNN_Classifier</name>
<url>http://maven.apache.org</url>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
</execution>
</executions>
<configuration>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>module-info.class</exclude>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<!-- main()所在的類,注意修改 -->
<mainClass>KNN_Classifier.KNN_Driver</mainClass>
</transformer>
</transformers>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>8</source>
<target>8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>17</java.version>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>3.3.5</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.hadoop/hadoop-hdfs -->
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId>
<version>3.3.5</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-mapreduce-client-core</artifactId>
<version>3.3.5</version>
</dependency>
</dependencies>
</project>
3.設(shè)計思路及代碼
1. KNN_Driver類
Diriver
類主要負責(zé)初始化job
的各項屬性,同時將訓(xùn)練數(shù)據(jù)加載到緩存中去,以便于Mapper
讀取。同時為了記錄測試數(shù)據(jù)量,在conf
中設(shè)置testDataNum
用于在map
階段記錄。
package KNN_Classifier;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
public class KNN_Driver {
public static void main(String[] args) throws Exception {
Configuration conf = new Configuration();
GenericOptionsParser optionParser = new GenericOptionsParser(conf, args);
String[] remainingArgs = optionParser.getRemainingArgs();
if (remainingArgs.length != 3) {
System.err.println("Usage: KNN_Classifier <training dataset> <test dataset> <output>");
System.exit(2);
}
conf.setInt("K",5);//設(shè)置KNN算法的K值
conf.setInt("testDataNum",0);//設(shè)置全局計數(shù)器,記錄測試數(shù)據(jù)數(shù)目
conf.setInt("dimension",4);//設(shè)置向量維度
Job job = Job.getInstance(conf, "KNN_Classifier");
job.setJarByClass(KNN_Driver.class);
job.setMapperClass(KNN_Mapper.class);
job.setReducerClass(KNN_Reducer.class);
//將訓(xùn)練數(shù)據(jù)添加到CacheFile中
job.addCacheFile(new Path(remainingArgs[0]).toUri());
FileInputFormat.addInputPath(job, new Path(remainingArgs[1]));
FileOutputFormat.setOutputPath(job, new Path(remainingArgs[2]));
job.waitForCompletion(true);
System.exit(0);
}
}
2. MyData類
這個類對每條測試數(shù)據(jù)進行封裝,同時用于計算向量距離。
package KNN_Classifier;
import java.util.Vector;
public class MyData {
//向量維度
private Integer dimension;
//向量坐標(biāo)
private Vector<Double>vec = new Vector<Double>();
//屬性,這里是水仙花的種類
private String attr = new String();
public void setAttr(String attr)
{
this.attr = attr;
}
public void setVec(Vector<Double> vec) {
this.dimension = vec.size();
for(Double d : vec)
{
this.vec.add(d);
}
}
public double calDist(MyData data1)//計算兩條數(shù)據(jù)之間的歐式距離
{
try{
if(this.dimension != data1.dimension)
throw new Exception("These two vectors have different dimensions.");
}
catch (Exception e)
{
System.out.println(e.getMessage());
System.exit(-1);
}
double dist = 0;
for(int i = 0;i<dimension;i++)
{
dist += Math.pow(this.vec.get(i)-data1.vec.get(i),2);
}
dist = Math.sqrt(dist);
return dist;
}
public String getAttr() {
return attr;
}
}
3. KNN_Mapper類
-
setup
:用于加載緩存中的訓(xùn)練數(shù)據(jù)到Mapper
的列表當(dāng)中,同時讀取K
值、維度等必要信息。 -
readTrainingData
:由setup
調(diào)用,加載緩存訓(xùn)練數(shù)據(jù)。 -
Gaussian
:用于計算歐式距離x
所占權(quán)重,它的公式為:
f ( x ) = a e ( x ? b ) 2 ? 2 c 2 f(x) = ae^{\frac{(x-b)^2}{-2c^2}} f(x)=ae?2c2(x?b)2?
它的圖像為:
隨 ∣ x ∣ |x| ∣x∣絕對值增加, f ( x ) f(x) f(x)的值越來越小,可以反映距離對權(quán)重的影響:即歐式距離越大,權(quán)重越小,對標(biāo)簽的影響也越小。
實際上高斯函數(shù)各個參數(shù)的確定需要對樣本數(shù)據(jù)經(jīng)過多次交叉驗證得出,但為了簡單起見,這里另a=1,b=0,c=0.9
即可(這種情況下訓(xùn)練的結(jié)果比較好一些)。
-
map
:對得到的測試數(shù)據(jù)進行KNN
算法處理,它的偽代碼如下:map(key,val): #key為樣本數(shù)據(jù)偏移量,val為該行數(shù)據(jù) testData = getTestData ; #從val中讀取測試數(shù)據(jù)信息 K_Nearest = Empty ; #K最近鄰,可以用最大堆來實現(xiàn) for trainingData in trainingDataSet : #遍歷可以改為用KDTree優(yōu)化 dist = CalDist(testData,trainingData) ; if sizeof(K_Nearest) < K : #如果此時還未達到K值,直接添加 K_Nearest.add(dist,trainingData.attr) ; else : if dist < K_Nearest.maxDist : #如果計算得出的距離大于當(dāng)前K個點之中最大距離,則替換之 replace pair with maxDist to (dist,trainingData.attr) ; calculate weight sum for every attr ; #為每種標(biāo)簽計算權(quán)重和 write(idx,max_weight_attr); #寫入序號,最大權(quán)重標(biāo)簽,完成分類
綜上,下面是KNN_Mapper
的代碼:文章來源:http://www.zghlxwxcb.cn/news/detail-479059.html
package KNN_Classifier;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.*;
import java.net.URI;
import java.io.BufferedReader;
import java.io.FileReader;
import javafx.util.Pair;
public class KNN_Mapper extends Mapper<LongWritable, Text, LongWritable, Text> {
private Text text = new Text();//輸出Val值
private LongWritable longWritable = new LongWritable();//輸出K值
private Integer K;//K值
private Configuration conf;//全局配置
private Integer dimension;//維度
private List<MyData> training_data = new ArrayList<>();
private void readTrainingData(URI uri)//讀取訓(xùn)練數(shù)據(jù)到training_data中
{
System.err.println("Read Training Data");
try{
Path patternsPath = new Path(uri.getPath());
String patternsFileName = patternsPath.getName().toString();
BufferedReader reader = new BufferedReader(new FileReader(
patternsFileName));
String line;
Vector<Double>vec = new Vector<>();
while ((line = reader.readLine()) != null) {
// TODO: your code here
//
String[] strings = line.split(",");
for(int i=0;i<dimension;i++)
{
vec.add(Double.valueOf(strings[i]));
}
MyData myData = new MyData();
myData.setVec(vec);
myData.setAttr(strings[dimension]);
System.out.println(strings[dimension]);
training_data.add(myData);
vec.clear();
}
reader.close();
}
catch (FileNotFoundException e)
{
e.printStackTrace();
}
catch (IOException e)
{
e.printStackTrace();
}
System.err.println("Read End");
}
private double Gaussian(double dist)
{
//a = 1,b=0,c = 0.9,2*c^2 = 1.62
double weight = Math.exp(-Math.pow(dist,2)/(1.62));
return weight;
}
@Override
public void setup(Context context) throws IOException,
InterruptedException {
conf = context.getConfiguration();
this.K = conf.getInt("K",1);
this.dimension = conf.getInt("dimension",1);
URI[] uri = context.getCacheFiles();
readTrainingData(uri[0]);
}
@Override
public void map(LongWritable key, Text value, Context context
) throws IOException, InterruptedException {
String line = value.toString();
try {
String[] strings = line.split(",");
if (strings.length!=dimension) {
throw new Exception("Error line format in the table.");
}
//獲取測試數(shù)據(jù)信息
Vector<Double>vec = new Vector<>();
for(String s:strings)
{
System.err.println("S: "+s);
vec.add(Double.valueOf(s));
}
MyData testData = new MyData();
testData.setVec(vec);
//計算與樣本的K近鄰
//存放K近鄰的優(yōu)先級隊列,元素類型為<距離,屬性>
PriorityQueue<Pair<Double,String>>K_nearst = new PriorityQueue<>((a,b)->(a.getKey()>b.getKey())?-1:1);
double dist;
for(MyData data : this.training_data)
{
dist = testData.calDist(data);
if(K_nearst.size()<this.K)
{
K_nearst.add(new Pair<>(dist,data.getAttr()));
}
else{
if(dist < K_nearst.peek().getKey())
{
K_nearst.poll();
K_nearst.add(new Pair<>(dist,data.getAttr()));
}
}
}
//獲取到K近鄰后,通過高斯函數(shù)處理每條數(shù)據(jù),并累加相同屬性的權(quán)值,通過Hash_table實現(xiàn)
Hashtable<String,Double>weightTable = new Hashtable<>();
while(!K_nearst.isEmpty())
{
double d = K_nearst.peek().getKey();
String attr = K_nearst.peek().getValue();
double w = this.Gaussian(d);
if(!weightTable.contains(attr))
{
weightTable.put(attr,w);
}
else{
weightTable.put(attr,weightTable.get(attr)+w);
}
K_nearst.poll();
}
//選取權(quán)重最大的標(biāo)簽作為輸出
Double max_weight = Double.MIN_VALUE;
String target_attr = "";
for(Iterator<String> itr = weightTable.keySet().iterator();itr.hasNext();){
String hash_key = (String)itr.next();
Double hash_val = weightTable.get(hash_key);
if(hash_val > max_weight)
{
target_attr = hash_key;
max_weight = hash_val;
}
}
text.set(target_attr);
//獲取測試數(shù)據(jù)條數(shù),用作下標(biāo)計數(shù)
longWritable.set(conf.getLong("testDataNum",0));
conf.setLong("testDataNum",longWritable.get()+1);//計數(shù)加一
context.write(longWritable,text);
}
catch (Exception e) {
System.err.println(e.toString());
System.exit(-1);
}
}
}
4. KNN_Reducer類
由于Mapper
類已經(jīng)完成了所有工作,所以傳入到Reducer
中的鍵值對都是Index,Attr
的形式,直接寫入即可。文章來源地址http://www.zghlxwxcb.cn/news/detail-479059.html
package KNN_Classifier;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
import java.io.IOException;
public class KNN_Reducer extends Reducer<LongWritable, Text,LongWritable,Text> {
public void reduce(LongWritable key, Iterable<Text> values,
Context context
) throws IOException, InterruptedException {
for(Text val:values)
{
context.write(key,val);
}
}
}
到了這里,關(guān)于MapReduce實現(xiàn)KNN算法分類推測鳶尾花種類的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!