前言

官方demo : https://github.com/microsoft/onnxruntime-inference-examples/blob/main/js/segment-anything/main.js

  • 版本

    • "@tensorflow/tfjs": "^4.17.0",
    • "onnxruntime-web": "^1.17.0",
  • 配置

    • 
      
      /**
       * Create service config by current env
       *
       * @param env The current env
       */
      export function createModelMap(env: Env.ImportMeta) {
        // const mockURL = 'https://mock.apifox.com/m1/3109515-0-default';
      
        const modelType: App.Model.ModelConfigMap = {
          dev: {
            wasm: {
              'ort-wasm.wasm': '/files/ort-wasm.wasm',
              'ort-wasm-simd.wasm': '/files/ort-wasm-simd.wasm',
              'ort-wasm-threaded.wasm': '/files/ort-wasm-threaded.wasm',
              'ort-wasm-simd-threaded.wasm': '/files/ort-wasm-simd-threaded.wasm',
            },
            onnx: {
              sam_b: ['/models/sam_vit_b_01ec64.encoder.onnx', "/models/sam_vit_b_01ec64.decoder.onnx"],
              mobile_sam: ['/models/mobile_sam.encoder.onnx', "/models/mobile_sam.decoder.onnx"]
            }
          },
          test: {
            wasm: {
              'ort-wasm.wasm': '/files/ort-wasm.wasm',
              'ort-wasm-simd.wasm': '/files/ort-wasm-simd.wasm',
              'ort-wasm-threaded.wasm': '/files/ort-wasm-threaded.wasm',
              'ort-wasm-simd-threaded.wasm': '/files/ort-wasm-simd-threaded.wasm',
            },
            onnx: {
              sam_b: ['/models/sam_vit_b_01ec64.encoder.onnx', "/models/sam_vit_b_01ec64.decoder.onnx"],
              mobile_sam: ['/models/mobile_sam.encoder.onnx', "/models/mobile_sam.decoder.onnx"]
            }
          },
          prod: {
            wasm: {
              'ort-wasm.wasm': '/files/ort-wasm.wasm',
              'ort-wasm-simd.wasm': '/files/ort-wasm-simd.wasm',
              'ort-wasm-threaded.wasm': '/files/ort-wasm-threaded.wasm',
              'ort-wasm-simd-threaded.wasm': '/files/ort-wasm-simd-threaded.wasm',
            },
            onnx: {
              sam_b: ['/models/sam_vit_b_01ec64.encoder.onnx', "/models/sam_vit_b_01ec64.decoder.onnx"],
              mobile_sam: ['/models/mobile_sam.encoder.onnx', "/models/mobile_sam.decoder.onnx"]
            }
          },
        };
      
        const { VITE_SERVICE_ENV = 'dev' } = env;
      
        return modelType[VITE_SERVICE_ENV];
      }
      
    • onnx配置

      ort.env.wasm.numThreads = data.threads;
      ort.env.wasm.simd = true;
      // ort.env.wasm.proxy = true;
      ort.env.wasm.wasmPaths = globalModelMap.wasm as ort.Env.WasmPrefixOrFilePaths;
      
    • 模型大小设置。这里需要注意的是onnx转换后,模型图像的宽高固定为1024*684,如果图像什么出现问题,需要做一下转换和计算。

封装库

  • segment.ts

import * as ort from 'onnxruntime-web';
import * as tf from '@tensorflow/tfjs'
import { ref, reactive, onMounted, onUnmounted } from 'vue'
import { captureFrame, toBase64 } from '@/utils/videoFrame'
import { BlockItem, ModelMapType, ConfigType, UploadType, FrameItem, BlockItemType, PointOperaType, SegmentOperaType } from "./types/segment.d.js"
import dayjs from 'dayjs';
import ColorMaker from '@/utils/colorUtils.js';
import { createModelMap } from '~/env.config';

const MODEL_WIDTH = 1024;
const MODEL_HEIGHT = 684;

const globalModelMap = createModelMap(import.meta.env)
const MODEL_MAP: ModelMapType = globalModelMap.onnx
const modelSess: ort.InferenceSession[] = []

