Как да изградите и внедрите ML модели в Java.

Машинното обучение (ML) даде значителни обещания в различни области както в академичните среди, така и в индустрията. Ден след ден ML увеличава ангажимента си в изчерпателен списък от приложения като изображения, разпознаване на реч, разпознаване на шаблони, оптимизиране, обработка на естествен език и препоръки и много други.

Програмирането на компютри за учене от опита в крайна сметка трябва да елиминира необходимостта от голяма част от това подробно усилие за програмиране. - Артър Самюел 1959 г.

Машинното обучение може да бъде разделено на четири основни техники: регресия, класификация, групиране и обучение с подсилване. Тези техники решават проблеми с различно естество главно в две форми: контролирано и неконтролирано обучение. Наблюдаваното обучение изисква данните да бъдат етикетирани и подготвени преди обучението на модела. Неконтролираното обучение е полезно за обработка на немаркирани данни или данни с неизвестни характеристики. Тази статия не описва концепциите на ML и не описва в дълбочина термините, използвани в тази област. Ако сте съвсем нов, моля, погледнете моята предишна статия, за да започнете вашето обучение в ML.

Библиотеки за машинно обучение в Java

Ето списък с добре познати библиотеки в Java за ML. Ще ги опишем един по един и ще дадем примери от реалния свят, използвайки някои от тези рамки.

До всяка библиотека следните икони ще показват основните категории алгоритми, предоставени във всяка рамка по подразбиране.

Weka

Weka е библиотека с отворен код, разработена от Университета на Уайкато в Нова Зеландия. Weka е написана на Java и е много добре позната за машинно обучение с общо предназначение. Weka предоставя файлов формат с данни, наречен ARFF. ARFF е разделен на две части: заглавка и действителните данни. Заглавието описва атрибутите и техните типове данни.

Apache Mahout

Apache Mahout предоставя мащабируема библиотека за машинно обучение. Mahout използва парадигмата MapReduce и може да се използва за класификация, съвместно филтриране и клъстериране. Mahout използва Apache Hadoop за обработка на множество паралелни задачи. В допълнение към класификацията и клъстерирането, Mahout предоставя алгоритми за препоръки, като например съвместно филтриране, улеснявайки скалируемостта за бързо изграждане на вашия модел.

Deeplearning4j

Deeplearning4j е друга библиотека на Java, фокусирана върху дълбокото обучение. Това е една страхотна библиотека с отворен код за дълбоко обучение за Java. Освен това е написан на Scala и Java и може да бъде интегриран с Hadoop и Spark, осигурявайки високи възможности за обработка. Текущата версия е в бета версия, но идва с отлична документация и примери за бърз старт (щракнете тук).

Чук

Mallet означава Machine Learning for Language Toolkit. Това е един от малкото специализирани инструменти за обработка на естествен език. Той предоставя възможности за моделиране на теми, класификация на документи, групиране и извличане на информация. С Mallet можем да ML модели за обработка на текстови документи.

Spark MLlib

Spark е много добре известно, че ускорява скалируемостта и цялостната производителност при обработката на огромно количество данни. Spark MLlib също има алгоритми с висока мощност за работа на spark и включени в работните процеси на Hadoop.

Рамката за машинно обучение Encog

Encog е Java и C# рамка за ML. Envog има библиотеки за изграждане на SVM, NN, Bayesian Networks, HMM и генетични алгоритми. Encog започна като изследователски проект и получи почти хиляда цитирания в Google Scholar.

MOA

Massive Online Analysis (MOA) предоставя алгоритми за класификация, регресия, групиране и препоръки. Той също така предоставя библиотеки за откриване на отклонения и откриване на отклонение. Той е предназначен за обработка в реално време на поток от произведени данни.

Weka Пример:

Ще използваме малък набор от данни за диабет. Първо ще заредим данните с помощта на Weka:

import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;

public class Main {

    public static void main(String[] args) throws Exception {
        // Specifying the datasource
        DataSource dataSource = new DataSource("data.arff");
        // Loading the dataset
        Instances dataInstances = dataSource.getDataSet();
        // Displaying the number of instances
        log.info("The number of loaded instances is: " + dataInstances.numInstances());

        log.info("data:" + dataInstances.toString());
    }
}

Има 768 екземпляра в набора от данни. Нека да видим как да получим броя на атрибутите (характеристиките), който трябва да бъде 9.

log.info("The number of attributes in the dataset: " + dataInstances.numAttributes());

Преди да изградим модел, искаме да определим коя колона е целевата колона и да видим колко класа има в тази колона:

// Identifying the label index
dataInstances.setClassIndex(dataInstances.numAttributes() - 1);
// Getting the number of 
log.info("The number of classes: " + dataInstances.numClasses());

След зареждане на набора от данни и идентифициране на нашия целеви атрибут, сега е времето за изграждане на модела. Нека направим прост дървовиден класификатор, J48.

// Creating a decision tree classifier
J48 treeClassifier = new J48();
treeClassifier.setOptions(new String[] { "-U" });
treeClassifier.buildClassifier(dataInstances);

