Skip to content

Commit

Permalink
feat: CosineSimilarity Expression
Browse files Browse the repository at this point in the history
- fixes #2005

Signed-off-by: Andreas Reichel <[email protected]>
  • Loading branch information
manticore-projects committed Sep 19, 2024
1 parent d1373c5 commit 90adf82
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import net.sf.jsqlparser.expression.operators.relational.Between;
import net.sf.jsqlparser.expression.operators.relational.ContainedBy;
import net.sf.jsqlparser.expression.operators.relational.Contains;
import net.sf.jsqlparser.expression.operators.relational.CosineSimilarity;
import net.sf.jsqlparser.expression.operators.relational.DoubleAnd;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExcludesExpression;
Expand Down Expand Up @@ -665,4 +666,6 @@ default void visit(PriorTo priorTo) {
default void visit(Inverse inverse) {
this.visit(inverse, null);
}

<S> T visit(CosineSimilarity cosineSimilarity, S context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import net.sf.jsqlparser.expression.operators.relational.Between;
import net.sf.jsqlparser.expression.operators.relational.ContainedBy;
import net.sf.jsqlparser.expression.operators.relational.Contains;
import net.sf.jsqlparser.expression.operators.relational.CosineSimilarity;
import net.sf.jsqlparser.expression.operators.relational.DoubleAnd;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExcludesExpression;
Expand Down Expand Up @@ -792,4 +793,11 @@ public <S> T visit(Inverse inverse, S context) {
return inverse.getExpression().accept(this, context);
}

@Override
public <S> T visit(CosineSimilarity cosineSimilarity, S context) {
cosineSimilarity.getLeftExpression().accept(this, context);
cosineSimilarity.getRightExpression().accept(this, context);
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*-
* #%L
* JSQLParser library
* %%
* Copyright (C) 2004 - 2022 JSQLParser
* %%
* Dual licensed under GNU LGPL 2.1 or Apache License 2.0
* #L%
*/
package net.sf.jsqlparser.expression.operators.relational;

import net.sf.jsqlparser.expression.ExpressionVisitor;

public class CosineSimilarity extends ComparisonOperator {

public CosineSimilarity() {
super("<=>");
}

public CosineSimilarity(String operator) {
super(operator);
}

@Override
public <T, S> T accept(ExpressionVisitor<T> expressionVisitor, S context) {
return expressionVisitor.visit(this, context);
}
}
8 changes: 8 additions & 0 deletions src/main/java/net/sf/jsqlparser/util/TablesNamesFinder.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import net.sf.jsqlparser.expression.operators.relational.Between;
import net.sf.jsqlparser.expression.operators.relational.ContainedBy;
import net.sf.jsqlparser.expression.operators.relational.Contains;
import net.sf.jsqlparser.expression.operators.relational.CosineSimilarity;
import net.sf.jsqlparser.expression.operators.relational.DoubleAnd;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExcludesExpression;
Expand Down Expand Up @@ -1624,6 +1625,13 @@ public <S> Void visit(Inverse inverse, S context) {
return null;
}

@Override
public <S> Void visit(CosineSimilarity cosineSimilarity, S context) {
cosineSimilarity.getLeftExpression().accept(this, context);
cosineSimilarity.getRightExpression().accept(this, context);
return null;
}

@Override
public <S> Void visit(VariableAssignment variableAssignment, S context) {
variableAssignment.getVariable().accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import net.sf.jsqlparser.expression.operators.relational.Between;
import net.sf.jsqlparser.expression.operators.relational.ContainedBy;
import net.sf.jsqlparser.expression.operators.relational.Contains;
import net.sf.jsqlparser.expression.operators.relational.CosineSimilarity;
import net.sf.jsqlparser.expression.operators.relational.DoubleAnd;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExcludesExpression;
Expand Down Expand Up @@ -1742,4 +1743,11 @@ public <S> StringBuilder visit(PriorTo priorTo, S context) {
public <S> StringBuilder visit(Inverse inverse, S context) {
return buffer.append(inverse.toString());
}

@Override
public <S> StringBuilder visit(CosineSimilarity cosineSimilarity, S context) {
deparse(cosineSimilarity,
" " + cosineSimilarity.getStringExpression() + " ", context);
return buffer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import net.sf.jsqlparser.expression.operators.relational.Between;
import net.sf.jsqlparser.expression.operators.relational.ContainedBy;
import net.sf.jsqlparser.expression.operators.relational.Contains;
import net.sf.jsqlparser.expression.operators.relational.CosineSimilarity;
import net.sf.jsqlparser.expression.operators.relational.DoubleAnd;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExcludesExpression;
Expand Down Expand Up @@ -1265,4 +1266,11 @@ public void visit(Inverse inverse) {
visit(inverse, null);
}

@Override
public <S> Void visit(CosineSimilarity cosineSimilarity, S context) {
cosineSimilarity.getLeftExpression().accept(this, context);
cosineSimilarity.getRightExpression().accept(this, context);
return null;
}

}
1 change: 1 addition & 0 deletions src/main/jjtree/net/sf/jsqlparser/parser/JSqlParserCC.jjt
Original file line number Diff line number Diff line change
Expand Up @@ -4008,6 +4008,7 @@ Expression RegularCondition() #RegularCondition:
| "-#" { result = new JsonOperator("-#"); }
| "<->" { result = new GeometryDistance("<->"); }
| "<#>" { result = new GeometryDistance("<#>"); }
| "<=>" { result = new CosineSimilarity(); }
)

( LOOKAHEAD(2) <K_PRIOR> rightExpression=ComparisonItem() { oraclePrior = EqualsTo.ORACLE_PRIOR_END; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,12 @@ public void testContainedBy() throws JSQLParserException {
TestUtils.assertSqlCanBeParsedAndDeparsed("SELECT * FROM foo WHERE a <& b");
Assertions.assertInstanceOf(ContainedBy.class, CCJSqlParserUtil.parseExpression("a <& b"));
}

@Test
void testCosineSimilarity() throws JSQLParserException {
TestUtils.assertSqlCanBeParsedAndDeparsed(
"SELECT (embedding <=> '[3,1,2]') AS cosine_similarity FROM items;");
Assertions.assertInstanceOf(CosineSimilarity.class,
CCJSqlParserUtil.parseExpression("embedding <=> '[3,1,2]'"));
}
}

0 comments on commit 90adf82

Please sign in to comment.