function getConfig(): ConfigType {
    const data: ConfigType = {
        model: "mobile_sam",
        provider: { name: "wasm" },//wasm,webgpu,webnn
        device: "cpu",
        threads: 4,
        clear_cache: false
    };
    ort.env.wasm.numThreads = data.threads;
    return data;
}


ort.env.wasm.simd = true;
// ort.env.wasm.proxy = true;
ort.env.wasm.wasmPaths = globalModelMap.wasm as ort.Env.WasmPrefixOrFilePaths;

const config = getConfig();

/**
 * fetch and cache url
 */
async function fetchAndCache(url: string) {
    try {
        const cache = await caches.open("onnx");
        if (config.clear_cache) {
            cache.delete(url);
        }
        let cachedResponse = await cache.match(url);
        if (cachedResponse == undefined) {
            await cache.add(url);
            cachedResponse = await cache.match(url);
            // console.log(`${url} (from network)`);
        } else {
            // console.log(`${url} (from cache)`);
        }
        const data = await cachedResponse?.arrayBuffer();
        return data
    } catch (error) {
        // console.log(`${url} (from network)`);
        return await fetch(url).then(response => {
            //   {
            //   mode: "cors",
            //   headers: {
            //     "Cross-Origin-Embedder-Policy": "require-corp",
            //     "Cross-Origin-Opener-Policy": "same-origin"
            //   },
            //   referrerPolicy: "same-origin",
            //   credentials: "same-origin"
            // }
            // response.headers.set("Cross-Origin-Embedder-Policy", "require-corp");
            // response.headers.set("Cross-Origin-Opener-Policy", "same-origin");
            return response.arrayBuffer()
        });
    }
}

function loadModel(model: string[], idx: number) {
    const provider = config.provider

    switch (config.provider.name) {
        case "webnn":
            if (!("ml" in navigator)) {
                throw new Error("webnn is NOT supported");
            }
            provider.deviceType = config.device;
            provider.powerPreference = 'default'
            break;
        case "webgpu":
            if (!navigator.gpu) {
                throw new Error("webgpu is NOT supported");
            }
            break;
    }

    const opt = { executionProviders: [provider] };

    fetchAndCache(model[idx]).then(async (data: any) => {
        // sess[idx] = ort.InferenceSession.create('./models/sam_vit_b_01ec64.encoder.onnx');
        // const u8data = new Uint8Array(data);
        modelSess[idx] = await ort.InferenceSession.create(data, opt);
        // console.log(`${model[idx]} loaded.`);
        // sess[idx].then(() => {
        //   console.log(`${model[idx]} loaded.`);
        if (idx == 0) {
            loadModel(model, 1);
        }
        // }, (e) => {
        //   console.log(`${model[idx]} failed with ${e}.`);
        //   throw e;

        // });
    })
}