В трите реда по-горе посочихме опция за обозначаване на необрязано дърво и предоставихме екземплярите на данни за обучение на модела. Ако отпечатаме дървовидната структура на генерирания модел след обучение, можем да проследим как моделът вътрешно е изградил своите правила:

plas <= 127
|   mass <= 26.4
|   |   preg <= 7: tested_negative (117.0/1.0)
|   |   preg > 7
|   |   |   mass <= 0: tested_positive (2.0)
|   |   |   mass > 0: tested_negative (13.0)
|   mass > 26.4
|   |   age <= 28: tested_negative (180.0/22.0)
|   |   age > 28
|   |   |   plas <= 99: tested_negative (55.0/10.0)
|   |   |   plas > 99
|   |   |   |   pedi <= 0.56: tested_negative (84.0/34.0)
|   |   |   |   pedi > 0.56
|   |   |   |   |   preg <= 6
|   |   |   |   |   |   age <= 30: tested_positive (4.0)
|   |   |   |   |   |   age > 30
|   |   |   |   |   |   |   age <= 34: tested_negative (7.0/1.0)
|   |   |   |   |   |   |   age > 34
|   |   |   |   |   |   |   |   mass <= 33.1: tested_positive (6.0)
|   |   |   |   |   |   |   |   mass > 33.1: tested_negative (4.0/1.0)
|   |   |   |   |   preg > 6: tested_positive (13.0)
plas > 127
|   mass <= 29.9
|   |   plas <= 145: tested_negative (41.0/6.0)
|   |   plas > 145
|   |   |   age <= 25: tested_negative (4.0)
|   |   |   age > 25
|   |   |   |   age <= 61
|   |   |   |   |   mass <= 27.1: tested_positive (12.0/1.0)
|   |   |   |   |   mass > 27.1
|   |   |   |   |   |   pres <= 82
|   |   |   |   |   |   |   pedi <= 0.396: tested_positive (8.0/1.0)
|   |   |   |   |   |   |   pedi > 0.396: tested_negative (3.0)
|   |   |   |   |   |   pres > 82: tested_negative (4.0)
|   |   |   |   age > 61: tested_negative (4.0)
|   mass > 29.9
|   |   plas <= 157
|   |   |   pres <= 61: tested_positive (15.0/1.0)
|   |   |   pres > 61
|   |   |   |   age <= 30: tested_negative (40.0/13.0)
|   |   |   |   age > 30: tested_positive (60.0/17.0)
|   |   plas > 157: tested_positive (92.0/12.0)
Number of Leaves  :  22
Size of the tree :  43

Пример за Deeplearning4j:

Този пример ще изгради модел на Convolution Neural Network (CNN) за класифициране на библиотеката MNIST. Ако не сте запознати с MNIST или как CNN работи, за да класифицира ръкописните цифри, препоръчвам ви да прегледате набързо моята ранна публикация, която описва тези аспекти подробно.

Както винаги, ще заредим набора от данни и ще покажем неговия размер.

DataSetIterator MNISTTrain = new MnistDataSetIterator(batchSize,true,seed);
DataSetIterator MNISTTest = new MnistDataSetIterator(batchSize,false,seed);

Нека проверим отново дали получаваме десет уникални етикета от набора от данни:

log.info("The number of total labels found in the training dataset " + MNISTTrain.totalOutcomes());
log.info("The number of total labels found in the test dataset " + MNISTTest.totalOutcomes());

След това нека конфигурираме архитектурата на модела. Ще използваме два слоя с навиване плюс сплескан слой за изхода. Deeplearning4j има няколко опции, които можете да използвате, за да инициализирате схемата за тегло.

// Building the CNN model
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed) // random seed
        .l2(0.0005) // regularization
        .weightInit(WeightInit.XAVIER) // initialization of the weight scheme
        .updater(new Adam(1e-3)) // Setting the optimization algorithm
        .list()
        .layer(new ConvolutionLayer.Builder(5, 5)
                //Setting the stride, the kernel size, and the activation function.
                .nIn(nChannels)
                .stride(1,1)
                .nOut(20)
                .activation(Activation.IDENTITY)
                .build())
        .layer(new SubsamplingLayer.Builder(PoolingType.MAX) // downsampling the convolution
                .kernelSize(2,2)
                .stride(2,2)
                .build())
        .layer(new ConvolutionLayer.Builder(5, 5)
                // Setting the stride, kernel size, and the activation function.
                .stride(1,1)
                .nOut(50)
                .activation(Activation.IDENTITY)
                .build())
        .layer(new SubsamplingLayer.Builder(PoolingType.MAX) // downsampling the convolution
                .kernelSize(2,2)
                .stride(2,2)
                .build())
        .layer(new DenseLayer.Builder().activation(Activation.RELU)
                .nOut(500).build())
        .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(outputNum)
                .activation(Activation.SOFTMAX)
                .build())
        // the final output layer is 28x28 with a depth of 1.
        .setInputType(InputType.convolutionalFlat(28,28,1))
        .build();

