Skip to content

Commit

Permalink
fix(vector): Update query_rewriter to fix dotproduct and cosine query…
Browse files Browse the repository at this point in the history
… conversion (#9083)

fixes
https://linear.app/hypermode/issue/DGR-315/querysimilarthingsbyembedding-not-working-with-dotproduct-or-cosine

---------

Co-authored-by: shivaji-dgraph <shivaji@dgraph.io>
  • Loading branch information
rderbier and shivaji-dgraph committed May 15, 2024
1 parent 6e7896e commit 207583d
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 65 deletions.
13 changes: 7 additions & 6 deletions graphql/resolve/query_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,12 +659,12 @@ func rewriteAsSimilarByIdQuery(
topK := query.ArgValue(schema.SimilarTopKArgName)
similarByField := typ.Field(similarBy)
metric := similarByField.EmbeddingSearchMetric()
distanceFormula := "math((v2 - v1) dot (v2 - v1))" // default - euclidian
distanceFormula := "math(sqrt((v2 - v1) dot (v2 - v1)))" // default - euclidian

if metric == schema.SimilarSearchMetricDotProduct {
distanceFormula = "math(v1 dot v2)"
distanceFormula = "math((1.0 - (v1 dot v2)) /2.0)"
} else if metric == schema.SimilarSearchMetricCosine {
distanceFormula = "math((v1 dot v2) / ((v1 dot v1) * (v2 dot v2)))"
distanceFormula = "math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)"
}

// First generate the query to fetch the uid
Expand Down Expand Up @@ -819,12 +819,13 @@ func rewriteAsSimilarByEmbeddingQuery(

similarByField := typ.Field(similarBy)
metric := similarByField.EmbeddingSearchMetric()
distanceFormula := "math((v2 - $search_vector) dot (v2 - $search_vector))" // default = euclidian
distanceFormula := "math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))" // default = euclidian

if metric == schema.SimilarSearchMetricDotProduct {
distanceFormula = "math($search_vector dot v2)"
distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)"
} else if metric == schema.SimilarSearchMetricCosine {
distanceFormula = "math(($search_vector dot v2) / (($search_vector dot $search_vector) * (v2 dot v2)))"
distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" +
" * (v2 dot v2) ) )) / 2.0)"
}

// Save vectorString as a query variable, $search_vector
Expand Down
12 changes: 6 additions & 6 deletions graphql/resolve/query_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3367,7 +3367,7 @@
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(Product.productVector, 1, $search_vector)) @filter(type(Product)) {
v2 as Product.productVector
distance as math((v2 - $search_vector) dot (v2 - $search_vector))
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
}
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Expand Down Expand Up @@ -3397,7 +3397,7 @@
}
var(func: similar_to(Product.productVector, 3, val(v1))) {
v2 as Product.productVector
distance as math((v2 - v1) dot (v2 - v1))
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
}
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Expand Down Expand Up @@ -3428,7 +3428,7 @@
}
var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) {
v2 as ProjectCosine.description_v
distance as math((v1 dot v2) / ((v1 dot v1) * (v2 dot v2)))
distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)
}
querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) {
ProjectCosine.id : ProjectCosine.id
Expand All @@ -3453,7 +3453,7 @@
query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) {
v2 as ProjectCosine.description_v
distance as math(($search_vector dot v2) / (($search_vector dot $search_vector) * (v2 dot v2)))
distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0)
}
querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) {
ProjectCosine.id : ProjectCosine.id
Expand Down Expand Up @@ -3483,7 +3483,7 @@
}
var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) {
v2 as ProjectDotProduct.description_v
distance as math(v1 dot v2)
distance as math((1.0 - (v1 dot v2)) /2.0)
}
querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) {
ProjectDotProduct.id : ProjectDotProduct.id
Expand All @@ -3508,7 +3508,7 @@
query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) {
v2 as ProjectDotProduct.description_v
distance as math($search_vector dot v2)
distance as math(( 1.0 - (($search_vector) dot v2)) /2.0)
}
querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
ProjectDotProduct.id : ProjectDotProduct.id
Expand Down
158 changes: 105 additions & 53 deletions query/vector/vector_graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package query

import (
"encoding/json"
"fmt"
"math/rand"
"testing"

"github.com/dgraph-io/dgraph/dgraphtest"
Expand All @@ -36,29 +38,56 @@ const (
type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!] @embedding @search(by: ["hnsw(metric: euclidian, exponent: 4)"])
}
`
title_v: [Float!] @embedding @search(by: ["hnsw(metric: %v, exponent: 4)"])
} `
)

var (
projects = []ProjectInput{ProjectInput{
Title: "iCreate with a Mini iPad",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Resistive Touchscreen",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Fitness Band",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Smart Watch",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Smart Ring",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}}
)
func generateProjects(count int) []ProjectInput {
var projects []ProjectInput
for i := 0; i < count; i++ {
title := generateUniqueRandomTitle(projects)
titleV := generateRandomTitleV(5) // Assuming size is fixed at 5
project := ProjectInput{
Title: title,
TitleV: titleV,
}
projects = append(projects, project)
}
return projects
}

