How to implement AI vector search and related posts with pgvector
At the end of this tutorial, you should be able to set up your own vector search with text embeddings in a Next.js app. This is a tutorial that mostly consists of coding samples taken directly from the Sanity codebase.
You can see the results right here on Sanity. The related posts section underneath each post is generated with pgvector. So is the search.
The stack I used:
- Open AI's text-embedding-ada-002 model
- Next.js
- Prisma
- PostgreSQL
Start by setting up the Prisma client:
This step is needed to get Prisma to cooperate with Next.js.
// Setting up prisma
import { PrismaClient } from "@prisma/client";
import { IS_DEVELOPMENT } from "@/utils";
/**
* This is basically:
* export const IS_DEVELOPMENT = process.env.NODE_ENV === "development";
*/
let prisma: PrismaClient;
if (IS_DEVELOPMENT) {
// @ts-ignore
if (!global.prisma) {
// @ts-ignore
global.prisma = new PrismaClient();
}
// @ts-ignore
prisma = global.prisma;
} else {
prisma = new PrismaClient();
}
export { prisma };
Set up your Prisma schema
// schema.prisma
generator client {
provider = "prisma-client-js"
previewFeatures = ["postgresqlExtensions"]
}
datasource db {
provider = "postgresql"
url = env("POSTGRES_PRISMA_URL")
directUrl = env("POSTGRES_URL_NON_POOLING")
extensions = [vector]
}
model Post {
mongoId String @map("mongo_id")
id String @id @default(cuid())
content String @db.VarChar(40000)
parentPostId String? @map("parent_post_id")
parentPostSlug String? @map("parent_post_slug")
userId String @map("user_id")
score Int @default(0)
commentCount Int @default(0) @map("comment_count")
tags String[]
createdAt DateTime @default(now()) @map("created_at")
slug String @unique
images Json
status PostStatus
publishedAt DateTime? @map("published_at")
embedding Unsupported("vector(1536)")?
@@map("posts")
}
enum PostStatus {
PUBLISHED
DRAFT
AWAITING_DELETION
AWAITING_PUBLICATION
}
At some point you'll need to push your changes with a Prisma command, for example prisma db push. If that does not ring a bell, I recommend reading up on Prisma then returning here.
You don't need most of the Post model fields. I have them because I use them internally at Sanity.
Generating embeddings
We use open AI's ada model to generate embeddings.
// Generate embeddings
"use server";
import OpenAI from "openai";
import { OPEN_AI_API_KEY } from "@/utils";
const openai = new OpenAI({
apiKey: OPEN_AI_API_KEY,
});
export const generateEmbedding = async (input: string) => {
const embeddingData = await openai.embeddings.create({
model: "text-embedding-ada-002",
input,
});
const [{ embedding }] = embeddingData.data;
return embedding;
};
Implementing vector search
I use the functions below to return related posts and to run post search.
// Searching posts and getting related posts
"use server";
import { prisma } from "@/lib";
import { RegularPost } from "@/types";
import { ServerTracker } from "@/utils";
import { generateEmbedding } from "../../openai";
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cleanPosts = (posts: any[]) => {
return posts.map((post) => {
return {
...post,
publishedAt: post.publishedAt?.toISOString() ?? null,
createdAt: post.createdAt?.toISOString() ?? null,
};
}) as RegularPost[];
};
const searchByEmbedding = async (embedding: number[]) => {
const posts = (await prisma.$queryRaw`
SELECT
content,
parent_post_id as "parentPostId",
parent_post_slug as "parentPostSlug",
user_id as "userId",
score,
comment_count as "commentCount",
tags,
created_at as "createdAt",
slug,
images,
status,
published_at as "publishedAt",
mongo_id AS "postId",
1 - (embedding <=> ${embedding}::vector) as similarity
FROM posts
WHERE embedding IS NOT NULL
AND LENGTH(content) > 200
ORDER BY similarity DESC
LIMIT 25
`) as any[]; // eslint-disable-line @typescript-eslint/no-explicit-any
return posts;
};
export const searchPosts = async (query: string): Promise<RegularPost[]> => {
ServerTracker.trackSearchQuery({ query });
const shortQuery = query.substring(0, 100);
const embedding = await generateEmbedding(shortQuery);
const posts = await searchByEmbedding(embedding);
return cleanPosts(posts);
};
export const getRelatedPostsByVector = async (
postId: string,
): Promise<RegularPost[]> => {
const [post] = (await prisma.$queryRaw`
SELECT mongo_id, CAST(embedding AS text) AS embedding
FROM posts
WHERE mongo_id = ${postId}
LIMIT 1;
`) as {
mongo_id: string;
embedding: string;
}[];
if (!post || !post.embedding) {
return [];
}
const relatedPosts = (await prisma.$queryRaw`
SELECT
content,
parent_post_id as "parentPostId",
parent_post_slug as "parentPostSlug",
user_id as "userId",
score,
comment_count as "commentCount",
tags,
created_at as "createdAt",
slug,
images,
status,
published_at as "publishedAt",
mongo_id AS "postId",
1 - (embedding <=> ${post.embedding}::vector) as similarity
FROM posts
WHERE embedding IS NOT NULL
AND parent_post_id IS NULL
AND mongo_id != ${postId}
ORDER BY similarity DESC
LIMIT 10
`) as any[]; // eslint-disable-line @typescript-eslint/no-explicit-any
return cleanPosts(relatedPosts);
};
Thanks for reading! If this post was helpful to you, please upvote this post or leave a comment!