export default function useSegment() {

    onMounted(() => {
        const model: string[] = MODEL_MAP[config.model as keyof ModelMapType];
        loadModel(model, 0);
    })
    onUnmounted(() => {
        modelSess.forEach((item) => {
            if (item != null) {
                item.release();
            }
        })
    })
    const frameList = reactive<FrameItem[]>([]);
    const frameIndex = ref(0);
    const blockList = reactive<BlockItem[]>([]);
    const blockIndex = ref(0);
    const videoUploadRef = ref();
    const videoUploadLoading = ref(false);
    const fileBase64 = ref('');
    const frameBase64 = ref('');
    const canvasRef = ref();
    const previewRef = ref();
    let colorIdx = 1, MAX_WIDTH: number = 1000, MAX_HEIGHT: number = 360
    const _useBlocks = useBlocks()
    const _useModel = useModel()
    const _useMediaUpload = useMediaUpload()
    const _useFrames = useFrames();

    // 当前帧的图像数据
    let imageImageData: ImageData | undefined, imageEmbeddings: ort.Tensor;

    function useFrames() {
        function frameAddHandle(addIndex: boolean = true, obj: { sec: number, type: number, base64: string, imageData?: ImageData, imageEmbeddings?: ort.Tensor }) {
            frameList.push({
                id: String(dayjs().valueOf()) + Math.random() * (100 - 1) + 1,
                type: obj.type,
                sec: obj.sec,
                base64: obj.base64,
                imageData: obj.imageData,
                imageEmbeddings: obj.imageEmbeddings
            });
            if (addIndex) { frameIndex.value += 1; }
        }
        function frameDelHandle() {

        }
        return {
            frameAddHandle,
            frameDelHandle
        }
    }

    function useBlocks() {
        const blockShowAll = ref(false); // true 展示全部,false 展示一个
        function blockAddHandle(obj: { type?: BlockItemType }, addIndex: boolean = true) {
            blockShowAll.value = false;
            blockList.push({
                id: String(dayjs().valueOf()) + Math.random() * (100 - 1) + 1,
                type: obj.type || 'mask',
                hasPoint: false,
                color: ColorMaker.getRgbColorArr(colorIdx),
                pointModel: [],
                pointXY: [],
                resource: '',
                resourceType: '',
                imgBase64: '',
                imgData: undefined,
                result: null,
                zIndex: 100 - blockIndex.value,
                segmentImage: '',
                segmentImageBase64: '',
                segmentImageBitmap: null
            });
            colorIdx++;
            if (addIndex) { blockIndex.value += 1; blockShowHandle(blockIndex.value); }

        }
        function blockDelHandle(index: number) {
            if (blockList.length === 1) {
                window.$message?.warning('请保留一层蒙版');
            } else {
                blockList.splice(index, 1);
                if (blockIndex.value >= blockList.length) {
                    blockIndex.value = blockList.length - 1;
                }
            }
            blockShowHandle();
        }

        function blockShowHandle(index: number = -1) {
            const canvas = canvasRef.value
            const ctx = canvas.getContext('2d');
            ctx.globalAlpha = 0.5;
            ctx.clearRect(0, 0, canvas.width, canvas.height);


            // todo: 等后面换成多帧后,需要从对应获取
            ctx.putImageData(imageImageData, 0, 0, 0, 0, canvas.width, canvas.height);
            if (index >= 0) {
                blockIndex.value = index;
                blockShowAll.value = false;
                const block = blockList[index];
                ctx.drawImage(block.segmentImageBitmap, 0, 0);
            } else {
                blockShowAll.value = true;
                for (let i = blockList.length - 1; i >= 0; i--) {
                    const block = blockList[i];
                    ctx.drawImage(block.segmentImageBitmap, 0, 0);
                }
            }
        }

        return {
            blockList,
            blockIndex,
            blockShowAll,
            blockAddHandle,
            blockDelHandle,
            blockShowHandle,
        }
    }

    function useMediaUpload() {
        const uploadType: UploadType = {
            type: 1,
            blockIndex: 0,
        };

        async function uploadHandle(type: number, bIndex: number = 0) {
            const vur = videoUploadRef.value;
            uploadType.type = type;
            uploadType.blockIndex = bIndex;
            switch (type) {
                case 1:
                    vur.accept = 'video/*'
                    break;
                case 2:
                    vur.accept = 'video/*,image/*'
                    break;
                case 3:
                    vur.accept = 'video/*'
            }
            vur.onchange = uploadVideo;
            vur.click();
        }

        async function uploadVideo() {
            if (uploadType.type === 1) {
                uploadMedia((img) => {
                    _useBlocks.blockAddHandle({ type: "mask" }, false);
                    _useModel.imageHandle(img, 'frame')
                })
            } else if (uploadType.type === 2) {
                uploadResource(uploadType.blockIndex)
            } else if (uploadType.type === 3) {
                uploadMedia((img) => {
                    // _useBlocks.blockAddHandle(false);
                    // _useModel.imageHandle(img)
                    _useModel.imageRenderHandle(img);
                })
            }
        }
        async function uploadMedia(callback: (img: HTMLImageElement) => void) {
            const vur = videoUploadRef.value;
            const file: File = vur.files[0];
            if (file) {
                frameList.length = 0;
                frameIndex.value = 0
                blockList.length = 0;
                blockIndex.value = 0;
                videoUploadLoading.value = true
                const img = new Image()
                if (file.type.startsWith("image/")) {
                    const base64 = await toBase64(file)
                    img.src = base64
                    fileBase64.value = img.src
                    frameBase64.value = img.src
                } else {
                    const info = await captureFrame(file, 1);
                    img.src = await toBase64(info.blob as Blob)
                    fileBase64.value = await toBase64(file);
                    frameBase64.value = img.src
                }
                _useFrames.frameAddHandle(false, {
                    sec: 1,
                    type: uploadType.type,
                    base64: frameBase64.value,
                })
                videoUploadLoading.value = false
                if (img.complete) {
                    if (callback != null) {
                        callback(img)
                    }
                } else {
                    img.onload = () => {
                        if (callback != null) {
                            callback(img)
                        }
                    }
                }
                vur.value = '';
            }
        }

        async function uploadResource(index: number) {
            if (index > blockList.length || index < 0) {
                return
            }
            const vur = videoUploadRef.value;
            const file: File = vur.files[0];
            if (file) {
                const item = blockList[index];
                item.resource = await toBase64(file);
                if (file.type.startsWith('image/')) {
                    item.resourceType = 'image';
                } else if (file.type.startsWith('video/')) {
                    item.resourceType = 'video';
                }
                vur.value = '';
            }
        }
        return {
            videoUploadRef,
            videoUploadLoading,
            fileBase64,
            frameBase64,
            uploadMedia,
            uploadResource,
            uploadHandle
        }
    }

    function useModel() {
        let pointOpera: PointOperaType = 'plus';
        let segmentOpera: SegmentOperaType = 'mask'

        async function imageRenderHandle(img: HTMLImageElement) {
            let renderLoading = videoUploadLoading.value ? 1 : 0;
            if (renderLoading === 0) {
                videoUploadLoading.value = true
            }
            let width = img.width;
            let height = img.height;

            if (width > MAX_WIDTH) {
                height = height * (MAX_WIDTH / width);
                width = MAX_WIDTH;
            }
            // } else {
            if (height > MAX_HEIGHT) {
                width = width * (MAX_HEIGHT / height);
                height = MAX_HEIGHT;
            }
            // }

            width = Math.round(width);
            height = Math.round(height);


            const canvas: HTMLCanvasElement = canvasRef.value;
            canvas.width = width;
            canvas.height = height;
            let context = canvas.getContext('2d');
            context?.drawImage(img, 0, 0, width, height);
            if (renderLoading === 0) {
                videoUploadLoading.value = false
            }
        }

        async function imageHandle(img: HTMLImageElement, type: "block" | "frame") {
            imageRenderHandle(img);
            videoUploadLoading.value = true;

            const frame = frameList[frameIndex.value]

            const block = blockList[blockIndex.value];
            if (block.imgBase64 === '') {
                block.imgBase64 = frameBase64.value;
            }

            const canvas: HTMLCanvasElement = canvasRef.value;
            let width = canvas.width
            let height = canvas.height
            let context = canvas.getContext('2d');

            setTimeout(async () => {
                imageImageData = context?.getImageData(0, 0, width, height);
                if (!frame.imageData) {
                    frame.imageData = imageImageData
                }

                let resizeTensor = await ort.Tensor.fromImage(imageImageData as unknown as ImageBitmap, { resizedWidth: MODEL_WIDTH, resizedHeight: MODEL_HEIGHT })
                let tfTensor = tf.tensor(resizeTensor.data, [...resizeTensor.dims]);
                tfTensor = tfTensor.reshape([3, MODEL_HEIGHT, MODEL_WIDTH]);
                tfTensor = tfTensor.transpose([1, 2, 0]).mul(255)
                const imageDataTensor = new ort.Tensor(tfTensor.dataSync(), tfTensor.shape);
                let start = Date.now();
                const feed = { "input_image": imageDataTensor };
                const res = await modelSess[0].run(feed);
                let end = Date.now();
                let time_taken = (end - start) / 1000;
                console.log(`Computing image embedding took ${time_taken} seconds`);

                imageEmbeddings = res.image_embeddings;
                if (type === "block") {
                    block.imageEmbeddings = res.image_embeddings;
                } else if (type === 'frame') {
                    frame.imageEmbeddings = res.image_embeddings;
                }
                videoUploadLoading.value = false;
            }, 500)
        }


        async function clickHandle(e: MouseEvent, label: number, fn?: (imgB64: string) => void) {
            // const originWidth = imageFrameRef.value.naturalWidth;
            // const offseWidth = imageFrameRef.value.offsetWidth;
            // const xscale = offseWidth / originWidth;

            // const originHeight = imageFrameRef.value.naturalHeight;
            // const offsetHeight = imageFrameRef.value.offsetHeight;
            // const yscale = offsetHeight / originHeight;

            // 获取x,y坐标
            // const x = Math.floor(e.offsetX / xscale);
            // const y = Math.floor(e.offsetY / yscale);

            // console.log(image_embeddings)
            // if (image_embeddings === undefined) {
            //   await sess[0];
            // }
            // const emb = await image_embeddings;

            const canvas = canvasRef.value
            const ctx = canvas.getContext('2d');

            const rect = canvas.getBoundingClientRect();
            const x = Math.floor(e.clientX - rect.left);
            const y = Math.floor(e.clientY - rect.top);
            console.log(x, y)

            const frame = frameList[frameIndex.value]
            const block = blockList[blockIndex.value];

            if (pointOpera === 'plus') {
                block.pointModel.push(label);
                block.pointXY.push(x, y);
            } else {
                block.pointModel = [label];
                block.pointXY = [x, y];
            }
            // canvas.width = imageImageData.width;
            // canvas.height = imageImageData.height;
            // console.log(canvas.width, canvas.height);

            ctx.clearRect(0, 0, canvas.width, canvas.height);

            ctx.putImageData(imageImageData, 0, 0, 0, 0, canvas.width, canvas.height);
            // ctx.fillStyle = 'blue';
            // ctx.fillRect(x, y, 10, 10);

            // const labels = [label]
            const labels = block.pointModel;
            const points = block.pointXY;

            const pointCoords = new ort.Tensor(new Float32Array(points), [1, points.length / 2, 2]);
            const pointLabels = new ort.Tensor(new Float32Array(labels), [1, labels.length]);
            const maskInput = new ort.Tensor(new Float32Array(256 * 256), [1, 1, 256, 256]);
            const hasMask = new ort.Tensor(new Float32Array([0]), [1,]);
            const origianlImageSize = new ort.Tensor(new Float32Array([MODEL_HEIGHT, MODEL_WIDTH]), [2,]);

            // const t = new ort.Tensor(emb.image_embeddings.type, Float32Array.from(emb.image_embeddings.data), emb.image_embeddings.dims);
            // console.log("t", t)
            // const t = new ort.Tensor(image_embeddings.type, Float32Array.from(image_embeddings.data), image_embeddings.dims);

            const feed = {
                "image_embeddings": imageEmbeddings,
                "point_coords": pointCoords,
                "point_labels": pointLabels,
                "mask_input": maskInput,
                "has_mask_input": hasMask,
                "orig_im_size": origianlImageSize
            }
            console.log("feed", feed)
            // const start = Date.now();
            try {
                const results = await modelSess[1].run(feed);
                console.log("Generated mask:", results);
                const mask = results.masks;

                const maskImageData = mask.toImageData();

                segmentTypeHandle(maskImageData, block)

                block.hasPoint = true;
                if (fn) {
                    fn(block.segmentImageBase64);
                }
            } catch (error) {
                console.log(`caught error: ${error}`)
            }
            // const end = Date.now();
            // console.log(`generating masks took ${(end - start) / 1000} seconds`);
        }

        function cropImage(imagedata: ImageData) {
            let left = -1, top = -1, right = -1, bottom = -1;
            const idata = imagedata.data
            const count = idata.length;
            const countWidth = imagedata.width * 4;
            console.log(count)
            for (let i = 0, y = 0; i < count; i = i + 4) {
                if (i != 0 && i % countWidth == 0) {
                    y++;
                }
                // console.log([idata[i], idata[i + 1], idata[i + 2], idata[i + 3]])
                if (
                    idata[i] !== 255
                    || idata[i + 1] !== 255
                    || idata[i + 2] !== 255
                    || idata[i + 3] !== 255
                ) {
                    const x = (i - countWidth * y) / 4;
                    if (left === -1) {
                        left = x
                    } else {
                        if (left > x) {
                            left = x;
                        }
                        if (right === -1) {
                            right = x
                        } else if (right < x) {
                            right = x
                        }
                    }
                    if (top === -1) {
                        top = y;
                    } else {
                        if (bottom === -1) {
                            bottom = y
                        } else if (bottom < y) {
                            bottom = y
                        }
                    }
                }
            }
            return [left, top, right - left, bottom - top]
        }

        async function segmentTypeHandle(maskImageData: ImageData, block: BlockItem) {
            const idata = maskImageData.data
            const count = idata.length;
            if (segmentOpera === 'mask') {
                for (let i = 0; i < count; i = i + 4) {
                    // console.log([idata[i], idata[i + 1], idata[i + 2], idata[i + 3]])
                    if (idata[i] > 0) {
                        idata[i] = block.color[0];
                        idata[i + 1] = block.color[1]
                        idata[i + 2] = block.color[2]
                        idata[i + 3] = 178; // 透明度
                    } else {
                        idata[i + 3] = 0
                    }
                }
            } else {
                for (let i = 0; i < count; i = i + 4) {
                    // console.log([idata[i], idata[i + 1], idata[i + 2], idata[i + 3]])
                    if (idata[i] > 0) {
                        idata[i + 3] = 0; // 透明度
                    } else {
                        idata[i] = 255
                        idata[i + 1] = 255
                        idata[i + 2] = 255
                        idata[i + 3] = 255
                    }
                }
            }

            const canvas = canvasRef.value
            const ctx = canvas.getContext('2d');

            // ctx.globalAlpha = 0.5;
            // convert image data to image bitmap
            // let imageBitmap = await createImageBitmap(maskImageData, { resizeWidth: canvas.width, resizeHeight: canvas.height, resizeQuality: "medium" });
            let imageBitmap = await createImageBitmap(maskImageData);
            // canvas.width = maskImageData.width;
            // canvas.height = maskImageData.height;
            ctx.drawImage(imageBitmap, 0, 0);

            if (segmentOpera === 'crop') {
                const locate = cropImage(ctx.getImageData(0, 0, canvas.width, canvas.height));
                const imagedata = ctx.getImageData(locate[0], locate[1], locate[2], locate[3])
                imageBitmap = await createImageBitmap(imagedata);

                const tmpCanvas = document.createElement('canvas');
                const tmpCtx = tmpCanvas.getContext('2d');
                // 设置canvas的尺寸并将复制的ImageData绘制到canvas上
                tmpCanvas.width = locate[2]
                tmpCanvas.height = locate[3]
                tmpCtx?.drawImage(imageBitmap, 0, 0);
                block.segmentImageBase64 = tmpCanvas.toDataURL()

                ctx.clearRect(0, 0, canvas.width, canvas.height);
                ctx.putImageData(imageImageData, 0, 0, 0, 0, canvas.width, canvas.height);
                previewHandle(imagedata)
            } else {
                const tmpCanvas = document.createElement('canvas');
                const tmpCtx = tmpCanvas.getContext('2d');
                // 设置canvas的尺寸并将复制的ImageData绘制到canvas上
                tmpCanvas.width = canvas.width;
                tmpCanvas.height = canvas.height;
                tmpCtx?.drawImage(imageBitmap, 0, 0);
                block.segmentImageBase64 = tmpCanvas.toDataURL()
            }
            block.segmentImage = canvas.toDataURL();
            block.segmentImageBitmap = imageBitmap;
        }

        function previewHandle(imagedata: ImageData) {
            if (previewRef.value) {
                const tagName = previewRef.value.tagName
                if (tagName === 'CANVAS') {
                    const canvas = previewRef.value
                    const ctx = canvas.getContext('2d')
                    canvas.width = imagedata.width;
                    canvas.height = imagedata.height;
                    ctx.putImageData(imagedata, 0, 0, 0, 0, imagedata.width, imagedata.height);
                } else if (tagName === 'IMG') {
                    const tmpCanvas = document.createElement('canvas');
                    const tmpCtx = tmpCanvas.getContext('2d');
                    // 设置canvas的尺寸并将复制的ImageData绘制到canvas上
                    tmpCanvas.width = imagedata.width;
                    tmpCanvas.height = imagedata.height;
                    tmpCtx?.putImageData(imagedata, 0, 0, 0, 0, imagedata.width, imagedata.height);
                    previewRef.value.src = tmpCanvas.toDataURL()
                }
            }
        }

        function setMaxSize(width: number, height: number) {
            MAX_WIDTH = width;
            MAX_HEIGHT = height
        }

        function setSegmentOpera(tempSegmentOpera: SegmentOperaType) {
            segmentOpera = tempSegmentOpera
        }

        function setPointOpera(type: PointOperaType) {
            pointOpera = type
        }

        return {
            canvasRef,
            previewRef,
            imageHandle,
            imageRenderHandle,
            clickHandle,
            setPointOpera,
            setSegmentOpera,
            setMaxSize,
        }
    }

    return {
        _useBlocks,
        _useModel,
        _useMediaUpload,
        _useFrames,
    }
}

  • segment.d.ts

