Java实现AC自动机全文检索

Java基础

浏览数:77

2020-5-30

实现参考博客:http://www.cppblog.com/mythit/archive/2009/04/21/80633.html

第一步,构建Trie树,定义Node类型:

/**
 * Created by zhaoyy on 2017/2/7.
 */
interface Node {

    char value();

    boolean exists();

    boolean isRoot();

    Node parent();

    Node childOf(char c);

    Node fail();

    void setFail(Node node);

    void setExists(boolean exists);

    void add(Node child);

    List<Node> children();
}

第二步,实现两种Node,如果词汇全是可打印的ASCII字符,就采用AsciiNode,否则(比如包含汉字),使用基于hash表的MapNode;这两种Node均继承自AbstractNode:

/**
 * Created by zhaoyy on 2017/2/8.
 */
abstract class AbstractNode implements Node {

    private static final char EMPTY = '\0';
    private final char value;
    private final Node parent;
    private boolean exists;
    private Node fail;

    public AbstractNode(Node parent, char value) {
        this.parent = parent;
        this.value = value;
        this.exists = false;
        this.fail = null;
    }

    public AbstractNode() {
        this(null, EMPTY);
    }


    private static String tab(int n) {
        StringBuilder builder = new StringBuilder();
        for (int i = 0; i < n; i++) {
            builder.append('\t');
        }
        return builder.toString();
    }

    private static String toString(Node node, int depth) {
        StringBuilder builder = new StringBuilder();
        String tab = tab(depth);
        Node fail = node.fail();
        Node parent = node.parent();
        builder
                .append(tab)
                .append('<')
                .append(node.value())
                .append(" exists=\"")
                .append(node.exists())
                .append('"')
                .append(" parent=\"")
                .append(parent == null ? "null" : parent.value())
                .append('"')
                .append(" fail=\"")
                .append(fail == null ? "null" : fail.value())
                .append("\">\r\n");
        for (Node child : node.children())
            builder.append(toString(child, depth + 1));
        builder
                .append(tab)
                .append("</")
                .append(node.value())
                .append(">\r\n");

        return builder.toString();
    }

    @Override
    public char value() {
        return value;
    }


    @Override
    public boolean exists() {
        return exists;
    }

    @Override
    public boolean isRoot() {
        return value == EMPTY;
    }

    @Override
    public Node parent() {
        return parent;
    }

    @Override
    public Node fail() {
        return fail;
    }

    @Override
    public void setFail(Node node) {
        this.fail = node;
    }

    @Override
    public void setExists(boolean exists) {
        this.exists = exists;
    }

