/* Timelinize Copyright (c) 2013 Matthew Holt This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see . */ package timeline import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/url" "strconv" "strings" "sync" "time" sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo" "go.uber.org/zap" ) const ( PyHost = "127.0.0.1" PyPort = 12003 ) // embeddingJob represents an embedding job. These jobs only // insert new embeddings, they don't update existing ones. // If existing embeddings are no longer valid/relevant, they // should be deleted prior to this. type embeddingJob struct { // be sure not to include duplicates in this list ItemIDs []int64 `json:"item_ids,omitempty"` // infer items from the given import job ItemsFromImportJob uint64 `json:"items_from_import_job,omitempty"` } func (ej embeddingJob) Run(job *ActiveJob, checkpoint []byte) error { var chkpt embeddingJobCheckpoint if len(checkpoint) > 0 { if err := json.Unmarshal(checkpoint, &chkpt); err != nil { job.logger.Error("failed to resume from checkpoint", zap.Error(err)) } job.Logger().Info("resuming from checkpoint", zap.Int64("page_start", chkpt.PageStart), zap.Int("position", chkpt.IndexOnPage), zap.Int("configured_item_count", len(ej.ItemIDs)), zap.Uint64("items_from_import_job", ej.ItemsFromImportJob)) } // if the embeddings to generate were explicitly enumerated, simply do those if len(ej.ItemIDs) > 0 { job.Logger().Info("waiting until Python server is ready") if !pythonServerReady(job.ctx, true) { return errors.New("python server not ready") } job.Logger().Info("generating embeddings using predefined list", zap.Int("count", len(ej.ItemIDs))) return ej.processInBatches(job, ej.ItemIDs, 0, chkpt.IndexOnPage) } if ej.ItemsFromImportJob == 0 { return errors.New("no items; expecting either individual items listed, or an import job") } logger := job.Logger().With(zap.Uint64("import_job_id", ej.ItemsFromImportJob)) logger.Info("counting total size of job") // select items to work on if they don't have an embedding yet, // or if their embedding was created before the modifying job // started, or if their embedding was created before the item // was manually modified; note that we use <= for timestamp // equality, because precision is only in seconds, and some // import jobs can finish very quickly so they can generate // embeddings the same second (TODO: maybe we should use milliseconds across the board) const mostOfQuery = ` FROM items LEFT JOIN embeddings ON embeddings.item_id = items.id LEFT JOIN jobs ON jobs.id = items.modified_job_id WHERE (items.job_id=? OR items.modified_job_id=?) AND (items.data_type LIKE 'image/%' OR items.data_type LIKE 'text/%') AND (items.data_id IS NOT NULL OR items.data_text IS NOT NULL OR items.data_file IS NOT NULL) AND (embeddings.id IS NULL OR embeddings.generated <= jobs.start/1000 OR embeddings.generated <= items.modified)` var jobSize int err := job.tl.db.ReadPool.QueryRowContext(job.ctx, ` SELECT count() `+mostOfQuery, ej.ItemsFromImportJob, ej.ItemsFromImportJob).Scan(&jobSize) if err != nil { return fmt.Errorf("failed counting size of job: %w", err) } job.SetTotal(jobSize) if jobSize == 0 { logger.Info("nothing to do", zap.Int("count", jobSize)) return nil } logger.Info("waiting until Python server is ready") if !pythonServerReady(job.ctx, true) { return errors.New("python server not ready") } logger.Info("generating embeddings for items from import job", zap.Int("count", jobSize)) var thisPageStart int64 lastItemID := chkpt.PageStart for { var pageResults []int64 const pageSize = 1000 thisPageStart = lastItemID // select items from the configured import job that have a data type we can generate embeddings for, // and which actually have data, and which do not yet have embeddings (OR the embedding is outdated // and the item was updated more recently than the embedding), and which are on this page of results rows, err := job.tl.db.ReadPool.QueryContext(job.ctx, ` SELECT items.id, items.stored, items.data_type `+mostOfQuery+` AND items.id > ? ORDER BY items.id LIMIT ? `, ej.ItemsFromImportJob, ej.ItemsFromImportJob, lastItemID, pageSize) if err != nil { return fmt.Errorf("failed querying page of database table: %w", err) } var hadRow bool for rows.Next() { hadRow = true var rowID, stored int64 var dataType *string err := rows.Scan(&rowID, &stored, &dataType) if err != nil { defer rows.Close() return fmt.Errorf("failed to scan row from database page: %w", err) } // Keep the last item ID as a fast way to offset for the next page lastItemID = rowID if qualifiesForEmbedding(dataType) { pageResults = append(pageResults, rowID) } } rows.Close() if err = rows.Err(); err != nil { return fmt.Errorf("iterating rows for researching embeddings failed: %w", err) } if !hadRow { break // all done! } if len(pageResults) == 0 { continue } if err := ej.processInBatches(job, pageResults, thisPageStart, chkpt.IndexOnPage); err != nil { return fmt.Errorf("processing page of embedding tasks: %w", err) } // clear this so that we don't skip items on the next page after the first one, when resuming from a checkpoint! chkpt.IndexOnPage = 0 } return nil } func (ej embeddingJob) processInBatches(job *ActiveJob, itemIDs []int64, firstItemIDOfPage int64, startIndexInPage int) error { // run each task in a goroutine which we group by batch; this allows // us to throttle concurrent goroutines, run tasks in parallel for speed, // and checkpoint after the entire batch has finished, which allows for // reliable resumption var wg sync.WaitGroup const batchSize = 10 // all goroutines must be done before we return defer wg.Wait() for i := startIndexInPage; i < len(itemIDs); i++ { if err := job.Continue(); err != nil { return err } // At the end of every batch, wait for all the goroutines to complete before // proceeding; and once they complete, that's a good time to checkpoint if i%batchSize == batchSize-1 { wg.Wait() if err := job.Checkpoint(embeddingJobCheckpoint{ PageStart: firstItemIDOfPage, IndexOnPage: i, }); err != nil { job.logger.Error("failed to save checkpoint", zap.Int("position", i), zap.Error(err)) } } itemID := itemIDs[i] // proceed to spawn a new goroutine as part of this batch wg.Add(1) go func(job *ActiveJob, itemID int64) { defer wg.Done() err := ej.generateEmbeddingForItem(job.Context(), job, itemID) if err != nil { job.logger.Error("failed generating embedding", zap.Int64("item_id", itemID), zap.Error(err)) } job.Progress(1) }(job, itemID) } return nil } func (ej embeddingJob) generateEmbeddingForItem(ctx context.Context, job *ActiveJob, itemID int64) error { var data []byte var dataFile, dataText, dataType, filename *string // we query all the ways the content of the item could be stored: as a file (get its filename), // text (read the text directly from the DB), and a separate data table (read the blob directly // from the DB); only 1 of those should be non-nil, so we set the "data" variable to whichever // one it is, to be sure we send it in for an embedding err := job.tl.db.ReadPool.QueryRowContext(ctx, `SELECT items.data_file, items.data_text, items.data_type, item_data.content FROM items LEFT JOIN item_data ON item_data.id = items.data_id WHERE items.id=? LIMIT 1`, itemID).Scan(&dataFile, &dataText, &dataType, &data) if err != nil { return fmt.Errorf("querying item for which to generate embedding: %w", err) } if dataType == nil { return fmt.Errorf("item %d has no data type", itemID) } // convert data file path (if there is one) into a full filename if dataFile != nil { fn := job.tl.FullPath(*dataFile) filename = &fn } // if the item is text content in the DB, set the data as it so it gets passed in for an embedding if data == nil && dataText != nil { data = []byte(*dataText) } v, err := generateSerializedEmbedding(ctx, *dataType, data, filename) if err != nil { return err } _, err = job.tl.db.WritePool.ExecContext(ctx, "INSERT INTO embeddings (item_id, embedding) VALUES (?, ?)", itemID, v) if err != nil { return fmt.Errorf("storing embedding for item %d: %w", itemID, err) } return nil } func generateSerializedEmbedding(ctx context.Context, dataType string, data []byte, filename *string) ([]byte, error) { embeddingJSON, err := generateEmbedding(ctx, dataType, data, filename) if err != nil { return nil, err } var embedding []float32 err = json.Unmarshal(embeddingJSON, &embedding) if err != nil { return nil, fmt.Errorf("unmarshaling JSON embedding: %w", err) } v, err := sqlite_vec.SerializeFloat32(embedding) if err != nil { return nil, fmt.Errorf("serializing embedding: %w", err) } return v, nil } func generateEmbedding(ctx context.Context, dataType string, data []byte, filename *string) ([]byte, error) { if dataType == "" { return nil, errors.New("content type is required") } endpoint := pyServerURL("/embedding") var body io.Reader if data != nil { body = bytes.NewReader(data) } if filename != nil { qs := make(url.Values) qs.Set("filename", *filename) endpoint += "?" + qs.Encode() } req, err := http.NewRequestWithContext(ctx, "QUERY", endpoint, body) if err != nil { return nil, fmt.Errorf("making request to generate embedding: %w", err) } req.Header.Set("Content-Type", dataType) // throttle expensive operation cpuIntensiveThrottle <- struct{}{} defer func() { <-cpuIntensiveThrottle }() // check here in case the job was cancelled while we waited on the throttle if err := ctx.Err(); err != nil { return nil, err } resp, err := http.DefaultClient.Do(req) if err != nil { return nil, fmt.Errorf("performing embedding request to ML server: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { const maxSize = 1024 * 100 msg, err := io.ReadAll(io.LimitReader(resp.Body, maxSize)) if err != nil { return nil, fmt.Errorf("error reading error response from ML server, HTTP %d: %w", resp.StatusCode, err) } return nil, fmt.Errorf("got error status from ML server: HTTP %d (message='%s')", resp.StatusCode, msg) } const maxSize = 1024 * 100 respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxSize)) if err != nil { return nil, err } return respBody, nil } // TODO: endpoint currently works for images only func classify(ctx context.Context, itemFiles map[uint64]string, labels []string) (map[uint64]float64, error) { endpoint := pyServerURL("/classify") jsonBytes, err := json.Marshal(itemFiles) if err != nil { return nil, err } body := bytes.NewReader(jsonBytes) qs := make(url.Values) for _, label := range labels { qs.Add("labels", label) } endpoint += "?" + qs.Encode() req, err := http.NewRequestWithContext(ctx, "QUERY", endpoint, body) if err != nil { return nil, fmt.Errorf("making request to classify: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return nil, fmt.Errorf("performing classification request to ML server: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { const maxSize = 1024 * 10 msg, err := io.ReadAll(io.LimitReader(resp.Body, maxSize)) if err != nil { return nil, fmt.Errorf("error reading error response from ML server, HTTP %d: %w", resp.StatusCode, err) } return nil, fmt.Errorf("got error status from ML server: HTTP %d (message='%s')", resp.StatusCode, msg) } var scores map[uint64]float64 err = json.NewDecoder(resp.Body).Decode(&scores) if err != nil { return nil, fmt.Errorf("decoding JSON response: %w", err) } return scores, nil } func qualifiesForEmbedding(mimeType *string) bool { return mimeType != nil && (strings.HasPrefix(*mimeType, "image/") || strings.HasPrefix(*mimeType, "text/")) } type embeddingJobCheckpoint struct { PageStart int64 `json:"page_start"` IndexOnPage int `json:"index_on_page"` } func pythonServerReady(ctx context.Context, wait bool) bool { healthCheckURL := pyServerURL("/health-check") for { req, err := http.NewRequestWithContext(ctx, http.MethodGet, healthCheckURL, nil) if err != nil { panic("could not construct health check request: " + err.Error()) } resp, err := pythonServerHTTPClient.Do(req) if err == nil && resp.StatusCode >= 200 && resp.StatusCode < 300 { resp.Body.Close() return true } if !wait { return false } const pause = 500 * time.Millisecond select { case <-time.After(pause): case <-ctx.Done(): return false } } } func pyServerURL(path string) string { hostPort := net.JoinHostPort(PyHost, strconv.Itoa(PyPort)) return fmt.Sprintf("http://%s%s", hostPort, path) } var pythonServerHTTPClient = &http.Client{ Timeout: 2 * time.Second, }