import * as ort from 'onnxruntime-web';
type xy = {
    x: number;
    y: number;
};

export type BlockItemType = "mask" | "panel"

export interface BlockItem {
    id: string;
    type: BlockItemType,// mask遮罩,panel切图
    hasPoint: boolean; // 是否操作
    color: number[]; // 颜色 rgba
    pointModel: number[]; // 操作模式
    pointXY: number[]; // 操作点
    resource: string;
    resourceType: string;
    imgBase64: string; // 图片base64
    imgData: ImageData | undefined;
    result: Api.AiTools.SegmentFrame | null;
    zIndex: number;
    segmentImage: string;// 图片遮罩
    segmentImageBase64: string;// 图片遮罩base64
    segmentImageBitmap: ImageBitmap | null;// 图片遮罩位图
    imageEmbeddings?: ort.Tensor | undefined;
}
export interface ModelMapType {
    sam_b: string[];
    mobile_sam: string[];
}

export interface ProviderType extends ort.InferenceSession.ExecutionProviderOption {
    deviceType?: string;
    powerPreference?: string;
}

export interface ConfigType {
    model: string;
    provider: ProviderType;
    device: string;
    threads: number;
    clear_cache: boolean;
}

export interface UploadType {
    type: number; // 1是主视频,2是每一帧的蒙版视频,3是普通图片
    blockIndex: number;
}