    @Override
    public String toString() {
        return toString(this, 0);
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////

/**
 * Created by zhaoyy on 2017/2/8.
 */
final class AsciiNode extends AbstractNode implements Node {


    private static final char FROM = 32;
    private static final char TO = 126;
    private final Node[] children;


    public AsciiNode() {
        super();
        this.children = new Node[TO - FROM + 1];
    }

    public AsciiNode(Node parent, char value) {
        super(parent, value);
        this.children = new Node[TO - FROM + 1];
    }

    @Override
    public Node childOf(char c) {
        if (c >= FROM && c <= TO)
            return children[(int) c - FROM];
        else return null;
    }

    @Override
    public void add(Node child) {
        int index = (int) child.value();
        children[index - FROM] = child;
    }

    @Override
    public List<Node> children() {
        List<Node> nodes = new ArrayList<Node>();
        for (Node child : children)
            if (child != null)
                nodes.add(child);
        return nodes;
    }
}


//////////////////////////////////////////////////////////////////////////////////////////////

/**
 * Created by zhaoyy on 2017/2/8.
 */
final class MapNode extends AbstractNode implements Node {

    private final Map<Character, Node> children;

    public MapNode() {
        super();
        this.children = new HashMap<Character, Node>();
    }

    public MapNode(Node parent, char value) {
        super(parent, value);
        this.children = new HashMap<Character, Node>();
    }

    @Override
    public Node childOf(char c) {
        return children.get(c);
    }

    @Override
    public void add(Node child) {
        children.put(child.value(), child);
    }

    @Override
    public List<Node> children() {
        List<Node> nodes = new ArrayList<Node>();
        nodes.addAll(children.values());
        return nodes;
    }
}

第三步,

首先定义一个Node构造器:

/**
 * Created by zhaoyy on 2017/2/8.
 */
public interface NodeMaker {

    Node make(Node parent, char value);

    Node makeRoot();
}

然后构建AC自动机,实现创建及查找方法:

/**
 * Created by zhaoyy on 2017/2/7.
 */
public final class WordTable {

    private final Node root;


    private WordTable(Collection<? extends CharSequence> words, NodeMaker maker) {
        Node root = buildTrie(words, maker);
        setFailNode(root);
        this.root = root;
    }

    public static WordTable compile(Collection<? extends CharSequence> words) {
        if (words == null || words.isEmpty())
            throw new IllegalArgumentException();
        final NodeMaker maker;
        if (isAscii(words))
            maker = new NodeMaker() {
                @Override
                public Node make(Node parent, char value) {
                    return new AsciiNode(parent, value);
                }

                @Override
                public Node makeRoot() {
                    return new AsciiNode();
                }
            };
        else maker = new NodeMaker() {
            @Override
            public Node make(Node parent, char value) {
                return new MapNode(parent, value);
            }

            @Override
            public Node makeRoot() {
                return new MapNode();
            }
        };
        return new WordTable(words, maker);
    }

    private static boolean isAscii(Collection<? extends CharSequence> words) {
        for (CharSequence word : words) {
            int len = word.length();
            for (int i = 0; i < len; i++) {
                int c = (int) word.charAt(i);
                if (c < 32 || c > 126)
                    return false;
            }
        }
        return true;
    }

    private static Node buildTrie(Collection<? extends CharSequence> sequences, NodeMaker maker) {
        Node root = maker.makeRoot();
        for (CharSequence sequence : sequences) {
            int len = sequence.length();
            Node current = root;
            for (int i = 0; i < len; i++) {
                char c = sequence.charAt(i);
                Node node = current.childOf(c);
                if (node == null) {
                    node = maker.make(current, c);
                    current.add(node);
                }
                current = node;
                if (i == len - 1)
                    node.setExists(true);
            }
        }
        return root;
    }

    private static void setFailNode(final Node root) {
        root.setFail(null);
        Queue<Node> queue = new LinkedList<Node>();
        queue.add(root);
        while (!queue.isEmpty()) {
            Node parent = queue.poll();
            Node temp;
            for (Node child : parent.children()) {
                if (parent.isRoot())
                    child.setFail(root);
                else {
                    temp = parent.fail();
                    while (temp != null) {
                        Node node = temp.childOf(child.value());
                        if (node != null) {
                            child.setFail(node);
                            break;
                        }
                        temp = temp.fail();
                    }
                    if (temp == null)
                        child.setFail(root);
                }
                queue.add(child);
            }
        }
    }


    public boolean findAnyIn(CharSequence cs) {
        int len = cs.length();
        Node node = root;
        for (int i = 0; i < len; i++) {
            Node next = node.childOf(cs.charAt(i));
            if (next == null) {
                next = node.fail();
                if (next == null) {
                    node = root;
                    continue;
                }
            }
            if (next.exists())
                return true;
        }

        return false;
    }

    public List<MatchInfo> search(CharSequence cs) {
        if (cs == null || cs.length() == 0)
            return Collections.emptyList();
        List<MatchInfo> result = new ArrayList<MatchInfo>();
        int len = cs.length();
        Node node = root;
        for (int i = 0; i < len; i++) {
            Node next = node.childOf(cs.charAt(i));
            if (next == null) {
                next = node.fail();
                if (next == null) {
                    node = root;
                    continue;
                }
            }
            if (next.exists()) {
                MatchInfo info = new MatchInfo(i, next);
                result.add(info);
                node = root;
                continue;
            }
            node = next;
        }
        return result;
    }

    @Override
    public String toString() {
        return root.toString();
    }
}

定义一个保存查找结果的实体:

/**
 * Created by zhaoyy on 2017/2/7.
 */
public final class MatchInfo {

    private final int index;
    private final String word;

    public MatchInfo(int index, String word) {
        this.index = index;
        this.word = word;
    }

    public MatchInfo(int index, Node node) {
        StringBuilder builder = new StringBuilder();
        while (node != null) {
            if (!node.isRoot())
                builder.append(node.value());
            node = node.parent();
        }
        String word = builder.reverse().toString();
        this.index = index + 1 - word.length();
        this.word = word;
    }


    public int getIndex() {
        return index;
    }

    public String getWord() {
        return word;
    }

    @Override
    public String toString() {
        return index + ":" + word;
    }
}

第四步,调用Demo:

public static void main(String[] args) {
        List<String> list = Arrays.asList("say", "her", "he", "she", "shr", "alone");
        WordTable table = WordTable.compile(list);
        System.out.println(table);
        System.out.println(table.search("1shesaynothingabouthislivinghimalone"));
    }

以下是输出结果:

<  exists="false" parent="null" fail="null">
	<s exists="false" parent=" " fail=" ">
		<a exists="false" parent="s" fail="a">
			<y exists="true" parent="a" fail=" ">
			</y>
		</a>
		<h exists="false" parent="s" fail="h">
			<e exists="true" parent="h" fail="e">
			</e>
			<r exists="true" parent="h" fail=" ">
			</r>
		</h>
	</s>
	<h exists="false" parent=" " fail=" ">
		<e exists="true" parent="h" fail=" ">
			<r exists="true" parent="e" fail=" ">
			</r>
		</e>
	</h>
	<a exists="false" parent=" " fail=" ">
		<l exists="false" parent="a" fail=" ">
			<o exists="false" parent="l" fail=" ">
				<n exists="false" parent="o" fail=" ">
					<e exists="true" parent="n" fail=" ">
					</e>
				</n>
			</o>
		</l>
	</a>
</ >

[1:she, 4:say, 31:alone]

 

作者:Acce1erator