Как использовать TensorFlow с Java

Введение Машинное обучение набирает популярность и используется во всем мире. Он уже радикально изменил способ создания определенных приложений и, вероятно, будет продолжать оставаться огромной (и постоянно увеличивающейся) частью нашей повседневной жизни. Это не приукрашивает, машинное обучение - непростое дело. Это довольно сложно и многим может показаться очень сложным. Такие компании, как Google, взяли на себя задачу приблизить концепции машинного обучения к разработчикам и позволить им постепенно, с серьезной помощью,

Вступление

Машинное обучение набирает популярность и используется во всем мире. Он уже радикально изменил способ создания определенных приложений и, вероятно, будет продолжать оставаться огромной (и постоянно увеличивающейся) частью нашей повседневной жизни.

Это не приукрашивает, машинное обучение - непростое дело. Это довольно сложно и многим может показаться очень сложным.

Такие компании, как Google, взяли на себя задачу приблизить концепции машинного обучения к разработчикам и позволить им постепенно, с серьезной помощью, делать свои первые шаги.

Так родились такие фреймворки, как TensorFlow.

Что такое TensorFlow?

TensorFlow - это среда машинного обучения с открытым исходным кодом, разработанная Google на Python и C ++.

Это помогает разработчикам легко получать данные, подготавливать и обучать модели, прогнозировать будущие состояния и выполнять крупномасштабное машинное обучение.

С его помощью мы можем обучать и запускать глубокие нейронные сети, которые чаще всего используются для оптического распознавания символов , распознавания / классификации изображений, обработки естественного языка и т. Д.

Тензоры и операции

TensorFlow основан на вычислительных графах, которые вы можете представить как классический граф с узлами и ребрами.

Каждый узел называется операцией , и они принимают ноль или более тензоров и производят ноль или более тензоров на выходе. Операция может быть очень простой, например, базовое добавление, но также может быть очень сложной.

Тензоры изображаются как края графа и являются основной единицей данных. Мы выполняем различные функции с этими тензорами по мере их передачи в операции. Они могут иметь одно или несколько измерений, которые иногда называют их рангами - (Скаляр: ранг 0, Вектор: ранг 1, Матрица: ранг 2)

Эти данные протекают через вычислительный граф через тензоры, затронутые деятельность - отсюда и название TensorFlow.

Тензоры могут хранить данные в любом количестве измерений, и есть три основных типа тензоров: заполнители , переменные и константы .

Установка TensorFlow

Используя Maven , установить TensorFlow так же просто, как включить зависимость:

 <dependency> 
 <groupId>org.tensorflow</groupId> 
 <artifactId>tensorflow</artifactId> 
 <version>1.13.1</version> 
 </dependency> 

Если ваше устройство поддерживает поддержку графического процессора , используйте эти зависимости:

 <dependency> 
 <groupId>org.tensorflow</groupId> 
 <artifactId>libtensorflow</artifactId> 
 <version>1.13.1</version> 
 </dependency> 
 
 <dependency> 
 <groupId>org.tensorflow</groupId> 
 <artifactId>libtensorflow_jni_gpu</artifactId> 
 <version>1.13.1</version> 
 </dependency> 

Вы можете проверить версию TensorFlow, установленную в настоящее время, с помощью объекта TensorFlow

 System.out.println(TensorFlow.version()); 

TensorFlow Java API

Предлагаемые Java API TensorFlow содержатся в пакете org.tensorflow В настоящее время он экспериментальный, поэтому его стабильность не гарантируется.

Обратите внимание, что единственным полностью поддерживаемым языком для TensorFlow является Python, и что API Java не так функционально.

Он знакомит нас с новыми классами, интерфейсом, перечислением и исключением.

Классы

Новые классы, представленные через API:

  • Graph : граф потока данных, представляющий вычисление TensorFlow.
  • Operation : узел Graph, который выполняет вычисления на тензорах.
  • OperationBuilder : класс построителя для операций
  • Output<T> : символьный дескриптор тензора, созданного операцией.
  • SavedModelBundle : представляет модель, загруженную из хранилища.
  • SavedModelBundle.Loader : предоставляет параметры для загрузки SavedModel
  • Server : внутрипроцессный сервер TensorFlow для использования в распределенном обучении.
  • Session : Драйвер для выполнения графа
  • Session.Run : вывод тензоров и метаданных, полученных при выполнении сеанса.
  • Session.Runner : запуск операций и оценка тензоров
  • Shape : возможно, частично известная форма тензора, созданного операцией.
  • Tensor<T> : статически типизированный многомерный массив, элементы которого относятся к типу, описанному T
  • TensorFlow : статические служебные методы, описывающие среду выполнения TensorFlow
  • Tensors : типобезопасные фабричные методы для создания Tensor-объектов.