export interface FrameItem {
    id: string;
    type: number; // UploadType.type
    sec: number;// 视频第几秒
    base64: string;// 图片BASE64
    imageData: ImageData | undefined;
    imageEmbeddings: ort.Tensor | undefined;
}

export type PointOperaType = 'plus' | 'cover';// 追加/覆盖

export type SegmentOperaType = 'mask' | 'crop';// 遮罩/截取
  • 使用
<script setup lang="ts">
import { ref } from 'vue';
import useSegment from "@/hooks/ai/segment"
import { $t } from '@/locales';

const segment = useSegment();
const {
  videoUploadRef, videoUploadLoading, fileBase64, frameBase64, uploadHandle } = segment._useMediaUpload
const { blockList, blockIndex, blockAddHandle, blockDelHandle, blockShowHandle, } = segment._useBlocks;
const { canvasRef, clickHandle } = segment._useModel

const MediaLabelRadioRef = ref<InstanceType<typeof MediaLabelRadio>>()

async function imageClickHandle(e: MouseEvent) {
  clickHandle(e, MediaLabelRadioRef.value?.getValue() as number)
}

</script>

<template>
  <div>
    <n-layout>
      <n-layout has-sider>
        <n-layout-sider bordered content-class="p-12px" :default-collapsed="true" :collapsed-width="0"
          :show-collapsed-content="false" :width="160" show-trigger="arrow-circle">
          <n-spin :show="videoUploadLoading">
            <n-flex vertical>
              <n-button type="primary" @click="uploadHandle(1)">{{ $t('tapai.uploadVideo') }}</n-button>
              <n-button v-show="frameBase64 !== ''" type="primary" @click="blockAddHandle({})">{{ $t('tapai.addMask')
                }}</n-button>
              <n-button v-show="frameBase64 !== ''" type="primary" @click="renderVideo">{{ $t('tapai.renderVideo')
                }}</n-button>
            </n-flex>
          </n-spin>
        </n-layout-sider>
        <n-layout-content content-class="p-12px">
          <input ref="videoUploadRef" type="file" class="hidden" />
          <n-spin :show="videoUploadLoading">
            <div class="h-360px flex flex-row flex-items-center flex-justify-center">
              <n-empty v-show="frameBase64 === ''" size="large" description="请不要上传大小超过100M,时长超过10M中的视频">
                <template #extra>
                  <n-button size="small" @click="uploadHandle(1)">{{ $t('tapai.uploadVideo') }}</n-button>
                </template>
              </n-empty>
              <div v-show="frameBase64 != ''" class="h-100% position-relative">
                <canvas @click="imageClickHandle" ref="canvasRef"></canvas>
              </div>
            </div>
          </n-spin>
        </n-layout-content>
      </n-layout>
      <n-layout-footer bordered>
        <n-spin v-show="frameBase64 != ''" :show="videoUploadLoading">
          <div class="p-y-12px flex flex-row flex-justify-center">
            <div
              class="flex flex-col border-solid border-0px bg-#fff dark:bg-#2d2d32 cursor-pointer h-120px w-150px mr-10px position-relative flex-items-center flex-justify-around">
              <n-button text @click="blockShowHandle()">{{ $t('tapai.showAll') }}</n-button>
              <n-button text @click="blockAddHandle({})">{{ $t('tapai.addMask') }}</n-button>
            </div>
            <n-scrollbar x-scrollable>
              <div class="flex flex-row h-120px">
                <div v-for="(item, index) in blockList" :key="item.id"
                  class="border-solid border-0px cursor-pointer h-100% w-150px mr-10px position-relative bg-#fff dark:bg-#2d2d32 flex flex-col flex-items-center flex-justify-center"
                  :class="{ 'border-red': blockIndex === index }" @click="blockShowHandle(index)">
                  <div class="h-100px w-100% position-relative flex flex-items-center flex-justify-center">
                    <img v-if="item.segmentImage === ''" :src="frameBase64" class="w-100% h-100% object-contain" />
                    <!-- <img class="w-100% position-absolute top-50% left-50% translate--50%"
                      :src="item.result?.segmentImage" /> -->
                    <img v-else class="w-100% h-100% object-contain" :src="item.segmentImage" />
                  </div>
                  <n-popconfirm :show-icon="false" :positive-text="null" :negative-text="null">

                    <template #trigger>
                      <n-button text>{{ $t('tapai.operation') }}</n-button>
                    </template>
                    <n-space vertical>
                      <n-button text type="primary" size="tiny" @click.stop="uploadHandle(2, index)">
                        {{ item.resource != '' ? $t('tapai.reUpload') : $t('tapai.uploadResource') }}
                      </n-button>
                      <n-button :loading="videoUploadLoading" text type="error" size="tiny"
                        @click.stop="blockDelHandle(index)">
                        {{ $t('tapai.delMask') }}</n-button>
                    </n-space>
                  </n-popconfirm>
                </div>
              </div>
            </n-scrollbar>
          </div>
        </n-spin>
      </n-layout-footer>
    </n-layout>
  </div>
</template>

<style scoped></style>