След като архитектурата е зададена, трябва да инициализираме режима, да зададем набора от данни за обучение и да задействаме обучението на модела.

MultiLayerNetwork model = new MultiLayerNetwork(conf);
// initialize the model weights.
model.init();

log.info("Step2: start training the model");
//Setting a listener every 10 iterations and evaluate on test set on every epoch
model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(MNISTTest, 1, InvocationType.EPOCH_END));
// Training the model
model.fit(MNISTTrain, nEpochs);

По време на обучението слушателят на резултатите ще предостави матрицата на объркването на точността на класификацията. Нека видим точността след десет епохи на обучение:

=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  977    0    0    0    0    0    1    1    1    0 | 0 = 0
    0 1131    0    1    0    1    2    0    0    0 | 1 = 1
    1    2 1019    3    0    0    0    3    4    0 | 2 = 2
    0    0    1 1004    0    1    0    1    3    0 | 3 = 3
    0    0    0    0  977    0    2    0    1    2 | 4 = 4
    1    0    0    9    0  879    1    0    1    1 | 5 = 5
    4    2    0    0    1    1  949    0    1    0 | 6 = 6
    0    4    2    1    1    0    0 1018    1    1 | 7 = 7
    2    0    3    1    0    1    1    2  962    2 | 8 = 8
    0    2    0    2   11    2    0    3    2  987 | 9 = 9

Пример за чук:

Както споменахме по-рано, Mallet е мощен инструментариум за моделиране на естествен език. Ще използваме примерен корпус, предоставен от инструмента David Blei в пакета Mallet. Mallet има специфична библиотека за анотиране на текстови токени за класификация. Преди да заредим нашия набор от данни, Mallet има тази концепция за дефиниране на тръбопровода, където дефинирате своя тръбопровод и след това предоставяте набора от данни, през който да премине.

ArrayList<Pipe> pipeList = new ArrayList<Pipe>();

Конвейерът се дефинира като „ArrayList“, който ще съдържа типични стъпки, които винаги правим, преди да изградим модел на тема. Всеки текст в документа ще премине през следните стъпки:

  1. Ключови думи с малки букви
  2. Токенизиране на текст
  3. Премахнете стоп думите
  4. Карта на характеристиките
pipeList.add( new CharSequenceLowercase() );
pipeList.add( new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")) );
// Setting the dictionary of the stop words
URL stopWordsFile = getClass().getClassLoader().getResource("stoplists/en.txt");
pipeList.add( new TokenSequenceRemoveStopwords(new File(stopWordsFile.toURI()), "UTF-8", false, false, false) );

pipeList.add( new TokenSequence2FeatureSequence() );

След като конвейерът е дефиниран, ще предадем екземплярите, представляващи оригинален текст на всеки документ.

InstanceList instances = new InstanceList (new SerialPipes(pipeList));

Сега идва стъпката за предаване на входния файл за попълване на списъка с инстанции.

URL inputFileURL = getClass().getClassLoader().getResource(inputFile);
Reader fileReader = new InputStreamReader(new FileInputStream(new File(inputFileURL.toURI())), "UTF-8");
instances.addThruPipe(new CsvIterator (fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
        3, 2, 1)); // data, label, name fields

От последния команден ред можете да забележите, че предоставихме инструкции как е структуриран CSV файлът. Изходният файл, наличен в папката с ресурси, има около две хиляди реда. Всеки ред представлява оригинален текст на документ и се състои от три атрибута, разделени със запетая (име, етикет и съдържание на документа). Можем да отпечатаме броя на екземплярите, намерени във входния документ, като използваме следната команда:

log.info("The number of instances found in the input file is: " + instances.size());

Сега нека моделираме темите на документа. Да приемем, че имаме 100 различни теми в тези 2k документа. Mallet ни позволява да зададем две променливи: алфа и бета тегла. Алфа контролира концентрацията на разпределенията тема-дума, а бета представлява теглата на преди думата спрямо разпределенията тема-дума.

int numTopics = 100;
// defining the model 
ParallelTopicModel model = new ParallelTopicModel(numTopics, 1.0, 0.01);
// adding the instances to the model
model.addInstances(instances);

Моделът, който избираме в този пример, е реализация на LDA (латентно разпределение на Дирихле). Алгоритъмът използва група от наблюдавани прилики на ключови думи за класифициране на документи.

Едно от нещата, които харесвам в Mallet, са възможностите на API за лесно проектиране на вашата паралелна обработка. Тук можем да дефинираме многонишкова обработка за всяка подпроба.

model.setNumThreads(2);

Сега ни остават само две неща е да определим броя на повторенията за обучението на модела и да започнем обучението.

model.setNumIterations(50);
model.estimate();

Оставих повече подробности за това как да покажа резултата от моделирането на темата в пълния пример в github.

Допълнителни четения

Всички примери, предоставени в тази публикация, са достъпни в хранилището на my Github.