Enum
  • DataType : представляет тип элементов в Tensor как перечисление
Интерфейс
  • Operand<T> : интерфейс, реализованный операндами операции TensorFlow.
Исключение
  • TensorFlowException : непроверенное исключение, возникающее при выполнении графиков TensorFlow

Если мы сравним все это с модулем tf в Python , есть очевидная разница. API Java не имеет почти такой же функциональности, по крайней мере, на данный момент.

Графики

Как упоминалось ранее, TensorFlow основан на вычислительных графах, где org.tensorflow.Graph - это реализация Java.

Примечание : его экземпляры являются потокобезопасными, хотя нам нужно явно освободить ресурсы, используемые Graph, после того, как мы закончим с ним.

Начнем с пустого графика:

 Graph graph = new Graph(); 

Этот график ничего не значит, он пустой. Чтобы что-нибудь с ним сделать, нам сначала нужно загрузить его с помощью Operation s.

Чтобы загрузить его с помощью операций, мы используем метод opBuilder() , который возвращает OperationBuilder который добавит операции к нашему графу, как только мы вызовем метод .build() .

Константы

Добавим в наш график константу:

 Operation x = graph.opBuilder("Const", "x") 
 .setAttr("dtype", DataType.FLOAT) 
 .setAttr("value", Tensor.create(3.0f)) 
 .build(); 

Заполнители

Заполнители - это «тип» переменной, которая не имеет значения при объявлении. Их значения будут присвоены позже. Это позволяет нам строить графики с операциями без каких-либо фактических данных:

 Operation y = graph.opBuilder("Placeholder", "y") 
 .setAttr("dtype", DataType.FLOAT) 
 .build(); 

Функции

И, наконец, чтобы завершить это, нам нужно добавить определенные функции. Они могут быть такими простыми, как умножение, деление или сложение, или сложными, как умножение матриц. Как и раньше, мы определяем функции с помощью .opBuilder() :

 Operation xy = graph.opBuilder("Mul", "xy") 
 .addInput(x.output(0)) 
 .addInput(y.output(0)) 
 .build(); 

Примечание: мы используем output(0) поскольку тензор может иметь более одного вывода.

Визуализация графиков

К сожалению, в Java API пока нет инструментов, позволяющих визуализировать графики, как в Python. Когда обновится Java API, обновится и эта статья.

Сессии

Как упоминалось ранее, Session является драйвером выполнения Graph . Он инкапсулирует среду, в которой Operation и Graph для вычисления Tensor .

Это означает, что тензоры в нашем графике, который мы построили, на самом деле не имеют никакого значения, поскольку мы не запускали график в рамках сеанса.

Давайте сначала добавим график в сеанс:

 Session session = new Session(graph); 

Наше вычисление просто умножает значения x и y . Чтобы запустить наш график и вычислить его, мы fetch() xy и передаем ей значения x и y :

 Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0); 
 System.out.println(tensor.floatValue()); 

Выполнение этого фрагмента кода даст:

 10.0f 

Сохранение моделей в Python и загрузка в Java

Это может показаться немного странным, но поскольку Python является единственным хорошо поддерживаемым языком, API Java по-прежнему не имеет функции сохранения моделей.

Это означает, что Java API предназначен только для обслуживающего варианта использования, по крайней мере, до тех пор, пока он полностью не будет поддерживаться TensorFlow. По крайней мере, мы можем обучать и сохранять модели на Python, а затем загружать их на Java для их обслуживания, используя класс SavedModelBundle

 SavedModelBundle model = SavedModelBundle.load("./model", "serve"); 
 Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0); 
 
 System.out.println(tensor.floatValue()); 

Заключение

TensorFlow - это мощный, надежный и широко используемый фреймворк. Он постоянно совершенствуется и в последнее время вводится в новые языки, включая Java и JavaScript.

Хотя Java API еще не имеет такой функциональности, как TensorFlow для Python, он все же может служить хорошим введением в TensorFlow для разработчиков Java.

comments powered by Disqus