func isTitleExists(title string, existingTitles []ProjectInput) bool {
for _, project := range existingTitles {
if project.Title == title {
return true
}
}
return false
}

func generateUniqueRandomTitle(existingTitles []ProjectInput) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
const titleLength = 10
title := make([]byte, titleLength)
for {
for i := range title {
title[i] = charset[rand.Intn(len(charset))]
}
titleStr := string(title)
if !isTitleExists(titleStr, existingTitles) {
return titleStr
}
}
}

func generateRandomTitleV(size int) []float32 {
var titleV []float32
for i := 0; i < size; i++ {
value := rand.Float32()
titleV = append(titleV, value)
}
return titleV
}

func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) {
query := `
Expand All @@ -79,6 +108,7 @@ func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) {
_, err := hc.RunGraphqlQuery(params, false)
require.NoError(t, err)
}

func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title string) ProjectInput {
query := ` query QueryProject($title: String!) {
queryProject(filter: { title: { eq: $title } }) {
Expand All @@ -96,19 +126,17 @@ func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title strin
type QueryResult struct {
QueryProject []ProjectInput `json:"queryProject"`
}

var resp QueryResult
err = json.Unmarshal([]byte(string(response)), &resp)
require.NoError(t, err)

return resp.QueryProject[0]
}

func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32) []ProjectInput {
func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32, topk int) []ProjectInput {
// query similar project by embedding
queryProduct := `query QuerySimilarProjectByEmbedding($by: ProjectEmbedding!, $topK: Int!, $vector: [Float!]!) {
querySimilarProjectByEmbedding(by: $by, topK: $topK, vector: $vector) {
id
title
title_v
}
Expand All @@ -120,20 +148,19 @@ func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, ve
Query: queryProduct,
Variables: map[string]interface{}{
"by": "title_v",
"topK": 3,
"topK": topk,
"vector": vector,
}}
response, err := hc.RunGraphqlQuery(params, false)
require.NoError(t, err)
type QueryResult struct {
QueryProject []ProjectInput `json:"queryProject"`
QueryProject []ProjectInput `json:"querySimilarProjectByEmbedding"`
}
var resp QueryResult
err = json.Unmarshal([]byte(string(response)), &resp)
require.NoError(t, err)

return resp.QueryProject

}

func TestVectorGraphQLAddVectorPredicate(t *testing.T) {
Expand All @@ -143,21 +170,67 @@ func TestVectorGraphQLAddVectorPredicate(t *testing.T) {
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)
// add schema
require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema))
require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))
}

func TestVectorSchema(t *testing.T) {
require.NoError(t, client.DropAll())

hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := `type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!]
}`

// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))
}

func TestVectorGraphQlEuclidianIndexMutationAndQuery(t *testing.T) {
require.NoError(t, client.DropAll())
hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := fmt.Sprintf(graphQLVectorSchema, "euclidean")
// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
testVectorGraphQlMutationAndQuery(t, hc)
}

func TestVectorGraphQlMutationAndQuery(t *testing.T) {
func TestVectorGraphQlCosineIndexMutationAndQuery(t *testing.T) {
require.NoError(t, client.DropAll())
hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := fmt.Sprintf(graphQLVectorSchema, "cosine")
// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
testVectorGraphQlMutationAndQuery(t, hc)
}

func TestVectorGraphQlDotProductIndexMutationAndQuery(t *testing.T) {
require.NoError(t, client.DropAll())
hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := fmt.Sprintf(graphQLVectorSchema, "dotproduct")
// add schema
require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema))
require.NoError(t, hc.UpdateGQLSchema(schema))
testVectorGraphQlMutationAndQuery(t, hc)
}

// add project
func testVectorGraphQlMutationAndQuery(t *testing.T, hc *dgraphtest.HTTPClient) {
var vectors [][]float32
numProjects := 100
projects := generateProjects(numProjects)
for _, project := range projects {
vectors = append(vectors, project.TitleV)
addProject(t, hc, project)
Expand All @@ -177,30 +250,9 @@ func TestVectorGraphQlMutationAndQuery(t *testing.T) {

// query similar project by embedding
for _, project := range projects {
similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV)

similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV, numProjects)
for _, similarVec := range similarProjects {
require.Contains(t, vectors, similarVec.TitleV)
}
}
}

func TestVectorSchema(t *testing.T) {
require.NoError(t, client.DropAll())

hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := `type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!]
}`

// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
require.Error(t, hc.UpdateGQLSchema(graphQLVectorSchema))
require.NoError(t, client.DropAll())
require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema))
}

0 comments on commit 207583d

Please sign in to comment.