pgvector или как хранить и обрабатывать фичи в базе данных
На Хабре было много упоминаний pgvector в обзорах Postgresso. И каждый раз новость была про место которое где-то за границей и далеко. Многие коммерческие решения для хранения и поиска векторов в базе данных нынче не доступны, а pgvector доступен любому, тем более в самой популярной базе в России.
В этой статье покажу на практическом примере как хранить, кластеризовать вектора и искать по векторам в базе данных.
Прежде всего надо установить pgvector в PostgreSQL, он доступен в виде расширения. Поскольку я работаю с базой данных из Docker, то могу просто добавить в Dockerfile строчки и пересобрать образ:
RUN git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git
RUN cd pgvector && make && make install
А в самой базе данных, нужно загрузить расширение:
osmworld=# CREATE EXTENSION vector;
CREATE EXTENSION
Time: 32,606 ms
Данные для векторов можно получить, например, из модели машинного обучения в python скрипте или ML модели в spark и вставить в таблицу с колонкой типа vector. А можно создать в SQL как гистограмму определенных категорий. В этом случае можно значения в массивах float[],integer[] или double precision[], numeric[] привести к типу ::vector
Данными для примера послужат гистограммы числа объектов детской инфраструктуры в окрестностях жилых домов в Москве. Про то как рассчитать эти данные я рассказывал здесь раньше, но в этой публикации я просто возьму готовые данные и создам из них таблицу с колонкой типа одинадцатимерный vector:
create table infrastructure_for_children_features2 as
select (row_number() over ())::integer id, null::integer cluster,
district, street, housenumber,
ARRAY[kindergarten::integer, school::integer,college::integer, university::integer, language_school::integer, music_school::integer,training::integer,sports_centre::integer,community_centre::integer,playground::integer,clinic::integer]
::vector(11) feature
from infrastructure_for_children;
Так в базе создал таблицу на 30237 записей со структурой:
osmworld=# \d infrastructure_for_children_features2
Table "public.infrastructure_for_children_features2"
Column | Type |
-------------+------------|
id | integer |
cluster | integer |
district | text |
street | text |
housenumber | text |
feature | vector(11) |
Теперь хотелось бы объединить их в группы по близости векторов. Опять же можно использовать нейросети, а можно использовать классические алгоритмы кластеризации - метод k-средних(k-means) или основанную на плотности пространственную кластеризацию для приложений с шумами (DBSCAN). Для метрики близости использую Евклидово расстояние. Поскольку число кластеров мне не известно, то я выберу DBSCAN и прогоню этот крошечный набор данных через него чтобы посмотреть зависимость от epsilon числа групп и число элементов не попавших в группы:
eps|clusters|not_in_cluster
0.0 75 29667
0.5 75 29667
1.0 202 28648
1.5 475 26904
2.0 928 22630
2.5 1227 17620
3.0 1173 11778
3.5 856 7605
4.0 601 4562
4.5 364 2760
5.0 232 1574
5.5 138 972
6.0 77 604
6.5 51 377
7.0 29 265
7.5 14 168
8.0 10 96
8.5 4 54
9.0 4 37
9.5 2 30
На свой субъективный взгляд выберу eps=5.5 и запущу Java программу, которая заполнит колонку cluster значениями алгоритма DBSCAN для minPoints=3 и eps=5.5:
package com.github.isuhorukov;
import com.pgvector.PGvector;
import org.apache.commons.math3.ml.clustering.Cluster;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.DBSCANClusterer;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
public class Main {
public static void main(String[] args) throws Exception {
try (Connection connection = DriverManager.getConnection(
System.getenv("jdbc_url"), System.getenv("user"), System.getenv("password"))) {
connection.setAutoCommit(false);
PGvector.addVectorType(connection);
float eps = Float.parseFloat(System.getenv("eps"));
int minPoints = Integer.parseInt(System.getenv("minPoints"));
DBSCANClusterer<Feature> dbscanClusterer = new DBSCANClusterer<>(eps,minPoints,new EuclideanDistance());
List<Feature> features = fetchFeatures(connection,
"select id,feature from infrastructure_for_children_features");
List<Cluster<Feature>> cluster = dbscanClusterer.cluster(features);
saveClusters(connection, cluster);
}
}
private static void saveClusters(Connection connection, List<Cluster<Feature>> cluster) throws SQLException {
try (PreparedStatement clusterPs = connection.prepareStatement(
"update infrastructure_for_children_features set cluster = ? where id = ?")){
for (int idx = 0; idx < cluster.size(); idx++) {
List<Feature> featureCluster = cluster.get(idx).getPoints();
for (Feature feature : featureCluster) {
clusterPs.setInt(1, idx);
clusterPs.setInt(2, feature.id);
clusterPs.addBatch();
}
clusterPs.executeBatch();
}
connection.commit();
} catch (Exception e) {
connection.rollback();
throw new RuntimeException(e);
}
}
private static List<Feature> fetchFeatures(Connection connection, String query) {
List<Feature> features = new ArrayList<>();
try (Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery(query))
{
while (resultSet.next()) {
int id = resultSet.getInt(1);
float[] feature = ((PGvector) resultSet.getObject(2)).toArray();
features.add(new Feature(id, feature));
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return features;
}
static class Feature implements Clusterable {
public int id;
public double[] feature;
public Feature(int id, float[] feature) {
this.id = id;
this.feature = new double[feature.length];
for (int i = 0; i < feature.length; i++) {
this.feature[i] = feature[i];
}
}
@Override
public double[] getPoint() {
return feature;
}
}
}
Для компиляции которого нужен pom.xml для maven:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.github.igor-suhorukov</groupId>
<artifactId>vectors</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
<dependency>
<groupId>com.pgvector</groupId>
<artifactId>pgvector</artifactId>
<version>0.1.3</version>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>42.6.0</version>
</dependency>
</dependencies>
</project>


Найдем в базе данных дом со скриншота по идентификатору:
osmworld=# select * from infrastructure_for_children_features where id=831;
id | cluster | district | street | housenumber | feature
-----+---------+-------------+--------------------+-------------------+-----------------------+-------------+---------------------------------
831 | 4 | Пресненский район | Большая Бронная улица | 19 | [37,28,3,23,6,4,17,23,9,148,92]
(1 row)
А теперь в PostgreSQL найдем 10 домов близкие к нему по значению вектора:
osmworld=# select id, cluster, district,street,housenumber from infrastructure_for_children_features order by feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
id | cluster | district | street | housenumber
------+---------+-------------------+------------------------------+-------------
831 | 4 | Пресненский район | Большая Бронная улица | 19
1011 | 4 | Пресненский район | Большая Бронная улица | 17
897 | 4 | Пресненский район | Сытинский переулок | 5/10 с3
827 | 4 | Пресненский район | Богословский переулок | 8/15
1019 | 4 | Пресненский район | Большая Бронная улица | 16
823 | 4 | Тверской район | Малый Палашёвский переулок | 4
893 | 4 | Пресненский район | Большой Козихинский переулок | 4
631 | 4 | Пресненский район | Большая Бронная улица | 25 с3
821 | 4 | Пресненский район | Сытинский переулок | 5/10 с4
1117 | 4 | Пресненский район | Богословский переулок | 5
(10 rows)
Time: 25,777 ms
osmworld=# explain select id, cluster, district,street,housenumber from infrastructure_for_children_features order by feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
QUERY PLAN
--------------------------------------------------------------------------------------------------------
Limit (cost=2375.37..2375.40 rows=10 width=90)
-> Sort (cost=2375.37..2450.97 rows=30237 width=90)
Sort Key: ((feature <-> '[37,28,3,23,6,4,17,23,9,148,92]'::vector))
-> Seq Scan on infrastructure_for_children_features (cost=0.00..1721.96 rows=30237 width=90)
(4 rows)
Можно ли ускорить поиск по векторам? Да, расширение поддерживает индексы IVFFlat и HNSW. Попробуем HNSW он более быстрый и точный, если верить научным публикациям:
osmworld=# CREATE INDEX ON infrastructure_for_children_features USING hnsw (feature vector_l2_ops) WITH (m = 16, ef_construction = 64);
CREATE INDEX
Time: 12705,611 ms (00:12,706)
osmworld=# explain select id, cluster, district,street,housenumber from infrastructure_for_children_features order by feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
QUERY PLAN
-----------------------------------------------------------------------------------------------------------------------------------------------------------
Limit (cost=5.00..5.61 rows=10 width=90)
-> Index Scan using infrastructure_for_children_features_feature_idx on infrastructure_for_children_features (cost=5.00..1861.36 rows=30237 width=90)
Order By: (feature <-> '[37,28,3,23,6,4,17,23,9,148,92]'::vector)
(3 rows)
Time: 1,022 ms
osmworld=# select id, cluster, district,street,housenumber from infrastructure_for_children_features order by feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
id | cluster | district | street | housenumber
------+---------+-------------------+------------------------------+-------------
831 | 4 | Пресненский район | Большая Бронная улица | 19
1011 | 4 | Пресненский район | Большая Бронная улица | 17
827 | 4 | Пресненский район | Богословский переулок | 8/15
897 | 4 | Пресненский район | Сытинский переулок | 5/10 с3
1019 | 4 | Пресненский район | Большая Бронная улица | 16
823 | 4 | Тверской район | Малый Палашёвский переулок | 4
893 | 4 | Пресненский район | Большой Козихинский переулок | 4
631 | 4 | Пресненский район | Большая Бронная улица | 25 с3
821 | 4 | Пресненский район | Сытинский переулок | 5/10 с4
1117 | 4 | Пресненский район | Богословский переулок | 5
(10 rows)
Time: 1,644 ms
Запрос стал использовать этот индекс и поиск по векторам стал быстрее на порядок по сравнению с seqscan.
Вывод
C расширение pgvector PostgreSQL оказалось простым в использовании и с ним можно работать не только алгоритмами машинного обучения, но и классическими алгоритмами кластеризации из Java программы, а так же быстро искать используя поиск по близости векторов и специализированный индекс HNSW.