The crashes on ExFAT are caused by a bug in the MacOS ExFAT driver. It is unclear whether other OSes are affected too. https://github.com/mattn/go-sqlite3/issues/1355 We now utilize sqlite's concurrency features by creating a write pool (size 1) and a read pool, and can eliminate our own RWMutex, which prevents reads at the same time as writes. Sqlite's WAL mode allows reads concurrent with writes, and our code is much cleaner. Still need to do similar for the thumbnail DB. Also could look into using prepared statements for more efficiency gains.
457 lines
13 KiB
Go
457 lines
13 KiB
Go
/*
|
|
Timelinize
|
|
Copyright (c) 2013 Matthew Holt
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU Affero General Public License as published
|
|
by the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU Affero General Public License for more details.
|
|
|
|
You should have received a copy of the GNU Affero General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package timeline
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const (
|
|
PyHost = "127.0.0.1"
|
|
PyPort = 12003
|
|
)
|
|
|
|
// embeddingJob represents an embedding job. These jobs only
|
|
// insert new embeddings, they don't update existing ones.
|
|
// If existing embeddings are no longer valid/relevant, they
|
|
// should be deleted prior to this.
|
|
type embeddingJob struct {
|
|
// be sure not to include duplicates in this list
|
|
ItemIDs []int64 `json:"item_ids,omitempty"`
|
|
|
|
// infer items from the given import job
|
|
ItemsFromImportJob uint64 `json:"items_from_import_job,omitempty"`
|
|
}
|
|
|
|
func (ej embeddingJob) Run(job *ActiveJob, checkpoint []byte) error {
|
|
var chkpt embeddingJobCheckpoint
|
|
if len(checkpoint) > 0 {
|
|
if err := json.Unmarshal(checkpoint, &chkpt); err != nil {
|
|
job.logger.Error("failed to resume from checkpoint", zap.Error(err))
|
|
}
|
|
job.Logger().Info("resuming from checkpoint",
|
|
zap.Int64("page_start", chkpt.PageStart),
|
|
zap.Int("position", chkpt.IndexOnPage),
|
|
zap.Int("configured_item_count", len(ej.ItemIDs)),
|
|
zap.Uint64("items_from_import_job", ej.ItemsFromImportJob))
|
|
}
|
|
|
|
// if the embeddings to generate were explicitly enumerated, simply do those
|
|
if len(ej.ItemIDs) > 0 {
|
|
job.Logger().Info("waiting until Python server is ready")
|
|
if !pythonServerReady(job.ctx, true) {
|
|
return errors.New("python server not ready")
|
|
}
|
|
job.Logger().Info("generating embeddings using predefined list", zap.Int("count", len(ej.ItemIDs)))
|
|
return ej.processInBatches(job, ej.ItemIDs, 0, chkpt.IndexOnPage)
|
|
}
|
|
|
|
if ej.ItemsFromImportJob == 0 {
|
|
return errors.New("no items; expecting either individual items listed, or an import job")
|
|
}
|
|
|
|
logger := job.Logger().With(zap.Uint64("import_job_id", ej.ItemsFromImportJob))
|
|
|
|
logger.Info("counting total size of job")
|
|
|
|
// select items to work on if they don't have an embedding yet,
|
|
// or if their embedding was created before the modifying job
|
|
// started, or if their embedding was created before the item
|
|
// was manually modified; note that we use <= for timestamp
|
|
// equality, because precision is only in seconds, and some
|
|
// import jobs can finish very quickly so they can generate
|
|
// embeddings the same second (TODO: maybe we should use milliseconds across the board)
|
|
const mostOfQuery = `
|
|
FROM items
|
|
LEFT JOIN embeddings ON embeddings.item_id = items.id
|
|
LEFT JOIN jobs ON jobs.id = items.modified_job_id
|
|
WHERE (items.job_id=? OR items.modified_job_id=?)
|
|
AND (items.data_type LIKE 'image/%' OR items.data_type LIKE 'text/%')
|
|
AND (items.data_id IS NOT NULL OR items.data_text IS NOT NULL OR items.data_file IS NOT NULL)
|
|
AND (embeddings.id IS NULL
|
|
OR embeddings.generated <= jobs.start/1000
|
|
OR embeddings.generated <= items.modified)`
|
|
|
|
var jobSize int
|
|
err := job.tl.db.ReadPool.QueryRowContext(job.ctx, `
|
|
SELECT count()
|
|
`+mostOfQuery,
|
|
ej.ItemsFromImportJob, ej.ItemsFromImportJob).Scan(&jobSize)
|
|
if err != nil {
|
|
return fmt.Errorf("failed counting size of job: %w", err)
|
|
}
|
|
job.SetTotal(jobSize)
|
|
|
|
if jobSize == 0 {
|
|
logger.Info("nothing to do", zap.Int("count", jobSize))
|
|
return nil
|
|
}
|
|
|
|
logger.Info("waiting until Python server is ready")
|
|
|
|
if !pythonServerReady(job.ctx, true) {
|
|
return errors.New("python server not ready")
|
|
}
|
|
|
|
logger.Info("generating embeddings for items from import job", zap.Int("count", jobSize))
|
|
|
|
var thisPageStart int64
|
|
lastItemID := chkpt.PageStart
|
|
|
|
for {
|
|
var pageResults []int64
|
|
|
|
const pageSize = 1000
|
|
|
|
thisPageStart = lastItemID
|
|
|
|
// select items from the configured import job that have a data type we can generate embeddings for,
|
|
// and which actually have data, and which do not yet have embeddings (OR the embedding is outdated
|
|
// and the item was updated more recently than the embedding), and which are on this page of results
|
|
rows, err := job.tl.db.ReadPool.QueryContext(job.ctx, `
|
|
SELECT items.id, items.stored, items.data_type
|
|
`+mostOfQuery+`
|
|
AND items.id > ?
|
|
ORDER BY items.id
|
|
LIMIT ?
|
|
`, ej.ItemsFromImportJob, ej.ItemsFromImportJob, lastItemID, pageSize)
|
|
if err != nil {
|
|
return fmt.Errorf("failed querying page of database table: %w", err)
|
|
}
|
|
var hadRow bool
|
|
for rows.Next() {
|
|
hadRow = true
|
|
|
|
var rowID, stored int64
|
|
var dataType *string
|
|
|
|
err := rows.Scan(&rowID, &stored, &dataType)
|
|
if err != nil {
|
|
defer rows.Close()
|
|
return fmt.Errorf("failed to scan row from database page: %w", err)
|
|
}
|
|
|
|
// Keep the last item ID as a fast way to offset for the next page
|
|
lastItemID = rowID
|
|
|
|
if qualifiesForEmbedding(dataType) {
|
|
pageResults = append(pageResults, rowID)
|
|
}
|
|
}
|
|
rows.Close()
|
|
if err = rows.Err(); err != nil {
|
|
return fmt.Errorf("iterating rows for researching embeddings failed: %w", err)
|
|
}
|
|
|
|
if !hadRow {
|
|
break // all done!
|
|
}
|
|
|
|
if len(pageResults) == 0 {
|
|
continue
|
|
}
|
|
|
|
if err := ej.processInBatches(job, pageResults, thisPageStart, chkpt.IndexOnPage); err != nil {
|
|
return fmt.Errorf("processing page of embedding tasks: %w", err)
|
|
}
|
|
|
|
// clear this so that we don't skip items on the next page after the first one, when resuming from a checkpoint!
|
|
chkpt.IndexOnPage = 0
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ej embeddingJob) processInBatches(job *ActiveJob, itemIDs []int64, firstItemIDOfPage int64, startIndexInPage int) error {
|
|
// run each task in a goroutine which we group by batch; this allows
|
|
// us to throttle concurrent goroutines, run tasks in parallel for speed,
|
|
// and checkpoint after the entire batch has finished, which allows for
|
|
// reliable resumption
|
|
var wg sync.WaitGroup
|
|
const batchSize = 10
|
|
|
|
// all goroutines must be done before we return
|
|
defer wg.Wait()
|
|
|
|
for i := startIndexInPage; i < len(itemIDs); i++ {
|
|
if err := job.Continue(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// At the end of every batch, wait for all the goroutines to complete before
|
|
// proceeding; and once they complete, that's a good time to checkpoint
|
|
if i%batchSize == batchSize-1 {
|
|
wg.Wait()
|
|
|
|
if err := job.Checkpoint(embeddingJobCheckpoint{
|
|
PageStart: firstItemIDOfPage,
|
|
IndexOnPage: i,
|
|
}); err != nil {
|
|
job.logger.Error("failed to save checkpoint",
|
|
zap.Int("position", i),
|
|
zap.Error(err))
|
|
}
|
|
}
|
|
|
|
itemID := itemIDs[i]
|
|
|
|
// proceed to spawn a new goroutine as part of this batch
|
|
wg.Add(1)
|
|
go func(job *ActiveJob, itemID int64) {
|
|
defer wg.Done()
|
|
|
|
err := ej.generateEmbeddingForItem(job.Context(), job, itemID)
|
|
if err != nil {
|
|
job.logger.Error("failed generating embedding",
|
|
zap.Int64("item_id", itemID),
|
|
zap.Error(err))
|
|
}
|
|
|
|
job.Progress(1)
|
|
}(job, itemID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ej embeddingJob) generateEmbeddingForItem(ctx context.Context, job *ActiveJob, itemID int64) error {
|
|
var data []byte
|
|
var dataFile, dataText, dataType, filename *string
|
|
|
|
// we query all the ways the content of the item could be stored: as a file (get its filename),
|
|
// text (read the text directly from the DB), and a separate data table (read the blob directly
|
|
// from the DB); only 1 of those should be non-nil, so we set the "data" variable to whichever
|
|
// one it is, to be sure we send it in for an embedding
|
|
err := job.tl.db.ReadPool.QueryRowContext(ctx,
|
|
`SELECT items.data_file, items.data_text, items.data_type, item_data.content
|
|
FROM items
|
|
LEFT JOIN item_data ON item_data.id = items.data_id
|
|
WHERE items.id=?
|
|
LIMIT 1`, itemID).Scan(&dataFile, &dataText, &dataType, &data)
|
|
if err != nil {
|
|
return fmt.Errorf("querying item for which to generate embedding: %w", err)
|
|
}
|
|
if dataType == nil {
|
|
return fmt.Errorf("item %d has no data type", itemID)
|
|
}
|
|
|
|
// convert data file path (if there is one) into a full filename
|
|
if dataFile != nil {
|
|
fn := job.tl.FullPath(*dataFile)
|
|
filename = &fn
|
|
}
|
|
// if the item is text content in the DB, set the data as it so it gets passed in for an embedding
|
|
if data == nil && dataText != nil {
|
|
data = []byte(*dataText)
|
|
}
|
|
|
|
v, err := generateSerializedEmbedding(ctx, *dataType, data, filename)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = job.tl.db.WritePool.ExecContext(ctx, "INSERT INTO embeddings (item_id, embedding) VALUES (?, ?)", itemID, v)
|
|
if err != nil {
|
|
return fmt.Errorf("storing embedding for item %d: %w", itemID, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func generateSerializedEmbedding(ctx context.Context, dataType string, data []byte, filename *string) ([]byte, error) {
|
|
embeddingJSON, err := generateEmbedding(ctx, dataType, data, filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var embedding []float32
|
|
err = json.Unmarshal(embeddingJSON, &embedding)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unmarshaling JSON embedding: %w", err)
|
|
}
|
|
v, err := sqlite_vec.SerializeFloat32(embedding)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("serializing embedding: %w", err)
|
|
}
|
|
return v, nil
|
|
}
|
|
|
|
func generateEmbedding(ctx context.Context, dataType string, data []byte, filename *string) ([]byte, error) {
|
|
if dataType == "" {
|
|
return nil, errors.New("content type is required")
|
|
}
|
|
|
|
endpoint := pyServerURL("/embedding")
|
|
|
|
var body io.Reader
|
|
if data != nil {
|
|
body = bytes.NewReader(data)
|
|
}
|
|
if filename != nil {
|
|
qs := make(url.Values)
|
|
qs.Set("filename", *filename)
|
|
endpoint += "?" + qs.Encode()
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "QUERY", endpoint, body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("making request to generate embedding: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", dataType)
|
|
|
|
// throttle expensive operation
|
|
cpuIntensiveThrottle <- struct{}{}
|
|
defer func() { <-cpuIntensiveThrottle }()
|
|
|
|
// check here in case the job was cancelled while we waited on the throttle
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("performing embedding request to ML server: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
const maxSize = 1024 * 100
|
|
msg, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading error response from ML server, HTTP %d: %w", resp.StatusCode, err)
|
|
}
|
|
return nil, fmt.Errorf("got error status from ML server: HTTP %d (message='%s')", resp.StatusCode, msg)
|
|
}
|
|
|
|
const maxSize = 1024 * 100
|
|
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return respBody, nil
|
|
}
|
|
|
|
// TODO: endpoint currently works for images only
|
|
func classify(ctx context.Context, itemFiles map[uint64]string, labels []string) (map[uint64]float64, error) {
|
|
endpoint := pyServerURL("/classify")
|
|
|
|
jsonBytes, err := json.Marshal(itemFiles)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
body := bytes.NewReader(jsonBytes)
|
|
|
|
qs := make(url.Values)
|
|
for _, label := range labels {
|
|
qs.Add("labels", label)
|
|
}
|
|
endpoint += "?" + qs.Encode()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "QUERY", endpoint, body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("making request to classify: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("performing classification request to ML server: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
const maxSize = 1024 * 10
|
|
msg, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading error response from ML server, HTTP %d: %w", resp.StatusCode, err)
|
|
}
|
|
return nil, fmt.Errorf("got error status from ML server: HTTP %d (message='%s')", resp.StatusCode, msg)
|
|
}
|
|
|
|
var scores map[uint64]float64
|
|
err = json.NewDecoder(resp.Body).Decode(&scores)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decoding JSON response: %w", err)
|
|
}
|
|
|
|
return scores, nil
|
|
}
|
|
|
|
func qualifiesForEmbedding(mimeType *string) bool {
|
|
return mimeType != nil &&
|
|
(strings.HasPrefix(*mimeType, "image/") ||
|
|
strings.HasPrefix(*mimeType, "text/"))
|
|
}
|
|
|
|
type embeddingJobCheckpoint struct {
|
|
PageStart int64 `json:"page_start"`
|
|
IndexOnPage int `json:"index_on_page"`
|
|
}
|
|
|
|
func pythonServerReady(ctx context.Context, wait bool) bool {
|
|
healthCheckURL := pyServerURL("/health-check")
|
|
|
|
for {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, healthCheckURL, nil)
|
|
if err != nil {
|
|
panic("could not construct health check request: " + err.Error())
|
|
}
|
|
|
|
resp, err := pythonServerHTTPClient.Do(req)
|
|
if err == nil && resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
|
resp.Body.Close()
|
|
return true
|
|
}
|
|
|
|
if !wait {
|
|
return false
|
|
}
|
|
|
|
const pause = 500 * time.Millisecond
|
|
select {
|
|
case <-time.After(pause):
|
|
case <-ctx.Done():
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
|
|
func pyServerURL(path string) string {
|
|
hostPort := net.JoinHostPort(PyHost, strconv.Itoa(PyPort))
|
|
return fmt.Sprintf("http://%s%s", hostPort, path)
|
|
}
|
|
|
|
var pythonServerHTTPClient = &http.Client{
|
|
Timeout: 2 * time.Second,
|
|
}
|